├── polyffusion ├── __init__.py ├── chord_extractor │ ├── mir │ │ ├── io │ │ │ ├── implement │ │ │ │ ├── __init__.py │ │ │ │ ├── unknown_io.py │ │ │ │ ├── midi_io.py │ │ │ │ ├── music_io.py │ │ │ │ ├── scalar_io.py │ │ │ │ ├── chroma_io.py │ │ │ │ ├── spectrogram_io.py │ │ │ │ └── regional_spectrogram_io.py │ │ │ ├── __init__.py │ │ │ └── feature_io_base.py │ │ ├── .gitignore │ │ ├── extractors │ │ │ ├── __init__.py │ │ │ ├── misc.py │ │ │ ├── librosa_extractor.py │ │ │ ├── extractor_base.py │ │ │ └── vamp_extractor.py │ │ ├── requirements.txt │ │ ├── settings.py │ │ ├── __init__.py │ │ ├── common.py │ │ ├── data │ │ │ ├── sparse_tag_template.svl │ │ │ ├── midi_template.svl │ │ │ ├── curve_template.svl │ │ │ ├── tuning.n3 │ │ │ ├── spectrogram_template.svl │ │ │ ├── pitch_template.svl │ │ │ ├── bothchroma.n3 │ │ │ ├── chroma.n3 │ │ │ ├── tunedlogfreqspec.n3 │ │ │ └── chordino.n3 │ │ ├── music_base.py │ │ └── cache.py │ ├── example.sh │ ├── .gitignore │ ├── example.mid │ ├── requirements.txt │ ├── setup.py │ ├── io_new │ │ ├── list_io.py │ │ ├── jams_io.py │ │ ├── air_io.py │ │ ├── madmom_io.py │ │ ├── beatlab_io.py │ │ ├── key_io.py │ │ ├── lyric_io.py │ │ ├── tag_io.py │ │ ├── chordlab_io.py │ │ ├── jointbeat_io.py │ │ ├── downbeat_io.py │ │ ├── salami_io.py │ │ ├── complex_chord_io.py │ │ ├── midilab_io.py │ │ ├── osu_io.py │ │ └── beat_align_io.py │ ├── README.TXT │ ├── extractors │ │ └── rule_based_channel_reweight.py │ ├── __init__.py │ ├── main.py │ └── example.out ├── pyrightconfig.json ├── stable_diffusion │ ├── losses │ │ ├── __init__.py │ │ ├── discriminator.py │ │ ├── lpips.py │ │ └── contperceptual.py │ ├── util.py │ └── sampler │ │ └── __init__.py ├── data │ ├── example.mid │ ├── pop909_extractor.py │ ├── dataloader_musicalion.py │ ├── polydis_format_to_mine.py │ └── dataloader.py ├── setup.py ├── dl_modules │ ├── __init__.py │ ├── naive_nn.py │ ├── chord_enc.py │ ├── txt_enc.py │ ├── chord_dec.py │ └── pianotree_enc.py ├── params │ ├── chd_8bar.yaml │ ├── autoencoder.yaml │ ├── ddpm.yaml │ ├── sdf_pnotree.yaml │ ├── sdf_txtvnl.yaml │ ├── sdf.yaml │ ├── sdf_concat.yaml │ ├── sdf_txt.yaml │ ├── sdf_chd8bar.yaml │ ├── sdf_chdvnl.yaml │ ├── sdf_chd8bar_txt.yaml │ └── sdf_chd8bar_txt_mix2.yaml ├── ddpm │ ├── utils.py │ └── __init__.py ├── remove_pickle.py ├── mir_eval │ ├── __init__.py │ ├── setup.py │ ├── onset.py │ └── tempo.py ├── cleanup_checkpoints.py ├── dirs.py ├── polydis_aftertouch.py ├── models │ ├── model_autoencoder.py │ ├── model_ddpm.py │ └── model_chd_8bar.py ├── train │ ├── train_ddpm.py │ ├── train_chd_8bar.py │ ├── train_autoencoder.py │ ├── scheduler.py │ ├── __init__.py │ └── train_ldm.py ├── main.py ├── lightning_learner.py └── prepare_data.py ├── data └── train_split_pnt │ ├── pop909.pickle │ ├── musicalion.pickle │ ├── split_dict.pickle │ └── musicalion.pickle_unclean ├── requirements.txt ├── LICENSE ├── .gitignore └── README.md /polyffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/io/implement/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/example.sh: -------------------------------------------------------------------------------- 1 | python main.py ./example.mid ./example.out -------------------------------------------------------------------------------- /polyffusion/pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "reportInvalidStringEscapeSequence": false 3 | } 4 | -------------------------------------------------------------------------------- /polyffusion/stable_diffusion/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .contperceptual import LPIPSWithDiscriminator 2 | -------------------------------------------------------------------------------- /polyffusion/data/example.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/polyffusion/data/example.mid -------------------------------------------------------------------------------- /polyffusion/chord_extractor/.gitignore: -------------------------------------------------------------------------------- 1 | /cache_data 2 | /temp 3 | /.idea 4 | /__pycache__ 5 | *.pyc 6 | /output 7 | -------------------------------------------------------------------------------- /data/train_split_pnt/pop909.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/data/train_split_pnt/pop909.pickle -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/.gitignore: -------------------------------------------------------------------------------- 1 | /cache_data 2 | /temp 3 | /.idea 4 | /__pycache__ 5 | *.pyc 6 | /output 7 | /lib 8 | -------------------------------------------------------------------------------- /data/train_split_pnt/musicalion.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/data/train_split_pnt/musicalion.pickle -------------------------------------------------------------------------------- /data/train_split_pnt/split_dict.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/data/train_split_pnt/split_dict.pickle -------------------------------------------------------------------------------- /polyffusion/chord_extractor/example.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/polyffusion/chord_extractor/example.mid -------------------------------------------------------------------------------- /data/train_split_pnt/musicalion.pickle_unclean: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/data/train_split_pnt/musicalion.pickle_unclean -------------------------------------------------------------------------------- /polyffusion/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="polyffusion", 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/extractors/__init__.py: -------------------------------------------------------------------------------- 1 | from mir.extractors.extractor_base import ExtractorBase 2 | 3 | __all__ = ["ExtractorBase"] 4 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/requirements.txt: -------------------------------------------------------------------------------- 1 | pydub>=0.23.1 2 | pretty_midi>=0.2.9 3 | joblib>=0.13.2 4 | librosa>=0.7.2 5 | mir_eval>=0.5 6 | numpy>=1.16 -------------------------------------------------------------------------------- /polyffusion/chord_extractor/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="chord_extractor", 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/requirements.txt: -------------------------------------------------------------------------------- 1 | librosa>=0.6.1 2 | joblib>=0.12.2 3 | pydub>=0.22.1 4 | numpy>=1.15.4 5 | h5py>=2.8.0 6 | torch>=0.4.1 7 | pretty_midi>=0.2.8 8 | 9 | -------------------------------------------------------------------------------- /polyffusion/dl_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .chord_dec import ChordDecoder 2 | from .chord_enc import RnnEncoder as ChordEncoder 3 | from .naive_nn import NaiveNN 4 | from .pianotree_dec import PianoTreeDecoder 5 | from .pianotree_enc import PianoTreeEncoder 6 | from .txt_enc import TextureEncoder 7 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/settings.py: -------------------------------------------------------------------------------- 1 | SONIC_VISUALIZER_PATH = "C:/Program Files (x86)/Sonic Visualiser/Sonic Visualiser.exe" 2 | SONIC_ANNOTATOR_PATH = ( 3 | "C:/Program Files (x86)/Sonic Visualiser/annotator/sonic-annotator.exe" 4 | ) 5 | DEFAULT_DATA_STORAGE_PATH = "E:/dataset/" 6 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/__init__.py: -------------------------------------------------------------------------------- 1 | from mir.common import PACKAGE_PATH, WORKING_PATH 2 | from mir.data_file import DataEntry, DataPool, TextureBuilder 3 | 4 | __all__ = [ 5 | "TextureBuilder", 6 | "DataEntry", 7 | "WORKING_PATH", 8 | "PACKAGE_PATH", 9 | "DataPool", 10 | "io", 11 | ] 12 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from mir.settings import * 4 | 5 | WORKING_PATH = os.getcwd() 6 | PACKAGE_PATH = os.path.dirname(os.path.abspath(__file__)) 7 | 8 | DEFAULT_DATA_STORAGE_PATH = DEFAULT_DATA_STORAGE_PATH.replace( 9 | "$project_name$", os.path.basename(os.getcwd()) 10 | ) 11 | -------------------------------------------------------------------------------- /polyffusion/params/chd_8bar.yaml: -------------------------------------------------------------------------------- 1 | model_name: chd_8bar 2 | batch_size: 128 3 | max_epoch: 1000 4 | learning_rate: 0.001 5 | max_grad_norm: 10 6 | fp16: true 7 | tfr_chd: 8 | - 0.5 9 | - 0 10 | num_workers: 4 11 | pin_memory: true 12 | chd_n_step: 32 13 | chd_input_dim: 36 14 | chd_z_input_dim: 512 15 | chd_hidden_dim: 512 16 | chd_z_dim: 512 17 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/list_io.py: -------------------------------------------------------------------------------- 1 | from mir.io.feature_io_base import * 2 | 3 | 4 | class ListIO(FeatureIO): 5 | def read(self, filename, entry): 6 | return pickle_read(self, filename) 7 | 8 | def write(self, data, filename, entry): 9 | pickle_write(self, data, filename) 10 | 11 | def visualize(self, data, filename, entry, override_sr): 12 | return NotImplementedError() 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.10.13 2 | torch==2.1.2 3 | imageio==2.33.1 4 | jams==0.3.4 5 | joblib==1.3.2 6 | labml==0.4.168 7 | labml_helpers==0.4.89 8 | librosa==0.10.1 9 | lightning==2.1.2 10 | matplotlib==3.8.0 11 | muspy==0.5.0 12 | numpy==1.26.3 13 | omegaconf==2.3.0 14 | Pillow==10.2.0 15 | pretty_midi==0.2.10 16 | pydub==0.25.1 17 | Requests==2.31.0 18 | setuptools==68.2.2 19 | torchvision==0.16.2 20 | tqdm==4.65.0 21 | wandb==0.16.1 22 | -------------------------------------------------------------------------------- /polyffusion/params/autoencoder.yaml: -------------------------------------------------------------------------------- 1 | model_name: autoencoder 2 | batch_size: 16 3 | max_epoch: 100 4 | learning_rate: 5.0e-05 5 | max_grad_norm: 10 6 | fp16: false 7 | num_workers: 4 8 | pin_memory: true 9 | in_channels: 3 10 | out_channels: 3 11 | z_channels: 4 12 | channels: 64 13 | n_res_blocks: 2 14 | channel_multipliers: 15 | - 1 16 | - 2 17 | - 4 18 | - 4 19 | emb_channels: 4 20 | disc_start: 50001 21 | kl_weight: 1.0e-06 22 | disc_weight: 0.5 23 | -------------------------------------------------------------------------------- /polyffusion/ddpm/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | --- 3 | title: Utility functions for DDPM experiment 4 | summary: > 5 | Utility functions for DDPM experiment 6 | --- 7 | # Utility functions for [DDPM](index.html) experiemnt 8 | """ 9 | import torch 10 | import torch.utils.data 11 | 12 | 13 | def gather(consts: torch.Tensor, t: torch.Tensor): 14 | """Gather consts for $t$ and reshape to feature map shape""" 15 | c = consts.gather(-1, t) 16 | return c.reshape(-1, 1, 1, 1) 17 | -------------------------------------------------------------------------------- /polyffusion/params/ddpm.yaml: -------------------------------------------------------------------------------- 1 | model_name: ddpm 2 | batch_size: 16 3 | max_epoch: 100 4 | learning_rate: 2.0e-05 5 | max_grad_norm: 10 6 | fp16: false 7 | num_workers: 4 8 | pin_memory: true 9 | beta: 0.1 10 | weights: 11 | - 1 12 | - 0.5 13 | image_channels: 2 14 | image_size_h: 128 15 | image_size_w: 128 16 | n_channels: 64 17 | channel_multipliers: 18 | - 1 19 | - 2 20 | - 2 21 | - 4 22 | is_attention: 23 | - false 24 | - false 25 | - false 26 | - true 27 | n_steps: 1000 28 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/io/implement/unknown_io.py: -------------------------------------------------------------------------------- 1 | from mir.io.feature_io_base import * 2 | 3 | 4 | class UnknownIO(FeatureIO): 5 | def read(self, filename, entry): 6 | raise Exception("Unknown type cannot be read") 7 | 8 | def write(self, data, filename, entry): 9 | raise Exception("Unknown type cannot be written") 10 | 11 | def visualize(self, data, filename, entry, override_sr): 12 | raise Exception("Unknown type cannot be visualized") 13 | -------------------------------------------------------------------------------- /polyffusion/remove_pickle.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import sys 3 | 4 | from dirs import * 5 | 6 | if __name__ == "__main__": 7 | remove = sys.argv[1] 8 | with open(f"{TRAIN_SPLIT_DIR}/pop909.pickle", "rb") as f: 9 | pic = pickle.load(f) 10 | assert remove in pic[0] or remove in pic[1] 11 | if remove in pic[0]: 12 | pic[0].remove(remove) 13 | else: 14 | pic[1].remove(remove) 15 | with open(f"{TRAIN_SPLIT_DIR}/pop909.pickle", "wb") as f: 16 | pickle.dump(pic, f) 17 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/io/implement/midi_io.py: -------------------------------------------------------------------------------- 1 | import pretty_midi 2 | from mir.io.feature_io_base import * 3 | 4 | 5 | class MidiIO(FeatureIO): 6 | def read(self, filename, entry): 7 | midi_data = pretty_midi.PrettyMIDI(filename) 8 | return midi_data 9 | 10 | def write(self, data, filename, entry): 11 | data.write(filename) 12 | 13 | def visualize(self, data, filename, entry, override_sr): 14 | data.write(filename) 15 | 16 | def get_visualize_extention_name(self): 17 | return "mid" 18 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/data/sparse_tag_template.svl: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | [__DATA__] 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /polyffusion/params/sdf_pnotree.yaml: -------------------------------------------------------------------------------- 1 | model_name: sdf_pnotree 2 | batch_size: 16 3 | max_epoch: 50 4 | learning_rate: 5.0e-05 5 | max_grad_norm: 10 6 | fp16: true 7 | num_workers: 4 8 | pin_memory: true 9 | in_channels: 2 10 | out_channels: 2 11 | channels: 64 12 | attention_levels: 13 | - 2 14 | - 3 15 | n_res_blocks: 2 16 | channel_multipliers: 17 | - 1 18 | - 2 19 | - 4 20 | - 4 21 | n_heads: 4 22 | tf_layers: 1 23 | d_cond: 2048 24 | linear_start: 0.00085 25 | linear_end: 0.012 26 | n_steps: 1000 27 | latent_scaling_factor: 0.18215 28 | img_h: 128 29 | img_w: 128 30 | cond_type: pnotree 31 | cond_mode: mix 32 | -------------------------------------------------------------------------------- /polyffusion/params/sdf_txtvnl.yaml: -------------------------------------------------------------------------------- 1 | model_name: sdf_txtvnl 2 | batch_size: 16 3 | max_epoch: 200 4 | learning_rate: 5.0e-05 5 | max_grad_norm: 10 6 | fp16: true 7 | num_workers: 4 8 | pin_memory: true 9 | in_channels: 2 10 | out_channels: 2 11 | channels: 64 12 | attention_levels: 13 | - 2 14 | - 3 15 | n_res_blocks: 2 16 | channel_multipliers: 17 | - 1 18 | - 2 19 | - 4 20 | - 4 21 | n_heads: 4 22 | tf_layers: 1 23 | d_cond: 128 24 | linear_start: 0.00085 25 | linear_end: 0.012 26 | n_steps: 1000 27 | latent_scaling_factor: 0.18215 28 | img_h: 128 29 | img_w: 128 30 | cond_type: txt 31 | cond_mode: mix 32 | use_enc: false 33 | -------------------------------------------------------------------------------- /polyffusion/mir_eval/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Top-level module for mir_eval""" 3 | 4 | # Import all submodules (for each task) 5 | from . import alignment 6 | from . import beat 7 | from . import chord 8 | from . import io 9 | from . import onset 10 | from . import segment 11 | from . import separation 12 | from . import util 13 | from . import sonify 14 | from . import melody 15 | from . import multipitch 16 | from . import pattern 17 | from . import tempo 18 | from . import hierarchy 19 | from . import transcription 20 | from . import transcription_velocity 21 | from . import key 22 | 23 | __version__ = '0.7' 24 | -------------------------------------------------------------------------------- /polyffusion/params/sdf.yaml: -------------------------------------------------------------------------------- 1 | model_name: sdf 2 | batch_size: 16 3 | max_epoch: 100 4 | learning_rate: 5.0e-05 5 | max_grad_norm: 10 6 | fp16: true 7 | num_workers: 4 8 | pin_memory: true 9 | in_channels: 2 10 | out_channels: 2 11 | channels: 64 12 | attention_levels: 13 | - 2 14 | - 3 15 | n_res_blocks: 2 16 | channel_multipliers: 17 | - 1 18 | - 2 19 | - 4 20 | - 4 21 | n_heads: 4 22 | tf_layers: 1 23 | d_cond: 1152 24 | linear_start: 0.00085 25 | linear_end: 0.012 26 | n_steps: 1000 27 | latent_scaling_factor: 0.18215 28 | img_h: 128 29 | img_w: 128 30 | cond_type: chord 31 | cond_mode: uncond 32 | use_enc: false 33 | chd_n_step: 32 34 | chd_input_dim: 36 35 | -------------------------------------------------------------------------------- /polyffusion/params/sdf_concat.yaml: -------------------------------------------------------------------------------- 1 | model_name: sdf_concat 2 | batch_size: 16 3 | max_epoch: 100 4 | learning_rate: 5.0e-05 5 | max_grad_norm: 10 6 | fp16: true 7 | num_workers: 4 8 | pin_memory: true 9 | in_channels: 3 10 | out_channels: 2 11 | channels: 64 12 | attention_levels: 13 | - 2 14 | - 3 15 | n_res_blocks: 2 16 | channel_multipliers: 17 | - 1 18 | - 2 19 | - 4 20 | - 4 21 | n_heads: 4 22 | tf_layers: 1 23 | d_cond: 1152 24 | linear_start: 0.00085 25 | linear_end: 0.012 26 | n_steps: 1000 27 | latent_scaling_factor: 0.18215 28 | img_h: 128 29 | img_w: 128 30 | cond_type: chord 31 | cond_mode: uncond 32 | use_enc: false 33 | concat_blurry: true 34 | concat_ratio: 0.25 35 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/data/midi_template.svl: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | [__DATA__] 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/data/curve_template.svl: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | [__DATA__] 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /polyffusion/dl_modules/naive_nn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class NaiveNN(nn.Module): 5 | def __init__( 6 | self, 7 | input_dim=512, 8 | output_dim=512, 9 | ): 10 | """Only two linear layers""" 11 | super(NaiveNN, self).__init__() 12 | self.linear1 = nn.Linear(input_dim, output_dim) 13 | self.linear2 = nn.Linear(output_dim, output_dim) 14 | self.input_dim = input_dim 15 | self.output_dim = output_dim 16 | 17 | def forward(self, z_x): 18 | output = self.linear1(z_x) 19 | output = self.linear2(output) 20 | # print(output_mu.size(), output_var.size()) 21 | return output 22 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/jams_io.py: -------------------------------------------------------------------------------- 1 | import jams 2 | from mir.io import FeatureIO 3 | 4 | 5 | class JamsIO(FeatureIO): 6 | def read(self, filename, entry): 7 | return jams.load(filename) 8 | 9 | def write(self, data: jams.JAMS, filename, entry): 10 | data.save(filename) 11 | 12 | def visualize(self, data: jams.JAMS, filename, entry, override_sr): 13 | f = open(filename, "w") 14 | for annotation in data.annotations: 15 | for obs in annotation.data: 16 | f.write( 17 | "%f\t%f\t%s\n" % (obs.time, obs.time + obs.duration, str(obs.value)) 18 | ) 19 | f.close() 20 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/data/tuning.n3: -------------------------------------------------------------------------------- 1 | @prefix xsd: . 2 | @prefix vamp: . 3 | @prefix : <#> . 4 | 5 | :transform a vamp:Transform ; 6 | vamp:plugin ; 7 | vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ; 8 | vamp:block_size "[__WIN_SIZE__]"^^xsd:int ; 9 | vamp:plugin_version """5""" ; 10 | vamp:parameter_binding [ 11 | vamp:parameter [ vamp:identifier "rollon" ] ; 12 | vamp:value "0"^^xsd:float ; 13 | ] ; 14 | vamp:output . 15 | -------------------------------------------------------------------------------- /polyffusion/cleanup_checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | if __name__ == "__main__": 4 | result = "./result" 5 | for dir in os.listdir(result): 6 | dpath = f"{result}/{dir}" 7 | if dir == "try": 8 | os.system(f"rm -rf {dpath}/*") 9 | continue 10 | for item in os.listdir(dpath): 11 | item_dpath = f"{dpath}/{item}" 12 | chkpt_dir = f"{item_dpath}/chkpts" 13 | if not os.path.exists(f"{chkpt_dir}/weights.pt"): 14 | os.system(f"ls -l {chkpt_dir}") 15 | y = input(f"Remove {item_dpath} (y/n)?") 16 | if y == "y": 17 | os.system(f"rm -rf {item_dpath}") 18 | -------------------------------------------------------------------------------- /polyffusion/params/sdf_txt.yaml: -------------------------------------------------------------------------------- 1 | model_name: sdf_txt 2 | batch_size: 16 3 | max_epoch: 200 4 | learning_rate: 5.0e-05 5 | max_grad_norm: 10 6 | fp16: true 7 | num_workers: 4 8 | pin_memory: true 9 | in_channels: 2 10 | out_channels: 2 11 | channels: 64 12 | attention_levels: 13 | - 2 14 | - 3 15 | n_res_blocks: 2 16 | channel_multipliers: 17 | - 1 18 | - 2 19 | - 4 20 | - 4 21 | n_heads: 4 22 | tf_layers: 1 23 | d_cond: 1024 24 | linear_start: 0.00085 25 | linear_end: 0.012 26 | n_steps: 1000 27 | latent_scaling_factor: 0.18215 28 | img_h: 128 29 | img_w: 128 30 | cond_type: txt 31 | cond_mode: mix 32 | use_enc: true 33 | txt_emb_size: 256 34 | txt_hidden_dim: 1024 35 | txt_z_dim: 256 36 | txt_num_channel: 10 37 | -------------------------------------------------------------------------------- /polyffusion/params/sdf_chd8bar.yaml: -------------------------------------------------------------------------------- 1 | model_name: sdf_chd8bar 2 | batch_size: 16 3 | max_epoch: 100 4 | learning_rate: 5.0e-05 5 | max_grad_norm: 10 6 | fp16: true 7 | num_workers: 4 8 | pin_memory: true 9 | in_channels: 2 10 | out_channels: 2 11 | channels: 64 12 | attention_levels: 13 | - 2 14 | - 3 15 | n_res_blocks: 2 16 | channel_multipliers: 17 | - 1 18 | - 2 19 | - 4 20 | - 4 21 | n_heads: 4 22 | tf_layers: 1 23 | d_cond: 512 24 | linear_start: 0.00085 25 | linear_end: 0.012 26 | n_steps: 1000 27 | latent_scaling_factor: 0.18215 28 | img_h: 128 29 | img_w: 128 30 | cond_type: chord 31 | cond_mode: mix 32 | use_enc: true 33 | chd_n_step: 32 34 | chd_input_dim: 36 35 | chd_z_input_dim: 512 36 | chd_hidden_dim: 512 37 | chd_z_dim: 512 38 | -------------------------------------------------------------------------------- /polyffusion/params/sdf_chdvnl.yaml: -------------------------------------------------------------------------------- 1 | model_name: sdf_chdvnl 2 | batch_size: 16 3 | max_epoch: 100 4 | learning_rate: 5.0e-05 5 | max_grad_norm: 10 6 | fp16: true 7 | num_workers: 4 8 | pin_memory: true 9 | in_channels: 2 10 | out_channels: 2 11 | channels: 64 12 | attention_levels: 13 | - 2 14 | - 3 15 | n_res_blocks: 2 16 | channel_multipliers: 17 | - 1 18 | - 2 19 | - 4 20 | - 4 21 | n_heads: 4 22 | tf_layers: 1 23 | d_cond: 1152 24 | linear_start: 0.00085 25 | linear_end: 0.012 26 | n_steps: 1000 27 | latent_scaling_factor: 0.18215 28 | img_h: 128 29 | img_w: 128 30 | cond_type: chord 31 | cond_mode: mix 32 | use_enc: false 33 | chd_n_step: 32 34 | chd_input_dim: 36 35 | chd_z_input_dim: 512 36 | chd_hidden_dim: 512 37 | chd_z_dim: 512 38 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/air_io.py: -------------------------------------------------------------------------------- 1 | from mir.io.feature_io_base import * 2 | 3 | 4 | class AirIO(FeatureIO): 5 | def read(self, filename, entry): 6 | return pickle_read(self, filename) 7 | 8 | def write(self, data, filename, entry): 9 | return pickle_write(self, data, filename) 10 | 11 | def visualize(self, data, filename, entry, override_sr): 12 | arr = data.export_to_array() 13 | from mir.io.implement.regional_spectrogram_io import RegionalSpectrogramIO 14 | 15 | return RegionalSpectrogramIO().visualize( 16 | arr, filename, entry, override_sr=override_sr 17 | ) 18 | 19 | def get_visualize_extention_name(self): 20 | return "svl" 21 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .feature_io_base import FeatureIO, LoadingPlaceholder 2 | from .implement.chroma_io import ChromaIO 3 | from .implement.midi_io import MidiIO 4 | from .implement.music_io import MusicIO 5 | from .implement.regional_spectrogram_io import RegionalSpectrogramIO 6 | from .implement.scalar_io import FloatIO, IntegerIO 7 | from .implement.spectrogram_io import SpectrogramIO 8 | from .implement.unknown_io import UnknownIO 9 | 10 | __all__ = [ 11 | "FeatureIO", 12 | "LoadingPlaceholder", 13 | "ChromaIO", 14 | "MidiIO", 15 | "MusicIO", 16 | "SpectrogramIO", 17 | "IntegerIO", 18 | "FloatIO", 19 | "RegionalSpectrogramIO", 20 | "UnknownIO", 21 | ] 22 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/data/spectrogram_template.svl: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | [__DATA__] 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/io/implement/music_io.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | from mir.io.feature_io_base import * 3 | 4 | 5 | class MusicIO(FeatureIO): 6 | def read(self, filename, entry): 7 | y, sr = librosa.load(filename, sr=entry.prop.sr, mono=True) 8 | return y # (y-np.mean(y))/np.std(y) 9 | 10 | def write(self, data, filename, entry): 11 | sr = entry.prop.sr 12 | librosa.output.write_wav(filename, y=data, sr=sr, norm=False) 13 | 14 | def visualize(self, data, filename, entry, override_sr): 15 | sr = entry.prop.sr 16 | librosa.output.write_wav( 17 | filename, y=data, sr=sr, norm=True 18 | ) # otherwise I would be deaf 19 | 20 | def get_visualize_extention_name(self): 21 | return "wav" 22 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/README.TXT: -------------------------------------------------------------------------------- 1 | 这个代码理论上是可以直接按里面的README.TXT跑起来的,但是针对的是流行音乐的和弦(更准确的说是lmd数据集里的多轨流行音乐MIDI),可能对于古典音乐,musicalian,单轨音乐不太适用(但可以试试)。 2 | 3 | 算法是模板匹配+动态规划平滑。首先识别哪些轨道是和弦轨道(是软标签,main.py第47行的weights),用和弦轨道组成的chroma和bass chroma去和模板匹配。和弦的转移考虑了流行音乐的节拍性质,在downbeat的时候比较宽松,在upbeat和节拍内的时候比较严格。关于怎么调整: 4 | 5 | 1. 首先要保证MIDI里的节拍信息(也就是meter和tempo的标记)是准确的,也就是说用pretty_midi获得的midi.get_beats()和midi.get_downbeats()需要是准确的。如果不准确的话,需要手动修正MIDI里的这些标记。 6 | 2. 和弦类别需要做相应的修改。在chord_class.py文件里现在都是流行和爵士的和弦标签,QUALITIES存的是可能的和弦类型,INVERSIONS里存的是需要考虑转位的和弦类型,以及每个类型的转位种类。需要根据音乐风格做相应的调整。 7 | 8 | Junyan 9 | 10 | 11 | 安装 12 | >> pip3 install -r requirements.txt 13 | 14 | 用法 15 | >> python3 main.py ./example.mid ./example.out 16 | 可以参考main.py和example.sh 17 | 18 | 其他 19 | 和弦字典见chord_class.py 20 | 21 | Junyan 22 | -------------------------------------------------------------------------------- /polyffusion/params/sdf_chd8bar_txt.yaml: -------------------------------------------------------------------------------- 1 | model_name: sdf_chd8bar_txt 2 | batch_size: 16 3 | max_epoch: 100 4 | learning_rate: 5.0e-05 5 | max_grad_norm: 10 6 | fp16: true 7 | num_workers: 4 8 | pin_memory: true 9 | in_channels: 2 10 | out_channels: 2 11 | channels: 64 12 | attention_levels: 13 | - 2 14 | - 3 15 | n_res_blocks: 2 16 | channel_multipliers: 17 | - 1 18 | - 2 19 | - 4 20 | - 4 21 | n_heads: 4 22 | tf_layers: 1 23 | d_cond: 1536 24 | linear_start: 0.00085 25 | linear_end: 0.012 26 | n_steps: 1000 27 | latent_scaling_factor: 0.18215 28 | img_h: 128 29 | img_w: 128 30 | cond_type: chord+txt 31 | cond_mode: mix 32 | use_enc: true 33 | chd_n_step: 32 34 | chd_input_dim: 36 35 | chd_z_input_dim: 512 36 | chd_hidden_dim: 512 37 | chd_z_dim: 512 38 | txt_emb_size: 256 39 | txt_hidden_dim: 1024 40 | txt_z_dim: 256 41 | txt_num_channel: 10 42 | -------------------------------------------------------------------------------- /polyffusion/params/sdf_chd8bar_txt_mix2.yaml: -------------------------------------------------------------------------------- 1 | model_name: sdf_chd8bar_txt_mix2 2 | batch_size: 16 3 | max_epoch: 100 4 | learning_rate: 5.0e-05 5 | max_grad_norm: 10 6 | fp16: true 7 | num_workers: 0 8 | pin_memory: false 9 | in_channels: 2 10 | out_channels: 2 11 | channels: 64 12 | attention_levels: 13 | - 2 14 | - 3 15 | n_res_blocks: 2 16 | channel_multipliers: 17 | - 1 18 | - 2 19 | - 4 20 | - 4 21 | n_heads: 4 22 | tf_layers: 1 23 | d_cond: 1536 24 | linear_start: 0.00085 25 | linear_end: 0.012 26 | n_steps: 1000 27 | latent_scaling_factor: 0.18215 28 | img_h: 128 29 | img_w: 128 30 | cond_type: chord+txt 31 | cond_mode: mix2 32 | use_enc: true 33 | chd_n_step: 32 34 | chd_input_dim: 36 35 | chd_z_input_dim: 512 36 | chd_hidden_dim: 512 37 | chd_z_dim: 512 38 | txt_emb_size: 256 39 | txt_hidden_dim: 1024 40 | txt_z_dim: 256 41 | txt_num_channel: 10 42 | -------------------------------------------------------------------------------- /polyffusion/dl_modules/chord_enc.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.distributions import Normal 3 | 4 | 5 | class RnnEncoder(nn.Module): 6 | def __init__(self, input_dim, hidden_dim, z_dim): 7 | super(RnnEncoder, self).__init__() 8 | self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True, bidirectional=True) 9 | self.linear_mu = nn.Linear(hidden_dim * 2, z_dim) 10 | self.linear_var = nn.Linear(hidden_dim * 2, z_dim) 11 | self.input_dim = input_dim 12 | self.hidden_dim = hidden_dim 13 | self.z_dim = z_dim 14 | 15 | def forward(self, x): 16 | x = self.gru(x)[-1] 17 | x = x.transpose_(0, 1).contiguous() 18 | x = x.view(x.size(0), -1) 19 | mu = self.linear_mu(x) 20 | var = self.linear_var(x).exp_() 21 | dist = Normal(mu, var) 22 | return dist 23 | -------------------------------------------------------------------------------- /polyffusion/dirs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # dataset paths 4 | DATA_DIR = "data/LOP_4_bin_pnt" 5 | TRAIN_SPLIT_DIR = "data/train_split_pnt" 6 | 7 | MUSICALION_DATA_DIR = "data/musicalion_solo_piano_4_bin_pnt" 8 | POP909_DATA_DIR = "data/POP909_4_bin_pnt_8bar" 9 | 10 | # pretrained path 11 | PT_PNOTREE_PATH = "pretrained/pnotree_20/train_20-last-model.pt" 12 | 13 | PT_POLYDIS_PATH = "pretrained/polydis/model_master_final.pt" 14 | PT_A2S_PATH = "pretrained/a2s/a2s-stage3a.pt" 15 | 16 | # pretrained chd_8bar 17 | PT_CHD_8BAR_PATH = "pretrained/chd8bar/weights.pt" 18 | 19 | # the path to store demo. 20 | DEMO_FOLDER = "./demo" 21 | 22 | # the path to save trained model params and tensorboard log. 23 | RESULT_PATH = "./result" 24 | 25 | if not os.path.exists(DEMO_FOLDER): 26 | os.mkdir(DEMO_FOLDER) 27 | 28 | if not os.path.exists(RESULT_PATH): 29 | os.mkdir(RESULT_PATH) 30 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/music_base.py: -------------------------------------------------------------------------------- 1 | NUM_TO_ABS_SCALE = ["C", "C#", "D", "Eb", "E", "F", "F#", "G", "Ab", "A", "Bb", "B"] 2 | 3 | 4 | def get_scale_and_suffix(name): 5 | result = "C*D*EF*G*A*B".index(name[0]) 6 | prefix_length = 1 7 | if len(name) > 1: 8 | if name[1] == "b": 9 | result = result - 1 10 | if result < 0: 11 | result += 12 12 | prefix_length = 2 13 | if name[1] == "#": 14 | result = result + 1 15 | if result >= 12: 16 | result -= 12 17 | prefix_length = 2 18 | return result, name[prefix_length:] 19 | 20 | 21 | def scale_name_to_value(name): 22 | result = "1*2*34*5*6*78*9".index( 23 | name[-1] 24 | ) # 8 and 9 are for weird tagging in some mirex chords 25 | return (result - name.count("b") + name.count("#") + 12) % 12 26 | -------------------------------------------------------------------------------- /polyffusion/polydis_aftertouch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from polydis.model import DisentangleVAE 4 | from utils import estx_to_midi_file 5 | 6 | model_path = "pretrained/polydis/model_master_final.pt" 7 | # readme_fn = "./train.py" 8 | batch_size = 128 9 | length = 16 # 16 bars for inference 10 | # n_epoch = 6 11 | # clip = 1 12 | # parallel = False 13 | # weights = [1, 0.5] 14 | # beta = 0.1 15 | # tf_rates = [(0.6, 0), (0.5, 0), (0.5, 0)] 16 | # lr = 1e-3 17 | 18 | 19 | class PolydisAftertouch: 20 | def __init__(self) -> None: 21 | model = DisentangleVAE.init_model() 22 | model.load_model(model_path) 23 | print(f"loaded model {model_path}.") 24 | self.model = model 25 | 26 | def reconstruct(self, prmat, chd, fn, chd_sample=False): 27 | chd = chd.float() 28 | prmat = prmat.float() 29 | est_x = self.model.inference(prmat, chd, sample=False, chd_sample=chd_sample) 30 | estx_to_midi_file(est_x, fn) 31 | 32 | 33 | if __name__ == "__main__": 34 | aftertouch = PolydisAftertouch() 35 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/io/implement/scalar_io.py: -------------------------------------------------------------------------------- 1 | from mir.io.feature_io_base import * 2 | 3 | 4 | class FloatIO(FeatureIO): 5 | def read(self, filename, entry): 6 | f = open(filename, "r") 7 | result = float(f.readline().strip()) 8 | f.close() 9 | return result 10 | 11 | def write(self, data, filename, entry): 12 | f = open(filename, "w") 13 | f.write(str(float(data))) 14 | f.close() 15 | 16 | def visualize(self, data, filename, entry, override_sr): 17 | raise Exception("Cannot visualize a scalar") 18 | 19 | 20 | class IntegerIO(FeatureIO): 21 | def read(self, filename, entry): 22 | f = open(filename, "r") 23 | result = int(f.readline().strip()) 24 | f.close() 25 | return result 26 | 27 | def write(self, data, filename, entry): 28 | f = open(filename, "w") 29 | f.write(str(int(data))) 30 | f.close() 31 | 32 | def visualize(self, data, filename, entry, override_sr): 33 | raise Exception("Cannot visualize a scalar") 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Lejun Min 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /polyffusion/models/model_autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from stable_diffusion.model.autoencoder import Autoencoder 6 | from utils import * 7 | 8 | 9 | class Polyffusion_Autoencoder(nn.Module): 10 | def __init__(self, autoencoder: Autoencoder): 11 | super(Polyffusion_Autoencoder, self).__init__() 12 | self.autoencoder = autoencoder 13 | 14 | @classmethod 15 | def load_trained(cls, ldm, model_dir): 16 | model = cls(ldm) 17 | trained_leaner = torch.load(f"{model_dir}/weights.pt") 18 | model.load_state_dict(trained_leaner["model"]) 19 | return model 20 | 21 | def get_loss_dict(self, batch, step): 22 | """ 23 | z_y is the stuff the diffusion model needs to learn 24 | """ 25 | prmat, _, chord = batch 26 | # (#B, 2, 128, 128) 27 | print(f"prmat: {prmat.requires_grad}") 28 | prmat = F.pad(prmat, [0, 0, 0, 0, 0, 1], "constant", 0) 29 | print(f"prmat: {prmat.requires_grad}") 30 | # (#B, 3, 128, 128) 31 | return self.autoencoder.get_loss_dict(prmat, step) 32 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/data/pitch_template.svl: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | [__DATA_FREQ__] 9 | 10 | 11 | [__DATA_ENERGY__] 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /polyffusion/models/model_ddpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ddpm import DenoiseDiffusion 5 | from utils import * 6 | 7 | 8 | class Polyffusion_DDPM(nn.Module): 9 | def __init__( 10 | self, 11 | ddpm: DenoiseDiffusion, 12 | params, 13 | max_simu_note=20, 14 | ): 15 | super(Polyffusion_DDPM, self).__init__() 16 | self.params = params 17 | self.ddpm = ddpm 18 | 19 | @classmethod 20 | def load_trained(cls, ddpm, chkpt_fpath, params, max_simu_note=20): 21 | model = cls(ddpm, params, max_simu_note) 22 | trained_leaner = torch.load(chkpt_fpath) 23 | model.load_state_dict(trained_leaner["model"]) 24 | return model 25 | 26 | def p_sample(self, xt: torch.Tensor, t: torch.Tensor): 27 | return self.ddpm.p_sample(xt, t) 28 | 29 | def q_sample(self, x0: torch.Tensor, t: torch.Tensor): 30 | return self.ddpm.q_sample(x0, t) 31 | 32 | def get_loss_dict(self, batch, step): 33 | """ 34 | z_y is the stuff the diffusion model needs to learn 35 | """ 36 | prmat2c, pnotree, chord, prmat = batch 37 | return {"loss": self.ddpm.loss(prmat2c)} 38 | -------------------------------------------------------------------------------- /polyffusion/mir_eval/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="mir_eval", 5 | version="0.7", 6 | description="Common metrics for common audio/music processing tasks.", 7 | author="Colin Raffel", 8 | author_email="craffel@gmail.com", 9 | url="https://github.com/craffel/mir_eval", 10 | packages=find_packages(), 11 | classifiers=[ 12 | "License :: OSI Approved :: MIT License", 13 | "Programming Language :: Python", 14 | "Development Status :: 5 - Production/Stable", 15 | "Intended Audience :: Developers", 16 | "Topic :: Multimedia :: Sound/Audio :: Analysis", 17 | "Programming Language :: Python :: 3", 18 | ], 19 | keywords="audio music mir dsp", 20 | license="MIT", 21 | install_requires=[ 22 | "numpy >= 1.7.0", 23 | "scipy >= 1.0.0", 24 | ], 25 | extras_require={ 26 | "display": ["matplotlib>=1.5.0"], 27 | "testing": [ 28 | "matplotlib>=2.1.0", 29 | "decorator", 30 | "pytest", 31 | "pytest-cov", 32 | "pytest-mpl", 33 | "nose", 34 | ], 35 | }, 36 | python_requires=">=3", 37 | ) 38 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/madmom_io.py: -------------------------------------------------------------------------------- 1 | from mir.common import PACKAGE_PATH 2 | from mir.io.feature_io_base import * 3 | 4 | 5 | class MadmomBeatProbIO(FeatureIO): 6 | def read(self, filename, entry): 7 | return pickle_read(self, filename) 8 | 9 | def write(self, data, filename, entry): 10 | pickle_write(self, data, filename) 11 | 12 | def visualize(self, data, filename, entry, override_sr): 13 | f = open(os.path.join(PACKAGE_PATH, "data/spectrogram_template.svl"), "r") 14 | content = f.read() 15 | f.close() 16 | content = content.replace("[__SR__]", str(100)) 17 | content = content.replace("[__WIN_SHIFT__]", str(1)) 18 | content = content.replace("[__SHAPE_1__]", str(data.shape[1])) 19 | content = content.replace("[__COLOR__]", str(1)) 20 | labels = [str(i) for i in range(data.shape[1])] 21 | content = content.replace("[__DATA__]", create_svl_3d_data(labels, data)) 22 | f = open(filename, "w") 23 | f.write(content) 24 | f.close() 25 | 26 | def pre_assign(self, entry, proxy): 27 | entry.prop.set("n_frame", LoadingPlaceholder(proxy, entry)) 28 | 29 | def post_load(self, data, entry): 30 | entry.prop.set("n_frame", data.shape[0]) 31 | 32 | def get_visualize_extention_name(self): 33 | return "svl" 34 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/data/bothchroma.n3: -------------------------------------------------------------------------------- 1 | @prefix xsd: . 2 | @prefix vamp: . 3 | @prefix : <#> . 4 | 5 | :transform a vamp:Transform ; 6 | vamp:plugin ; 7 | vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ; 8 | vamp:block_size "[__WIN_SIZE__]"^^xsd:int ; 9 | vamp:plugin_version """5""" ; 10 | vamp:parameter_binding [ 11 | vamp:parameter [ vamp:identifier "chromanormalize" ] ; 12 | vamp:value "0"^^xsd:float ; 13 | ] ; 14 | vamp:parameter_binding [ 15 | vamp:parameter [ vamp:identifier "rollon" ] ; 16 | vamp:value "0"^^xsd:float ; 17 | ] ; 18 | vamp:parameter_binding [ 19 | vamp:parameter [ vamp:identifier "s" ] ; 20 | vamp:value "0.7"^^xsd:float ; 21 | ] ; 22 | vamp:parameter_binding [ 23 | vamp:parameter [ vamp:identifier "tuningmode" ] ; 24 | vamp:value "0"^^xsd:float ; 25 | ] ; 26 | vamp:parameter_binding [ 27 | vamp:parameter [ vamp:identifier "useNNLS" ] ; 28 | vamp:value "1"^^xsd:float ; 29 | ] ; 30 | vamp:parameter_binding [ 31 | vamp:parameter [ vamp:identifier "whitening" ] ; 32 | vamp:value "1"^^xsd:float ; 33 | ] ; 34 | vamp:output [ vamp:identifier "bothchroma" ] . 35 | -------------------------------------------------------------------------------- /polyffusion/dl_modules/txt_enc.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.distributions import Normal 3 | 4 | 5 | class TextureEncoder(nn.Module): 6 | def __init__(self, emb_size, hidden_dim, z_dim, num_channel=10): 7 | """input must be piano_mat: (B, 32, 128)""" 8 | super(TextureEncoder, self).__init__() 9 | self.cnn = nn.Sequential( 10 | nn.Conv2d(1, num_channel, kernel_size=(4, 12), stride=(4, 1), padding=0), 11 | nn.ReLU(), 12 | nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4)), 13 | ) 14 | self.fc1 = nn.Linear(num_channel * 29, 1000) 15 | self.fc2 = nn.Linear(1000, emb_size) 16 | self.gru = nn.GRU(emb_size, hidden_dim, batch_first=True, bidirectional=True) 17 | self.linear_mu = nn.Linear(hidden_dim * 2, z_dim) 18 | self.linear_var = nn.Linear(hidden_dim * 2, z_dim) 19 | self.emb_size = emb_size 20 | self.hidden_dim = hidden_dim 21 | self.z_dim = z_dim 22 | 23 | def forward(self, pr): 24 | # pr: (bs, 32, 128) 25 | bs = pr.size(0) 26 | pr = pr.unsqueeze(1) 27 | pr = self.cnn(pr).view(bs, 8, -1) 28 | pr = self.fc2(self.fc1(pr)) # (bs, 8, emb_size) 29 | pr = self.gru(pr)[-1] 30 | pr = pr.transpose_(0, 1).contiguous() 31 | pr = pr.view(pr.size(0), -1) 32 | mu = self.linear_mu(pr) 33 | var = self.linear_var(pr).exp_() 34 | dist = Normal(mu, var) 35 | return dist 36 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/data/chroma.n3: -------------------------------------------------------------------------------- 1 | @prefix xsd: . 2 | @prefix vamp: . 3 | @prefix : <#> . 4 | 5 | :transform_plugin a vamp:Plugin ; 6 | vamp:identifier "nnls-chroma" . 7 | 8 | :transform_library a vamp:PluginLibrary ; 9 | vamp:identifier "nnls-chroma" ; 10 | vamp:available_plugin :transform_plugin . 11 | 12 | :transform a vamp:Transform ; 13 | vamp:plugin :transform_plugin ; 14 | vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ; 15 | vamp:block_size "[__WIN_SIZE__]"^^xsd:int ; 16 | vamp:plugin_version """3""" ; 17 | vamp:sample_rate "[__SR__]"^^xsd:int ; 18 | vamp:parameter_binding [ 19 | vamp:parameter [ vamp:identifier "chromanormalize" ] ; 20 | vamp:value "0"^^xsd:float ; 21 | ] ; 22 | vamp:parameter_binding [ 23 | vamp:parameter [ vamp:identifier "rollon" ] ; 24 | vamp:value "0"^^xsd:float ; 25 | ] ; 26 | vamp:parameter_binding [ 27 | vamp:parameter [ vamp:identifier "s" ] ; 28 | vamp:value "0.7"^^xsd:float ; 29 | ] ; 30 | vamp:parameter_binding [ 31 | vamp:parameter [ vamp:identifier "tuningmode" ] ; 32 | vamp:value "0"^^xsd:float ; 33 | ] ; 34 | vamp:parameter_binding [ 35 | vamp:parameter [ vamp:identifier "useNNLS" ] ; 36 | vamp:value "1"^^xsd:float ; 37 | ] ; 38 | vamp:parameter_binding [ 39 | vamp:parameter [ vamp:identifier "whitening" ] ; 40 | vamp:value "1"^^xsd:float ; 41 | ] ; 42 | vamp:output [ vamp:identifier "chroma" ] . 43 | -------------------------------------------------------------------------------- /polyffusion/data/pop909_extractor.py: -------------------------------------------------------------------------------- 1 | from data.dataloader import get_train_val_dataloaders 2 | from utils import prmat_to_midi_file 3 | 4 | if __name__ == "__main__": 5 | # val_dataset = PianoOrchDataset.load_valid_set(use_track=[1, 2]) 6 | # val_dl = get_val_dataloader(1000, use_track=[0, 1, 2]) 7 | train_dl, val_dl = get_train_val_dataloaders() 8 | print(len(val_dl)) 9 | for i, batch in enumerate(val_dl): 10 | prmat2c, pnotree, chord, prmat = batch 11 | prmat_to_midi_file(prmat, "exp/ref_wm.mid") 12 | break 13 | 14 | # dir = "data/POP909_MIDIs" 15 | # os.makedirs(dir, exist_ok=True) 16 | 17 | # for i in range(1, 910): 18 | # fpath = os.path.join(POP909_DATA_DIR, f"{i:03}.npz") 19 | # print(fpath) 20 | # if not os.path.exists(fpath): 21 | # continue 22 | # data = np.load(fpath, allow_pickle=True) 23 | # notes = data["notes"] 24 | # midi = pm.PrettyMIDI() 25 | # piano_program = pm.instrument_name_to_program("Acoustic Grand Piano") 26 | # piano = pm.Instrument(program=piano_program) 27 | # one_beat = 0.125 28 | # for track in notes: 29 | # for note in track: 30 | # onset, pitch, duration, velocity, program = note 31 | # note = pm.Note( 32 | # velocity=velocity, 33 | # pitch=pitch, 34 | # start=onset * one_beat, 35 | # end=(onset + duration) * one_beat 36 | # ) 37 | # piano.notes.append(note) 38 | 39 | # midi.instruments.append(piano) 40 | # midi.write(f"{dir}/{i:03}.mid") 41 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/data/tunedlogfreqspec.n3: -------------------------------------------------------------------------------- 1 | @prefix xsd: . 2 | @prefix vamp: . 3 | @prefix : <#> . 4 | 5 | :transform_plugin a vamp:Plugin ; 6 | vamp:identifier "nnls-chroma" . 7 | 8 | :transform_library a vamp:PluginLibrary ; 9 | vamp:identifier "nnls-chroma" ; 10 | vamp:available_plugin :transform_plugin . 11 | 12 | :transform a vamp:Transform ; 13 | vamp:plugin ; 14 | vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ; 15 | vamp:block_size "[__WIN_SIZE__]"^^xsd:int ; 16 | vamp:plugin_version """5""" ; 17 | vamp:sample_rate "[__SR__]"^^xsd:int ; 18 | vamp:parameter_binding [ 19 | vamp:parameter [ vamp:identifier "chromanormalize" ] ; 20 | vamp:value "0"^^xsd:float ; 21 | ] ; 22 | vamp:parameter_binding [ 23 | vamp:parameter [ vamp:identifier "rollon" ] ; 24 | vamp:value "0"^^xsd:float ; 25 | ] ; 26 | vamp:parameter_binding [ 27 | vamp:parameter [ vamp:identifier "s" ] ; 28 | vamp:value "0.7"^^xsd:float ; 29 | ] ; 30 | vamp:parameter_binding [ 31 | vamp:parameter [ vamp:identifier "tuningmode" ] ; 32 | vamp:value "0"^^xsd:float ; 33 | ] ; 34 | vamp:parameter_binding [ 35 | vamp:parameter [ vamp:identifier "useNNLS" ] ; 36 | vamp:value "1"^^xsd:float ; 37 | ] ; 38 | vamp:parameter_binding [ 39 | vamp:parameter [ vamp:identifier "whitening" ] ; 40 | vamp:value "1"^^xsd:float ; 41 | ] ; 42 | vamp:output [ vamp:identifier "tunedlogfreqspec" ] . 43 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/extractors/rule_based_channel_reweight.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from extractors.midi_utilities import is_percussive_channel 3 | 4 | 5 | def get_channel_thickness(piano_roll): 6 | chroma_matrix = np.zeros((piano_roll.shape[0], 12)) 7 | for note in range(12): 8 | chroma_matrix[:, note] = np.sum(piano_roll[:, note::12], axis=1) 9 | thickness_array = (chroma_matrix > 0).sum(axis=1) 10 | if thickness_array.sum() == 0: 11 | return 0 12 | return thickness_array[thickness_array > 0].mean() 13 | 14 | 15 | def get_channel_bass_property(piano_roll): 16 | result = np.argwhere(piano_roll > 0)[:, 1] 17 | if len(result) == 0: 18 | return 0.0, 1.0 19 | return result.mean(), min(1.0, len(result) / len(piano_roll)) 20 | 21 | 22 | def midi_to_thickness_weights(midi): 23 | thickness = np.array( 24 | [ 25 | get_channel_thickness(ins.get_piano_roll().T) 26 | for ins in midi.instruments 27 | if not is_percussive_channel(ins) 28 | ] 29 | ) 30 | result = 1 - np.exp(-(thickness - 0.95)) 31 | result /= result.max() 32 | return result 33 | 34 | 35 | def midi_to_thickness_and_bass_weights(midi): 36 | rolls = [ 37 | ins.get_piano_roll().T 38 | for ins in midi.instruments 39 | if not is_percussive_channel(ins) 40 | ] 41 | thickness = np.array([get_channel_thickness(roll) for roll in rolls]) 42 | bass = np.array([get_channel_bass_property(roll) for roll in rolls]) 43 | bass[bass[:, 1] < 0.2, 0] = 128 44 | result = 1 - np.exp(-(thickness - 0.95)) 45 | result /= result.max() 46 | result[np.argmin(bass[:, 0])] = 1.0 47 | 48 | return result 49 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/beatlab_io.py: -------------------------------------------------------------------------------- 1 | from mir.common import PACKAGE_PATH 2 | from mir.io.feature_io_base import * 3 | 4 | 5 | class BeatLabIO(FeatureIO): 6 | def read(self, filename, entry): 7 | f = open(filename, "r") 8 | content = f.read() 9 | lines = content.split("\n") 10 | f.close() 11 | result = [] 12 | for i in range(len(lines)): 13 | line = lines[i].strip() 14 | if line == "": 15 | continue 16 | tokens = line.split("\t") 17 | result.append([float(tokens[0]), float(tokens[2])]) 18 | return result 19 | 20 | def write(self, data, filename, entry): 21 | f = open(filename, "w") 22 | for i in range(0, data.shape[0]): 23 | f.write("\t".join([str(item) for item in data[i, :]])) 24 | f.write("\n") 25 | f.close() 26 | 27 | def visualize(self, data, filename, entry, override_sr): 28 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r") 29 | sr = override_sr 30 | content = f.read() 31 | f.close() 32 | content = content.replace("[__SR__]", str(sr)) 33 | content = content.replace("[__STYLE__]", str(1)) 34 | output_text = "" 35 | for beat_info in data: 36 | output_text += '\n' % ( 37 | int(beat_info[0] * sr), 38 | int(beat_info[1]), 39 | ) 40 | content = content.replace("[__DATA__]", output_text) 41 | f = open(filename, "w") 42 | f.write(content) 43 | f.close() 44 | 45 | def get_visualize_extention_name(self): 46 | return "svl" 47 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/key_io.py: -------------------------------------------------------------------------------- 1 | from mir.io import FeatureIO 2 | 3 | 4 | class KeyIO(FeatureIO): 5 | def read(self, filename, entry): 6 | f = open(filename, "r") 7 | content = f.read() 8 | lines = content.split("\n") 9 | f.close() 10 | result = [] 11 | for i in range(len(lines)): 12 | line = lines[i].strip() 13 | if line == "": 14 | continue 15 | tokens = line.split("\t") 16 | assert len(tokens) == 3 17 | result.append([float(tokens[0]), float(tokens[1]), tokens[2]]) 18 | return result 19 | 20 | def write(self, data, filename, entry): 21 | f = open(filename, "w") 22 | for i in range(0, len(data)): 23 | f.write("\t".join([str(item).replace("\t", " ") for item in data[i]])) 24 | f.write("\n") 25 | f.close() 26 | 27 | def visualize(self, data, filename, entry, override_sr): 28 | sr = override_sr 29 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r") 30 | content = f.read() 31 | f.close() 32 | content = content.replace("[__SR__]", str(sr)) 33 | content = content.replace("[__STYLE__]", str(1)) 34 | results = [] 35 | for item in data: 36 | results.append( 37 | '' 38 | % (int(np.round(item[0] * sr)), item[2]) 39 | ) 40 | content = content.replace("[__DATA__]", "\n".join(results)) 41 | f = open(filename, "w") 42 | f.write(content) 43 | f.close() 44 | 45 | def get_visualize_extention_name(self): 46 | return "svl" 47 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/lyric_io.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | 3 | from mir.io.feature_io_base import * 4 | 5 | 6 | class LyricIO(FeatureIO): 7 | def read(self, filename, entry): 8 | f = open(filename, "r", encoding="utf-16-le") 9 | content = f.read() 10 | if content.startswith("\ufeff"): 11 | content = content[1:] 12 | lines = content.split("\n") 13 | f.close() 14 | result = [] 15 | for i in range(len(lines)): 16 | line = lines[i].strip() 17 | if line == "": 18 | continue 19 | tokens = line.split("\t") 20 | if len(tokens) == 3: 21 | result.append([float(tokens[0]), float(tokens[1]), tokens[2]]) 22 | elif len(tokens) == 4: # Contains sentence information 23 | result.append( 24 | [float(tokens[0]), float(tokens[1]), tokens[2], int(tokens[3])] 25 | ) 26 | else: 27 | raise Exception("Not supported format") 28 | return result 29 | 30 | def write(self, data, filename, entry): 31 | f = open(filename, "wb") 32 | f.write(codecs.BOM_UTF16_LE) 33 | for i in range(0, len(data)): 34 | f.write("\t".join([str(item) for item in data[i]]).encode("utf-16-le")) 35 | f.write("\n".encode("utf-16-le")) 36 | f.close() 37 | 38 | def visualize(self, data, filename, entry, override_sr): 39 | f = open(filename, "w") 40 | for i in range(0, len(data)): 41 | f.write("\t".join([str(item) for item in data[i]])) 42 | f.write("\n") 43 | f.close() 44 | 45 | def get_visualize_extention_name(self): 46 | return "txt" 47 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/tag_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mir.common import PACKAGE_PATH 3 | from mir.io.feature_io_base import * 4 | 5 | 6 | class TimedTagIO(FeatureIO): 7 | def read(self, filename, entry): 8 | f = open(filename, "r") 9 | content = f.read() 10 | lines = content.split("\n") 11 | f.close() 12 | result = [] 13 | for i in range(len(lines)): 14 | line = lines[i].strip() 15 | if line == "": 16 | continue 17 | tokens = line.split("\t") 18 | assert len(tokens) == 2 19 | result.append([float(tokens[0]), tokens[1]]) 20 | return result 21 | 22 | def write(self, data, filename, entry): 23 | f = open(filename, "w") 24 | for i in range(0, len(data)): 25 | f.write("\t".join([str(item) for item in data[i]])) 26 | f.write("\n") 27 | f.close() 28 | 29 | def visualize(self, data, filename, entry, override_sr): 30 | sr = override_sr 31 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r") 32 | content = f.read() 33 | f.close() 34 | content = content.replace("[__SR__]", str(sr)) 35 | content = content.replace("[__STYLE__]", str(1)) 36 | results = [] 37 | for item in data: 38 | results.append( 39 | '' 40 | % (int(np.round(item[0] * sr)), item[1]) 41 | ) 42 | content = content.replace("[__DATA__]", "\n".join(results)) 43 | f = open(filename, "w") 44 | f.write(content) 45 | f.close() 46 | 47 | def get_visualize_extention_name(self): 48 | return "svl" 49 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/chordlab_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mir.common import PACKAGE_PATH 3 | from mir.io.feature_io_base import * 4 | 5 | 6 | class ChordLabIO(FeatureIO): 7 | def read(self, filename, entry): 8 | f = open(filename, "r") 9 | content = f.read() 10 | lines = content.split("\n") 11 | f.close() 12 | result = [] 13 | for i in range(len(lines)): 14 | line = lines[i].strip() 15 | if line == "": 16 | continue 17 | tokens = line.split("\t") 18 | assert len(tokens) == 3 19 | result.append([float(tokens[0]), float(tokens[1]), tokens[2]]) 20 | return result 21 | 22 | def write(self, data, filename, entry): 23 | f = open(filename, "w") 24 | for i in range(0, len(data)): 25 | f.write("\t".join([str(item) for item in data[i]])) 26 | f.write("\n") 27 | f.close() 28 | 29 | def visualize(self, data, filename, entry, override_sr): 30 | sr = override_sr 31 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r") 32 | content = f.read() 33 | f.close() 34 | content = content.replace("[__SR__]", str(sr)) 35 | content = content.replace("[__STYLE__]", str(1)) 36 | results = [] 37 | for item in data: 38 | results.append( 39 | '' 40 | % (int(np.round(item[0] * sr)), item[2]) 41 | ) 42 | content = content.replace("[__DATA__]", "\n".join(results)) 43 | f = open(filename, "w") 44 | f.write(content) 45 | f.close() 46 | 47 | def get_visualize_extention_name(self): 48 | return "svl" 49 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/data/chordino.n3: -------------------------------------------------------------------------------- 1 | @prefix xsd: . 2 | @prefix vamp: . 3 | @prefix : <#> . 4 | 5 | :transform_plugin a vamp:Plugin ; 6 | vamp:identifier "chordino" . 7 | 8 | :transform_library a vamp:PluginLibrary ; 9 | vamp:identifier "nnls-chroma" ; 10 | vamp:available_plugin :transform_plugin . 11 | 12 | :transform a vamp:Transform ; 13 | vamp:plugin ; 14 | vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ; 15 | vamp:block_size "[__WIN_SIZE__]"^^xsd:int ; 16 | vamp:plugin_version """5""" ; 17 | vamp:sample_rate "[__SR__]"^^xsd:int ; 18 | vamp:parameter_binding [ 19 | vamp:parameter [ vamp:identifier "boostn" ] ; 20 | vamp:value "0.1"^^xsd:float ; 21 | ] ; 22 | vamp:parameter_binding [ 23 | vamp:parameter [ vamp:identifier "rollon" ] ; 24 | vamp:value "0"^^xsd:float ; 25 | ] ; 26 | vamp:parameter_binding [ 27 | vamp:parameter [ vamp:identifier "s" ] ; 28 | vamp:value "0.7"^^xsd:float ; 29 | ] ; 30 | vamp:parameter_binding [ 31 | vamp:parameter [ vamp:identifier "tuningmode" ] ; 32 | vamp:value "0"^^xsd:float ; 33 | ] ; 34 | vamp:parameter_binding [ 35 | vamp:parameter [ vamp:identifier "useHMM" ] ; 36 | vamp:value "1"^^xsd:float ; 37 | ] ; 38 | vamp:parameter_binding [ 39 | vamp:parameter [ vamp:identifier "useNNLS" ] ; 40 | vamp:value "1"^^xsd:float ; 41 | ] ; 42 | vamp:parameter_binding [ 43 | vamp:parameter [ vamp:identifier "whitening" ] ; 44 | vamp:value "1"^^xsd:float ; 45 | ] ; 46 | vamp:output [ vamp:identifier "simplechord" ] . 47 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/jointbeat_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mir.common import PACKAGE_PATH 3 | from mir.io.feature_io_base import * 4 | 5 | 6 | class JointBeatIO(FeatureIO): 7 | def read(self, filename, entry): 8 | f = open(filename, "r") 9 | content = f.read() 10 | lines = content.split("\n") 11 | f.close() 12 | result = [] 13 | for i in range(len(lines)): 14 | line = lines[i].strip() 15 | if line == "": 16 | continue 17 | tokens = line.split("\t") 18 | assert len(tokens) == 3 19 | result.append([float(tokens[0]), int(tokens[1]), int(tokens[2])]) 20 | return np.array(result) 21 | 22 | def write(self, data, filename, entry): 23 | f = open(filename, "w") 24 | for i in range(0, len(data)): 25 | f.write("\t".join([str(item) for item in data[i]])) 26 | f.write("\n") 27 | f.close() 28 | 29 | def visualize(self, data, filename, entry, override_sr): 30 | sr = override_sr 31 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r") 32 | content = f.read() 33 | f.close() 34 | content = content.replace("[__SR__]", str(sr)) 35 | content = content.replace("[__STYLE__]", str(1)) 36 | results = [] 37 | for item in data: 38 | results.append( 39 | '' 40 | % (int(np.round(item[0] * sr)), int(item[1]), int(item[2])) 41 | ) 42 | content = content.replace("[__DATA__]", "\n".join(results)) 43 | f = open(filename, "w") 44 | f.write(content) 45 | f.close() 46 | 47 | def get_visualize_extention_name(self): 48 | return "svl" 49 | -------------------------------------------------------------------------------- /polyffusion/train/train_ddpm.py: -------------------------------------------------------------------------------- 1 | from data.dataloader import get_custom_train_val_dataloaders, get_train_val_dataloaders 2 | from ddpm import DenoiseDiffusion 3 | from ddpm.unet import UNet 4 | from models.model_ddpm import Polyffusion_DDPM 5 | 6 | from . import * 7 | 8 | 9 | class DDPM_TrainConfig(TrainConfig): 10 | # U-Net model for $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ 11 | eps_model: UNet 12 | # [DDPM algorithm](index.html) 13 | diffusion: DenoiseDiffusion 14 | 15 | # Adam optimizer 16 | optimizer: torch.optim.Adam 17 | 18 | def __init__(self, params, output_dir, data_dir=None): 19 | super().__init__(params, None, output_dir) 20 | 21 | self.eps_model = UNet( 22 | image_channels=params.image_channels, 23 | n_channels=params.n_channels, 24 | ch_mults=params.channel_multipliers, 25 | is_attn=params.is_attention, 26 | ) 27 | 28 | # Create [DDPM class](index.html) 29 | self.diffusion = DenoiseDiffusion( 30 | eps_model=self.eps_model, 31 | n_steps=params.n_steps, 32 | ) 33 | 34 | self.model = Polyffusion_DDPM(self.diffusion, params) 35 | # Create dataloader 36 | if data_dir is None: 37 | self.train_dl, self.val_dl = get_train_val_dataloaders( 38 | params.batch_size, params.num_workers, params.pin_memory 39 | ) 40 | else: 41 | self.train_dl, self.val_dl = get_custom_train_val_dataloaders( 42 | params.batch_size, 43 | data_dir, 44 | num_workers=params.num_workers, 45 | pin_memory=params.pin_memory, 46 | ) 47 | # Create optimizer 48 | self.optimizer = torch.optim.Adam( 49 | self.eps_model.parameters(), lr=params.learning_rate 50 | ) 51 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/downbeat_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mir.common import PACKAGE_PATH 3 | from mir.io.feature_io_base import * 4 | 5 | 6 | class DownbeatIO(FeatureIO): 7 | def read(self, filename, entry): 8 | f = open(filename, "r") 9 | lines = f.readlines() 10 | lines = [line.strip("\n\r") for line in lines] 11 | lines = [line for line in lines if line != ""] 12 | f.close() 13 | result = np.zeros((len(lines), 2)) 14 | for i in range(len(lines)): 15 | line = lines[i] 16 | tokens = line.split("\t") 17 | assert len(tokens) == 2 18 | result[i, 0] = float(tokens[0]) 19 | result[i, 1] = float(tokens[1]) 20 | return result 21 | 22 | def write(self, data, filename, entry): 23 | f = open(filename, "w") 24 | for i in range(0, len(data)): 25 | f.write("\t".join([str(item) for item in data[i, :]])) 26 | f.write("\n") 27 | f.close() 28 | 29 | def visualize(self, data, filename, entry, override_sr): 30 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r") 31 | sr = override_sr 32 | content = f.read() 33 | f.close() 34 | content = content.replace("[__SR__]", str(sr)) 35 | content = content.replace("[__STYLE__]", str(1)) 36 | output_text = "" 37 | for beat_info in data: 38 | output_text += '\n' % ( 39 | int(beat_info[0] * sr), 40 | int(beat_info[1]), 41 | ) 42 | content = content.replace("[__DATA__]", output_text) 43 | f = open(filename, "w") 44 | f.write(content) 45 | f.close() 46 | 47 | def get_visualize_extention_name(self): 48 | return "svl" 49 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/__init__.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | 4 | import mir_eval 5 | import numpy as np 6 | 7 | from .main import transcribe_cb1000_midi 8 | 9 | 10 | def get_chord_from_chdfile(fpath, one_beat=0.5, rounding=True): 11 | """ 12 | chord matrix [M * 14], each line represent the chord of a beat 13 | same format as mir_eval.chord.encode(): 14 | root_number(1), semitone_bitmap(12), bass_number(1) 15 | inputs are generated from junyan's algorithm 16 | """ 17 | file = csv.reader(open(fpath), delimiter="\t") 18 | beat_cnt = 0 19 | chords = [] 20 | for line in file: 21 | start = float(line[0]) 22 | end = float(line[1]) 23 | chord = line[2] 24 | if not rounding: 25 | assert ((end - start) / one_beat).is_integer() 26 | beat_num = int((end - start) / one_beat) 27 | else: 28 | beat_num = round((end - start) / one_beat) 29 | for _ in range(beat_num): 30 | beat_cnt += 1 31 | # see https://craffel.github.io/mir_eval/#mir_eval.chord.encode 32 | chd_enc = mir_eval.chord.encode(chord) 33 | 34 | root = chd_enc[0] 35 | # make chroma and bass absolute 36 | chroma_bitmap = chd_enc[1] 37 | chroma_bitmap = np.roll(chroma_bitmap, root) 38 | bass = (chd_enc[2] + root) % 12 39 | 40 | line = [root] 41 | for _ in chroma_bitmap: 42 | line.append(_) 43 | line.append(bass) 44 | 45 | chords.append(line) 46 | return np.array(chords, dtype=np.float32) 47 | 48 | 49 | def extract_chords_from_midi_file(fpath, chdfile_path): 50 | transcribe_cb1000_midi(fpath, chdfile_path) 51 | return get_chord_from_chdfile(chdfile_path) 52 | 53 | 54 | if __name__ == "__main__": 55 | extract_chords_from_midi_file(sys.argv[1], sys.argv[2]) 56 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/cache.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import pickle 4 | 5 | from mir.common import WORKING_PATH 6 | 7 | __all__ = ["load", "save"] 8 | 9 | 10 | def mkdir_for_file(path): 11 | folder_path = os.path.dirname(path) 12 | if not os.path.isdir(folder_path): 13 | os.makedirs(folder_path) 14 | return path 15 | 16 | 17 | def dumptofile(obj, filename, protocol): 18 | f = open(filename, "wb") 19 | # If you are awared of the compatibility issues 20 | # Well, you use cache only on your own computer, right? 21 | pickle.dump(obj, f, protocol=protocol) 22 | f.close() 23 | 24 | 25 | def loadfromfile(filename): 26 | if os.path.isfile(filename): 27 | f = open(filename, "rb") 28 | obj = pickle.load(f) 29 | f.close() 30 | return obj 31 | else: 32 | raise Exception("No cache of %s" % filename) 33 | 34 | 35 | def load(*names): 36 | if len(names) == 1: 37 | return loadfromfile( 38 | os.path.join(WORKING_PATH, "cache_data/%s.cache" % names[0]) 39 | ) 40 | result = [None] * len(names) 41 | for i in range(len(names)): 42 | result[i] = loadfromfile( 43 | os.path.join(WORKING_PATH, "cache_data/%s.cache" % names[i]) 44 | ) 45 | return result 46 | 47 | 48 | def save(obj, name, protocol=pickle.HIGHEST_PROTOCOL): 49 | path = os.path.join(WORKING_PATH, "cache_data/%s.cache" % name) 50 | mkdir_for_file(path) 51 | dumptofile(obj, path, protocol) 52 | 53 | 54 | def hasher(obj): 55 | if isinstance(obj, list): 56 | m = hashlib.md5() 57 | for item in obj: 58 | m.update(item) 59 | return m.hexdigest() 60 | if isinstance(obj, str): 61 | return hashlib.md5(obj.encode("utf8")).hexdigest() 62 | return hashlib.md5(obj).hexdigest() 63 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/salami_io.py: -------------------------------------------------------------------------------- 1 | from mir.io.feature_io_base import * 2 | from mir.music_base import get_scale_and_suffix 3 | 4 | 5 | class SalamiIO(FeatureIO): 6 | def read(self, filename, entry): 7 | f = open(filename, "r") 8 | data = f.read() 9 | lines = data.split("\n") 10 | result = [] 11 | metre_up = -1 12 | metre_down = -1 13 | tonic = -1 14 | for line in lines: 15 | if line == "": 16 | continue 17 | if line.startswith("#"): 18 | if ":" in line: 19 | seperator_index = line.index(":") 20 | keyword = line[1:seperator_index].strip() 21 | if keyword == "metre": 22 | slash_index = line.index("/") 23 | metre_up = int(line[seperator_index + 1 : slash_index].strip()) 24 | metre_down = int(line[slash_index + 1 :].strip()) 25 | # print('metre changed to %d/%d'%(metre_up,metre_down)) 26 | if keyword == "tonic": 27 | tonic = int( 28 | get_scale_and_suffix(line[seperator_index + 1 :].strip())[0] 29 | ) 30 | 31 | else: 32 | tokens = line.split("\t") 33 | assert len(tokens) == 2 34 | start_time = float(tokens[0]) 35 | result.append((start_time, tokens[1], metre_up, metre_down, tonic)) 36 | f.close() 37 | return result 38 | 39 | def write(self, data, filename, entry): 40 | raise NotImplementedError() 41 | 42 | def visualize(self, data, filename, entry, override_sr): 43 | f = open(filename, "w") 44 | for time, token, _, _, _ in data: 45 | f.write("%f\t%s\n" % (time, token)) 46 | f.close() 47 | -------------------------------------------------------------------------------- /polyffusion/models/model_chd_8bar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from dl_modules import ChordDecoder, ChordEncoder 5 | from utils import * 6 | 7 | 8 | class Chord_8Bar(nn.Module): 9 | def __init__(self, chord_enc: ChordEncoder, chord_dec: ChordDecoder): 10 | super(Chord_8Bar, self).__init__() 11 | self.chord_enc = chord_enc 12 | self.chord_dec = chord_dec 13 | 14 | @classmethod 15 | def load_trained(cls, chord_enc, chord_dec, model_dir): 16 | model = cls(chord_enc, chord_dec) 17 | trained_leaner = torch.load(f"{model_dir}/weights.pt") 18 | model.load_state_dict(trained_leaner["model"]) 19 | return model 20 | 21 | def chord_loss(self, c, recon_root, recon_chroma, recon_bass): 22 | loss_fun = nn.CrossEntropyLoss() 23 | root = c[:, :, 0:12].max(-1)[-1].view(-1).contiguous() 24 | chroma = c[:, :, 12:24].long().view(-1).contiguous() 25 | bass = c[:, :, 24:].max(-1)[-1].view(-1).contiguous() 26 | 27 | recon_root = recon_root.view(-1, 12).contiguous() 28 | recon_chroma = recon_chroma.view(-1, 2).contiguous() 29 | recon_bass = recon_bass.view(-1, 12).contiguous() 30 | root_loss = loss_fun(recon_root, root) 31 | chroma_loss = loss_fun(recon_chroma, chroma) 32 | bass_loss = loss_fun(recon_bass, bass) 33 | chord_loss = root_loss + chroma_loss + bass_loss 34 | return { 35 | "loss": chord_loss, 36 | "root": root_loss, 37 | "chroma": chroma_loss, 38 | "bass": bass_loss, 39 | } 40 | 41 | def get_loss_dict(self, batch, step, tfr_chd): 42 | _, _, chord, _ = batch 43 | 44 | z_chd = self.chord_enc(chord).rsample() 45 | recon_root, recon_chroma, recon_bass = self.chord_dec( 46 | z_chd, False, tfr_chd, chord 47 | ) 48 | return self.chord_loss(chord, recon_root, recon_chroma, recon_bass) 49 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/io/implement/chroma_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mir.io.feature_io_base import * 3 | 4 | 5 | class ChromaIO(FeatureIO): 6 | def read(self, filename, entry): 7 | if filename.endswith(".csv"): 8 | f = open(filename, "r") 9 | lines = f.readlines() 10 | result = [] 11 | for line in lines: 12 | line = line.strip() 13 | if line == "": 14 | continue 15 | arr = np.array(list(map(float, line.split(",")[2:]))) 16 | arr = arr.reshape((2, 12))[::-1].T 17 | arr = np.roll(arr, -3, axis=0).reshape((24)) 18 | result.append(arr) 19 | data = np.array(result) 20 | else: 21 | data = pickle_read(self, filename) 22 | return data 23 | 24 | def write(self, data, filename, entry): 25 | pickle_write(self, data, filename) 26 | 27 | def visualize(self, data, filename, entry, override_sr): 28 | sr = entry.prop.sr 29 | win_shift = entry.prop.hop_length 30 | feature_tuple_size = entry.prop.chroma_tuple_size 31 | # if(FEATURETUPLESIZE==2): 32 | features = data 33 | f = open(filename, "w") 34 | for i in range(0, features.shape[0]): 35 | time = win_shift * i / sr 36 | f.write(str(time)) 37 | for j in range(0, feature_tuple_size): 38 | if j > 0: 39 | f.write("\t0") 40 | for k in range(0, 12): 41 | f.write("\t" + str(features[i][k * feature_tuple_size + j])) 42 | f.write("\n") 43 | f.close() 44 | 45 | def pre_assign(self, entry, proxy): 46 | entry.prop.set("n_frame", LoadingPlaceholder(proxy, entry)) 47 | 48 | def post_load(self, data, entry): 49 | entry.prop.set("n_frame", data.shape[0]) 50 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/extractors/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mir.extractors.extractor_base import * 3 | 4 | 5 | class BlankMusic(ExtractorBase): 6 | def get_feature_class(self): 7 | return io.MusicIO 8 | 9 | def extract(self, entry, **kwargs): 10 | time = 60.0 # seconds 11 | if "time" in kwargs: 12 | time = kwargs["time"] 13 | return np.zeros((int(np.ceil(time * entry.prop.sr)))) 14 | 15 | 16 | class FrameCount(ExtractorBase): 17 | def get_feature_class(self): 18 | return io.IntegerIO 19 | 20 | def extract(self, entry, **kwargs): 21 | # self.require(entry.prop.hop_length) 22 | return entry.dict[kwargs["source"]].get(entry).shape[0] 23 | 24 | 25 | class Evaluate: 26 | def __init__(self, io): 27 | self.__io = io 28 | 29 | def __call__(self, *args, **kwargs): 30 | inner_instance = Evaluate.InnerEvaluate() 31 | inner_instance.io = self.__io 32 | return inner_instance 33 | 34 | class InnerEvaluate(ExtractorBase): 35 | def __init__(self): 36 | self.io = None 37 | 38 | def get_feature_class(self): 39 | return self.io 40 | 41 | class __ProxyReflector: 42 | def __init__(self, entry): 43 | self.__entry = entry 44 | 45 | def __getattr__(self, item): 46 | if item in self.__entry.dict: 47 | print("Getting %s" % item) 48 | return self.__entry.dict[item].get(self.__entry) 49 | else: 50 | raise AttributeError( 51 | "No key '%s' found in entry %s" % (item, self.__entry.name) 52 | ) 53 | 54 | def extract(self, entry, **kwargs): 55 | eval_proxy_ref__ = __class__.__ProxyReflector(entry) 56 | expr = kwargs["expr"].replace("$", "eval_proxy_ref__.") 57 | return eval(expr) 58 | -------------------------------------------------------------------------------- /polyffusion/main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from omegaconf import OmegaConf 4 | 5 | from train.train_autoencoder import Autoencoder_TrainConfig 6 | from train.train_chd_8bar import Chord8bar_TrainConfig 7 | from train.train_ddpm import DDPM_TrainConfig 8 | from train.train_ldm import LDM_TrainConfig 9 | 10 | if __name__ == "__main__": 11 | parser = ArgumentParser( 12 | description="train (or resume training) a Polyffusion model" 13 | ) 14 | parser.add_argument( 15 | "--output_dir", 16 | default=None, 17 | help="directory in which to store model checkpoints and training logs", 18 | ) 19 | parser.add_argument( 20 | "--data_dir", default=None, help="directory of custom training data, in npzs" 21 | ) 22 | parser.add_argument( 23 | "--pop909_use_track", 24 | default="0,1,2", 25 | help="which tracks to use for pop909 (default dataset) training. (0: melody, 1: bridge, 2: piano accompaniment)", 26 | ) 27 | parser.add_argument("--model", help="which model to train (autoencoder, ldm, ddpm)") 28 | args = parser.parse_args() 29 | 30 | use_track = [int(x) for x in args.pop909_use_track.split(",")] 31 | 32 | params = OmegaConf.load(f"polyffusion/params/{args.model}.yaml") 33 | 34 | if args.model.startswith("sdf"): 35 | use_musicalion = "musicalion" in args.model 36 | config = LDM_TrainConfig( 37 | params, 38 | args.output_dir, 39 | use_musicalion, 40 | use_track=use_track, 41 | data_dir=args.data_dir, 42 | ) 43 | elif args.model == "ddpm": 44 | config = DDPM_TrainConfig(params, args.output_dir, data_dir=args.data_dir) 45 | elif args.model == "autoencoder": 46 | config = Autoencoder_TrainConfig( 47 | params, args.output_dir, data_dir=args.data_dir 48 | ) 49 | elif args.model == "chd_8bar": 50 | config = Chord8bar_TrainConfig(params, args.output_dir, data_dir=args.data_dir) 51 | else: 52 | raise NotImplementedError 53 | config.train() 54 | -------------------------------------------------------------------------------- /polyffusion/train/train_chd_8bar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from data.dataloader import get_custom_train_val_dataloaders, get_train_val_dataloaders 4 | from dl_modules import ChordDecoder, ChordEncoder 5 | from models.model_chd_8bar import Chord_8Bar 6 | from train.scheduler import ParameterScheduler, TeacherForcingScheduler 7 | 8 | # from stable_diffusion.model.autoencoder import Autoencoder, Encoder, Decoder 9 | from . import * 10 | 11 | 12 | class Chord8bar_TrainConfig(TrainConfig): 13 | def __init__(self, params, output_dir, data_dir=None) -> None: 14 | # Teacher-forcing rate for Chord VAE training 15 | tfr_chd = params.tfr_chd 16 | tfr_chd_scheduler = TeacherForcingScheduler(*tfr_chd) 17 | params_dict = dict(tfr_chd=tfr_chd_scheduler) 18 | param_scheduler = ParameterScheduler(**params_dict) 19 | 20 | super().__init__(params, param_scheduler, output_dir) 21 | 22 | self.chord_enc = ChordEncoder( 23 | input_dim=params.chd_input_dim, 24 | hidden_dim=params.chd_hidden_dim, 25 | z_dim=params.chd_z_dim, 26 | ) 27 | self.chord_dec = ChordDecoder( 28 | input_dim=params.chd_input_dim, 29 | z_input_dim=params.chd_z_input_dim, 30 | hidden_dim=params.chd_hidden_dim, 31 | z_dim=params.chd_z_dim, 32 | n_step=params.chd_n_step, 33 | ) 34 | self.model = Chord_8Bar( 35 | self.chord_enc, 36 | self.chord_dec, 37 | ) 38 | 39 | # Create dataloader 40 | if data_dir is None: 41 | self.train_dl, self.val_dl = get_train_val_dataloaders( 42 | params.batch_size, params.num_workers, params.pin_memory 43 | ) 44 | else: 45 | self.train_dl, self.val_dl = get_custom_train_val_dataloaders( 46 | params.batch_size, 47 | data_dir, 48 | num_workers=params.num_workers, 49 | pin_memory=params.pin_memory, 50 | ) 51 | 52 | # Create optimizer 53 | self.optimizer = torch.optim.Adam( 54 | self.model.parameters(), lr=params.learning_rate 55 | ) 56 | -------------------------------------------------------------------------------- /polyffusion/train/train_autoencoder.py: -------------------------------------------------------------------------------- 1 | # This file is unused 2 | 3 | import torch 4 | 5 | from data.dataloader import get_custom_train_val_dataloaders, get_train_val_dataloaders 6 | from dirs import * 7 | from models.model_autoencoder import Polyffusion_Autoencoder 8 | from stable_diffusion.model.autoencoder import Autoencoder, Decoder, Encoder 9 | 10 | from . import TrainConfig 11 | 12 | 13 | class Autoencoder_TrainConfig(TrainConfig): 14 | model: Autoencoder 15 | optimizer: torch.optim.Adam 16 | 17 | def __init__(self, params, output_dir, data_dir=None) -> None: 18 | super().__init__(params, None, output_dir) 19 | encoder = Encoder( 20 | in_channels=params.in_channels, 21 | z_channels=params.z_channels, 22 | channels=params.channels, 23 | channel_multipliers=params.channel_multipliers, 24 | n_resnet_blocks=params.n_res_blocks, 25 | ) 26 | 27 | decoder = Decoder( 28 | out_channels=params.out_channels, 29 | z_channels=params.z_channels, 30 | channels=params.channels, 31 | channel_multipliers=params.channel_multipliers, 32 | n_resnet_blocks=params.n_res_blocks, 33 | ) 34 | 35 | autoencoder = Autoencoder( 36 | encoder=encoder, 37 | decoder=decoder, 38 | emb_channels=params.emb_channels, 39 | z_channels=params.z_channels, 40 | ) 41 | 42 | self.model = Polyffusion_Autoencoder(autoencoder) 43 | 44 | # Create dataloader 45 | if data_dir is None: 46 | self.train_dl, self.val_dl = get_train_val_dataloaders( 47 | params.batch_size, params.num_workers, params.pin_memory 48 | ) 49 | else: 50 | self.train_dl, self.val_dl = get_custom_train_val_dataloaders( 51 | params.batch_size, 52 | data_dir, 53 | num_workers=params.num_workers, 54 | pin_memory=params.pin_memory, 55 | ) 56 | 57 | # Create optimizer 58 | self.optimizer = torch.optim.Adam( 59 | self.model.parameters(), lr=params.learning_rate 60 | ) 61 | -------------------------------------------------------------------------------- /polyffusion/lightning_learner.py: -------------------------------------------------------------------------------- 1 | import lightning 2 | import torch 3 | 4 | 5 | class LightningLearner(lightning.LightningModule): 6 | def __init__(self, model, optimizer, params, param_scheduler): 7 | super().__init__() 8 | self.model = model 9 | self.optimizer = optimizer 10 | self.params = params 11 | self.param_scheduler = param_scheduler # teacher-forcing stuff 12 | 13 | self.save_hyperparameters("params", "param_scheduler") 14 | 15 | def _categorize_loss_dict(self, loss_dict, category): 16 | return {f"{category}/{k}": v for k, v in loss_dict.items()} 17 | 18 | def training_step(self, batch, batch_idx): 19 | if self.param_scheduler is not None: 20 | scheduled_params = self.param_scheduler.step() 21 | loss_dict = self.model.get_loss_dict( 22 | batch, self.global_step, **scheduled_params 23 | ) 24 | else: 25 | scheduled_params = None 26 | loss_dict = self.model.get_loss_dict(batch, self.global_step) 27 | 28 | # check NaN 29 | for loss_value in list(loss_dict.values()): 30 | if isinstance(loss_value, torch.Tensor) and torch.isnan(loss_value).any(): 31 | raise RuntimeError( 32 | f"Detected NaN loss at step {self.global_step}, epoch {self.epoch}" 33 | ) 34 | loss = loss_dict["loss"] 35 | 36 | loss_dict = self._categorize_loss_dict(loss_dict, "train") 37 | self.log_dict(loss_dict, prog_bar=True) 38 | 39 | return loss 40 | 41 | def validation_step(self, batch, batch_idx): 42 | if self.param_scheduler is not None: 43 | scheduled_params = self.param_scheduler.step() 44 | loss_dict = self.model.get_loss_dict( 45 | batch, self.global_step, **scheduled_params 46 | ) 47 | else: 48 | scheduled_params = None 49 | loss_dict = self.model.get_loss_dict(batch, self.global_step) 50 | 51 | loss_dict = self._categorize_loss_dict(loss_dict, "val") 52 | self.log_dict(loss_dict) 53 | 54 | def configure_optimizers(self): 55 | return self.optimizer 56 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/complex_chord_io.py: -------------------------------------------------------------------------------- 1 | import complex_chord 2 | import numpy as np 3 | from mir.common import PACKAGE_PATH 4 | from mir.io.feature_io_base import * 5 | 6 | 7 | class ComplexChordIO(FeatureIO): 8 | def read(self, filename, entry): 9 | n_frame = entry.n_frame 10 | f = open(filename, "r") 11 | line_list = f.readlines() 12 | tags = np.ones((n_frame, 6)) * -2 13 | for line in line_list: 14 | line = line.strip() 15 | if line == "": 16 | continue 17 | if "\t" in line: 18 | tokens = line.split("\t") 19 | else: 20 | tokens = line.split(" ") 21 | sr = entry.prop.sr 22 | win_shift = entry.prop.hop_length 23 | begin = int(round(float(tokens[0]) * sr / win_shift)) 24 | end = int(round(float(tokens[1]) * sr / win_shift)) 25 | if end > n_frame: 26 | end = n_frame 27 | if begin < 0: 28 | begin = 0 29 | tags[begin:end, :] = ( 30 | complex_chord.Chord(tokens[2]).to_numpy().reshape((1, 6)) 31 | ) 32 | f.close() 33 | return tags 34 | 35 | def write(self, data, filename, entry): 36 | raise NotImplementedError() 37 | 38 | def visualize(self, data, filename, entry, override_sr): 39 | f = open(os.path.join(PACKAGE_PATH, "data/spectrogram_template.svl"), "r") 40 | sr = entry.prop.sr 41 | win_shift = entry.prop.hop_length 42 | content = f.read() 43 | f.close() 44 | content = content.replace("[__SR__]", str(sr)) 45 | content = content.replace("[__WIN_SHIFT__]", str(win_shift)) 46 | content = content.replace("[__SHAPE_1__]", str(data.shape[1])) 47 | content = content.replace("[__COLOR__]", str(1)) 48 | labels = [str(i) for i in range(data.shape[1])] 49 | content = content.replace("[__DATA__]", create_svl_3d_data(labels, data)) 50 | f = open(filename, "w") 51 | f.write(content) 52 | f.close() 53 | 54 | def get_visualize_extention_name(self): 55 | return "svl" 56 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/midilab_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mir.common import PACKAGE_PATH 3 | from mir.io.feature_io_base import * 4 | 5 | 6 | class MidiLabIO(FeatureIO): 7 | def read(self, filename, entry): 8 | f = open(filename, "r") 9 | lines = f.readlines() 10 | lines = [line.strip("\n\r") for line in lines] 11 | lines = [line for line in lines if line != ""] 12 | f.close() 13 | result = np.zeros((len(lines), 3)) 14 | for i in range(len(lines)): 15 | line = lines[i] 16 | tokens = line.split("\t") 17 | assert len(tokens) == 3 18 | result[i, 0] = float(tokens[0]) 19 | result[i, 1] = float(tokens[1]) 20 | result[i, 2] = float(tokens[2]) 21 | return result 22 | 23 | def write(self, data, filename, entry): 24 | f = open(filename, "w") 25 | for i in range(0, len(data)): 26 | f.write("\t".join([str(item) for item in data[i]])) 27 | f.write("\n") 28 | f.close() 29 | 30 | def visualize(self, data, filename, entry, override_sr): 31 | f = open(os.path.join(PACKAGE_PATH, "data/midi_template.svl"), "r") 32 | sr = override_sr 33 | content = f.read() 34 | f.close() 35 | content = content.replace("[__SR__]", str(sr)) 36 | content = content.replace("[__WIN_SHIFT__]", "1") 37 | output_text = "" 38 | for note_info in data: 39 | output_text += self.__get_midi_note_text( 40 | note_info[0] * sr, note_info[1] * sr - 1, note_info[2] 41 | ) 42 | content = content.replace("[__DATA__]", output_text) 43 | f = open(filename, "w") 44 | f.write(content) 45 | f.close() 46 | 47 | def __get_midi_note_text(self, start_frame, end_frame, midi_height, level=0.78125): 48 | return '\n' % ( 49 | int(round(start_frame)), 50 | midi_height, 51 | int(round(end_frame - start_frame)), 52 | level, 53 | ) 54 | 55 | def get_visualize_extention_name(self): 56 | return "svl" 57 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/osu_io.py: -------------------------------------------------------------------------------- 1 | from mir.io.feature_io_base import * 2 | 3 | 4 | class MetaDict: 5 | def __init__(self): 6 | self.dict = {} 7 | 8 | def set(self, key, value): 9 | self.dict[key] = value 10 | 11 | def __getattr__(self, item): 12 | return self.dict[item.lower()] 13 | 14 | 15 | class OsuMapIO(FeatureIO): 16 | def read(self, filename, entry): 17 | f = open(filename, "r", encoding="UTF-8") 18 | result = MetaDict() 19 | lines = f.readlines() 20 | current_state = 0 21 | current_dict = None 22 | for line in lines: 23 | line = line.strip() 24 | if line == "": 25 | continue 26 | if line.startswith("["): 27 | assert line.endswith("]") 28 | namespace = line[1:-1].lower() 29 | if namespace in [ 30 | "general", 31 | "editor", 32 | "metadata", 33 | "difficulty", 34 | "colours", 35 | ]: 36 | current_state = 1 37 | current_dict = MetaDict() 38 | result.set(namespace, current_dict) 39 | elif namespace in ["hitobjects", "events", "timingpoints"]: 40 | current_state = 2 41 | current_dict = [] 42 | result.set(namespace, current_dict) 43 | else: 44 | raise Exception( 45 | "Unknown namespace %s in %s" % (namespace, filename) 46 | ) 47 | else: 48 | if current_state == 1: 49 | split_index = line.index(":") 50 | key = line[:split_index].strip() 51 | value = line[split_index + 1 :].strip() 52 | current_dict.set(key.lower(), value) 53 | elif current_state == 2: 54 | current_dict.append(line) 55 | return result 56 | 57 | def write(self, data, filename, entry): 58 | raise NotImplementedError() 59 | 60 | def visualize(self, data, filename, entry, override_sr): 61 | raise NotImplementedError() 62 | 63 | def get_visualize_extention_name(self): 64 | raise NotImplementedError() 65 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/extractors/librosa_extractor.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | from mir.extractors.extractor_base import * 4 | 5 | 6 | class HPSS(ExtractorBase): 7 | def get_feature_class(self): 8 | return io.MusicIO 9 | 10 | def extract(self, entry, **kwargs): 11 | if "source" in kwargs: 12 | y = entry.dict[kwargs["source"]].get(entry) 13 | else: 14 | y = entry.music 15 | y_h = librosa.effects.harmonic(y, margin=kwargs["margin"]) 16 | # y_h, y_p = librosa.effects.hpss(y, margin=(1.0, 5.0)) 17 | return y_h 18 | 19 | 20 | class CQT(ExtractorBase): 21 | def get_feature_class(self): 22 | return io.SpectrogramIO 23 | 24 | # Warning this spectrum has a 1/3 half note stepping 25 | def extract(self, entry, **kwargs): 26 | n_bins = 262 27 | y = entry.music 28 | logspec = librosa.core.cqt( 29 | y, 30 | sr=kwargs["sr"], 31 | hop_length=kwargs["hop_length"], 32 | bins_per_octave=36, 33 | n_bins=n_bins, 34 | filter_scale=1.5, 35 | ).T 36 | logspec = np.abs(logspec) 37 | return logspec 38 | 39 | 40 | class STFT(ExtractorBase): 41 | def get_feature_class(self): 42 | return io.SpectrogramIO 43 | 44 | # Warning this spectrum has a 1/3 half note stepping 45 | def extract(self, entry, **kwargs): 46 | y = entry.music 47 | logspec = librosa.core.stft( 48 | y, win_length=kwargs["win_size"], hop_length=kwargs["hop_length"] 49 | ).T 50 | logspec = np.abs(logspec) 51 | return logspec 52 | 53 | 54 | class Onset(ExtractorBase): 55 | def get_feature_class(self): 56 | return io.SpectrogramIO 57 | 58 | def extract(self, entry, **kwargs): 59 | onset = librosa.onset.onset_strength( 60 | entry.music, sr=kwargs["sr"], hop_length=kwargs["hop_length"] 61 | ).reshape((-1, 1)) 62 | return onset 63 | 64 | 65 | class Energy(ExtractorBase): 66 | def get_feature_class(self): 67 | return io.SpectrogramIO 68 | 69 | def extract(self, entry, **kwargs): 70 | energy = librosa.feature.rmse( 71 | y=entry.dict[kwargs["source"]].get(entry), 72 | hop_length=kwargs["hop_length"], 73 | frame_length=kwargs["win_size"], 74 | center=True, 75 | ).reshape((-1, 1)) 76 | return energy 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # local 2 | /data/ 3 | /pretrained/ 4 | /result*/ 5 | /exp/ 6 | /eval/ 7 | /wandb/ 8 | 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/io/implement/spectrogram_io.py: -------------------------------------------------------------------------------- 1 | from mir.common import PACKAGE_PATH 2 | from mir.io.feature_io_base import * 3 | 4 | 5 | class SpectrogramIO(FeatureIO): 6 | def read(self, filename, entry): 7 | return pickle_read(self, filename) 8 | 9 | def write(self, data, filename, entry): 10 | pickle_write(self, data, filename) 11 | 12 | def visualize(self, data, filename, entry, override_sr): 13 | if type(data) is tuple: 14 | labels = data[0] 15 | data = data[1] 16 | else: 17 | labels = None 18 | if len(data.shape) == 1: 19 | data = data.reshape((-1, 1)) 20 | if data.shape[1] > 1: 21 | f = open(os.path.join(PACKAGE_PATH, "data/spectrogram_template.svl"), "r") 22 | sr = entry.prop.sr 23 | win_shift = entry.prop.hop_length 24 | content = f.read() 25 | f.close() 26 | content = content.replace("[__SR__]", str(sr)) 27 | content = content.replace("[__WIN_SHIFT__]", str(win_shift)) 28 | content = content.replace("[__SHAPE_1__]", str(data.shape[1])) 29 | content = content.replace("[__COLOR__]", str(1)) 30 | if labels is None: 31 | labels = [str(i) for i in range(data.shape[1])] 32 | content = content.replace("[__DATA__]", create_svl_3d_data(labels, data)) 33 | else: 34 | f = open(os.path.join(PACKAGE_PATH, "data/curve_template.svl"), "r") 35 | sr = entry.prop.sr 36 | win_shift = entry.prop.hop_length 37 | content = f.read() 38 | f.close() 39 | content = content.replace("[__SR__]", str(sr)) 40 | content = content.replace("[__STYLE__]", str(1)) 41 | results = [] 42 | for i in range(0, len(data)): 43 | results.append( 44 | '' 45 | % (int(override_sr / sr * i * win_shift), data[i, 0]) 46 | ) 47 | content = content.replace("[__DATA__]", "\n".join(results)) 48 | content = content.replace("[__NAME__]", "curve") 49 | 50 | f = open(filename, "w") 51 | f.write(content) 52 | f.close() 53 | 54 | def pre_assign(self, entry, proxy): 55 | entry.prop.set("n_frame", LoadingPlaceholder(proxy, entry)) 56 | 57 | def post_load(self, data, entry): 58 | entry.prop.set("n_frame", data.shape[0]) 59 | 60 | def get_visualize_extention_name(self): 61 | return "svl" 62 | -------------------------------------------------------------------------------- /polyffusion/data/dataloader_musicalion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | 5 | from data.dataset_musicalion import PianoOrchDataset_Musicalion 6 | from utils import ( 7 | estx_to_midi_file, 8 | pianotree_pitch_shift, 9 | pr_mat_pitch_shift, 10 | prmat2c_to_midi_file, 11 | prmat_to_midi_file, 12 | ) 13 | 14 | 15 | def collate_fn(batch, shift): 16 | def sample_shift(): 17 | return np.random.choice(np.arange(-6, 6), 1)[0] 18 | 19 | prmat2c = [] 20 | pnotree = [] 21 | prmat = [] 22 | song_fn = [] 23 | for b in batch: 24 | # b[0]: seg_pnotree; b[1]: seg_pnotree_y 25 | seg_prmat2c = b[0] 26 | seg_pnotree = b[1] 27 | seg_prmat = b[3] 28 | 29 | if shift: 30 | shift_pitch = sample_shift() 31 | seg_prmat2c = pr_mat_pitch_shift(seg_prmat2c, shift_pitch) 32 | seg_pnotree = pianotree_pitch_shift(seg_pnotree, shift_pitch) 33 | seg_prmat = pr_mat_pitch_shift(seg_prmat, shift_pitch) 34 | 35 | prmat2c.append(seg_prmat2c) 36 | pnotree.append(seg_pnotree) 37 | prmat.append(seg_prmat) 38 | 39 | if len(b) > 4: 40 | song_fn.append(b[4]) 41 | 42 | prmat2c = torch.Tensor(np.array(prmat2c, np.float32)).float() 43 | pnotree = torch.Tensor(np.array(pnotree, np.int64)).long() 44 | prmat = torch.Tensor(np.array(prmat, np.float32)).float() 45 | # prmat = prmat.unsqueeze(1) # (B, 1, 128, 128) 46 | if len(song_fn) > 0: 47 | return prmat2c, pnotree, None, prmat, song_fn 48 | else: 49 | return prmat2c, pnotree, None, prmat 50 | 51 | 52 | def get_train_val_dataloaders(batch_size, num_workers=0, pin_memory=False, debug=False): 53 | train_dataset, val_dataset = PianoOrchDataset_Musicalion.load_train_and_valid_sets( 54 | debug 55 | ) 56 | train_dl = DataLoader( 57 | train_dataset, 58 | batch_size, 59 | True, 60 | collate_fn=lambda x: collate_fn(x, shift=True), 61 | num_workers=num_workers, 62 | pin_memory=pin_memory, 63 | ) 64 | val_dl = DataLoader( 65 | val_dataset, 66 | batch_size, 67 | False, 68 | collate_fn=lambda x: collate_fn(x, shift=False), 69 | num_workers=num_workers, 70 | pin_memory=pin_memory, 71 | ) 72 | print( 73 | f"Dataloader ready: batch_size={batch_size}, num_workers={num_workers}, pin_memory={pin_memory}" 74 | ) 75 | return train_dl, val_dl 76 | 77 | 78 | if __name__ == "__main__": 79 | train_dl, val_dl = get_train_val_dataloaders(16) 80 | print(len(train_dl)) 81 | for batch in train_dl: 82 | print(len(batch)) 83 | prmat2c, pnotree, _, prmat = batch 84 | print(prmat2c.shape) 85 | print(pnotree.shape) 86 | print(prmat.shape) 87 | prmat2c = prmat2c.cpu().numpy() 88 | pnotree = pnotree.cpu().numpy() 89 | prmat = prmat.cpu().numpy() 90 | prmat2c_to_midi_file(prmat2c, "exp/m_dl_prmat2c.mid") 91 | estx_to_midi_file(pnotree, "exp/m_dl_pnotree.mid") 92 | prmat_to_midi_file(prmat, "exp/m_dl_prmat.mid") 93 | exit(0) 94 | -------------------------------------------------------------------------------- /polyffusion/train/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copied from torch_plus 2 | 3 | import numpy as np 4 | 5 | 6 | def scheduled_sampling(i, high=0.7, low=0.05): 7 | i /= 1000 * 40 # new update 8 | x = 10 * (i - 0.5) 9 | z = 1 / (1 + np.exp(x)) 10 | y = (high - low) * z + low 11 | return y 12 | 13 | 14 | class _Scheduler: 15 | def __init__(self, step=0, mode="train"): 16 | self._step = step 17 | self._mode = mode 18 | 19 | def _update_step(self): 20 | if self._mode == "train": 21 | self._step += 1 22 | elif self._mode == "val": 23 | pass 24 | else: 25 | raise NotImplementedError 26 | 27 | def step(self): 28 | raise NotImplementedError 29 | 30 | def train(self): 31 | self._mode = "train" 32 | 33 | def eval(self): 34 | self._mode = "val" 35 | 36 | 37 | class ConstantScheduler(_Scheduler): 38 | def __init__(self, param, step=0.0): 39 | super(ConstantScheduler, self).__init__(step) 40 | self.param = param 41 | 42 | def step(self): 43 | self._update_step() 44 | return self.param 45 | 46 | 47 | class TeacherForcingScheduler(_Scheduler): 48 | def __init__(self, high, low, f=scheduled_sampling, step=0): 49 | super(TeacherForcingScheduler, self).__init__(step) 50 | self.high = high 51 | self.low = low 52 | self._step = step 53 | self.schedule_f = f 54 | 55 | def get_tfr(self): 56 | return self.schedule_f(self._step, self.high, self.low) 57 | 58 | def step(self): 59 | tfr = self.get_tfr() 60 | self._update_step() 61 | return tfr 62 | 63 | 64 | class OptimizerScheduler(_Scheduler): 65 | def __init__(self, optimizer, scheduler, clip, step=0): 66 | # optimizer and scheduler are pytorch class 67 | super(OptimizerScheduler, self).__init__(step) 68 | self.optimizer = optimizer 69 | self.scheduler = scheduler 70 | self.clip = clip 71 | 72 | def optimizer_zero_grad(self): 73 | self.optimizer.zero_grad() 74 | 75 | def step(self, require_zero_grad=False): 76 | self.optimizer.step() 77 | self.scheduler.step() 78 | if require_zero_grad: 79 | self.optimizer_zero_grad() 80 | self._update_step() 81 | 82 | 83 | class ParameterScheduler(_Scheduler): 84 | def __init__(self, step=0, mode="train", **schedulers): 85 | # optimizer and scheduler are pytorch class 86 | super(ParameterScheduler, self).__init__(step) 87 | self.schedulers = schedulers 88 | self.mode = mode 89 | 90 | def train(self): 91 | self.mode = "train" 92 | for scheduler in self.schedulers.values(): 93 | scheduler.train() 94 | 95 | def eval(self): 96 | self.mode = "val" 97 | for scheduler in self.schedulers.values(): 98 | scheduler.eval() 99 | 100 | def step(self, require_zero_grad=False): 101 | params_dic = {} 102 | for key, scheduler in self.schedulers.items(): 103 | params_dic[key] = scheduler.step() 104 | return params_dic 105 | -------------------------------------------------------------------------------- /polyffusion/stable_diffusion/losses/discriminator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from .util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find("BatchNorm") != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | 22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 23 | """Construct a PatchGAN discriminator 24 | Parameters: 25 | input_nc (int) -- the number of channels in input images 26 | ndf (int) -- the number of filters in the last conv layer 27 | n_layers (int) -- the number of conv layers in the discriminator 28 | norm_layer -- normalization layer 29 | """ 30 | super(NLayerDiscriminator, self).__init__() 31 | if not use_actnorm: 32 | norm_layer = nn.BatchNorm2d 33 | else: 34 | norm_layer = ActNorm 35 | if ( 36 | type(norm_layer) == functools.partial 37 | ): # no need to use bias as BatchNorm2d has affine parameters 38 | use_bias = norm_layer.func != nn.BatchNorm2d 39 | else: 40 | use_bias = norm_layer != nn.BatchNorm2d 41 | 42 | kw = 4 43 | padw = 1 44 | sequence = [ 45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 46 | nn.LeakyReLU(0.2, True), 47 | ] 48 | nf_mult = 1 49 | nf_mult_prev = 1 50 | for n in range(1, n_layers): # gradually increase the number of filters 51 | nf_mult_prev = nf_mult 52 | nf_mult = min(2**n, 8) 53 | sequence += [ 54 | nn.Conv2d( 55 | ndf * nf_mult_prev, 56 | ndf * nf_mult, 57 | kernel_size=kw, 58 | stride=2, 59 | padding=padw, 60 | bias=use_bias, 61 | ), 62 | norm_layer(ndf * nf_mult), 63 | nn.LeakyReLU(0.2, True), 64 | ] 65 | 66 | nf_mult_prev = nf_mult 67 | nf_mult = min(2**n_layers, 8) 68 | sequence += [ 69 | nn.Conv2d( 70 | ndf * nf_mult_prev, 71 | ndf * nf_mult, 72 | kernel_size=kw, 73 | stride=1, 74 | padding=padw, 75 | bias=use_bias, 76 | ), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True), 79 | ] 80 | 81 | sequence += [ 82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 83 | ] # output 1 channel prediction map 84 | self.main = nn.Sequential(*sequence) 85 | 86 | def forward(self, input): 87 | """Standard forward.""" 88 | return self.main(input) 89 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | 4 | import numpy as np 5 | from chord_class import ChordClass 6 | from extractors.midi_utilities import ( 7 | MidiBeatExtractor, 8 | ) 9 | from extractors.rule_based_channel_reweight import midi_to_thickness_and_bass_weights 10 | from io_new.chordlab_io import ChordLabIO 11 | from midi_chord import ChordRecognition 12 | from mir import DataEntry, io 13 | from tqdm import tqdm 14 | 15 | 16 | def process_chord(entry, extra_division): 17 | """ 18 | 19 | Parameters 20 | ---------- 21 | entry: the song to be processed. Properties required: 22 | entry.midi: the pretry midi object 23 | entry.beat: extracted beat and downbeat 24 | extra_division: extra divisions to each beat. 25 | For chord recognition on beat-level, use extra_division=1 26 | For chord recognition on half-beat-level, use extra_division=2 27 | 28 | Returns 29 | ------- 30 | Extracted chord sequence 31 | """ 32 | 33 | midi = entry.midi 34 | beats = midi.get_beats() 35 | if extra_division > 1: 36 | beat_interp = np.linspace(beats[:-1], beats[1:], extra_division + 1).T 37 | last_beat = beat_interp[-1, -1] 38 | beats = np.append(beat_interp[:, :-1].reshape((-1)), last_beat) 39 | downbeats = midi.get_downbeats() 40 | j = 0 41 | beat_pos = -2 42 | beat = [] 43 | for i in range(len(beats)): 44 | if j < len(downbeats) and beats[i] == downbeats[j]: 45 | beat_pos = 1 46 | j += 1 47 | else: 48 | beat_pos = beat_pos + 1 49 | assert beat_pos > 0 50 | beat.append([beats[i], beat_pos]) 51 | rec = ChordRecognition(entry, ChordClass()) 52 | weights = midi_to_thickness_and_bass_weights(entry.midi) 53 | rec.process_feature(weights) 54 | chord = rec.decode() 55 | return chord 56 | 57 | 58 | def transcribe_cb1000_midi(midi_path, output_path): 59 | """ 60 | Perform chord recognition on a midi 61 | :param midi_path: the path to the midi file 62 | :param output_path: the path to the output file 63 | """ 64 | entry = DataEntry() 65 | entry.append_file(midi_path, io.MidiIO, "midi") 66 | entry.append_extractor(MidiBeatExtractor, "beat") 67 | result = process_chord(entry, extra_division=2) 68 | entry.append_data(result, ChordLabIO, "pred") 69 | entry.save("pred", output_path) 70 | 71 | 72 | def extract_in_folder(): 73 | dpath = sys.argv[1] 74 | dpath_output = sys.argv[2] 75 | os.system(f"rm -rf {dpath_output}") 76 | os.system(f"mkdir -p {dpath_output}") 77 | for piece in tqdm(os.listdir(dpath)): 78 | os.system(f"mkdir -p {join(dpath_output, piece)}") 79 | for ver in os.listdir(join(dpath, piece)): 80 | transcribe_cb1000_midi( 81 | join(dpath, piece, ver), join(dpath_output, piece, ver[:-4]) + ".out" 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | import sys 87 | 88 | if len(sys.argv) != 3: 89 | print("Usage: main.py midi_path output_path") 90 | exit(0) 91 | 92 | transcribe_cb1000_midi(sys.argv[1], sys.argv[2]) 93 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/io/feature_io_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from abc import ABC, abstractmethod 4 | 5 | 6 | class LoadingPlaceholder: 7 | def __init__(self, proxy, entry): 8 | self.proxy = proxy 9 | self.entry = entry 10 | pass 11 | 12 | def fire(self): 13 | self.proxy.get(self.entry) 14 | 15 | 16 | class FeatureIO(ABC): 17 | @abstractmethod 18 | def read(self, filename, entry): 19 | pass 20 | 21 | def safe_read(self, filename, entry): 22 | entry.prop.start_record_reading() 23 | try: 24 | result = self.read(filename, entry) 25 | except Exception: 26 | entry.prop.end_record_reading() 27 | raise 28 | entry.prop.end_record_reading() 29 | return result 30 | 31 | def try_mkdir(self, filename): 32 | folder = os.path.dirname(filename) 33 | if not os.path.isdir(folder): 34 | os.makedirs(folder) 35 | 36 | def create(self, data, filename, entry): 37 | self.try_mkdir(filename) 38 | self.write(data, filename, entry) 39 | 40 | @abstractmethod 41 | def write(self, data, filename, entry): 42 | pass 43 | 44 | # override iif writing and visualizing use different methods 45 | # (i.e. compressed vs uncompressed) 46 | def visualize(self, data, filename, entry, override_sr): 47 | self.write(data, filename, entry) 48 | 49 | # override iff entry properties will be updated upon loading 50 | def pre_assign(self, entry, proxy): 51 | pass 52 | 53 | # override iff entry properties need updated upon loading 54 | def post_load(self, data, entry): 55 | pass 56 | 57 | # override iif it will save as other formats (e.g. wav) 58 | def get_visualize_extention_name(self): 59 | return "txt" 60 | 61 | def file_to_evaluation_format(self, filename, entry): 62 | raise Exception("Not supported by the io class") 63 | 64 | def data_to_evaluation_format(self, data, entry): 65 | raise Exception("Not supported by the io class") 66 | 67 | 68 | def pickle_read(self, filename): 69 | f = open(filename, "rb") 70 | obj = pickle.load(f) 71 | f.close() 72 | return obj 73 | 74 | 75 | def pickle_write(self, data, filename): 76 | f = open(filename, "wb") 77 | pickle.dump(data, f) 78 | f.close() 79 | 80 | 81 | def create_svl_3d_data(labels, data): 82 | assert len(labels) == data.shape[1] 83 | results_part1 = [ 84 | '' % (i, str(labels[i])) for i in range(len(labels)) 85 | ] 86 | results_part2 = [ 87 | '%s' % (i, " ".join([str(s) for s in data[i]])) 88 | for i in range(data.shape[0]) 89 | ] 90 | return "\n".join(results_part1) + "\n" + "\n".join(results_part2) 91 | 92 | 93 | def framed_2d_feature_visualizer(entry, features, filename): 94 | f = open(filename, "w") 95 | for i in range(0, features.shape[0]): 96 | time = entry.prop.hop_length * i / entry.prop.sr 97 | f.write(str(time)) 98 | for j in range(0, features.shape[1]): 99 | f.write("\t" + str(features[i][j])) 100 | f.write("\n") 101 | f.close() 102 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/io_new/beat_align_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mir import PACKAGE_PATH 3 | from mir.io.feature_io_base import * 4 | 5 | 6 | class BeatAlignCQTIO(FeatureIO): 7 | def read(self, filename, entry): 8 | return pickle_read(self, filename) 9 | 10 | def write(self, data, filename, entry): 11 | pickle_write(self, data, filename) 12 | 13 | def visualize(self, data, filename, entry, override_sr): 14 | sr = entry.prop.sr 15 | win_shift = entry.prop.hop_length 16 | beat = entry.beat 17 | assert len(beat) - 1 == data.shape[0] 18 | n_frame = int(beat[-1] * sr / win_shift) + data.shape[1] + 1 19 | new_data = np.ones((n_frame, data.shape[2])) * -1 20 | for i in range(len(beat) - 1): 21 | time = int(np.round(beat[i] * sr / win_shift)) 22 | for j in range(data.shape[1]): 23 | time_j = time + j 24 | if time_j >= 0 and time_j < n_frame: 25 | new_data[time_j, :] = data[i, j, :] 26 | f = open(os.path.join(PACKAGE_PATH, "data/spectogram_template.svl"), "r") 27 | content = f.read() 28 | f.close() 29 | content = content.replace("[__SR__]", str(sr)) 30 | content = content.replace("[__WIN_SHIFT__]", str(win_shift)) 31 | content = content.replace("[__SHAPE_1__]", str(new_data.shape[1])) 32 | content = content.replace("[__COLOR__]", str(1)) 33 | labels = [str(i) for i in range(new_data.shape[1])] 34 | content = content.replace("[__DATA__]", create_svl_3d_data(labels, new_data)) 35 | f = open(filename, "w") 36 | f.write(content) 37 | f.close() 38 | 39 | def get_visualize_extention_name(self): 40 | return "svl" 41 | 42 | 43 | class BeatSpectrogramIO(FeatureIO): 44 | def read(self, filename, entry): 45 | return pickle_read(self, filename) 46 | 47 | def write(self, data, filename, entry): 48 | pickle_write(self, data, filename) 49 | 50 | def visualize(self, data, filename, entry, override_sr): 51 | sr = entry.prop.sr 52 | win_shift = entry.prop.hop_length 53 | beat = entry.beat 54 | assert len(beat) - 1 == data.shape[0] 55 | n_frame = int(beat[-1] * sr / win_shift) + data.shape[1] + 1 56 | new_data = np.ones((n_frame, data.shape[1])) * -1 57 | for i in range(len(beat) - 1): 58 | start_time = int(np.round(beat[i] * sr / win_shift)) 59 | end_time = int(np.round(beat[i + 1] * sr / win_shift)) 60 | for j in range(start_time, end_time): 61 | new_data[j, :] = data[i, :] 62 | f = open(os.path.join(PACKAGE_PATH, "data/spectogram_template.svl"), "r") 63 | content = f.read() 64 | f.close() 65 | content = content.replace("[__SR__]", str(sr)) 66 | content = content.replace("[__WIN_SHIFT__]", str(win_shift)) 67 | content = content.replace("[__SHAPE_1__]", str(new_data.shape[1])) 68 | content = content.replace("[__COLOR__]", str(1)) 69 | labels = [str(i) for i in range(new_data.shape[1])] 70 | content = content.replace("[__DATA__]", create_svl_3d_data(labels, new_data)) 71 | f = open(filename, "w") 72 | f.write(content) 73 | f.close() 74 | 75 | def get_visualize_extention_name(self): 76 | return "svl" 77 | -------------------------------------------------------------------------------- /polyffusion/dl_modules/chord_dec.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class ChordDecoder(nn.Module): 8 | def __init__( 9 | self, input_dim=36, z_input_dim=256, hidden_dim=512, z_dim=256, n_step=8 10 | ): 11 | super(ChordDecoder, self).__init__() 12 | self.z2dec_hid = nn.Linear(z_dim, hidden_dim) 13 | self.z2dec_in = nn.Linear(z_dim, z_input_dim) 14 | self.gru = nn.GRU( 15 | input_dim + z_input_dim, hidden_dim, batch_first=True, bidirectional=False 16 | ) 17 | self.init_input = nn.Parameter(torch.rand(36)) 18 | self.input_dim = input_dim 19 | self.hidden_dim = hidden_dim 20 | self.z_dim = z_dim 21 | self.root_out = nn.Linear(hidden_dim, 12) 22 | self.chroma_out = nn.Linear(hidden_dim, 24) 23 | self.bass_out = nn.Linear(hidden_dim, 12) 24 | self.n_step = n_step 25 | self.loss_func = nn.CrossEntropyLoss() 26 | 27 | def forward(self, z_chd, inference, tfr, gt_chd=None): 28 | # z_chd: (B, z_chd_size) 29 | bs = z_chd.size(0) 30 | z_chd_hid = self.z2dec_hid(z_chd).unsqueeze(0) 31 | z_chd_in = self.z2dec_in(z_chd).unsqueeze(1) 32 | if inference: 33 | tfr = 0.0 34 | token = self.init_input.repeat(bs, 1).unsqueeze(1) 35 | recon_root = [] 36 | recon_chroma = [] 37 | recon_bass = [] 38 | 39 | for t in range(self.n_step): 40 | chd_t, z_chd_hid = self.gru(torch.cat([token, z_chd_in], dim=-1), z_chd_hid) 41 | 42 | # compute output distribution 43 | r_root = self.root_out(chd_t) # (bs, 1, 12) 44 | r_chroma = self.chroma_out(chd_t).view(bs, 1, 12, 2).contiguous() 45 | r_bass = self.bass_out(chd_t) # (bs, 1, 12) 46 | 47 | # write distribution on the list 48 | recon_root.append(r_root) 49 | recon_chroma.append(r_chroma) 50 | recon_bass.append(r_bass) 51 | 52 | # prepare the input to the next step 53 | if t == self.n_step - 1: 54 | break 55 | teacher_force = random.random() < tfr 56 | if teacher_force and not inference: 57 | token = gt_chd[:, t].unsqueeze(1) 58 | else: 59 | t_root = torch.zeros(bs, 1, 12).to(z_chd.device).float() 60 | t_root[torch.arange(0, bs), 0, r_root.max(-1)[-1]] = 1.0 61 | t_chroma = r_chroma.max(-1)[-1].float() 62 | t_bass = torch.zeros(bs, 1, 12).to(z_chd.device).float() 63 | t_bass[torch.arange(0, bs), 0, r_bass.max(-1)[-1]] = 1.0 64 | token = torch.cat([t_root, t_chroma, t_bass], dim=-1) 65 | 66 | recon_root = torch.cat(recon_root, dim=1) 67 | recon_chroma = torch.cat(recon_chroma, dim=1) 68 | recon_bass = torch.cat(recon_bass, dim=1) 69 | # print(recon_root.shape, recon_chroma.shape, recon_bass.shape) 70 | return recon_root, recon_chroma, recon_bass 71 | 72 | def recon_loss(self, c, recon_root, recon_chroma, recon_bass): 73 | loss_fun = self.loss_func 74 | root = c[:, :, 0:12].max(-1)[-1].view(-1).contiguous() 75 | chroma = c[:, :, 12:24].long().view(-1).contiguous() 76 | bass = c[:, :, 24:].max(-1)[-1].view(-1).contiguous() 77 | 78 | recon_root = recon_root.view(-1, 12).contiguous() 79 | recon_chroma = recon_chroma.view(-1, 2).contiguous() 80 | recon_bass = recon_bass.view(-1, 12).contiguous() 81 | root_loss = loss_fun(recon_root, root) 82 | chroma_loss = loss_fun(recon_chroma, chroma) 83 | bass_loss = loss_fun(recon_bass, bass) 84 | chord_loss = root_loss + chroma_loss + bass_loss 85 | return chord_loss, root_loss, chroma_loss, bass_loss 86 | -------------------------------------------------------------------------------- /polyffusion/ddpm/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.utils.data 6 | from torch import nn 7 | 8 | from .utils import gather 9 | 10 | 11 | class DenoiseDiffusion(nn.Module): 12 | """ 13 | ## Denoise Diffusion 14 | """ 15 | 16 | def __init__(self, eps_model: nn.Module, n_steps: int): 17 | """ 18 | * `eps_model` is $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ model 19 | * `n_steps` is $t$ 20 | """ 21 | super().__init__() 22 | self.eps_model = eps_model 23 | 24 | # Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule 25 | self.register_buffer("beta", torch.linspace(0.0001, 0.02, n_steps)) 26 | 27 | # $\alpha_t = 1 - \beta_t$ 28 | self.alpha = 1.0 - self.beta 29 | # $\bar\alpha_t = \prod_{s=1}^t \alpha_s$ 30 | self.alpha_bar = torch.cumprod(self.alpha, dim=0) 31 | # $T$ 32 | self.n_steps = n_steps 33 | # $\sigma^2 = \beta$ 34 | self.sigma2 = self.beta 35 | 36 | def q_xt_x0( 37 | self, x0: torch.Tensor, t: torch.Tensor 38 | ) -> Tuple[torch.Tensor, torch.Tensor]: 39 | """ 40 | #### Get $q(x_t|x_0)$ distribution 41 | """ 42 | 43 | # [gather](utils.html) $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$ 44 | mean = gather(self.alpha_bar, t) ** 0.5 * x0 45 | # $(1-\bar\alpha_t) \mathbf{I}$ 46 | var = 1 - gather(self.alpha_bar, t) 47 | # 48 | return mean, var 49 | 50 | def q_sample( 51 | self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None 52 | ): 53 | """ 54 | #### Sample from $q(x_t|x_0)$ 55 | """ 56 | 57 | # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ 58 | if eps is None: 59 | eps = torch.randn_like(x0) 60 | 61 | # get $q(x_t|x_0)$ 62 | mean, var = self.q_xt_x0(x0, t) 63 | # Sample from $q(x_t|x_0)$ 64 | return mean + (var**0.5) * eps 65 | 66 | def p_sample(self, xt: torch.Tensor, t: torch.Tensor): 67 | """ 68 | #### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$ 69 | """ 70 | 71 | # $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ 72 | eps_theta = self.eps_model(xt, t) 73 | # [gather](utils.html) $\bar\alpha_t$ 74 | alpha_bar = gather(self.alpha_bar, t) 75 | # $\alpha_t$ 76 | alpha = gather(self.alpha, t) 77 | # $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$ 78 | eps_coef = (1 - alpha) / (1 - alpha_bar) ** 0.5 79 | # $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t - 80 | # \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$ 81 | mean = 1 / (alpha**0.5) * (xt - eps_coef * eps_theta) 82 | # $\sigma^2$ 83 | var = gather(self.sigma2, t) 84 | 85 | # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ 86 | eps = torch.randn(xt.shape, device=xt.device) 87 | # Sample 88 | return mean + (var**0.5) * eps 89 | 90 | def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None): 91 | """ 92 | #### Simplified Loss 93 | """ 94 | # Get batch size 95 | batch_size = x0.shape[0] 96 | # Get random $t$ for each sample in the batch 97 | t = torch.randint( 98 | 0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long 99 | ) 100 | 101 | # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ 102 | if noise is None: 103 | noise = torch.randn_like(x0) 104 | 105 | # Sample $x_t$ for $q(x_t|x_0)$ 106 | xt = self.q_sample(x0, t, eps=noise) 107 | # Get $\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$ 108 | eps_theta = self.eps_model(xt, t) 109 | 110 | # MSE loss 111 | return F.mse_loss(noise, eps_theta) 112 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/io/implement/regional_spectrogram_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mir.common import PACKAGE_PATH 3 | from mir.io.feature_io_base import * 4 | 5 | 6 | class RegionalSpectrogramIO(FeatureIO): 7 | def read(self, filename, entry): 8 | data = pickle_read(self, filename) 9 | assert len(data) == 3 or len(data) == 2 10 | return data 11 | 12 | def write(self, data, filename, entry): 13 | assert len(data) == 3 or len(data) == 2 14 | pickle_write(self, data, filename) 15 | 16 | def visualize(self, data, filename, entry, override_sr): 17 | if len(data) == 2: 18 | timing, data = data 19 | labels = None 20 | elif len(data) == 3: 21 | labels, timing, data = data 22 | else: 23 | raise Exception("Format error") 24 | data = np.array(data) 25 | if len(data.shape) == 1: 26 | data = data.reshape((-1, 1)) 27 | sr = entry.prop.sr 28 | win_shift = entry.prop.hop_length 29 | timing = np.array(timing).reshape((len(timing), -1)) 30 | n_frame = max(1, int(np.round(np.max(timing * sr / win_shift)))) 31 | data_indices = (-1) * np.ones(n_frame, dtype=np.int32) 32 | timing_start = timing[: len(data), 0] 33 | if timing.shape[1] == 1: 34 | assert len(timing) == len(data) or len(timing) == len(data) + 1 35 | if len(timing) == len(data) + 1: 36 | timing_end = timing[1:, 0] 37 | else: 38 | timing_end = np.append( 39 | timing[1:, 0], 40 | timing[-1, 0] * 2 - timing[-2, 0] if (len(timing) > 1) else 1.0, 41 | ) 42 | else: 43 | timing_end = timing[:, 1] 44 | for i in range(len(data)): 45 | frame_start = max(0, int(np.round(timing_start[i] * sr / win_shift))) 46 | frame_end = max(0, int(np.round(timing_end[i] * sr / win_shift))) 47 | data_indices[frame_start:frame_end] = i 48 | if data.shape[1] >= 1: 49 | f = open(os.path.join(PACKAGE_PATH, "data/spectrogram_template.svl"), "r") 50 | content = f.read() 51 | f.close() 52 | content = content.replace("[__SR__]", str(sr)) 53 | content = content.replace("[__WIN_SHIFT__]", str(win_shift)) 54 | content = content.replace("[__SHAPE_1__]", str(data.shape[1])) 55 | content = content.replace("[__COLOR__]", str(1)) 56 | if labels is None: 57 | labels = [str(i) for i in range(data.shape[1])] 58 | assert len(labels) == len(data[0]) 59 | result = ( 60 | "\n".join( 61 | [ 62 | '' % (i, str(labels[i])) 63 | for i in range(len(labels)) 64 | ] 65 | ) 66 | + "\n" 67 | ) 68 | for i in range(n_frame): 69 | if data_indices[i] >= 0: 70 | result += '%s\n' % ( 71 | i, 72 | " ".join([str(s) for s in data[data_indices[i]]]), 73 | ) 74 | content = content.replace("[__DATA__]", result) 75 | else: 76 | f = open(os.path.join(PACKAGE_PATH, "data/curve_template.svl"), "r") 77 | content = f.read() 78 | f.close() 79 | content = content.replace("[__SR__]", str(sr)) 80 | content = content.replace("[__STYLE__]", str(1)) 81 | results = [] 82 | raise NotImplementedError() 83 | # for i in range(0, len(data)): 84 | # results.append(''%(int(override_sr/sr*i*win_shift),data[i,0])) 85 | # content = content.replace('[__DATA__]','\n'.join(results)) 86 | # content = content.replace('[__NAME__]', 'curve') 87 | 88 | f = open(filename, "w") 89 | f.write(content) 90 | f.close() 91 | 92 | def get_visualize_extention_name(self): 93 | return "svl" 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Polyffusion: A Diffusion Model for Polyphonic Score Generation with Internal and External Controls 2 | 3 | - [Paper link](https://arxiv.org/abs/2307.10304) 4 | - Check our [demo page](https://polyffusion.github.io/) and give it a listen! 5 | 6 | ``` 7 | @inproceedings{polyffusion2023, 8 | author = {Lejun Min and Junyan Jiang and Gus Xia and Jingwei Zhao}, 9 | title = {Polyffusion: A Diffusion Model for Polyphonic Score Generation with Internal and External Controls}, 10 | booktitle = {Proceedings of the 24th International Society for Music Information Retrieval Conference, {ISMIR}}, 11 | year = {2023} 12 | } 13 | ``` 14 | 15 | ## Installation 16 | 17 | ```shell 18 | pip install -r requirements.txt 19 | pip install -e polyffusion 20 | pip install -e polyffusion/chord_extractor 21 | pip isntall -e polyffusion/mir_eval 22 | ``` 23 | 24 | ## Some Clarifications 25 | 26 | - The abbreviation "sdf" means Stable Diffusion, and "ldm" means Latent Diffusion. Basically they are referring to the same thing. However, we only borrow the cross-attention conditioning mechanism from Latent Diffusion, without utilizing its encoder and decoder. The latter is left for future experiments. 27 | - `prmat2c` in the code is the piano-roll image representation. 28 | 29 | ## Preparations 30 | 31 | - The extracted features of the dataset POP909 can be accessed [here](https://yukisaki-my.sharepoint.com/:u:/g/personal/aik2_yukisaki_io/EdUovlRZvExJrGatAR8BlTsBDC8udJiuhnIimPuD2PQ3FQ?e=WwD7Dl). Please put it under `/data/` after extraction. 32 | 33 | - The needed pre-trained models for training can be accessed [here](https://yukisaki-my.sharepoint.com/:u:/g/personal/aik2_yukisaki_io/Eca406YwV1tMgwHdoepC7G8B5l-4GRBGv7TzrI9OOg3eIA?e=uecJdU). Please put them under `/pretrained/` after extraction. 34 | 35 | ## Training 36 | 37 | ### Modifications 38 | 39 | - You can modify the parameters in the corresponding `*.yaml` files under `/polyffusion/params/`, or create your own. 40 | 41 | ### Commands 42 | 43 | ```shell 44 | python polyffusion/main.py --model [model] --output_dir [output_dir] 45 | ``` 46 | 47 | The models can be selected from `/polyffusion/params/[model].yaml`. Here are some cases: 48 | 49 | - `sdf_chd8bar`: conditioned on latent chord representations encoded by a pre-trained chord encoder. 50 | - `sdf_txt`: conditioned on latent texture representations encoded by a pre-trained texture encoder. 51 | - `sdf_chdvnl`: conditioned on vanilla chord representations. 52 | - `sdf_txtvnl`: conditioned on vanilla texture representations. 53 | - `ddpm`: vanilla diffusion model from DDPM without conditioning. 54 | 55 | Examples: 56 | 57 | ```shell 58 | python polyffusion/main.py --model sdf_chd8bar --output_dir result/sdf_chd8bar 59 | ``` 60 | 61 | ### Trained Checkpoints 62 | 63 | If you'd like to test our trained checkpoints, please access the folder [here](https://yukisaki-my.sharepoint.com/:f:/g/personal/aik2_yukisaki_io/EjG0IB8Xb_1CoVfYCmNUB-ABMLVSRqJST4VTrYJxjJFdnw?e=OqmZpp). We suggest to put them under `/result/` after extraction for inference. 64 | 65 | ## Inference 66 | 67 | Please see the helping messages by running 68 | 69 | ```shell 70 | python polyffusion/inference_sdf.py --help 71 | ``` 72 | 73 | Examples: 74 | 75 | ```shell 76 | # unconditional generation of length 10x8 bars 77 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --uncond_scale=0. --length=10 78 | 79 | # conditional generation using DDIM sampler (default guidance scale = 1) 80 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --ddim --ddim_steps=50 --ddim_eta=0.0 --ddim_discretize=uniform 81 | 82 | # conditional generation with guidance scale = 5, conditional chord progressions chosen from a song from POP909 validation set. 83 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --uncond_scale=5. 84 | 85 | # conditional iterative inpainting (i.e. autoregressive generation) (default guidance scale = 1) 86 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --autoreg 87 | 88 | # unconditional melody generation given accompaniment 89 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --uncond_scale=0. --inpaint_from_midi=/path/to/accompaniment.mid --inpaint_type=above 90 | 91 | # accompaniment generation given melody, conditioned on chord progressions of another midi file (default guidance scale = 1) 92 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --inpaint_from_midi=/path/to/melody.mid --inpaint_type=below --from_midi=/path/to/cond_midi.mid 93 | ``` 94 | -------------------------------------------------------------------------------- /polyffusion/train/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from pathlib import Path 4 | from shutil import copy2 5 | 6 | import torch 7 | from lightning.pytorch import Trainer 8 | from lightning.pytorch.callbacks import ModelCheckpoint 9 | from lightning.pytorch.loggers import WandbLogger 10 | from omegaconf import OmegaConf 11 | from torch.optim import Optimizer 12 | from torch.utils.data import DataLoader 13 | 14 | from lightning_learner import LightningLearner 15 | from utils import convert_json_to_yaml 16 | 17 | 18 | class TrainConfig: 19 | model: torch.nn.Module 20 | train_dl: DataLoader 21 | val_dl: DataLoader 22 | optimizer: Optimizer 23 | 24 | def __init__(self, params, param_scheduler, output_dir) -> None: 25 | self.model_name = params.model_name 26 | self.params = params 27 | self.param_scheduler = param_scheduler 28 | 29 | self.resume = False 30 | if os.path.exists(f"{output_dir}/chkpts/last.ckpt"): 31 | print("Checkpoint already exists.") 32 | if input("Resume training? (y/n)") == "y": 33 | self.resume = True 34 | else: 35 | print("Aborting...") 36 | exit(0) 37 | else: 38 | output_dir = f"{output_dir}/{datetime.now().strftime('%y-%m-%d_%H%M%S')}" 39 | print(f"Creating new log folder as {output_dir}") 40 | os.makedirs(output_dir, exist_ok=True) 41 | 42 | self.output_dir = output_dir 43 | self.log_dir = f"{output_dir}/logs" 44 | self.checkpoint_dir = f"{output_dir}/chkpts" 45 | self.time_stamp = Path(output_dir).name 46 | 47 | os.makedirs(self.log_dir, exist_ok=True) 48 | os.makedirs(self.checkpoint_dir, exist_ok=True) 49 | 50 | # json to yaml (compatibility) 51 | if os.path.exists(f"{output_dir}/params.json"): 52 | convert_json_to_yaml(f"{output_dir}/params.json") 53 | 54 | if os.path.exists(f"{output_dir}/params.yaml"): 55 | old_params = OmegaConf.load(f"{output_dir}/params.yaml") 56 | 57 | # The "weights" attribute is a tuple in AttrDict, but saved as a list. To compare these two, we make them both tuples: 58 | # if "weights" in old_params: 59 | # old_params["weights"] = tuple(old_params["weights"]) 60 | 61 | if old_params != self.params: 62 | print("New params differ, using new params could break things.") 63 | if ( 64 | input( 65 | "Do you want to keep the old params file (y/n)? The model will be trained on new params regardless." 66 | ) 67 | == "y" 68 | ): 69 | time_stamp = datetime.now().strftime("%y-%m-%d_%H%M%S") 70 | copy2( 71 | f"{output_dir}/params.yaml", 72 | f"{output_dir}/old_params_{time_stamp}.yaml", 73 | ) 74 | print(f"Old params saved as old_params_{time_stamp}.yaml") 75 | # save params 76 | OmegaConf.save(self.params, f"{output_dir}/params.yaml") 77 | 78 | def train(self): 79 | total_parameters = sum( 80 | p.numel() for p in self.model.parameters() if p.requires_grad 81 | ) 82 | print(f"Total parameters: {total_parameters}") 83 | print(OmegaConf.to_yaml(self.params)) 84 | 85 | checkpoint_callback = ModelCheckpoint( 86 | dirpath=self.checkpoint_dir, 87 | monitor="val/loss", 88 | filename="epoch{epoch}-val_loss{val/loss:.6f}", 89 | save_last=True, 90 | save_top_k=3, 91 | auto_insert_metric_name=False, 92 | ) 93 | logger = WandbLogger( 94 | project=f"Polyff-{self.model_name}", 95 | save_dir=self.log_dir, 96 | name=self.time_stamp, 97 | ) 98 | trainer = Trainer( 99 | default_root_dir=self.output_dir, 100 | callbacks=[checkpoint_callback], 101 | max_epochs=self.params.max_epoch, 102 | logger=logger, 103 | precision="16-mixed" if self.params.fp16 else "32-true", 104 | ) 105 | learner = LightningLearner( 106 | self.model, 107 | self.optimizer, 108 | self.params, 109 | self.param_scheduler, 110 | ) 111 | trainer.fit( 112 | learner, 113 | self.train_dl, 114 | self.val_dl, 115 | ckpt_path="last" if self.resume else None, 116 | ) 117 | -------------------------------------------------------------------------------- /polyffusion/dl_modules/pianotree_enc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.distributions import Normal 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | 6 | 7 | class PianoTreeEncoder(nn.Module): 8 | def __init__( 9 | self, 10 | max_simu_note=20, 11 | max_pitch=127, 12 | min_pitch=0, 13 | pitch_sos=128, 14 | pitch_eos=129, 15 | pitch_pad=130, 16 | dur_pad=2, 17 | dur_width=5, 18 | num_step=32, 19 | note_emb_size=128, 20 | enc_notes_hid_size=256, 21 | enc_time_hid_size=512, 22 | z_size=512, 23 | ): 24 | super(PianoTreeEncoder, self).__init__() 25 | 26 | # Parameters 27 | # note and time 28 | self.max_pitch = max_pitch # the highest pitch in train/val set. 29 | self.min_pitch = min_pitch # the lowest pitch in train/val set. 30 | self.pitch_sos = pitch_sos 31 | self.pitch_eos = pitch_eos 32 | self.pitch_pad = pitch_pad 33 | self.pitch_range = max_pitch - min_pitch + 3 # not including pad. 34 | self.dur_pad = dur_pad 35 | self.dur_width = dur_width 36 | self.note_size = self.pitch_range + dur_width 37 | self.max_simu_note = max_simu_note # the max # of notes at each ts. 38 | self.num_step = num_step # 32 39 | self.note_emb_size = note_emb_size 40 | self.z_size = z_size 41 | self.enc_notes_hid_size = enc_notes_hid_size 42 | self.enc_time_hid_size = enc_time_hid_size 43 | 44 | self.note_embedding = nn.Linear(self.note_size, note_emb_size) 45 | self.enc_notes_gru = nn.GRU( 46 | note_emb_size, 47 | enc_notes_hid_size, 48 | num_layers=1, 49 | batch_first=True, 50 | bidirectional=True, 51 | ) 52 | self.enc_time_gru = nn.GRU( 53 | 2 * enc_notes_hid_size, 54 | enc_time_hid_size, 55 | num_layers=1, 56 | batch_first=True, 57 | bidirectional=True, 58 | ) 59 | self.linear_mu = nn.Linear(2 * enc_time_hid_size, z_size) 60 | self.linear_std = nn.Linear(2 * enc_time_hid_size, z_size) 61 | 62 | @property 63 | def device(self): 64 | """ 65 | ### Get model device 66 | """ 67 | return next(iter(self.parameters())).device 68 | 69 | def get_len_index_tensor(self, ind_x): 70 | """Calculate the lengths ((B, 32), torch.LongTensor) of pgrid.""" 71 | with torch.no_grad(): 72 | lengths = self.max_simu_note - ( 73 | ind_x[:, :, :, 0] - self.pitch_pad == 0 74 | ).sum(dim=-1) 75 | return lengths.to("cpu") 76 | 77 | def index_tensor_to_multihot_tensor(self, ind_x): 78 | """Transfer piano_grid to multi-hot piano_grid.""" 79 | # ind_x: (B, 32, max_simu_note, 1 + dur_width) 80 | with torch.no_grad(): 81 | dur_part = ind_x[:, :, :, 1:].float() 82 | out = torch.zeros( 83 | [ 84 | ind_x.size(0) * self.num_step * self.max_simu_note, 85 | self.pitch_range + 1, 86 | ], 87 | dtype=torch.float, 88 | ).to(self.device) 89 | 90 | out[range(0, out.size(0)), ind_x[:, :, :, 0].reshape(-1)] = 1.0 91 | out = out.view(-1, 32, self.max_simu_note, self.pitch_range + 1) 92 | out = torch.cat([out[:, :, :, 0 : self.pitch_range], dur_part], dim=-1) 93 | return out 94 | 95 | def encoder(self, x, lengths): 96 | embedded = self.note_embedding(x) 97 | # x: (B, num_step, max_simu_note, note_emb_size) 98 | # now x are notes 99 | x = embedded.view(-1, self.max_simu_note, self.note_emb_size) 100 | x = pack_padded_sequence( 101 | x, lengths.view(-1), batch_first=True, enforce_sorted=False 102 | ) 103 | x = self.enc_notes_gru(x)[-1].transpose(0, 1).contiguous() 104 | x = x.view(-1, self.num_step, 2 * self.enc_notes_hid_size) 105 | # now, x is simu_notes. 106 | x = self.enc_time_gru(x)[-1].transpose(0, 1).contiguous() 107 | # x: (B, 2, enc_time_hid_size) 108 | x = x.view(x.size(0), -1) 109 | mu = self.linear_mu(x) # (B, z_size) 110 | std = self.linear_std(x).exp_() # (B, z_size) 111 | dist = Normal(mu, std) 112 | return dist, embedded 113 | 114 | def forward(self, x, return_iterators=False): 115 | lengths = self.get_len_index_tensor(x) 116 | x = self.index_tensor_to_multihot_tensor(x) 117 | dist, embedded_x = self.encoder(x, lengths) 118 | if return_iterators: 119 | return dist.mean, dist.scale, embedded_x 120 | else: 121 | return dist, embedded_x, lengths 122 | -------------------------------------------------------------------------------- /polyffusion/mir_eval/onset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The goal of an onset detection algorithm is to automatically determine when 3 | notes are played in a piece of music. The primary method used to evaluate 4 | onset detectors is to first determine which estimated onsets are "correct", 5 | where correctness is defined as being within a small window of a reference 6 | onset. 7 | 8 | Based in part on this script: 9 | 10 | https://github.com/CPJKU/onset_detection/blob/master/onset_evaluation.py 11 | 12 | Conventions 13 | ----------- 14 | 15 | Onsets should be provided in the form of a 1-dimensional array of onset 16 | times in seconds in increasing order. 17 | 18 | Metrics 19 | ------- 20 | 21 | * :func:`mir_eval.onset.f_measure`: Precision, Recall, and F-measure scores 22 | based on the number of esimated onsets which are sufficiently close to 23 | reference onsets. 24 | ''' 25 | 26 | import collections 27 | from . import util 28 | import warnings 29 | 30 | 31 | # The maximum allowable beat time 32 | MAX_TIME = 30000. 33 | 34 | 35 | def validate(reference_onsets, estimated_onsets): 36 | """Checks that the input annotations to a metric look like valid onset time 37 | arrays, and throws helpful errors if not. 38 | 39 | Parameters 40 | ---------- 41 | reference_onsets : np.ndarray 42 | reference onset locations, in seconds 43 | estimated_onsets : np.ndarray 44 | estimated onset locations, in seconds 45 | 46 | """ 47 | # If reference or estimated onsets are empty, warn because metric will be 0 48 | if reference_onsets.size == 0: 49 | warnings.warn("Reference onsets are empty.") 50 | if estimated_onsets.size == 0: 51 | warnings.warn("Estimated onsets are empty.") 52 | for onsets in [reference_onsets, estimated_onsets]: 53 | util.validate_events(onsets, MAX_TIME) 54 | 55 | 56 | def f_measure(reference_onsets, estimated_onsets, window=.05): 57 | """Compute the F-measure of correct vs incorrectly predicted onsets. 58 | "Corectness" is determined over a small window. 59 | 60 | Examples 61 | -------- 62 | >>> reference_onsets = mir_eval.io.load_events('reference.txt') 63 | >>> estimated_onsets = mir_eval.io.load_events('estimated.txt') 64 | >>> F, P, R = mir_eval.onset.f_measure(reference_onsets, 65 | ... estimated_onsets) 66 | 67 | Parameters 68 | ---------- 69 | reference_onsets : np.ndarray 70 | reference onset locations, in seconds 71 | estimated_onsets : np.ndarray 72 | estimated onset locations, in seconds 73 | window : float 74 | Window size, in seconds 75 | (Default value = .05) 76 | 77 | Returns 78 | ------- 79 | f_measure : float 80 | 2*precision*recall/(precision + recall) 81 | precision : float 82 | (# true positives)/(# true positives + # false positives) 83 | recall : float 84 | (# true positives)/(# true positives + # false negatives) 85 | 86 | """ 87 | validate(reference_onsets, estimated_onsets) 88 | # If either list is empty, return 0s 89 | if reference_onsets.size == 0 or estimated_onsets.size == 0: 90 | return 0., 0., 0. 91 | # Compute the best-case matching between reference and estimated onset 92 | # locations 93 | matching = util.match_events(reference_onsets, estimated_onsets, window) 94 | 95 | precision = float(len(matching))/len(estimated_onsets) 96 | recall = float(len(matching))/len(reference_onsets) 97 | # Compute F-measure and return all statistics 98 | return util.f_measure(precision, recall), precision, recall 99 | 100 | 101 | def evaluate(reference_onsets, estimated_onsets, **kwargs): 102 | """Compute all metrics for the given reference and estimated annotations. 103 | 104 | Examples 105 | -------- 106 | >>> reference_onsets = mir_eval.io.load_events('reference.txt') 107 | >>> estimated_onsets = mir_eval.io.load_events('estimated.txt') 108 | >>> scores = mir_eval.onset.evaluate(reference_onsets, 109 | ... estimated_onsets) 110 | 111 | Parameters 112 | ---------- 113 | reference_onsets : np.ndarray 114 | reference onset locations, in seconds 115 | estimated_onsets : np.ndarray 116 | estimated onset locations, in seconds 117 | kwargs 118 | Additional keyword arguments which will be passed to the 119 | appropriate metric or preprocessing functions. 120 | 121 | Returns 122 | ------- 123 | scores : dict 124 | Dictionary of scores, where the key is the metric name (str) and 125 | the value is the (float) score achieved. 126 | 127 | """ 128 | # Compute all metrics 129 | scores = collections.OrderedDict() 130 | 131 | (scores['F-measure'], 132 | scores['Precision'], 133 | scores['Recall']) = util.filter_kwargs(f_measure, reference_onsets, 134 | estimated_onsets, **kwargs) 135 | 136 | return scores 137 | -------------------------------------------------------------------------------- /polyffusion/data/polydis_format_to_mine.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | 4 | import numpy as np 5 | import pretty_midi as pm 6 | from tqdm import tqdm 7 | 8 | ORIGIN_DIR = "data/POP09-PIANOROLL-4-bin-quantization" 9 | NEW_DIR = "data/POP909_4_bin_pnt_8bar" 10 | 11 | ONE_BEAT_TIME = 0.5 12 | SEG_LGTH = 32 13 | BEAT = 4 14 | BIN = 4 15 | SEG_LGTH_BIN = SEG_LGTH * BIN 16 | 17 | 18 | def get_note_matrix(mats): 19 | """ 20 | (onset_beat, onset_bin, bin, offset_beat, offset_bin, bin, pitch, velocity) 21 | """ 22 | notes = [] 23 | 24 | for mat in mats: 25 | assert mat[2] == mat[5] == BIN 26 | onset = mat[0] * BIN + mat[1] 27 | offset = mat[3] * BIN + mat[4] 28 | duration = offset - onset 29 | if duration > 0: 30 | # this is compulsory because there may be notes 31 | # with zero duration after adjusting resolution 32 | notes.append([onset, mat[6], duration, mat[7], 0]) 33 | # sort according to (start, duration) 34 | # notes.sort(key=lambda x: (x[0] * BIN + x[1], x[2])) 35 | notes.sort(key=lambda x: (x[0], x[1], x[2])) 36 | return notes 37 | 38 | 39 | def dedup_note_matrix(notes): 40 | """ 41 | remove duplicated notes (because of multiple tracks) 42 | """ 43 | 44 | last = [] 45 | notes_dedup = [] 46 | for i, note in enumerate(notes): 47 | if i != 0: 48 | if note[:2] != last[:2]: 49 | # if start and pitch are not the same 50 | notes_dedup.append(note) 51 | else: 52 | notes_dedup.append(note) 53 | last = note 54 | # print(f"dedup: {len(notes) - len(notes_dedup)} : {len(notes)}") 55 | 56 | return notes_dedup 57 | 58 | 59 | def retrieve_midi_from_nmat(notes, output_fpath): 60 | """ 61 | retrieve midi from note matrix 62 | """ 63 | midi = pm.PrettyMIDI() 64 | piano_program = pm.instrument_name_to_program("Acoustic Grand Piano") 65 | piano = pm.Instrument(program=piano_program) 66 | for note in notes: 67 | onset, pitch, duration, velocity, program = note 68 | start = onset * ONE_BEAT_TIME / float(BIN) 69 | end = start + duration * ONE_BEAT_TIME / float(BIN) 70 | pm_note = pm.Note(velocity, pitch, start, end) 71 | piano.notes.append(pm_note) 72 | 73 | midi.instruments.append(piano) 74 | midi.write(output_fpath) 75 | 76 | 77 | def get_downbeat_pos_and_filter(notes, beats): 78 | """ 79 | beats: [0-1, 2/3beat-cnt, 2, 0-3, 4/6beat-cnt, 4] 80 | """ 81 | end_time = notes[-1][0] 82 | db_pos = [] 83 | for i, beat in enumerate(beats): 84 | if beat[3] == 0: 85 | pos = i * BIN 86 | db_pos.append(pos) 87 | 88 | # print(db_pos) 89 | db_pos_filter = [] 90 | for idx, db in enumerate(db_pos): 91 | if ( 92 | idx + (SEG_LGTH / BEAT) <= len(db_pos) 93 | and db_pos[idx + 1] - db == BEAT * BIN 94 | ): 95 | db_pos_filter.append(True) 96 | else: 97 | db_pos_filter.append(False) 98 | # print(db_pos_filter) 99 | return db_pos, db_pos_filter 100 | 101 | 102 | def get_start_table(notes, db_pos): 103 | """ 104 | i-th row indicates the starting row of the "notes" array at i-th beat. 105 | """ 106 | 107 | # simply add 8-beat padding in case of out-of-range index 108 | # total_beat = int(music.get_end_time()) + 8 109 | row_cnt = 0 110 | start_table = {} 111 | # for beat in range(total_beat): 112 | for db in db_pos: 113 | while row_cnt < len(notes) and notes[row_cnt][0] < db: 114 | row_cnt += 1 115 | start_table[db] = row_cnt 116 | 117 | return start_table 118 | 119 | 120 | def cat_note_mats(note_mats): 121 | return np.concatenate(note_mats, 0) 122 | 123 | 124 | if __name__ == "__main__": 125 | if os.path.exists(NEW_DIR): 126 | os.system(f"rm -rf {NEW_DIR}") 127 | os.makedirs(NEW_DIR) 128 | 129 | for piece in tqdm(os.listdir(ORIGIN_DIR)): 130 | fpath = os.path.join(ORIGIN_DIR, piece) 131 | f = np.load(fpath) 132 | melody = get_note_matrix(f["melody"]) 133 | bridge = get_note_matrix(f["bridge"]) 134 | piano = get_note_matrix(f["piano"]) 135 | beats = f["beat"] 136 | notes = cat_note_mats([melody, bridge, piano]) 137 | 138 | retrieve_midi_from_nmat( 139 | notes, os.path.join(NEW_DIR, piece[:-4] + "_flatten.mid") 140 | ) 141 | db_pos, db_pos_filter = get_downbeat_pos_and_filter(notes, beats) 142 | start_table_melody = get_start_table(melody, db_pos) 143 | start_table_bridge = get_start_table(bridge, db_pos) 144 | start_table_piano = get_start_table(piano, db_pos) 145 | np.savez( 146 | join(NEW_DIR, piece[:-4]), 147 | notes=[melody, bridge, piano], 148 | start_table=[start_table_melody, start_table_bridge, start_table_piano], 149 | db_pos=db_pos, 150 | db_pos_filter=db_pos_filter, 151 | chord=f["chord"], 152 | ) 153 | -------------------------------------------------------------------------------- /polyffusion/stable_diffusion/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | --- 3 | title: Utility functions for stable diffusion 4 | summary: > 5 | Utility functions for stable diffusion 6 | --- 7 | 8 | # Utility functions for [stable diffusion](index.html) 9 | """ 10 | 11 | import os 12 | import random 13 | from pathlib import Path 14 | 15 | import numpy as np 16 | import PIL 17 | import torch 18 | from labml import monit 19 | from labml.logger import inspect 20 | from PIL import Image 21 | 22 | from .latent_diffusion import LatentDiffusion 23 | from .model.autoencoder import Autoencoder, Decoder, Encoder 24 | 25 | # from model.clip_embedder import CLIPTextEmbedder 26 | from .model.unet import UNetModel 27 | 28 | 29 | def set_seed(seed: int): 30 | """ 31 | ### Set random seeds 32 | """ 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed_all(seed) 37 | 38 | 39 | def load_model(path: Path = None) -> LatentDiffusion: 40 | """ 41 | ### Load [`LatentDiffusion` model](latent_diffusion.html) 42 | """ 43 | 44 | # Initialize the autoencoder 45 | with monit.section("Initialize autoencoder"): 46 | encoder = Encoder( 47 | z_channels=4, 48 | in_channels=3, 49 | channels=128, 50 | channel_multipliers=[1, 2, 4, 4], 51 | n_resnet_blocks=2, 52 | ) 53 | 54 | decoder = Decoder( 55 | out_channels=3, 56 | z_channels=4, 57 | channels=128, 58 | channel_multipliers=[1, 2, 4, 4], 59 | n_resnet_blocks=2, 60 | ) 61 | 62 | autoencoder = Autoencoder( 63 | emb_channels=4, encoder=encoder, decoder=decoder, z_channels=4 64 | ) 65 | 66 | # Initialize the U-Net 67 | with monit.section("Initialize U-Net"): 68 | unet_model = UNetModel( 69 | in_channels=4, 70 | out_channels=4, 71 | channels=320, 72 | attention_levels=[0, 1, 2], 73 | n_res_blocks=2, 74 | channel_multipliers=[1, 2, 4, 4], 75 | n_heads=8, 76 | tf_layers=1, 77 | d_cond=768, 78 | ) 79 | 80 | # Initialize the Latent Diffusion model 81 | with monit.section("Initialize Latent Diffusion model"): 82 | model = LatentDiffusion( 83 | linear_start=0.00085, 84 | linear_end=0.0120, 85 | n_steps=1000, 86 | latent_scaling_factor=0.18215, 87 | autoencoder=autoencoder, 88 | unet_model=unet_model, 89 | ) 90 | 91 | # Load the checkpoint 92 | with monit.section(f"Loading model from {path}"): 93 | checkpoint = torch.load(path, map_location="cpu") 94 | 95 | # Set model state 96 | with monit.section("Load state"): 97 | missing_keys, extra_keys = model.load_state_dict( 98 | checkpoint["state_dict"], strict=False 99 | ) 100 | 101 | # Debugging output 102 | inspect( 103 | global_step=checkpoint.get("global_step", -1), 104 | missing_keys=missing_keys, 105 | extra_keys=extra_keys, 106 | _expand=True, 107 | ) 108 | 109 | # 110 | model.eval() 111 | return model 112 | 113 | 114 | def load_img(path: str): 115 | """ 116 | ### Load an image 117 | 118 | This loads an image from a file and returns a PyTorch tensor. 119 | 120 | :param path: is the path of the image 121 | """ 122 | # Open Image 123 | image = Image.open(path).convert("RGB") 124 | # Get image size 125 | w, h = image.size 126 | # Resize to a multiple of 32 127 | w = w - w % 32 128 | h = h - h % 32 129 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 130 | # Convert to numpy and map to `[-1, 1]` for `[0, 255]` 131 | image = np.array(image).astype(np.float32) * (2.0 / 255.0) - 1 132 | # Transpose to shape `[batch_size, channels, height, width]` 133 | image = image[None].transpose(0, 3, 1, 2) 134 | # Convert to torch 135 | return torch.from_numpy(image) 136 | 137 | 138 | def save_images( 139 | images: torch.Tensor, dest_path: str, prefix: str = "", img_format: str = "jpeg" 140 | ): 141 | """ 142 | ### Save a images 143 | 144 | :param images: is the tensor with images of shape `[batch_size, channels, height, width]` 145 | :param dest_path: is the folder to save images in 146 | :param prefix: is the prefix to add to file names 147 | :param img_format: is the image format 148 | """ 149 | 150 | # Create the destination folder 151 | os.makedirs(dest_path, exist_ok=True) 152 | 153 | # Map images to `[0, 1]` space and clip 154 | images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) 155 | # Transpose to `[batch_size, height, width, channels]` and convert to numpy 156 | images = images.cpu().permute(0, 2, 3, 1).numpy() 157 | 158 | # Save images 159 | for i, img in enumerate(images): 160 | img = Image.fromarray((255.0 * img).astype(np.uint8)) 161 | img.save( 162 | os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format 163 | ) 164 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/extractors/extractor_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from abc import ABC, abstractmethod 4 | 5 | from mir import io 6 | from mir.common import WORKING_PATH 7 | 8 | 9 | def pickle_read(filename): 10 | f = open(filename, "rb") 11 | obj = pickle.load(f) 12 | f.close() 13 | return obj 14 | 15 | 16 | def pickle_write(data, filename): 17 | f = open(filename, "wb") 18 | pickle.dump(data, f) 19 | f.close() 20 | 21 | 22 | def try_mkdir(filename): 23 | folder = os.path.dirname(filename) 24 | if not os.path.isdir(folder): 25 | os.makedirs(folder) 26 | 27 | 28 | class ExtractorBase(ABC): 29 | def require(self, *args): 30 | pass 31 | 32 | def get_feature_class(self): 33 | return io.UnknownIO 34 | 35 | @abstractmethod 36 | def extract(self, entry, **kwargs): 37 | pass 38 | 39 | def __create_cache_path(self, entry, cached_prop_record, input_kwargs): 40 | items = {} 41 | items_entry = {} 42 | for k in input_kwargs: 43 | items[k] = input_kwargs[k] 44 | for prop_name in cached_prop_record: 45 | if prop_name not in items: 46 | items_entry[prop_name] = entry.prop.get_unrecorded(prop_name) 47 | 48 | if len(items) == 0: 49 | folder_name = self.__class__.__name__ 50 | else: 51 | folder_name = ( 52 | self.__class__.__name__ 53 | + "/" 54 | + ",".join([k + "=" + str(items[k]) for k in sorted(items.keys())]) 55 | ) 56 | 57 | if len(items_entry) == 0: 58 | entry_name = entry.name + ".cache" 59 | else: 60 | entry_name = ( 61 | entry.name 62 | + "." 63 | + ",".join( 64 | [k + "=" + str(items_entry[k]) for k in sorted(items_entry.keys())] 65 | ) 66 | + ".cache" 67 | ) 68 | 69 | return os.path.join(WORKING_PATH, "cache_data", folder_name, entry_name) 70 | 71 | def extract_and_cache(self, entry, cache_enabled=True, **kwargs): 72 | folder_name = os.path.join(WORKING_PATH, "cache_data", self.__class__.__name__) 73 | prop_cache_filename = os.path.join(folder_name, "_prop_records.cache") 74 | if "cached_prop_record" in self.__dict__: 75 | cached_prop_record = self.__dict__["cached_prop_record"] 76 | else: 77 | if os.path.exists(prop_cache_filename): 78 | cached_prop_record = pickle_read(prop_cache_filename) 79 | else: 80 | cached_prop_record = None 81 | 82 | if ( 83 | cache_enabled 84 | and entry.name != "" 85 | and self.get_feature_class() != io.UnknownIO 86 | ): 87 | # Need cache 88 | need_io_create = False 89 | if cached_prop_record is None: 90 | entry.prop.start_record_reading() 91 | feature = self.extract(entry, **kwargs) 92 | cached_prop_record = sorted(entry.prop.end_record_reading()) 93 | try_mkdir(prop_cache_filename) 94 | pickle_write(cached_prop_record, prop_cache_filename) 95 | cache_file_name = self.__create_cache_path( 96 | entry, cached_prop_record, kwargs 97 | ) 98 | need_io_create = True 99 | else: 100 | cache_file_name = self.__create_cache_path( 101 | entry, cached_prop_record, kwargs 102 | ) 103 | if not os.path.isfile(cache_file_name): 104 | entry.prop.start_record_reading() 105 | feature = self.extract(entry, **kwargs) 106 | new_prop_record = sorted(entry.prop.end_record_reading()) 107 | if cached_prop_record != new_prop_record: 108 | print( 109 | "[Warning] Inconsistent cached properity requirement in %s, overrode:" 110 | % self.__class__.__name__ 111 | ) 112 | print("Old:", cached_prop_record) 113 | print("New:", new_prop_record) 114 | cached_prop_record = new_prop_record 115 | pickle_write(cached_prop_record, prop_cache_filename) 116 | cache_file_name = self.__create_cache_path( 117 | entry, cached_prop_record, kwargs 118 | ) 119 | need_io_create = True 120 | else: 121 | io_obj = self.get_feature_class()() 122 | feature = io_obj.safe_read(cache_file_name, entry) 123 | if need_io_create: 124 | io_obj = self.get_feature_class()() 125 | io_obj.create(feature, cache_file_name, entry) 126 | else: 127 | feature = self.extract(entry, **kwargs) 128 | return feature 129 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/example.out: -------------------------------------------------------------------------------- 1 | 0.0 2.6666399999999997 N 2 | 2.6666399999999997 5.333279999999999 N 3 | 5.333279999999999 7.99992 C#:min 4 | 7.99992 10.66656 A:maj9 5 | 10.66656 13.333200000000001 A:maj9 6 | 13.333200000000001 15.999840000000003 D:maj6(9) 7 | 15.999840000000003 18.666480000000004 D:maj6 8 | 18.666480000000004 21.333120000000005 A:maj9 9 | 21.333120000000005 23.999760000000006 A:maj9 10 | 23.999760000000006 26.666400000000007 D:maj6(9) 11 | 26.666400000000007 29.333040000000008 D:maj6(9) 12 | 29.333040000000008 31.99968000000001 A:maj 13 | 31.99968000000001 34.666320000000006 A:maj 14 | 34.666320000000006 37.33296000000001 D:maj6 15 | 37.33296000000001 39.99960000000001 D:maj6 16 | 39.99960000000001 42.66624000000001 A:maj 17 | 42.66624000000001 45.33288000000001 A:maj 18 | 45.33288000000001 47.99952000000001 D:maj6 19 | 47.99952000000001 50.66616000000001 D:maj6 20 | 50.66616000000001 51.99948000000001 C#:min/b3 21 | 51.99948000000001 53.33280000000001 F#:min 22 | 53.33280000000001 54.666120000000014 D:maj/3 23 | 54.666120000000014 55.999440000000014 E:7 24 | 55.999440000000014 57.332760000000015 C#:min 25 | 57.332760000000015 58.666080000000015 F#:min 26 | 58.666080000000015 59.999400000000016 D:maj/3 27 | 59.999400000000016 61.332720000000016 E:7/3 28 | 61.332720000000016 62.66604000000002 C#:min7/5 29 | 62.66604000000002 63.99936000000002 F#:min(9) 30 | 63.99936000000002 65.99934 B:min7 31 | 65.99934 66.666 E:maj 32 | 66.666 69.33263999999997 A:maj7 33 | 69.33263999999997 71.99927999999994 A:maj7 34 | 71.99927999999994 74.66591999999991 D:maj6 35 | 74.66591999999991 77.33255999999989 D:maj6 36 | 77.33255999999989 79.99919999999986 A:maj7 37 | 79.99919999999986 82.66583999999983 A:maj7 38 | 82.66583999999983 85.3324799999998 D:maj6(9) 39 | 85.3324799999998 87.99911999999978 D:maj6(9) 40 | 87.99911999999978 90.66575999999975 A:maj7 41 | 90.66575999999975 93.33239999999972 A:maj7 42 | 93.33239999999972 95.9990399999997 D:maj6(9) 43 | 95.9990399999997 98.66567999999967 D:maj6(9) 44 | 98.66567999999967 99.99899999999965 C#:min/b3 45 | 99.99899999999965 101.33231999999964 F#:min 46 | 101.33231999999964 102.66563999999963 D:maj/3 47 | 102.66563999999963 103.99895999999961 E:7 48 | 103.99895999999961 105.3322799999996 C#:min 49 | 105.3322799999996 106.66559999999959 F#:min 50 | 106.66559999999959 107.99891999999957 D:maj/3 51 | 107.99891999999957 109.33223999999956 E:7/3 52 | 109.33223999999956 110.66555999999954 C#:min7/5 53 | 110.66555999999954 111.99887999999953 F#:min(9) 54 | 111.99887999999953 113.99885999999951 B:min7 55 | 113.99885999999951 117.33215999999948 A:maj7/5 56 | 117.33215999999948 119.99879999999945 A:maj 57 | 119.99879999999945 122.66543999999942 D:maj6 58 | 122.66543999999942 125.3320799999994 D:maj6 59 | 125.3320799999994 127.99871999999937 A:maj 60 | 127.99871999999937 130.6653599999994 A:maj 61 | 130.6653599999994 133.33199999999943 D:maj6 62 | 133.33199999999943 135.99863999999945 D:maj6 63 | 135.99863999999945 138.66527999999948 A:maj9 64 | 138.66527999999948 141.3319199999995 A:maj9 65 | 141.3319199999995 143.99855999999954 C:aug 66 | 143.99855999999954 146.66519999999957 F#:min9 67 | 146.66519999999957 148.6651799999996 B:min(9) 68 | 148.6651799999996 151.99847999999963 E:maj6 69 | 151.99847999999963 153.99845999999965 G:maj9 70 | 153.99845999999965 157.3317599999997 E:min/b3 71 | 157.3317599999997 159.99839999999972 E:min11 72 | 159.99839999999972 162.66503999999975 E:min11 73 | 162.66503999999975 165.33167999999978 A:maj7 74 | 165.33167999999978 167.9983199999998 A:maj7 75 | 167.9983199999998 170.66495999999984 D:maj6(9) 76 | 170.66495999999984 173.33159999999987 D:maj6(9) 77 | 173.33159999999987 175.9982399999999 A:maj7 78 | 175.9982399999999 178.66487999999993 A:maj7 79 | 178.66487999999993 181.33151999999995 D:maj6(9) 80 | 181.33151999999995 183.99815999999998 D:maj6(9) 81 | 183.99815999999998 185.33148 C#:min/b3 82 | 185.33148 186.6648 F#:min 83 | 186.6648 187.99812000000003 D:maj/3 84 | 187.99812000000003 189.33144000000004 E:7 85 | 189.33144000000004 190.66476000000006 C#:min 86 | 190.66476000000006 191.99808000000007 F#:min 87 | 191.99808000000007 193.3314000000001 D:maj/3 88 | 193.3314000000001 194.6647200000001 E:7/3 89 | 194.6647200000001 195.99804000000012 C#:min7/5 90 | 195.99804000000012 197.33136000000013 F#:min(9) 91 | 197.33136000000013 199.33134000000015 B:min7 92 | 199.33134000000015 199.99800000000016 E:maj 93 | 199.99800000000016 202.6646400000002 A:maj7 94 | 202.6646400000002 205.33128000000022 A:maj7 95 | 205.33128000000022 207.99792000000025 D:maj6(9) 96 | 207.99792000000025 210.66456000000028 D:maj6(9) 97 | 210.66456000000028 213.3312000000003 A:maj7 98 | 213.3312000000003 215.99784000000034 A:maj7 99 | 215.99784000000034 218.66448000000037 D:maj(9) 100 | 218.66448000000037 221.3311200000004 D:maj6(9) 101 | 221.3311200000004 223.99776000000043 A:maj7 102 | 223.99776000000043 226.66440000000046 A:maj7 103 | 226.66440000000046 229.33104000000048 D:maj6(9) 104 | 229.33104000000048 231.9976800000005 A:maj13 105 | 231.9976800000005 234.66432000000054 A:maj7 106 | 234.66432000000054 237.33096000000057 A:maj7 107 | 237.33096000000057 239.9976000000006 D:maj6(9) 108 | 239.9976000000006 242.66424000000063 D:maj6(9) 109 | 242.66424000000063 245.33088000000066 N 110 | 245.33088000000066 247.33086000000068 N 111 | -------------------------------------------------------------------------------- /polyffusion/prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import muspy 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from data.midi_to_data import * 9 | 10 | 11 | def force_length(music: muspy.music, bars=8): 12 | """Loops a MIDI file if it's under the specified number of bars, in place.""" 13 | num_tracks_at_least_bars = sum( 14 | [ 15 | 1 if (track.get_end_time() + 15) // 16 >= bars else 0 16 | for track in music.tracks 17 | ] 18 | ) 19 | if num_tracks_at_least_bars > 0: 20 | return 21 | for track in music.tracks: 22 | timesteps = track.get_end_time() 23 | old_bars = (timesteps + 15) // 16 24 | div = bars // old_bars 25 | for i in range(1, div): 26 | tmp = track.deepcopy() 27 | tmp.adjust_time(lambda x: x + i * timesteps) 28 | track.notes.extend(tmp.notes) 29 | 30 | 31 | def get_note_matrix_melodies(music, ignore_non_melody=True): 32 | """Similar to get_note_matrix from data.midi_to_data, with an option to ignore non-melodies.""" 33 | notes = [] 34 | for inst in music.tracks: 35 | if ignore_non_melody and (inst.is_drum or inst.program >= 113): 36 | continue 37 | for note in inst.notes: 38 | onset = int(note.time) 39 | duration = int(note.duration) 40 | if duration > 0: 41 | notes.append( 42 | [ 43 | onset, 44 | note.pitch, 45 | duration, 46 | note.velocity, 47 | inst.program, 48 | ] 49 | ) 50 | notes.sort(key=lambda x: (x[0], x[1], x[2])) 51 | assert len(notes) # in case if a MIDI has only non-melodies 52 | return notes 53 | 54 | 55 | def prepare_npz(midi_dir, chords_dir, output_dir, force=False, ignore_non_melody=True): 56 | for dir in [chords_dir, output_dir]: 57 | if not os.path.exists(dir): 58 | os.makedirs(dir) 59 | ttl = 0 60 | success = 0 61 | downbeat_errors = 0 62 | chords_errors = 0 63 | for root, dirs, files in os.walk(midi_dir): 64 | for midi in tqdm(files, desc=f"Processing {root}"): 65 | ttl += 1 66 | fpath = os.path.join(root, midi) 67 | chdpath = os.path.join(chords_dir, os.path.splitext(midi)[0] + ".csv") 68 | music = muspy.read_midi(fpath) 69 | music.adjust_resolution(4) 70 | if len(music.time_signatures) == 0: 71 | music.time_signatures.append(muspy.TimeSignature(0, 4, 4)) 72 | if force: 73 | force_length(music) 74 | 75 | try: 76 | note_mat = get_note_matrix_melodies(music, ignore_non_melody) 77 | note_mat = dedup_note_matrix(note_mat) 78 | extract_chords_from_midi_file(fpath, chdpath) 79 | chord = get_chord_matrix(chdpath) 80 | except: 81 | chords_errors += 1 82 | continue 83 | 84 | try: 85 | db_pos, db_pos_filter = get_downbeat_pos_and_filter(music, fpath) 86 | except: 87 | downbeat_errors += 1 88 | continue 89 | if db_pos is not None and sum(filter(lambda x: x, db_pos_filter)) != 0: 90 | start_table = get_start_table(note_mat, db_pos) 91 | processed_data = { 92 | "notes": np.array(note_mat), 93 | "start_table": np.array(start_table), 94 | "db_pos": np.array(db_pos), 95 | "db_pos_filter": np.array(db_pos_filter), 96 | "chord": np.array(chord), 97 | } 98 | np.savez( 99 | os.path.join(output_dir, midi), 100 | notes=processed_data["notes"], 101 | start_table=processed_data["start_table"], 102 | db_pos=processed_data["db_pos"], 103 | db_pos_filter=processed_data["db_pos_filter"], 104 | chord=processed_data["chord"], 105 | ) 106 | success += 1 107 | else: 108 | downbeat_errors += 1 109 | 110 | print( 111 | f"""{ttl} tracks processed, {success} succeeded, {chords_errors} chords errors, {downbeat_errors} downbeat errors""" 112 | ) 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = ArgumentParser( 117 | description="prepare data from midi for a Polyffusion model" 118 | ) 119 | parser.add_argument( 120 | "--midi_dir", type=str, help="directory of input midis to preparep" 121 | ) 122 | parser.add_argument( 123 | "--chords_dir", type=str, help="directory to store extracted chords" 124 | ) 125 | parser.add_argument( 126 | "--npz_dir", type=str, help="directory to store prepared data in npz" 127 | ) 128 | parser.add_argument( 129 | "--force_length", 130 | action="store_true", 131 | help="to repeat shorter samples into the desired number of bars", 132 | ) 133 | parser.add_argument( 134 | "--ignore_non_melody", 135 | action="store_false", 136 | help="whether ignore all non-melody instruments. default: true", 137 | ) 138 | args = parser.parse_args() 139 | prepare_npz( 140 | args.midi_dir, 141 | args.chords_dir, 142 | args.npz_dir, 143 | args.force_length, 144 | args.ignore_non_melody, 145 | ) 146 | -------------------------------------------------------------------------------- /polyffusion/train/train_ldm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from data.dataloader import get_custom_train_val_dataloaders, get_train_val_dataloaders 4 | from data.dataloader_musicalion import ( 5 | get_train_val_dataloaders as get_train_val_dataloaders_musicalion, 6 | ) 7 | from dirs import PT_CHD_8BAR_PATH, PT_PNOTREE_PATH, PT_POLYDIS_PATH 8 | from models.model_sdf import Polyffusion_SDF 9 | from stable_diffusion.latent_diffusion import LatentDiffusion 10 | from stable_diffusion.model.unet import UNetModel 11 | from utils import ( 12 | load_pretrained_chd_enc_dec, 13 | load_pretrained_pnotree_enc_dec, 14 | load_pretrained_txt_enc, 15 | ) 16 | 17 | # from stable_diffusion.model.autoencoder import Autoencoder, Encoder, Decoder 18 | from . import TrainConfig 19 | 20 | 21 | class LDM_TrainConfig(TrainConfig): 22 | def __init__( 23 | self, 24 | params, 25 | output_dir, 26 | use_autoencoder=False, 27 | use_musicalion=False, 28 | use_track=[0, 1, 2], 29 | data_dir=None, 30 | ) -> None: 31 | super().__init__(params, None, output_dir) 32 | self.autoencoder = None 33 | 34 | if use_autoencoder: 35 | # encoder = Encoder( 36 | # in_channels=2, 37 | # z_channels=4, 38 | # channels=64, 39 | # channel_multipliers=[1, 2, 4, 4], 40 | # n_resnet_blocks=2 41 | # ) 42 | 43 | # decoder = Decoder( 44 | # out_channels=2, 45 | # z_channels=4, 46 | # channels=64, 47 | # channel_multipliers=[1, 2, 4, 4], 48 | # n_resnet_blocks=2 49 | # ) 50 | 51 | # self.autoencoder = Autoencoder( 52 | # emb_channels=4, encoder=encoder, decoder=decoder, z_channels=4 53 | # ) 54 | raise NotImplementedError 55 | 56 | self.unet_model = UNetModel( 57 | in_channels=params.in_channels, 58 | out_channels=params.out_channels, 59 | channels=params.channels, 60 | attention_levels=params.attention_levels, 61 | n_res_blocks=params.n_res_blocks, 62 | channel_multipliers=params.channel_multipliers, 63 | n_heads=params.n_heads, 64 | tf_layers=params.tf_layers, 65 | d_cond=params.d_cond, 66 | ) 67 | 68 | self.ldm_model = LatentDiffusion( 69 | linear_start=params.linear_start, 70 | linear_end=params.linear_end, 71 | n_steps=params.n_steps, 72 | latent_scaling_factor=params.latent_scaling_factor, 73 | autoencoder=self.autoencoder, 74 | unet_model=self.unet_model, 75 | ) 76 | 77 | self.pnotree_enc, self.pnotree_dec = None, None 78 | self.chord_enc, self.chord_dec = None, None 79 | self.txt_enc = None 80 | if params.cond_type == "pnotree": 81 | self.pnotree_enc, self.pnotree_dec = load_pretrained_pnotree_enc_dec( 82 | PT_PNOTREE_PATH, 20 83 | ) 84 | if "chord" in params.cond_type: 85 | if params.use_enc: 86 | self.chord_enc, self.chord_dec = load_pretrained_chd_enc_dec( 87 | PT_CHD_8BAR_PATH, 88 | params.chd_input_dim, 89 | params.chd_z_input_dim, 90 | params.chd_hidden_dim, 91 | params.chd_z_dim, 92 | params.chd_n_step, 93 | ) 94 | if "txt" in params.cond_type: 95 | if params.use_enc: 96 | self.txt_enc = load_pretrained_txt_enc( 97 | PT_POLYDIS_PATH, 98 | params.txt_emb_size, 99 | params.txt_hidden_dim, 100 | params.txt_z_dim, 101 | params.txt_num_channel, 102 | ) 103 | self.model = Polyffusion_SDF( 104 | self.ldm_model, 105 | cond_type=params.cond_type, 106 | cond_mode=params.cond_mode, 107 | chord_enc=self.chord_enc, 108 | chord_dec=self.chord_dec, 109 | pnotree_enc=self.pnotree_enc, 110 | pnotree_dec=self.pnotree_dec, 111 | txt_enc=self.txt_enc, 112 | concat_blurry=params.concat_blurry 113 | if hasattr(params, "concat_blurry") 114 | else False, 115 | concat_ratio=params.concat_ratio 116 | if hasattr(params, "concat_ratio") 117 | else 1 / 8, 118 | ) 119 | # Create dataloader 120 | if use_musicalion: 121 | self.train_dl, self.val_dl = get_train_val_dataloaders_musicalion( 122 | params.batch_size, params.num_workers, params.pin_memory 123 | ) 124 | else: 125 | if data_dir is None: 126 | self.train_dl, self.val_dl = get_train_val_dataloaders( 127 | params.batch_size, 128 | params.num_workers, 129 | params.pin_memory, 130 | use_track=use_track, 131 | ) 132 | else: 133 | self.train_dl, self.val_dl = get_custom_train_val_dataloaders( 134 | params.batch_size, 135 | data_dir, 136 | num_workers=params.num_workers, 137 | pin_memory=params.pin_memory, 138 | ) 139 | 140 | # Create optimizer 141 | self.optimizer = torch.optim.Adam( 142 | self.model.parameters(), lr=params.learning_rate 143 | ) 144 | -------------------------------------------------------------------------------- /polyffusion/stable_diffusion/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .util import get_ckpt_path 10 | 11 | 12 | class LPIPS(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name="vgg_lpips"): 29 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 30 | self.load_state_dict( 31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 32 | ) 33 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 34 | 35 | @classmethod 36 | def from_pretrained(cls, name="vgg_lpips"): 37 | if name != "vgg_lpips": 38 | raise NotImplementedError 39 | model = cls() 40 | ckpt = get_ckpt_path(name) 41 | model.load_state_dict( 42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 43 | ) 44 | return model 45 | 46 | def forward(self, input, target): 47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 48 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 49 | feats0, feats1, diffs = {}, {}, {} 50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 51 | for kk in range(len(self.chns)): 52 | feats0[kk], feats1[kk] = ( 53 | normalize_tensor(outs0[kk]), 54 | normalize_tensor(outs1[kk]), 55 | ) 56 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 57 | 58 | res = [ 59 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 60 | for kk in range(len(self.chns)) 61 | ] 62 | val = res[0] 63 | for l in range(1, len(self.chns)): 64 | val += res[l] 65 | return val 66 | 67 | 68 | class ScalingLayer(nn.Module): 69 | def __init__(self): 70 | super(ScalingLayer, self).__init__() 71 | self.register_buffer( 72 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 73 | ) 74 | self.register_buffer( 75 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 76 | ) 77 | 78 | def forward(self, inp): 79 | return (inp - self.shift) / self.scale 80 | 81 | 82 | class NetLinLayer(nn.Module): 83 | """A single linear layer which does a 1x1 conv""" 84 | 85 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 86 | super(NetLinLayer, self).__init__() 87 | layers = ( 88 | [ 89 | nn.Dropout(), 90 | ] 91 | if (use_dropout) 92 | else [] 93 | ) 94 | layers += [ 95 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 96 | ] 97 | self.model = nn.Sequential(*layers) 98 | 99 | 100 | class vgg16(torch.nn.Module): 101 | def __init__(self, requires_grad=False, pretrained=True): 102 | super(vgg16, self).__init__() 103 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 104 | self.slice1 = torch.nn.Sequential() 105 | self.slice2 = torch.nn.Sequential() 106 | self.slice3 = torch.nn.Sequential() 107 | self.slice4 = torch.nn.Sequential() 108 | self.slice5 = torch.nn.Sequential() 109 | self.N_slices = 5 110 | for x in range(4): 111 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(4, 9): 113 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(9, 16): 115 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 116 | for x in range(16, 23): 117 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 118 | for x in range(23, 30): 119 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 120 | if not requires_grad: 121 | for param in self.parameters(): 122 | param.requires_grad = False 123 | 124 | def forward(self, X): 125 | h = self.slice1(X) 126 | h_relu1_2 = h 127 | h = self.slice2(h) 128 | h_relu2_2 = h 129 | h = self.slice3(h) 130 | h_relu3_3 = h 131 | h = self.slice4(h) 132 | h_relu4_3 = h 133 | h = self.slice5(h) 134 | h_relu5_3 = h 135 | vgg_outputs = namedtuple( 136 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 137 | ) 138 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 139 | return out 140 | 141 | 142 | def normalize_tensor(x, eps=1e-10): 143 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 144 | return x / (norm_factor + eps) 145 | 146 | 147 | def spatial_average(x, keepdim=True): 148 | return x.mean([2, 3], keepdim=keepdim) 149 | -------------------------------------------------------------------------------- /polyffusion/stable_diffusion/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | --- 3 | title: Sampling algorithms for stable diffusion 4 | summary: > 5 | Annotated PyTorch implementation/tutorial of 6 | sampling algorithms 7 | for stable diffusion model. 8 | --- 9 | 10 | # Sampling algorithms for [stable diffusion](../index.html) 11 | 12 | We have implemented the following [sampling algorithms](sampler/index.html): 13 | 14 | * [Denoising Diffusion Probabilistic Models (DDPM) Sampling](ddpm.html) 15 | * [Denoising Diffusion Implicit Models (DDIM) Sampling](ddim.html) 16 | """ 17 | 18 | from typing import List, Optional 19 | 20 | import torch 21 | 22 | from ..latent_diffusion import LatentDiffusion 23 | 24 | 25 | class DiffusionSampler: 26 | """ 27 | ## Base class for sampling algorithms 28 | """ 29 | 30 | model: LatentDiffusion 31 | 32 | def __init__(self, model: LatentDiffusion): 33 | """ 34 | :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$ 35 | """ 36 | super().__init__() 37 | # Set the model $\epsilon_\text{cond}(x_t, c)$ 38 | self.model = model 39 | # Get number of steps the model was trained with $T$ 40 | self.n_steps = model.n_steps 41 | 42 | def get_eps( 43 | self, 44 | x: torch.Tensor, 45 | t: torch.Tensor, 46 | c: torch.Tensor, 47 | *, 48 | uncond_scale: float, 49 | uncond_cond: Optional[torch.Tensor], 50 | ): 51 | """ 52 | ## Get $\epsilon(x_t, c)$ 53 | 54 | :param x: is $x_t$ of shape `[batch_size, channels, height, width]` 55 | :param t: is $t$ of shape `[batch_size]` 56 | :param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]` 57 | :param uncond_scale: is the unconditional guidance scale $s$. This is used for 58 | $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ 59 | :param uncond_cond: is the conditional embedding for empty prompt $c_u$ 60 | """ 61 | # When the scale $s = 1$ 62 | # $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$ 63 | if uncond_cond is None or uncond_scale == 1.0: 64 | return self.model(x, t, c) 65 | elif uncond_scale == 0.0: # unconditional 66 | return self.model(x, t, uncond_cond) 67 | 68 | # Duplicate $x_t$ and $t$ 69 | x_in = torch.cat([x] * 2) 70 | t_in = torch.cat([t] * 2) 71 | # Concatenated $c$ and $c_u$ 72 | c_in = torch.cat([uncond_cond, c]) 73 | # Get $\epsilon_\text{cond}(x_t, c)$ and $\epsilon_\text{cond}(x_t, c_u)$ 74 | e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2) 75 | # Calculate 76 | # $$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$$ 77 | e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond) 78 | 79 | # 80 | return e_t 81 | 82 | def sample( 83 | self, 84 | shape: List[int], 85 | cond: torch.Tensor, 86 | repeat_noise: bool = False, 87 | temperature: float = 1.0, 88 | x_last: Optional[torch.Tensor] = None, 89 | uncond_scale: float = 1.0, 90 | uncond_cond: Optional[torch.Tensor] = None, 91 | skip_steps: int = 0, 92 | ): 93 | """ 94 | ### Sampling Loop 95 | 96 | :param shape: is the shape of the generated images in the 97 | form `[batch_size, channels, height, width]` 98 | :param cond: is the conditional embeddings $c$ 99 | :param temperature: is the noise temperature (random noise gets multiplied by this) 100 | :param x_last: is $x_T$. If not provided random noise will be used. 101 | :param uncond_scale: is the unconditional guidance scale $s$. This is used for 102 | $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ 103 | :param uncond_cond: is the conditional embedding for empty prompt $c_u$ 104 | :param skip_steps: is the number of time steps to skip. 105 | """ 106 | raise NotImplementedError() 107 | 108 | def paint( 109 | self, 110 | x: torch.Tensor, 111 | cond: torch.Tensor, 112 | t_start: int, 113 | *, 114 | orig: Optional[torch.Tensor] = None, 115 | mask: Optional[torch.Tensor] = None, 116 | orig_noise: Optional[torch.Tensor] = None, 117 | uncond_scale: float = 1.0, 118 | uncond_cond: Optional[torch.Tensor] = None, 119 | ): 120 | """ 121 | ### Painting Loop 122 | 123 | :param x: is $x_{T'}$ of shape `[batch_size, channels, height, width]` 124 | :param cond: is the conditional embeddings $c$ 125 | :param t_start: is the sampling step to start from, $T'$ 126 | :param orig: is the original image in latent page which we are in paining. 127 | :param mask: is the mask to keep the original image. 128 | :param orig_noise: is fixed noise to be added to the original image. 129 | :param uncond_scale: is the unconditional guidance scale $s$. This is used for 130 | $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ 131 | :param uncond_cond: is the conditional embedding for empty prompt $c_u$ 132 | """ 133 | raise NotImplementedError() 134 | 135 | def q_sample( 136 | self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None 137 | ): 138 | """ 139 | ### Sample from $q(x_t|x_0)$ 140 | 141 | :param x0: is $x_0$ of shape `[batch_size, channels, height, width]` 142 | :param index: is the time step $t$ index 143 | :param noise: is the noise, $\epsilon$ 144 | """ 145 | raise NotImplementedError() 146 | -------------------------------------------------------------------------------- /polyffusion/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from data.dataset import PianoOrchDataset 8 | from utils import ( 9 | chd_pitch_shift, 10 | chd_to_midi_file, 11 | chd_to_onehot, 12 | estx_to_midi_file, 13 | pianotree_pitch_shift, 14 | pr_mat_pitch_shift, 15 | prmat2c_to_midi_file, 16 | prmat_to_midi_file, 17 | ) 18 | 19 | # SEED = 7890 20 | # torch.manual_seed(SEED) 21 | # np.random.seed(SEED) 22 | # random.seed(SEED) 23 | 24 | 25 | def collate_fn(batch, shift): 26 | def sample_shift(): 27 | return np.random.choice(np.arange(-6, 6), 1)[0] 28 | 29 | prmat2c = [] 30 | pnotree = [] 31 | chord = [] 32 | prmat = [] 33 | song_fn = [] 34 | for b in batch: 35 | # b[0]: seg_pnotree; b[1]: seg_pnotree_y 36 | seg_prmat2c = b[0] 37 | seg_pnotree = b[1] 38 | seg_chord = b[2] 39 | seg_prmat = b[3] 40 | 41 | if shift: 42 | shift_pitch = sample_shift() 43 | seg_prmat2c = pr_mat_pitch_shift(seg_prmat2c, shift_pitch) 44 | seg_pnotree = pianotree_pitch_shift(seg_pnotree, shift_pitch) 45 | seg_chord = chd_pitch_shift(seg_chord, shift_pitch) 46 | seg_prmat = pr_mat_pitch_shift(seg_prmat, shift_pitch) 47 | 48 | seg_chord = chd_to_onehot(seg_chord) 49 | 50 | prmat2c.append(seg_prmat2c) 51 | pnotree.append(seg_pnotree) 52 | chord.append(seg_chord) 53 | prmat.append(seg_prmat) 54 | 55 | if len(b) > 4: 56 | song_fn.append(b[4]) 57 | 58 | prmat2c = torch.Tensor(np.array(prmat2c, np.float32)).float() 59 | pnotree = torch.Tensor(np.array(pnotree, np.int64)).long() 60 | chord = torch.Tensor(np.array(chord, np.float32)).float() 61 | prmat = torch.Tensor(np.array(prmat, np.float32)).float() 62 | # prmat = prmat.unsqueeze(1) # (B, 1, 128, 128) 63 | if len(song_fn) > 0: 64 | return prmat2c, pnotree, chord, prmat, song_fn 65 | else: 66 | return prmat2c, pnotree, chord, prmat 67 | 68 | 69 | def get_custom_train_val_dataloaders( 70 | batch_size, 71 | data_dir, 72 | num_workers=0, 73 | pin_memory=False, 74 | debug=False, 75 | train_ratio=0.9, 76 | ): 77 | all_data = next(os.walk(data_dir))[2] 78 | train_num = int(len(all_data) * train_ratio) 79 | train_files = all_data[:train_num] 80 | val_files = all_data[train_num:] 81 | 82 | train_dataset = PianoOrchDataset.load_with_song_paths( 83 | song_paths=train_files, data_dir=data_dir 84 | ) 85 | val_dataset = PianoOrchDataset.load_with_song_paths( 86 | song_paths=val_files, data_dir=data_dir 87 | ) 88 | 89 | train_dl = DataLoader( 90 | train_dataset, 91 | batch_size, 92 | True, 93 | collate_fn=lambda x: collate_fn(x, shift=True), 94 | num_workers=num_workers, 95 | pin_memory=pin_memory, 96 | ) 97 | val_dl = DataLoader( 98 | val_dataset, 99 | batch_size, 100 | False, 101 | collate_fn=lambda x: collate_fn(x, shift=False), 102 | num_workers=num_workers, 103 | pin_memory=pin_memory, 104 | ) 105 | print( 106 | f"Dataloader ready: batch_size={batch_size}, num_workers={num_workers}, pin_memory={pin_memory}, train_segments={len(train_dataset)}, val_segments={len(val_dataset)}" 107 | ) 108 | return train_dl, val_dl 109 | 110 | 111 | def get_train_val_dataloaders( 112 | batch_size, num_workers=0, pin_memory=False, debug=False, **kwargs 113 | ): 114 | train_dataset, val_dataset = PianoOrchDataset.load_train_and_valid_sets( 115 | debug=debug, **kwargs 116 | ) 117 | train_dl = DataLoader( 118 | train_dataset, 119 | batch_size, 120 | True, 121 | collate_fn=lambda x: collate_fn(x, shift=True), 122 | num_workers=num_workers, 123 | pin_memory=pin_memory, 124 | ) 125 | val_dl = DataLoader( 126 | val_dataset, 127 | batch_size, 128 | False, 129 | collate_fn=lambda x: collate_fn(x, shift=False), 130 | num_workers=num_workers, 131 | pin_memory=pin_memory, 132 | ) 133 | print( 134 | f"Dataloader ready: batch_size={batch_size}, num_workers={num_workers}, pin_memory={pin_memory}, train_segments={len(train_dataset)}, val_segments={len(val_dataset)} {kwargs}" 135 | ) 136 | return train_dl, val_dl 137 | 138 | 139 | def get_val_dataloader( 140 | batch_size, num_workers=0, pin_memory=False, debug=False, **kwargs 141 | ): 142 | val_dataset = PianoOrchDataset.load_valid_set(debug, **kwargs) 143 | val_dl = DataLoader( 144 | val_dataset, 145 | batch_size, 146 | False, 147 | collate_fn=lambda x: collate_fn(x, shift=False), 148 | num_workers=num_workers, 149 | pin_memory=pin_memory, 150 | ) 151 | print( 152 | f"Dataloader ready: batch_size={batch_size}, num_workers={num_workers}, pin_memory={pin_memory}, {kwargs}" 153 | ) 154 | return val_dl 155 | 156 | 157 | if __name__ == "__main__": 158 | train_dl, val_dl = get_train_val_dataloaders(16) 159 | print(len(train_dl)) 160 | for batch in train_dl: 161 | print(len(batch)) 162 | prmat2c, pnotree, chord, prmat = batch 163 | print(prmat2c.shape) 164 | print(pnotree.shape) 165 | print(chord.shape) 166 | print(prmat.shape) 167 | prmat2c = prmat2c.cpu().numpy() 168 | pnotree = pnotree.cpu().numpy() 169 | chord = chord.cpu().numpy() 170 | prmat = prmat.cpu().numpy() 171 | # chord = [onehot_to_chd(onehot) for onehot in chord] 172 | prmat2c_to_midi_file(prmat2c, "exp/dl_prmat2c.mid") 173 | estx_to_midi_file(pnotree, "exp/dl_pnotree.mid") 174 | chd_to_midi_file(chord, "exp/dl_chord.mid") 175 | prmat_to_midi_file(prmat, "exp/dl_prmat.mid") 176 | exit(0) 177 | -------------------------------------------------------------------------------- /polyffusion/mir_eval/tempo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The goal of a tempo estimation algorithm is to automatically detect the tempo 3 | of a piece of music, measured in beats per minute (BPM). 4 | 5 | See http://www.music-ir.org/mirex/wiki/2014:Audio_Tempo_Estimation for a 6 | description of the task and evaluation criteria. 7 | 8 | Conventions 9 | ----------- 10 | 11 | Reference and estimated tempi should be positive, and provided in ascending 12 | order as a numpy array of length 2. 13 | 14 | The weighting value from the reference must be a float in the range [0, 1]. 15 | 16 | Metrics 17 | ------- 18 | * :func:`mir_eval.tempo.detection`: Relative error, hits, and weighted 19 | precision of tempo estimation. 20 | 21 | ''' 22 | 23 | import warnings 24 | import numpy as np 25 | import collections 26 | from . import util 27 | 28 | 29 | def validate_tempi(tempi, reference=True): 30 | """Checks that there are two non-negative tempi. 31 | For a reference value, at least one tempo has to be greater than zero. 32 | 33 | Parameters 34 | ---------- 35 | tempi : np.ndarray 36 | length-2 array of tempo, in bpm 37 | 38 | reference : bool 39 | indicates a reference value 40 | 41 | """ 42 | 43 | if tempi.size != 2: 44 | raise ValueError('tempi must have exactly two values') 45 | 46 | if not np.all(np.isfinite(tempi)) or np.any(tempi < 0): 47 | raise ValueError('tempi={} must be non-negative numbers'.format(tempi)) 48 | 49 | if reference and np.all(tempi == 0): 50 | raise ValueError('reference tempi={} must have one' 51 | ' value greater than zero'.format(tempi)) 52 | 53 | 54 | def validate(reference_tempi, reference_weight, estimated_tempi): 55 | """Checks that the input annotations to a metric look like valid tempo 56 | annotations. 57 | 58 | Parameters 59 | ---------- 60 | reference_tempi : np.ndarray 61 | reference tempo values, in bpm 62 | 63 | reference_weight : float 64 | perceptual weight of slow vs fast in reference 65 | 66 | estimated_tempi : np.ndarray 67 | estimated tempo values, in bpm 68 | 69 | """ 70 | validate_tempi(reference_tempi, reference=True) 71 | validate_tempi(estimated_tempi, reference=False) 72 | 73 | if reference_weight < 0 or reference_weight > 1: 74 | raise ValueError('Reference weight must lie in range [0, 1]') 75 | 76 | 77 | def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): 78 | """Compute the tempo detection accuracy metric. 79 | 80 | Parameters 81 | ---------- 82 | reference_tempi : np.ndarray, shape=(2,) 83 | Two non-negative reference tempi 84 | 85 | reference_weight : float > 0 86 | The relative strength of ``reference_tempi[0]`` vs 87 | ``reference_tempi[1]``. 88 | 89 | estimated_tempi : np.ndarray, shape=(2,) 90 | Two non-negative estimated tempi. 91 | 92 | tol : float in [0, 1]: 93 | The maximum allowable deviation from a reference tempo to 94 | count as a hit. 95 | ``|est_t - ref_t| <= tol * ref_t`` 96 | (Default value = 0.08) 97 | 98 | Returns 99 | ------- 100 | p_score : float in [0, 1] 101 | Weighted average of recalls: 102 | ``reference_weight * hits[0] + (1 - reference_weight) * hits[1]`` 103 | 104 | one_correct : bool 105 | True if at least one reference tempo was correctly estimated 106 | 107 | both_correct : bool 108 | True if both reference tempi were correctly estimated 109 | 110 | Raises 111 | ------ 112 | ValueError 113 | If the input tempi are ill-formed 114 | 115 | If the reference weight is not in the range [0, 1] 116 | 117 | If ``tol < 0`` or ``tol > 1``. 118 | """ 119 | 120 | validate(reference_tempi, reference_weight, estimated_tempi) 121 | 122 | if tol < 0 or tol > 1: 123 | raise ValueError('invalid tolerance {}: must lie in the range ' 124 | '[0, 1]'.format(tol)) 125 | if tol == 0.: 126 | warnings.warn('A tolerance of 0.0 may not ' 127 | 'lead to the results you expect.') 128 | 129 | hits = [False, False] 130 | 131 | for i, ref_t in enumerate(reference_tempi): 132 | if ref_t > 0: 133 | # Compute the relative error for this reference tempo 134 | f_ref_t = float(ref_t) 135 | relative_error = np.min(np.abs(ref_t - estimated_tempi) / f_ref_t) 136 | 137 | # Count the hits 138 | hits[i] = relative_error <= tol 139 | 140 | p_score = reference_weight * hits[0] + (1.0-reference_weight) * hits[1] 141 | 142 | one_correct = bool(np.max(hits)) 143 | both_correct = bool(np.min(hits)) 144 | 145 | return p_score, one_correct, both_correct 146 | 147 | 148 | def evaluate(reference_tempi, reference_weight, estimated_tempi, **kwargs): 149 | """Compute all metrics for the given reference and estimated annotations. 150 | 151 | Parameters 152 | ---------- 153 | reference_tempi : np.ndarray, shape=(2,) 154 | Two non-negative reference tempi 155 | 156 | reference_weight : float > 0 157 | The relative strength of ``reference_tempi[0]`` vs 158 | ``reference_tempi[1]``. 159 | 160 | estimated_tempi : np.ndarray, shape=(2,) 161 | Two non-negative estimated tempi. 162 | 163 | kwargs 164 | Additional keyword arguments which will be passed to the 165 | appropriate metric or preprocessing functions. 166 | 167 | Returns 168 | ------- 169 | scores : dict 170 | Dictionary of scores, where the key is the metric name (str) and 171 | the value is the (float) score achieved. 172 | """ 173 | # Compute all metrics 174 | scores = collections.OrderedDict() 175 | 176 | (scores['P-score'], 177 | scores['One-correct'], 178 | scores['Both-correct']) = util.filter_kwargs(detection, reference_tempi, 179 | reference_weight, 180 | estimated_tempi, 181 | **kwargs) 182 | 183 | return scores 184 | -------------------------------------------------------------------------------- /polyffusion/chord_extractor/mir/extractors/vamp_extractor.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import numpy as np 4 | from mir.cache import hasher 5 | from mir.common import PACKAGE_PATH, SONIC_ANNOTATOR_PATH, WORKING_PATH 6 | from mir.extractors.extractor_base import * 7 | 8 | 9 | def rewrite_extract_n3(entry, inputfilename, outputfilename): 10 | f = open(inputfilename, "r") 11 | content = f.read() 12 | f.close() 13 | content = content.replace("[__SR__]", str(entry.prop.sr)) 14 | content = content.replace("[__WIN_SHIFT__]", str(entry.prop.hop_length)) 15 | content = content.replace("[__WIN_SIZE__]", str(entry.prop.win_size)) 16 | if not os.path.isdir(os.path.dirname(outputfilename)): 17 | os.makedirs(os.path.dirname(outputfilename)) 18 | f = open(outputfilename, "w") 19 | f.write(content) 20 | f.close() 21 | 22 | 23 | class NNLSChroma(ExtractorBase): 24 | def get_feature_class(self): 25 | return io.ChromaIO 26 | 27 | def extract(self, entry, **kwargs): 28 | print("NNLSChroma working on entry " + entry.name) 29 | if "margin" in kwargs: 30 | if kwargs["margin"] > 0: 31 | music = entry.music_h 32 | else: 33 | raise Exception("Error margin") 34 | 35 | else: 36 | music = entry.music 37 | music_io = io.MusicIO() 38 | temp_path = os.path.join( 39 | WORKING_PATH, "temp/nnlschroma_extractor_%s.wav" % hasher(entry.name) 40 | ) 41 | temp_n3_path = temp_path + ".n3" 42 | rewrite_extract_n3( 43 | entry, os.path.join(PACKAGE_PATH, "data/bothchroma.n3"), temp_n3_path 44 | ) 45 | music_io.write(music, temp_path, entry) 46 | proc = subprocess.Popen( 47 | [ 48 | SONIC_ANNOTATOR_PATH, 49 | "-t", 50 | temp_n3_path, 51 | temp_path, 52 | "-w", 53 | "lab", 54 | "--lab-stdout", 55 | ], 56 | stdout=subprocess.PIPE, 57 | stderr=subprocess.DEVNULL, 58 | ) 59 | # print('Begin processing') 60 | result = np.zeros((0, 24)) 61 | for line in proc.stdout: 62 | # the real code does filtering here 63 | line = bytes.decode(line) 64 | if line.endswith("\r\n"): 65 | line = line[: len(line) - 2] 66 | if line.endswith("\r"): 67 | line = line[: len(line) - 1] 68 | arr = np.array(list(map(float, line.split("\t")))[1:]) 69 | arr = arr.reshape((2, 12))[::-1].T 70 | arr = np.roll(arr, -3, axis=0).reshape((1, 24)) 71 | result = np.append(result, arr, axis=0) 72 | try: 73 | os.unlink(temp_path) 74 | os.unlink(temp_n3_path) 75 | except: 76 | pass 77 | if result.shape[0] == 0: 78 | raise Exception("Empty response") 79 | return result 80 | 81 | 82 | class TunedLogSpectrogram(ExtractorBase): 83 | def get_feature_class(self): 84 | return io.SpectrogramIO 85 | 86 | def extract(self, entry, **kwargs): 87 | print("TunedLogSpectrogram working on entry " + entry.name) 88 | music_io = io.MusicIO() 89 | temp_path = os.path.join( 90 | WORKING_PATH, 91 | "temp/tunedlogspectrogram_extractor_%s.wav" % hasher(entry.name), 92 | ) 93 | temp_n3_path = temp_path + ".n3" 94 | rewrite_extract_n3( 95 | entry, os.path.join(PACKAGE_PATH, "data/tunedlogfreqspec.n3"), temp_n3_path 96 | ) 97 | music_io.write(entry.music, temp_path, entry) 98 | proc = subprocess.Popen( 99 | [ 100 | SONIC_ANNOTATOR_PATH, 101 | "-t", 102 | temp_n3_path, 103 | temp_path, 104 | "-w", 105 | "lab", 106 | "--lab-stdout", 107 | ], 108 | stdout=subprocess.PIPE, 109 | stderr=subprocess.DEVNULL, 110 | ) 111 | # print('Begin processing') 112 | result = np.zeros((0, 256)) 113 | for line in proc.stdout: 114 | # the real code does filtering here 115 | line = bytes.decode(line) 116 | if line.endswith("\r\n"): 117 | line = line[: len(line) - 2] 118 | if line.endswith("\r"): 119 | line = line[: len(line) - 1] 120 | arr = np.array(list(map(float, line.split("\t")))[1:]) 121 | arr = arr.reshape((1, -1)) 122 | result = np.append(result, arr, axis=0) 123 | try: 124 | os.unlink(temp_path) 125 | os.unlink(temp_n3_path) 126 | except: 127 | pass 128 | if result.shape[0] == 0: 129 | raise Exception("Empty response") 130 | return result 131 | 132 | 133 | class GlobalTuning(ExtractorBase): 134 | def get_feature_class(self): 135 | return io.FloatIO 136 | 137 | def extract(self, entry, **kwargs): 138 | music_io = io.MusicIO() 139 | temp_path = os.path.join( 140 | WORKING_PATH, "temp/tuning_%s.wav" % hasher(entry.name) 141 | ) 142 | temp_n3_path = temp_path + ".n3" 143 | rewrite_extract_n3( 144 | entry, os.path.join(PACKAGE_PATH, "data/tuning.n3"), temp_n3_path 145 | ) 146 | if "source" in kwargs: 147 | music = entry.dict[kwargs["source"]].get(entry) 148 | else: 149 | music = entry.music 150 | music_io.write(music, temp_path, entry) 151 | proc = subprocess.Popen( 152 | [ 153 | SONIC_ANNOTATOR_PATH, 154 | "-t", 155 | temp_n3_path, 156 | temp_path, 157 | "-w", 158 | "lab", 159 | "--lab-stdout", 160 | ], 161 | stdout=subprocess.PIPE, 162 | stderr=subprocess.DEVNULL, 163 | ) 164 | # print('Begin processing') 165 | output = proc.stdout.readlines() 166 | result = ( 167 | np.log2(np.float64(output[0].decode().split("\t")[2])) - np.log2(440) 168 | ) * 12 169 | try: 170 | os.unlink(temp_path) 171 | os.unlink(temp_n3_path) 172 | except: 173 | pass 174 | return result 175 | -------------------------------------------------------------------------------- /polyffusion/stable_diffusion/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | from .vqperceptual import * 6 | 7 | 8 | class LPIPSWithDiscriminator(nn.Module): 9 | def __init__( 10 | self, 11 | disc_start, 12 | logvar_init=0.0, 13 | kl_weight=1.0, 14 | pixelloss_weight=1.0, 15 | disc_num_layers=3, 16 | disc_in_channels=3, 17 | disc_factor=1.0, 18 | disc_weight=1.0, 19 | perceptual_weight=1.0, 20 | use_actnorm=False, 21 | disc_conditional=False, 22 | disc_loss="hinge", 23 | ): 24 | super().__init__() 25 | assert disc_loss in ["hinge", "vanilla"] 26 | self.kl_weight = kl_weight 27 | self.pixel_weight = pixelloss_weight 28 | self.perceptual_loss = LPIPS().eval() 29 | self.perceptual_weight = perceptual_weight 30 | # output log variance 31 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 32 | 33 | self.discriminator = NLayerDiscriminator( 34 | input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm 35 | ).apply(weights_init) 36 | self.discriminator_iter_start = disc_start 37 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 38 | self.disc_factor = disc_factor 39 | self.discriminator_weight = disc_weight 40 | self.disc_conditional = disc_conditional 41 | 42 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 43 | if last_layer is not None: 44 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 45 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 46 | else: 47 | nll_grads = torch.autograd.grad( 48 | nll_loss, self.last_layer[0], retain_graph=True 49 | )[0] 50 | g_grads = torch.autograd.grad( 51 | g_loss, self.last_layer[0], retain_graph=True 52 | )[0] 53 | 54 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 55 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 56 | d_weight = d_weight * self.discriminator_weight 57 | return d_weight 58 | 59 | def forward( 60 | self, 61 | inputs, 62 | reconstructions, 63 | posteriors, 64 | optimizer_idx, 65 | global_step, 66 | last_layer=None, 67 | cond=None, 68 | split="train", 69 | weights=None, 70 | ): 71 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 72 | if self.perceptual_weight > 0: 73 | p_loss = self.perceptual_loss( 74 | inputs.contiguous(), reconstructions.contiguous() 75 | ) 76 | rec_loss = rec_loss + self.perceptual_weight * p_loss 77 | 78 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 79 | weighted_nll_loss = nll_loss 80 | if weights is not None: 81 | weighted_nll_loss = weights * nll_loss 82 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 83 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 84 | kl_loss = posteriors.kl() 85 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 86 | 87 | # now the GAN part 88 | if optimizer_idx == 0: 89 | # generator update 90 | if cond is None: 91 | assert not self.disc_conditional 92 | logits_fake = self.discriminator(reconstructions.contiguous()) 93 | else: 94 | assert self.disc_conditional 95 | logits_fake = self.discriminator( 96 | torch.cat((reconstructions.contiguous(), cond), dim=1) 97 | ) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | if self.disc_factor > 0.0: 101 | try: 102 | d_weight = self.calculate_adaptive_weight( 103 | nll_loss, g_loss, last_layer=last_layer 104 | ) 105 | except RuntimeError: 106 | assert not self.training 107 | d_weight = torch.tensor(0.0) 108 | else: 109 | d_weight = torch.tensor(0.0) 110 | 111 | disc_factor = adopt_weight( 112 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 113 | ) 114 | loss = ( 115 | weighted_nll_loss 116 | + self.kl_weight * kl_loss 117 | + d_weight * disc_factor * g_loss 118 | ) 119 | 120 | log = { 121 | "{}/total_loss".format(split): loss.clone().detach().mean(), 122 | "{}/logvar".format(split): self.logvar.detach(), 123 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 124 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 125 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 126 | "{}/d_weight".format(split): d_weight.detach(), 127 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 128 | "{}/g_loss".format(split): g_loss.detach().mean(), 129 | } 130 | return loss, log 131 | 132 | if optimizer_idx == 1: 133 | # second pass for discriminator update 134 | if cond is None: 135 | logits_real = self.discriminator(inputs.contiguous().detach()) 136 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 137 | else: 138 | logits_real = self.discriminator( 139 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 140 | ) 141 | logits_fake = self.discriminator( 142 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 143 | ) 144 | 145 | disc_factor = adopt_weight( 146 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 147 | ) 148 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 149 | 150 | log = { 151 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 152 | "{}/logits_real".format(split): logits_real.detach().mean(), 153 | "{}/logits_fake".format(split): logits_fake.detach().mean(), 154 | } 155 | return d_loss, log 156 | --------------------------------------------------------------------------------