├── mask.png
├── inputs.png
├── future-bach.png
├── masked-inputs.png
├── util.py
├── cog.yaml
├── README.md
├── data.py
├── output.py
├── predict.py
└── train.py
/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreasjansson/cantable-diffuguesion/HEAD/mask.png
--------------------------------------------------------------------------------
/inputs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreasjansson/cantable-diffuguesion/HEAD/inputs.png
--------------------------------------------------------------------------------
/future-bach.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreasjansson/cantable-diffuguesion/HEAD/future-bach.png
--------------------------------------------------------------------------------
/masked-inputs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreasjansson/cantable-diffuguesion/HEAD/masked-inputs.png
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 |
3 |
4 | args = {}
5 |
6 |
7 | # TODO: this doesn't work
8 | @contextmanager
9 | def define_args():
10 | global args
11 |
12 | before = {k: v for k, v in globals().items()}
13 | yield
14 | after = globals()
15 | args = {k: after[k] for k in set(after) - set(before)}
16 | print(args)
17 |
18 |
19 | def get_args():
20 | return args
21 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | cuda: "11.7"
3 | python_version: "3.7"
4 | gpu: true
5 | python_packages:
6 | - torch==1.13.1
7 | - diffusers==0.11.1
8 | - decorator==4.3.0
9 | - flask==1.0.2
10 | - flask-cors==3.0.7
11 | - itsdangerous==1.1.0
12 | - jinja2==2.10
13 | - markupsafe==1.1.0
14 | - music21==6.7.1
15 | - tqdm==4.29.0
16 | - werkzeug==0.14.1
17 | - numpy==1.19.5
18 | - midi2audio==0.1.1
19 | - jupyterlab==3.5.2
20 | - matplotlib==3.5.3
21 | - pretty_midi==0.2.9
22 | - midiSynth==0.3
23 | - pyfluidsynth==1.3.1
24 | - wandb==0.13.7
25 | - mypy==0.991
26 | - soundfile==0.11.0
27 | system_packages:
28 | - musescore
29 | - fluidsynth --fix-missing
30 | - ffmpeg
31 | - lilypond
32 |
33 | predict: predict.py:Predictor
34 | image: r8.im/andreasjansson/cantable-diffuguesion
35 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cantable Diffuguesion
2 |
3 | [](https://replicate.com/andreasjansson/cantable-diffuguesion)
4 |
5 | _Bach chorale generation and harmonization_
6 |
7 | 
8 |
9 | ## Usage
10 |
11 | You can use Cantable Diffuguesion to generate Bach chorales unconditionally, or harmonize melodies or parts of melodies.
12 |
13 | For harmonization we use [tinyNotation](https://web.mit.edu/music21/doc/moduleReference/moduleTinyNotation.html), with a few modifications:
14 | * The `?` symbol followed by a duration denotes a section that the model should in-paint, e.g. `?2` will in-paint a half note duration.
15 | * The `?*` symbol will in-paint everything between a defined beginning and an end, e.g. `c2 ?* B4 c2` will start the piece with `c2`, then generate notes for the specified duration, and finally the melody will end with `B4 c2`.
16 | * Optional bars `|` are ignored and can be used to make the melody notation more pleasing.
17 |
18 | ## Training
19 |
20 | Cantable Diffuguesion is a diffusion model trained to generate Bach chorales. Four-part chorales are presented to the network as 4-channel arrays. The pitches of the individual parts are activated in the corresponding channel of the array. Here is a plot of a single input example, where the four channels are plotted on separate images:
21 |
22 |
23 |
24 | As in Stable Diffusion, a U-Net is trained to predict the noise residual.
25 |
26 | After training the generative model we add 8 channels to the inputs, with the middle four channels representing a mask, and the last four channels are masked chorales. We randomly mask the four channels individually, as opposed to [Stable Diffusion Inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) that use a one-channel mask.
27 |
28 | The two plots below show a mask and a masked input array:
29 |
30 |
31 |
32 | ## Dataset
33 |
34 | We use all four-part pieces in the [Music21 Bach Chorales corpus](https://web.mit.edu/music21/doc/moduleReference/moduleCorpusChorales.html). 85% are used for training, the rest for validation and testing.
35 |
36 | ## Inspiration
37 |
38 | * [Riffusion](https://github.com/riffusion/riffusion)
39 | * [DeepBach](https://arxiv.org/abs/1612.01010)
40 | * [Dreambooth Inpainting](https://github.com/huggingface/diffusers/blob/50b6513531da7e258204871a9c675a56875d9e69/examples/research_projects/dreambooth_inpaint/README.md)
41 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import random
2 | from pathlib import Path
3 | import music21
4 | import torch
5 | from torch import nn
6 | from torch.utils.data import Dataset, DataLoader, Subset
7 |
8 |
9 | MIN_PITCH = 30
10 | MAX_PITCH = 94
11 | MIN_TRANSPOSE = -6
12 | MAX_TRANSPOSE = 6
13 | NUM_PARTS = 4
14 | RESOLUTION = 4 # steps per quarter-note, e.g. 4 == 16th note resolution
15 | WIDTH = 64
16 | HEIGHT = MAX_PITCH - MIN_PITCH
17 | BATCH_SIZE = 32
18 |
19 |
20 | def train_val_test_dataloaders():
21 | ds = BachDataset()
22 |
23 | rng = random.Random(0)
24 | idx = list(range(len(ds)))
25 | rng.shuffle(idx)
26 |
27 | train_idx = idx[: int(len(ds) * 0.8)]
28 | val_idx = idx[len(train_idx) : int(len(ds) * 0.9)]
29 | test_idx = idx[len(train_idx) + len(val_idx) :]
30 |
31 | train_ds = Subset(ds, indices=train_idx)
32 | val_ds = Subset(ds, indices=test_idx)
33 | test_ds = Subset(ds, indices=val_idx)
34 |
35 | train_ds = TransformDataset(train_ds, transform=RandomCropAndTranspose())
36 | train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
37 | val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0)
38 | test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)
39 |
40 | return train_dl, val_dl, test_dl
41 |
42 |
43 | class BachDataset(Dataset):
44 | def __init__(self):
45 | self.data = load_data()
46 |
47 | def __len__(self):
48 | return len(self.data)
49 |
50 | def __getitem__(self, idx):
51 | return self.data[idx]
52 |
53 |
54 | def load_data():
55 | cache_path = Path("dataset.cache.pt")
56 | if cache_path.exists():
57 | return torch.load(cache_path)
58 |
59 | data = []
60 | corpuses = music21.corpus.search("bach")
61 | for i, piece in enumerate(corpuses):
62 | if i % 10 == 0:
63 | print(f"{i+1}/{len(corpuses)}")
64 | piece = piece.parse()
65 | if len(piece.parts) == NUM_PARTS:
66 | data.append(piece_to_array(piece))
67 | torch.save(data, cache_path)
68 | return data
69 |
70 |
71 | def piece_to_array(piece):
72 | duration = int(piece.expandRepeats().duration.quarterLength * RESOLUTION)
73 | arr = torch.zeros([NUM_PARTS, duration, MAX_PITCH - MIN_PITCH])
74 |
75 | for part_i, part in enumerate(piece.parts):
76 | notes = part.expandRepeats().flat.notes
77 | for note in notes:
78 | next_note = note.next("Note")
79 | next_is_same = next_note and note.pitch.midi == next_note.pitch.midi
80 | subtract = 1 if next_is_same else 0
81 | start_column = int(note.offset * RESOLUTION)
82 | end_column = int(
83 | (note.offset + note.duration.quarterLength) * RESOLUTION - subtract
84 | )
85 | pitch_row = note.pitch.midi - MIN_PITCH
86 | arr[part_i, start_column:end_column, pitch_row] = 1
87 |
88 | return arr
89 |
90 |
91 | class RandomCropAndTranspose(nn.Module):
92 | def forward(self, arr):
93 | duration = arr.shape[1]
94 | t = torch.randint(0, duration - WIDTH, size=(1,)).item()
95 | t = 4 * (t // 16)
96 | cropped = arr[:, t : t + WIDTH]
97 | transpose = torch.randint(MIN_TRANSPOSE, MAX_TRANSPOSE, size=(1,)).item()
98 | transposed = torch.roll(cropped, transpose, dims=2)
99 | return transposed
100 |
101 |
102 | class TransformDataset(Dataset):
103 | def __init__(self, subset, transform=None):
104 | self.subset = subset
105 | self.transform = transform
106 |
107 | def __getitem__(self, index):
108 | x = self.subset[index]
109 | if self.transform:
110 | x = self.transform(x)
111 | return x
112 |
113 | def __len__(self):
114 | return len(self.subset)
115 |
--------------------------------------------------------------------------------
/output.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import subprocess
4 | import tempfile
5 | from PIL import Image
6 | import torch
7 | import pretty_midi as pm
8 | import matplotlib.pyplot as plt
9 | from midiSynth.synth import MidiSynth
10 |
11 | from data import RESOLUTION, MIN_PITCH
12 |
13 |
14 | PALETTE = torch.tensor(
15 | [
16 | [0, 0, 0], # black
17 | [255, 0, 0], # red
18 | [0, 255, 0], # green
19 | [0, 0, 255], # blue
20 | [255, 0, 255], # magenta
21 | [255, 255, 0], # yellow
22 | [255, 255, 255], # white
23 | ]
24 | )
25 |
26 |
27 | def array_to_plot(arr):
28 | color_matrix = PALETTE[
29 | ((arr.cpu() > 0.75).permute([1, 2, 0]) * torch.arange(1, 5))
30 | .max(axis=2)[0]
31 | .T.to(int)
32 | ]
33 | resize = 0.1
34 | height, width, _ = color_matrix.shape
35 | figsize = width * resize, height * resize
36 |
37 | fig = plt.figure(figsize=figsize)
38 | ax = fig.add_axes([0, 0, 1, 1])
39 | ax.axis("off")
40 |
41 | ax.imshow(color_matrix, origin="lower")
42 | fig.canvas.draw()
43 | print(fig.canvas.get_width_height())
44 |
45 | img = Image.frombytes(
46 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()
47 | )
48 | plt.close(fig)
49 | return img
50 |
51 |
52 | def array_to_plots(arr):
53 | fig, axs = plt.subplots(2, 2, figsize=[8, 8])
54 | for i in range(2):
55 | for j in range(2):
56 | ax = axs[i][j]
57 | ax.imshow(arr[i * 2 + j].T, origin="lower")
58 | fig.tight_layout()
59 | fig.canvas.draw()
60 |
61 | img = Image.frombytes(
62 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()
63 | )
64 | plt.close(fig)
65 | return img
66 |
67 |
68 | def midi_to_wav(midi_path, wav_path):
69 | MidiSynth().midi2audio(str(midi_path), str(wav_path))
70 | return str(wav_path)
71 |
72 |
73 | def midi_to_mp3(midi_path, mp3_path):
74 | with tempfile.TemporaryDirectory() as temp_dir:
75 | wav_path = temp_dir + "/audio.wav"
76 | MidiSynth().midi2audio(str(midi_path), str(wav_path))
77 | subprocess.check_output(
78 | [
79 | "ffmpeg",
80 | "-y",
81 | "-i",
82 | str(wav_path),
83 | "-af",
84 | "silenceremove=1:0:-50dB,aformat=dblp,areverse,silenceremove=1:0:-50dB,aformat=dblp,areverse", # strip silence
85 | str(mp3_path),
86 | ],
87 | )
88 | return mp3_path
89 |
90 |
91 | def midi_to_score(midi_path, score_path):
92 | with tempfile.TemporaryDirectory() as temp_dir:
93 | lilypond_path = temp_dir + "/score.ly"
94 | subprocess.check_output(
95 | ["midi2ly", str(midi_path), "--output", str(lilypond_path)],
96 | )
97 | subprocess.check_output(
98 | [
99 | "lilypond",
100 | "-fpng",
101 | "-dresolution=300",
102 | '-dpaper-size="a5landscape"',
103 | "-dcrop",
104 | "-o",
105 | str(Path(score_path).with_suffix("")),
106 | str(lilypond_path),
107 | ]
108 | )
109 | cropped_path = str(Path(score_path).with_suffix("")) + ".cropped." + str(Path(score_path).suffix)
110 | if Path(cropped_path).exists():
111 | os.rename(cropped_path, str(score_path))
112 | return score_path
113 |
114 |
115 | def array_to_midi(
116 | arr, midi_path, instrument_name="Lead 6 (voice)", tempo=90, time_sig=4
117 | ):
118 | sec_per_beat = 60 / tempo
119 | track = pm.PrettyMIDI(initial_tempo=tempo)
120 | track.time_signature_changes.append(pm.TimeSignature(time_sig, 4, 0))
121 |
122 | for mat in arr:
123 | instrument = pm.Instrument(pm.instrument_name_to_program(instrument_name))
124 | write_notes(instrument, mat, sec_per_beat)
125 | track.instruments.append(instrument)
126 |
127 | track.write(str(midi_path))
128 | return midi_path
129 |
130 |
131 | def write_notes(instrument, mat, sec_per_beat):
132 | def append_note(pitch, start_beat, end_beat):
133 | note = pm.Note(
134 | pitch=pitch,
135 | velocity=120,
136 | start=start_beat * sec_per_beat,
137 | end=end_beat * sec_per_beat,
138 | )
139 | instrument.notes.append(note)
140 |
141 | cur_pitch = None
142 | start_beat = None
143 | for t, vec in enumerate(mat):
144 | beat = t / RESOLUTION
145 |
146 | max_i = int(torch.argmax(vec).item())
147 | if vec[max_i] > 0.75:
148 | pitch = max_i + MIN_PITCH
149 | if pitch != cur_pitch:
150 | if cur_pitch:
151 | append_note(cur_pitch, start_beat, beat)
152 | cur_pitch = pitch
153 | start_beat = beat
154 | if start_beat is None:
155 | start_beat = beat
156 | cur_pitch = pitch
157 | else:
158 | if cur_pitch is not None:
159 | append_note(cur_pitch, start_beat, beat)
160 | start_beat = None
161 | cur_pitch = None
162 |
163 | if cur_pitch is not None:
164 | append_note(cur_pitch, start_beat, mat.shape[1] / RESOLUTION)
165 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Optional
3 | import sys
4 | import torch
5 | import torch.nn.functional as F
6 | import music21
7 | from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
8 | from cog import BaseModel, BasePredictor, Input, Path
9 |
10 | from data import RESOLUTION, MIN_PITCH, MAX_PITCH
11 | from output import array_to_midi, midi_to_score, midi_to_mp3
12 |
13 |
14 | class Output(BaseModel):
15 | mp3: Optional[Path]
16 | #score: Optional[Path]
17 | midi: Optional[Path]
18 |
19 |
20 | class Predictor(BasePredictor):
21 | def setup(self):
22 | self.model = model = UNet2DModel.from_pretrained("checkpoints/unet").to("cuda")
23 |
24 | def predict(
25 | self,
26 | duration: int = Input(
27 | description="Duration in quarter notes",
28 | choices=(
29 | 64 // RESOLUTION,
30 | 128 // RESOLUTION,
31 | 256 // RESOLUTION,
32 | 512 // RESOLUTION,
33 | 1024 // RESOLUTION,
34 | ),
35 | default=128 // RESOLUTION,
36 | ),
37 | tempo: float = Input(
38 | description="Tempo in quarter notes per minute", default=90, ge=40, le=200
39 | ),
40 | melody: str = Input(
41 | description="Melody in tinyNotation format. Accepts ? for inpainting a single note, and ?* for inpainting between two melodic parts",
42 | default="",
43 | ),
44 | # not working :(
45 | # return_score: bool = Input(
46 | # description="Return sheet music score", default=True
47 | # ),
48 | return_mp3: bool = Input(description="Return mp3 audio", default=True),
49 | return_midi: bool = Input(description="Return midi", default=True),
50 | seed: int = Input(description="Random seed. Random if seed == -1", default=-1),
51 | ) -> Output:
52 | num_outputs = 1
53 |
54 | #if not return_score and not return_mp3 and not return_midi:
55 | if not return_mp3 and not return_midi:
56 | raise Exception(
57 | "At least one of return_score, return_mp3, return_midi must be true"
58 | )
59 |
60 | if seed == -1:
61 | seed = random.randint(0, 100000)
62 |
63 | length = duration * RESOLUTION
64 |
65 | if melody:
66 | mel_inputs, mel_mask, length = parse_melody(melody, length, num_outputs)
67 | else:
68 | mel_inputs = torch.zeros([num_outputs, 4, length, 64]).to("cuda")
69 | mel_mask = torch.zeros_like(mel_inputs, dtype=torch.bool).to("cuda")
70 |
71 | generator = torch.Generator(device="cuda")
72 | generator.manual_seed(seed)
73 | array = sample(self.model, generator, mel_inputs, mel_mask, 1000, length)[0]
74 | midi = array_to_midi(array, "/tmp/midi.mid", tempo=tempo)
75 |
76 | output = Output(
77 | midi=Path(midi) if return_midi else None,
78 | #score=Path(midi_to_score(midi, "/tmp/score.png")) if return_score else None,
79 | mp3=Path(midi_to_mp3(midi, "/tmp/audio.mp3")) if return_mp3 else None,
80 | )
81 |
82 | return output
83 |
84 |
85 | def parse_melody(text, length, num_outputs):
86 | text = text.replace("|", "")
87 | if "?*" in text:
88 | if text.count("?*") > 1:
89 | raise Exception("Can only have on '?*' in the input")
90 | text1, text2 = text.split("?*")
91 | mel_inputs1, mel_mask1 = parse_notes(text1, num_outputs)
92 | mel_inputs2, mel_mask2 = parse_notes(text2, num_outputs)
93 |
94 | notes_length = mel_inputs1.shape[2] + mel_inputs2.shape[2]
95 | if notes_length > length:
96 | length = notes_length
97 | mel_inputs = torch.zeros([num_outputs, 4, length, 64]).to("cuda")
98 | mel_mask = torch.zeros_like(mel_inputs, dtype=torch.bool).to("cuda")
99 |
100 | mel_inputs[:, :, : mel_inputs1.shape[2]] = mel_inputs1
101 | mel_mask[:, :, : mel_mask1.shape[2]] = mel_mask1
102 | mel_inputs[:, :, -mel_inputs2.shape[2] :] = mel_inputs2
103 | mel_mask[:, :, -mel_mask2.shape[2] :] = mel_mask2
104 | else:
105 | mel_inputs1, mel_mask1 = parse_notes(text, num_outputs)
106 |
107 | notes_length = mel_inputs1.shape[2]
108 | if notes_length > length:
109 | length = notes_length
110 | mel_inputs = torch.zeros([num_outputs, 4, length, 64]).to("cuda")
111 | mel_mask = torch.zeros_like(mel_inputs, dtype=torch.bool).to("cuda")
112 |
113 | mel_inputs[:, :, : mel_inputs1.shape[2]] = mel_inputs1
114 | mel_mask[:, :, : mel_mask1.shape[2]] = mel_mask1
115 |
116 | if length % 64 != 0:
117 | new_length = length - (length % 64) + 64
118 | pad = new_length - length
119 | mel_inputs = F.pad(mel_inputs, (0, 0, 0, pad), "constant", 0)
120 | mel_mask = F.pad(mel_mask, (0, 0, 0, pad), "constant", True)
121 |
122 | return mel_inputs, mel_mask, length
123 |
124 |
125 | def parse_notes(text, num_outputs):
126 | text = text.replace("?", "CC") # hack for inpainting masks
127 |
128 | notes = music21.converter.parse("tinyNotation: 4/4 " + text).flat.notes
129 | if len(notes) > 0:
130 | mel_length = int((notes[-1].offset + notes[-1].duration.quarterLength) * RESOLUTION)
131 | else:
132 | mel_length = 0
133 |
134 | mel_inputs = torch.zeros([num_outputs, 4, mel_length, 64]).to("cuda")
135 | mel_mask = torch.zeros_like(mel_inputs, dtype=torch.bool).to("cuda")
136 |
137 | for note in notes:
138 | if note.pitch.midi != 36: # == CC == "?"
139 | pitch = note.pitch.midi + 12
140 | if pitch < MIN_PITCH:
141 | raise Exception(f"Pitch is too low: {note}")
142 | if pitch > MAX_PITCH:
143 | raise Exception(f"Pitch is too high: {note}")
144 | start_index = int(note.offset * RESOLUTION)
145 | end_index = int(
146 | note.offset * RESOLUTION + note.duration.quarterLength * RESOLUTION
147 | )
148 | mel_inputs[0, 0, start_index:end_index, pitch - MIN_PITCH] = 1
149 | mel_mask[:, 0, start_index:end_index] = True
150 |
151 | # staccato
152 | mel_inputs[0, 0, start_index - 1, pitch - MIN_PITCH] = 0
153 |
154 | return mel_inputs, mel_mask
155 |
156 |
157 | @torch.no_grad()
158 | def sample(model, generator, inputs, mask, num_inference_steps, length):
159 | num_outputs = inputs.shape[0]
160 | length = inputs.shape[2]
161 | noise_scheduler = DDPMScheduler(num_train_timesteps=num_inference_steps)
162 | image = torch.randn(
163 | (num_outputs, 4, length, model.sample_size),
164 | generator=generator,
165 | device="cuda",
166 | )
167 | noise_scheduler.set_timesteps(num_inference_steps)
168 |
169 | for t in noise_scheduler.timesteps:
170 | model_input = torch.cat([image, mask, inputs], dim=1)
171 | model_output = model(model_input, t).sample
172 |
173 | image = noise_scheduler.step(
174 | model_output, t, image, generator=generator
175 | ).prev_sample
176 |
177 | image = (image / 2 + 0.5).clamp(0, 1)
178 | image[mask] = inputs[mask]
179 | return image[:, :, :length]
180 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 |
4 | import wandb
5 | import torch
6 | from torch import nn
7 | from tqdm.auto import tqdm
8 | from torch.nn.utils import clip_grad_norm_
9 | import torch.nn.functional as F
10 | from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
11 | from diffusers.optimization import get_scheduler
12 | from diffusers.training_utils import EMAModel
13 |
14 | from data import train_val_test_dataloaders
15 | from output import array_to_plot, array_to_midi, midi_to_mp3
16 | from util import define_args, get_args
17 |
18 | with define_args():
19 | from data import BATCH_SIZE, WIDTH, HEIGHT
20 |
21 | EVAL_BATCH_SIZE = 4
22 | MIXED_PRECISION = "no"
23 | LEARNING_RATE = 1e-4
24 | ADAM_BETA1 = 0.95
25 | ADAM_BETA2 = 0.999
26 | ADAM_WEIGHT_DECAY = 1e-6
27 | ADAM_EPSILON = 1e-08
28 | LR_SCHEDULER = "cosine"
29 | LR_WARMUP_STEPS = 500
30 | NUM_EPOCHS = 100000
31 | GRADIENT_ACCUMULATION_STEPS = 10
32 | USE_EMA = True
33 | EMA_INV_GAMMA = 1.0
34 | EMA_POWER = 3 / 4
35 | EMA_MAX_DECAY = 0.9999
36 | SAVE_MEDIA_EPOCHS = 100
37 | SAVE_MODEL_EPOCHS = 200
38 | OUTPUT_DIR = "checkpoints"
39 | RESUME_FROM = "checkpoints-570000"
40 | TRAIN_INPAINTER = True
41 | SAMPLE_COUNT = 4
42 |
43 |
44 | def main():
45 | os.makedirs(OUTPUT_DIR, exist_ok=True)
46 |
47 | if RESUME_FROM:
48 | model = UNet2DModel.from_pretrained(RESUME_FROM + "/unet").to("cuda")
49 |
50 | if model.conv_in.weight.shape[1] == 4 and TRAIN_INPAINTER:
51 | new_conv = nn.Conv2d(12, 128, kernel_size=3, padding=(1, 1)).to("cuda")
52 | new_conv.weight.data[:, :4, :, :] = model.conv_in.weight
53 | new_conv.bias.data = model.conv_in.bias
54 | model.conv_in = new_conv
55 | else:
56 | model = UNet2DModel(
57 | sample_size=64,
58 | in_channels=12 if TRAIN_INPAINTER else 4,
59 | out_channels=4,
60 | layers_per_block=2,
61 | block_out_channels=(128, 128, 256, 256, 512, 512),
62 | down_block_types=(
63 | "DownBlock2D",
64 | "DownBlock2D",
65 | "DownBlock2D",
66 | "DownBlock2D",
67 | "AttnDownBlock2D",
68 | "DownBlock2D",
69 | ),
70 | up_block_types=(
71 | "UpBlock2D",
72 | "AttnUpBlock2D",
73 | "UpBlock2D",
74 | "UpBlock2D",
75 | "UpBlock2D",
76 | "UpBlock2D",
77 | ),
78 | ).to("cuda")
79 |
80 | noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
81 | optimizer = torch.optim.AdamW(
82 | model.parameters(),
83 | lr=LEARNING_RATE,
84 | betas=(ADAM_BETA1, ADAM_BETA2),
85 | weight_decay=ADAM_WEIGHT_DECAY,
86 | eps=ADAM_EPSILON,
87 | )
88 |
89 | train_dl, test_dl, val_dl = train_val_test_dataloaders()
90 |
91 | lr_scheduler = get_scheduler(
92 | LR_SCHEDULER,
93 | optimizer=optimizer,
94 | num_warmup_steps=LR_WARMUP_STEPS,
95 | num_training_steps=(len(train_dl) * NUM_EPOCHS) // GRADIENT_ACCUMULATION_STEPS,
96 | )
97 |
98 | num_update_steps_per_epoch = math.ceil(len(train_dl) / GRADIENT_ACCUMULATION_STEPS)
99 |
100 | ema_model = EMAModel(
101 | model, inv_gamma=EMA_INV_GAMMA, power=EMA_POWER, max_value=EMA_MAX_DECAY
102 | )
103 |
104 | print("args", get_args())
105 | wandb_run = wandb.init(project="cantable-diffuguesion", config=get_args())
106 |
107 | pipeline = DDPMPipeline(
108 | unet=ema_model.averaged_model if USE_EMA else model,
109 | scheduler=noise_scheduler,
110 | ).to("cuda")
111 |
112 | global_step = 0
113 | for epoch in range(NUM_EPOCHS):
114 | model.train()
115 | progress_bar = tqdm(
116 | total=num_update_steps_per_epoch,
117 | )
118 | progress_bar.set_description(f"Epoch {epoch}")
119 |
120 | # Generate sample images for visual inspection
121 | if epoch % SAVE_MEDIA_EPOCHS == 0:
122 | log_media(
123 | model,
124 | noise_scheduler,
125 | global_step,
126 | val_dl,
127 | )
128 |
129 | if epoch % SAVE_MODEL_EPOCHS == 0:
130 | # save the model
131 | pipeline.save_pretrained(OUTPUT_DIR)
132 | artifact = wandb.Artifact("checkpoints", type="checkpoints")
133 | artifact.add_dir("checkpoints") # Adds multiple files to artifact
134 | wandb_run.log_artifact(artifact)
135 |
136 | for step, batch in enumerate(train_dl):
137 | clean_images = batch.to("cuda")
138 |
139 | # Sample noise that we'll add to the images
140 | noise = torch.randn(clean_images.shape).to(clean_images.device)
141 | bsz = clean_images.shape[0]
142 | # Sample a random timestep for each image
143 | timesteps = torch.randint(
144 | 0,
145 | noise_scheduler.config.num_train_timesteps,
146 | (bsz,),
147 | device=clean_images.device,
148 | ).long()
149 |
150 | # Add noise to the clean images according to the noise magnitude at each timestep
151 | # (this is the forward diffusion process)
152 | noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
153 |
154 | if TRAIN_INPAINTER:
155 | mask = random_mask(clean_images.shape[0])
156 | masked_images = clean_images.clone()
157 | masked_images[~mask] *= 0
158 | model_input = torch.cat([noisy_images, mask, masked_images], dim=1)
159 | else:
160 | model_input = noisy_images
161 |
162 | # Predict the noise residual
163 | noise_pred = model(model_input, timesteps).sample
164 | loss = F.mse_loss(noise_pred, noise)
165 | loss.backward()
166 | clip_grad_norm_(model.parameters(), 1.0)
167 | optimizer.step()
168 | lr_scheduler.step()
169 | if USE_EMA:
170 | ema_model.step(model)
171 | optimizer.zero_grad()
172 |
173 | progress_bar.update(1)
174 | global_step += 1
175 |
176 | logs = {"loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0]}
177 | if USE_EMA:
178 | logs["ema_decay"] = ema_model.decay
179 |
180 | if global_step % 10 == 0:
181 | wandb.log(logs, step=global_step)
182 |
183 | progress_bar.set_postfix(**logs)
184 |
185 | progress_bar.close()
186 | log_media(
187 | ema_model.averaged_model if USE_EMA else model,
188 | noise_scheduler,
189 | global_step,
190 | val_dl,
191 | )
192 | pipeline.save_pretrained(OUTPUT_DIR)
193 | wandb.save(OUTPUT_DIR)
194 |
195 |
196 | @torch.no_grad()
197 | def sample(model, noise_scheduler, mask=None, masked_images=None):
198 | generator = torch.Generator(device="cuda")
199 | generator.manual_seed(0)
200 | num_inference_steps = 1000
201 | noise_scheduler = DDPMScheduler(num_train_timesteps=num_inference_steps)
202 | image = torch.randn(
203 | (SAMPLE_COUNT, model.in_channels, WIDTH, model.sample_size),
204 | generator=generator,
205 | device="cuda",
206 | )
207 | noise_scheduler.set_timesteps(num_inference_steps)
208 |
209 | for t in noise_scheduler.timesteps:
210 | # 1. predict noise model_output
211 | if TRAIN_INPAINTER:
212 | model_input = torch.cat([image.to("cuda"), mask.to("cuda"), masked_images.to("cuda")], dim=1).to("cuda")
213 | else:
214 | model_input = image
215 | model_output = model(model_input, t).sample
216 |
217 | # 2. compute previous image: x_t -> x_t-1
218 | image = noise_scheduler.step(model_output, t, image).prev_sample
219 |
220 | return (image / 2 + 0.5).clamp(0, 1)
221 |
222 |
223 | def random_mask(count):
224 | mask = torch.zeros([count, 4, WIDTH, HEIGHT], dtype=bool).to("cuda")
225 | for i in range(count):
226 | for c in range(4):
227 | start_index, end_index = torch.randint(0, WIDTH, (2,))
228 | if start_index > end_index:
229 | start_index, end_index = end_index, start_index
230 | mask[i, c, start_index:end_index] = True
231 | return mask
232 |
233 |
234 | def log_media(model, noise_scheduler, global_step, val_dl=None):
235 | logs = {}
236 | # run pipeline in inference (sample random noise and denoise)
237 | if TRAIN_INPAINTER:
238 | mask = random_mask(SAMPLE_COUNT)
239 | masked_images = torch.zeros_like(mask, dtype=torch.float32).to("cuda")
240 | for i, batch in enumerate(val_dl):
241 | if i == SAMPLE_COUNT:
242 | break
243 | masked_images[i] = batch[0, :, :WIDTH]
244 | masked_images[~mask] *= 0
245 | logs["masked_inputs"] = [wandb.Image(array_to_plot(i)) for i in masked_images]
246 |
247 | arrays = sample(model, noise_scheduler, mask.to("cuda"), masked_images.to("cuda"))
248 | else:
249 | arrays = sample(model, noise_scheduler)
250 | #arrays = torch.from_numpy(arrays).permute([0, 3, 1, 2])
251 |
252 | plots = [array_to_plot(a) for a in arrays]
253 | midis = [array_to_midi(a, f"/tmp/midi-output-{i}.mid") for i, a in enumerate(arrays)]
254 | audios = [
255 | midi_to_mp3(midi, f"/tmp/audio-output-{i}.mp3") for i, midi in enumerate(midis)
256 | ]
257 |
258 | logs["plots"] = [wandb.Image(plot) for plot in plots]
259 | logs["audios"] = [wandb.Audio(audio) for audio in audios]
260 |
261 | wandb.log(logs, step=global_step)
262 |
263 |
264 | if __name__ == "__main__":
265 | main()
266 |
--------------------------------------------------------------------------------