├── polyffusion
├── __init__.py
├── chord_extractor
│ ├── mir
│ │ ├── io
│ │ │ ├── implement
│ │ │ │ ├── __init__.py
│ │ │ │ ├── unknown_io.py
│ │ │ │ ├── midi_io.py
│ │ │ │ ├── music_io.py
│ │ │ │ ├── scalar_io.py
│ │ │ │ ├── chroma_io.py
│ │ │ │ ├── spectrogram_io.py
│ │ │ │ └── regional_spectrogram_io.py
│ │ │ ├── __init__.py
│ │ │ └── feature_io_base.py
│ │ ├── .gitignore
│ │ ├── extractors
│ │ │ ├── __init__.py
│ │ │ ├── misc.py
│ │ │ ├── librosa_extractor.py
│ │ │ ├── extractor_base.py
│ │ │ └── vamp_extractor.py
│ │ ├── requirements.txt
│ │ ├── settings.py
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── data
│ │ │ ├── sparse_tag_template.svl
│ │ │ ├── midi_template.svl
│ │ │ ├── curve_template.svl
│ │ │ ├── tuning.n3
│ │ │ ├── spectrogram_template.svl
│ │ │ ├── pitch_template.svl
│ │ │ ├── bothchroma.n3
│ │ │ ├── chroma.n3
│ │ │ ├── tunedlogfreqspec.n3
│ │ │ └── chordino.n3
│ │ ├── music_base.py
│ │ └── cache.py
│ ├── example.sh
│ ├── .gitignore
│ ├── example.mid
│ ├── requirements.txt
│ ├── setup.py
│ ├── io_new
│ │ ├── list_io.py
│ │ ├── jams_io.py
│ │ ├── air_io.py
│ │ ├── madmom_io.py
│ │ ├── beatlab_io.py
│ │ ├── key_io.py
│ │ ├── lyric_io.py
│ │ ├── tag_io.py
│ │ ├── chordlab_io.py
│ │ ├── jointbeat_io.py
│ │ ├── downbeat_io.py
│ │ ├── salami_io.py
│ │ ├── complex_chord_io.py
│ │ ├── midilab_io.py
│ │ ├── osu_io.py
│ │ └── beat_align_io.py
│ ├── README.TXT
│ ├── extractors
│ │ └── rule_based_channel_reweight.py
│ ├── __init__.py
│ ├── main.py
│ └── example.out
├── pyrightconfig.json
├── stable_diffusion
│ ├── losses
│ │ ├── __init__.py
│ │ ├── discriminator.py
│ │ ├── lpips.py
│ │ └── contperceptual.py
│ ├── util.py
│ └── sampler
│ │ └── __init__.py
├── data
│ ├── example.mid
│ ├── pop909_extractor.py
│ ├── dataloader_musicalion.py
│ ├── polydis_format_to_mine.py
│ └── dataloader.py
├── setup.py
├── dl_modules
│ ├── __init__.py
│ ├── naive_nn.py
│ ├── chord_enc.py
│ ├── txt_enc.py
│ ├── chord_dec.py
│ └── pianotree_enc.py
├── params
│ ├── chd_8bar.yaml
│ ├── autoencoder.yaml
│ ├── ddpm.yaml
│ ├── sdf_pnotree.yaml
│ ├── sdf_txtvnl.yaml
│ ├── sdf.yaml
│ ├── sdf_concat.yaml
│ ├── sdf_txt.yaml
│ ├── sdf_chd8bar.yaml
│ ├── sdf_chdvnl.yaml
│ ├── sdf_chd8bar_txt.yaml
│ └── sdf_chd8bar_txt_mix2.yaml
├── ddpm
│ ├── utils.py
│ └── __init__.py
├── remove_pickle.py
├── mir_eval
│ ├── __init__.py
│ ├── setup.py
│ ├── onset.py
│ └── tempo.py
├── cleanup_checkpoints.py
├── dirs.py
├── polydis_aftertouch.py
├── models
│ ├── model_autoencoder.py
│ ├── model_ddpm.py
│ └── model_chd_8bar.py
├── train
│ ├── train_ddpm.py
│ ├── train_chd_8bar.py
│ ├── train_autoencoder.py
│ ├── scheduler.py
│ ├── __init__.py
│ └── train_ldm.py
├── main.py
├── lightning_learner.py
└── prepare_data.py
├── data
└── train_split_pnt
│ ├── pop909.pickle
│ ├── musicalion.pickle
│ ├── split_dict.pickle
│ └── musicalion.pickle_unclean
├── requirements.txt
├── LICENSE
├── .gitignore
└── README.md
/polyffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/io/implement/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/example.sh:
--------------------------------------------------------------------------------
1 | python main.py ./example.mid ./example.out
--------------------------------------------------------------------------------
/polyffusion/pyrightconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "reportInvalidStringEscapeSequence": false
3 | }
4 |
--------------------------------------------------------------------------------
/polyffusion/stable_diffusion/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .contperceptual import LPIPSWithDiscriminator
2 |
--------------------------------------------------------------------------------
/polyffusion/data/example.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/polyffusion/data/example.mid
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/.gitignore:
--------------------------------------------------------------------------------
1 | /cache_data
2 | /temp
3 | /.idea
4 | /__pycache__
5 | *.pyc
6 | /output
7 |
--------------------------------------------------------------------------------
/data/train_split_pnt/pop909.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/data/train_split_pnt/pop909.pickle
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/.gitignore:
--------------------------------------------------------------------------------
1 | /cache_data
2 | /temp
3 | /.idea
4 | /__pycache__
5 | *.pyc
6 | /output
7 | /lib
8 |
--------------------------------------------------------------------------------
/data/train_split_pnt/musicalion.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/data/train_split_pnt/musicalion.pickle
--------------------------------------------------------------------------------
/data/train_split_pnt/split_dict.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/data/train_split_pnt/split_dict.pickle
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/example.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/polyffusion/chord_extractor/example.mid
--------------------------------------------------------------------------------
/data/train_split_pnt/musicalion.pickle_unclean:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aik2mlj/polyffusion/HEAD/data/train_split_pnt/musicalion.pickle_unclean
--------------------------------------------------------------------------------
/polyffusion/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="polyffusion",
5 | packages=find_packages(),
6 | )
7 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/extractors/__init__.py:
--------------------------------------------------------------------------------
1 | from mir.extractors.extractor_base import ExtractorBase
2 |
3 | __all__ = ["ExtractorBase"]
4 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/requirements.txt:
--------------------------------------------------------------------------------
1 | pydub>=0.23.1
2 | pretty_midi>=0.2.9
3 | joblib>=0.13.2
4 | librosa>=0.7.2
5 | mir_eval>=0.5
6 | numpy>=1.16
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="chord_extractor",
5 | packages=find_packages(),
6 | )
7 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/requirements.txt:
--------------------------------------------------------------------------------
1 | librosa>=0.6.1
2 | joblib>=0.12.2
3 | pydub>=0.22.1
4 | numpy>=1.15.4
5 | h5py>=2.8.0
6 | torch>=0.4.1
7 | pretty_midi>=0.2.8
8 |
9 |
--------------------------------------------------------------------------------
/polyffusion/dl_modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .chord_dec import ChordDecoder
2 | from .chord_enc import RnnEncoder as ChordEncoder
3 | from .naive_nn import NaiveNN
4 | from .pianotree_dec import PianoTreeDecoder
5 | from .pianotree_enc import PianoTreeEncoder
6 | from .txt_enc import TextureEncoder
7 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/settings.py:
--------------------------------------------------------------------------------
1 | SONIC_VISUALIZER_PATH = "C:/Program Files (x86)/Sonic Visualiser/Sonic Visualiser.exe"
2 | SONIC_ANNOTATOR_PATH = (
3 | "C:/Program Files (x86)/Sonic Visualiser/annotator/sonic-annotator.exe"
4 | )
5 | DEFAULT_DATA_STORAGE_PATH = "E:/dataset/"
6 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/__init__.py:
--------------------------------------------------------------------------------
1 | from mir.common import PACKAGE_PATH, WORKING_PATH
2 | from mir.data_file import DataEntry, DataPool, TextureBuilder
3 |
4 | __all__ = [
5 | "TextureBuilder",
6 | "DataEntry",
7 | "WORKING_PATH",
8 | "PACKAGE_PATH",
9 | "DataPool",
10 | "io",
11 | ]
12 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/common.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from mir.settings import *
4 |
5 | WORKING_PATH = os.getcwd()
6 | PACKAGE_PATH = os.path.dirname(os.path.abspath(__file__))
7 |
8 | DEFAULT_DATA_STORAGE_PATH = DEFAULT_DATA_STORAGE_PATH.replace(
9 | "$project_name$", os.path.basename(os.getcwd())
10 | )
11 |
--------------------------------------------------------------------------------
/polyffusion/params/chd_8bar.yaml:
--------------------------------------------------------------------------------
1 | model_name: chd_8bar
2 | batch_size: 128
3 | max_epoch: 1000
4 | learning_rate: 0.001
5 | max_grad_norm: 10
6 | fp16: true
7 | tfr_chd:
8 | - 0.5
9 | - 0
10 | num_workers: 4
11 | pin_memory: true
12 | chd_n_step: 32
13 | chd_input_dim: 36
14 | chd_z_input_dim: 512
15 | chd_hidden_dim: 512
16 | chd_z_dim: 512
17 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/list_io.py:
--------------------------------------------------------------------------------
1 | from mir.io.feature_io_base import *
2 |
3 |
4 | class ListIO(FeatureIO):
5 | def read(self, filename, entry):
6 | return pickle_read(self, filename)
7 |
8 | def write(self, data, filename, entry):
9 | pickle_write(self, data, filename)
10 |
11 | def visualize(self, data, filename, entry, override_sr):
12 | return NotImplementedError()
13 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python==3.10.13
2 | torch==2.1.2
3 | imageio==2.33.1
4 | jams==0.3.4
5 | joblib==1.3.2
6 | labml==0.4.168
7 | labml_helpers==0.4.89
8 | librosa==0.10.1
9 | lightning==2.1.2
10 | matplotlib==3.8.0
11 | muspy==0.5.0
12 | numpy==1.26.3
13 | omegaconf==2.3.0
14 | Pillow==10.2.0
15 | pretty_midi==0.2.10
16 | pydub==0.25.1
17 | Requests==2.31.0
18 | setuptools==68.2.2
19 | torchvision==0.16.2
20 | tqdm==4.65.0
21 | wandb==0.16.1
22 |
--------------------------------------------------------------------------------
/polyffusion/params/autoencoder.yaml:
--------------------------------------------------------------------------------
1 | model_name: autoencoder
2 | batch_size: 16
3 | max_epoch: 100
4 | learning_rate: 5.0e-05
5 | max_grad_norm: 10
6 | fp16: false
7 | num_workers: 4
8 | pin_memory: true
9 | in_channels: 3
10 | out_channels: 3
11 | z_channels: 4
12 | channels: 64
13 | n_res_blocks: 2
14 | channel_multipliers:
15 | - 1
16 | - 2
17 | - 4
18 | - 4
19 | emb_channels: 4
20 | disc_start: 50001
21 | kl_weight: 1.0e-06
22 | disc_weight: 0.5
23 |
--------------------------------------------------------------------------------
/polyffusion/ddpm/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | ---
3 | title: Utility functions for DDPM experiment
4 | summary: >
5 | Utility functions for DDPM experiment
6 | ---
7 | # Utility functions for [DDPM](index.html) experiemnt
8 | """
9 | import torch
10 | import torch.utils.data
11 |
12 |
13 | def gather(consts: torch.Tensor, t: torch.Tensor):
14 | """Gather consts for $t$ and reshape to feature map shape"""
15 | c = consts.gather(-1, t)
16 | return c.reshape(-1, 1, 1, 1)
17 |
--------------------------------------------------------------------------------
/polyffusion/params/ddpm.yaml:
--------------------------------------------------------------------------------
1 | model_name: ddpm
2 | batch_size: 16
3 | max_epoch: 100
4 | learning_rate: 2.0e-05
5 | max_grad_norm: 10
6 | fp16: false
7 | num_workers: 4
8 | pin_memory: true
9 | beta: 0.1
10 | weights:
11 | - 1
12 | - 0.5
13 | image_channels: 2
14 | image_size_h: 128
15 | image_size_w: 128
16 | n_channels: 64
17 | channel_multipliers:
18 | - 1
19 | - 2
20 | - 2
21 | - 4
22 | is_attention:
23 | - false
24 | - false
25 | - false
26 | - true
27 | n_steps: 1000
28 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/io/implement/unknown_io.py:
--------------------------------------------------------------------------------
1 | from mir.io.feature_io_base import *
2 |
3 |
4 | class UnknownIO(FeatureIO):
5 | def read(self, filename, entry):
6 | raise Exception("Unknown type cannot be read")
7 |
8 | def write(self, data, filename, entry):
9 | raise Exception("Unknown type cannot be written")
10 |
11 | def visualize(self, data, filename, entry, override_sr):
12 | raise Exception("Unknown type cannot be visualized")
13 |
--------------------------------------------------------------------------------
/polyffusion/remove_pickle.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import sys
3 |
4 | from dirs import *
5 |
6 | if __name__ == "__main__":
7 | remove = sys.argv[1]
8 | with open(f"{TRAIN_SPLIT_DIR}/pop909.pickle", "rb") as f:
9 | pic = pickle.load(f)
10 | assert remove in pic[0] or remove in pic[1]
11 | if remove in pic[0]:
12 | pic[0].remove(remove)
13 | else:
14 | pic[1].remove(remove)
15 | with open(f"{TRAIN_SPLIT_DIR}/pop909.pickle", "wb") as f:
16 | pickle.dump(pic, f)
17 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/io/implement/midi_io.py:
--------------------------------------------------------------------------------
1 | import pretty_midi
2 | from mir.io.feature_io_base import *
3 |
4 |
5 | class MidiIO(FeatureIO):
6 | def read(self, filename, entry):
7 | midi_data = pretty_midi.PrettyMIDI(filename)
8 | return midi_data
9 |
10 | def write(self, data, filename, entry):
11 | data.write(filename)
12 |
13 | def visualize(self, data, filename, entry, override_sr):
14 | data.write(filename)
15 |
16 | def get_visualize_extention_name(self):
17 | return "mid"
18 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/data/sparse_tag_template.svl:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | [__DATA__]
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/polyffusion/params/sdf_pnotree.yaml:
--------------------------------------------------------------------------------
1 | model_name: sdf_pnotree
2 | batch_size: 16
3 | max_epoch: 50
4 | learning_rate: 5.0e-05
5 | max_grad_norm: 10
6 | fp16: true
7 | num_workers: 4
8 | pin_memory: true
9 | in_channels: 2
10 | out_channels: 2
11 | channels: 64
12 | attention_levels:
13 | - 2
14 | - 3
15 | n_res_blocks: 2
16 | channel_multipliers:
17 | - 1
18 | - 2
19 | - 4
20 | - 4
21 | n_heads: 4
22 | tf_layers: 1
23 | d_cond: 2048
24 | linear_start: 0.00085
25 | linear_end: 0.012
26 | n_steps: 1000
27 | latent_scaling_factor: 0.18215
28 | img_h: 128
29 | img_w: 128
30 | cond_type: pnotree
31 | cond_mode: mix
32 |
--------------------------------------------------------------------------------
/polyffusion/params/sdf_txtvnl.yaml:
--------------------------------------------------------------------------------
1 | model_name: sdf_txtvnl
2 | batch_size: 16
3 | max_epoch: 200
4 | learning_rate: 5.0e-05
5 | max_grad_norm: 10
6 | fp16: true
7 | num_workers: 4
8 | pin_memory: true
9 | in_channels: 2
10 | out_channels: 2
11 | channels: 64
12 | attention_levels:
13 | - 2
14 | - 3
15 | n_res_blocks: 2
16 | channel_multipliers:
17 | - 1
18 | - 2
19 | - 4
20 | - 4
21 | n_heads: 4
22 | tf_layers: 1
23 | d_cond: 128
24 | linear_start: 0.00085
25 | linear_end: 0.012
26 | n_steps: 1000
27 | latent_scaling_factor: 0.18215
28 | img_h: 128
29 | img_w: 128
30 | cond_type: txt
31 | cond_mode: mix
32 | use_enc: false
33 |
--------------------------------------------------------------------------------
/polyffusion/mir_eval/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Top-level module for mir_eval"""
3 |
4 | # Import all submodules (for each task)
5 | from . import alignment
6 | from . import beat
7 | from . import chord
8 | from . import io
9 | from . import onset
10 | from . import segment
11 | from . import separation
12 | from . import util
13 | from . import sonify
14 | from . import melody
15 | from . import multipitch
16 | from . import pattern
17 | from . import tempo
18 | from . import hierarchy
19 | from . import transcription
20 | from . import transcription_velocity
21 | from . import key
22 |
23 | __version__ = '0.7'
24 |
--------------------------------------------------------------------------------
/polyffusion/params/sdf.yaml:
--------------------------------------------------------------------------------
1 | model_name: sdf
2 | batch_size: 16
3 | max_epoch: 100
4 | learning_rate: 5.0e-05
5 | max_grad_norm: 10
6 | fp16: true
7 | num_workers: 4
8 | pin_memory: true
9 | in_channels: 2
10 | out_channels: 2
11 | channels: 64
12 | attention_levels:
13 | - 2
14 | - 3
15 | n_res_blocks: 2
16 | channel_multipliers:
17 | - 1
18 | - 2
19 | - 4
20 | - 4
21 | n_heads: 4
22 | tf_layers: 1
23 | d_cond: 1152
24 | linear_start: 0.00085
25 | linear_end: 0.012
26 | n_steps: 1000
27 | latent_scaling_factor: 0.18215
28 | img_h: 128
29 | img_w: 128
30 | cond_type: chord
31 | cond_mode: uncond
32 | use_enc: false
33 | chd_n_step: 32
34 | chd_input_dim: 36
35 |
--------------------------------------------------------------------------------
/polyffusion/params/sdf_concat.yaml:
--------------------------------------------------------------------------------
1 | model_name: sdf_concat
2 | batch_size: 16
3 | max_epoch: 100
4 | learning_rate: 5.0e-05
5 | max_grad_norm: 10
6 | fp16: true
7 | num_workers: 4
8 | pin_memory: true
9 | in_channels: 3
10 | out_channels: 2
11 | channels: 64
12 | attention_levels:
13 | - 2
14 | - 3
15 | n_res_blocks: 2
16 | channel_multipliers:
17 | - 1
18 | - 2
19 | - 4
20 | - 4
21 | n_heads: 4
22 | tf_layers: 1
23 | d_cond: 1152
24 | linear_start: 0.00085
25 | linear_end: 0.012
26 | n_steps: 1000
27 | latent_scaling_factor: 0.18215
28 | img_h: 128
29 | img_w: 128
30 | cond_type: chord
31 | cond_mode: uncond
32 | use_enc: false
33 | concat_blurry: true
34 | concat_ratio: 0.25
35 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/data/midi_template.svl:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | [__DATA__]
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/data/curve_template.svl:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | [__DATA__]
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/polyffusion/dl_modules/naive_nn.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class NaiveNN(nn.Module):
5 | def __init__(
6 | self,
7 | input_dim=512,
8 | output_dim=512,
9 | ):
10 | """Only two linear layers"""
11 | super(NaiveNN, self).__init__()
12 | self.linear1 = nn.Linear(input_dim, output_dim)
13 | self.linear2 = nn.Linear(output_dim, output_dim)
14 | self.input_dim = input_dim
15 | self.output_dim = output_dim
16 |
17 | def forward(self, z_x):
18 | output = self.linear1(z_x)
19 | output = self.linear2(output)
20 | # print(output_mu.size(), output_var.size())
21 | return output
22 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/jams_io.py:
--------------------------------------------------------------------------------
1 | import jams
2 | from mir.io import FeatureIO
3 |
4 |
5 | class JamsIO(FeatureIO):
6 | def read(self, filename, entry):
7 | return jams.load(filename)
8 |
9 | def write(self, data: jams.JAMS, filename, entry):
10 | data.save(filename)
11 |
12 | def visualize(self, data: jams.JAMS, filename, entry, override_sr):
13 | f = open(filename, "w")
14 | for annotation in data.annotations:
15 | for obs in annotation.data:
16 | f.write(
17 | "%f\t%f\t%s\n" % (obs.time, obs.time + obs.duration, str(obs.value))
18 | )
19 | f.close()
20 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/data/tuning.n3:
--------------------------------------------------------------------------------
1 | @prefix xsd: .
2 | @prefix vamp: .
3 | @prefix : <#> .
4 |
5 | :transform a vamp:Transform ;
6 | vamp:plugin ;
7 | vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ;
8 | vamp:block_size "[__WIN_SIZE__]"^^xsd:int ;
9 | vamp:plugin_version """5""" ;
10 | vamp:parameter_binding [
11 | vamp:parameter [ vamp:identifier "rollon" ] ;
12 | vamp:value "0"^^xsd:float ;
13 | ] ;
14 | vamp:output .
15 |
--------------------------------------------------------------------------------
/polyffusion/cleanup_checkpoints.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | if __name__ == "__main__":
4 | result = "./result"
5 | for dir in os.listdir(result):
6 | dpath = f"{result}/{dir}"
7 | if dir == "try":
8 | os.system(f"rm -rf {dpath}/*")
9 | continue
10 | for item in os.listdir(dpath):
11 | item_dpath = f"{dpath}/{item}"
12 | chkpt_dir = f"{item_dpath}/chkpts"
13 | if not os.path.exists(f"{chkpt_dir}/weights.pt"):
14 | os.system(f"ls -l {chkpt_dir}")
15 | y = input(f"Remove {item_dpath} (y/n)?")
16 | if y == "y":
17 | os.system(f"rm -rf {item_dpath}")
18 |
--------------------------------------------------------------------------------
/polyffusion/params/sdf_txt.yaml:
--------------------------------------------------------------------------------
1 | model_name: sdf_txt
2 | batch_size: 16
3 | max_epoch: 200
4 | learning_rate: 5.0e-05
5 | max_grad_norm: 10
6 | fp16: true
7 | num_workers: 4
8 | pin_memory: true
9 | in_channels: 2
10 | out_channels: 2
11 | channels: 64
12 | attention_levels:
13 | - 2
14 | - 3
15 | n_res_blocks: 2
16 | channel_multipliers:
17 | - 1
18 | - 2
19 | - 4
20 | - 4
21 | n_heads: 4
22 | tf_layers: 1
23 | d_cond: 1024
24 | linear_start: 0.00085
25 | linear_end: 0.012
26 | n_steps: 1000
27 | latent_scaling_factor: 0.18215
28 | img_h: 128
29 | img_w: 128
30 | cond_type: txt
31 | cond_mode: mix
32 | use_enc: true
33 | txt_emb_size: 256
34 | txt_hidden_dim: 1024
35 | txt_z_dim: 256
36 | txt_num_channel: 10
37 |
--------------------------------------------------------------------------------
/polyffusion/params/sdf_chd8bar.yaml:
--------------------------------------------------------------------------------
1 | model_name: sdf_chd8bar
2 | batch_size: 16
3 | max_epoch: 100
4 | learning_rate: 5.0e-05
5 | max_grad_norm: 10
6 | fp16: true
7 | num_workers: 4
8 | pin_memory: true
9 | in_channels: 2
10 | out_channels: 2
11 | channels: 64
12 | attention_levels:
13 | - 2
14 | - 3
15 | n_res_blocks: 2
16 | channel_multipliers:
17 | - 1
18 | - 2
19 | - 4
20 | - 4
21 | n_heads: 4
22 | tf_layers: 1
23 | d_cond: 512
24 | linear_start: 0.00085
25 | linear_end: 0.012
26 | n_steps: 1000
27 | latent_scaling_factor: 0.18215
28 | img_h: 128
29 | img_w: 128
30 | cond_type: chord
31 | cond_mode: mix
32 | use_enc: true
33 | chd_n_step: 32
34 | chd_input_dim: 36
35 | chd_z_input_dim: 512
36 | chd_hidden_dim: 512
37 | chd_z_dim: 512
38 |
--------------------------------------------------------------------------------
/polyffusion/params/sdf_chdvnl.yaml:
--------------------------------------------------------------------------------
1 | model_name: sdf_chdvnl
2 | batch_size: 16
3 | max_epoch: 100
4 | learning_rate: 5.0e-05
5 | max_grad_norm: 10
6 | fp16: true
7 | num_workers: 4
8 | pin_memory: true
9 | in_channels: 2
10 | out_channels: 2
11 | channels: 64
12 | attention_levels:
13 | - 2
14 | - 3
15 | n_res_blocks: 2
16 | channel_multipliers:
17 | - 1
18 | - 2
19 | - 4
20 | - 4
21 | n_heads: 4
22 | tf_layers: 1
23 | d_cond: 1152
24 | linear_start: 0.00085
25 | linear_end: 0.012
26 | n_steps: 1000
27 | latent_scaling_factor: 0.18215
28 | img_h: 128
29 | img_w: 128
30 | cond_type: chord
31 | cond_mode: mix
32 | use_enc: false
33 | chd_n_step: 32
34 | chd_input_dim: 36
35 | chd_z_input_dim: 512
36 | chd_hidden_dim: 512
37 | chd_z_dim: 512
38 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/air_io.py:
--------------------------------------------------------------------------------
1 | from mir.io.feature_io_base import *
2 |
3 |
4 | class AirIO(FeatureIO):
5 | def read(self, filename, entry):
6 | return pickle_read(self, filename)
7 |
8 | def write(self, data, filename, entry):
9 | return pickle_write(self, data, filename)
10 |
11 | def visualize(self, data, filename, entry, override_sr):
12 | arr = data.export_to_array()
13 | from mir.io.implement.regional_spectrogram_io import RegionalSpectrogramIO
14 |
15 | return RegionalSpectrogramIO().visualize(
16 | arr, filename, entry, override_sr=override_sr
17 | )
18 |
19 | def get_visualize_extention_name(self):
20 | return "svl"
21 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/io/__init__.py:
--------------------------------------------------------------------------------
1 | from .feature_io_base import FeatureIO, LoadingPlaceholder
2 | from .implement.chroma_io import ChromaIO
3 | from .implement.midi_io import MidiIO
4 | from .implement.music_io import MusicIO
5 | from .implement.regional_spectrogram_io import RegionalSpectrogramIO
6 | from .implement.scalar_io import FloatIO, IntegerIO
7 | from .implement.spectrogram_io import SpectrogramIO
8 | from .implement.unknown_io import UnknownIO
9 |
10 | __all__ = [
11 | "FeatureIO",
12 | "LoadingPlaceholder",
13 | "ChromaIO",
14 | "MidiIO",
15 | "MusicIO",
16 | "SpectrogramIO",
17 | "IntegerIO",
18 | "FloatIO",
19 | "RegionalSpectrogramIO",
20 | "UnknownIO",
21 | ]
22 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/data/spectrogram_template.svl:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | [__DATA__]
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/io/implement/music_io.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | from mir.io.feature_io_base import *
3 |
4 |
5 | class MusicIO(FeatureIO):
6 | def read(self, filename, entry):
7 | y, sr = librosa.load(filename, sr=entry.prop.sr, mono=True)
8 | return y # (y-np.mean(y))/np.std(y)
9 |
10 | def write(self, data, filename, entry):
11 | sr = entry.prop.sr
12 | librosa.output.write_wav(filename, y=data, sr=sr, norm=False)
13 |
14 | def visualize(self, data, filename, entry, override_sr):
15 | sr = entry.prop.sr
16 | librosa.output.write_wav(
17 | filename, y=data, sr=sr, norm=True
18 | ) # otherwise I would be deaf
19 |
20 | def get_visualize_extention_name(self):
21 | return "wav"
22 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/README.TXT:
--------------------------------------------------------------------------------
1 | 这个代码理论上是可以直接按里面的README.TXT跑起来的,但是针对的是流行音乐的和弦(更准确的说是lmd数据集里的多轨流行音乐MIDI),可能对于古典音乐,musicalian,单轨音乐不太适用(但可以试试)。
2 |
3 | 算法是模板匹配+动态规划平滑。首先识别哪些轨道是和弦轨道(是软标签,main.py第47行的weights),用和弦轨道组成的chroma和bass chroma去和模板匹配。和弦的转移考虑了流行音乐的节拍性质,在downbeat的时候比较宽松,在upbeat和节拍内的时候比较严格。关于怎么调整:
4 |
5 | 1. 首先要保证MIDI里的节拍信息(也就是meter和tempo的标记)是准确的,也就是说用pretty_midi获得的midi.get_beats()和midi.get_downbeats()需要是准确的。如果不准确的话,需要手动修正MIDI里的这些标记。
6 | 2. 和弦类别需要做相应的修改。在chord_class.py文件里现在都是流行和爵士的和弦标签,QUALITIES存的是可能的和弦类型,INVERSIONS里存的是需要考虑转位的和弦类型,以及每个类型的转位种类。需要根据音乐风格做相应的调整。
7 |
8 | Junyan
9 |
10 |
11 | 安装
12 | >> pip3 install -r requirements.txt
13 |
14 | 用法
15 | >> python3 main.py ./example.mid ./example.out
16 | 可以参考main.py和example.sh
17 |
18 | 其他
19 | 和弦字典见chord_class.py
20 |
21 | Junyan
22 |
--------------------------------------------------------------------------------
/polyffusion/params/sdf_chd8bar_txt.yaml:
--------------------------------------------------------------------------------
1 | model_name: sdf_chd8bar_txt
2 | batch_size: 16
3 | max_epoch: 100
4 | learning_rate: 5.0e-05
5 | max_grad_norm: 10
6 | fp16: true
7 | num_workers: 4
8 | pin_memory: true
9 | in_channels: 2
10 | out_channels: 2
11 | channels: 64
12 | attention_levels:
13 | - 2
14 | - 3
15 | n_res_blocks: 2
16 | channel_multipliers:
17 | - 1
18 | - 2
19 | - 4
20 | - 4
21 | n_heads: 4
22 | tf_layers: 1
23 | d_cond: 1536
24 | linear_start: 0.00085
25 | linear_end: 0.012
26 | n_steps: 1000
27 | latent_scaling_factor: 0.18215
28 | img_h: 128
29 | img_w: 128
30 | cond_type: chord+txt
31 | cond_mode: mix
32 | use_enc: true
33 | chd_n_step: 32
34 | chd_input_dim: 36
35 | chd_z_input_dim: 512
36 | chd_hidden_dim: 512
37 | chd_z_dim: 512
38 | txt_emb_size: 256
39 | txt_hidden_dim: 1024
40 | txt_z_dim: 256
41 | txt_num_channel: 10
42 |
--------------------------------------------------------------------------------
/polyffusion/params/sdf_chd8bar_txt_mix2.yaml:
--------------------------------------------------------------------------------
1 | model_name: sdf_chd8bar_txt_mix2
2 | batch_size: 16
3 | max_epoch: 100
4 | learning_rate: 5.0e-05
5 | max_grad_norm: 10
6 | fp16: true
7 | num_workers: 0
8 | pin_memory: false
9 | in_channels: 2
10 | out_channels: 2
11 | channels: 64
12 | attention_levels:
13 | - 2
14 | - 3
15 | n_res_blocks: 2
16 | channel_multipliers:
17 | - 1
18 | - 2
19 | - 4
20 | - 4
21 | n_heads: 4
22 | tf_layers: 1
23 | d_cond: 1536
24 | linear_start: 0.00085
25 | linear_end: 0.012
26 | n_steps: 1000
27 | latent_scaling_factor: 0.18215
28 | img_h: 128
29 | img_w: 128
30 | cond_type: chord+txt
31 | cond_mode: mix2
32 | use_enc: true
33 | chd_n_step: 32
34 | chd_input_dim: 36
35 | chd_z_input_dim: 512
36 | chd_hidden_dim: 512
37 | chd_z_dim: 512
38 | txt_emb_size: 256
39 | txt_hidden_dim: 1024
40 | txt_z_dim: 256
41 | txt_num_channel: 10
42 |
--------------------------------------------------------------------------------
/polyffusion/dl_modules/chord_enc.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.distributions import Normal
3 |
4 |
5 | class RnnEncoder(nn.Module):
6 | def __init__(self, input_dim, hidden_dim, z_dim):
7 | super(RnnEncoder, self).__init__()
8 | self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True, bidirectional=True)
9 | self.linear_mu = nn.Linear(hidden_dim * 2, z_dim)
10 | self.linear_var = nn.Linear(hidden_dim * 2, z_dim)
11 | self.input_dim = input_dim
12 | self.hidden_dim = hidden_dim
13 | self.z_dim = z_dim
14 |
15 | def forward(self, x):
16 | x = self.gru(x)[-1]
17 | x = x.transpose_(0, 1).contiguous()
18 | x = x.view(x.size(0), -1)
19 | mu = self.linear_mu(x)
20 | var = self.linear_var(x).exp_()
21 | dist = Normal(mu, var)
22 | return dist
23 |
--------------------------------------------------------------------------------
/polyffusion/dirs.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # dataset paths
4 | DATA_DIR = "data/LOP_4_bin_pnt"
5 | TRAIN_SPLIT_DIR = "data/train_split_pnt"
6 |
7 | MUSICALION_DATA_DIR = "data/musicalion_solo_piano_4_bin_pnt"
8 | POP909_DATA_DIR = "data/POP909_4_bin_pnt_8bar"
9 |
10 | # pretrained path
11 | PT_PNOTREE_PATH = "pretrained/pnotree_20/train_20-last-model.pt"
12 |
13 | PT_POLYDIS_PATH = "pretrained/polydis/model_master_final.pt"
14 | PT_A2S_PATH = "pretrained/a2s/a2s-stage3a.pt"
15 |
16 | # pretrained chd_8bar
17 | PT_CHD_8BAR_PATH = "pretrained/chd8bar/weights.pt"
18 |
19 | # the path to store demo.
20 | DEMO_FOLDER = "./demo"
21 |
22 | # the path to save trained model params and tensorboard log.
23 | RESULT_PATH = "./result"
24 |
25 | if not os.path.exists(DEMO_FOLDER):
26 | os.mkdir(DEMO_FOLDER)
27 |
28 | if not os.path.exists(RESULT_PATH):
29 | os.mkdir(RESULT_PATH)
30 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/music_base.py:
--------------------------------------------------------------------------------
1 | NUM_TO_ABS_SCALE = ["C", "C#", "D", "Eb", "E", "F", "F#", "G", "Ab", "A", "Bb", "B"]
2 |
3 |
4 | def get_scale_and_suffix(name):
5 | result = "C*D*EF*G*A*B".index(name[0])
6 | prefix_length = 1
7 | if len(name) > 1:
8 | if name[1] == "b":
9 | result = result - 1
10 | if result < 0:
11 | result += 12
12 | prefix_length = 2
13 | if name[1] == "#":
14 | result = result + 1
15 | if result >= 12:
16 | result -= 12
17 | prefix_length = 2
18 | return result, name[prefix_length:]
19 |
20 |
21 | def scale_name_to_value(name):
22 | result = "1*2*34*5*6*78*9".index(
23 | name[-1]
24 | ) # 8 and 9 are for weird tagging in some mirex chords
25 | return (result - name.count("b") + name.count("#") + 12) % 12
26 |
--------------------------------------------------------------------------------
/polyffusion/polydis_aftertouch.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from polydis.model import DisentangleVAE
4 | from utils import estx_to_midi_file
5 |
6 | model_path = "pretrained/polydis/model_master_final.pt"
7 | # readme_fn = "./train.py"
8 | batch_size = 128
9 | length = 16 # 16 bars for inference
10 | # n_epoch = 6
11 | # clip = 1
12 | # parallel = False
13 | # weights = [1, 0.5]
14 | # beta = 0.1
15 | # tf_rates = [(0.6, 0), (0.5, 0), (0.5, 0)]
16 | # lr = 1e-3
17 |
18 |
19 | class PolydisAftertouch:
20 | def __init__(self) -> None:
21 | model = DisentangleVAE.init_model()
22 | model.load_model(model_path)
23 | print(f"loaded model {model_path}.")
24 | self.model = model
25 |
26 | def reconstruct(self, prmat, chd, fn, chd_sample=False):
27 | chd = chd.float()
28 | prmat = prmat.float()
29 | est_x = self.model.inference(prmat, chd, sample=False, chd_sample=chd_sample)
30 | estx_to_midi_file(est_x, fn)
31 |
32 |
33 | if __name__ == "__main__":
34 | aftertouch = PolydisAftertouch()
35 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/io/implement/scalar_io.py:
--------------------------------------------------------------------------------
1 | from mir.io.feature_io_base import *
2 |
3 |
4 | class FloatIO(FeatureIO):
5 | def read(self, filename, entry):
6 | f = open(filename, "r")
7 | result = float(f.readline().strip())
8 | f.close()
9 | return result
10 |
11 | def write(self, data, filename, entry):
12 | f = open(filename, "w")
13 | f.write(str(float(data)))
14 | f.close()
15 |
16 | def visualize(self, data, filename, entry, override_sr):
17 | raise Exception("Cannot visualize a scalar")
18 |
19 |
20 | class IntegerIO(FeatureIO):
21 | def read(self, filename, entry):
22 | f = open(filename, "r")
23 | result = int(f.readline().strip())
24 | f.close()
25 | return result
26 |
27 | def write(self, data, filename, entry):
28 | f = open(filename, "w")
29 | f.write(str(int(data)))
30 | f.close()
31 |
32 | def visualize(self, data, filename, entry, override_sr):
33 | raise Exception("Cannot visualize a scalar")
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Lejun Min
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/polyffusion/models/model_autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from stable_diffusion.model.autoencoder import Autoencoder
6 | from utils import *
7 |
8 |
9 | class Polyffusion_Autoencoder(nn.Module):
10 | def __init__(self, autoencoder: Autoencoder):
11 | super(Polyffusion_Autoencoder, self).__init__()
12 | self.autoencoder = autoencoder
13 |
14 | @classmethod
15 | def load_trained(cls, ldm, model_dir):
16 | model = cls(ldm)
17 | trained_leaner = torch.load(f"{model_dir}/weights.pt")
18 | model.load_state_dict(trained_leaner["model"])
19 | return model
20 |
21 | def get_loss_dict(self, batch, step):
22 | """
23 | z_y is the stuff the diffusion model needs to learn
24 | """
25 | prmat, _, chord = batch
26 | # (#B, 2, 128, 128)
27 | print(f"prmat: {prmat.requires_grad}")
28 | prmat = F.pad(prmat, [0, 0, 0, 0, 0, 1], "constant", 0)
29 | print(f"prmat: {prmat.requires_grad}")
30 | # (#B, 3, 128, 128)
31 | return self.autoencoder.get_loss_dict(prmat, step)
32 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/data/pitch_template.svl:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | [__DATA_FREQ__]
9 |
10 |
11 | [__DATA_ENERGY__]
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/polyffusion/models/model_ddpm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from ddpm import DenoiseDiffusion
5 | from utils import *
6 |
7 |
8 | class Polyffusion_DDPM(nn.Module):
9 | def __init__(
10 | self,
11 | ddpm: DenoiseDiffusion,
12 | params,
13 | max_simu_note=20,
14 | ):
15 | super(Polyffusion_DDPM, self).__init__()
16 | self.params = params
17 | self.ddpm = ddpm
18 |
19 | @classmethod
20 | def load_trained(cls, ddpm, chkpt_fpath, params, max_simu_note=20):
21 | model = cls(ddpm, params, max_simu_note)
22 | trained_leaner = torch.load(chkpt_fpath)
23 | model.load_state_dict(trained_leaner["model"])
24 | return model
25 |
26 | def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
27 | return self.ddpm.p_sample(xt, t)
28 |
29 | def q_sample(self, x0: torch.Tensor, t: torch.Tensor):
30 | return self.ddpm.q_sample(x0, t)
31 |
32 | def get_loss_dict(self, batch, step):
33 | """
34 | z_y is the stuff the diffusion model needs to learn
35 | """
36 | prmat2c, pnotree, chord, prmat = batch
37 | return {"loss": self.ddpm.loss(prmat2c)}
38 |
--------------------------------------------------------------------------------
/polyffusion/mir_eval/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="mir_eval",
5 | version="0.7",
6 | description="Common metrics for common audio/music processing tasks.",
7 | author="Colin Raffel",
8 | author_email="craffel@gmail.com",
9 | url="https://github.com/craffel/mir_eval",
10 | packages=find_packages(),
11 | classifiers=[
12 | "License :: OSI Approved :: MIT License",
13 | "Programming Language :: Python",
14 | "Development Status :: 5 - Production/Stable",
15 | "Intended Audience :: Developers",
16 | "Topic :: Multimedia :: Sound/Audio :: Analysis",
17 | "Programming Language :: Python :: 3",
18 | ],
19 | keywords="audio music mir dsp",
20 | license="MIT",
21 | install_requires=[
22 | "numpy >= 1.7.0",
23 | "scipy >= 1.0.0",
24 | ],
25 | extras_require={
26 | "display": ["matplotlib>=1.5.0"],
27 | "testing": [
28 | "matplotlib>=2.1.0",
29 | "decorator",
30 | "pytest",
31 | "pytest-cov",
32 | "pytest-mpl",
33 | "nose",
34 | ],
35 | },
36 | python_requires=">=3",
37 | )
38 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/madmom_io.py:
--------------------------------------------------------------------------------
1 | from mir.common import PACKAGE_PATH
2 | from mir.io.feature_io_base import *
3 |
4 |
5 | class MadmomBeatProbIO(FeatureIO):
6 | def read(self, filename, entry):
7 | return pickle_read(self, filename)
8 |
9 | def write(self, data, filename, entry):
10 | pickle_write(self, data, filename)
11 |
12 | def visualize(self, data, filename, entry, override_sr):
13 | f = open(os.path.join(PACKAGE_PATH, "data/spectrogram_template.svl"), "r")
14 | content = f.read()
15 | f.close()
16 | content = content.replace("[__SR__]", str(100))
17 | content = content.replace("[__WIN_SHIFT__]", str(1))
18 | content = content.replace("[__SHAPE_1__]", str(data.shape[1]))
19 | content = content.replace("[__COLOR__]", str(1))
20 | labels = [str(i) for i in range(data.shape[1])]
21 | content = content.replace("[__DATA__]", create_svl_3d_data(labels, data))
22 | f = open(filename, "w")
23 | f.write(content)
24 | f.close()
25 |
26 | def pre_assign(self, entry, proxy):
27 | entry.prop.set("n_frame", LoadingPlaceholder(proxy, entry))
28 |
29 | def post_load(self, data, entry):
30 | entry.prop.set("n_frame", data.shape[0])
31 |
32 | def get_visualize_extention_name(self):
33 | return "svl"
34 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/data/bothchroma.n3:
--------------------------------------------------------------------------------
1 | @prefix xsd: .
2 | @prefix vamp: .
3 | @prefix : <#> .
4 |
5 | :transform a vamp:Transform ;
6 | vamp:plugin ;
7 | vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ;
8 | vamp:block_size "[__WIN_SIZE__]"^^xsd:int ;
9 | vamp:plugin_version """5""" ;
10 | vamp:parameter_binding [
11 | vamp:parameter [ vamp:identifier "chromanormalize" ] ;
12 | vamp:value "0"^^xsd:float ;
13 | ] ;
14 | vamp:parameter_binding [
15 | vamp:parameter [ vamp:identifier "rollon" ] ;
16 | vamp:value "0"^^xsd:float ;
17 | ] ;
18 | vamp:parameter_binding [
19 | vamp:parameter [ vamp:identifier "s" ] ;
20 | vamp:value "0.7"^^xsd:float ;
21 | ] ;
22 | vamp:parameter_binding [
23 | vamp:parameter [ vamp:identifier "tuningmode" ] ;
24 | vamp:value "0"^^xsd:float ;
25 | ] ;
26 | vamp:parameter_binding [
27 | vamp:parameter [ vamp:identifier "useNNLS" ] ;
28 | vamp:value "1"^^xsd:float ;
29 | ] ;
30 | vamp:parameter_binding [
31 | vamp:parameter [ vamp:identifier "whitening" ] ;
32 | vamp:value "1"^^xsd:float ;
33 | ] ;
34 | vamp:output [ vamp:identifier "bothchroma" ] .
35 |
--------------------------------------------------------------------------------
/polyffusion/dl_modules/txt_enc.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.distributions import Normal
3 |
4 |
5 | class TextureEncoder(nn.Module):
6 | def __init__(self, emb_size, hidden_dim, z_dim, num_channel=10):
7 | """input must be piano_mat: (B, 32, 128)"""
8 | super(TextureEncoder, self).__init__()
9 | self.cnn = nn.Sequential(
10 | nn.Conv2d(1, num_channel, kernel_size=(4, 12), stride=(4, 1), padding=0),
11 | nn.ReLU(),
12 | nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4)),
13 | )
14 | self.fc1 = nn.Linear(num_channel * 29, 1000)
15 | self.fc2 = nn.Linear(1000, emb_size)
16 | self.gru = nn.GRU(emb_size, hidden_dim, batch_first=True, bidirectional=True)
17 | self.linear_mu = nn.Linear(hidden_dim * 2, z_dim)
18 | self.linear_var = nn.Linear(hidden_dim * 2, z_dim)
19 | self.emb_size = emb_size
20 | self.hidden_dim = hidden_dim
21 | self.z_dim = z_dim
22 |
23 | def forward(self, pr):
24 | # pr: (bs, 32, 128)
25 | bs = pr.size(0)
26 | pr = pr.unsqueeze(1)
27 | pr = self.cnn(pr).view(bs, 8, -1)
28 | pr = self.fc2(self.fc1(pr)) # (bs, 8, emb_size)
29 | pr = self.gru(pr)[-1]
30 | pr = pr.transpose_(0, 1).contiguous()
31 | pr = pr.view(pr.size(0), -1)
32 | mu = self.linear_mu(pr)
33 | var = self.linear_var(pr).exp_()
34 | dist = Normal(mu, var)
35 | return dist
36 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/data/chroma.n3:
--------------------------------------------------------------------------------
1 | @prefix xsd: .
2 | @prefix vamp: .
3 | @prefix : <#> .
4 |
5 | :transform_plugin a vamp:Plugin ;
6 | vamp:identifier "nnls-chroma" .
7 |
8 | :transform_library a vamp:PluginLibrary ;
9 | vamp:identifier "nnls-chroma" ;
10 | vamp:available_plugin :transform_plugin .
11 |
12 | :transform a vamp:Transform ;
13 | vamp:plugin :transform_plugin ;
14 | vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ;
15 | vamp:block_size "[__WIN_SIZE__]"^^xsd:int ;
16 | vamp:plugin_version """3""" ;
17 | vamp:sample_rate "[__SR__]"^^xsd:int ;
18 | vamp:parameter_binding [
19 | vamp:parameter [ vamp:identifier "chromanormalize" ] ;
20 | vamp:value "0"^^xsd:float ;
21 | ] ;
22 | vamp:parameter_binding [
23 | vamp:parameter [ vamp:identifier "rollon" ] ;
24 | vamp:value "0"^^xsd:float ;
25 | ] ;
26 | vamp:parameter_binding [
27 | vamp:parameter [ vamp:identifier "s" ] ;
28 | vamp:value "0.7"^^xsd:float ;
29 | ] ;
30 | vamp:parameter_binding [
31 | vamp:parameter [ vamp:identifier "tuningmode" ] ;
32 | vamp:value "0"^^xsd:float ;
33 | ] ;
34 | vamp:parameter_binding [
35 | vamp:parameter [ vamp:identifier "useNNLS" ] ;
36 | vamp:value "1"^^xsd:float ;
37 | ] ;
38 | vamp:parameter_binding [
39 | vamp:parameter [ vamp:identifier "whitening" ] ;
40 | vamp:value "1"^^xsd:float ;
41 | ] ;
42 | vamp:output [ vamp:identifier "chroma" ] .
43 |
--------------------------------------------------------------------------------
/polyffusion/data/pop909_extractor.py:
--------------------------------------------------------------------------------
1 | from data.dataloader import get_train_val_dataloaders
2 | from utils import prmat_to_midi_file
3 |
4 | if __name__ == "__main__":
5 | # val_dataset = PianoOrchDataset.load_valid_set(use_track=[1, 2])
6 | # val_dl = get_val_dataloader(1000, use_track=[0, 1, 2])
7 | train_dl, val_dl = get_train_val_dataloaders()
8 | print(len(val_dl))
9 | for i, batch in enumerate(val_dl):
10 | prmat2c, pnotree, chord, prmat = batch
11 | prmat_to_midi_file(prmat, "exp/ref_wm.mid")
12 | break
13 |
14 | # dir = "data/POP909_MIDIs"
15 | # os.makedirs(dir, exist_ok=True)
16 |
17 | # for i in range(1, 910):
18 | # fpath = os.path.join(POP909_DATA_DIR, f"{i:03}.npz")
19 | # print(fpath)
20 | # if not os.path.exists(fpath):
21 | # continue
22 | # data = np.load(fpath, allow_pickle=True)
23 | # notes = data["notes"]
24 | # midi = pm.PrettyMIDI()
25 | # piano_program = pm.instrument_name_to_program("Acoustic Grand Piano")
26 | # piano = pm.Instrument(program=piano_program)
27 | # one_beat = 0.125
28 | # for track in notes:
29 | # for note in track:
30 | # onset, pitch, duration, velocity, program = note
31 | # note = pm.Note(
32 | # velocity=velocity,
33 | # pitch=pitch,
34 | # start=onset * one_beat,
35 | # end=(onset + duration) * one_beat
36 | # )
37 | # piano.notes.append(note)
38 |
39 | # midi.instruments.append(piano)
40 | # midi.write(f"{dir}/{i:03}.mid")
41 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/data/tunedlogfreqspec.n3:
--------------------------------------------------------------------------------
1 | @prefix xsd: .
2 | @prefix vamp: .
3 | @prefix : <#> .
4 |
5 | :transform_plugin a vamp:Plugin ;
6 | vamp:identifier "nnls-chroma" .
7 |
8 | :transform_library a vamp:PluginLibrary ;
9 | vamp:identifier "nnls-chroma" ;
10 | vamp:available_plugin :transform_plugin .
11 |
12 | :transform a vamp:Transform ;
13 | vamp:plugin ;
14 | vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ;
15 | vamp:block_size "[__WIN_SIZE__]"^^xsd:int ;
16 | vamp:plugin_version """5""" ;
17 | vamp:sample_rate "[__SR__]"^^xsd:int ;
18 | vamp:parameter_binding [
19 | vamp:parameter [ vamp:identifier "chromanormalize" ] ;
20 | vamp:value "0"^^xsd:float ;
21 | ] ;
22 | vamp:parameter_binding [
23 | vamp:parameter [ vamp:identifier "rollon" ] ;
24 | vamp:value "0"^^xsd:float ;
25 | ] ;
26 | vamp:parameter_binding [
27 | vamp:parameter [ vamp:identifier "s" ] ;
28 | vamp:value "0.7"^^xsd:float ;
29 | ] ;
30 | vamp:parameter_binding [
31 | vamp:parameter [ vamp:identifier "tuningmode" ] ;
32 | vamp:value "0"^^xsd:float ;
33 | ] ;
34 | vamp:parameter_binding [
35 | vamp:parameter [ vamp:identifier "useNNLS" ] ;
36 | vamp:value "1"^^xsd:float ;
37 | ] ;
38 | vamp:parameter_binding [
39 | vamp:parameter [ vamp:identifier "whitening" ] ;
40 | vamp:value "1"^^xsd:float ;
41 | ] ;
42 | vamp:output [ vamp:identifier "tunedlogfreqspec" ] .
43 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/extractors/rule_based_channel_reweight.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from extractors.midi_utilities import is_percussive_channel
3 |
4 |
5 | def get_channel_thickness(piano_roll):
6 | chroma_matrix = np.zeros((piano_roll.shape[0], 12))
7 | for note in range(12):
8 | chroma_matrix[:, note] = np.sum(piano_roll[:, note::12], axis=1)
9 | thickness_array = (chroma_matrix > 0).sum(axis=1)
10 | if thickness_array.sum() == 0:
11 | return 0
12 | return thickness_array[thickness_array > 0].mean()
13 |
14 |
15 | def get_channel_bass_property(piano_roll):
16 | result = np.argwhere(piano_roll > 0)[:, 1]
17 | if len(result) == 0:
18 | return 0.0, 1.0
19 | return result.mean(), min(1.0, len(result) / len(piano_roll))
20 |
21 |
22 | def midi_to_thickness_weights(midi):
23 | thickness = np.array(
24 | [
25 | get_channel_thickness(ins.get_piano_roll().T)
26 | for ins in midi.instruments
27 | if not is_percussive_channel(ins)
28 | ]
29 | )
30 | result = 1 - np.exp(-(thickness - 0.95))
31 | result /= result.max()
32 | return result
33 |
34 |
35 | def midi_to_thickness_and_bass_weights(midi):
36 | rolls = [
37 | ins.get_piano_roll().T
38 | for ins in midi.instruments
39 | if not is_percussive_channel(ins)
40 | ]
41 | thickness = np.array([get_channel_thickness(roll) for roll in rolls])
42 | bass = np.array([get_channel_bass_property(roll) for roll in rolls])
43 | bass[bass[:, 1] < 0.2, 0] = 128
44 | result = 1 - np.exp(-(thickness - 0.95))
45 | result /= result.max()
46 | result[np.argmin(bass[:, 0])] = 1.0
47 |
48 | return result
49 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/beatlab_io.py:
--------------------------------------------------------------------------------
1 | from mir.common import PACKAGE_PATH
2 | from mir.io.feature_io_base import *
3 |
4 |
5 | class BeatLabIO(FeatureIO):
6 | def read(self, filename, entry):
7 | f = open(filename, "r")
8 | content = f.read()
9 | lines = content.split("\n")
10 | f.close()
11 | result = []
12 | for i in range(len(lines)):
13 | line = lines[i].strip()
14 | if line == "":
15 | continue
16 | tokens = line.split("\t")
17 | result.append([float(tokens[0]), float(tokens[2])])
18 | return result
19 |
20 | def write(self, data, filename, entry):
21 | f = open(filename, "w")
22 | for i in range(0, data.shape[0]):
23 | f.write("\t".join([str(item) for item in data[i, :]]))
24 | f.write("\n")
25 | f.close()
26 |
27 | def visualize(self, data, filename, entry, override_sr):
28 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r")
29 | sr = override_sr
30 | content = f.read()
31 | f.close()
32 | content = content.replace("[__SR__]", str(sr))
33 | content = content.replace("[__STYLE__]", str(1))
34 | output_text = ""
35 | for beat_info in data:
36 | output_text += '\n' % (
37 | int(beat_info[0] * sr),
38 | int(beat_info[1]),
39 | )
40 | content = content.replace("[__DATA__]", output_text)
41 | f = open(filename, "w")
42 | f.write(content)
43 | f.close()
44 |
45 | def get_visualize_extention_name(self):
46 | return "svl"
47 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/key_io.py:
--------------------------------------------------------------------------------
1 | from mir.io import FeatureIO
2 |
3 |
4 | class KeyIO(FeatureIO):
5 | def read(self, filename, entry):
6 | f = open(filename, "r")
7 | content = f.read()
8 | lines = content.split("\n")
9 | f.close()
10 | result = []
11 | for i in range(len(lines)):
12 | line = lines[i].strip()
13 | if line == "":
14 | continue
15 | tokens = line.split("\t")
16 | assert len(tokens) == 3
17 | result.append([float(tokens[0]), float(tokens[1]), tokens[2]])
18 | return result
19 |
20 | def write(self, data, filename, entry):
21 | f = open(filename, "w")
22 | for i in range(0, len(data)):
23 | f.write("\t".join([str(item).replace("\t", " ") for item in data[i]]))
24 | f.write("\n")
25 | f.close()
26 |
27 | def visualize(self, data, filename, entry, override_sr):
28 | sr = override_sr
29 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r")
30 | content = f.read()
31 | f.close()
32 | content = content.replace("[__SR__]", str(sr))
33 | content = content.replace("[__STYLE__]", str(1))
34 | results = []
35 | for item in data:
36 | results.append(
37 | ''
38 | % (int(np.round(item[0] * sr)), item[2])
39 | )
40 | content = content.replace("[__DATA__]", "\n".join(results))
41 | f = open(filename, "w")
42 | f.write(content)
43 | f.close()
44 |
45 | def get_visualize_extention_name(self):
46 | return "svl"
47 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/lyric_io.py:
--------------------------------------------------------------------------------
1 | import codecs
2 |
3 | from mir.io.feature_io_base import *
4 |
5 |
6 | class LyricIO(FeatureIO):
7 | def read(self, filename, entry):
8 | f = open(filename, "r", encoding="utf-16-le")
9 | content = f.read()
10 | if content.startswith("\ufeff"):
11 | content = content[1:]
12 | lines = content.split("\n")
13 | f.close()
14 | result = []
15 | for i in range(len(lines)):
16 | line = lines[i].strip()
17 | if line == "":
18 | continue
19 | tokens = line.split("\t")
20 | if len(tokens) == 3:
21 | result.append([float(tokens[0]), float(tokens[1]), tokens[2]])
22 | elif len(tokens) == 4: # Contains sentence information
23 | result.append(
24 | [float(tokens[0]), float(tokens[1]), tokens[2], int(tokens[3])]
25 | )
26 | else:
27 | raise Exception("Not supported format")
28 | return result
29 |
30 | def write(self, data, filename, entry):
31 | f = open(filename, "wb")
32 | f.write(codecs.BOM_UTF16_LE)
33 | for i in range(0, len(data)):
34 | f.write("\t".join([str(item) for item in data[i]]).encode("utf-16-le"))
35 | f.write("\n".encode("utf-16-le"))
36 | f.close()
37 |
38 | def visualize(self, data, filename, entry, override_sr):
39 | f = open(filename, "w")
40 | for i in range(0, len(data)):
41 | f.write("\t".join([str(item) for item in data[i]]))
42 | f.write("\n")
43 | f.close()
44 |
45 | def get_visualize_extention_name(self):
46 | return "txt"
47 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/tag_io.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mir.common import PACKAGE_PATH
3 | from mir.io.feature_io_base import *
4 |
5 |
6 | class TimedTagIO(FeatureIO):
7 | def read(self, filename, entry):
8 | f = open(filename, "r")
9 | content = f.read()
10 | lines = content.split("\n")
11 | f.close()
12 | result = []
13 | for i in range(len(lines)):
14 | line = lines[i].strip()
15 | if line == "":
16 | continue
17 | tokens = line.split("\t")
18 | assert len(tokens) == 2
19 | result.append([float(tokens[0]), tokens[1]])
20 | return result
21 |
22 | def write(self, data, filename, entry):
23 | f = open(filename, "w")
24 | for i in range(0, len(data)):
25 | f.write("\t".join([str(item) for item in data[i]]))
26 | f.write("\n")
27 | f.close()
28 |
29 | def visualize(self, data, filename, entry, override_sr):
30 | sr = override_sr
31 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r")
32 | content = f.read()
33 | f.close()
34 | content = content.replace("[__SR__]", str(sr))
35 | content = content.replace("[__STYLE__]", str(1))
36 | results = []
37 | for item in data:
38 | results.append(
39 | ''
40 | % (int(np.round(item[0] * sr)), item[1])
41 | )
42 | content = content.replace("[__DATA__]", "\n".join(results))
43 | f = open(filename, "w")
44 | f.write(content)
45 | f.close()
46 |
47 | def get_visualize_extention_name(self):
48 | return "svl"
49 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/chordlab_io.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mir.common import PACKAGE_PATH
3 | from mir.io.feature_io_base import *
4 |
5 |
6 | class ChordLabIO(FeatureIO):
7 | def read(self, filename, entry):
8 | f = open(filename, "r")
9 | content = f.read()
10 | lines = content.split("\n")
11 | f.close()
12 | result = []
13 | for i in range(len(lines)):
14 | line = lines[i].strip()
15 | if line == "":
16 | continue
17 | tokens = line.split("\t")
18 | assert len(tokens) == 3
19 | result.append([float(tokens[0]), float(tokens[1]), tokens[2]])
20 | return result
21 |
22 | def write(self, data, filename, entry):
23 | f = open(filename, "w")
24 | for i in range(0, len(data)):
25 | f.write("\t".join([str(item) for item in data[i]]))
26 | f.write("\n")
27 | f.close()
28 |
29 | def visualize(self, data, filename, entry, override_sr):
30 | sr = override_sr
31 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r")
32 | content = f.read()
33 | f.close()
34 | content = content.replace("[__SR__]", str(sr))
35 | content = content.replace("[__STYLE__]", str(1))
36 | results = []
37 | for item in data:
38 | results.append(
39 | ''
40 | % (int(np.round(item[0] * sr)), item[2])
41 | )
42 | content = content.replace("[__DATA__]", "\n".join(results))
43 | f = open(filename, "w")
44 | f.write(content)
45 | f.close()
46 |
47 | def get_visualize_extention_name(self):
48 | return "svl"
49 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/data/chordino.n3:
--------------------------------------------------------------------------------
1 | @prefix xsd: .
2 | @prefix vamp: .
3 | @prefix : <#> .
4 |
5 | :transform_plugin a vamp:Plugin ;
6 | vamp:identifier "chordino" .
7 |
8 | :transform_library a vamp:PluginLibrary ;
9 | vamp:identifier "nnls-chroma" ;
10 | vamp:available_plugin :transform_plugin .
11 |
12 | :transform a vamp:Transform ;
13 | vamp:plugin ;
14 | vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ;
15 | vamp:block_size "[__WIN_SIZE__]"^^xsd:int ;
16 | vamp:plugin_version """5""" ;
17 | vamp:sample_rate "[__SR__]"^^xsd:int ;
18 | vamp:parameter_binding [
19 | vamp:parameter [ vamp:identifier "boostn" ] ;
20 | vamp:value "0.1"^^xsd:float ;
21 | ] ;
22 | vamp:parameter_binding [
23 | vamp:parameter [ vamp:identifier "rollon" ] ;
24 | vamp:value "0"^^xsd:float ;
25 | ] ;
26 | vamp:parameter_binding [
27 | vamp:parameter [ vamp:identifier "s" ] ;
28 | vamp:value "0.7"^^xsd:float ;
29 | ] ;
30 | vamp:parameter_binding [
31 | vamp:parameter [ vamp:identifier "tuningmode" ] ;
32 | vamp:value "0"^^xsd:float ;
33 | ] ;
34 | vamp:parameter_binding [
35 | vamp:parameter [ vamp:identifier "useHMM" ] ;
36 | vamp:value "1"^^xsd:float ;
37 | ] ;
38 | vamp:parameter_binding [
39 | vamp:parameter [ vamp:identifier "useNNLS" ] ;
40 | vamp:value "1"^^xsd:float ;
41 | ] ;
42 | vamp:parameter_binding [
43 | vamp:parameter [ vamp:identifier "whitening" ] ;
44 | vamp:value "1"^^xsd:float ;
45 | ] ;
46 | vamp:output [ vamp:identifier "simplechord" ] .
47 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/jointbeat_io.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mir.common import PACKAGE_PATH
3 | from mir.io.feature_io_base import *
4 |
5 |
6 | class JointBeatIO(FeatureIO):
7 | def read(self, filename, entry):
8 | f = open(filename, "r")
9 | content = f.read()
10 | lines = content.split("\n")
11 | f.close()
12 | result = []
13 | for i in range(len(lines)):
14 | line = lines[i].strip()
15 | if line == "":
16 | continue
17 | tokens = line.split("\t")
18 | assert len(tokens) == 3
19 | result.append([float(tokens[0]), int(tokens[1]), int(tokens[2])])
20 | return np.array(result)
21 |
22 | def write(self, data, filename, entry):
23 | f = open(filename, "w")
24 | for i in range(0, len(data)):
25 | f.write("\t".join([str(item) for item in data[i]]))
26 | f.write("\n")
27 | f.close()
28 |
29 | def visualize(self, data, filename, entry, override_sr):
30 | sr = override_sr
31 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r")
32 | content = f.read()
33 | f.close()
34 | content = content.replace("[__SR__]", str(sr))
35 | content = content.replace("[__STYLE__]", str(1))
36 | results = []
37 | for item in data:
38 | results.append(
39 | ''
40 | % (int(np.round(item[0] * sr)), int(item[1]), int(item[2]))
41 | )
42 | content = content.replace("[__DATA__]", "\n".join(results))
43 | f = open(filename, "w")
44 | f.write(content)
45 | f.close()
46 |
47 | def get_visualize_extention_name(self):
48 | return "svl"
49 |
--------------------------------------------------------------------------------
/polyffusion/train/train_ddpm.py:
--------------------------------------------------------------------------------
1 | from data.dataloader import get_custom_train_val_dataloaders, get_train_val_dataloaders
2 | from ddpm import DenoiseDiffusion
3 | from ddpm.unet import UNet
4 | from models.model_ddpm import Polyffusion_DDPM
5 |
6 | from . import *
7 |
8 |
9 | class DDPM_TrainConfig(TrainConfig):
10 | # U-Net model for $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
11 | eps_model: UNet
12 | # [DDPM algorithm](index.html)
13 | diffusion: DenoiseDiffusion
14 |
15 | # Adam optimizer
16 | optimizer: torch.optim.Adam
17 |
18 | def __init__(self, params, output_dir, data_dir=None):
19 | super().__init__(params, None, output_dir)
20 |
21 | self.eps_model = UNet(
22 | image_channels=params.image_channels,
23 | n_channels=params.n_channels,
24 | ch_mults=params.channel_multipliers,
25 | is_attn=params.is_attention,
26 | )
27 |
28 | # Create [DDPM class](index.html)
29 | self.diffusion = DenoiseDiffusion(
30 | eps_model=self.eps_model,
31 | n_steps=params.n_steps,
32 | )
33 |
34 | self.model = Polyffusion_DDPM(self.diffusion, params)
35 | # Create dataloader
36 | if data_dir is None:
37 | self.train_dl, self.val_dl = get_train_val_dataloaders(
38 | params.batch_size, params.num_workers, params.pin_memory
39 | )
40 | else:
41 | self.train_dl, self.val_dl = get_custom_train_val_dataloaders(
42 | params.batch_size,
43 | data_dir,
44 | num_workers=params.num_workers,
45 | pin_memory=params.pin_memory,
46 | )
47 | # Create optimizer
48 | self.optimizer = torch.optim.Adam(
49 | self.eps_model.parameters(), lr=params.learning_rate
50 | )
51 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/downbeat_io.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mir.common import PACKAGE_PATH
3 | from mir.io.feature_io_base import *
4 |
5 |
6 | class DownbeatIO(FeatureIO):
7 | def read(self, filename, entry):
8 | f = open(filename, "r")
9 | lines = f.readlines()
10 | lines = [line.strip("\n\r") for line in lines]
11 | lines = [line for line in lines if line != ""]
12 | f.close()
13 | result = np.zeros((len(lines), 2))
14 | for i in range(len(lines)):
15 | line = lines[i]
16 | tokens = line.split("\t")
17 | assert len(tokens) == 2
18 | result[i, 0] = float(tokens[0])
19 | result[i, 1] = float(tokens[1])
20 | return result
21 |
22 | def write(self, data, filename, entry):
23 | f = open(filename, "w")
24 | for i in range(0, len(data)):
25 | f.write("\t".join([str(item) for item in data[i, :]]))
26 | f.write("\n")
27 | f.close()
28 |
29 | def visualize(self, data, filename, entry, override_sr):
30 | f = open(os.path.join(PACKAGE_PATH, "data/sparse_tag_template.svl"), "r")
31 | sr = override_sr
32 | content = f.read()
33 | f.close()
34 | content = content.replace("[__SR__]", str(sr))
35 | content = content.replace("[__STYLE__]", str(1))
36 | output_text = ""
37 | for beat_info in data:
38 | output_text += '\n' % (
39 | int(beat_info[0] * sr),
40 | int(beat_info[1]),
41 | )
42 | content = content.replace("[__DATA__]", output_text)
43 | f = open(filename, "w")
44 | f.write(content)
45 | f.close()
46 |
47 | def get_visualize_extention_name(self):
48 | return "svl"
49 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/__init__.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import sys
3 |
4 | import mir_eval
5 | import numpy as np
6 |
7 | from .main import transcribe_cb1000_midi
8 |
9 |
10 | def get_chord_from_chdfile(fpath, one_beat=0.5, rounding=True):
11 | """
12 | chord matrix [M * 14], each line represent the chord of a beat
13 | same format as mir_eval.chord.encode():
14 | root_number(1), semitone_bitmap(12), bass_number(1)
15 | inputs are generated from junyan's algorithm
16 | """
17 | file = csv.reader(open(fpath), delimiter="\t")
18 | beat_cnt = 0
19 | chords = []
20 | for line in file:
21 | start = float(line[0])
22 | end = float(line[1])
23 | chord = line[2]
24 | if not rounding:
25 | assert ((end - start) / one_beat).is_integer()
26 | beat_num = int((end - start) / one_beat)
27 | else:
28 | beat_num = round((end - start) / one_beat)
29 | for _ in range(beat_num):
30 | beat_cnt += 1
31 | # see https://craffel.github.io/mir_eval/#mir_eval.chord.encode
32 | chd_enc = mir_eval.chord.encode(chord)
33 |
34 | root = chd_enc[0]
35 | # make chroma and bass absolute
36 | chroma_bitmap = chd_enc[1]
37 | chroma_bitmap = np.roll(chroma_bitmap, root)
38 | bass = (chd_enc[2] + root) % 12
39 |
40 | line = [root]
41 | for _ in chroma_bitmap:
42 | line.append(_)
43 | line.append(bass)
44 |
45 | chords.append(line)
46 | return np.array(chords, dtype=np.float32)
47 |
48 |
49 | def extract_chords_from_midi_file(fpath, chdfile_path):
50 | transcribe_cb1000_midi(fpath, chdfile_path)
51 | return get_chord_from_chdfile(chdfile_path)
52 |
53 |
54 | if __name__ == "__main__":
55 | extract_chords_from_midi_file(sys.argv[1], sys.argv[2])
56 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/cache.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import pickle
4 |
5 | from mir.common import WORKING_PATH
6 |
7 | __all__ = ["load", "save"]
8 |
9 |
10 | def mkdir_for_file(path):
11 | folder_path = os.path.dirname(path)
12 | if not os.path.isdir(folder_path):
13 | os.makedirs(folder_path)
14 | return path
15 |
16 |
17 | def dumptofile(obj, filename, protocol):
18 | f = open(filename, "wb")
19 | # If you are awared of the compatibility issues
20 | # Well, you use cache only on your own computer, right?
21 | pickle.dump(obj, f, protocol=protocol)
22 | f.close()
23 |
24 |
25 | def loadfromfile(filename):
26 | if os.path.isfile(filename):
27 | f = open(filename, "rb")
28 | obj = pickle.load(f)
29 | f.close()
30 | return obj
31 | else:
32 | raise Exception("No cache of %s" % filename)
33 |
34 |
35 | def load(*names):
36 | if len(names) == 1:
37 | return loadfromfile(
38 | os.path.join(WORKING_PATH, "cache_data/%s.cache" % names[0])
39 | )
40 | result = [None] * len(names)
41 | for i in range(len(names)):
42 | result[i] = loadfromfile(
43 | os.path.join(WORKING_PATH, "cache_data/%s.cache" % names[i])
44 | )
45 | return result
46 |
47 |
48 | def save(obj, name, protocol=pickle.HIGHEST_PROTOCOL):
49 | path = os.path.join(WORKING_PATH, "cache_data/%s.cache" % name)
50 | mkdir_for_file(path)
51 | dumptofile(obj, path, protocol)
52 |
53 |
54 | def hasher(obj):
55 | if isinstance(obj, list):
56 | m = hashlib.md5()
57 | for item in obj:
58 | m.update(item)
59 | return m.hexdigest()
60 | if isinstance(obj, str):
61 | return hashlib.md5(obj.encode("utf8")).hexdigest()
62 | return hashlib.md5(obj).hexdigest()
63 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/salami_io.py:
--------------------------------------------------------------------------------
1 | from mir.io.feature_io_base import *
2 | from mir.music_base import get_scale_and_suffix
3 |
4 |
5 | class SalamiIO(FeatureIO):
6 | def read(self, filename, entry):
7 | f = open(filename, "r")
8 | data = f.read()
9 | lines = data.split("\n")
10 | result = []
11 | metre_up = -1
12 | metre_down = -1
13 | tonic = -1
14 | for line in lines:
15 | if line == "":
16 | continue
17 | if line.startswith("#"):
18 | if ":" in line:
19 | seperator_index = line.index(":")
20 | keyword = line[1:seperator_index].strip()
21 | if keyword == "metre":
22 | slash_index = line.index("/")
23 | metre_up = int(line[seperator_index + 1 : slash_index].strip())
24 | metre_down = int(line[slash_index + 1 :].strip())
25 | # print('metre changed to %d/%d'%(metre_up,metre_down))
26 | if keyword == "tonic":
27 | tonic = int(
28 | get_scale_and_suffix(line[seperator_index + 1 :].strip())[0]
29 | )
30 |
31 | else:
32 | tokens = line.split("\t")
33 | assert len(tokens) == 2
34 | start_time = float(tokens[0])
35 | result.append((start_time, tokens[1], metre_up, metre_down, tonic))
36 | f.close()
37 | return result
38 |
39 | def write(self, data, filename, entry):
40 | raise NotImplementedError()
41 |
42 | def visualize(self, data, filename, entry, override_sr):
43 | f = open(filename, "w")
44 | for time, token, _, _, _ in data:
45 | f.write("%f\t%s\n" % (time, token))
46 | f.close()
47 |
--------------------------------------------------------------------------------
/polyffusion/models/model_chd_8bar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from dl_modules import ChordDecoder, ChordEncoder
5 | from utils import *
6 |
7 |
8 | class Chord_8Bar(nn.Module):
9 | def __init__(self, chord_enc: ChordEncoder, chord_dec: ChordDecoder):
10 | super(Chord_8Bar, self).__init__()
11 | self.chord_enc = chord_enc
12 | self.chord_dec = chord_dec
13 |
14 | @classmethod
15 | def load_trained(cls, chord_enc, chord_dec, model_dir):
16 | model = cls(chord_enc, chord_dec)
17 | trained_leaner = torch.load(f"{model_dir}/weights.pt")
18 | model.load_state_dict(trained_leaner["model"])
19 | return model
20 |
21 | def chord_loss(self, c, recon_root, recon_chroma, recon_bass):
22 | loss_fun = nn.CrossEntropyLoss()
23 | root = c[:, :, 0:12].max(-1)[-1].view(-1).contiguous()
24 | chroma = c[:, :, 12:24].long().view(-1).contiguous()
25 | bass = c[:, :, 24:].max(-1)[-1].view(-1).contiguous()
26 |
27 | recon_root = recon_root.view(-1, 12).contiguous()
28 | recon_chroma = recon_chroma.view(-1, 2).contiguous()
29 | recon_bass = recon_bass.view(-1, 12).contiguous()
30 | root_loss = loss_fun(recon_root, root)
31 | chroma_loss = loss_fun(recon_chroma, chroma)
32 | bass_loss = loss_fun(recon_bass, bass)
33 | chord_loss = root_loss + chroma_loss + bass_loss
34 | return {
35 | "loss": chord_loss,
36 | "root": root_loss,
37 | "chroma": chroma_loss,
38 | "bass": bass_loss,
39 | }
40 |
41 | def get_loss_dict(self, batch, step, tfr_chd):
42 | _, _, chord, _ = batch
43 |
44 | z_chd = self.chord_enc(chord).rsample()
45 | recon_root, recon_chroma, recon_bass = self.chord_dec(
46 | z_chd, False, tfr_chd, chord
47 | )
48 | return self.chord_loss(chord, recon_root, recon_chroma, recon_bass)
49 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/io/implement/chroma_io.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mir.io.feature_io_base import *
3 |
4 |
5 | class ChromaIO(FeatureIO):
6 | def read(self, filename, entry):
7 | if filename.endswith(".csv"):
8 | f = open(filename, "r")
9 | lines = f.readlines()
10 | result = []
11 | for line in lines:
12 | line = line.strip()
13 | if line == "":
14 | continue
15 | arr = np.array(list(map(float, line.split(",")[2:])))
16 | arr = arr.reshape((2, 12))[::-1].T
17 | arr = np.roll(arr, -3, axis=0).reshape((24))
18 | result.append(arr)
19 | data = np.array(result)
20 | else:
21 | data = pickle_read(self, filename)
22 | return data
23 |
24 | def write(self, data, filename, entry):
25 | pickle_write(self, data, filename)
26 |
27 | def visualize(self, data, filename, entry, override_sr):
28 | sr = entry.prop.sr
29 | win_shift = entry.prop.hop_length
30 | feature_tuple_size = entry.prop.chroma_tuple_size
31 | # if(FEATURETUPLESIZE==2):
32 | features = data
33 | f = open(filename, "w")
34 | for i in range(0, features.shape[0]):
35 | time = win_shift * i / sr
36 | f.write(str(time))
37 | for j in range(0, feature_tuple_size):
38 | if j > 0:
39 | f.write("\t0")
40 | for k in range(0, 12):
41 | f.write("\t" + str(features[i][k * feature_tuple_size + j]))
42 | f.write("\n")
43 | f.close()
44 |
45 | def pre_assign(self, entry, proxy):
46 | entry.prop.set("n_frame", LoadingPlaceholder(proxy, entry))
47 |
48 | def post_load(self, data, entry):
49 | entry.prop.set("n_frame", data.shape[0])
50 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/extractors/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mir.extractors.extractor_base import *
3 |
4 |
5 | class BlankMusic(ExtractorBase):
6 | def get_feature_class(self):
7 | return io.MusicIO
8 |
9 | def extract(self, entry, **kwargs):
10 | time = 60.0 # seconds
11 | if "time" in kwargs:
12 | time = kwargs["time"]
13 | return np.zeros((int(np.ceil(time * entry.prop.sr))))
14 |
15 |
16 | class FrameCount(ExtractorBase):
17 | def get_feature_class(self):
18 | return io.IntegerIO
19 |
20 | def extract(self, entry, **kwargs):
21 | # self.require(entry.prop.hop_length)
22 | return entry.dict[kwargs["source"]].get(entry).shape[0]
23 |
24 |
25 | class Evaluate:
26 | def __init__(self, io):
27 | self.__io = io
28 |
29 | def __call__(self, *args, **kwargs):
30 | inner_instance = Evaluate.InnerEvaluate()
31 | inner_instance.io = self.__io
32 | return inner_instance
33 |
34 | class InnerEvaluate(ExtractorBase):
35 | def __init__(self):
36 | self.io = None
37 |
38 | def get_feature_class(self):
39 | return self.io
40 |
41 | class __ProxyReflector:
42 | def __init__(self, entry):
43 | self.__entry = entry
44 |
45 | def __getattr__(self, item):
46 | if item in self.__entry.dict:
47 | print("Getting %s" % item)
48 | return self.__entry.dict[item].get(self.__entry)
49 | else:
50 | raise AttributeError(
51 | "No key '%s' found in entry %s" % (item, self.__entry.name)
52 | )
53 |
54 | def extract(self, entry, **kwargs):
55 | eval_proxy_ref__ = __class__.__ProxyReflector(entry)
56 | expr = kwargs["expr"].replace("$", "eval_proxy_ref__.")
57 | return eval(expr)
58 |
--------------------------------------------------------------------------------
/polyffusion/main.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from omegaconf import OmegaConf
4 |
5 | from train.train_autoencoder import Autoencoder_TrainConfig
6 | from train.train_chd_8bar import Chord8bar_TrainConfig
7 | from train.train_ddpm import DDPM_TrainConfig
8 | from train.train_ldm import LDM_TrainConfig
9 |
10 | if __name__ == "__main__":
11 | parser = ArgumentParser(
12 | description="train (or resume training) a Polyffusion model"
13 | )
14 | parser.add_argument(
15 | "--output_dir",
16 | default=None,
17 | help="directory in which to store model checkpoints and training logs",
18 | )
19 | parser.add_argument(
20 | "--data_dir", default=None, help="directory of custom training data, in npzs"
21 | )
22 | parser.add_argument(
23 | "--pop909_use_track",
24 | default="0,1,2",
25 | help="which tracks to use for pop909 (default dataset) training. (0: melody, 1: bridge, 2: piano accompaniment)",
26 | )
27 | parser.add_argument("--model", help="which model to train (autoencoder, ldm, ddpm)")
28 | args = parser.parse_args()
29 |
30 | use_track = [int(x) for x in args.pop909_use_track.split(",")]
31 |
32 | params = OmegaConf.load(f"polyffusion/params/{args.model}.yaml")
33 |
34 | if args.model.startswith("sdf"):
35 | use_musicalion = "musicalion" in args.model
36 | config = LDM_TrainConfig(
37 | params,
38 | args.output_dir,
39 | use_musicalion,
40 | use_track=use_track,
41 | data_dir=args.data_dir,
42 | )
43 | elif args.model == "ddpm":
44 | config = DDPM_TrainConfig(params, args.output_dir, data_dir=args.data_dir)
45 | elif args.model == "autoencoder":
46 | config = Autoencoder_TrainConfig(
47 | params, args.output_dir, data_dir=args.data_dir
48 | )
49 | elif args.model == "chd_8bar":
50 | config = Chord8bar_TrainConfig(params, args.output_dir, data_dir=args.data_dir)
51 | else:
52 | raise NotImplementedError
53 | config.train()
54 |
--------------------------------------------------------------------------------
/polyffusion/train/train_chd_8bar.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from data.dataloader import get_custom_train_val_dataloaders, get_train_val_dataloaders
4 | from dl_modules import ChordDecoder, ChordEncoder
5 | from models.model_chd_8bar import Chord_8Bar
6 | from train.scheduler import ParameterScheduler, TeacherForcingScheduler
7 |
8 | # from stable_diffusion.model.autoencoder import Autoencoder, Encoder, Decoder
9 | from . import *
10 |
11 |
12 | class Chord8bar_TrainConfig(TrainConfig):
13 | def __init__(self, params, output_dir, data_dir=None) -> None:
14 | # Teacher-forcing rate for Chord VAE training
15 | tfr_chd = params.tfr_chd
16 | tfr_chd_scheduler = TeacherForcingScheduler(*tfr_chd)
17 | params_dict = dict(tfr_chd=tfr_chd_scheduler)
18 | param_scheduler = ParameterScheduler(**params_dict)
19 |
20 | super().__init__(params, param_scheduler, output_dir)
21 |
22 | self.chord_enc = ChordEncoder(
23 | input_dim=params.chd_input_dim,
24 | hidden_dim=params.chd_hidden_dim,
25 | z_dim=params.chd_z_dim,
26 | )
27 | self.chord_dec = ChordDecoder(
28 | input_dim=params.chd_input_dim,
29 | z_input_dim=params.chd_z_input_dim,
30 | hidden_dim=params.chd_hidden_dim,
31 | z_dim=params.chd_z_dim,
32 | n_step=params.chd_n_step,
33 | )
34 | self.model = Chord_8Bar(
35 | self.chord_enc,
36 | self.chord_dec,
37 | )
38 |
39 | # Create dataloader
40 | if data_dir is None:
41 | self.train_dl, self.val_dl = get_train_val_dataloaders(
42 | params.batch_size, params.num_workers, params.pin_memory
43 | )
44 | else:
45 | self.train_dl, self.val_dl = get_custom_train_val_dataloaders(
46 | params.batch_size,
47 | data_dir,
48 | num_workers=params.num_workers,
49 | pin_memory=params.pin_memory,
50 | )
51 |
52 | # Create optimizer
53 | self.optimizer = torch.optim.Adam(
54 | self.model.parameters(), lr=params.learning_rate
55 | )
56 |
--------------------------------------------------------------------------------
/polyffusion/train/train_autoencoder.py:
--------------------------------------------------------------------------------
1 | # This file is unused
2 |
3 | import torch
4 |
5 | from data.dataloader import get_custom_train_val_dataloaders, get_train_val_dataloaders
6 | from dirs import *
7 | from models.model_autoencoder import Polyffusion_Autoencoder
8 | from stable_diffusion.model.autoencoder import Autoencoder, Decoder, Encoder
9 |
10 | from . import TrainConfig
11 |
12 |
13 | class Autoencoder_TrainConfig(TrainConfig):
14 | model: Autoencoder
15 | optimizer: torch.optim.Adam
16 |
17 | def __init__(self, params, output_dir, data_dir=None) -> None:
18 | super().__init__(params, None, output_dir)
19 | encoder = Encoder(
20 | in_channels=params.in_channels,
21 | z_channels=params.z_channels,
22 | channels=params.channels,
23 | channel_multipliers=params.channel_multipliers,
24 | n_resnet_blocks=params.n_res_blocks,
25 | )
26 |
27 | decoder = Decoder(
28 | out_channels=params.out_channels,
29 | z_channels=params.z_channels,
30 | channels=params.channels,
31 | channel_multipliers=params.channel_multipliers,
32 | n_resnet_blocks=params.n_res_blocks,
33 | )
34 |
35 | autoencoder = Autoencoder(
36 | encoder=encoder,
37 | decoder=decoder,
38 | emb_channels=params.emb_channels,
39 | z_channels=params.z_channels,
40 | )
41 |
42 | self.model = Polyffusion_Autoencoder(autoencoder)
43 |
44 | # Create dataloader
45 | if data_dir is None:
46 | self.train_dl, self.val_dl = get_train_val_dataloaders(
47 | params.batch_size, params.num_workers, params.pin_memory
48 | )
49 | else:
50 | self.train_dl, self.val_dl = get_custom_train_val_dataloaders(
51 | params.batch_size,
52 | data_dir,
53 | num_workers=params.num_workers,
54 | pin_memory=params.pin_memory,
55 | )
56 |
57 | # Create optimizer
58 | self.optimizer = torch.optim.Adam(
59 | self.model.parameters(), lr=params.learning_rate
60 | )
61 |
--------------------------------------------------------------------------------
/polyffusion/lightning_learner.py:
--------------------------------------------------------------------------------
1 | import lightning
2 | import torch
3 |
4 |
5 | class LightningLearner(lightning.LightningModule):
6 | def __init__(self, model, optimizer, params, param_scheduler):
7 | super().__init__()
8 | self.model = model
9 | self.optimizer = optimizer
10 | self.params = params
11 | self.param_scheduler = param_scheduler # teacher-forcing stuff
12 |
13 | self.save_hyperparameters("params", "param_scheduler")
14 |
15 | def _categorize_loss_dict(self, loss_dict, category):
16 | return {f"{category}/{k}": v for k, v in loss_dict.items()}
17 |
18 | def training_step(self, batch, batch_idx):
19 | if self.param_scheduler is not None:
20 | scheduled_params = self.param_scheduler.step()
21 | loss_dict = self.model.get_loss_dict(
22 | batch, self.global_step, **scheduled_params
23 | )
24 | else:
25 | scheduled_params = None
26 | loss_dict = self.model.get_loss_dict(batch, self.global_step)
27 |
28 | # check NaN
29 | for loss_value in list(loss_dict.values()):
30 | if isinstance(loss_value, torch.Tensor) and torch.isnan(loss_value).any():
31 | raise RuntimeError(
32 | f"Detected NaN loss at step {self.global_step}, epoch {self.epoch}"
33 | )
34 | loss = loss_dict["loss"]
35 |
36 | loss_dict = self._categorize_loss_dict(loss_dict, "train")
37 | self.log_dict(loss_dict, prog_bar=True)
38 |
39 | return loss
40 |
41 | def validation_step(self, batch, batch_idx):
42 | if self.param_scheduler is not None:
43 | scheduled_params = self.param_scheduler.step()
44 | loss_dict = self.model.get_loss_dict(
45 | batch, self.global_step, **scheduled_params
46 | )
47 | else:
48 | scheduled_params = None
49 | loss_dict = self.model.get_loss_dict(batch, self.global_step)
50 |
51 | loss_dict = self._categorize_loss_dict(loss_dict, "val")
52 | self.log_dict(loss_dict)
53 |
54 | def configure_optimizers(self):
55 | return self.optimizer
56 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/complex_chord_io.py:
--------------------------------------------------------------------------------
1 | import complex_chord
2 | import numpy as np
3 | from mir.common import PACKAGE_PATH
4 | from mir.io.feature_io_base import *
5 |
6 |
7 | class ComplexChordIO(FeatureIO):
8 | def read(self, filename, entry):
9 | n_frame = entry.n_frame
10 | f = open(filename, "r")
11 | line_list = f.readlines()
12 | tags = np.ones((n_frame, 6)) * -2
13 | for line in line_list:
14 | line = line.strip()
15 | if line == "":
16 | continue
17 | if "\t" in line:
18 | tokens = line.split("\t")
19 | else:
20 | tokens = line.split(" ")
21 | sr = entry.prop.sr
22 | win_shift = entry.prop.hop_length
23 | begin = int(round(float(tokens[0]) * sr / win_shift))
24 | end = int(round(float(tokens[1]) * sr / win_shift))
25 | if end > n_frame:
26 | end = n_frame
27 | if begin < 0:
28 | begin = 0
29 | tags[begin:end, :] = (
30 | complex_chord.Chord(tokens[2]).to_numpy().reshape((1, 6))
31 | )
32 | f.close()
33 | return tags
34 |
35 | def write(self, data, filename, entry):
36 | raise NotImplementedError()
37 |
38 | def visualize(self, data, filename, entry, override_sr):
39 | f = open(os.path.join(PACKAGE_PATH, "data/spectrogram_template.svl"), "r")
40 | sr = entry.prop.sr
41 | win_shift = entry.prop.hop_length
42 | content = f.read()
43 | f.close()
44 | content = content.replace("[__SR__]", str(sr))
45 | content = content.replace("[__WIN_SHIFT__]", str(win_shift))
46 | content = content.replace("[__SHAPE_1__]", str(data.shape[1]))
47 | content = content.replace("[__COLOR__]", str(1))
48 | labels = [str(i) for i in range(data.shape[1])]
49 | content = content.replace("[__DATA__]", create_svl_3d_data(labels, data))
50 | f = open(filename, "w")
51 | f.write(content)
52 | f.close()
53 |
54 | def get_visualize_extention_name(self):
55 | return "svl"
56 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/midilab_io.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mir.common import PACKAGE_PATH
3 | from mir.io.feature_io_base import *
4 |
5 |
6 | class MidiLabIO(FeatureIO):
7 | def read(self, filename, entry):
8 | f = open(filename, "r")
9 | lines = f.readlines()
10 | lines = [line.strip("\n\r") for line in lines]
11 | lines = [line for line in lines if line != ""]
12 | f.close()
13 | result = np.zeros((len(lines), 3))
14 | for i in range(len(lines)):
15 | line = lines[i]
16 | tokens = line.split("\t")
17 | assert len(tokens) == 3
18 | result[i, 0] = float(tokens[0])
19 | result[i, 1] = float(tokens[1])
20 | result[i, 2] = float(tokens[2])
21 | return result
22 |
23 | def write(self, data, filename, entry):
24 | f = open(filename, "w")
25 | for i in range(0, len(data)):
26 | f.write("\t".join([str(item) for item in data[i]]))
27 | f.write("\n")
28 | f.close()
29 |
30 | def visualize(self, data, filename, entry, override_sr):
31 | f = open(os.path.join(PACKAGE_PATH, "data/midi_template.svl"), "r")
32 | sr = override_sr
33 | content = f.read()
34 | f.close()
35 | content = content.replace("[__SR__]", str(sr))
36 | content = content.replace("[__WIN_SHIFT__]", "1")
37 | output_text = ""
38 | for note_info in data:
39 | output_text += self.__get_midi_note_text(
40 | note_info[0] * sr, note_info[1] * sr - 1, note_info[2]
41 | )
42 | content = content.replace("[__DATA__]", output_text)
43 | f = open(filename, "w")
44 | f.write(content)
45 | f.close()
46 |
47 | def __get_midi_note_text(self, start_frame, end_frame, midi_height, level=0.78125):
48 | return '\n' % (
49 | int(round(start_frame)),
50 | midi_height,
51 | int(round(end_frame - start_frame)),
52 | level,
53 | )
54 |
55 | def get_visualize_extention_name(self):
56 | return "svl"
57 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/osu_io.py:
--------------------------------------------------------------------------------
1 | from mir.io.feature_io_base import *
2 |
3 |
4 | class MetaDict:
5 | def __init__(self):
6 | self.dict = {}
7 |
8 | def set(self, key, value):
9 | self.dict[key] = value
10 |
11 | def __getattr__(self, item):
12 | return self.dict[item.lower()]
13 |
14 |
15 | class OsuMapIO(FeatureIO):
16 | def read(self, filename, entry):
17 | f = open(filename, "r", encoding="UTF-8")
18 | result = MetaDict()
19 | lines = f.readlines()
20 | current_state = 0
21 | current_dict = None
22 | for line in lines:
23 | line = line.strip()
24 | if line == "":
25 | continue
26 | if line.startswith("["):
27 | assert line.endswith("]")
28 | namespace = line[1:-1].lower()
29 | if namespace in [
30 | "general",
31 | "editor",
32 | "metadata",
33 | "difficulty",
34 | "colours",
35 | ]:
36 | current_state = 1
37 | current_dict = MetaDict()
38 | result.set(namespace, current_dict)
39 | elif namespace in ["hitobjects", "events", "timingpoints"]:
40 | current_state = 2
41 | current_dict = []
42 | result.set(namespace, current_dict)
43 | else:
44 | raise Exception(
45 | "Unknown namespace %s in %s" % (namespace, filename)
46 | )
47 | else:
48 | if current_state == 1:
49 | split_index = line.index(":")
50 | key = line[:split_index].strip()
51 | value = line[split_index + 1 :].strip()
52 | current_dict.set(key.lower(), value)
53 | elif current_state == 2:
54 | current_dict.append(line)
55 | return result
56 |
57 | def write(self, data, filename, entry):
58 | raise NotImplementedError()
59 |
60 | def visualize(self, data, filename, entry, override_sr):
61 | raise NotImplementedError()
62 |
63 | def get_visualize_extention_name(self):
64 | raise NotImplementedError()
65 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/extractors/librosa_extractor.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import numpy as np
3 | from mir.extractors.extractor_base import *
4 |
5 |
6 | class HPSS(ExtractorBase):
7 | def get_feature_class(self):
8 | return io.MusicIO
9 |
10 | def extract(self, entry, **kwargs):
11 | if "source" in kwargs:
12 | y = entry.dict[kwargs["source"]].get(entry)
13 | else:
14 | y = entry.music
15 | y_h = librosa.effects.harmonic(y, margin=kwargs["margin"])
16 | # y_h, y_p = librosa.effects.hpss(y, margin=(1.0, 5.0))
17 | return y_h
18 |
19 |
20 | class CQT(ExtractorBase):
21 | def get_feature_class(self):
22 | return io.SpectrogramIO
23 |
24 | # Warning this spectrum has a 1/3 half note stepping
25 | def extract(self, entry, **kwargs):
26 | n_bins = 262
27 | y = entry.music
28 | logspec = librosa.core.cqt(
29 | y,
30 | sr=kwargs["sr"],
31 | hop_length=kwargs["hop_length"],
32 | bins_per_octave=36,
33 | n_bins=n_bins,
34 | filter_scale=1.5,
35 | ).T
36 | logspec = np.abs(logspec)
37 | return logspec
38 |
39 |
40 | class STFT(ExtractorBase):
41 | def get_feature_class(self):
42 | return io.SpectrogramIO
43 |
44 | # Warning this spectrum has a 1/3 half note stepping
45 | def extract(self, entry, **kwargs):
46 | y = entry.music
47 | logspec = librosa.core.stft(
48 | y, win_length=kwargs["win_size"], hop_length=kwargs["hop_length"]
49 | ).T
50 | logspec = np.abs(logspec)
51 | return logspec
52 |
53 |
54 | class Onset(ExtractorBase):
55 | def get_feature_class(self):
56 | return io.SpectrogramIO
57 |
58 | def extract(self, entry, **kwargs):
59 | onset = librosa.onset.onset_strength(
60 | entry.music, sr=kwargs["sr"], hop_length=kwargs["hop_length"]
61 | ).reshape((-1, 1))
62 | return onset
63 |
64 |
65 | class Energy(ExtractorBase):
66 | def get_feature_class(self):
67 | return io.SpectrogramIO
68 |
69 | def extract(self, entry, **kwargs):
70 | energy = librosa.feature.rmse(
71 | y=entry.dict[kwargs["source"]].get(entry),
72 | hop_length=kwargs["hop_length"],
73 | frame_length=kwargs["win_size"],
74 | center=True,
75 | ).reshape((-1, 1))
76 | return energy
77 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # local
2 | /data/
3 | /pretrained/
4 | /result*/
5 | /exp/
6 | /eval/
7 | /wandb/
8 |
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | pip-wheel-metadata/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/io/implement/spectrogram_io.py:
--------------------------------------------------------------------------------
1 | from mir.common import PACKAGE_PATH
2 | from mir.io.feature_io_base import *
3 |
4 |
5 | class SpectrogramIO(FeatureIO):
6 | def read(self, filename, entry):
7 | return pickle_read(self, filename)
8 |
9 | def write(self, data, filename, entry):
10 | pickle_write(self, data, filename)
11 |
12 | def visualize(self, data, filename, entry, override_sr):
13 | if type(data) is tuple:
14 | labels = data[0]
15 | data = data[1]
16 | else:
17 | labels = None
18 | if len(data.shape) == 1:
19 | data = data.reshape((-1, 1))
20 | if data.shape[1] > 1:
21 | f = open(os.path.join(PACKAGE_PATH, "data/spectrogram_template.svl"), "r")
22 | sr = entry.prop.sr
23 | win_shift = entry.prop.hop_length
24 | content = f.read()
25 | f.close()
26 | content = content.replace("[__SR__]", str(sr))
27 | content = content.replace("[__WIN_SHIFT__]", str(win_shift))
28 | content = content.replace("[__SHAPE_1__]", str(data.shape[1]))
29 | content = content.replace("[__COLOR__]", str(1))
30 | if labels is None:
31 | labels = [str(i) for i in range(data.shape[1])]
32 | content = content.replace("[__DATA__]", create_svl_3d_data(labels, data))
33 | else:
34 | f = open(os.path.join(PACKAGE_PATH, "data/curve_template.svl"), "r")
35 | sr = entry.prop.sr
36 | win_shift = entry.prop.hop_length
37 | content = f.read()
38 | f.close()
39 | content = content.replace("[__SR__]", str(sr))
40 | content = content.replace("[__STYLE__]", str(1))
41 | results = []
42 | for i in range(0, len(data)):
43 | results.append(
44 | ''
45 | % (int(override_sr / sr * i * win_shift), data[i, 0])
46 | )
47 | content = content.replace("[__DATA__]", "\n".join(results))
48 | content = content.replace("[__NAME__]", "curve")
49 |
50 | f = open(filename, "w")
51 | f.write(content)
52 | f.close()
53 |
54 | def pre_assign(self, entry, proxy):
55 | entry.prop.set("n_frame", LoadingPlaceholder(proxy, entry))
56 |
57 | def post_load(self, data, entry):
58 | entry.prop.set("n_frame", data.shape[0])
59 |
60 | def get_visualize_extention_name(self):
61 | return "svl"
62 |
--------------------------------------------------------------------------------
/polyffusion/data/dataloader_musicalion.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import DataLoader
4 |
5 | from data.dataset_musicalion import PianoOrchDataset_Musicalion
6 | from utils import (
7 | estx_to_midi_file,
8 | pianotree_pitch_shift,
9 | pr_mat_pitch_shift,
10 | prmat2c_to_midi_file,
11 | prmat_to_midi_file,
12 | )
13 |
14 |
15 | def collate_fn(batch, shift):
16 | def sample_shift():
17 | return np.random.choice(np.arange(-6, 6), 1)[0]
18 |
19 | prmat2c = []
20 | pnotree = []
21 | prmat = []
22 | song_fn = []
23 | for b in batch:
24 | # b[0]: seg_pnotree; b[1]: seg_pnotree_y
25 | seg_prmat2c = b[0]
26 | seg_pnotree = b[1]
27 | seg_prmat = b[3]
28 |
29 | if shift:
30 | shift_pitch = sample_shift()
31 | seg_prmat2c = pr_mat_pitch_shift(seg_prmat2c, shift_pitch)
32 | seg_pnotree = pianotree_pitch_shift(seg_pnotree, shift_pitch)
33 | seg_prmat = pr_mat_pitch_shift(seg_prmat, shift_pitch)
34 |
35 | prmat2c.append(seg_prmat2c)
36 | pnotree.append(seg_pnotree)
37 | prmat.append(seg_prmat)
38 |
39 | if len(b) > 4:
40 | song_fn.append(b[4])
41 |
42 | prmat2c = torch.Tensor(np.array(prmat2c, np.float32)).float()
43 | pnotree = torch.Tensor(np.array(pnotree, np.int64)).long()
44 | prmat = torch.Tensor(np.array(prmat, np.float32)).float()
45 | # prmat = prmat.unsqueeze(1) # (B, 1, 128, 128)
46 | if len(song_fn) > 0:
47 | return prmat2c, pnotree, None, prmat, song_fn
48 | else:
49 | return prmat2c, pnotree, None, prmat
50 |
51 |
52 | def get_train_val_dataloaders(batch_size, num_workers=0, pin_memory=False, debug=False):
53 | train_dataset, val_dataset = PianoOrchDataset_Musicalion.load_train_and_valid_sets(
54 | debug
55 | )
56 | train_dl = DataLoader(
57 | train_dataset,
58 | batch_size,
59 | True,
60 | collate_fn=lambda x: collate_fn(x, shift=True),
61 | num_workers=num_workers,
62 | pin_memory=pin_memory,
63 | )
64 | val_dl = DataLoader(
65 | val_dataset,
66 | batch_size,
67 | False,
68 | collate_fn=lambda x: collate_fn(x, shift=False),
69 | num_workers=num_workers,
70 | pin_memory=pin_memory,
71 | )
72 | print(
73 | f"Dataloader ready: batch_size={batch_size}, num_workers={num_workers}, pin_memory={pin_memory}"
74 | )
75 | return train_dl, val_dl
76 |
77 |
78 | if __name__ == "__main__":
79 | train_dl, val_dl = get_train_val_dataloaders(16)
80 | print(len(train_dl))
81 | for batch in train_dl:
82 | print(len(batch))
83 | prmat2c, pnotree, _, prmat = batch
84 | print(prmat2c.shape)
85 | print(pnotree.shape)
86 | print(prmat.shape)
87 | prmat2c = prmat2c.cpu().numpy()
88 | pnotree = pnotree.cpu().numpy()
89 | prmat = prmat.cpu().numpy()
90 | prmat2c_to_midi_file(prmat2c, "exp/m_dl_prmat2c.mid")
91 | estx_to_midi_file(pnotree, "exp/m_dl_pnotree.mid")
92 | prmat_to_midi_file(prmat, "exp/m_dl_prmat.mid")
93 | exit(0)
94 |
--------------------------------------------------------------------------------
/polyffusion/train/scheduler.py:
--------------------------------------------------------------------------------
1 | # Copied from torch_plus
2 |
3 | import numpy as np
4 |
5 |
6 | def scheduled_sampling(i, high=0.7, low=0.05):
7 | i /= 1000 * 40 # new update
8 | x = 10 * (i - 0.5)
9 | z = 1 / (1 + np.exp(x))
10 | y = (high - low) * z + low
11 | return y
12 |
13 |
14 | class _Scheduler:
15 | def __init__(self, step=0, mode="train"):
16 | self._step = step
17 | self._mode = mode
18 |
19 | def _update_step(self):
20 | if self._mode == "train":
21 | self._step += 1
22 | elif self._mode == "val":
23 | pass
24 | else:
25 | raise NotImplementedError
26 |
27 | def step(self):
28 | raise NotImplementedError
29 |
30 | def train(self):
31 | self._mode = "train"
32 |
33 | def eval(self):
34 | self._mode = "val"
35 |
36 |
37 | class ConstantScheduler(_Scheduler):
38 | def __init__(self, param, step=0.0):
39 | super(ConstantScheduler, self).__init__(step)
40 | self.param = param
41 |
42 | def step(self):
43 | self._update_step()
44 | return self.param
45 |
46 |
47 | class TeacherForcingScheduler(_Scheduler):
48 | def __init__(self, high, low, f=scheduled_sampling, step=0):
49 | super(TeacherForcingScheduler, self).__init__(step)
50 | self.high = high
51 | self.low = low
52 | self._step = step
53 | self.schedule_f = f
54 |
55 | def get_tfr(self):
56 | return self.schedule_f(self._step, self.high, self.low)
57 |
58 | def step(self):
59 | tfr = self.get_tfr()
60 | self._update_step()
61 | return tfr
62 |
63 |
64 | class OptimizerScheduler(_Scheduler):
65 | def __init__(self, optimizer, scheduler, clip, step=0):
66 | # optimizer and scheduler are pytorch class
67 | super(OptimizerScheduler, self).__init__(step)
68 | self.optimizer = optimizer
69 | self.scheduler = scheduler
70 | self.clip = clip
71 |
72 | def optimizer_zero_grad(self):
73 | self.optimizer.zero_grad()
74 |
75 | def step(self, require_zero_grad=False):
76 | self.optimizer.step()
77 | self.scheduler.step()
78 | if require_zero_grad:
79 | self.optimizer_zero_grad()
80 | self._update_step()
81 |
82 |
83 | class ParameterScheduler(_Scheduler):
84 | def __init__(self, step=0, mode="train", **schedulers):
85 | # optimizer and scheduler are pytorch class
86 | super(ParameterScheduler, self).__init__(step)
87 | self.schedulers = schedulers
88 | self.mode = mode
89 |
90 | def train(self):
91 | self.mode = "train"
92 | for scheduler in self.schedulers.values():
93 | scheduler.train()
94 |
95 | def eval(self):
96 | self.mode = "val"
97 | for scheduler in self.schedulers.values():
98 | scheduler.eval()
99 |
100 | def step(self, require_zero_grad=False):
101 | params_dic = {}
102 | for key, scheduler in self.schedulers.items():
103 | params_dic[key] = scheduler.step()
104 | return params_dic
105 |
--------------------------------------------------------------------------------
/polyffusion/stable_diffusion/losses/discriminator.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch.nn as nn
4 |
5 | from .util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find("Conv") != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find("BatchNorm") != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 |
22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23 | """Construct a PatchGAN discriminator
24 | Parameters:
25 | input_nc (int) -- the number of channels in input images
26 | ndf (int) -- the number of filters in the last conv layer
27 | n_layers (int) -- the number of conv layers in the discriminator
28 | norm_layer -- normalization layer
29 | """
30 | super(NLayerDiscriminator, self).__init__()
31 | if not use_actnorm:
32 | norm_layer = nn.BatchNorm2d
33 | else:
34 | norm_layer = ActNorm
35 | if (
36 | type(norm_layer) == functools.partial
37 | ): # no need to use bias as BatchNorm2d has affine parameters
38 | use_bias = norm_layer.func != nn.BatchNorm2d
39 | else:
40 | use_bias = norm_layer != nn.BatchNorm2d
41 |
42 | kw = 4
43 | padw = 1
44 | sequence = [
45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46 | nn.LeakyReLU(0.2, True),
47 | ]
48 | nf_mult = 1
49 | nf_mult_prev = 1
50 | for n in range(1, n_layers): # gradually increase the number of filters
51 | nf_mult_prev = nf_mult
52 | nf_mult = min(2**n, 8)
53 | sequence += [
54 | nn.Conv2d(
55 | ndf * nf_mult_prev,
56 | ndf * nf_mult,
57 | kernel_size=kw,
58 | stride=2,
59 | padding=padw,
60 | bias=use_bias,
61 | ),
62 | norm_layer(ndf * nf_mult),
63 | nn.LeakyReLU(0.2, True),
64 | ]
65 |
66 | nf_mult_prev = nf_mult
67 | nf_mult = min(2**n_layers, 8)
68 | sequence += [
69 | nn.Conv2d(
70 | ndf * nf_mult_prev,
71 | ndf * nf_mult,
72 | kernel_size=kw,
73 | stride=1,
74 | padding=padw,
75 | bias=use_bias,
76 | ),
77 | norm_layer(ndf * nf_mult),
78 | nn.LeakyReLU(0.2, True),
79 | ]
80 |
81 | sequence += [
82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83 | ] # output 1 channel prediction map
84 | self.main = nn.Sequential(*sequence)
85 |
86 | def forward(self, input):
87 | """Standard forward."""
88 | return self.main(input)
89 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 |
4 | import numpy as np
5 | from chord_class import ChordClass
6 | from extractors.midi_utilities import (
7 | MidiBeatExtractor,
8 | )
9 | from extractors.rule_based_channel_reweight import midi_to_thickness_and_bass_weights
10 | from io_new.chordlab_io import ChordLabIO
11 | from midi_chord import ChordRecognition
12 | from mir import DataEntry, io
13 | from tqdm import tqdm
14 |
15 |
16 | def process_chord(entry, extra_division):
17 | """
18 |
19 | Parameters
20 | ----------
21 | entry: the song to be processed. Properties required:
22 | entry.midi: the pretry midi object
23 | entry.beat: extracted beat and downbeat
24 | extra_division: extra divisions to each beat.
25 | For chord recognition on beat-level, use extra_division=1
26 | For chord recognition on half-beat-level, use extra_division=2
27 |
28 | Returns
29 | -------
30 | Extracted chord sequence
31 | """
32 |
33 | midi = entry.midi
34 | beats = midi.get_beats()
35 | if extra_division > 1:
36 | beat_interp = np.linspace(beats[:-1], beats[1:], extra_division + 1).T
37 | last_beat = beat_interp[-1, -1]
38 | beats = np.append(beat_interp[:, :-1].reshape((-1)), last_beat)
39 | downbeats = midi.get_downbeats()
40 | j = 0
41 | beat_pos = -2
42 | beat = []
43 | for i in range(len(beats)):
44 | if j < len(downbeats) and beats[i] == downbeats[j]:
45 | beat_pos = 1
46 | j += 1
47 | else:
48 | beat_pos = beat_pos + 1
49 | assert beat_pos > 0
50 | beat.append([beats[i], beat_pos])
51 | rec = ChordRecognition(entry, ChordClass())
52 | weights = midi_to_thickness_and_bass_weights(entry.midi)
53 | rec.process_feature(weights)
54 | chord = rec.decode()
55 | return chord
56 |
57 |
58 | def transcribe_cb1000_midi(midi_path, output_path):
59 | """
60 | Perform chord recognition on a midi
61 | :param midi_path: the path to the midi file
62 | :param output_path: the path to the output file
63 | """
64 | entry = DataEntry()
65 | entry.append_file(midi_path, io.MidiIO, "midi")
66 | entry.append_extractor(MidiBeatExtractor, "beat")
67 | result = process_chord(entry, extra_division=2)
68 | entry.append_data(result, ChordLabIO, "pred")
69 | entry.save("pred", output_path)
70 |
71 |
72 | def extract_in_folder():
73 | dpath = sys.argv[1]
74 | dpath_output = sys.argv[2]
75 | os.system(f"rm -rf {dpath_output}")
76 | os.system(f"mkdir -p {dpath_output}")
77 | for piece in tqdm(os.listdir(dpath)):
78 | os.system(f"mkdir -p {join(dpath_output, piece)}")
79 | for ver in os.listdir(join(dpath, piece)):
80 | transcribe_cb1000_midi(
81 | join(dpath, piece, ver), join(dpath_output, piece, ver[:-4]) + ".out"
82 | )
83 |
84 |
85 | if __name__ == "__main__":
86 | import sys
87 |
88 | if len(sys.argv) != 3:
89 | print("Usage: main.py midi_path output_path")
90 | exit(0)
91 |
92 | transcribe_cb1000_midi(sys.argv[1], sys.argv[2])
93 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/io/feature_io_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from abc import ABC, abstractmethod
4 |
5 |
6 | class LoadingPlaceholder:
7 | def __init__(self, proxy, entry):
8 | self.proxy = proxy
9 | self.entry = entry
10 | pass
11 |
12 | def fire(self):
13 | self.proxy.get(self.entry)
14 |
15 |
16 | class FeatureIO(ABC):
17 | @abstractmethod
18 | def read(self, filename, entry):
19 | pass
20 |
21 | def safe_read(self, filename, entry):
22 | entry.prop.start_record_reading()
23 | try:
24 | result = self.read(filename, entry)
25 | except Exception:
26 | entry.prop.end_record_reading()
27 | raise
28 | entry.prop.end_record_reading()
29 | return result
30 |
31 | def try_mkdir(self, filename):
32 | folder = os.path.dirname(filename)
33 | if not os.path.isdir(folder):
34 | os.makedirs(folder)
35 |
36 | def create(self, data, filename, entry):
37 | self.try_mkdir(filename)
38 | self.write(data, filename, entry)
39 |
40 | @abstractmethod
41 | def write(self, data, filename, entry):
42 | pass
43 |
44 | # override iif writing and visualizing use different methods
45 | # (i.e. compressed vs uncompressed)
46 | def visualize(self, data, filename, entry, override_sr):
47 | self.write(data, filename, entry)
48 |
49 | # override iff entry properties will be updated upon loading
50 | def pre_assign(self, entry, proxy):
51 | pass
52 |
53 | # override iff entry properties need updated upon loading
54 | def post_load(self, data, entry):
55 | pass
56 |
57 | # override iif it will save as other formats (e.g. wav)
58 | def get_visualize_extention_name(self):
59 | return "txt"
60 |
61 | def file_to_evaluation_format(self, filename, entry):
62 | raise Exception("Not supported by the io class")
63 |
64 | def data_to_evaluation_format(self, data, entry):
65 | raise Exception("Not supported by the io class")
66 |
67 |
68 | def pickle_read(self, filename):
69 | f = open(filename, "rb")
70 | obj = pickle.load(f)
71 | f.close()
72 | return obj
73 |
74 |
75 | def pickle_write(self, data, filename):
76 | f = open(filename, "wb")
77 | pickle.dump(data, f)
78 | f.close()
79 |
80 |
81 | def create_svl_3d_data(labels, data):
82 | assert len(labels) == data.shape[1]
83 | results_part1 = [
84 | '' % (i, str(labels[i])) for i in range(len(labels))
85 | ]
86 | results_part2 = [
87 | '%s
' % (i, " ".join([str(s) for s in data[i]]))
88 | for i in range(data.shape[0])
89 | ]
90 | return "\n".join(results_part1) + "\n" + "\n".join(results_part2)
91 |
92 |
93 | def framed_2d_feature_visualizer(entry, features, filename):
94 | f = open(filename, "w")
95 | for i in range(0, features.shape[0]):
96 | time = entry.prop.hop_length * i / entry.prop.sr
97 | f.write(str(time))
98 | for j in range(0, features.shape[1]):
99 | f.write("\t" + str(features[i][j]))
100 | f.write("\n")
101 | f.close()
102 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/io_new/beat_align_io.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mir import PACKAGE_PATH
3 | from mir.io.feature_io_base import *
4 |
5 |
6 | class BeatAlignCQTIO(FeatureIO):
7 | def read(self, filename, entry):
8 | return pickle_read(self, filename)
9 |
10 | def write(self, data, filename, entry):
11 | pickle_write(self, data, filename)
12 |
13 | def visualize(self, data, filename, entry, override_sr):
14 | sr = entry.prop.sr
15 | win_shift = entry.prop.hop_length
16 | beat = entry.beat
17 | assert len(beat) - 1 == data.shape[0]
18 | n_frame = int(beat[-1] * sr / win_shift) + data.shape[1] + 1
19 | new_data = np.ones((n_frame, data.shape[2])) * -1
20 | for i in range(len(beat) - 1):
21 | time = int(np.round(beat[i] * sr / win_shift))
22 | for j in range(data.shape[1]):
23 | time_j = time + j
24 | if time_j >= 0 and time_j < n_frame:
25 | new_data[time_j, :] = data[i, j, :]
26 | f = open(os.path.join(PACKAGE_PATH, "data/spectogram_template.svl"), "r")
27 | content = f.read()
28 | f.close()
29 | content = content.replace("[__SR__]", str(sr))
30 | content = content.replace("[__WIN_SHIFT__]", str(win_shift))
31 | content = content.replace("[__SHAPE_1__]", str(new_data.shape[1]))
32 | content = content.replace("[__COLOR__]", str(1))
33 | labels = [str(i) for i in range(new_data.shape[1])]
34 | content = content.replace("[__DATA__]", create_svl_3d_data(labels, new_data))
35 | f = open(filename, "w")
36 | f.write(content)
37 | f.close()
38 |
39 | def get_visualize_extention_name(self):
40 | return "svl"
41 |
42 |
43 | class BeatSpectrogramIO(FeatureIO):
44 | def read(self, filename, entry):
45 | return pickle_read(self, filename)
46 |
47 | def write(self, data, filename, entry):
48 | pickle_write(self, data, filename)
49 |
50 | def visualize(self, data, filename, entry, override_sr):
51 | sr = entry.prop.sr
52 | win_shift = entry.prop.hop_length
53 | beat = entry.beat
54 | assert len(beat) - 1 == data.shape[0]
55 | n_frame = int(beat[-1] * sr / win_shift) + data.shape[1] + 1
56 | new_data = np.ones((n_frame, data.shape[1])) * -1
57 | for i in range(len(beat) - 1):
58 | start_time = int(np.round(beat[i] * sr / win_shift))
59 | end_time = int(np.round(beat[i + 1] * sr / win_shift))
60 | for j in range(start_time, end_time):
61 | new_data[j, :] = data[i, :]
62 | f = open(os.path.join(PACKAGE_PATH, "data/spectogram_template.svl"), "r")
63 | content = f.read()
64 | f.close()
65 | content = content.replace("[__SR__]", str(sr))
66 | content = content.replace("[__WIN_SHIFT__]", str(win_shift))
67 | content = content.replace("[__SHAPE_1__]", str(new_data.shape[1]))
68 | content = content.replace("[__COLOR__]", str(1))
69 | labels = [str(i) for i in range(new_data.shape[1])]
70 | content = content.replace("[__DATA__]", create_svl_3d_data(labels, new_data))
71 | f = open(filename, "w")
72 | f.write(content)
73 | f.close()
74 |
75 | def get_visualize_extention_name(self):
76 | return "svl"
77 |
--------------------------------------------------------------------------------
/polyffusion/dl_modules/chord_dec.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class ChordDecoder(nn.Module):
8 | def __init__(
9 | self, input_dim=36, z_input_dim=256, hidden_dim=512, z_dim=256, n_step=8
10 | ):
11 | super(ChordDecoder, self).__init__()
12 | self.z2dec_hid = nn.Linear(z_dim, hidden_dim)
13 | self.z2dec_in = nn.Linear(z_dim, z_input_dim)
14 | self.gru = nn.GRU(
15 | input_dim + z_input_dim, hidden_dim, batch_first=True, bidirectional=False
16 | )
17 | self.init_input = nn.Parameter(torch.rand(36))
18 | self.input_dim = input_dim
19 | self.hidden_dim = hidden_dim
20 | self.z_dim = z_dim
21 | self.root_out = nn.Linear(hidden_dim, 12)
22 | self.chroma_out = nn.Linear(hidden_dim, 24)
23 | self.bass_out = nn.Linear(hidden_dim, 12)
24 | self.n_step = n_step
25 | self.loss_func = nn.CrossEntropyLoss()
26 |
27 | def forward(self, z_chd, inference, tfr, gt_chd=None):
28 | # z_chd: (B, z_chd_size)
29 | bs = z_chd.size(0)
30 | z_chd_hid = self.z2dec_hid(z_chd).unsqueeze(0)
31 | z_chd_in = self.z2dec_in(z_chd).unsqueeze(1)
32 | if inference:
33 | tfr = 0.0
34 | token = self.init_input.repeat(bs, 1).unsqueeze(1)
35 | recon_root = []
36 | recon_chroma = []
37 | recon_bass = []
38 |
39 | for t in range(self.n_step):
40 | chd_t, z_chd_hid = self.gru(torch.cat([token, z_chd_in], dim=-1), z_chd_hid)
41 |
42 | # compute output distribution
43 | r_root = self.root_out(chd_t) # (bs, 1, 12)
44 | r_chroma = self.chroma_out(chd_t).view(bs, 1, 12, 2).contiguous()
45 | r_bass = self.bass_out(chd_t) # (bs, 1, 12)
46 |
47 | # write distribution on the list
48 | recon_root.append(r_root)
49 | recon_chroma.append(r_chroma)
50 | recon_bass.append(r_bass)
51 |
52 | # prepare the input to the next step
53 | if t == self.n_step - 1:
54 | break
55 | teacher_force = random.random() < tfr
56 | if teacher_force and not inference:
57 | token = gt_chd[:, t].unsqueeze(1)
58 | else:
59 | t_root = torch.zeros(bs, 1, 12).to(z_chd.device).float()
60 | t_root[torch.arange(0, bs), 0, r_root.max(-1)[-1]] = 1.0
61 | t_chroma = r_chroma.max(-1)[-1].float()
62 | t_bass = torch.zeros(bs, 1, 12).to(z_chd.device).float()
63 | t_bass[torch.arange(0, bs), 0, r_bass.max(-1)[-1]] = 1.0
64 | token = torch.cat([t_root, t_chroma, t_bass], dim=-1)
65 |
66 | recon_root = torch.cat(recon_root, dim=1)
67 | recon_chroma = torch.cat(recon_chroma, dim=1)
68 | recon_bass = torch.cat(recon_bass, dim=1)
69 | # print(recon_root.shape, recon_chroma.shape, recon_bass.shape)
70 | return recon_root, recon_chroma, recon_bass
71 |
72 | def recon_loss(self, c, recon_root, recon_chroma, recon_bass):
73 | loss_fun = self.loss_func
74 | root = c[:, :, 0:12].max(-1)[-1].view(-1).contiguous()
75 | chroma = c[:, :, 12:24].long().view(-1).contiguous()
76 | bass = c[:, :, 24:].max(-1)[-1].view(-1).contiguous()
77 |
78 | recon_root = recon_root.view(-1, 12).contiguous()
79 | recon_chroma = recon_chroma.view(-1, 2).contiguous()
80 | recon_bass = recon_bass.view(-1, 12).contiguous()
81 | root_loss = loss_fun(recon_root, root)
82 | chroma_loss = loss_fun(recon_chroma, chroma)
83 | bass_loss = loss_fun(recon_bass, bass)
84 | chord_loss = root_loss + chroma_loss + bass_loss
85 | return chord_loss, root_loss, chroma_loss, bass_loss
86 |
--------------------------------------------------------------------------------
/polyffusion/ddpm/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.utils.data
6 | from torch import nn
7 |
8 | from .utils import gather
9 |
10 |
11 | class DenoiseDiffusion(nn.Module):
12 | """
13 | ## Denoise Diffusion
14 | """
15 |
16 | def __init__(self, eps_model: nn.Module, n_steps: int):
17 | """
18 | * `eps_model` is $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ model
19 | * `n_steps` is $t$
20 | """
21 | super().__init__()
22 | self.eps_model = eps_model
23 |
24 | # Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule
25 | self.register_buffer("beta", torch.linspace(0.0001, 0.02, n_steps))
26 |
27 | # $\alpha_t = 1 - \beta_t$
28 | self.alpha = 1.0 - self.beta
29 | # $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
30 | self.alpha_bar = torch.cumprod(self.alpha, dim=0)
31 | # $T$
32 | self.n_steps = n_steps
33 | # $\sigma^2 = \beta$
34 | self.sigma2 = self.beta
35 |
36 | def q_xt_x0(
37 | self, x0: torch.Tensor, t: torch.Tensor
38 | ) -> Tuple[torch.Tensor, torch.Tensor]:
39 | """
40 | #### Get $q(x_t|x_0)$ distribution
41 | """
42 |
43 | # [gather](utils.html) $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$
44 | mean = gather(self.alpha_bar, t) ** 0.5 * x0
45 | # $(1-\bar\alpha_t) \mathbf{I}$
46 | var = 1 - gather(self.alpha_bar, t)
47 | #
48 | return mean, var
49 |
50 | def q_sample(
51 | self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None
52 | ):
53 | """
54 | #### Sample from $q(x_t|x_0)$
55 | """
56 |
57 | # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
58 | if eps is None:
59 | eps = torch.randn_like(x0)
60 |
61 | # get $q(x_t|x_0)$
62 | mean, var = self.q_xt_x0(x0, t)
63 | # Sample from $q(x_t|x_0)$
64 | return mean + (var**0.5) * eps
65 |
66 | def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
67 | """
68 | #### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
69 | """
70 |
71 | # $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
72 | eps_theta = self.eps_model(xt, t)
73 | # [gather](utils.html) $\bar\alpha_t$
74 | alpha_bar = gather(self.alpha_bar, t)
75 | # $\alpha_t$
76 | alpha = gather(self.alpha, t)
77 | # $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
78 | eps_coef = (1 - alpha) / (1 - alpha_bar) ** 0.5
79 | # $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
80 | # \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
81 | mean = 1 / (alpha**0.5) * (xt - eps_coef * eps_theta)
82 | # $\sigma^2$
83 | var = gather(self.sigma2, t)
84 |
85 | # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
86 | eps = torch.randn(xt.shape, device=xt.device)
87 | # Sample
88 | return mean + (var**0.5) * eps
89 |
90 | def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
91 | """
92 | #### Simplified Loss
93 | """
94 | # Get batch size
95 | batch_size = x0.shape[0]
96 | # Get random $t$ for each sample in the batch
97 | t = torch.randint(
98 | 0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long
99 | )
100 |
101 | # $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
102 | if noise is None:
103 | noise = torch.randn_like(x0)
104 |
105 | # Sample $x_t$ for $q(x_t|x_0)$
106 | xt = self.q_sample(x0, t, eps=noise)
107 | # Get $\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$
108 | eps_theta = self.eps_model(xt, t)
109 |
110 | # MSE loss
111 | return F.mse_loss(noise, eps_theta)
112 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/io/implement/regional_spectrogram_io.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mir.common import PACKAGE_PATH
3 | from mir.io.feature_io_base import *
4 |
5 |
6 | class RegionalSpectrogramIO(FeatureIO):
7 | def read(self, filename, entry):
8 | data = pickle_read(self, filename)
9 | assert len(data) == 3 or len(data) == 2
10 | return data
11 |
12 | def write(self, data, filename, entry):
13 | assert len(data) == 3 or len(data) == 2
14 | pickle_write(self, data, filename)
15 |
16 | def visualize(self, data, filename, entry, override_sr):
17 | if len(data) == 2:
18 | timing, data = data
19 | labels = None
20 | elif len(data) == 3:
21 | labels, timing, data = data
22 | else:
23 | raise Exception("Format error")
24 | data = np.array(data)
25 | if len(data.shape) == 1:
26 | data = data.reshape((-1, 1))
27 | sr = entry.prop.sr
28 | win_shift = entry.prop.hop_length
29 | timing = np.array(timing).reshape((len(timing), -1))
30 | n_frame = max(1, int(np.round(np.max(timing * sr / win_shift))))
31 | data_indices = (-1) * np.ones(n_frame, dtype=np.int32)
32 | timing_start = timing[: len(data), 0]
33 | if timing.shape[1] == 1:
34 | assert len(timing) == len(data) or len(timing) == len(data) + 1
35 | if len(timing) == len(data) + 1:
36 | timing_end = timing[1:, 0]
37 | else:
38 | timing_end = np.append(
39 | timing[1:, 0],
40 | timing[-1, 0] * 2 - timing[-2, 0] if (len(timing) > 1) else 1.0,
41 | )
42 | else:
43 | timing_end = timing[:, 1]
44 | for i in range(len(data)):
45 | frame_start = max(0, int(np.round(timing_start[i] * sr / win_shift)))
46 | frame_end = max(0, int(np.round(timing_end[i] * sr / win_shift)))
47 | data_indices[frame_start:frame_end] = i
48 | if data.shape[1] >= 1:
49 | f = open(os.path.join(PACKAGE_PATH, "data/spectrogram_template.svl"), "r")
50 | content = f.read()
51 | f.close()
52 | content = content.replace("[__SR__]", str(sr))
53 | content = content.replace("[__WIN_SHIFT__]", str(win_shift))
54 | content = content.replace("[__SHAPE_1__]", str(data.shape[1]))
55 | content = content.replace("[__COLOR__]", str(1))
56 | if labels is None:
57 | labels = [str(i) for i in range(data.shape[1])]
58 | assert len(labels) == len(data[0])
59 | result = (
60 | "\n".join(
61 | [
62 | '' % (i, str(labels[i]))
63 | for i in range(len(labels))
64 | ]
65 | )
66 | + "\n"
67 | )
68 | for i in range(n_frame):
69 | if data_indices[i] >= 0:
70 | result += '%s
\n' % (
71 | i,
72 | " ".join([str(s) for s in data[data_indices[i]]]),
73 | )
74 | content = content.replace("[__DATA__]", result)
75 | else:
76 | f = open(os.path.join(PACKAGE_PATH, "data/curve_template.svl"), "r")
77 | content = f.read()
78 | f.close()
79 | content = content.replace("[__SR__]", str(sr))
80 | content = content.replace("[__STYLE__]", str(1))
81 | results = []
82 | raise NotImplementedError()
83 | # for i in range(0, len(data)):
84 | # results.append(''%(int(override_sr/sr*i*win_shift),data[i,0]))
85 | # content = content.replace('[__DATA__]','\n'.join(results))
86 | # content = content.replace('[__NAME__]', 'curve')
87 |
88 | f = open(filename, "w")
89 | f.write(content)
90 | f.close()
91 |
92 | def get_visualize_extention_name(self):
93 | return "svl"
94 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Polyffusion: A Diffusion Model for Polyphonic Score Generation with Internal and External Controls
2 |
3 | - [Paper link](https://arxiv.org/abs/2307.10304)
4 | - Check our [demo page](https://polyffusion.github.io/) and give it a listen!
5 |
6 | ```
7 | @inproceedings{polyffusion2023,
8 | author = {Lejun Min and Junyan Jiang and Gus Xia and Jingwei Zhao},
9 | title = {Polyffusion: A Diffusion Model for Polyphonic Score Generation with Internal and External Controls},
10 | booktitle = {Proceedings of the 24th International Society for Music Information Retrieval Conference, {ISMIR}},
11 | year = {2023}
12 | }
13 | ```
14 |
15 | ## Installation
16 |
17 | ```shell
18 | pip install -r requirements.txt
19 | pip install -e polyffusion
20 | pip install -e polyffusion/chord_extractor
21 | pip isntall -e polyffusion/mir_eval
22 | ```
23 |
24 | ## Some Clarifications
25 |
26 | - The abbreviation "sdf" means Stable Diffusion, and "ldm" means Latent Diffusion. Basically they are referring to the same thing. However, we only borrow the cross-attention conditioning mechanism from Latent Diffusion, without utilizing its encoder and decoder. The latter is left for future experiments.
27 | - `prmat2c` in the code is the piano-roll image representation.
28 |
29 | ## Preparations
30 |
31 | - The extracted features of the dataset POP909 can be accessed [here](https://yukisaki-my.sharepoint.com/:u:/g/personal/aik2_yukisaki_io/EdUovlRZvExJrGatAR8BlTsBDC8udJiuhnIimPuD2PQ3FQ?e=WwD7Dl). Please put it under `/data/` after extraction.
32 |
33 | - The needed pre-trained models for training can be accessed [here](https://yukisaki-my.sharepoint.com/:u:/g/personal/aik2_yukisaki_io/Eca406YwV1tMgwHdoepC7G8B5l-4GRBGv7TzrI9OOg3eIA?e=uecJdU). Please put them under `/pretrained/` after extraction.
34 |
35 | ## Training
36 |
37 | ### Modifications
38 |
39 | - You can modify the parameters in the corresponding `*.yaml` files under `/polyffusion/params/`, or create your own.
40 |
41 | ### Commands
42 |
43 | ```shell
44 | python polyffusion/main.py --model [model] --output_dir [output_dir]
45 | ```
46 |
47 | The models can be selected from `/polyffusion/params/[model].yaml`. Here are some cases:
48 |
49 | - `sdf_chd8bar`: conditioned on latent chord representations encoded by a pre-trained chord encoder.
50 | - `sdf_txt`: conditioned on latent texture representations encoded by a pre-trained texture encoder.
51 | - `sdf_chdvnl`: conditioned on vanilla chord representations.
52 | - `sdf_txtvnl`: conditioned on vanilla texture representations.
53 | - `ddpm`: vanilla diffusion model from DDPM without conditioning.
54 |
55 | Examples:
56 |
57 | ```shell
58 | python polyffusion/main.py --model sdf_chd8bar --output_dir result/sdf_chd8bar
59 | ```
60 |
61 | ### Trained Checkpoints
62 |
63 | If you'd like to test our trained checkpoints, please access the folder [here](https://yukisaki-my.sharepoint.com/:f:/g/personal/aik2_yukisaki_io/EjG0IB8Xb_1CoVfYCmNUB-ABMLVSRqJST4VTrYJxjJFdnw?e=OqmZpp). We suggest to put them under `/result/` after extraction for inference.
64 |
65 | ## Inference
66 |
67 | Please see the helping messages by running
68 |
69 | ```shell
70 | python polyffusion/inference_sdf.py --help
71 | ```
72 |
73 | Examples:
74 |
75 | ```shell
76 | # unconditional generation of length 10x8 bars
77 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --uncond_scale=0. --length=10
78 |
79 | # conditional generation using DDIM sampler (default guidance scale = 1)
80 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --ddim --ddim_steps=50 --ddim_eta=0.0 --ddim_discretize=uniform
81 |
82 | # conditional generation with guidance scale = 5, conditional chord progressions chosen from a song from POP909 validation set.
83 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --uncond_scale=5.
84 |
85 | # conditional iterative inpainting (i.e. autoregressive generation) (default guidance scale = 1)
86 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --autoreg
87 |
88 | # unconditional melody generation given accompaniment
89 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --uncond_scale=0. --inpaint_from_midi=/path/to/accompaniment.mid --inpaint_type=above
90 |
91 | # accompaniment generation given melody, conditioned on chord progressions of another midi file (default guidance scale = 1)
92 | python polyffusion/inference_sdf.py --chkpt_path=/path/to/checkpoint --inpaint_from_midi=/path/to/melody.mid --inpaint_type=below --from_midi=/path/to/cond_midi.mid
93 | ```
94 |
--------------------------------------------------------------------------------
/polyffusion/train/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 | from pathlib import Path
4 | from shutil import copy2
5 |
6 | import torch
7 | from lightning.pytorch import Trainer
8 | from lightning.pytorch.callbacks import ModelCheckpoint
9 | from lightning.pytorch.loggers import WandbLogger
10 | from omegaconf import OmegaConf
11 | from torch.optim import Optimizer
12 | from torch.utils.data import DataLoader
13 |
14 | from lightning_learner import LightningLearner
15 | from utils import convert_json_to_yaml
16 |
17 |
18 | class TrainConfig:
19 | model: torch.nn.Module
20 | train_dl: DataLoader
21 | val_dl: DataLoader
22 | optimizer: Optimizer
23 |
24 | def __init__(self, params, param_scheduler, output_dir) -> None:
25 | self.model_name = params.model_name
26 | self.params = params
27 | self.param_scheduler = param_scheduler
28 |
29 | self.resume = False
30 | if os.path.exists(f"{output_dir}/chkpts/last.ckpt"):
31 | print("Checkpoint already exists.")
32 | if input("Resume training? (y/n)") == "y":
33 | self.resume = True
34 | else:
35 | print("Aborting...")
36 | exit(0)
37 | else:
38 | output_dir = f"{output_dir}/{datetime.now().strftime('%y-%m-%d_%H%M%S')}"
39 | print(f"Creating new log folder as {output_dir}")
40 | os.makedirs(output_dir, exist_ok=True)
41 |
42 | self.output_dir = output_dir
43 | self.log_dir = f"{output_dir}/logs"
44 | self.checkpoint_dir = f"{output_dir}/chkpts"
45 | self.time_stamp = Path(output_dir).name
46 |
47 | os.makedirs(self.log_dir, exist_ok=True)
48 | os.makedirs(self.checkpoint_dir, exist_ok=True)
49 |
50 | # json to yaml (compatibility)
51 | if os.path.exists(f"{output_dir}/params.json"):
52 | convert_json_to_yaml(f"{output_dir}/params.json")
53 |
54 | if os.path.exists(f"{output_dir}/params.yaml"):
55 | old_params = OmegaConf.load(f"{output_dir}/params.yaml")
56 |
57 | # The "weights" attribute is a tuple in AttrDict, but saved as a list. To compare these two, we make them both tuples:
58 | # if "weights" in old_params:
59 | # old_params["weights"] = tuple(old_params["weights"])
60 |
61 | if old_params != self.params:
62 | print("New params differ, using new params could break things.")
63 | if (
64 | input(
65 | "Do you want to keep the old params file (y/n)? The model will be trained on new params regardless."
66 | )
67 | == "y"
68 | ):
69 | time_stamp = datetime.now().strftime("%y-%m-%d_%H%M%S")
70 | copy2(
71 | f"{output_dir}/params.yaml",
72 | f"{output_dir}/old_params_{time_stamp}.yaml",
73 | )
74 | print(f"Old params saved as old_params_{time_stamp}.yaml")
75 | # save params
76 | OmegaConf.save(self.params, f"{output_dir}/params.yaml")
77 |
78 | def train(self):
79 | total_parameters = sum(
80 | p.numel() for p in self.model.parameters() if p.requires_grad
81 | )
82 | print(f"Total parameters: {total_parameters}")
83 | print(OmegaConf.to_yaml(self.params))
84 |
85 | checkpoint_callback = ModelCheckpoint(
86 | dirpath=self.checkpoint_dir,
87 | monitor="val/loss",
88 | filename="epoch{epoch}-val_loss{val/loss:.6f}",
89 | save_last=True,
90 | save_top_k=3,
91 | auto_insert_metric_name=False,
92 | )
93 | logger = WandbLogger(
94 | project=f"Polyff-{self.model_name}",
95 | save_dir=self.log_dir,
96 | name=self.time_stamp,
97 | )
98 | trainer = Trainer(
99 | default_root_dir=self.output_dir,
100 | callbacks=[checkpoint_callback],
101 | max_epochs=self.params.max_epoch,
102 | logger=logger,
103 | precision="16-mixed" if self.params.fp16 else "32-true",
104 | )
105 | learner = LightningLearner(
106 | self.model,
107 | self.optimizer,
108 | self.params,
109 | self.param_scheduler,
110 | )
111 | trainer.fit(
112 | learner,
113 | self.train_dl,
114 | self.val_dl,
115 | ckpt_path="last" if self.resume else None,
116 | )
117 |
--------------------------------------------------------------------------------
/polyffusion/dl_modules/pianotree_enc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.distributions import Normal
4 | from torch.nn.utils.rnn import pack_padded_sequence
5 |
6 |
7 | class PianoTreeEncoder(nn.Module):
8 | def __init__(
9 | self,
10 | max_simu_note=20,
11 | max_pitch=127,
12 | min_pitch=0,
13 | pitch_sos=128,
14 | pitch_eos=129,
15 | pitch_pad=130,
16 | dur_pad=2,
17 | dur_width=5,
18 | num_step=32,
19 | note_emb_size=128,
20 | enc_notes_hid_size=256,
21 | enc_time_hid_size=512,
22 | z_size=512,
23 | ):
24 | super(PianoTreeEncoder, self).__init__()
25 |
26 | # Parameters
27 | # note and time
28 | self.max_pitch = max_pitch # the highest pitch in train/val set.
29 | self.min_pitch = min_pitch # the lowest pitch in train/val set.
30 | self.pitch_sos = pitch_sos
31 | self.pitch_eos = pitch_eos
32 | self.pitch_pad = pitch_pad
33 | self.pitch_range = max_pitch - min_pitch + 3 # not including pad.
34 | self.dur_pad = dur_pad
35 | self.dur_width = dur_width
36 | self.note_size = self.pitch_range + dur_width
37 | self.max_simu_note = max_simu_note # the max # of notes at each ts.
38 | self.num_step = num_step # 32
39 | self.note_emb_size = note_emb_size
40 | self.z_size = z_size
41 | self.enc_notes_hid_size = enc_notes_hid_size
42 | self.enc_time_hid_size = enc_time_hid_size
43 |
44 | self.note_embedding = nn.Linear(self.note_size, note_emb_size)
45 | self.enc_notes_gru = nn.GRU(
46 | note_emb_size,
47 | enc_notes_hid_size,
48 | num_layers=1,
49 | batch_first=True,
50 | bidirectional=True,
51 | )
52 | self.enc_time_gru = nn.GRU(
53 | 2 * enc_notes_hid_size,
54 | enc_time_hid_size,
55 | num_layers=1,
56 | batch_first=True,
57 | bidirectional=True,
58 | )
59 | self.linear_mu = nn.Linear(2 * enc_time_hid_size, z_size)
60 | self.linear_std = nn.Linear(2 * enc_time_hid_size, z_size)
61 |
62 | @property
63 | def device(self):
64 | """
65 | ### Get model device
66 | """
67 | return next(iter(self.parameters())).device
68 |
69 | def get_len_index_tensor(self, ind_x):
70 | """Calculate the lengths ((B, 32), torch.LongTensor) of pgrid."""
71 | with torch.no_grad():
72 | lengths = self.max_simu_note - (
73 | ind_x[:, :, :, 0] - self.pitch_pad == 0
74 | ).sum(dim=-1)
75 | return lengths.to("cpu")
76 |
77 | def index_tensor_to_multihot_tensor(self, ind_x):
78 | """Transfer piano_grid to multi-hot piano_grid."""
79 | # ind_x: (B, 32, max_simu_note, 1 + dur_width)
80 | with torch.no_grad():
81 | dur_part = ind_x[:, :, :, 1:].float()
82 | out = torch.zeros(
83 | [
84 | ind_x.size(0) * self.num_step * self.max_simu_note,
85 | self.pitch_range + 1,
86 | ],
87 | dtype=torch.float,
88 | ).to(self.device)
89 |
90 | out[range(0, out.size(0)), ind_x[:, :, :, 0].reshape(-1)] = 1.0
91 | out = out.view(-1, 32, self.max_simu_note, self.pitch_range + 1)
92 | out = torch.cat([out[:, :, :, 0 : self.pitch_range], dur_part], dim=-1)
93 | return out
94 |
95 | def encoder(self, x, lengths):
96 | embedded = self.note_embedding(x)
97 | # x: (B, num_step, max_simu_note, note_emb_size)
98 | # now x are notes
99 | x = embedded.view(-1, self.max_simu_note, self.note_emb_size)
100 | x = pack_padded_sequence(
101 | x, lengths.view(-1), batch_first=True, enforce_sorted=False
102 | )
103 | x = self.enc_notes_gru(x)[-1].transpose(0, 1).contiguous()
104 | x = x.view(-1, self.num_step, 2 * self.enc_notes_hid_size)
105 | # now, x is simu_notes.
106 | x = self.enc_time_gru(x)[-1].transpose(0, 1).contiguous()
107 | # x: (B, 2, enc_time_hid_size)
108 | x = x.view(x.size(0), -1)
109 | mu = self.linear_mu(x) # (B, z_size)
110 | std = self.linear_std(x).exp_() # (B, z_size)
111 | dist = Normal(mu, std)
112 | return dist, embedded
113 |
114 | def forward(self, x, return_iterators=False):
115 | lengths = self.get_len_index_tensor(x)
116 | x = self.index_tensor_to_multihot_tensor(x)
117 | dist, embedded_x = self.encoder(x, lengths)
118 | if return_iterators:
119 | return dist.mean, dist.scale, embedded_x
120 | else:
121 | return dist, embedded_x, lengths
122 |
--------------------------------------------------------------------------------
/polyffusion/mir_eval/onset.py:
--------------------------------------------------------------------------------
1 | '''
2 | The goal of an onset detection algorithm is to automatically determine when
3 | notes are played in a piece of music. The primary method used to evaluate
4 | onset detectors is to first determine which estimated onsets are "correct",
5 | where correctness is defined as being within a small window of a reference
6 | onset.
7 |
8 | Based in part on this script:
9 |
10 | https://github.com/CPJKU/onset_detection/blob/master/onset_evaluation.py
11 |
12 | Conventions
13 | -----------
14 |
15 | Onsets should be provided in the form of a 1-dimensional array of onset
16 | times in seconds in increasing order.
17 |
18 | Metrics
19 | -------
20 |
21 | * :func:`mir_eval.onset.f_measure`: Precision, Recall, and F-measure scores
22 | based on the number of esimated onsets which are sufficiently close to
23 | reference onsets.
24 | '''
25 |
26 | import collections
27 | from . import util
28 | import warnings
29 |
30 |
31 | # The maximum allowable beat time
32 | MAX_TIME = 30000.
33 |
34 |
35 | def validate(reference_onsets, estimated_onsets):
36 | """Checks that the input annotations to a metric look like valid onset time
37 | arrays, and throws helpful errors if not.
38 |
39 | Parameters
40 | ----------
41 | reference_onsets : np.ndarray
42 | reference onset locations, in seconds
43 | estimated_onsets : np.ndarray
44 | estimated onset locations, in seconds
45 |
46 | """
47 | # If reference or estimated onsets are empty, warn because metric will be 0
48 | if reference_onsets.size == 0:
49 | warnings.warn("Reference onsets are empty.")
50 | if estimated_onsets.size == 0:
51 | warnings.warn("Estimated onsets are empty.")
52 | for onsets in [reference_onsets, estimated_onsets]:
53 | util.validate_events(onsets, MAX_TIME)
54 |
55 |
56 | def f_measure(reference_onsets, estimated_onsets, window=.05):
57 | """Compute the F-measure of correct vs incorrectly predicted onsets.
58 | "Corectness" is determined over a small window.
59 |
60 | Examples
61 | --------
62 | >>> reference_onsets = mir_eval.io.load_events('reference.txt')
63 | >>> estimated_onsets = mir_eval.io.load_events('estimated.txt')
64 | >>> F, P, R = mir_eval.onset.f_measure(reference_onsets,
65 | ... estimated_onsets)
66 |
67 | Parameters
68 | ----------
69 | reference_onsets : np.ndarray
70 | reference onset locations, in seconds
71 | estimated_onsets : np.ndarray
72 | estimated onset locations, in seconds
73 | window : float
74 | Window size, in seconds
75 | (Default value = .05)
76 |
77 | Returns
78 | -------
79 | f_measure : float
80 | 2*precision*recall/(precision + recall)
81 | precision : float
82 | (# true positives)/(# true positives + # false positives)
83 | recall : float
84 | (# true positives)/(# true positives + # false negatives)
85 |
86 | """
87 | validate(reference_onsets, estimated_onsets)
88 | # If either list is empty, return 0s
89 | if reference_onsets.size == 0 or estimated_onsets.size == 0:
90 | return 0., 0., 0.
91 | # Compute the best-case matching between reference and estimated onset
92 | # locations
93 | matching = util.match_events(reference_onsets, estimated_onsets, window)
94 |
95 | precision = float(len(matching))/len(estimated_onsets)
96 | recall = float(len(matching))/len(reference_onsets)
97 | # Compute F-measure and return all statistics
98 | return util.f_measure(precision, recall), precision, recall
99 |
100 |
101 | def evaluate(reference_onsets, estimated_onsets, **kwargs):
102 | """Compute all metrics for the given reference and estimated annotations.
103 |
104 | Examples
105 | --------
106 | >>> reference_onsets = mir_eval.io.load_events('reference.txt')
107 | >>> estimated_onsets = mir_eval.io.load_events('estimated.txt')
108 | >>> scores = mir_eval.onset.evaluate(reference_onsets,
109 | ... estimated_onsets)
110 |
111 | Parameters
112 | ----------
113 | reference_onsets : np.ndarray
114 | reference onset locations, in seconds
115 | estimated_onsets : np.ndarray
116 | estimated onset locations, in seconds
117 | kwargs
118 | Additional keyword arguments which will be passed to the
119 | appropriate metric or preprocessing functions.
120 |
121 | Returns
122 | -------
123 | scores : dict
124 | Dictionary of scores, where the key is the metric name (str) and
125 | the value is the (float) score achieved.
126 |
127 | """
128 | # Compute all metrics
129 | scores = collections.OrderedDict()
130 |
131 | (scores['F-measure'],
132 | scores['Precision'],
133 | scores['Recall']) = util.filter_kwargs(f_measure, reference_onsets,
134 | estimated_onsets, **kwargs)
135 |
136 | return scores
137 |
--------------------------------------------------------------------------------
/polyffusion/data/polydis_format_to_mine.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 |
4 | import numpy as np
5 | import pretty_midi as pm
6 | from tqdm import tqdm
7 |
8 | ORIGIN_DIR = "data/POP09-PIANOROLL-4-bin-quantization"
9 | NEW_DIR = "data/POP909_4_bin_pnt_8bar"
10 |
11 | ONE_BEAT_TIME = 0.5
12 | SEG_LGTH = 32
13 | BEAT = 4
14 | BIN = 4
15 | SEG_LGTH_BIN = SEG_LGTH * BIN
16 |
17 |
18 | def get_note_matrix(mats):
19 | """
20 | (onset_beat, onset_bin, bin, offset_beat, offset_bin, bin, pitch, velocity)
21 | """
22 | notes = []
23 |
24 | for mat in mats:
25 | assert mat[2] == mat[5] == BIN
26 | onset = mat[0] * BIN + mat[1]
27 | offset = mat[3] * BIN + mat[4]
28 | duration = offset - onset
29 | if duration > 0:
30 | # this is compulsory because there may be notes
31 | # with zero duration after adjusting resolution
32 | notes.append([onset, mat[6], duration, mat[7], 0])
33 | # sort according to (start, duration)
34 | # notes.sort(key=lambda x: (x[0] * BIN + x[1], x[2]))
35 | notes.sort(key=lambda x: (x[0], x[1], x[2]))
36 | return notes
37 |
38 |
39 | def dedup_note_matrix(notes):
40 | """
41 | remove duplicated notes (because of multiple tracks)
42 | """
43 |
44 | last = []
45 | notes_dedup = []
46 | for i, note in enumerate(notes):
47 | if i != 0:
48 | if note[:2] != last[:2]:
49 | # if start and pitch are not the same
50 | notes_dedup.append(note)
51 | else:
52 | notes_dedup.append(note)
53 | last = note
54 | # print(f"dedup: {len(notes) - len(notes_dedup)} : {len(notes)}")
55 |
56 | return notes_dedup
57 |
58 |
59 | def retrieve_midi_from_nmat(notes, output_fpath):
60 | """
61 | retrieve midi from note matrix
62 | """
63 | midi = pm.PrettyMIDI()
64 | piano_program = pm.instrument_name_to_program("Acoustic Grand Piano")
65 | piano = pm.Instrument(program=piano_program)
66 | for note in notes:
67 | onset, pitch, duration, velocity, program = note
68 | start = onset * ONE_BEAT_TIME / float(BIN)
69 | end = start + duration * ONE_BEAT_TIME / float(BIN)
70 | pm_note = pm.Note(velocity, pitch, start, end)
71 | piano.notes.append(pm_note)
72 |
73 | midi.instruments.append(piano)
74 | midi.write(output_fpath)
75 |
76 |
77 | def get_downbeat_pos_and_filter(notes, beats):
78 | """
79 | beats: [0-1, 2/3beat-cnt, 2, 0-3, 4/6beat-cnt, 4]
80 | """
81 | end_time = notes[-1][0]
82 | db_pos = []
83 | for i, beat in enumerate(beats):
84 | if beat[3] == 0:
85 | pos = i * BIN
86 | db_pos.append(pos)
87 |
88 | # print(db_pos)
89 | db_pos_filter = []
90 | for idx, db in enumerate(db_pos):
91 | if (
92 | idx + (SEG_LGTH / BEAT) <= len(db_pos)
93 | and db_pos[idx + 1] - db == BEAT * BIN
94 | ):
95 | db_pos_filter.append(True)
96 | else:
97 | db_pos_filter.append(False)
98 | # print(db_pos_filter)
99 | return db_pos, db_pos_filter
100 |
101 |
102 | def get_start_table(notes, db_pos):
103 | """
104 | i-th row indicates the starting row of the "notes" array at i-th beat.
105 | """
106 |
107 | # simply add 8-beat padding in case of out-of-range index
108 | # total_beat = int(music.get_end_time()) + 8
109 | row_cnt = 0
110 | start_table = {}
111 | # for beat in range(total_beat):
112 | for db in db_pos:
113 | while row_cnt < len(notes) and notes[row_cnt][0] < db:
114 | row_cnt += 1
115 | start_table[db] = row_cnt
116 |
117 | return start_table
118 |
119 |
120 | def cat_note_mats(note_mats):
121 | return np.concatenate(note_mats, 0)
122 |
123 |
124 | if __name__ == "__main__":
125 | if os.path.exists(NEW_DIR):
126 | os.system(f"rm -rf {NEW_DIR}")
127 | os.makedirs(NEW_DIR)
128 |
129 | for piece in tqdm(os.listdir(ORIGIN_DIR)):
130 | fpath = os.path.join(ORIGIN_DIR, piece)
131 | f = np.load(fpath)
132 | melody = get_note_matrix(f["melody"])
133 | bridge = get_note_matrix(f["bridge"])
134 | piano = get_note_matrix(f["piano"])
135 | beats = f["beat"]
136 | notes = cat_note_mats([melody, bridge, piano])
137 |
138 | retrieve_midi_from_nmat(
139 | notes, os.path.join(NEW_DIR, piece[:-4] + "_flatten.mid")
140 | )
141 | db_pos, db_pos_filter = get_downbeat_pos_and_filter(notes, beats)
142 | start_table_melody = get_start_table(melody, db_pos)
143 | start_table_bridge = get_start_table(bridge, db_pos)
144 | start_table_piano = get_start_table(piano, db_pos)
145 | np.savez(
146 | join(NEW_DIR, piece[:-4]),
147 | notes=[melody, bridge, piano],
148 | start_table=[start_table_melody, start_table_bridge, start_table_piano],
149 | db_pos=db_pos,
150 | db_pos_filter=db_pos_filter,
151 | chord=f["chord"],
152 | )
153 |
--------------------------------------------------------------------------------
/polyffusion/stable_diffusion/util.py:
--------------------------------------------------------------------------------
1 | """
2 | ---
3 | title: Utility functions for stable diffusion
4 | summary: >
5 | Utility functions for stable diffusion
6 | ---
7 |
8 | # Utility functions for [stable diffusion](index.html)
9 | """
10 |
11 | import os
12 | import random
13 | from pathlib import Path
14 |
15 | import numpy as np
16 | import PIL
17 | import torch
18 | from labml import monit
19 | from labml.logger import inspect
20 | from PIL import Image
21 |
22 | from .latent_diffusion import LatentDiffusion
23 | from .model.autoencoder import Autoencoder, Decoder, Encoder
24 |
25 | # from model.clip_embedder import CLIPTextEmbedder
26 | from .model.unet import UNetModel
27 |
28 |
29 | def set_seed(seed: int):
30 | """
31 | ### Set random seeds
32 | """
33 | random.seed(seed)
34 | np.random.seed(seed)
35 | torch.manual_seed(seed)
36 | torch.cuda.manual_seed_all(seed)
37 |
38 |
39 | def load_model(path: Path = None) -> LatentDiffusion:
40 | """
41 | ### Load [`LatentDiffusion` model](latent_diffusion.html)
42 | """
43 |
44 | # Initialize the autoencoder
45 | with monit.section("Initialize autoencoder"):
46 | encoder = Encoder(
47 | z_channels=4,
48 | in_channels=3,
49 | channels=128,
50 | channel_multipliers=[1, 2, 4, 4],
51 | n_resnet_blocks=2,
52 | )
53 |
54 | decoder = Decoder(
55 | out_channels=3,
56 | z_channels=4,
57 | channels=128,
58 | channel_multipliers=[1, 2, 4, 4],
59 | n_resnet_blocks=2,
60 | )
61 |
62 | autoencoder = Autoencoder(
63 | emb_channels=4, encoder=encoder, decoder=decoder, z_channels=4
64 | )
65 |
66 | # Initialize the U-Net
67 | with monit.section("Initialize U-Net"):
68 | unet_model = UNetModel(
69 | in_channels=4,
70 | out_channels=4,
71 | channels=320,
72 | attention_levels=[0, 1, 2],
73 | n_res_blocks=2,
74 | channel_multipliers=[1, 2, 4, 4],
75 | n_heads=8,
76 | tf_layers=1,
77 | d_cond=768,
78 | )
79 |
80 | # Initialize the Latent Diffusion model
81 | with monit.section("Initialize Latent Diffusion model"):
82 | model = LatentDiffusion(
83 | linear_start=0.00085,
84 | linear_end=0.0120,
85 | n_steps=1000,
86 | latent_scaling_factor=0.18215,
87 | autoencoder=autoencoder,
88 | unet_model=unet_model,
89 | )
90 |
91 | # Load the checkpoint
92 | with monit.section(f"Loading model from {path}"):
93 | checkpoint = torch.load(path, map_location="cpu")
94 |
95 | # Set model state
96 | with monit.section("Load state"):
97 | missing_keys, extra_keys = model.load_state_dict(
98 | checkpoint["state_dict"], strict=False
99 | )
100 |
101 | # Debugging output
102 | inspect(
103 | global_step=checkpoint.get("global_step", -1),
104 | missing_keys=missing_keys,
105 | extra_keys=extra_keys,
106 | _expand=True,
107 | )
108 |
109 | #
110 | model.eval()
111 | return model
112 |
113 |
114 | def load_img(path: str):
115 | """
116 | ### Load an image
117 |
118 | This loads an image from a file and returns a PyTorch tensor.
119 |
120 | :param path: is the path of the image
121 | """
122 | # Open Image
123 | image = Image.open(path).convert("RGB")
124 | # Get image size
125 | w, h = image.size
126 | # Resize to a multiple of 32
127 | w = w - w % 32
128 | h = h - h % 32
129 | image = image.resize((w, h), resample=PIL.Image.LANCZOS)
130 | # Convert to numpy and map to `[-1, 1]` for `[0, 255]`
131 | image = np.array(image).astype(np.float32) * (2.0 / 255.0) - 1
132 | # Transpose to shape `[batch_size, channels, height, width]`
133 | image = image[None].transpose(0, 3, 1, 2)
134 | # Convert to torch
135 | return torch.from_numpy(image)
136 |
137 |
138 | def save_images(
139 | images: torch.Tensor, dest_path: str, prefix: str = "", img_format: str = "jpeg"
140 | ):
141 | """
142 | ### Save a images
143 |
144 | :param images: is the tensor with images of shape `[batch_size, channels, height, width]`
145 | :param dest_path: is the folder to save images in
146 | :param prefix: is the prefix to add to file names
147 | :param img_format: is the image format
148 | """
149 |
150 | # Create the destination folder
151 | os.makedirs(dest_path, exist_ok=True)
152 |
153 | # Map images to `[0, 1]` space and clip
154 | images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
155 | # Transpose to `[batch_size, height, width, channels]` and convert to numpy
156 | images = images.cpu().permute(0, 2, 3, 1).numpy()
157 |
158 | # Save images
159 | for i, img in enumerate(images):
160 | img = Image.fromarray((255.0 * img).astype(np.uint8))
161 | img.save(
162 | os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format
163 | )
164 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/extractors/extractor_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from abc import ABC, abstractmethod
4 |
5 | from mir import io
6 | from mir.common import WORKING_PATH
7 |
8 |
9 | def pickle_read(filename):
10 | f = open(filename, "rb")
11 | obj = pickle.load(f)
12 | f.close()
13 | return obj
14 |
15 |
16 | def pickle_write(data, filename):
17 | f = open(filename, "wb")
18 | pickle.dump(data, f)
19 | f.close()
20 |
21 |
22 | def try_mkdir(filename):
23 | folder = os.path.dirname(filename)
24 | if not os.path.isdir(folder):
25 | os.makedirs(folder)
26 |
27 |
28 | class ExtractorBase(ABC):
29 | def require(self, *args):
30 | pass
31 |
32 | def get_feature_class(self):
33 | return io.UnknownIO
34 |
35 | @abstractmethod
36 | def extract(self, entry, **kwargs):
37 | pass
38 |
39 | def __create_cache_path(self, entry, cached_prop_record, input_kwargs):
40 | items = {}
41 | items_entry = {}
42 | for k in input_kwargs:
43 | items[k] = input_kwargs[k]
44 | for prop_name in cached_prop_record:
45 | if prop_name not in items:
46 | items_entry[prop_name] = entry.prop.get_unrecorded(prop_name)
47 |
48 | if len(items) == 0:
49 | folder_name = self.__class__.__name__
50 | else:
51 | folder_name = (
52 | self.__class__.__name__
53 | + "/"
54 | + ",".join([k + "=" + str(items[k]) for k in sorted(items.keys())])
55 | )
56 |
57 | if len(items_entry) == 0:
58 | entry_name = entry.name + ".cache"
59 | else:
60 | entry_name = (
61 | entry.name
62 | + "."
63 | + ",".join(
64 | [k + "=" + str(items_entry[k]) for k in sorted(items_entry.keys())]
65 | )
66 | + ".cache"
67 | )
68 |
69 | return os.path.join(WORKING_PATH, "cache_data", folder_name, entry_name)
70 |
71 | def extract_and_cache(self, entry, cache_enabled=True, **kwargs):
72 | folder_name = os.path.join(WORKING_PATH, "cache_data", self.__class__.__name__)
73 | prop_cache_filename = os.path.join(folder_name, "_prop_records.cache")
74 | if "cached_prop_record" in self.__dict__:
75 | cached_prop_record = self.__dict__["cached_prop_record"]
76 | else:
77 | if os.path.exists(prop_cache_filename):
78 | cached_prop_record = pickle_read(prop_cache_filename)
79 | else:
80 | cached_prop_record = None
81 |
82 | if (
83 | cache_enabled
84 | and entry.name != ""
85 | and self.get_feature_class() != io.UnknownIO
86 | ):
87 | # Need cache
88 | need_io_create = False
89 | if cached_prop_record is None:
90 | entry.prop.start_record_reading()
91 | feature = self.extract(entry, **kwargs)
92 | cached_prop_record = sorted(entry.prop.end_record_reading())
93 | try_mkdir(prop_cache_filename)
94 | pickle_write(cached_prop_record, prop_cache_filename)
95 | cache_file_name = self.__create_cache_path(
96 | entry, cached_prop_record, kwargs
97 | )
98 | need_io_create = True
99 | else:
100 | cache_file_name = self.__create_cache_path(
101 | entry, cached_prop_record, kwargs
102 | )
103 | if not os.path.isfile(cache_file_name):
104 | entry.prop.start_record_reading()
105 | feature = self.extract(entry, **kwargs)
106 | new_prop_record = sorted(entry.prop.end_record_reading())
107 | if cached_prop_record != new_prop_record:
108 | print(
109 | "[Warning] Inconsistent cached properity requirement in %s, overrode:"
110 | % self.__class__.__name__
111 | )
112 | print("Old:", cached_prop_record)
113 | print("New:", new_prop_record)
114 | cached_prop_record = new_prop_record
115 | pickle_write(cached_prop_record, prop_cache_filename)
116 | cache_file_name = self.__create_cache_path(
117 | entry, cached_prop_record, kwargs
118 | )
119 | need_io_create = True
120 | else:
121 | io_obj = self.get_feature_class()()
122 | feature = io_obj.safe_read(cache_file_name, entry)
123 | if need_io_create:
124 | io_obj = self.get_feature_class()()
125 | io_obj.create(feature, cache_file_name, entry)
126 | else:
127 | feature = self.extract(entry, **kwargs)
128 | return feature
129 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/example.out:
--------------------------------------------------------------------------------
1 | 0.0 2.6666399999999997 N
2 | 2.6666399999999997 5.333279999999999 N
3 | 5.333279999999999 7.99992 C#:min
4 | 7.99992 10.66656 A:maj9
5 | 10.66656 13.333200000000001 A:maj9
6 | 13.333200000000001 15.999840000000003 D:maj6(9)
7 | 15.999840000000003 18.666480000000004 D:maj6
8 | 18.666480000000004 21.333120000000005 A:maj9
9 | 21.333120000000005 23.999760000000006 A:maj9
10 | 23.999760000000006 26.666400000000007 D:maj6(9)
11 | 26.666400000000007 29.333040000000008 D:maj6(9)
12 | 29.333040000000008 31.99968000000001 A:maj
13 | 31.99968000000001 34.666320000000006 A:maj
14 | 34.666320000000006 37.33296000000001 D:maj6
15 | 37.33296000000001 39.99960000000001 D:maj6
16 | 39.99960000000001 42.66624000000001 A:maj
17 | 42.66624000000001 45.33288000000001 A:maj
18 | 45.33288000000001 47.99952000000001 D:maj6
19 | 47.99952000000001 50.66616000000001 D:maj6
20 | 50.66616000000001 51.99948000000001 C#:min/b3
21 | 51.99948000000001 53.33280000000001 F#:min
22 | 53.33280000000001 54.666120000000014 D:maj/3
23 | 54.666120000000014 55.999440000000014 E:7
24 | 55.999440000000014 57.332760000000015 C#:min
25 | 57.332760000000015 58.666080000000015 F#:min
26 | 58.666080000000015 59.999400000000016 D:maj/3
27 | 59.999400000000016 61.332720000000016 E:7/3
28 | 61.332720000000016 62.66604000000002 C#:min7/5
29 | 62.66604000000002 63.99936000000002 F#:min(9)
30 | 63.99936000000002 65.99934 B:min7
31 | 65.99934 66.666 E:maj
32 | 66.666 69.33263999999997 A:maj7
33 | 69.33263999999997 71.99927999999994 A:maj7
34 | 71.99927999999994 74.66591999999991 D:maj6
35 | 74.66591999999991 77.33255999999989 D:maj6
36 | 77.33255999999989 79.99919999999986 A:maj7
37 | 79.99919999999986 82.66583999999983 A:maj7
38 | 82.66583999999983 85.3324799999998 D:maj6(9)
39 | 85.3324799999998 87.99911999999978 D:maj6(9)
40 | 87.99911999999978 90.66575999999975 A:maj7
41 | 90.66575999999975 93.33239999999972 A:maj7
42 | 93.33239999999972 95.9990399999997 D:maj6(9)
43 | 95.9990399999997 98.66567999999967 D:maj6(9)
44 | 98.66567999999967 99.99899999999965 C#:min/b3
45 | 99.99899999999965 101.33231999999964 F#:min
46 | 101.33231999999964 102.66563999999963 D:maj/3
47 | 102.66563999999963 103.99895999999961 E:7
48 | 103.99895999999961 105.3322799999996 C#:min
49 | 105.3322799999996 106.66559999999959 F#:min
50 | 106.66559999999959 107.99891999999957 D:maj/3
51 | 107.99891999999957 109.33223999999956 E:7/3
52 | 109.33223999999956 110.66555999999954 C#:min7/5
53 | 110.66555999999954 111.99887999999953 F#:min(9)
54 | 111.99887999999953 113.99885999999951 B:min7
55 | 113.99885999999951 117.33215999999948 A:maj7/5
56 | 117.33215999999948 119.99879999999945 A:maj
57 | 119.99879999999945 122.66543999999942 D:maj6
58 | 122.66543999999942 125.3320799999994 D:maj6
59 | 125.3320799999994 127.99871999999937 A:maj
60 | 127.99871999999937 130.6653599999994 A:maj
61 | 130.6653599999994 133.33199999999943 D:maj6
62 | 133.33199999999943 135.99863999999945 D:maj6
63 | 135.99863999999945 138.66527999999948 A:maj9
64 | 138.66527999999948 141.3319199999995 A:maj9
65 | 141.3319199999995 143.99855999999954 C:aug
66 | 143.99855999999954 146.66519999999957 F#:min9
67 | 146.66519999999957 148.6651799999996 B:min(9)
68 | 148.6651799999996 151.99847999999963 E:maj6
69 | 151.99847999999963 153.99845999999965 G:maj9
70 | 153.99845999999965 157.3317599999997 E:min/b3
71 | 157.3317599999997 159.99839999999972 E:min11
72 | 159.99839999999972 162.66503999999975 E:min11
73 | 162.66503999999975 165.33167999999978 A:maj7
74 | 165.33167999999978 167.9983199999998 A:maj7
75 | 167.9983199999998 170.66495999999984 D:maj6(9)
76 | 170.66495999999984 173.33159999999987 D:maj6(9)
77 | 173.33159999999987 175.9982399999999 A:maj7
78 | 175.9982399999999 178.66487999999993 A:maj7
79 | 178.66487999999993 181.33151999999995 D:maj6(9)
80 | 181.33151999999995 183.99815999999998 D:maj6(9)
81 | 183.99815999999998 185.33148 C#:min/b3
82 | 185.33148 186.6648 F#:min
83 | 186.6648 187.99812000000003 D:maj/3
84 | 187.99812000000003 189.33144000000004 E:7
85 | 189.33144000000004 190.66476000000006 C#:min
86 | 190.66476000000006 191.99808000000007 F#:min
87 | 191.99808000000007 193.3314000000001 D:maj/3
88 | 193.3314000000001 194.6647200000001 E:7/3
89 | 194.6647200000001 195.99804000000012 C#:min7/5
90 | 195.99804000000012 197.33136000000013 F#:min(9)
91 | 197.33136000000013 199.33134000000015 B:min7
92 | 199.33134000000015 199.99800000000016 E:maj
93 | 199.99800000000016 202.6646400000002 A:maj7
94 | 202.6646400000002 205.33128000000022 A:maj7
95 | 205.33128000000022 207.99792000000025 D:maj6(9)
96 | 207.99792000000025 210.66456000000028 D:maj6(9)
97 | 210.66456000000028 213.3312000000003 A:maj7
98 | 213.3312000000003 215.99784000000034 A:maj7
99 | 215.99784000000034 218.66448000000037 D:maj(9)
100 | 218.66448000000037 221.3311200000004 D:maj6(9)
101 | 221.3311200000004 223.99776000000043 A:maj7
102 | 223.99776000000043 226.66440000000046 A:maj7
103 | 226.66440000000046 229.33104000000048 D:maj6(9)
104 | 229.33104000000048 231.9976800000005 A:maj13
105 | 231.9976800000005 234.66432000000054 A:maj7
106 | 234.66432000000054 237.33096000000057 A:maj7
107 | 237.33096000000057 239.9976000000006 D:maj6(9)
108 | 239.9976000000006 242.66424000000063 D:maj6(9)
109 | 242.66424000000063 245.33088000000066 N
110 | 245.33088000000066 247.33086000000068 N
111 |
--------------------------------------------------------------------------------
/polyffusion/prepare_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 |
4 | import muspy
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | from data.midi_to_data import *
9 |
10 |
11 | def force_length(music: muspy.music, bars=8):
12 | """Loops a MIDI file if it's under the specified number of bars, in place."""
13 | num_tracks_at_least_bars = sum(
14 | [
15 | 1 if (track.get_end_time() + 15) // 16 >= bars else 0
16 | for track in music.tracks
17 | ]
18 | )
19 | if num_tracks_at_least_bars > 0:
20 | return
21 | for track in music.tracks:
22 | timesteps = track.get_end_time()
23 | old_bars = (timesteps + 15) // 16
24 | div = bars // old_bars
25 | for i in range(1, div):
26 | tmp = track.deepcopy()
27 | tmp.adjust_time(lambda x: x + i * timesteps)
28 | track.notes.extend(tmp.notes)
29 |
30 |
31 | def get_note_matrix_melodies(music, ignore_non_melody=True):
32 | """Similar to get_note_matrix from data.midi_to_data, with an option to ignore non-melodies."""
33 | notes = []
34 | for inst in music.tracks:
35 | if ignore_non_melody and (inst.is_drum or inst.program >= 113):
36 | continue
37 | for note in inst.notes:
38 | onset = int(note.time)
39 | duration = int(note.duration)
40 | if duration > 0:
41 | notes.append(
42 | [
43 | onset,
44 | note.pitch,
45 | duration,
46 | note.velocity,
47 | inst.program,
48 | ]
49 | )
50 | notes.sort(key=lambda x: (x[0], x[1], x[2]))
51 | assert len(notes) # in case if a MIDI has only non-melodies
52 | return notes
53 |
54 |
55 | def prepare_npz(midi_dir, chords_dir, output_dir, force=False, ignore_non_melody=True):
56 | for dir in [chords_dir, output_dir]:
57 | if not os.path.exists(dir):
58 | os.makedirs(dir)
59 | ttl = 0
60 | success = 0
61 | downbeat_errors = 0
62 | chords_errors = 0
63 | for root, dirs, files in os.walk(midi_dir):
64 | for midi in tqdm(files, desc=f"Processing {root}"):
65 | ttl += 1
66 | fpath = os.path.join(root, midi)
67 | chdpath = os.path.join(chords_dir, os.path.splitext(midi)[0] + ".csv")
68 | music = muspy.read_midi(fpath)
69 | music.adjust_resolution(4)
70 | if len(music.time_signatures) == 0:
71 | music.time_signatures.append(muspy.TimeSignature(0, 4, 4))
72 | if force:
73 | force_length(music)
74 |
75 | try:
76 | note_mat = get_note_matrix_melodies(music, ignore_non_melody)
77 | note_mat = dedup_note_matrix(note_mat)
78 | extract_chords_from_midi_file(fpath, chdpath)
79 | chord = get_chord_matrix(chdpath)
80 | except:
81 | chords_errors += 1
82 | continue
83 |
84 | try:
85 | db_pos, db_pos_filter = get_downbeat_pos_and_filter(music, fpath)
86 | except:
87 | downbeat_errors += 1
88 | continue
89 | if db_pos is not None and sum(filter(lambda x: x, db_pos_filter)) != 0:
90 | start_table = get_start_table(note_mat, db_pos)
91 | processed_data = {
92 | "notes": np.array(note_mat),
93 | "start_table": np.array(start_table),
94 | "db_pos": np.array(db_pos),
95 | "db_pos_filter": np.array(db_pos_filter),
96 | "chord": np.array(chord),
97 | }
98 | np.savez(
99 | os.path.join(output_dir, midi),
100 | notes=processed_data["notes"],
101 | start_table=processed_data["start_table"],
102 | db_pos=processed_data["db_pos"],
103 | db_pos_filter=processed_data["db_pos_filter"],
104 | chord=processed_data["chord"],
105 | )
106 | success += 1
107 | else:
108 | downbeat_errors += 1
109 |
110 | print(
111 | f"""{ttl} tracks processed, {success} succeeded, {chords_errors} chords errors, {downbeat_errors} downbeat errors"""
112 | )
113 |
114 |
115 | if __name__ == "__main__":
116 | parser = ArgumentParser(
117 | description="prepare data from midi for a Polyffusion model"
118 | )
119 | parser.add_argument(
120 | "--midi_dir", type=str, help="directory of input midis to preparep"
121 | )
122 | parser.add_argument(
123 | "--chords_dir", type=str, help="directory to store extracted chords"
124 | )
125 | parser.add_argument(
126 | "--npz_dir", type=str, help="directory to store prepared data in npz"
127 | )
128 | parser.add_argument(
129 | "--force_length",
130 | action="store_true",
131 | help="to repeat shorter samples into the desired number of bars",
132 | )
133 | parser.add_argument(
134 | "--ignore_non_melody",
135 | action="store_false",
136 | help="whether ignore all non-melody instruments. default: true",
137 | )
138 | args = parser.parse_args()
139 | prepare_npz(
140 | args.midi_dir,
141 | args.chords_dir,
142 | args.npz_dir,
143 | args.force_length,
144 | args.ignore_non_melody,
145 | )
146 |
--------------------------------------------------------------------------------
/polyffusion/train/train_ldm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from data.dataloader import get_custom_train_val_dataloaders, get_train_val_dataloaders
4 | from data.dataloader_musicalion import (
5 | get_train_val_dataloaders as get_train_val_dataloaders_musicalion,
6 | )
7 | from dirs import PT_CHD_8BAR_PATH, PT_PNOTREE_PATH, PT_POLYDIS_PATH
8 | from models.model_sdf import Polyffusion_SDF
9 | from stable_diffusion.latent_diffusion import LatentDiffusion
10 | from stable_diffusion.model.unet import UNetModel
11 | from utils import (
12 | load_pretrained_chd_enc_dec,
13 | load_pretrained_pnotree_enc_dec,
14 | load_pretrained_txt_enc,
15 | )
16 |
17 | # from stable_diffusion.model.autoencoder import Autoencoder, Encoder, Decoder
18 | from . import TrainConfig
19 |
20 |
21 | class LDM_TrainConfig(TrainConfig):
22 | def __init__(
23 | self,
24 | params,
25 | output_dir,
26 | use_autoencoder=False,
27 | use_musicalion=False,
28 | use_track=[0, 1, 2],
29 | data_dir=None,
30 | ) -> None:
31 | super().__init__(params, None, output_dir)
32 | self.autoencoder = None
33 |
34 | if use_autoencoder:
35 | # encoder = Encoder(
36 | # in_channels=2,
37 | # z_channels=4,
38 | # channels=64,
39 | # channel_multipliers=[1, 2, 4, 4],
40 | # n_resnet_blocks=2
41 | # )
42 |
43 | # decoder = Decoder(
44 | # out_channels=2,
45 | # z_channels=4,
46 | # channels=64,
47 | # channel_multipliers=[1, 2, 4, 4],
48 | # n_resnet_blocks=2
49 | # )
50 |
51 | # self.autoencoder = Autoencoder(
52 | # emb_channels=4, encoder=encoder, decoder=decoder, z_channels=4
53 | # )
54 | raise NotImplementedError
55 |
56 | self.unet_model = UNetModel(
57 | in_channels=params.in_channels,
58 | out_channels=params.out_channels,
59 | channels=params.channels,
60 | attention_levels=params.attention_levels,
61 | n_res_blocks=params.n_res_blocks,
62 | channel_multipliers=params.channel_multipliers,
63 | n_heads=params.n_heads,
64 | tf_layers=params.tf_layers,
65 | d_cond=params.d_cond,
66 | )
67 |
68 | self.ldm_model = LatentDiffusion(
69 | linear_start=params.linear_start,
70 | linear_end=params.linear_end,
71 | n_steps=params.n_steps,
72 | latent_scaling_factor=params.latent_scaling_factor,
73 | autoencoder=self.autoencoder,
74 | unet_model=self.unet_model,
75 | )
76 |
77 | self.pnotree_enc, self.pnotree_dec = None, None
78 | self.chord_enc, self.chord_dec = None, None
79 | self.txt_enc = None
80 | if params.cond_type == "pnotree":
81 | self.pnotree_enc, self.pnotree_dec = load_pretrained_pnotree_enc_dec(
82 | PT_PNOTREE_PATH, 20
83 | )
84 | if "chord" in params.cond_type:
85 | if params.use_enc:
86 | self.chord_enc, self.chord_dec = load_pretrained_chd_enc_dec(
87 | PT_CHD_8BAR_PATH,
88 | params.chd_input_dim,
89 | params.chd_z_input_dim,
90 | params.chd_hidden_dim,
91 | params.chd_z_dim,
92 | params.chd_n_step,
93 | )
94 | if "txt" in params.cond_type:
95 | if params.use_enc:
96 | self.txt_enc = load_pretrained_txt_enc(
97 | PT_POLYDIS_PATH,
98 | params.txt_emb_size,
99 | params.txt_hidden_dim,
100 | params.txt_z_dim,
101 | params.txt_num_channel,
102 | )
103 | self.model = Polyffusion_SDF(
104 | self.ldm_model,
105 | cond_type=params.cond_type,
106 | cond_mode=params.cond_mode,
107 | chord_enc=self.chord_enc,
108 | chord_dec=self.chord_dec,
109 | pnotree_enc=self.pnotree_enc,
110 | pnotree_dec=self.pnotree_dec,
111 | txt_enc=self.txt_enc,
112 | concat_blurry=params.concat_blurry
113 | if hasattr(params, "concat_blurry")
114 | else False,
115 | concat_ratio=params.concat_ratio
116 | if hasattr(params, "concat_ratio")
117 | else 1 / 8,
118 | )
119 | # Create dataloader
120 | if use_musicalion:
121 | self.train_dl, self.val_dl = get_train_val_dataloaders_musicalion(
122 | params.batch_size, params.num_workers, params.pin_memory
123 | )
124 | else:
125 | if data_dir is None:
126 | self.train_dl, self.val_dl = get_train_val_dataloaders(
127 | params.batch_size,
128 | params.num_workers,
129 | params.pin_memory,
130 | use_track=use_track,
131 | )
132 | else:
133 | self.train_dl, self.val_dl = get_custom_train_val_dataloaders(
134 | params.batch_size,
135 | data_dir,
136 | num_workers=params.num_workers,
137 | pin_memory=params.pin_memory,
138 | )
139 |
140 | # Create optimizer
141 | self.optimizer = torch.optim.Adam(
142 | self.model.parameters(), lr=params.learning_rate
143 | )
144 |
--------------------------------------------------------------------------------
/polyffusion/stable_diffusion/losses/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | from collections import namedtuple
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .util import get_ckpt_path
10 |
11 |
12 | class LPIPS(nn.Module):
13 | # Learned perceptual metric
14 | def __init__(self, use_dropout=True):
15 | super().__init__()
16 | self.scaling_layer = ScalingLayer()
17 | self.chns = [64, 128, 256, 512, 512] # vg16 features
18 | self.net = vgg16(pretrained=True, requires_grad=False)
19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24 | self.load_from_pretrained()
25 | for param in self.parameters():
26 | param.requires_grad = False
27 |
28 | def load_from_pretrained(self, name="vgg_lpips"):
29 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
30 | self.load_state_dict(
31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32 | )
33 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
34 |
35 | @classmethod
36 | def from_pretrained(cls, name="vgg_lpips"):
37 | if name != "vgg_lpips":
38 | raise NotImplementedError
39 | model = cls()
40 | ckpt = get_ckpt_path(name)
41 | model.load_state_dict(
42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43 | )
44 | return model
45 |
46 | def forward(self, input, target):
47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
49 | feats0, feats1, diffs = {}, {}, {}
50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51 | for kk in range(len(self.chns)):
52 | feats0[kk], feats1[kk] = (
53 | normalize_tensor(outs0[kk]),
54 | normalize_tensor(outs1[kk]),
55 | )
56 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
57 |
58 | res = [
59 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
60 | for kk in range(len(self.chns))
61 | ]
62 | val = res[0]
63 | for l in range(1, len(self.chns)):
64 | val += res[l]
65 | return val
66 |
67 |
68 | class ScalingLayer(nn.Module):
69 | def __init__(self):
70 | super(ScalingLayer, self).__init__()
71 | self.register_buffer(
72 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
73 | )
74 | self.register_buffer(
75 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
76 | )
77 |
78 | def forward(self, inp):
79 | return (inp - self.shift) / self.scale
80 |
81 |
82 | class NetLinLayer(nn.Module):
83 | """A single linear layer which does a 1x1 conv"""
84 |
85 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
86 | super(NetLinLayer, self).__init__()
87 | layers = (
88 | [
89 | nn.Dropout(),
90 | ]
91 | if (use_dropout)
92 | else []
93 | )
94 | layers += [
95 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
96 | ]
97 | self.model = nn.Sequential(*layers)
98 |
99 |
100 | class vgg16(torch.nn.Module):
101 | def __init__(self, requires_grad=False, pretrained=True):
102 | super(vgg16, self).__init__()
103 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
104 | self.slice1 = torch.nn.Sequential()
105 | self.slice2 = torch.nn.Sequential()
106 | self.slice3 = torch.nn.Sequential()
107 | self.slice4 = torch.nn.Sequential()
108 | self.slice5 = torch.nn.Sequential()
109 | self.N_slices = 5
110 | for x in range(4):
111 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
112 | for x in range(4, 9):
113 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
114 | for x in range(9, 16):
115 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
116 | for x in range(16, 23):
117 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
118 | for x in range(23, 30):
119 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
120 | if not requires_grad:
121 | for param in self.parameters():
122 | param.requires_grad = False
123 |
124 | def forward(self, X):
125 | h = self.slice1(X)
126 | h_relu1_2 = h
127 | h = self.slice2(h)
128 | h_relu2_2 = h
129 | h = self.slice3(h)
130 | h_relu3_3 = h
131 | h = self.slice4(h)
132 | h_relu4_3 = h
133 | h = self.slice5(h)
134 | h_relu5_3 = h
135 | vgg_outputs = namedtuple(
136 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
137 | )
138 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
139 | return out
140 |
141 |
142 | def normalize_tensor(x, eps=1e-10):
143 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
144 | return x / (norm_factor + eps)
145 |
146 |
147 | def spatial_average(x, keepdim=True):
148 | return x.mean([2, 3], keepdim=keepdim)
149 |
--------------------------------------------------------------------------------
/polyffusion/stable_diffusion/sampler/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | ---
3 | title: Sampling algorithms for stable diffusion
4 | summary: >
5 | Annotated PyTorch implementation/tutorial of
6 | sampling algorithms
7 | for stable diffusion model.
8 | ---
9 |
10 | # Sampling algorithms for [stable diffusion](../index.html)
11 |
12 | We have implemented the following [sampling algorithms](sampler/index.html):
13 |
14 | * [Denoising Diffusion Probabilistic Models (DDPM) Sampling](ddpm.html)
15 | * [Denoising Diffusion Implicit Models (DDIM) Sampling](ddim.html)
16 | """
17 |
18 | from typing import List, Optional
19 |
20 | import torch
21 |
22 | from ..latent_diffusion import LatentDiffusion
23 |
24 |
25 | class DiffusionSampler:
26 | """
27 | ## Base class for sampling algorithms
28 | """
29 |
30 | model: LatentDiffusion
31 |
32 | def __init__(self, model: LatentDiffusion):
33 | """
34 | :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
35 | """
36 | super().__init__()
37 | # Set the model $\epsilon_\text{cond}(x_t, c)$
38 | self.model = model
39 | # Get number of steps the model was trained with $T$
40 | self.n_steps = model.n_steps
41 |
42 | def get_eps(
43 | self,
44 | x: torch.Tensor,
45 | t: torch.Tensor,
46 | c: torch.Tensor,
47 | *,
48 | uncond_scale: float,
49 | uncond_cond: Optional[torch.Tensor],
50 | ):
51 | """
52 | ## Get $\epsilon(x_t, c)$
53 |
54 | :param x: is $x_t$ of shape `[batch_size, channels, height, width]`
55 | :param t: is $t$ of shape `[batch_size]`
56 | :param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
57 | :param uncond_scale: is the unconditional guidance scale $s$. This is used for
58 | $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
59 | :param uncond_cond: is the conditional embedding for empty prompt $c_u$
60 | """
61 | # When the scale $s = 1$
62 | # $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$
63 | if uncond_cond is None or uncond_scale == 1.0:
64 | return self.model(x, t, c)
65 | elif uncond_scale == 0.0: # unconditional
66 | return self.model(x, t, uncond_cond)
67 |
68 | # Duplicate $x_t$ and $t$
69 | x_in = torch.cat([x] * 2)
70 | t_in = torch.cat([t] * 2)
71 | # Concatenated $c$ and $c_u$
72 | c_in = torch.cat([uncond_cond, c])
73 | # Get $\epsilon_\text{cond}(x_t, c)$ and $\epsilon_\text{cond}(x_t, c_u)$
74 | e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)
75 | # Calculate
76 | # $$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$$
77 | e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
78 |
79 | #
80 | return e_t
81 |
82 | def sample(
83 | self,
84 | shape: List[int],
85 | cond: torch.Tensor,
86 | repeat_noise: bool = False,
87 | temperature: float = 1.0,
88 | x_last: Optional[torch.Tensor] = None,
89 | uncond_scale: float = 1.0,
90 | uncond_cond: Optional[torch.Tensor] = None,
91 | skip_steps: int = 0,
92 | ):
93 | """
94 | ### Sampling Loop
95 |
96 | :param shape: is the shape of the generated images in the
97 | form `[batch_size, channels, height, width]`
98 | :param cond: is the conditional embeddings $c$
99 | :param temperature: is the noise temperature (random noise gets multiplied by this)
100 | :param x_last: is $x_T$. If not provided random noise will be used.
101 | :param uncond_scale: is the unconditional guidance scale $s$. This is used for
102 | $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
103 | :param uncond_cond: is the conditional embedding for empty prompt $c_u$
104 | :param skip_steps: is the number of time steps to skip.
105 | """
106 | raise NotImplementedError()
107 |
108 | def paint(
109 | self,
110 | x: torch.Tensor,
111 | cond: torch.Tensor,
112 | t_start: int,
113 | *,
114 | orig: Optional[torch.Tensor] = None,
115 | mask: Optional[torch.Tensor] = None,
116 | orig_noise: Optional[torch.Tensor] = None,
117 | uncond_scale: float = 1.0,
118 | uncond_cond: Optional[torch.Tensor] = None,
119 | ):
120 | """
121 | ### Painting Loop
122 |
123 | :param x: is $x_{T'}$ of shape `[batch_size, channels, height, width]`
124 | :param cond: is the conditional embeddings $c$
125 | :param t_start: is the sampling step to start from, $T'$
126 | :param orig: is the original image in latent page which we are in paining.
127 | :param mask: is the mask to keep the original image.
128 | :param orig_noise: is fixed noise to be added to the original image.
129 | :param uncond_scale: is the unconditional guidance scale $s$. This is used for
130 | $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
131 | :param uncond_cond: is the conditional embedding for empty prompt $c_u$
132 | """
133 | raise NotImplementedError()
134 |
135 | def q_sample(
136 | self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None
137 | ):
138 | """
139 | ### Sample from $q(x_t|x_0)$
140 |
141 | :param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
142 | :param index: is the time step $t$ index
143 | :param noise: is the noise, $\epsilon$
144 | """
145 | raise NotImplementedError()
146 |
--------------------------------------------------------------------------------
/polyffusion/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 | from data.dataset import PianoOrchDataset
8 | from utils import (
9 | chd_pitch_shift,
10 | chd_to_midi_file,
11 | chd_to_onehot,
12 | estx_to_midi_file,
13 | pianotree_pitch_shift,
14 | pr_mat_pitch_shift,
15 | prmat2c_to_midi_file,
16 | prmat_to_midi_file,
17 | )
18 |
19 | # SEED = 7890
20 | # torch.manual_seed(SEED)
21 | # np.random.seed(SEED)
22 | # random.seed(SEED)
23 |
24 |
25 | def collate_fn(batch, shift):
26 | def sample_shift():
27 | return np.random.choice(np.arange(-6, 6), 1)[0]
28 |
29 | prmat2c = []
30 | pnotree = []
31 | chord = []
32 | prmat = []
33 | song_fn = []
34 | for b in batch:
35 | # b[0]: seg_pnotree; b[1]: seg_pnotree_y
36 | seg_prmat2c = b[0]
37 | seg_pnotree = b[1]
38 | seg_chord = b[2]
39 | seg_prmat = b[3]
40 |
41 | if shift:
42 | shift_pitch = sample_shift()
43 | seg_prmat2c = pr_mat_pitch_shift(seg_prmat2c, shift_pitch)
44 | seg_pnotree = pianotree_pitch_shift(seg_pnotree, shift_pitch)
45 | seg_chord = chd_pitch_shift(seg_chord, shift_pitch)
46 | seg_prmat = pr_mat_pitch_shift(seg_prmat, shift_pitch)
47 |
48 | seg_chord = chd_to_onehot(seg_chord)
49 |
50 | prmat2c.append(seg_prmat2c)
51 | pnotree.append(seg_pnotree)
52 | chord.append(seg_chord)
53 | prmat.append(seg_prmat)
54 |
55 | if len(b) > 4:
56 | song_fn.append(b[4])
57 |
58 | prmat2c = torch.Tensor(np.array(prmat2c, np.float32)).float()
59 | pnotree = torch.Tensor(np.array(pnotree, np.int64)).long()
60 | chord = torch.Tensor(np.array(chord, np.float32)).float()
61 | prmat = torch.Tensor(np.array(prmat, np.float32)).float()
62 | # prmat = prmat.unsqueeze(1) # (B, 1, 128, 128)
63 | if len(song_fn) > 0:
64 | return prmat2c, pnotree, chord, prmat, song_fn
65 | else:
66 | return prmat2c, pnotree, chord, prmat
67 |
68 |
69 | def get_custom_train_val_dataloaders(
70 | batch_size,
71 | data_dir,
72 | num_workers=0,
73 | pin_memory=False,
74 | debug=False,
75 | train_ratio=0.9,
76 | ):
77 | all_data = next(os.walk(data_dir))[2]
78 | train_num = int(len(all_data) * train_ratio)
79 | train_files = all_data[:train_num]
80 | val_files = all_data[train_num:]
81 |
82 | train_dataset = PianoOrchDataset.load_with_song_paths(
83 | song_paths=train_files, data_dir=data_dir
84 | )
85 | val_dataset = PianoOrchDataset.load_with_song_paths(
86 | song_paths=val_files, data_dir=data_dir
87 | )
88 |
89 | train_dl = DataLoader(
90 | train_dataset,
91 | batch_size,
92 | True,
93 | collate_fn=lambda x: collate_fn(x, shift=True),
94 | num_workers=num_workers,
95 | pin_memory=pin_memory,
96 | )
97 | val_dl = DataLoader(
98 | val_dataset,
99 | batch_size,
100 | False,
101 | collate_fn=lambda x: collate_fn(x, shift=False),
102 | num_workers=num_workers,
103 | pin_memory=pin_memory,
104 | )
105 | print(
106 | f"Dataloader ready: batch_size={batch_size}, num_workers={num_workers}, pin_memory={pin_memory}, train_segments={len(train_dataset)}, val_segments={len(val_dataset)}"
107 | )
108 | return train_dl, val_dl
109 |
110 |
111 | def get_train_val_dataloaders(
112 | batch_size, num_workers=0, pin_memory=False, debug=False, **kwargs
113 | ):
114 | train_dataset, val_dataset = PianoOrchDataset.load_train_and_valid_sets(
115 | debug=debug, **kwargs
116 | )
117 | train_dl = DataLoader(
118 | train_dataset,
119 | batch_size,
120 | True,
121 | collate_fn=lambda x: collate_fn(x, shift=True),
122 | num_workers=num_workers,
123 | pin_memory=pin_memory,
124 | )
125 | val_dl = DataLoader(
126 | val_dataset,
127 | batch_size,
128 | False,
129 | collate_fn=lambda x: collate_fn(x, shift=False),
130 | num_workers=num_workers,
131 | pin_memory=pin_memory,
132 | )
133 | print(
134 | f"Dataloader ready: batch_size={batch_size}, num_workers={num_workers}, pin_memory={pin_memory}, train_segments={len(train_dataset)}, val_segments={len(val_dataset)} {kwargs}"
135 | )
136 | return train_dl, val_dl
137 |
138 |
139 | def get_val_dataloader(
140 | batch_size, num_workers=0, pin_memory=False, debug=False, **kwargs
141 | ):
142 | val_dataset = PianoOrchDataset.load_valid_set(debug, **kwargs)
143 | val_dl = DataLoader(
144 | val_dataset,
145 | batch_size,
146 | False,
147 | collate_fn=lambda x: collate_fn(x, shift=False),
148 | num_workers=num_workers,
149 | pin_memory=pin_memory,
150 | )
151 | print(
152 | f"Dataloader ready: batch_size={batch_size}, num_workers={num_workers}, pin_memory={pin_memory}, {kwargs}"
153 | )
154 | return val_dl
155 |
156 |
157 | if __name__ == "__main__":
158 | train_dl, val_dl = get_train_val_dataloaders(16)
159 | print(len(train_dl))
160 | for batch in train_dl:
161 | print(len(batch))
162 | prmat2c, pnotree, chord, prmat = batch
163 | print(prmat2c.shape)
164 | print(pnotree.shape)
165 | print(chord.shape)
166 | print(prmat.shape)
167 | prmat2c = prmat2c.cpu().numpy()
168 | pnotree = pnotree.cpu().numpy()
169 | chord = chord.cpu().numpy()
170 | prmat = prmat.cpu().numpy()
171 | # chord = [onehot_to_chd(onehot) for onehot in chord]
172 | prmat2c_to_midi_file(prmat2c, "exp/dl_prmat2c.mid")
173 | estx_to_midi_file(pnotree, "exp/dl_pnotree.mid")
174 | chd_to_midi_file(chord, "exp/dl_chord.mid")
175 | prmat_to_midi_file(prmat, "exp/dl_prmat.mid")
176 | exit(0)
177 |
--------------------------------------------------------------------------------
/polyffusion/mir_eval/tempo.py:
--------------------------------------------------------------------------------
1 | '''
2 | The goal of a tempo estimation algorithm is to automatically detect the tempo
3 | of a piece of music, measured in beats per minute (BPM).
4 |
5 | See http://www.music-ir.org/mirex/wiki/2014:Audio_Tempo_Estimation for a
6 | description of the task and evaluation criteria.
7 |
8 | Conventions
9 | -----------
10 |
11 | Reference and estimated tempi should be positive, and provided in ascending
12 | order as a numpy array of length 2.
13 |
14 | The weighting value from the reference must be a float in the range [0, 1].
15 |
16 | Metrics
17 | -------
18 | * :func:`mir_eval.tempo.detection`: Relative error, hits, and weighted
19 | precision of tempo estimation.
20 |
21 | '''
22 |
23 | import warnings
24 | import numpy as np
25 | import collections
26 | from . import util
27 |
28 |
29 | def validate_tempi(tempi, reference=True):
30 | """Checks that there are two non-negative tempi.
31 | For a reference value, at least one tempo has to be greater than zero.
32 |
33 | Parameters
34 | ----------
35 | tempi : np.ndarray
36 | length-2 array of tempo, in bpm
37 |
38 | reference : bool
39 | indicates a reference value
40 |
41 | """
42 |
43 | if tempi.size != 2:
44 | raise ValueError('tempi must have exactly two values')
45 |
46 | if not np.all(np.isfinite(tempi)) or np.any(tempi < 0):
47 | raise ValueError('tempi={} must be non-negative numbers'.format(tempi))
48 |
49 | if reference and np.all(tempi == 0):
50 | raise ValueError('reference tempi={} must have one'
51 | ' value greater than zero'.format(tempi))
52 |
53 |
54 | def validate(reference_tempi, reference_weight, estimated_tempi):
55 | """Checks that the input annotations to a metric look like valid tempo
56 | annotations.
57 |
58 | Parameters
59 | ----------
60 | reference_tempi : np.ndarray
61 | reference tempo values, in bpm
62 |
63 | reference_weight : float
64 | perceptual weight of slow vs fast in reference
65 |
66 | estimated_tempi : np.ndarray
67 | estimated tempo values, in bpm
68 |
69 | """
70 | validate_tempi(reference_tempi, reference=True)
71 | validate_tempi(estimated_tempi, reference=False)
72 |
73 | if reference_weight < 0 or reference_weight > 1:
74 | raise ValueError('Reference weight must lie in range [0, 1]')
75 |
76 |
77 | def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08):
78 | """Compute the tempo detection accuracy metric.
79 |
80 | Parameters
81 | ----------
82 | reference_tempi : np.ndarray, shape=(2,)
83 | Two non-negative reference tempi
84 |
85 | reference_weight : float > 0
86 | The relative strength of ``reference_tempi[0]`` vs
87 | ``reference_tempi[1]``.
88 |
89 | estimated_tempi : np.ndarray, shape=(2,)
90 | Two non-negative estimated tempi.
91 |
92 | tol : float in [0, 1]:
93 | The maximum allowable deviation from a reference tempo to
94 | count as a hit.
95 | ``|est_t - ref_t| <= tol * ref_t``
96 | (Default value = 0.08)
97 |
98 | Returns
99 | -------
100 | p_score : float in [0, 1]
101 | Weighted average of recalls:
102 | ``reference_weight * hits[0] + (1 - reference_weight) * hits[1]``
103 |
104 | one_correct : bool
105 | True if at least one reference tempo was correctly estimated
106 |
107 | both_correct : bool
108 | True if both reference tempi were correctly estimated
109 |
110 | Raises
111 | ------
112 | ValueError
113 | If the input tempi are ill-formed
114 |
115 | If the reference weight is not in the range [0, 1]
116 |
117 | If ``tol < 0`` or ``tol > 1``.
118 | """
119 |
120 | validate(reference_tempi, reference_weight, estimated_tempi)
121 |
122 | if tol < 0 or tol > 1:
123 | raise ValueError('invalid tolerance {}: must lie in the range '
124 | '[0, 1]'.format(tol))
125 | if tol == 0.:
126 | warnings.warn('A tolerance of 0.0 may not '
127 | 'lead to the results you expect.')
128 |
129 | hits = [False, False]
130 |
131 | for i, ref_t in enumerate(reference_tempi):
132 | if ref_t > 0:
133 | # Compute the relative error for this reference tempo
134 | f_ref_t = float(ref_t)
135 | relative_error = np.min(np.abs(ref_t - estimated_tempi) / f_ref_t)
136 |
137 | # Count the hits
138 | hits[i] = relative_error <= tol
139 |
140 | p_score = reference_weight * hits[0] + (1.0-reference_weight) * hits[1]
141 |
142 | one_correct = bool(np.max(hits))
143 | both_correct = bool(np.min(hits))
144 |
145 | return p_score, one_correct, both_correct
146 |
147 |
148 | def evaluate(reference_tempi, reference_weight, estimated_tempi, **kwargs):
149 | """Compute all metrics for the given reference and estimated annotations.
150 |
151 | Parameters
152 | ----------
153 | reference_tempi : np.ndarray, shape=(2,)
154 | Two non-negative reference tempi
155 |
156 | reference_weight : float > 0
157 | The relative strength of ``reference_tempi[0]`` vs
158 | ``reference_tempi[1]``.
159 |
160 | estimated_tempi : np.ndarray, shape=(2,)
161 | Two non-negative estimated tempi.
162 |
163 | kwargs
164 | Additional keyword arguments which will be passed to the
165 | appropriate metric or preprocessing functions.
166 |
167 | Returns
168 | -------
169 | scores : dict
170 | Dictionary of scores, where the key is the metric name (str) and
171 | the value is the (float) score achieved.
172 | """
173 | # Compute all metrics
174 | scores = collections.OrderedDict()
175 |
176 | (scores['P-score'],
177 | scores['One-correct'],
178 | scores['Both-correct']) = util.filter_kwargs(detection, reference_tempi,
179 | reference_weight,
180 | estimated_tempi,
181 | **kwargs)
182 |
183 | return scores
184 |
--------------------------------------------------------------------------------
/polyffusion/chord_extractor/mir/extractors/vamp_extractor.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | import numpy as np
4 | from mir.cache import hasher
5 | from mir.common import PACKAGE_PATH, SONIC_ANNOTATOR_PATH, WORKING_PATH
6 | from mir.extractors.extractor_base import *
7 |
8 |
9 | def rewrite_extract_n3(entry, inputfilename, outputfilename):
10 | f = open(inputfilename, "r")
11 | content = f.read()
12 | f.close()
13 | content = content.replace("[__SR__]", str(entry.prop.sr))
14 | content = content.replace("[__WIN_SHIFT__]", str(entry.prop.hop_length))
15 | content = content.replace("[__WIN_SIZE__]", str(entry.prop.win_size))
16 | if not os.path.isdir(os.path.dirname(outputfilename)):
17 | os.makedirs(os.path.dirname(outputfilename))
18 | f = open(outputfilename, "w")
19 | f.write(content)
20 | f.close()
21 |
22 |
23 | class NNLSChroma(ExtractorBase):
24 | def get_feature_class(self):
25 | return io.ChromaIO
26 |
27 | def extract(self, entry, **kwargs):
28 | print("NNLSChroma working on entry " + entry.name)
29 | if "margin" in kwargs:
30 | if kwargs["margin"] > 0:
31 | music = entry.music_h
32 | else:
33 | raise Exception("Error margin")
34 |
35 | else:
36 | music = entry.music
37 | music_io = io.MusicIO()
38 | temp_path = os.path.join(
39 | WORKING_PATH, "temp/nnlschroma_extractor_%s.wav" % hasher(entry.name)
40 | )
41 | temp_n3_path = temp_path + ".n3"
42 | rewrite_extract_n3(
43 | entry, os.path.join(PACKAGE_PATH, "data/bothchroma.n3"), temp_n3_path
44 | )
45 | music_io.write(music, temp_path, entry)
46 | proc = subprocess.Popen(
47 | [
48 | SONIC_ANNOTATOR_PATH,
49 | "-t",
50 | temp_n3_path,
51 | temp_path,
52 | "-w",
53 | "lab",
54 | "--lab-stdout",
55 | ],
56 | stdout=subprocess.PIPE,
57 | stderr=subprocess.DEVNULL,
58 | )
59 | # print('Begin processing')
60 | result = np.zeros((0, 24))
61 | for line in proc.stdout:
62 | # the real code does filtering here
63 | line = bytes.decode(line)
64 | if line.endswith("\r\n"):
65 | line = line[: len(line) - 2]
66 | if line.endswith("\r"):
67 | line = line[: len(line) - 1]
68 | arr = np.array(list(map(float, line.split("\t")))[1:])
69 | arr = arr.reshape((2, 12))[::-1].T
70 | arr = np.roll(arr, -3, axis=0).reshape((1, 24))
71 | result = np.append(result, arr, axis=0)
72 | try:
73 | os.unlink(temp_path)
74 | os.unlink(temp_n3_path)
75 | except:
76 | pass
77 | if result.shape[0] == 0:
78 | raise Exception("Empty response")
79 | return result
80 |
81 |
82 | class TunedLogSpectrogram(ExtractorBase):
83 | def get_feature_class(self):
84 | return io.SpectrogramIO
85 |
86 | def extract(self, entry, **kwargs):
87 | print("TunedLogSpectrogram working on entry " + entry.name)
88 | music_io = io.MusicIO()
89 | temp_path = os.path.join(
90 | WORKING_PATH,
91 | "temp/tunedlogspectrogram_extractor_%s.wav" % hasher(entry.name),
92 | )
93 | temp_n3_path = temp_path + ".n3"
94 | rewrite_extract_n3(
95 | entry, os.path.join(PACKAGE_PATH, "data/tunedlogfreqspec.n3"), temp_n3_path
96 | )
97 | music_io.write(entry.music, temp_path, entry)
98 | proc = subprocess.Popen(
99 | [
100 | SONIC_ANNOTATOR_PATH,
101 | "-t",
102 | temp_n3_path,
103 | temp_path,
104 | "-w",
105 | "lab",
106 | "--lab-stdout",
107 | ],
108 | stdout=subprocess.PIPE,
109 | stderr=subprocess.DEVNULL,
110 | )
111 | # print('Begin processing')
112 | result = np.zeros((0, 256))
113 | for line in proc.stdout:
114 | # the real code does filtering here
115 | line = bytes.decode(line)
116 | if line.endswith("\r\n"):
117 | line = line[: len(line) - 2]
118 | if line.endswith("\r"):
119 | line = line[: len(line) - 1]
120 | arr = np.array(list(map(float, line.split("\t")))[1:])
121 | arr = arr.reshape((1, -1))
122 | result = np.append(result, arr, axis=0)
123 | try:
124 | os.unlink(temp_path)
125 | os.unlink(temp_n3_path)
126 | except:
127 | pass
128 | if result.shape[0] == 0:
129 | raise Exception("Empty response")
130 | return result
131 |
132 |
133 | class GlobalTuning(ExtractorBase):
134 | def get_feature_class(self):
135 | return io.FloatIO
136 |
137 | def extract(self, entry, **kwargs):
138 | music_io = io.MusicIO()
139 | temp_path = os.path.join(
140 | WORKING_PATH, "temp/tuning_%s.wav" % hasher(entry.name)
141 | )
142 | temp_n3_path = temp_path + ".n3"
143 | rewrite_extract_n3(
144 | entry, os.path.join(PACKAGE_PATH, "data/tuning.n3"), temp_n3_path
145 | )
146 | if "source" in kwargs:
147 | music = entry.dict[kwargs["source"]].get(entry)
148 | else:
149 | music = entry.music
150 | music_io.write(music, temp_path, entry)
151 | proc = subprocess.Popen(
152 | [
153 | SONIC_ANNOTATOR_PATH,
154 | "-t",
155 | temp_n3_path,
156 | temp_path,
157 | "-w",
158 | "lab",
159 | "--lab-stdout",
160 | ],
161 | stdout=subprocess.PIPE,
162 | stderr=subprocess.DEVNULL,
163 | )
164 | # print('Begin processing')
165 | output = proc.stdout.readlines()
166 | result = (
167 | np.log2(np.float64(output[0].decode().split("\t")[2])) - np.log2(440)
168 | ) * 12
169 | try:
170 | os.unlink(temp_path)
171 | os.unlink(temp_n3_path)
172 | except:
173 | pass
174 | return result
175 |
--------------------------------------------------------------------------------
/polyffusion/stable_diffusion/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | # from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 | from .vqperceptual import *
6 |
7 |
8 | class LPIPSWithDiscriminator(nn.Module):
9 | def __init__(
10 | self,
11 | disc_start,
12 | logvar_init=0.0,
13 | kl_weight=1.0,
14 | pixelloss_weight=1.0,
15 | disc_num_layers=3,
16 | disc_in_channels=3,
17 | disc_factor=1.0,
18 | disc_weight=1.0,
19 | perceptual_weight=1.0,
20 | use_actnorm=False,
21 | disc_conditional=False,
22 | disc_loss="hinge",
23 | ):
24 | super().__init__()
25 | assert disc_loss in ["hinge", "vanilla"]
26 | self.kl_weight = kl_weight
27 | self.pixel_weight = pixelloss_weight
28 | self.perceptual_loss = LPIPS().eval()
29 | self.perceptual_weight = perceptual_weight
30 | # output log variance
31 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
32 |
33 | self.discriminator = NLayerDiscriminator(
34 | input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
35 | ).apply(weights_init)
36 | self.discriminator_iter_start = disc_start
37 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
38 | self.disc_factor = disc_factor
39 | self.discriminator_weight = disc_weight
40 | self.disc_conditional = disc_conditional
41 |
42 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
43 | if last_layer is not None:
44 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
45 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
46 | else:
47 | nll_grads = torch.autograd.grad(
48 | nll_loss, self.last_layer[0], retain_graph=True
49 | )[0]
50 | g_grads = torch.autograd.grad(
51 | g_loss, self.last_layer[0], retain_graph=True
52 | )[0]
53 |
54 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
55 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
56 | d_weight = d_weight * self.discriminator_weight
57 | return d_weight
58 |
59 | def forward(
60 | self,
61 | inputs,
62 | reconstructions,
63 | posteriors,
64 | optimizer_idx,
65 | global_step,
66 | last_layer=None,
67 | cond=None,
68 | split="train",
69 | weights=None,
70 | ):
71 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
72 | if self.perceptual_weight > 0:
73 | p_loss = self.perceptual_loss(
74 | inputs.contiguous(), reconstructions.contiguous()
75 | )
76 | rec_loss = rec_loss + self.perceptual_weight * p_loss
77 |
78 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
79 | weighted_nll_loss = nll_loss
80 | if weights is not None:
81 | weighted_nll_loss = weights * nll_loss
82 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
83 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
84 | kl_loss = posteriors.kl()
85 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
86 |
87 | # now the GAN part
88 | if optimizer_idx == 0:
89 | # generator update
90 | if cond is None:
91 | assert not self.disc_conditional
92 | logits_fake = self.discriminator(reconstructions.contiguous())
93 | else:
94 | assert self.disc_conditional
95 | logits_fake = self.discriminator(
96 | torch.cat((reconstructions.contiguous(), cond), dim=1)
97 | )
98 | g_loss = -torch.mean(logits_fake)
99 |
100 | if self.disc_factor > 0.0:
101 | try:
102 | d_weight = self.calculate_adaptive_weight(
103 | nll_loss, g_loss, last_layer=last_layer
104 | )
105 | except RuntimeError:
106 | assert not self.training
107 | d_weight = torch.tensor(0.0)
108 | else:
109 | d_weight = torch.tensor(0.0)
110 |
111 | disc_factor = adopt_weight(
112 | self.disc_factor, global_step, threshold=self.discriminator_iter_start
113 | )
114 | loss = (
115 | weighted_nll_loss
116 | + self.kl_weight * kl_loss
117 | + d_weight * disc_factor * g_loss
118 | )
119 |
120 | log = {
121 | "{}/total_loss".format(split): loss.clone().detach().mean(),
122 | "{}/logvar".format(split): self.logvar.detach(),
123 | "{}/kl_loss".format(split): kl_loss.detach().mean(),
124 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
125 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
126 | "{}/d_weight".format(split): d_weight.detach(),
127 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
128 | "{}/g_loss".format(split): g_loss.detach().mean(),
129 | }
130 | return loss, log
131 |
132 | if optimizer_idx == 1:
133 | # second pass for discriminator update
134 | if cond is None:
135 | logits_real = self.discriminator(inputs.contiguous().detach())
136 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
137 | else:
138 | logits_real = self.discriminator(
139 | torch.cat((inputs.contiguous().detach(), cond), dim=1)
140 | )
141 | logits_fake = self.discriminator(
142 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
143 | )
144 |
145 | disc_factor = adopt_weight(
146 | self.disc_factor, global_step, threshold=self.discriminator_iter_start
147 | )
148 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
149 |
150 | log = {
151 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
152 | "{}/logits_real".format(split): logits_real.detach().mean(),
153 | "{}/logits_fake".format(split): logits_fake.detach().mean(),
154 | }
155 | return d_loss, log
156 |
--------------------------------------------------------------------------------