├── CITATION.cff ├── LICENSE ├── README.md ├── pypianoroll ├── __init__.py ├── core.py ├── inputs.py ├── metrics.py ├── multitrack.py ├── outputs.py ├── track.py ├── utils.py ├── version.py └── visualization.py ├── pyproject.toml ├── setup.cfg └── setup.py /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: If you use this software, please cite it using these metadata. 3 | authors: 4 | - family-names: Dong 5 | given-names: Hao-Wen 6 | title: Pypianoroll 7 | preferred-citation: 8 | type: article 9 | authors: 10 | - family-names: Dong 11 | given-names: Hao-Wen 12 | - family-names: Hsiao 13 | given-names: Wen-Yi 14 | - family-names: Yang 15 | given-names: Yi-Hsuan 16 | title: "Pypianoroll: Open Source Python Package for Handling Multitrack Pianorolls" 17 | journal: Late-Breaking Demos of the 19th International Society for Music Information Retrieval Conference (ISMIR) 18 | year: 2018 19 | date-released: 2018-01-27 20 | license: MIT 21 | url: "https://salu133445.github.io/pypianoroll/" 22 | repository-code: "https://github.com/salu133445/pypianoroll" 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Hao-Wen Dong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Pypianoroll 2 | =========== 3 | 4 | [![GitHub workflow](https://img.shields.io/github/workflow/status/salu133445/pypianoroll/Testing)](https://github.com/salu133445/pypianoroll/actions) 5 | [![Codecov](https://img.shields.io/codecov/c/github/salu133445/pypianoroll)](https://codecov.io/gh/salu133445/pypianoroll) 6 | [![GitHub license](https://img.shields.io/github/license/salu133445/pypianoroll)](https://github.com/salu133445/pypianoroll/blob/main/LICENSE) 7 | [![GitHub release](https://img.shields.io/github/v/release/salu133445/pypianoroll)](https://github.com/salu133445/pypianoroll/releases) 8 | 9 | 10 | Pypianoroll is an open source Python library for working with piano rolls. It provides essential tools for handling multitrack piano rolls, including efficient I/O as well as manipulation, visualization and evaluation tools. 11 | 12 | 13 | Features 14 | -------- 15 | 16 | - Manipulate multitrack piano rolls intuitively 17 | - Visualize multitrack piano rolls beautifully 18 | - Save and load multitrack piano rolls in a space-efficient format 19 | - Parse MIDI files into multitrack piano rolls 20 | - Write multitrack piano rolls into MIDI files 21 | 22 | 23 | Why Pypianoroll 24 | --------------- 25 | 26 | Our aim is to provide convenient classes for piano-roll matrix and MIDI-like track information (program number, track name, drum track indicator). Pypianoroll is also designed to provide efficient I/O for piano rolls, since piano rolls have long been considered an inefficient way to store music data due to the sparse nature. 27 | 28 | 29 | Installation 30 | ------------ 31 | 32 | To install Pypianoroll, please run `pip install pypianoroll`. To build Pypianoroll from source, please download the [source](https://github.com/salu133445/pypianoroll/releases) and run `python setup.py install`. 33 | 34 | 35 | Documentation 36 | ------------- 37 | 38 | Documentation is available [here](https://salu133445.github.io/pypianoroll) and as docstrings with the code. 39 | 40 | 41 | Citing 42 | ------ 43 | 44 | Please cite the following paper if you use the code provided in this repository. 45 | 46 | Hao-Wen Dong, Wen-Yi Hsiao, and Yi-Hsuan Yang, "Pypianoroll: Open Source Python Package for Handling Multitrack Pianorolls," in _Late-Breaking Demos of the 19th International Society for Music Information Retrieval Conference (ISMIR)_, 2018.
47 | [[homepage](https://salu133445.github.io/pypianoroll/)] 48 | [[paper](https://salu133445.github.io/pypianoroll/pdf/pypianoroll_ismir2018_lbd_paper.pdf)] 49 | [[poster](https://salu133445.github.io/pypianoroll/pdf/pypianoroll_ismir2018_lbd_poster.pdf)] 50 | [[code](https://github.com/salu133445/pypianoroll)] 51 | [[documentation](https://salu133445.github.io/pypianoroll/)] 52 | 53 | 54 | Lakh Pianoroll Dataset 55 | ---------------------- 56 | 57 | [Lakh Pianoroll Dataset](https://salu133445.github.io/musegan/dataset) (LPD) is a new multitrack piano roll dataset using Pypianoroll for efficient data I/O and to save space, which is used as the training dataset in our [MuseGAN](https://salu133445.github.io/musegan) project. 58 | -------------------------------------------------------------------------------- /pypianoroll/__init__.py: -------------------------------------------------------------------------------- 1 | """A Python library for handling multitrack pianorolls. 2 | 3 | Pypianoroll is an open source Python library for working with piano 4 | rolls. It provides essential tools for handling multitrack piano rolls, 5 | including efficient I/O as well as manipulation, visualization and 6 | evaluation tools. 7 | 8 | Features 9 | -------- 10 | 11 | - Manipulate multitrack piano rolls intuitively 12 | - Visualize multitrack piano rolls beautifully 13 | - Save and load multitrack piano rolls in a space-efficient format 14 | - Parse MIDI files into multitrack piano rolls 15 | - Write multitrack piano rolls into MIDI files 16 | 17 | """ 18 | from . import core, inputs, metrics, multitrack, outputs, track, visualization 19 | from .core import * 20 | from .inputs import * 21 | from .metrics import * 22 | from .multitrack import * 23 | from .outputs import * 24 | from .track import * 25 | from .version import __version__ 26 | from .visualization import * 27 | 28 | __all__ = ["__version__"] 29 | __all__.extend(core.__all__) 30 | __all__.extend(inputs.__all__) 31 | __all__.extend(metrics.__all__) 32 | __all__.extend(multitrack.__all__) 33 | __all__.extend(outputs.__all__) 34 | __all__.extend(track.__all__) 35 | __all__.extend(visualization.__all__) 36 | -------------------------------------------------------------------------------- /pypianoroll/core.py: -------------------------------------------------------------------------------- 1 | """Functions for Pypianoroll objects. 2 | 3 | Functions 4 | --------- 5 | 6 | - binarize 7 | - clip 8 | - pad 9 | - pad_to_multiple 10 | - pad_to_same 11 | - plot 12 | - set_nonzeros 13 | - set_resolution 14 | - transpose 15 | - trim 16 | 17 | 18 | """ 19 | from typing import List, TypeVar, Union, overload 20 | 21 | from matplotlib.axes import Axes 22 | 23 | from .multitrack import Multitrack 24 | from .track import BinaryTrack, StandardTrack, Track 25 | 26 | __all__ = [ 27 | "binarize", 28 | "clip", 29 | "pad", 30 | "pad_to_multiple", 31 | "pad_to_same", 32 | "plot", 33 | "set_nonzeros", 34 | "set_resolution", 35 | "transpose", 36 | "trim", 37 | ] 38 | 39 | MultitrackType = TypeVar("MultitrackType", bound=Multitrack) 40 | MultitrackOrTrackType = TypeVar("MultitrackOrTrackType", Multitrack, Track) 41 | StandardTrackType = TypeVar("StandardTrackType", Multitrack, StandardTrack) 42 | 43 | 44 | @overload 45 | def set_nonzeros(obj: Multitrack, value: int) -> Multitrack: 46 | pass 47 | 48 | 49 | @overload 50 | def set_nonzeros( 51 | obj: Union[StandardTrack, BinaryTrack], value: int 52 | ) -> StandardTrack: 53 | pass 54 | 55 | 56 | def set_nonzeros( 57 | obj: Union[Multitrack, StandardTrack, BinaryTrack], value: int 58 | ): 59 | """Assign a constant value to all nonzeros entries. 60 | 61 | Arguments 62 | --------- 63 | obj : :class:`pypianoroll.Multitrack`, \ 64 | :class:`pypianoroll.StandardTrack` or \ 65 | :class:`pypianoroll.BinaryTrack` 66 | Object to modify. 67 | value : int 68 | Value to assign. 69 | 70 | """ 71 | return obj.set_nonzeros(value=value) 72 | 73 | 74 | @overload 75 | def binarize(obj: Multitrack, threshold: int = 0) -> Multitrack: 76 | pass 77 | 78 | 79 | @overload 80 | def binarize(obj: StandardTrack, threshold: int = 0) -> BinaryTrack: 81 | pass 82 | 83 | 84 | def binarize(obj: Union[Multitrack, StandardTrack], threshold: int = 0): 85 | """Binarize the piano roll(s). 86 | 87 | Parameters 88 | ---------- 89 | obj : :class:`pypianoroll.Multitrack` or \ 90 | :class:`pypianoroll.StandardTrack` 91 | Object to binarize. 92 | threshold : int, default: 0 93 | Threshold. 94 | 95 | """ 96 | return obj.binarize(threshold=threshold) 97 | 98 | 99 | def clip( 100 | obj: MultitrackType, lower: int = 0, upper: int = 127 101 | ) -> MultitrackType: 102 | """Clip (limit) the the piano roll(s) into [`lower`, `upper`]. 103 | 104 | Parameters 105 | ---------- 106 | obj : :class:`pypianoroll.Multitrack` or \ 107 | :class:`pypianoroll.StandardTrack` 108 | Object to clip. 109 | lower : int, default: 0 110 | Lower bound. 111 | upper : int, default: 127 112 | Upper bound. 113 | 114 | Returns 115 | ------- 116 | Object itself. 117 | 118 | """ 119 | return obj.clip(lower=lower, upper=upper) 120 | 121 | 122 | def set_resolution( 123 | obj: MultitrackType, resolution: int, rounding: str = "round" 124 | ) -> MultitrackType: 125 | """Downsample the piano rolls by a factor. 126 | 127 | Parameters 128 | ---------- 129 | obj : :class:`pypianoroll.Multitrack` 130 | Object to downsample. 131 | resolution : int 132 | Target resolution. 133 | rounding : {'round', 'ceil', 'floor'}, default: 'round' 134 | Rounding mode. 135 | 136 | Returns 137 | ------- 138 | Object itself. 139 | 140 | """ 141 | return obj.set_resolution(resolution=resolution, rounding=rounding) 142 | 143 | 144 | def pad(obj: MultitrackOrTrackType, pad_length: int) -> MultitrackOrTrackType: 145 | """Pad the piano roll(s). 146 | 147 | Notes 148 | ----- 149 | The lengths of the resulting piano rolls are not guaranteed to be 150 | the same. See :meth:`pypianoroll.Multitrack.pad_to_same`. 151 | 152 | Parameters 153 | ---------- 154 | obj : :class:`pypianoroll.Multitrack` or :class:`pypianoroll.Track` 155 | Object to pad. 156 | pad_length : int 157 | Length to pad along the time axis. 158 | 159 | Returns 160 | ------- 161 | Object itself. 162 | 163 | See Also 164 | -------- 165 | :func:`pypianoroll.pad_to_same` : Pad the piano rolls so that they 166 | have the same length. 167 | :func:`pypianoroll.pad_to_multiple` : Pad the piano rolls so that 168 | their lengths are some multiples. 169 | 170 | """ 171 | return obj.pad(pad_length=pad_length) 172 | 173 | 174 | def pad_to_multiple( 175 | obj: MultitrackOrTrackType, factor: int 176 | ) -> MultitrackOrTrackType: 177 | """Pad the piano roll(s) so that their lengths are some multiples. 178 | 179 | Pad the piano rolls at the end along the time axis of the 180 | minimum length that makes the lengths of the resulting piano rolls 181 | multiples of `factor`. 182 | 183 | Parameters 184 | ---------- 185 | obj : :class:`pypianoroll.Multitrack` or :class:`pypianoroll.Track` 186 | Object to pad. 187 | factor : int 188 | The value which the length of the resulting pianoroll(s) will be 189 | a multiple of. 190 | 191 | Returns 192 | ------- 193 | Object itself. 194 | 195 | Notes 196 | ----- 197 | Lengths of the resulting piano rolls are necessarily the same. 198 | 199 | See Also 200 | -------- 201 | :func:`pypianoroll.pad` : Pad the piano rolls. 202 | :func:`pypianoroll.pad_to_same` : Pad the piano rolls so that they 203 | have the same length. 204 | 205 | """ 206 | return obj.pad_to_multiple(factor=factor) 207 | 208 | 209 | def pad_to_same(obj: MultitrackType) -> MultitrackType: 210 | """Pad the piano rolls so that they have the same length. 211 | 212 | Pad shorter piano rolls at the end along the time axis so that the 213 | resulting piano rolls have the same length. 214 | 215 | Parameters 216 | ---------- 217 | obj : :class:`pypianoroll.Multitrack` 218 | Object to pad. 219 | 220 | Returns 221 | ------- 222 | Object itself. 223 | 224 | See Also 225 | -------- 226 | :func:`pypianoroll.pad` : Pad the piano rolls. 227 | :func:`pypianoroll.pad_to_multiple` : Pad the piano 228 | rolls so that their lengths are some multiples. 229 | 230 | """ 231 | return obj.pad_to_same() 232 | 233 | 234 | def transpose( 235 | obj: MultitrackOrTrackType, semitone: int 236 | ) -> MultitrackOrTrackType: 237 | """Transpose the piano roll(s) by a number of semitones. 238 | 239 | Positive values are for a higher key, while negative values are for 240 | a lower key. Drum tracks are ignored. 241 | 242 | Parameters 243 | ---------- 244 | obj : :class:`pypianoroll.Multitrack` or :class:`pypianoroll.Track` 245 | Object to transpose. 246 | semitone : int 247 | Number of semitones to transpose. A positive value raises the 248 | pitches, while a negative value lowers the pitches. 249 | 250 | Returns 251 | ------- 252 | Object itself. 253 | 254 | """ 255 | return obj.transpose(semitone=semitone) 256 | 257 | 258 | def trim( 259 | obj: MultitrackOrTrackType, start: int = None, end: int = None 260 | ) -> MultitrackOrTrackType: 261 | """Trim the trailing silences of the piano roll(s). 262 | 263 | Parameters 264 | ---------- 265 | obj : :class:`pypianoroll.Multitrack` or :class:`pypianoroll.Track` 266 | Object to trim. 267 | start : int, default: 0 268 | Start time. 269 | end : int, optional 270 | End time. Defaults to active length. 271 | 272 | Returns 273 | ------- 274 | Object itself. 275 | 276 | """ 277 | return obj.trim(start=start, end=end) 278 | 279 | 280 | def plot(obj: Union[Track, Multitrack], **kwargs) -> Union[List[Axes], Axes]: 281 | """Plot the object. 282 | 283 | See :func:`pypianoroll.plot_multitrack` and 284 | :func:`pypianoroll.plot_track` for full documentation. 285 | 286 | """ 287 | return obj.plot(**kwargs) 288 | -------------------------------------------------------------------------------- /pypianoroll/inputs.py: -------------------------------------------------------------------------------- 1 | """Input interfaces. 2 | 3 | Functions 4 | --------- 5 | 6 | - load 7 | - from_pretty_midi 8 | - read 9 | 10 | """ 11 | import json 12 | from pathlib import Path 13 | from typing import Union 14 | 15 | import numpy as np 16 | from pretty_midi import PrettyMIDI 17 | 18 | from .multitrack import DEFAULT_RESOLUTION, Multitrack 19 | from .track import BinaryTrack, StandardTrack, Track 20 | from .utils import reconstruct_sparse 21 | 22 | __all__ = ["load", "from_pretty_midi", "read"] 23 | 24 | 25 | def load(path: Union[str, Path]) -> Multitrack: 26 | """Load a NPZ file into a Multitrack object. 27 | 28 | Supports only files previously saved by :func:`pypianoroll.save`. 29 | 30 | Parameters 31 | ---------- 32 | path : str or Path 33 | Path to the file to load. 34 | 35 | See Also 36 | -------- 37 | :func:`pypianoroll.save` : Save a Multitrack object to a NPZ file. 38 | :func:`pypianoroll.read` : Read a MIDI file into a Multitrack 39 | object. 40 | 41 | """ 42 | with np.load(path) as loaded: 43 | if "info.json" not in loaded: 44 | raise RuntimeError("Cannot find `info.json` in the NPZ file.") 45 | 46 | # Load the info dictionary 47 | info_dict = json.loads(loaded["info.json"].decode("utf-8")) 48 | 49 | # Get the resolution 50 | resolution = info_dict.get("resolution") 51 | 52 | # Look for `beat_resolution` for backward compatibility 53 | if resolution is None: 54 | resolution = info_dict.get("beat_resolution") 55 | if resolution is None: 56 | raise RuntimeError( 57 | "Cannot find `resolution` or `beat_resolution` in " 58 | "`info.json`." 59 | ) 60 | 61 | # Load the tracks 62 | idx = 0 63 | tracks = [] 64 | while str(idx) in info_dict: 65 | name = info_dict[str(idx)].get("name") 66 | program = info_dict[str(idx)].get("program") 67 | is_drum = info_dict[str(idx)].get("is_drum") 68 | pianoroll = reconstruct_sparse(loaded, "pianoroll_" + str(idx)) 69 | if pianoroll.dtype == np.bool_: 70 | track: Track = BinaryTrack( 71 | name=name, 72 | program=program, 73 | is_drum=is_drum, 74 | pianoroll=pianoroll, 75 | ) 76 | elif pianoroll.dtype == np.uint8: 77 | track = StandardTrack( 78 | name=name, 79 | program=program, 80 | is_drum=is_drum, 81 | pianoroll=pianoroll, 82 | ) 83 | else: 84 | track = Track( 85 | name=name, 86 | program=program, 87 | is_drum=is_drum, 88 | pianoroll=pianoroll, 89 | ) 90 | tracks.append(track) 91 | idx += 1 92 | 93 | return Multitrack( 94 | name=info_dict["name"], 95 | resolution=resolution, 96 | tempo=loaded.get("tempo"), 97 | beat=loaded.get("beat"), 98 | downbeat=loaded.get("downbeat"), 99 | tracks=tracks, 100 | ) 101 | 102 | 103 | def from_pretty_midi( 104 | midi: PrettyMIDI, 105 | resolution: int = DEFAULT_RESOLUTION, 106 | mode: str = "max", 107 | algorithm: str = "normal", 108 | collect_onsets_only: bool = False, 109 | first_beat_time: float = None, 110 | ) -> Multitrack: 111 | """Return a Multitrack object converted from a PrettyMIDI object. 112 | 113 | Parse a :class:`pretty_midi.PrettyMIDI` object. The data type of the 114 | resulting piano rolls is automatically determined (int if 'mode' is 115 | 'sum' and np.uint8 if `mode` is 'max'). 116 | 117 | Parameters 118 | ---------- 119 | midi : :class:`pretty_midi.PrettyMIDI` 120 | PrettyMIDI object to parse. 121 | resolution : int, default: `pypianoroll.DEFAULT_RESOLUTION` (24) 122 | Time steps per quarter note. 123 | mode : {'max', 'sum'}, default: 'max' 124 | Merging strategy for duplicate notes. 125 | algorithm : {'normal', 'strict', 'custom'}, default: 'normal' 126 | Algorithm for finding the location of the first beat (see 127 | Notes). 128 | collect_onsets_only : bool, default: False 129 | True to collect only the onset of the notes (i.e. note on 130 | events) in all tracks, where the note off and duration 131 | information are discarded. False to parse regular piano rolls. 132 | first_beat_time : float, optional 133 | Location of the first beat, in sec. Required and only 134 | effective when using 'custom' algorithm. 135 | 136 | Returns 137 | ------- 138 | :class:`pypianoroll.Multitrack` 139 | Converted Multitrack object. 140 | 141 | Notes 142 | ----- 143 | There are three algorithms for finding the location of the first 144 | beat: 145 | 146 | - 'normal' : Estimate the location of the first beat using 147 | :meth:`pretty_midi.PrettyMIDI.estimate_beat_start`. 148 | - 'strict' : Set the location of the first beat to the time of the 149 | first time signature change. Raise a RuntimeError if no time 150 | signature change is found. 151 | - 'custom' : Set the location of the first beat to the value of 152 | argument `first_beat_time`. Raise a ValueError if 153 | `first_beat_time` is not given. 154 | 155 | If an incomplete beat before the first beat is found, an additional 156 | beat will be added before the (estimated) beat starting time. 157 | However, notes before the (estimated) beat starting time for more 158 | than one beat are dropped. 159 | 160 | """ 161 | if mode not in ("max", "sum"): 162 | raise ValueError("`mode` must be either 'max' or 'sum'.") 163 | 164 | # Set first_beat_time for 'normal' and 'strict' modes 165 | if algorithm == "normal": 166 | if midi.time_signature_changes: 167 | midi.time_signature_changes.sort(key=lambda x: x.time) 168 | first_beat_time = midi.time_signature_changes[0].time 169 | else: 170 | first_beat_time = midi.estimate_beat_start() 171 | elif algorithm == "strict": 172 | if not midi.time_signature_changes: 173 | raise RuntimeError( 174 | "No time signature change event found. Unable to set beat " 175 | "start time using 'strict' algorithm." 176 | ) 177 | midi.time_signature_changes.sort(key=lambda x: x.time) 178 | first_beat_time = midi.time_signature_changes[0].time 179 | elif algorithm == "custom": 180 | if first_beat_time is None: 181 | raise TypeError( 182 | "`first_beat_time` must be given when using 'custom' " 183 | "algorithm." 184 | ) 185 | if first_beat_time < 0.0: 186 | raise ValueError("`first_beat_time` must be a positive number.") 187 | else: 188 | raise ValueError( 189 | "`algorithm` must be one of 'normal', 'strict' or 'custom'." 190 | ) 191 | 192 | # get tempo change event times and contents 193 | tc_times, tempi = midi.get_tempo_changes() 194 | arg_sorted = np.argsort(tc_times) 195 | tc_times = tc_times[arg_sorted] 196 | tempi = tempi[arg_sorted] 197 | 198 | beat_times = midi.get_beats(first_beat_time) 199 | if not beat_times.size: 200 | raise ValueError("Cannot get beat timings to quantize the piano roll.") 201 | beat_times.sort() 202 | 203 | n_beats = len(beat_times) 204 | n_time_steps = resolution * n_beats 205 | 206 | # Parse downbeat array 207 | if not midi.time_signature_changes: 208 | # This probably won't happen as pretty_midi always add a 4/4 time 209 | # signature at time 0 210 | beat = None 211 | downbeat = None 212 | else: 213 | beat = np.zeros((n_time_steps, 1), bool) 214 | downbeat = np.zeros((n_time_steps, 1), bool) 215 | beat[0] = True 216 | downbeat[0] = True 217 | start = 0 218 | end = start 219 | for idx, tsc in enumerate(midi.time_signature_changes): 220 | start_idx = start * resolution 221 | if idx + 1 < len(midi.time_signature_changes): 222 | end += np.searchsorted( 223 | beat_times[end:], midi.time_signature_changes[idx + 1].time 224 | ) 225 | end_idx = end * resolution 226 | else: 227 | end_idx = n_time_steps 228 | beat[start_idx:end_idx:resolution] = True 229 | stride = tsc.numerator * resolution 230 | downbeat[start_idx:end_idx:stride] = True 231 | start = end 232 | 233 | # Build tempo array 234 | one_more_beat = 2 * beat_times[-1] - beat_times[-2] 235 | beat_times_one_more = np.append(beat_times, one_more_beat) 236 | bpm = 60.0 / np.diff(beat_times_one_more) 237 | tempo = np.tile(bpm, (1, 24)).reshape(-1, 1) 238 | 239 | # Parse the tracks 240 | tracks = [] 241 | for instrument in midi.instruments: 242 | if mode == "max": 243 | pianoroll = np.zeros((n_time_steps, 128), np.uint8) 244 | else: 245 | pianoroll = np.zeros((n_time_steps, 128), int) 246 | 247 | pitches = np.array( 248 | [ 249 | note.pitch 250 | for note in instrument.notes 251 | if note.end > first_beat_time 252 | ] 253 | ) 254 | note_on_times = np.array( 255 | [ 256 | note.start 257 | for note in instrument.notes 258 | if note.end > first_beat_time 259 | ] 260 | ) 261 | beat_indices = np.searchsorted(beat_times, note_on_times) - 1 262 | remained = note_on_times - beat_times[beat_indices] 263 | ratios = remained / ( 264 | beat_times_one_more[beat_indices + 1] - beat_times[beat_indices] 265 | ) 266 | rounded = np.round((beat_indices + ratios) * resolution) 267 | note_ons = rounded.astype(int) 268 | 269 | if collect_onsets_only: 270 | pianoroll[note_ons, pitches] = True 271 | elif instrument.is_drum: 272 | velocities = [ 273 | note.velocity 274 | for note in instrument.notes 275 | if note.end > first_beat_time 276 | ] 277 | pianoroll[note_ons, pitches] = velocities 278 | else: 279 | note_off_times = np.array( 280 | [ 281 | note.end 282 | for note in instrument.notes 283 | if note.end > first_beat_time 284 | ] 285 | ) 286 | beat_indices = np.searchsorted(beat_times, note_off_times) - 1 287 | remained = note_off_times - beat_times[beat_indices] 288 | ratios = remained / ( 289 | beat_times_one_more[beat_indices + 1] 290 | - beat_times[beat_indices] 291 | ) 292 | note_offs = ((beat_indices + ratios) * resolution).astype(int) 293 | 294 | for idx, start in enumerate(note_ons): 295 | end = note_offs[idx] 296 | velocity = instrument.notes[idx].velocity 297 | 298 | if velocity < 1: 299 | continue 300 | 301 | if 0 < start < n_time_steps: 302 | if pianoroll[start - 1, pitches[idx]]: 303 | pianoroll[start - 1, pitches[idx]] = 0 304 | if end < n_time_steps - 1: 305 | if pianoroll[end, pitches[idx]]: 306 | end -= 1 307 | 308 | if mode == "max": 309 | pianoroll[start:end, pitches[idx]] += velocity 310 | else: 311 | maximum = np.maximum( 312 | pianoroll[start:end, pitches[idx]], velocity 313 | ) 314 | pianoroll[start:end, pitches[idx]] = maximum 315 | 316 | if mode == "max": 317 | track: Track = StandardTrack( 318 | name=str(instrument.name), 319 | program=int(instrument.program), 320 | is_drum=bool(instrument.is_drum), 321 | pianoroll=pianoroll, 322 | ) 323 | else: 324 | track = Track( 325 | name=str(instrument.name), 326 | program=int(instrument.program), 327 | is_drum=bool(instrument.is_drum), 328 | pianoroll=pianoroll, 329 | ) 330 | tracks.append(track) 331 | 332 | return Multitrack( 333 | resolution=resolution, 334 | tempo=tempo, 335 | beat=beat, 336 | downbeat=downbeat, 337 | tracks=tracks, 338 | ) 339 | 340 | 341 | def read(path: Union[str, Path], **kwargs) -> Multitrack: 342 | """Read a MIDI file into a Multitrack object. 343 | 344 | Parameters 345 | ---------- 346 | path : str or Path 347 | Path to the file to read. 348 | **kwargs 349 | Keyword arguments to pass to 350 | :func:`pypianoroll.from_pretty_midi`. 351 | 352 | See Also 353 | -------- 354 | :func:`pypianoroll.write` : Write a Multitrack object to a MIDI 355 | file. 356 | :func:`pypianoroll.load` : Load a NPZ file into a Multitrack object. 357 | 358 | """ 359 | return from_pretty_midi(PrettyMIDI(str(path)), **kwargs) 360 | -------------------------------------------------------------------------------- /pypianoroll/metrics.py: -------------------------------------------------------------------------------- 1 | """Objective metrics for piano rolls. 2 | 3 | Functions 4 | --------- 5 | 6 | - drum_in_pattern_rate 7 | - empty_beat_rate 8 | - in_scale_rate 9 | - n_pitch_classes_used 10 | - n_pitches_used 11 | - pitch_range 12 | - pitch_range_tuple 13 | - polyphonic_rate 14 | - qualified_note_rate 15 | - tonal_distance 16 | 17 | """ 18 | from math import nan 19 | from typing import Sequence, Tuple 20 | 21 | import numpy as np 22 | from numpy import ndarray 23 | 24 | __all__ = [ 25 | "drum_in_pattern_rate", 26 | "empty_beat_rate", 27 | "in_scale_rate", 28 | "n_pitch_classes_used", 29 | "n_pitches_used", 30 | "pitch_range", 31 | "pitch_range_tuple", 32 | "polyphonic_rate", 33 | "qualified_note_rate", 34 | "tonal_distance", 35 | ] 36 | 37 | 38 | def _to_chroma(pianoroll: ndarray) -> ndarray: 39 | """Return the unnormalized chroma features.""" 40 | reshaped = pianoroll[:, :120].reshape(-1, 12, 10) 41 | reshaped[..., :8] += pianoroll[:, 120:].reshape(-1, 1, 8) 42 | return np.sum(reshaped, -1) 43 | 44 | 45 | def empty_beat_rate(pianoroll: ndarray, resolution: int) -> float: 46 | r"""Return the ratio of empty beats. 47 | 48 | The empty-beat rate is defined as the ratio of the number of empty 49 | beats (where no note is played) to the total number of beats. Return 50 | NaN if song length is zero. 51 | 52 | .. math:: empty\_beat\_rate = \frac{\#(empty\_beats)}{\#(beats)} 53 | 54 | Parameters 55 | ---------- 56 | pianoroll : ndarray 57 | Piano roll to evaluate. 58 | 59 | Returns 60 | ------- 61 | float 62 | Empty-beat rate. 63 | 64 | """ 65 | reshaped = pianoroll.reshape(-1, resolution * pianoroll.shape[1]) 66 | if len(reshaped) < 1: 67 | return nan 68 | n_empty_beats = np.count_nonzero(reshaped.any(1)) 69 | return n_empty_beats / len(reshaped) 70 | 71 | 72 | def n_pitches_used(pianoroll: ndarray) -> int: 73 | """Return the number of unique pitches used. 74 | 75 | Parameters 76 | ---------- 77 | pianoroll : ndarray 78 | Piano roll to evaluate. 79 | 80 | Returns 81 | ------- 82 | int 83 | Number of unique pitch classes used. 84 | 85 | See Also 86 | -------- 87 | :func:`pypianoroll.n_pitch_class_used`: Compute the number of unique 88 | pitch classes used. 89 | 90 | """ 91 | return np.count_nonzero(np.any(pianoroll, 0)) 92 | 93 | 94 | def n_pitch_classes_used(pianoroll: ndarray) -> int: 95 | """Return the number of unique pitch classes used. 96 | 97 | Parameters 98 | ---------- 99 | pianoroll : ndarray 100 | Piano roll to evaluate. 101 | 102 | Returns 103 | ------- 104 | int 105 | Number of unique pitch classes used. 106 | 107 | See Also 108 | -------- 109 | :func:`pypianoroll.n_pitches_used`: Compute the number of unique 110 | pitches used. 111 | 112 | """ 113 | return np.count_nonzero(_to_chroma(pianoroll).any(0)) 114 | 115 | 116 | def pitch_range_tuple(pianoroll) -> Tuple[float, float]: 117 | """Return the pitch range as a tuple `(lowest, highest)`. 118 | 119 | Returns 120 | ------- 121 | int or nan 122 | Highest active pitch. 123 | int or nan 124 | Lowest active pitch. 125 | 126 | See Also 127 | -------- 128 | :func:`pypianoroll.pitch_range`: Compute the pitch range. 129 | 130 | """ 131 | nonzero = pianoroll.any(0).nonzero()[0] 132 | if not nonzero.size: 133 | return nan, nan 134 | return nonzero[0], nonzero[-1] 135 | 136 | 137 | def pitch_range(pianoroll) -> float: 138 | """Return the pitch range. 139 | 140 | Returns 141 | ------- 142 | int or nan 143 | Pitch range (in semitones), i.e., difference between the 144 | highest and the lowest active pitches. 145 | 146 | See Also 147 | -------- 148 | :func:`pypianoroll.pitch_range_tuple`: Return the pitch range as a 149 | tuple. 150 | 151 | """ 152 | lowest, highest = pitch_range_tuple(pianoroll) 153 | return highest - lowest 154 | 155 | 156 | def qualified_note_rate(pianoroll: ndarray, threshold: float = 2) -> float: 157 | r"""Return the ratio of the number of the qualified notes. 158 | 159 | The qualified note rate is defined as the ratio of the number of 160 | qualified notes (notes longer than `threshold`, in time steps) to 161 | the total number of notes. Return NaN if no note is found. 162 | 163 | .. math:: 164 | qualified\_note\_rate = \frac{ 165 | \#(notes\_longer\_than\_the\_threshold) 166 | }{ 167 | \#(notes) 168 | } 169 | 170 | Parameters 171 | ---------- 172 | pianoroll : ndarray 173 | Piano roll to evaluate. 174 | threshold : int 175 | Threshold of note length to count into the numerator. 176 | 177 | Returns 178 | ------- 179 | float 180 | Qualified note rate. 181 | 182 | References 183 | ---------- 184 | 1. Hao-Wen Dong, Wen-Yi Hsiao, Li-Chia Yang, and Yi-Hsuan Yang, 185 | "MuseGAN: Multi-track sequential generative adversarial networks 186 | for symbolic music generation and accompaniment," in Proceedings 187 | of the 32nd AAAI Conference on Artificial Intelligence (AAAI), 188 | 2018. 189 | 190 | """ 191 | if np.issubdtype(pianoroll.dtype, np.bool_): 192 | pianoroll = pianoroll.astype(np.uint8) 193 | padded = np.pad(pianoroll, ((1, 1), (0, 0)), "constant") 194 | diff = np.diff(padded, axis=0).reshape(-1) 195 | onsets = (diff > 0).nonzero()[0] 196 | if len(onsets) < 1: 197 | return nan 198 | offsets = (diff < 0).nonzero()[0] 199 | n_qualified_notes = np.count_nonzero(offsets - onsets >= threshold) 200 | return n_qualified_notes / len(onsets) 201 | 202 | 203 | def polyphonic_rate(pianoroll: ndarray, threshold: float = 2) -> float: 204 | r"""Return the ratio of time steps where multiple pitches are on. 205 | 206 | The polyphony rate is defined as the ratio of the number of time 207 | steps where multiple pitches are on to the total number of time 208 | steps. Drum tracks are ignored. Return NaN if song length is zero. 209 | This metric is used in [1], where it is called *polyphonicity*. 210 | 211 | .. math:: 212 | polyphony\_rate = \frac{ 213 | \#(time\_steps\_where\_multiple\_pitches\_are\_on) 214 | }{ 215 | \#(time\_steps) 216 | } 217 | 218 | Parameters 219 | ---------- 220 | pianoroll : ndarray 221 | Piano roll to evaluate. 222 | threshold : int 223 | Threshold of number of pitches to count into the numerator. 224 | 225 | Returns 226 | ------- 227 | float 228 | Polyphony rate. 229 | 230 | References 231 | ---------- 232 | 1. Hao-Wen Dong, Wen-Yi Hsiao, Li-Chia Yang, and Yi-Hsuan Yang, 233 | "MuseGAN: Multi-track sequential generative adversarial networks 234 | for symbolic music generation and accompaniment," in Proceedings 235 | of the 32nd AAAI Conference on Artificial Intelligence (AAAI), 236 | 2018. 237 | 238 | """ 239 | if len(pianoroll) < 1: 240 | return nan 241 | n_poly = np.count_nonzero(np.count_nonzero(pianoroll, 1) > threshold) 242 | return n_poly / len(pianoroll) 243 | 244 | 245 | def drum_in_pattern_rate( 246 | pianoroll: ndarray, resolution: int, tolerance: float = 0.1 247 | ) -> float: 248 | r"""Return the ratio of drum notes in a certain drum pattern. 249 | 250 | The drum-in-pattern rate is defined as the ratio of the number of 251 | notes in a certain scale to the total number of notes. Only drum 252 | tracks are considered. Return NaN if no drum note is found. This 253 | metric is used in [1]. 254 | 255 | .. math:: 256 | drum\_in\_pattern\_rate = \frac{ 257 | \#(drum\_notes\_in\_pattern)}{\#(drum\_notes)} 258 | 259 | Parameters 260 | ---------- 261 | pianoroll : ndarray 262 | Piano roll to evaluate. 263 | resolution : int 264 | Time steps per beat. 265 | tolerance : float, default: 0.1 266 | Tolerance. 267 | 268 | Returns 269 | ------- 270 | float 271 | Drum-in-pattern rate. 272 | 273 | References 274 | ---------- 275 | 1. Hao-Wen Dong, Wen-Yi Hsiao, Li-Chia Yang, and Yi-Hsuan Yang, 276 | "MuseGAN: Multi-track sequential generative adversarial networks 277 | for symbolic music generation and accompaniment," in Proceedings 278 | of the 32nd AAAI Conference on Artificial Intelligence (AAAI), 279 | 2018. 280 | 281 | """ 282 | if resolution not in (4, 6, 8, 9, 12, 16, 18, 24): 283 | raise ValueError( 284 | "Unsupported beat resolution. Expect 4, 6, 8 ,9, 12, 16, 18 or 24." 285 | ) 286 | 287 | def _drum_pattern_mask(res, tol): 288 | """Return a drum pattern mask with the given tolerance.""" 289 | if res == 24: 290 | drum_pattern_mask = np.tile([1.0, tol, 0.0, 0.0, 0.0, tol], 4) 291 | elif res == 12: 292 | drum_pattern_mask = np.tile([1.0, tol, tol], 4) 293 | elif res == 6: 294 | drum_pattern_mask = np.tile([1.0, tol, tol], 2) 295 | elif res == 18: 296 | drum_pattern_mask = np.tile([1.0, tol, 0.0, 0.0, 0.0, tol], 3) 297 | elif res == 9: 298 | drum_pattern_mask = np.tile([1.0, tol, tol], 3) 299 | elif res == 16: 300 | drum_pattern_mask = np.tile([1.0, tol, 0.0, tol], 4) 301 | elif res == 8: 302 | drum_pattern_mask = np.tile([1.0, tol], 4) 303 | elif res == 4: 304 | drum_pattern_mask = np.tile([1.0, tol], 2) 305 | return drum_pattern_mask 306 | 307 | drum_pattern_mask = _drum_pattern_mask(resolution, tolerance) 308 | n_in_pattern = np.sum(drum_pattern_mask * np.count_nonzero(pianoroll, 1)) 309 | return n_in_pattern / np.count_nonzero(pianoroll) 310 | 311 | 312 | def _get_scale(root: int, mode: str) -> ndarray: 313 | """Return the scale mask for a specific root.""" 314 | if mode == "major": 315 | a_scale_mask = np.array([0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1], bool) 316 | else: 317 | a_scale_mask = np.array([1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1], bool) 318 | return np.roll(a_scale_mask, root) 319 | 320 | 321 | def in_scale_rate( 322 | pianoroll: ndarray, root: int = 3, mode: str = "major" 323 | ) -> float: 324 | r"""Return the ratio of pitches in a certain musical scale. 325 | 326 | The pitch-in-scale rate is defined as the ratio of the number of 327 | notes in a certain scale to the total number of notes. Drum tracks 328 | are ignored. Return NaN if no note is found. This metric is used in 329 | [1]. 330 | 331 | .. math:: 332 | pitch\_in\_scale\_rate = \frac{\#(notes\_in\_scale)}{\#(notes)} 333 | 334 | Parameters 335 | ---------- 336 | pianoroll : ndarray 337 | Piano roll to evaluate. 338 | root : int 339 | Root of the scale. 340 | mode : str, {'major', 'minor'} 341 | Mode of the scale. 342 | 343 | Returns 344 | ------- 345 | float 346 | Pitch-in-scale rate. 347 | 348 | See Also 349 | -------- 350 | :func:`muspy.scale_consistency`: Compute the largest pitch-in-class 351 | rate. 352 | 353 | References 354 | ---------- 355 | 1. Hao-Wen Dong, Wen-Yi Hsiao, Li-Chia Yang, and Yi-Hsuan Yang, 356 | "MuseGAN: Multi-track sequential generative adversarial networks 357 | for symbolic music generation and accompaniment," in Proceedings 358 | of the 32nd AAAI Conference on Artificial Intelligence (AAAI), 359 | 2018. 360 | 361 | """ 362 | chroma = _to_chroma(pianoroll) 363 | scale_mask = _get_scale(root, mode) 364 | n_in_scale = np.sum(scale_mask.reshape(-1, 12) * chroma) 365 | return n_in_scale / np.count_nonzero(pianoroll) 366 | 367 | 368 | def _get_tonal_matrix(r1, r2, r3) -> ndarray: # pylint: disable=invalid-name 369 | """Return a tonal matrix for computing the tonal distance.""" 370 | tonal_matrix = np.empty((6, 12)) 371 | tonal_matrix[0] = r1 * np.sin(np.arange(12) * (7.0 / 6.0) * np.pi) 372 | tonal_matrix[1] = r1 * np.cos(np.arange(12) * (7.0 / 6.0) * np.pi) 373 | tonal_matrix[2] = r2 * np.sin(np.arange(12) * (3.0 / 2.0) * np.pi) 374 | tonal_matrix[3] = r2 * np.cos(np.arange(12) * (3.0 / 2.0) * np.pi) 375 | tonal_matrix[4] = r3 * np.sin(np.arange(12) * (2.0 / 3.0) * np.pi) 376 | tonal_matrix[5] = r3 * np.cos(np.arange(12) * (2.0 / 3.0) * np.pi) 377 | return tonal_matrix 378 | 379 | 380 | def _to_tonal_space( 381 | pianoroll: ndarray, resolution: int, tonal_matrix: ndarray 382 | ) -> ndarray: 383 | """Return the tensor in tonal space (chroma normalized per beat).""" 384 | beat_chroma = _to_chroma(pianoroll).reshape((-1, resolution, 12)) 385 | beat_chroma = beat_chroma / beat_chroma.sum(2, keepdims=True) 386 | return np.matmul(tonal_matrix, beat_chroma.T).T 387 | 388 | 389 | def tonal_distance( 390 | pianoroll_1: ndarray, 391 | pianoroll_2: ndarray, 392 | resolution: int, 393 | radii: Sequence[float] = (1.0, 1.0, 0.5), 394 | ) -> float: 395 | """Return the tonal distance [1] between the two input piano rolls. 396 | 397 | Parameters 398 | ---------- 399 | pianoroll_1 : ndarray 400 | First piano roll to evaluate. 401 | pianoroll_2 : ndarray 402 | Second piano roll to evaluate. 403 | resolution : int 404 | Time steps per beat. 405 | radii : tuple of float 406 | Radii of the three tonal circles (see Equation 3 in [1]). 407 | 408 | References 409 | ---------- 410 | 1. Christopher Harte, Mark Sandler, and Martin Gasser, "Detecting 411 | harmonic change in musical audio," in Proceedings of the 1st ACM 412 | workshop on Audio and music computing multimedia, 2006. 413 | 414 | """ 415 | assert len(pianoroll_1) == len( 416 | pianoroll_2 417 | ), "Input piano rolls must have the same length." 418 | 419 | r1, r2, r3 = radii # pylint: disable=invalid-name 420 | tonal_matrix = _get_tonal_matrix(r1, r2, r3) 421 | mapped_1 = _to_tonal_space(pianoroll_1, resolution, tonal_matrix) 422 | mapped_2 = _to_tonal_space(pianoroll_2, resolution, tonal_matrix) 423 | return np.linalg.norm(mapped_1 - mapped_2) 424 | -------------------------------------------------------------------------------- /pypianoroll/multitrack.py: -------------------------------------------------------------------------------- 1 | """Class for multitrack piano rolls. 2 | 3 | Class 4 | ----- 5 | 6 | - Multitrack 7 | 8 | Variable 9 | -------- 10 | 11 | - DEFAULT_RESOLUTION 12 | 13 | """ 14 | from typing import List, Sequence, TypeVar 15 | 16 | import numpy as np 17 | from matplotlib.axes import Axes 18 | from numpy import ndarray 19 | 20 | from .outputs import save, to_pretty_midi, write 21 | from .track import BinaryTrack, StandardTrack, Track 22 | from .visualization import plot_multitrack 23 | 24 | __all__ = [ 25 | "Multitrack", 26 | "DEFAULT_RESOLUTION", 27 | ] 28 | 29 | DEFAULT_RESOLUTION = 24 30 | 31 | MultitrackType = TypeVar("MultitrackType", bound="Multitrack") 32 | 33 | 34 | def _round_time(time, factor, rounding): 35 | if rounding == "round": 36 | return np.round(time * factor).astype(int) 37 | if rounding == "ceil": 38 | return np.ceil(time * factor).astype(int) 39 | if rounding == "floor": 40 | return np.floor(time * factor).astype(int) 41 | raise ValueError( 42 | "`rounding` must be one of 'round', 'ceil' or 'floor', " 43 | f"not {rounding}." 44 | ) 45 | 46 | 47 | class Multitrack: 48 | """A container for multitrack piano rolls. 49 | 50 | This is the core class of Pypianoroll. A Multitrack object can be 51 | constructed in the following ways. 52 | 53 | - :meth:`pypianoroll.Multitrack`: Construct by setting values for 54 | attributes 55 | - :meth:`pypianoroll.read`: Read from a MIDI file 56 | - :meth:`pypianoroll.from_pretty_midi`: Convert from a 57 | :class:`pretty_midi.PrettyMIDI` object 58 | - :func:`pypianoroll.load`: Load from a JSON or a YAML file saved by 59 | :func:`pypianoroll.save` 60 | 61 | Attributes 62 | ---------- 63 | name : str, optional 64 | Multitrack name. 65 | resolution : int, default: `pypianoroll.DEFAULT_RESOLUTION` (24) 66 | Time steps per quarter note. 67 | tempo : ndarray, dtype=float, shape=(?, 1), optional 68 | Tempo (in qpm) at each time step. Length is the total number 69 | of time steps. Cast to float if not of float type. 70 | beat : ndarray, dtype=bool, shape=(?, 1), optional 71 | A boolean array that indicates whether the time step contains a 72 | beat. Length is the total number of time steps. Cast to bool if 73 | not of bool type. 74 | downbeat : ndarray, dtype=bool, shape=(?, 1), optional 75 | A boolean array that indicates whether the time step contains a 76 | downbeat, i.e., the first time step of a measure. Length is the 77 | total number of time steps. Cast to bool if not of bool type. 78 | tracks : sequence of :class:`pypianoroll.Track`, default: [] 79 | Music tracks. 80 | 81 | """ 82 | 83 | def __init__( 84 | self, 85 | name: str = None, 86 | resolution: int = None, 87 | tempo: ndarray = None, 88 | beat: ndarray = None, 89 | downbeat: ndarray = None, 90 | tracks: Sequence[Track] = None, 91 | ): 92 | self.name = name 93 | 94 | if resolution is not None: 95 | self.resolution = resolution 96 | else: 97 | self.resolution = DEFAULT_RESOLUTION 98 | 99 | if tempo is None: 100 | self.tempo = None 101 | elif np.issubdtype(tempo.dtype, np.floating): 102 | self.tempo = tempo 103 | else: 104 | self.tempo = np.asarray(tempo).astype(float) 105 | 106 | if beat is None: 107 | self.beat = None 108 | elif beat.dtype == np.bool_: 109 | self.beat = beat 110 | else: 111 | self.beat = np.asarray(beat).astype(bool) 112 | 113 | if downbeat is None: 114 | self.downbeat = None 115 | elif downbeat.dtype == np.bool_: 116 | self.downbeat = downbeat 117 | else: 118 | self.downbeat = np.asarray(downbeat).astype(bool) 119 | 120 | if tracks is None: 121 | self.tracks = [] 122 | elif isinstance(tracks, list): 123 | self.tracks = tracks 124 | else: 125 | self.tracks = list(tracks) 126 | 127 | def __len__(self) -> int: 128 | return len(self.tracks) 129 | 130 | def __getitem__(self, key: int) -> Track: 131 | return self.tracks[key] 132 | 133 | def __setitem__(self, key: int, value: Track): 134 | self.tracks[key] = value 135 | 136 | def __delitem__(self, key: int): 137 | del self.tracks[key] 138 | 139 | def __repr__(self) -> str: 140 | to_join = [ 141 | f"name={repr(self.name)}", 142 | f"resolution={repr(self.resolution)}", 143 | ] 144 | if self.tempo is not None: 145 | to_join.append( 146 | f"tempo=array(shape={self.tempo.shape}, " 147 | f"dtype={self.tempo.dtype})" 148 | ) 149 | if self.beat is not None: 150 | to_join.append( 151 | f"beat=array(shape={self.beat.shape}, dtype={self.beat.dtype})" 152 | ) 153 | if self.downbeat is not None: 154 | to_join.append( 155 | f"downbeat=array(shape={self.downbeat.shape}, " 156 | f"dtype={self.downbeat.dtype})" 157 | ) 158 | to_join.append(f"tracks={repr(self.tracks)}") 159 | return f"Multitrack({', '.join(to_join)})" 160 | 161 | def _validate_type(self, attr): 162 | if getattr(self, attr) is None: 163 | if attr == "resolution": 164 | raise TypeError(f"`{attr}` must not be None.") 165 | return 166 | 167 | if attr == "name": 168 | if not isinstance(self.name, str): 169 | raise TypeError( 170 | "`name` must be of type str, but got type " 171 | f"{type(self.name)}." 172 | ) 173 | elif attr == "resolution": 174 | if not isinstance(self.resolution, int): 175 | raise TypeError( 176 | "`resolution` must be of type int, but got " 177 | f"{type(self.resolution)}." 178 | ) 179 | elif attr == "tempo": 180 | if not isinstance(self.tempo, np.ndarray): 181 | raise TypeError("`tempo` must be a NumPy array.") 182 | if not np.issubdtype(self.tempo.dtype, np.number): 183 | raise TypeError( 184 | "`tempo` must be of data type numpy.number, but got data " 185 | f"type {self.tempo.dtype}." 186 | ) 187 | elif attr == "beat": 188 | if not isinstance(self.beat, np.ndarray): 189 | raise TypeError("`beat` must be a NumPy array.") 190 | if not np.issubdtype(self.beat.dtype, np.bool_): 191 | raise TypeError( 192 | "`beat` must be of data type bool, but got data type" 193 | f"{self.beat.dtype}." 194 | ) 195 | elif attr == "downbeat": 196 | if not isinstance(self.downbeat, np.ndarray): 197 | raise TypeError("`downbeat` must be a NumPy array.") 198 | if not np.issubdtype(self.downbeat.dtype, np.bool_): 199 | raise TypeError( 200 | "`downbeat` must be of data type bool, but got data type" 201 | f"{self.downbeat.dtype}." 202 | ) 203 | elif attr == "tracks": 204 | for i, track in enumerate(self.tracks): 205 | if not isinstance(track, Track): 206 | raise TypeError( 207 | "`tracks` must be a list of type Track, but got type " 208 | f"{type(track)} at index {i}." 209 | ) 210 | 211 | def validate_type(self, attr=None): 212 | """Raise an error if an attribute has an invalid type. 213 | 214 | Parameters 215 | ---------- 216 | attr : str 217 | Attribute to validate. Defaults to validate all attributes. 218 | 219 | Returns 220 | ------- 221 | Object itself. 222 | 223 | """ 224 | if attr is None: 225 | attributes = ( 226 | "name", 227 | "resolution", 228 | "tempo", 229 | "beat", 230 | "downbeat", 231 | "tracks", 232 | ) 233 | for attribute in attributes: 234 | self._validate_type(attribute) 235 | else: 236 | self._validate_type(attr) 237 | return self 238 | 239 | def _validate(self, attr): 240 | if getattr(self, attr) is None: 241 | if attr == "resolution": 242 | raise TypeError(f"`{attr}` must not be None.") 243 | return 244 | 245 | self._validate_type(attr) 246 | 247 | if attr == "resolution": 248 | if self.resolution < 1: 249 | raise ValueError("`resolution` must be a positive integer.") 250 | elif attr == "tempo": 251 | if self.tempo.ndim != 1: 252 | raise ValueError("`tempo` must be a 1D NumPy array.") 253 | if np.any(self.tempo <= 0.0): 254 | raise ValueError("`tempo` must contain only positive numbers.") 255 | elif attr == "beat": 256 | if self.beat.ndim != 1: 257 | raise ValueError("`beat` must be a 1D NumPy array.") 258 | elif attr == "downbeat": 259 | if self.downbeat.ndim != 1: 260 | raise ValueError("`downbeat` must be a 1D NumPy array.") 261 | elif attr == "tracks": 262 | for track in self.tracks: 263 | track.validate() 264 | 265 | def validate(self: MultitrackType, attr=None) -> MultitrackType: 266 | """Raise an error if an attribute has an invalid type or value. 267 | 268 | Parameters 269 | ---------- 270 | attr : str 271 | Attribute to validate. Defaults to validate all attributes. 272 | 273 | Returns 274 | ------- 275 | Object itself. 276 | 277 | """ 278 | if attr is None: 279 | attributes = ( 280 | "name", 281 | "resolution", 282 | "tempo", 283 | "beat", 284 | "downbeat", 285 | "tracks", 286 | ) 287 | for attribute in attributes: 288 | self._validate(attribute) 289 | else: 290 | self._validate(attr) 291 | return self 292 | 293 | def is_valid_type(self, attr: str = None) -> bool: 294 | """Return True if an attribute is of a valid type. 295 | 296 | Parameters 297 | ---------- 298 | attr : str 299 | Attribute to validate. Defaults to validate all attributes. 300 | 301 | Returns 302 | ------- 303 | bool 304 | Whether the attribute is of a valid type. 305 | 306 | """ 307 | try: 308 | self.validate_type(attr) 309 | except TypeError: 310 | return False 311 | return True 312 | 313 | def is_valid(self, attr: str = None) -> bool: 314 | """Return True if an attribute is valid. 315 | 316 | Parameters 317 | ---------- 318 | attr : str 319 | Attribute to validate. Defaults to validate all attributes. 320 | 321 | Returns 322 | ------- 323 | bool 324 | Whether the attribute has a valid type and value. 325 | 326 | """ 327 | try: 328 | self.validate(attr) 329 | except (TypeError, ValueError): 330 | return False 331 | return True 332 | 333 | def get_end_time(self) -> int: 334 | """Return the end time of the multitrack. 335 | 336 | Returns 337 | ------- 338 | int 339 | Maximum length (in time steps) of the tempo, beat, downbeat 340 | arrays and all piano rolls. 341 | 342 | """ 343 | end_time = self.get_max_length() 344 | if self.tempo is not None and end_time < self.tempo.shape[0]: 345 | end_time = self.tempo.shape[0] 346 | if self.beat is not None and end_time < self.beat.shape[0]: 347 | end_time = self.beat.shape[0] 348 | if self.downbeat is not None and end_time < self.downbeat.shape[0]: 349 | end_time = self.downbeat.shape[0] 350 | return end_time 351 | 352 | def get_length(self) -> int: 353 | """Return the maximum active length of the piano rolls. 354 | 355 | Returns 356 | ------- 357 | int 358 | Maximum active length (in time steps) of the piano rolls, 359 | where active length is the length of the piano roll without 360 | trailing silence. 361 | 362 | """ 363 | active_length = 0 364 | for track in self.tracks: 365 | now_length = track.get_length() 366 | if active_length < track.get_length(): 367 | active_length = now_length 368 | return active_length 369 | 370 | def get_max_length(self) -> int: 371 | """Return the maximum length of the piano rolls. 372 | 373 | Returns 374 | ------- 375 | int 376 | Maximum length (in time steps) of the piano rolls. 377 | 378 | """ 379 | max_length = 0 380 | for track in self.tracks: 381 | if max_length < track.pianoroll.shape[0]: 382 | max_length = track.pianoroll.shape[0] 383 | return max_length 384 | 385 | def get_beat_steps(self) -> ndarray: 386 | """Return the indices of time steps that contain beats. 387 | 388 | Returns 389 | ------- 390 | ndarray, dtype=int 391 | Indices of time steps that contain beats. 392 | 393 | """ 394 | if self.beat is None: 395 | return np.array([]) 396 | return np.nonzero(self.beat)[0] 397 | 398 | def get_downbeat_steps(self) -> ndarray: 399 | """Return the indices of time steps that contain downbeats. 400 | 401 | Returns 402 | ------- 403 | ndarray, dtype=int 404 | Indices of time steps that contain downbeats. 405 | 406 | """ 407 | if self.downbeat is None: 408 | return np.array([]) 409 | return np.nonzero(self.downbeat)[0] 410 | 411 | def set_nonzeros(self: MultitrackType, value: int) -> MultitrackType: 412 | """Assign a constant value to all nonzero entries. 413 | 414 | Arguments 415 | --------- 416 | value : int 417 | Value to assign. 418 | 419 | Returns 420 | ------- 421 | Object itself. 422 | 423 | """ 424 | for i, track in enumerate(self.tracks): 425 | if isinstance(track, (StandardTrack, BinaryTrack)): 426 | self.tracks[i] = track.set_nonzeros(value) 427 | return self 428 | 429 | def set_resolution( 430 | self: MultitrackType, resolution: int, rounding: str = "round" 431 | ) -> MultitrackType: 432 | """Set the resolution. 433 | 434 | Parameters 435 | ---------- 436 | resolution : int 437 | Target resolution. 438 | rounding : {'round', 'ceil', 'floor'}, default: 'round' 439 | Rounding mode. 440 | 441 | Returns 442 | ------- 443 | Object itself. 444 | 445 | """ 446 | factor = resolution / self.resolution 447 | # Get the end time 448 | end_time = self.get_end_time() 449 | rounded_end_time = _round_time(end_time, factor, rounding) 450 | # Beat array 451 | beats = self.get_beat_steps() 452 | beats = _round_time(beats, factor, rounding) 453 | self.beat = np.zeros((rounded_end_time + 1, 1), bool) 454 | self.beat[beats] = True 455 | # Downbeat array 456 | downbeats = self.get_downbeat_steps() 457 | downbeats = _round_time(downbeats, factor, rounding) 458 | self.downbeat = np.zeros((rounded_end_time + 1, 1), bool) 459 | self.downbeat[downbeats] = True 460 | # Iterate over each track 461 | for track in self.tracks: 462 | time, pitch = track.pianoroll.nonzero() 463 | if len(time) < 1: 464 | continue 465 | if track.pianoroll.dtype == np.bool_: 466 | value = 1 467 | else: 468 | value = track.pianoroll[time, pitch] 469 | rounded_time = _round_time(time, factor, rounding) 470 | track.pianoroll = np.zeros( 471 | (rounded_end_time + 1, 128), track.pianoroll.dtype 472 | ) 473 | track.pianoroll[rounded_time, pitch] = value 474 | # Set the new resolution 475 | self.resolution = resolution 476 | return self 477 | 478 | def count_beat(self) -> int: 479 | """Return the number of beats. 480 | 481 | Returns 482 | ------- 483 | int 484 | Number of beats. 485 | 486 | Note 487 | ---- 488 | Return value is calculated based only on the attribute `beat`. 489 | 490 | """ 491 | if self.beat is None: 492 | return 0 493 | return np.count_nonzero(self.beat) 494 | 495 | def count_downbeat(self) -> int: 496 | """Return the number of downbeats. 497 | 498 | Returns 499 | ------- 500 | int 501 | Number of downbeats. 502 | 503 | Note 504 | ---- 505 | Return value is calculated based only on the attribute 506 | `downbeat`. 507 | 508 | """ 509 | if self.downbeat is None: 510 | return 0 511 | return np.count_nonzero(self.downbeat) 512 | 513 | def stack(self) -> ndarray: 514 | """Return the piano rolls stacked as a 3D tensor. 515 | 516 | Returns 517 | ------- 518 | ndarray, shape=(?, ?, 128) 519 | Stacked piano roll, provided as *(track, time, pitch)*. 520 | 521 | """ 522 | max_length = self.get_max_length() 523 | pianorolls = [] 524 | for track in self.tracks: 525 | if track.pianoroll.shape[0] < max_length: 526 | pad_length = max_length - track.pianoroll.shape[0] 527 | padded = np.pad( 528 | track.pianoroll, 529 | ((0, pad_length), (0, 0)), 530 | "constant", 531 | ) 532 | pianorolls.append(padded) 533 | else: 534 | pianorolls.append(track.pianoroll) 535 | return np.stack(pianorolls) 536 | 537 | def blend(self, mode: str = None) -> ndarray: 538 | """Return the blended pianoroll. 539 | 540 | Parameters 541 | ---------- 542 | mode : {'sum', 'max', 'any'}, default: 'sum' 543 | Blending strategy to apply along the track axis. For 'sum' 544 | mode, integer summation is performed for binary piano rolls. 545 | 546 | Returns 547 | ------- 548 | ndarray, shape=(?, 128) 549 | Blended piano roll. 550 | 551 | """ 552 | stacked = self.stack() 553 | if mode is None or mode.lower() == "sum": 554 | return np.sum(stacked, axis=0).clip(0, 127).astype(np.uint8) 555 | if mode.lower() == "any": 556 | return np.any(stacked, axis=0) 557 | if mode.lower() == "max": 558 | return np.max(stacked, axis=0) 559 | raise ValueError("`mode` must be one of 'max', 'sum' or 'any'.") 560 | 561 | def copy(self): 562 | """Return a copy of the multitrack. 563 | 564 | Returns 565 | ------- 566 | A copy of the object itself. 567 | 568 | Notes 569 | ----- 570 | Arrays are copied using :func:`numpy.copy`. 571 | 572 | """ 573 | return Multitrack( 574 | name=self.name, 575 | resolution=self.resolution, 576 | tempo=None if self.tempo is None else self.tempo.copy(), 577 | beat=None if self.beat is None else self.beat.copy(), 578 | downbeat=None if self.downbeat is None else self.downbeat.copy(), 579 | tracks=[track.copy() for track in self.tracks], 580 | ) 581 | 582 | def append(self: MultitrackType, track: Track) -> MultitrackType: 583 | """Append a Track object to the track list. 584 | 585 | Parameters 586 | ---------- 587 | track : :class:`pypianoroll.Track` 588 | Track to append. 589 | 590 | Returns 591 | ------- 592 | Object itself. 593 | 594 | """ 595 | self.tracks.append(track) 596 | return self 597 | 598 | def binarize(self: MultitrackType, threshold: float = 0) -> MultitrackType: 599 | """Binarize the piano rolls. 600 | 601 | Parameters 602 | ---------- 603 | threshold : int or float, default: 0 604 | Threshold to binarize the piano rolls. 605 | 606 | Returns 607 | ------- 608 | Object itself. 609 | 610 | """ 611 | for i, track in enumerate(self.tracks): 612 | if isinstance(track, StandardTrack): 613 | self.tracks[i] = track.binarize(threshold) 614 | return self 615 | 616 | def clip( 617 | self: MultitrackType, lower: int = 0, upper: int = 127 618 | ) -> MultitrackType: 619 | """Clip (limit) the the piano roll into [`lower`, `upper`]. 620 | 621 | Parameters 622 | ---------- 623 | lower : int, default: 0 624 | Lower bound. 625 | upper : int, default: 127 626 | Upper bound. 627 | 628 | Returns 629 | ------- 630 | Object itself. 631 | 632 | Note 633 | ---- 634 | Only affect StandardTrack instances. 635 | 636 | """ 637 | for track in self.tracks: 638 | if isinstance(track, StandardTrack): 639 | track.clip(lower, upper) 640 | return self 641 | 642 | def pad(self: MultitrackType, pad_length) -> MultitrackType: 643 | """Pad the piano rolls. 644 | 645 | Notes 646 | ----- 647 | The lengths of the resulting piano rolls are not guaranteed to 648 | be the same. 649 | 650 | Parameters 651 | ---------- 652 | pad_length : int 653 | Length to pad along the time axis. 654 | 655 | Returns 656 | ------- 657 | Object itself. 658 | 659 | See Also 660 | -------- 661 | :meth:`pypianoroll.Multitrack.pad_to_multiple` : Pad the piano 662 | rolls so that their lengths are some multiples. 663 | :meth:`pypianoroll.Multitrack.pad_to_same` : Pad the piano rolls 664 | so that they have the same length. 665 | 666 | """ 667 | for track in self.tracks: 668 | track.pad(pad_length) 669 | return self 670 | 671 | def pad_to_multiple(self: MultitrackType, factor: int) -> MultitrackType: 672 | """Pad the piano rolls so that their lengths are some multiples. 673 | 674 | Pad the piano rolls at the end along the time axis of the 675 | minimum length that makes the lengths of the resulting piano 676 | rolls multiples of `factor`. 677 | 678 | Parameters 679 | ---------- 680 | factor : int 681 | The value which the length of the resulting piano rolls will 682 | be a multiple of. 683 | 684 | Returns 685 | ------- 686 | Object itself. 687 | 688 | Notes 689 | ----- 690 | Lengths of the resulting piano rolls are necessarily the same. 691 | 692 | See Also 693 | -------- 694 | :meth:`pypianoroll.Multitrack.pad` : Pad the piano rolls. 695 | :meth:`pypianoroll.Multitrack.pad_to_same` : Pad the piano rolls 696 | so that they have the same length. 697 | 698 | """ 699 | for track in self.tracks: 700 | track.pad_to_multiple(factor) 701 | return self 702 | 703 | def pad_to_same(self: MultitrackType) -> MultitrackType: 704 | """Pad the piano rolls so that they have the same length. 705 | 706 | Pad shorter piano rolls at the end along the time axis so that 707 | the resulting piano rolls have the same length. 708 | 709 | Returns 710 | ------- 711 | Object itself. 712 | 713 | See Also 714 | -------- 715 | :meth:`pypianoroll.Multitrack.pad` : Pad the piano rolls. 716 | :meth:`pypianoroll.Multitrack.pad_to_multiple` : Pad the piano 717 | rolls so that their lengths are some multiples. 718 | 719 | """ 720 | max_length = self.get_max_length() 721 | for track in self.tracks: 722 | if track.pianoroll.shape[0] < max_length: 723 | track.pad(max_length - track.pianoroll.shape[0]) 724 | return self 725 | 726 | def remove_empty(self: MultitrackType) -> MultitrackType: 727 | """Remove tracks with empty pianorolls.""" 728 | self.tracks = [ 729 | track for track in self.tracks if not np.any(track.pianoroll) 730 | ] 731 | return self 732 | 733 | def transpose(self: MultitrackType, semitone: int) -> MultitrackType: 734 | """Transpose the piano rolls by a number of semitones. 735 | 736 | Parameters 737 | ---------- 738 | semitone : int 739 | Number of semitones to transpose. A positive value raises 740 | the pitches, while a negative value lowers the pitches. 741 | 742 | Returns 743 | ------- 744 | Object itself. 745 | 746 | Notes 747 | ----- 748 | Drum tracks are skipped. 749 | 750 | """ 751 | for track in self.tracks: 752 | if not track.is_drum: 753 | track.transpose(semitone) 754 | return self 755 | 756 | def trim( 757 | self: MultitrackType, start: int = None, end: int = None 758 | ) -> MultitrackType: 759 | """Trim the trailing silences of the piano rolls. 760 | 761 | Parameters 762 | ---------- 763 | start : int, default: 0 764 | Start time. 765 | end : int, optional 766 | End time. Defaults to active length. 767 | 768 | Returns 769 | ------- 770 | Object itself. 771 | 772 | """ 773 | if start is None: 774 | start = 0 775 | if end is None: 776 | end = self.get_length() 777 | if self.tempo is not None: 778 | self.tempo = self.tempo[start:end] 779 | if self.beat is not None: 780 | self.beat = self.beat[start:end] 781 | if self.downbeat is not None: 782 | self.downbeat = self.downbeat[start:end] 783 | for track in self.tracks: 784 | track.trim(start=start, end=end) 785 | return self 786 | 787 | def save(self, path: str, compressed: bool = True): 788 | """Save to a NPZ file. 789 | 790 | Refer to :func:`pypianoroll.save` for full documentation. 791 | 792 | """ 793 | save(path, self, compressed=compressed) 794 | 795 | def write(self, path: str): 796 | """Write to a MIDI file. 797 | 798 | Refer to :func:`pypianoroll.write` for full documentation. 799 | 800 | """ 801 | return write(path, self) 802 | 803 | def to_pretty_midi(self, **kwargs): 804 | """Return as a PrettyMIDI object. 805 | 806 | Refer to :func:`pypianoroll.to_pretty_midi` for full 807 | documentation. 808 | 809 | """ 810 | return to_pretty_midi(self, **kwargs) 811 | 812 | def plot(self, axs: Sequence[Axes] = None, **kwargs) -> List[Axes]: 813 | """Plot the multitrack piano roll. 814 | 815 | Refer to :func:`pypianoroll.plot_multitrack` for full 816 | documentation. 817 | 818 | """ 819 | return plot_multitrack(self, axs, **kwargs) 820 | -------------------------------------------------------------------------------- /pypianoroll/outputs.py: -------------------------------------------------------------------------------- 1 | """Output interfaces. 2 | 3 | Functions 4 | --------- 5 | 6 | - save 7 | - to_pretty_midi 8 | - write 9 | 10 | Variable 11 | -------- 12 | 13 | - DEFAULT_TEMPO 14 | - DEFAULT_VELOCITY 15 | 16 | """ 17 | import json 18 | import zipfile 19 | from copy import deepcopy 20 | from operator import attrgetter 21 | from pathlib import Path 22 | from typing import TYPE_CHECKING, Dict, Union 23 | 24 | import numpy as np 25 | import pretty_midi 26 | import scipy.stats 27 | from pretty_midi import Instrument, PrettyMIDI 28 | 29 | from .track import BinaryTrack, StandardTrack 30 | from .utils import decompose_sparse 31 | 32 | if TYPE_CHECKING: 33 | from .multitrack import Multitrack 34 | 35 | __all__ = [ 36 | "save", 37 | "to_pretty_midi", 38 | "write", 39 | "DEFAULT_TEMPO", 40 | "DEFAULT_VELOCITY", 41 | ] 42 | 43 | DEFAULT_TEMPO = 120 44 | DEFAULT_VELOCITY = 64 45 | 46 | 47 | def save( 48 | path: Union[str, Path], multitrack: "Multitrack", compressed: bool = True 49 | ): 50 | """Save a Multitrack object to a NPZ file. 51 | 52 | Parameters 53 | ---------- 54 | path : str or Path 55 | Path to the NPZ file to save. 56 | multitrack : :class:`pypianoroll.Multitrack` 57 | Multitrack to save. 58 | compressed : bool, default: True 59 | Whether to save to a compressed NPZ file. 60 | 61 | Notes 62 | ----- 63 | To reduce the file size, the piano rolls are first converted to 64 | instances of :class:`scipy.sparse.csc_matrix`. The component 65 | arrays are then collected and saved to a npz file. 66 | 67 | See Also 68 | -------- 69 | :func:`pypianoroll.load` : Load a NPZ file into a Multitrack object. 70 | :func:`pypianoroll.write` : Write a Multitrack object to a MIDI 71 | file. 72 | 73 | """ 74 | info_dict: Dict = { 75 | "resolution": multitrack.resolution, 76 | "name": multitrack.name, 77 | } 78 | 79 | array_dict = {} 80 | if multitrack.tempo is not None: 81 | array_dict["tempo"] = multitrack.tempo 82 | if multitrack.beat is not None: 83 | array_dict["beat"] = multitrack.beat 84 | if multitrack.downbeat is not None: 85 | array_dict["downbeat"] = multitrack.downbeat 86 | 87 | for idx, track in enumerate(multitrack.tracks): 88 | array_dict.update( 89 | decompose_sparse(track.pianoroll, "pianoroll_" + str(idx)) 90 | ) 91 | info_dict[str(idx)] = { 92 | "program": track.program, 93 | "is_drum": track.is_drum, 94 | "name": track.name, 95 | } 96 | 97 | if compressed: 98 | np.savez_compressed(path, **array_dict) 99 | else: 100 | np.savez(path, **array_dict) 101 | 102 | compression = zipfile.ZIP_DEFLATED if compressed else zipfile.ZIP_STORED 103 | with zipfile.ZipFile(path, "a") as zip_file: 104 | zip_file.writestr("info.json", json.dumps(info_dict), compression) 105 | 106 | 107 | def to_pretty_midi( 108 | multitrack: "Multitrack", 109 | default_tempo: float = None, 110 | default_velocity: int = DEFAULT_VELOCITY, 111 | ) -> PrettyMIDI: 112 | """Return a Multitrack object as a PrettyMIDI object. 113 | 114 | Parameters 115 | ---------- 116 | default_tempo : int, default: `pypianoroll.DEFAULT_TEMPO` (120) 117 | Default tempo to use. If attribute `tempo` is available, use its 118 | first element. 119 | default_velocity : int, default: `pypianoroll.DEFAULT_VELOCITY` (64) 120 | Default velocity to assign to binarized tracks. 121 | 122 | Returns 123 | ------- 124 | :class:`pretty_midi.PrettyMIDI` 125 | Converted PrettyMIDI object. 126 | 127 | Notes 128 | ----- 129 | - Tempo changes are not supported. 130 | - Time signature changes are not supported. 131 | - The velocities of the converted piano rolls will be clipped to 132 | [0, 127]. 133 | - Adjacent nonzero values of the same pitch will be considered 134 | a single note with their mean as its velocity. 135 | 136 | """ 137 | if default_tempo is not None: 138 | tempo = default_tempo 139 | elif multitrack.tempo is not None: 140 | tempo = float(scipy.stats.hmean(multitrack.tempo)) 141 | else: 142 | tempo = DEFAULT_TEMPO 143 | 144 | # Create a PrettyMIDI instance 145 | midi = PrettyMIDI(initial_tempo=tempo) 146 | 147 | # Compute length of a time step 148 | time_step_length = 60.0 / tempo / multitrack.resolution 149 | 150 | for track in multitrack.tracks: 151 | instrument = Instrument( 152 | program=track.program, is_drum=track.is_drum, name=track.name 153 | ) 154 | if isinstance(track, BinaryTrack): 155 | processed = track.set_nonzeros(default_velocity) 156 | elif isinstance(track, StandardTrack): 157 | copied = deepcopy(track) 158 | processed = copied.clip() 159 | else: 160 | raise ValueError( 161 | f"Expect BinaryTrack or StandardTrack, but got {type(track)}." 162 | ) 163 | clipped = processed.pianoroll.astype(np.uint8) 164 | binarized = clipped > 0 165 | padded = np.pad(binarized, ((1, 1), (0, 0)), "constant") 166 | diff = np.diff(padded.astype(np.int8), axis=0) 167 | 168 | positives = np.nonzero((diff > 0).T) 169 | pitches = positives[0] 170 | note_ons = positives[1] 171 | note_on_times = time_step_length * note_ons 172 | note_offs = np.nonzero((diff < 0).T)[1] 173 | note_off_times = time_step_length * note_offs 174 | 175 | for idx, pitch in enumerate(pitches): 176 | velocity = np.mean(clipped[note_ons[idx] : note_offs[idx], pitch]) 177 | note = pretty_midi.Note( 178 | velocity=int(velocity), 179 | pitch=pitch, 180 | start=note_on_times[idx], 181 | end=note_off_times[idx], 182 | ) 183 | instrument.notes.append(note) 184 | 185 | instrument.notes.sort(key=attrgetter("start")) 186 | midi.instruments.append(instrument) 187 | 188 | return midi 189 | 190 | 191 | def write(path: str, multitrack: "Multitrack"): 192 | """Write a Multitrack object to a MIDI file. 193 | 194 | Parameters 195 | ---------- 196 | path : str 197 | Path to write the file. 198 | multitrack : :class:`pypianoroll.Multitrack` 199 | Multitrack to save. 200 | 201 | See Also 202 | -------- 203 | :func:`pypianoroll.read` : Read a MIDI file into a Multitrack 204 | object. 205 | :func:`pypianoroll.save` : Save a Multitrack object to a NPZ file. 206 | 207 | """ 208 | return to_pretty_midi(multitrack).write(str(path)) 209 | -------------------------------------------------------------------------------- /pypianoroll/track.py: -------------------------------------------------------------------------------- 1 | """Classes for single-track piano rolls. 2 | 3 | Classes 4 | ------- 5 | 6 | - BinaryTrack 7 | - StandardTrack 8 | - Track 9 | 10 | Variables 11 | --------- 12 | 13 | - DEFAULT_PROGRAM 14 | - DEFAULT_IS_DRUM 15 | 16 | """ 17 | from typing import Any, TypeVar 18 | 19 | import numpy as np 20 | from matplotlib.axes import Axes 21 | from numpy import ndarray 22 | 23 | from .visualization import plot_track 24 | 25 | __all__ = [ 26 | "BinaryTrack", 27 | "StandardTrack", 28 | "Track", 29 | "DEFAULT_PROGRAM", 30 | "DEFAULT_IS_DRUM", 31 | ] 32 | 33 | DEFAULT_PROGRAM = 0 34 | DEFAULT_IS_DRUM = False 35 | 36 | TrackType = TypeVar("TrackType", bound="Track") 37 | StandardTrackType = TypeVar("StandardTrackType", bound="StandardTrack") 38 | 39 | 40 | class Track: 41 | """A generic container for single-track piano rolls. 42 | 43 | Attributes 44 | ---------- 45 | name : str, optional 46 | Track name. 47 | program : int, 0-127, default: `pypianoroll.DEFAULT_PROGRAM` (0) 48 | Program number according to General MIDI specification [1]. 49 | Defaults to 0 (Acoustic Grand Piano). 50 | is_drum : bool, `pypianoroll.DEFAULT_IS_DRUM` (False) 51 | Whether it is a percussion track. 52 | pianoroll : ndarray, shape=(?, 128), optional 53 | Piano-roll matrix. The first dimension represents time, and the 54 | second dimension represents pitch. 55 | 56 | References 57 | ---------- 58 | 1. https://www.midi.org/specifications/item/gm-level-1-sound-set 59 | 60 | """ 61 | 62 | def __init__( 63 | self, 64 | name: str = None, 65 | program: int = None, 66 | is_drum: bool = None, 67 | pianoroll: ndarray = None, 68 | ): 69 | self.name = name 70 | self.program = program if program is not None else DEFAULT_PROGRAM 71 | self.is_drum = is_drum if is_drum is not None else DEFAULT_IS_DRUM 72 | if pianoroll is None: 73 | self.pianoroll = np.zeros((0, 128)) 74 | else: 75 | self.pianoroll = np.asarray(pianoroll) 76 | 77 | def __repr__(self) -> str: 78 | to_join = [ 79 | f"name={repr(self.name)}", 80 | f"program={repr(self.program)}", 81 | f"is_drum={repr(self.is_drum)}", 82 | f"pianoroll=array(shape={self.pianoroll.shape}, " 83 | f"dtype={self.pianoroll.dtype})", 84 | ] 85 | return f"Track({', '.join(to_join)})" 86 | 87 | def __len__(self) -> int: 88 | return len(self.pianoroll) 89 | 90 | def __getitem__(self, key) -> ndarray: 91 | return self.pianoroll[key] 92 | 93 | def __setitem__(self, key: int, value: Any): 94 | self.pianoroll[key] = value 95 | 96 | def _validate_type(self, attr): 97 | if getattr(self, attr) is None: 98 | if attr in ("program", "is_drum", "pianoroll"): 99 | raise TypeError(f"`{attr}` must not be None.") 100 | return 101 | 102 | if attr == "program": 103 | if not isinstance(self.program, int): 104 | raise TypeError( 105 | "`program` must be of type int, not " 106 | f"{type(self.program)}." 107 | ) 108 | elif attr == "is_drum": 109 | if not isinstance(self.is_drum, bool): 110 | raise TypeError( 111 | "`is_drum` must be of type bool, not " 112 | f"{type(self.is_drum)}." 113 | ) 114 | elif attr == "name": 115 | if not isinstance(self.name, str): 116 | raise TypeError( 117 | f"`name` must be of type str, not {type(self.name)}." 118 | ) 119 | elif attr == "pianoroll": 120 | if not isinstance(self.pianoroll, ndarray): 121 | raise TypeError( 122 | "`pianoroll` must be a NumPy array, not " 123 | f"{type(self.pianoroll)}." 124 | ) 125 | 126 | def validate_type(self, attr=None): 127 | """Raise an error if an attribute has an invalid type. 128 | 129 | Parameters 130 | ---------- 131 | attr : str 132 | Attribute to validate. Defaults to validate all attributes. 133 | 134 | Returns 135 | ------- 136 | Object itself. 137 | 138 | """ 139 | if attr is None: 140 | for attribute in ("program", "is_drum", "name", "pianoroll"): 141 | self._validate_type(attribute) 142 | else: 143 | self._validate_type(attr) 144 | return self 145 | 146 | def _validate(self, attr): 147 | if getattr(self, attr) is None: 148 | if attr in ("program", "is_drum", "pianoroll"): 149 | raise TypeError(f"`{attr}` must not be None.") 150 | return 151 | 152 | self._validate_type(attr) 153 | 154 | if attr == "program": 155 | if self.program < 0 or self.program > 127: 156 | raise ValueError("`program` must be in between 0 to 127.") 157 | elif attr == "pianoroll": 158 | if self.pianoroll.ndim != 2: 159 | raise ValueError( 160 | "`pianoroll` must have exactly two dimensions." 161 | ) 162 | if self.pianoroll.shape[1] != 128: 163 | raise ValueError( 164 | "Length of the second axis of `pianoroll` must be 128." 165 | ) 166 | 167 | def validate(self, attr=None): 168 | """Raise an error if an attribute has an invalid type or value. 169 | 170 | Parameters 171 | ---------- 172 | attr : str 173 | Attribute to validate. Defaults to validate all attributes. 174 | 175 | Returns 176 | ------- 177 | Object itself. 178 | 179 | """ 180 | if attr is None: 181 | for attribute in ("program", "is_drum", "name", "pianoroll"): 182 | self._validate(attribute) 183 | else: 184 | self._validate(attr) 185 | return self 186 | 187 | def is_valid_type(self, attr: str = None) -> bool: 188 | """Return True if an attribute is of a valid type. 189 | 190 | Parameters 191 | ---------- 192 | attr : str 193 | Attribute to validate. Defaults to validate all attributes. 194 | 195 | Returns 196 | ------- 197 | bool 198 | Whether the attribute is of a valid type. 199 | 200 | """ 201 | try: 202 | self.validate_type(attr) 203 | except TypeError: 204 | return False 205 | return True 206 | 207 | def is_valid(self, attr: str = None) -> bool: 208 | """Return True if an attribute is valid. 209 | 210 | Parameters 211 | ---------- 212 | attr : str 213 | Attribute to validate. Defaults to validate all attributes. 214 | 215 | Returns 216 | ------- 217 | bool 218 | Whether the attribute has a valid type and value. 219 | 220 | """ 221 | try: 222 | self.validate(attr) 223 | except (TypeError, ValueError): 224 | return False 225 | return True 226 | 227 | def get_length(self) -> int: 228 | """Return the active length of the piano roll. 229 | 230 | Returns 231 | ------- 232 | int 233 | Length (in time steps) of the piano roll without trailing 234 | silence. 235 | 236 | """ 237 | nonzero_steps = np.any(self.pianoroll, axis=1) 238 | inv_last_nonzero_step = int(np.argmax(np.flip(nonzero_steps, axis=0))) 239 | return self.pianoroll.shape[0] - inv_last_nonzero_step 240 | 241 | def copy(self: "Track") -> "Track": 242 | """Return a copy of the track. 243 | 244 | Returns 245 | ------- 246 | A copy of the object itself. 247 | 248 | Notes 249 | ----- 250 | The piano-roll array is copied using :func:`numpy.copy`. 251 | 252 | """ 253 | return Track( 254 | name=self.name, 255 | program=self.program, 256 | is_drum=self.is_drum, 257 | pianoroll=self.pianoroll.copy(), 258 | ) 259 | 260 | def pad(self: TrackType, pad_length: int) -> TrackType: 261 | """Pad the piano roll. 262 | 263 | Parameters 264 | ---------- 265 | pad_length : int 266 | Length to pad along the time axis. 267 | 268 | Returns 269 | ------- 270 | Object itself. 271 | 272 | See Also 273 | -------- 274 | :func:`pypianoroll.Track.pad_to_multiple` : Pad the piano 275 | roll so that its length is some multiple. 276 | 277 | """ 278 | self.pianoroll = np.pad( 279 | self.pianoroll, ((0, pad_length), (0, 0)), "constant" 280 | ) 281 | return self 282 | 283 | def pad_to_multiple(self: TrackType, factor: int) -> TrackType: 284 | """Pad the piano roll so that its length is some multiple. 285 | 286 | Pad the piano roll at the end along the time axis of the minimum 287 | length that makes the length of the resulting piano roll a 288 | multiple of `factor`. 289 | 290 | Parameters 291 | ---------- 292 | factor : int 293 | The value which the length of the resulting piano roll will 294 | be a multiple of. 295 | 296 | Returns 297 | ------- 298 | Object itself. 299 | 300 | See Also 301 | -------- 302 | :func:`pypianoroll.Track.pad` : Pad the piano roll. 303 | 304 | """ 305 | remainder = self.pianoroll.shape[0] % factor 306 | if remainder: 307 | pad_width = ((0, (factor - remainder)), (0, 0)) 308 | self.pianoroll = np.pad(self.pianoroll, pad_width, "constant") 309 | return self 310 | 311 | def transpose(self: TrackType, semitone: int) -> TrackType: 312 | """Transpose the piano roll by a number of semitones. 313 | 314 | Parameters 315 | ---------- 316 | semitone : int 317 | Number of semitones to transpose. A positive value raises 318 | the pitches, while a negative value lowers the pitches. 319 | 320 | Returns 321 | ------- 322 | Object itself. 323 | 324 | """ 325 | if 0 < semitone < 128: 326 | self.pianoroll[:, semitone:] = self.pianoroll[ 327 | :, : (128 - semitone) 328 | ] 329 | self.pianoroll[:, :semitone] = 0 330 | elif -128 < semitone < 0: 331 | self.pianoroll[:, : (128 + semitone)] = self.pianoroll[ 332 | :, -semitone: 333 | ] 334 | self.pianoroll[:, (128 + semitone) :] = 0 335 | return self 336 | 337 | def trim(self: TrackType, start: int = None, end: int = None) -> TrackType: 338 | """Trim the piano roll. 339 | 340 | Parameters 341 | ---------- 342 | start : int, default: 0 343 | Start time. 344 | end : int, optional 345 | End time. Defaults to active length. 346 | 347 | Returns 348 | ------- 349 | Object itself. 350 | 351 | """ 352 | if start is None: 353 | start = 0 354 | elif start < 0: 355 | raise ValueError("`start` must be nonnegative.") 356 | if end is None: 357 | end = self.get_length() 358 | elif end > len(self.pianoroll): 359 | raise ValueError( 360 | "`end` must be shorter than the piano roll length." 361 | ) 362 | self.pianoroll = self.pianoroll[start:end] 363 | return self 364 | 365 | def standardize(self: "Track") -> "StandardTrack": 366 | """Standardize the track. 367 | 368 | This will clip the piano roll to [0, 127] and cast to np.uint8. 369 | 370 | Returns 371 | ------- 372 | Converted StandardTrack object. 373 | 374 | """ 375 | return StandardTrack( 376 | name=self.name, 377 | program=self.program, 378 | is_drum=self.is_drum, 379 | pianoroll=np.clip(self.pianoroll, 0, 127), 380 | ) 381 | 382 | def binarize(self, threshold: float = 0) -> "BinaryTrack": 383 | """Binarize the track. 384 | 385 | This will binarize the piano roll by the given threshold. 386 | 387 | Parameters 388 | ---------- 389 | threshold : int or float, default: 0 390 | Threshold. 391 | 392 | Returns 393 | ------- 394 | Converted BinaryTrack object. 395 | 396 | """ 397 | return BinaryTrack( 398 | program=self.program, 399 | is_drum=self.is_drum, 400 | name=self.name, 401 | pianoroll=(self.pianoroll > threshold), 402 | ) 403 | 404 | def plot(self, ax: Axes = None, **kwargs) -> Axes: 405 | """Plot the piano roll. 406 | 407 | Refer to :func:`pypianoroll.plot_track` for full documentation. 408 | 409 | """ 410 | return plot_track(self, ax, **kwargs) 411 | 412 | 413 | class StandardTrack(Track): 414 | """A container for single-track piano rolls with velocities. 415 | 416 | Attributes 417 | ---------- 418 | name : str, optional 419 | Track name. 420 | program : int, 0-127, default: `pypianoroll.DEFAULT_PROGRAM` (0) 421 | Program number according to General MIDI specification [1]. 422 | Defaults to 0 (Acoustic Grand Piano). 423 | is_drum : bool, default: `pypianoroll.DEFAULT_IS_DRUM` (False) 424 | Whether it is a percussion track. 425 | pianoroll : ndarray, dtype=uint8, shape=(?, 128), optional 426 | Piano-roll matrix. The first dimension represents time, and the 427 | second dimension represents pitch. Cast to uint8 if not of data 428 | type uint8. 429 | 430 | References 431 | ---------- 432 | 1. https://www.midi.org/specifications/item/gm-level-1-sound-set 433 | 434 | """ 435 | 436 | def __init__( 437 | self, 438 | name: str = None, 439 | program: int = None, 440 | is_drum: bool = None, 441 | pianoroll: ndarray = None, 442 | ): 443 | super().__init__(name, program, is_drum, pianoroll) 444 | if self.pianoroll.dtype != np.uint8: 445 | self.pianoroll = self.pianoroll.astype(np.uint8) 446 | 447 | def __repr__(self): 448 | to_join = [ 449 | f"name={repr(self.name)}", 450 | f"program={repr(self.program)}", 451 | f"is_drum={repr(self.is_drum)}", 452 | f"pianoroll=array(shape={self.pianoroll.shape}, " 453 | f"dtype={self.pianoroll.dtype})", 454 | ] 455 | return f"StandardTrack({', '.join(to_join)})" 456 | 457 | def _validate_type(self, attr): 458 | super()._validate_type(attr) 459 | if attr == "pianoroll" and self.pianoroll.dtype != np.uint8: 460 | raise TypeError( 461 | "`pianoroll` must be of data type uint8, not " 462 | f"{self.pianoroll.dtype}." 463 | ) 464 | 465 | def _validate(self, attr): 466 | super()._validate(attr) 467 | if attr == "pianoroll" and np.any(self.pianoroll > 127): 468 | raise ValueError( 469 | "`pianoroll` must contain only integers between 0 to 127." 470 | ) 471 | 472 | def set_nonzeros(self: StandardTrackType, value: int) -> StandardTrackType: 473 | """Assign a constant value to all nonzeros entries. 474 | 475 | Arguments 476 | --------- 477 | value : int 478 | Value to assign. 479 | 480 | Returns 481 | ------- 482 | Object itself. 483 | 484 | """ 485 | self.pianoroll[self.pianoroll.nonzero()] = value 486 | return self 487 | 488 | def clip( 489 | self: StandardTrackType, lower: int = 0, upper: int = 127 490 | ) -> StandardTrackType: 491 | """Clip (limit) the the piano roll into [`lower`, `upper`]. 492 | 493 | Parameters 494 | ---------- 495 | lower : int, default: 0 496 | Lower bound. 497 | upper : int, default: 127 498 | Upper bound. 499 | 500 | Returns 501 | ------- 502 | Object itself. 503 | 504 | """ 505 | if not isinstance(lower, int): 506 | raise ValueError("`lower` must be of type int.") 507 | if not isinstance(upper, int): 508 | raise ValueError("`upper` must be of type int.") 509 | self.pianoroll = self.pianoroll.clip(lower, upper) 510 | return self 511 | 512 | def copy(self: "StandardTrack") -> "StandardTrack": 513 | """Return a copy of the track. 514 | 515 | Returns 516 | ------- 517 | A copy of the object itself. 518 | 519 | Notes 520 | ----- 521 | The piano-roll array is copied using :func:`numpy.copy`. 522 | 523 | """ 524 | return StandardTrack( 525 | name=self.name, 526 | program=self.program, 527 | is_drum=self.is_drum, 528 | pianoroll=self.pianoroll.copy(), 529 | ) 530 | 531 | 532 | class BinaryTrack(Track): 533 | """A container for single-track, binary piano rolls. 534 | 535 | Attributes 536 | ---------- 537 | name : str, optional 538 | Track name. 539 | program : int, 0-127, default: `pypianoroll.DEFAULT_PROGRAM` (0) 540 | Program number according to General MIDI specification [1]. 541 | Defaults to 0 (Acoustic Grand Piano). 542 | is_drum : bool, default: `pypianoroll.DEFAULT_IS_DRUM` (False) 543 | Whether it is a percussion track. 544 | pianoroll : ndarray, dtype=bool, shape=(?, 128), optional 545 | Piano-roll matrix. The first dimension represents time, and the 546 | second dimension represents pitch. Cast to bool if not of data 547 | type bool. 548 | 549 | References 550 | ---------- 551 | 1. https://www.midi.org/specifications/item/gm-level-1-sound-set 552 | 553 | """ 554 | 555 | def __init__( 556 | self, 557 | name: str = None, 558 | program: int = None, 559 | is_drum: bool = None, 560 | pianoroll: ndarray = None, 561 | ): 562 | super().__init__(name, program, is_drum, pianoroll) 563 | if self.pianoroll.dtype != np.bool_: 564 | self.pianoroll = self.pianoroll.astype(np.bool_) 565 | 566 | def __repr__(self): 567 | to_join = [ 568 | f"name={repr(self.name)}", 569 | f"program={repr(self.program)}", 570 | f"is_drum={repr(self.is_drum)}", 571 | f"pianoroll=array(shape={self.pianoroll.shape}, " 572 | f"dtype={self.pianoroll.dtype})", 573 | ] 574 | return f"BinaryTrack({', '.join(to_join)})" 575 | 576 | def _validate_type(self, attr): 577 | super()._validate_type(attr) 578 | if attr == "pianoroll" and self.pianoroll.dtype != np.bool_: 579 | raise TypeError( 580 | "`pianoroll` must be of data type bool, not " 581 | f"{self.pianoroll.dtype}." 582 | ) 583 | 584 | def set_nonzeros(self, value: int) -> "StandardTrack": 585 | """Assign a constant value to all nonzeros entries. 586 | 587 | Arguments 588 | --------- 589 | value : int 590 | Value to assign. 591 | 592 | Returns 593 | ------- 594 | Converted StandardTrack object. 595 | 596 | """ 597 | pianoroll = np.zeros(self.pianoroll.shape, np.uint8) 598 | pianoroll[self.pianoroll.nonzero()] = value 599 | return StandardTrack( 600 | name=self.name, 601 | program=self.program, 602 | is_drum=self.is_drum, 603 | pianoroll=pianoroll, 604 | ) 605 | 606 | def copy(self: "BinaryTrack") -> "BinaryTrack": 607 | """Return a copy of the track. 608 | 609 | Returns 610 | ------- 611 | A copy of the object itself. 612 | 613 | Notes 614 | ----- 615 | The piano-roll array is copied using :func:`numpy.copy`. 616 | 617 | """ 618 | return BinaryTrack( 619 | name=self.name, 620 | program=self.program, 621 | is_drum=self.is_drum, 622 | pianoroll=self.pianoroll.copy(), 623 | ) 624 | -------------------------------------------------------------------------------- /pypianoroll/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions. 2 | 3 | Functions 4 | --------- 5 | 6 | - decompose_sparse 7 | - reconstruct_sparse 8 | 9 | """ 10 | from typing import Dict 11 | 12 | from numpy import ndarray 13 | from scipy.sparse import csc_matrix 14 | 15 | 16 | def decompose_sparse(matrix: ndarray, name: str) -> Dict[str, ndarray]: 17 | """Decompose a matrix to sparse components. 18 | 19 | Convert a matrix to a :class:`scipy.sparse.csc_matrix` object. 20 | Return its component arrays as a dictionary with key as `name` 21 | suffixed with their component types. 22 | 23 | """ 24 | csc = csc_matrix(matrix) 25 | return { 26 | name + "_csc_data": csc.data, 27 | name + "_csc_indices": csc.indices, 28 | name + "_csc_indptr": csc.indptr, 29 | name + "_csc_shape": csc.shape, 30 | } 31 | 32 | 33 | def reconstruct_sparse(data_dict: Dict[str, ndarray], name: str) -> ndarray: 34 | """Reconstruct a matrix from a dictionary.""" 35 | sparse_matrix = csc_matrix( 36 | ( 37 | data_dict[name + "_csc_data"], 38 | data_dict[name + "_csc_indices"], 39 | data_dict[name + "_csc_indptr"], 40 | ), 41 | shape=data_dict[name + "_csc_shape"], 42 | ) 43 | return sparse_matrix.toarray() 44 | -------------------------------------------------------------------------------- /pypianoroll/version.py: -------------------------------------------------------------------------------- 1 | """Pypianoroll library version.""" 2 | __version__ = "1.0.4" 3 | -------------------------------------------------------------------------------- /pypianoroll/visualization.py: -------------------------------------------------------------------------------- 1 | """Visualization tools. 2 | 3 | Functions 4 | --------- 5 | 6 | - plot_multitrack 7 | - plot_pianoroll 8 | - plot_track 9 | 10 | """ 11 | from typing import TYPE_CHECKING, List, Sequence 12 | 13 | import matplotlib 14 | import numpy as np 15 | from matplotlib import pyplot as plt 16 | from matplotlib.axes import Axes 17 | from matplotlib.patches import Patch 18 | from numpy import ndarray 19 | from pretty_midi import ( 20 | note_number_to_drum_name, 21 | note_number_to_name, 22 | program_to_instrument_class, 23 | program_to_instrument_name, 24 | ) 25 | 26 | if TYPE_CHECKING: 27 | from .multitrack import Multitrack 28 | from .track import Track 29 | 30 | __all__ = ["plot_multitrack", "plot_pianoroll", "plot_track"] 31 | 32 | 33 | def plot_pianoroll( 34 | ax: Axes, 35 | pianoroll: ndarray, 36 | is_drum: bool = False, 37 | resolution: int = None, 38 | beats: ndarray = None, 39 | downbeats: ndarray = None, 40 | preset: str = "full", 41 | cmap: str = "Blues", 42 | xtick: str = "auto", 43 | ytick: str = "octave", 44 | xticklabel: bool = True, 45 | yticklabel: str = "auto", 46 | tick_loc: Sequence[str] = ("bottom", "left"), 47 | tick_direction: str = "in", 48 | label: str = "both", 49 | grid_axis: str = "both", 50 | grid_color: str = "gray", 51 | grid_linestyle: str = ":", 52 | grid_linewidth: float = 0.5, 53 | **kwargs, 54 | ): 55 | """ 56 | Plot a piano roll. 57 | 58 | Parameters 59 | ---------- 60 | ax : :class:`matplotlib.axes.Axes` 61 | Axes to plot the piano roll on. 62 | pianoroll : ndarray, shape=(?, 128), (?, 128, 3) or (?, 128, 4) 63 | Piano roll to plot. For a 3D piano-roll array, the last axis can 64 | be either RGB or RGBA. 65 | is_drum : bool, default: False 66 | Whether it is a percussion track. 67 | resolution : int 68 | Time steps per quarter note. Required if `xtick` is 'beat'. 69 | beats : ndarray, dtype=int, shape=(?, 1), 70 | A boolean array that indicates the time steps that contain a 71 | beat. 72 | downbeats : ndarray, dtype=int, shape=(?, 1), 73 | A boolean array that indicates the time steps that contain a 74 | downbeat (i.e., the first time step of a bar). 75 | preset : {'full', 'frame', 'plain'}, default: 'full' 76 | Preset theme. For 'full' preset, ticks, grid and labels are on. 77 | For 'frame' preset, ticks and grid are both off. For 'plain' 78 | preset, the x- and y-axis are both off. 79 | cmap : str or :class:`matplotlib.colors.Colormap`, default: 'Blues' 80 | Colormap. Will be passed to :func:`matplotlib.pyplot.imshow`. 81 | Only effective when `pianoroll` is 2D. 82 | xtick : {'auto', 'beat', 'step', 'off'} 83 | Tick format for the x-axis. For 'auto' mode, set to 'beat' if 84 | `beats` is given, otherwise set to 'step'. Defaults to 'auto'. 85 | ytick : {'octave', 'pitch', 'off'}, default: 'octave' 86 | Tick format for the y-axis. 87 | xticklabel : bool 88 | Whether to add tick labels along the x-axis. 89 | yticklabel : {'auto', 'name', 'number', 'off'}, default: 'auto' 90 | Tick label format for the y-axis. For 'name' mode, use pitch 91 | name as tick labels. For 'number' mode, use pitch number. For 92 | 'auto' mode, set to 'name' if `ytick` is 'octave' and 'number' 93 | if `ytick` is 'pitch'. 94 | tick_loc : sequence of {'bottom', 'top', 'left', 'right'} 95 | Tick locations. Defaults to `('bottom', 'left')`. 96 | tick_direction : {'in', 'out', 'inout'}, default: 'in' 97 | Tick direction. 98 | label : {'x', 'y', 'both', 'off'}, default: 'both' 99 | Whether to add labels to x- and y-axes. 100 | grid_axis : {'x', 'y', 'both', 'off'}, default: 'both' 101 | Whether to add grids to the x- and y-axes. 102 | grid_color : str, default: 'gray' 103 | Grid color. Will be passed to :meth:`matplotlib.axes.Axes.grid`. 104 | grid_linestyle : str, default: '-' 105 | Grid line style. Will be passed to 106 | :meth:`matplotlib.axes.Axes.grid`. 107 | grid_linewidth : float, default: 0.5 108 | Grid line width. Will be passed to 109 | :meth:`matplotlib.axes.Axes.grid`. 110 | **kwargs 111 | Keyword arguments to be passed to 112 | :meth:`matplotlib.axes.Axes.imshow`. 113 | 114 | """ 115 | # Plot the piano roll 116 | if pianoroll.ndim == 2: 117 | transposed = pianoroll.T 118 | elif pianoroll.ndim == 3: 119 | transposed = pianoroll.transpose(1, 0, 2) 120 | else: 121 | raise ValueError("`pianoroll` must be a 2D or 3D numpy array") 122 | 123 | img = ax.imshow( 124 | transposed, 125 | cmap=cmap, 126 | aspect="auto", 127 | vmin=0, 128 | vmax=1 if pianoroll.dtype == np.bool_ else 127, 129 | origin="lower", 130 | interpolation="none", 131 | **kwargs, 132 | ) 133 | 134 | # Format ticks and labels 135 | if xtick == "auto": 136 | xtick = "beat" if beats is not None else "step" 137 | elif xtick not in ("beat", "step", "off"): 138 | raise ValueError( 139 | "`xtick` must be one of 'auto', 'beat', 'step' or 'off', not " 140 | f"{xtick}." 141 | ) 142 | if yticklabel == "auto": 143 | yticklabel = "name" if ytick == "octave" else "number" 144 | elif yticklabel not in ("name", "number", "off"): 145 | raise ValueError( 146 | "`yticklabel` must be one of 'auto', 'name', 'number' or 'off', " 147 | f"{yticklabel}." 148 | ) 149 | 150 | if preset == "full": 151 | ax.tick_params( 152 | direction=tick_direction, 153 | bottom=("bottom" in tick_loc), 154 | top=("top" in tick_loc), 155 | left=("left" in tick_loc), 156 | right=("right" in tick_loc), 157 | labelbottom=xticklabel, 158 | labelleft=(yticklabel != "off"), 159 | labeltop=False, 160 | labelright=False, 161 | ) 162 | elif preset == "frame": 163 | ax.tick_params( 164 | direction=tick_direction, 165 | bottom=False, 166 | top=False, 167 | left=False, 168 | right=False, 169 | labelbottom=False, 170 | labeltop=False, 171 | labelleft=False, 172 | labelright=False, 173 | ) 174 | elif preset == "plain": 175 | ax.axis("off") 176 | else: 177 | raise ValueError( 178 | f"`preset` must be one of 'full', 'frame' or 'plain', not {preset}" 179 | ) 180 | 181 | # Format x-axis 182 | if xtick == "beat" and preset != "frame": 183 | if beats is None: 184 | raise RuntimeError( 185 | "Beats must be given when using `beat` for ticks on the " 186 | "x-axis." 187 | ) 188 | if len(beats) < 2: 189 | raise RuntimeError( 190 | "There muse be at least two beats given when using `beat` for " 191 | "ticks on the x-axis." 192 | ) 193 | beats_arr = np.append(beats, beats[-1] + (beats[-1] - beats[-2])) 194 | ax.set_xticks(beats_arr[:-1] - 0.5) 195 | ax.set_xticklabels("") 196 | ax.set_xticks((beats_arr[1:] + beats_arr[:-1]) / 2 - 0.5, minor=True) 197 | ax.set_xticklabels(np.arange(1, len(beats) + 1), minor=True) 198 | ax.tick_params(axis="x", which="minor", width=0) 199 | 200 | # Format y-axis 201 | if ytick == "octave": 202 | ax.set_yticks(np.arange(0, 128, 12)) 203 | if yticklabel == "name": 204 | ax.set_yticklabels([f"C{i - 2}" for i in range(11)]) 205 | elif ytick == "step": 206 | ax.set_yticks(np.arange(0, 128)) 207 | if yticklabel == "name": 208 | if is_drum: 209 | ax.set_yticklabels( 210 | [note_number_to_drum_name(i) for i in range(128)] 211 | ) 212 | else: 213 | ax.set_yticklabels( 214 | [note_number_to_name(i) for i in range(128)] 215 | ) 216 | elif ytick != "off": 217 | raise ValueError( 218 | f"`ytick` must be one of 'octave', 'pitch' or 'off', not {ytick}." 219 | ) 220 | 221 | # Format axis labels 222 | if label not in ("x", "y", "both", "off"): 223 | raise ValueError( 224 | f"`label` must be one of 'x', 'y', 'both' or 'off', not {label}." 225 | ) 226 | 227 | if label in ("x", "both"): 228 | if xtick == "step" or not xticklabel: 229 | ax.set_xlabel("time (step)") 230 | else: 231 | ax.set_xlabel("time (beat)") 232 | 233 | if label in ("y", "both"): 234 | if is_drum: 235 | ax.set_ylabel("key name") 236 | else: 237 | ax.set_ylabel("pitch") 238 | 239 | # Plot the grid 240 | if grid_axis not in ("x", "y", "both", "off"): 241 | raise ValueError( 242 | "`grid` must be one of 'x', 'y', 'both' or 'off', not " 243 | f"{grid_axis}." 244 | ) 245 | if grid_axis != "off": 246 | ax.grid( 247 | axis=grid_axis, 248 | color=grid_color, 249 | linestyle=grid_linestyle, 250 | linewidth=grid_linewidth, 251 | ) 252 | 253 | # Plot downbeat boundaries 254 | if downbeats is not None: 255 | for downbeat in downbeats: 256 | ax.axvline(x=downbeat, color="k", linewidth=1) 257 | 258 | return img 259 | 260 | 261 | def plot_track(track: "Track", ax: Axes = None, **kwargs) -> Axes: 262 | """ 263 | Plot a track. 264 | 265 | Parameters 266 | ---------- 267 | track : :class:`pypianoroll.Track` 268 | Track to plot. 269 | ax : :class:`matplotlib.axes.Axes`, optional 270 | Axes to plot the piano roll on. Defaults to call `plt.gca()`. 271 | **kwargs 272 | Keyword arguments to pass to :func:`pypianoroll.plot_pianoroll`. 273 | 274 | Returns 275 | ------- 276 | :class:`matplotlib.axes.Axes` 277 | (Created) Axes object. 278 | 279 | """ 280 | if ax is None: 281 | ax = plt.gca() 282 | plot_pianoroll(ax, track.pianoroll, track.is_drum, **kwargs) 283 | return ax 284 | 285 | 286 | def _get_track_label(track_label, track=None): 287 | """Return corresponding track labels.""" 288 | if track_label == "name": 289 | return track.name 290 | if track_label == "program": 291 | return program_to_instrument_name(track.program) 292 | if track_label == "family": 293 | return program_to_instrument_class(track.program) 294 | return track_label 295 | 296 | 297 | def _add_tracklabel(ax, track_label, track=None): 298 | """Add a track label to an axis.""" 299 | if not ax.get_ylabel(): 300 | return 301 | ax.set_ylabel( 302 | f"{_get_track_label(track_label, track)}\n\n{ax.get_ylabel()}" 303 | ) 304 | 305 | 306 | def plot_multitrack( 307 | multitrack: "Multitrack", 308 | axs: Sequence[Axes] = None, 309 | mode: str = "separate", 310 | track_label: str = "name", 311 | preset: str = "full", 312 | cmaps: Sequence[str] = None, 313 | xtick: str = "auto", 314 | ytick: str = "octave", 315 | xticklabel: bool = True, 316 | yticklabel: str = "auto", 317 | tick_loc: Sequence[str] = ("bottom", "left"), 318 | tick_direction: str = "in", 319 | label: str = "both", 320 | grid_axis: str = "both", 321 | grid_color: str = "gray", 322 | grid_linestyle: str = "-", 323 | grid_linewidth: float = 0.5, 324 | **kwargs, 325 | ) -> List[Axes]: 326 | """ 327 | Plot the multitrack. 328 | 329 | Parameters 330 | ---------- 331 | multitrack : :class:`pypianoroll.Multitrack` 332 | Multitrack to plot. 333 | axs : sequence of :class:`matplotlib.axes.Axes`, optional 334 | Axes to plot the tracks on. 335 | mode : {'separate', 'blended', 'hybrid'}, default: 'separate' 336 | Plotting strategy for visualizing multiple tracks. For 337 | 'separate' mode, plot each track separately. For 'blended', 338 | blend and plot the pianoroll as a colored image. For 'hybrid' 339 | mode, drum tracks are blended into a 'Drums' track and all 340 | other tracks are blended into an 'Others' track. 341 | track_label : {'name', 'program', 'family', 'off'} 342 | Track label format. When `mode` is 'hybrid', all options other 343 | than 'off' will label the two track with 'Drums' and 'Others'. 344 | preset : {'full', 'frame', 'plain'}, default: 'full' 345 | Preset theme to use. For 'full' preset, ticks, grid and labels 346 | are on. For 'frame' preset, ticks and grid are both off. For 347 | 'plain' preset, the x- and y-axis are both off. 348 | cmaps : tuple or list 349 | Colormaps. Will be passed to :func:`matplotlib.pyplot.imshow`. 350 | Only effective when `pianoroll` is 2D. Defaults to 'Blues'. 351 | If `mode` is 'separate', defaults to `('Blues', 'Oranges', 352 | 'Greens', 'Reds', 'Purples', 'Greys')`. If `mode` is 'blended', 353 | defaults to `('hsv')`. If `mode` is 'hybrid', defaults to 354 | `('Blues', 'Greens')`. 355 | **kwargs 356 | Keyword arguments to pass to :func:`pypianoroll.plot_pianoroll`. 357 | 358 | Returns 359 | ------- 360 | list of :class:`matplotlib.axes.Axes` 361 | (Created) list of Axes objects. 362 | 363 | """ 364 | if not multitrack.tracks: 365 | raise RuntimeError("There is no track to plot.") 366 | if track_label not in ("name", "program", "family", "off"): 367 | raise ValueError( 368 | "`track_label` must be one of 'name', 'program' or 'family', not " 369 | f"{track_label}." 370 | ) 371 | 372 | if axs is not None and not isinstance(axs, list): 373 | axs = list(axs) 374 | 375 | # Set default color maps 376 | if cmaps is None: 377 | if mode == "separate": 378 | cmaps = ("Blues", "Oranges", "Greens", "Reds", "Purples", "Greys") 379 | elif mode == "blended": 380 | cmaps = ("hsv",) 381 | else: 382 | cmaps = ("Blues", "Greens") 383 | 384 | n_tracks = len(multitrack.tracks) 385 | beats = multitrack.get_beat_steps() 386 | downbeats = multitrack.get_downbeat_steps() 387 | 388 | if mode == "separate": 389 | if axs is None: 390 | if n_tracks > 1: 391 | fig, axs_ = plt.subplots(n_tracks, sharex=True) 392 | fig.subplots_adjust(hspace=0) 393 | axs = axs_.tolist() 394 | else: 395 | fig, ax = plt.subplots() 396 | axs = [ax] 397 | 398 | for idx, track in enumerate(multitrack.tracks): 399 | now_xticklabel = xticklabel if idx < n_tracks else False 400 | plot_pianoroll( 401 | ax=axs[idx], 402 | pianoroll=track.pianoroll, 403 | is_drum=False, 404 | resolution=multitrack.resolution, 405 | beats=beats, 406 | downbeats=downbeats, 407 | preset=preset, 408 | cmap=cmaps[idx % len(cmaps)], 409 | xtick=xtick, 410 | ytick=ytick, 411 | xticklabel=now_xticklabel, 412 | yticklabel=yticklabel, 413 | tick_loc=tick_loc, 414 | tick_direction=tick_direction, 415 | label=label, 416 | grid_axis=grid_axis, 417 | grid_color=grid_color, 418 | grid_linestyle=grid_linestyle, 419 | grid_linewidth=grid_linewidth, 420 | **kwargs, 421 | ) 422 | if track_label != "none": 423 | _add_tracklabel(axs[idx], track_label, track) 424 | 425 | elif mode == "blended": 426 | is_all_drum = True 427 | for track in multitrack.tracks: 428 | if not track.is_drum: 429 | is_all_drum = False 430 | 431 | if axs is None: 432 | fig, ax = plt.subplots() 433 | axs = [ax] 434 | 435 | stacked = multitrack.stack() 436 | 437 | colormap = matplotlib.cm.get_cmap(cmaps[0]) 438 | colormatrix = colormap(np.arange(0, 1, 1 / n_tracks))[:, :3] 439 | recolored = np.clip( 440 | np.matmul(stacked.reshape(-1, n_tracks), colormatrix), 0, 1 441 | ) 442 | blended = recolored.reshape(stacked.shape[1:] + (3,)) 443 | 444 | plot_pianoroll( 445 | ax=axs[0], 446 | pianoroll=blended, 447 | is_drum=is_all_drum, 448 | resolution=multitrack.resolution, 449 | beats=beats, 450 | downbeats=downbeats, 451 | preset=preset, 452 | xtick=xtick, 453 | ytick=ytick, 454 | xticklabel=xticklabel, 455 | yticklabel=yticklabel, 456 | tick_loc=tick_loc, 457 | tick_direction=tick_direction, 458 | label=label, 459 | grid_axis=grid_axis, 460 | grid_color=grid_color, 461 | grid_linestyle=grid_linestyle, 462 | grid_linewidth=grid_linewidth, 463 | **kwargs, 464 | ) 465 | 466 | if track_label != "none": 467 | patches = [ 468 | Patch( 469 | color=colormatrix[idx], 470 | label=_get_track_label(track_label, track), 471 | ) 472 | for idx, track in enumerate(multitrack.tracks) 473 | ] 474 | plt.legend(handles=patches) 475 | 476 | elif mode == "hybrid": 477 | drums = multitrack.copy() 478 | drums.tracks = [track for track in multitrack.tracks if track.is_drum] 479 | merged_drums = drums.blend() 480 | 481 | others = multitrack.copy() 482 | others.tracks = [ 483 | track for track in multitrack.tracks if not track.is_drum 484 | ] 485 | merged_others = others.blend() 486 | 487 | if axs is None: 488 | fig, axs_ = plt.subplots(2, sharex=True, sharey=True) 489 | axs = axs_.tolist() 490 | else: 491 | fig = plt.gcf() 492 | 493 | plot_pianoroll( 494 | ax=axs[0], 495 | pianoroll=merged_drums, 496 | is_drum=True, 497 | resolution=multitrack.resolution, 498 | beats=beats, 499 | downbeats=downbeats, 500 | preset=preset, 501 | cmap=cmaps[0], 502 | xtick=xtick, 503 | ytick=ytick, 504 | xticklabel=xticklabel, 505 | yticklabel=yticklabel, 506 | tick_loc=tick_loc, 507 | tick_direction=tick_direction, 508 | label=label, 509 | grid_axis=grid_axis, 510 | grid_color=grid_color, 511 | grid_linestyle=grid_linestyle, 512 | grid_linewidth=grid_linewidth, 513 | **kwargs, 514 | ) 515 | plot_pianoroll( 516 | ax=axs[1], 517 | pianoroll=merged_others, 518 | is_drum=False, 519 | resolution=multitrack.resolution, 520 | beats=beats, 521 | downbeats=downbeats, 522 | preset=preset, 523 | cmap=cmaps[1], 524 | ytick=ytick, 525 | xticklabel=xticklabel, 526 | yticklabel=yticklabel, 527 | tick_loc=tick_loc, 528 | tick_direction=tick_direction, 529 | label=label, 530 | grid_axis=grid_axis, 531 | grid_color=grid_color, 532 | grid_linestyle=grid_linestyle, 533 | grid_linewidth=grid_linewidth, 534 | **kwargs, 535 | ) 536 | fig.subplots_adjust(hspace=0) 537 | 538 | if track_label != "none": 539 | _add_tracklabel(axs[0], "Drums") 540 | _add_tracklabel(axs[1], "Others") 541 | 542 | else: 543 | raise ValueError( 544 | "`mode` must be one of 'separate', 'blended' or 'hybrid', not" 545 | f"{mode}." 546 | ) 547 | 548 | return axs # type: ignore 549 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | 4 | [tool.isort] 5 | profile = "black" 6 | line_length = 79 7 | 8 | [tool.pytest.ini_options] 9 | minversion = "6.0" 10 | addopts = "-ra -q" 11 | testpaths = "tests" 12 | 13 | [tool.pylint.master] 14 | ignore-patterns = "test_.*?py" 15 | 16 | [tool.pylint.basic] 17 | good-names = "i,j,k,_,a,b,c,x,y,t,n,ax,f,T" 18 | 19 | [tool.pylint.messages_control] 20 | disable = "R,C0330,C0326,W0511" 21 | 22 | [tool.pylint.format] 23 | max-line-length = 79 24 | 25 | [tool.mypy] 26 | ignore_missing_imports = true 27 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203,F401,F403,W503,D105 3 | exclude = test_*.py 4 | max-doc-length = 72 5 | docstring-convention = numpy 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup script.""" 2 | from pathlib import Path 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def _get_long_description(): 8 | with open(str(Path(__file__).parent / "README.md"), "r") as f: 9 | return f.read() 10 | 11 | 12 | def _get_version(): 13 | with open(str(Path(__file__).parent / "pypianoroll/version.py"), "r") as f: 14 | for line in f: 15 | if line.startswith("__version__"): 16 | delimeter = '"' if '"' in line else "'" 17 | return line.split(delimeter)[1] 18 | raise RuntimeError("Cannot read version string.") 19 | 20 | 21 | VERSION = _get_version() 22 | 23 | setup( 24 | name="pypianoroll", 25 | version=VERSION, 26 | author="Hao-Wen Dong", 27 | author_email="salu.hwdong@gmail.com", 28 | description="A toolkit for working with piano rolls", 29 | long_description=_get_long_description(), 30 | long_description_content_type="text/markdown", 31 | download_url=( 32 | f"https://github.com/salu133445/pypianoroll/archive/v{VERSION}.tar.gz" 33 | ), 34 | project_urls={ 35 | "Documentation": "https://salu133445.github.io/pypianoroll/" 36 | }, 37 | license="MIT", 38 | keywords=["music", "audio", "music-information-retrieval"], 39 | packages=find_packages( 40 | include=["pypianoroll", "pypianoroll.*"], exclude=["tests"] 41 | ), 42 | install_requires=[ 43 | "numpy>=1.12.0", 44 | "scipy>=1.0.0", 45 | "pretty_midi>=0.2.8", 46 | "matplotlib>=1.5", 47 | ], 48 | extras_require={ 49 | "test": ["pytest>=6.0", "pytest-cov>=2.0"], 50 | }, 51 | classifiers=[ 52 | "Development Status :: 5 - Production/Stable", 53 | "License :: OSI Approved :: MIT License", 54 | "Programming Language :: Python :: 3", 55 | "Topic :: Multimedia :: Sound/Audio", 56 | ], 57 | python_requires=">=3.6", 58 | ) 59 | --------------------------------------------------------------------------------