├── .github └── workflows │ ├── ci.yml │ ├── deploy_docs.yml │ └── pypi_release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── AUTHORS.md ├── CITATION.cff ├── LICENSE.md ├── README.md ├── TODOs.md ├── docs ├── Makefile ├── images │ ├── logo_rectangle.png │ ├── logo_rectangle.svg │ ├── logo_rectangle_dark.png │ ├── logo_rectangle_dark.svg │ ├── logo_square.png │ ├── logo_square.svg │ ├── logo_square_dark.png │ ├── logo_square_dark.svg │ ├── mmm_tokens.png │ └── musicaiz_general.png ├── make.bat └── source │ ├── _static │ └── css │ │ └── custom.css │ ├── algorithms.rst │ ├── conf.py │ ├── converters.rst │ ├── datasets.rst │ ├── evaluation.rst │ ├── examples │ └── README.txt │ ├── features.rst │ ├── genindex.rst │ ├── glossary.rst │ ├── harmony.rst │ ├── implementations.rst │ ├── index.rst │ ├── install.rst │ ├── introduction.rst │ ├── loaders.rst │ ├── models.rst │ ├── notebooks │ └── 3-plot.ipynb │ ├── plotters.rst │ ├── rhythm.rst │ ├── structure.rst │ └── tokenizers.rst ├── environment.yml ├── musicaiz ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── chord_prediction.py │ ├── harmonic_shift.py │ └── key_profiles.py ├── converters │ ├── __init__.py │ ├── musa_json.py │ ├── musa_protobuf.py │ ├── pretty_midi_musa.py │ └── protobuf │ │ ├── README.md │ │ ├── __init__.py │ │ ├── musicaiz.proto │ │ ├── musicaiz_pb2.py │ │ └── musicaiz_pb2.pyi ├── datasets │ ├── __init__.py │ ├── bps_fh.py │ ├── configs.py │ ├── jsbchorales.py │ ├── lmd.py │ ├── maestro.py │ └── utils.py ├── eval.py ├── features │ ├── __init__.py │ ├── graphs.py │ ├── harmony.py │ ├── pitch.py │ ├── predict_midi.py │ ├── rhythm.py │ ├── self_similarity.py │ └── structure.py ├── harmony │ ├── __init__.py │ ├── chords.py │ ├── intervals.py │ └── keys.py ├── loaders.py ├── models │ ├── __init__.py │ └── transformer_composers │ │ ├── __init__.py │ │ ├── configs.py │ │ ├── dataset.py │ │ ├── generate.py │ │ ├── train.py │ │ └── transformers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── gpt.py │ │ └── layers.py ├── plotters │ ├── __init__.py │ └── pianorolls.py ├── rhythm │ ├── __init__.py │ ├── quantizer.py │ └── timing.py ├── structure │ ├── __init__.py │ ├── bars.py │ ├── instruments.py │ └── notes.py ├── tokenizers │ ├── __init__.py │ ├── cpword.py │ ├── encoder.py │ ├── mmm.py │ ├── one_hot.py │ └── remi.py ├── utils.py ├── version.py └── wrappers.py ├── pyproject.toml ├── requirements-docs.txt ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── fixtures ├── datasets │ ├── jsbchorales │ │ ├── test │ │ │ ├── 14.mid │ │ │ ├── 27.mid │ │ │ └── 9.mid │ │ ├── train │ │ │ ├── 2.mid │ │ │ ├── 3.mid │ │ │ └── 4.mid │ │ └── valid │ │ │ └── 1.mid │ ├── lmd │ │ └── ABBA │ │ │ └── Andante, Andante.mid │ └── maestro │ │ ├── maestro-v2.0.0.csv │ │ └── maestro-v2.0.0 │ │ ├── 2018 │ │ └── MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R3_2018_wav--1.midi │ │ └── maestro-v2.0.0.csv ├── midis │ ├── midi_changes.mid │ └── midi_data.mid └── tokenizers │ ├── cpword_tokens.txt │ ├── mmm_multiple_tokens.txt │ ├── mmm_tokens.mid │ ├── mmm_tokens.txt │ └── remi_tokens.txt └── unit ├── __init__.py └── musicaiz ├── __init__.py ├── algorithms ├── __init__.py ├── test_chord_prediction.py ├── test_harmonic_shift.py └── test_key_profiles.py ├── converters ├── __init__.py ├── test_musa_json.py ├── test_musa_to_protobuf.py └── test_pretty_midi_musa.py ├── datasets ├── __init__.py ├── asserts.py ├── test_bps_fh.py ├── test_jsbchorales.py ├── test_lmd.py └── test_maestro.py ├── eval ├── __init__.py └── test_eval.py ├── features ├── __init__.py ├── test_graphs.py ├── test_harmony.py ├── test_pitch.py ├── test_rhythm.py └── test_structure.py ├── harmony ├── __init__.py ├── test_chords.py ├── test_intervals.py └── test_keys.py ├── loaders ├── __init__.py └── test_loaders.py ├── plotters ├── __init__.py └── test_pianorolls.py ├── rhythm ├── __init__.py ├── test_quantizer.py └── test_timing.py ├── structure ├── __init__.py ├── test_instruments.py └── test_notes.py └── tokenizers ├── __init__.py ├── test_cpword.py ├── test_mmm.py └── test_remi.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: True 14 | 15 | jobs: 16 | test: 17 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 18 | runs-on: ${{ matrix.os }} 19 | 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | include: 24 | - os: ubuntu-latest 25 | python-version: "3.8" 26 | channel-priority: "strict" 27 | envfile: "environment.yml" 28 | 29 | - os: ubuntu-latest 30 | python-version: "3.9" 31 | channel-priority: "strict" 32 | envfile: "environment.yml" 33 | 34 | - os: macos-latest 35 | python-version: "3.9" 36 | channel-priority: "strict" 37 | envfile: "environment.yml" 38 | 39 | - os: windows-latest 40 | python-version: "3.9" 41 | channel-priority: "strict" 42 | envfile: "environment.yml" 43 | 44 | - os: ubuntu-latest 45 | python-version: "3.10" 46 | channel-priority: "strict" 47 | envfile: "environment.yml" 48 | 49 | steps: 50 | - uses: actions/checkout@v2 51 | with: 52 | submodules: true 53 | 54 | 55 | - name: Cache conda 56 | uses: actions/cache@v2 57 | env: 58 | CACHE_NUMBER: 1 59 | with: 60 | path: ~/conda_pkgs_dir 61 | key: ${{ runner.os }}-${{ matrix.python-version }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles( matrix.envfile ) }} 62 | 63 | - name: Install Conda environment 64 | uses: conda-incubator/setup-miniconda@v2 65 | with: 66 | auto-update-conda: true 67 | python-version: ${{ matrix.python-version }} 68 | add-pip-as-python-dependency: true 69 | auto-activate-base: false 70 | activate-environment: test 71 | channel-priority: ${{ matrix.channel-priority }} 72 | environment-file: ${{ matrix.envfile }} 73 | use-only-tar-bz2: true 74 | 75 | - name: Conda info 76 | shell: bash -l {0} 77 | run: | 78 | conda info -a 79 | conda list 80 | - name: Install musicaiz 81 | shell: bash -l {0} 82 | run: python -m pip install --upgrade-strategy only-if-needed -e .[tests] 83 | 84 | - name: Run pytest 85 | shell: bash -l {0} 86 | run: pytest 87 | 88 | - name: Upload coverage to Codecov 89 | uses: codecov/codecov-action@v3 90 | with: 91 | token: ${{ secrets.CODECOV_TOKEN }} 92 | files: ./coverage.xml 93 | directory: ./coverage/reports/ 94 | flags: unittests 95 | env_vars: OS,PYTHON 96 | name: codecov-umbrella 97 | fail_ci_if_error: true 98 | verbose: true -------------------------------------------------------------------------------- /.github/workflows/deploy_docs.yml: -------------------------------------------------------------------------------- 1 | name: Build and deploy documentation 2 | 3 | on: 4 | release: 5 | types: [published] 6 | branches: 7 | - main 8 | 9 | jobs: 10 | build-and-deploy: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v2 16 | 17 | - name: Setup Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | 22 | - name: Install dependencies 23 | run: | 24 | pip install -r requirements.txt 25 | pip install -r requirements-docs.txt 26 | 27 | - name: Build documentation 28 | run: | 29 | make html 30 | 31 | - name: Configure Git 32 | run: | 33 | git config --local user.email "action@github.com" 34 | git config --local user.name "GitHub Action" 35 | 36 | - name: Deploy documentation 37 | uses: peaceiris/actions-gh-pages@v3 38 | with: 39 | personal_token: ${{ secrets.GITHUB_TOKEN }} 40 | publish_dir: ./ 41 | publish_branch: gh-pages -------------------------------------------------------------------------------- /.github/workflows/pypi_release.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | release: 8 | types: [published] 9 | 10 | jobs: 11 | build-n-publish: 12 | name: Build and publish to PyPI 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout source 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: "3.x" 23 | 24 | - name: Build source and wheel distributions 25 | run: | 26 | python -m pip install --upgrade build twine 27 | python -m build 28 | twine check --strict dist/* 29 | - name: Publish distribution to PyPI 30 | uses: pypa/gh-action-pypi-publish@main 31 | with: 32 | user: __token__ 33 | password: ${{ secrets.PYPI_API_TOKEN }} 34 | 35 | - name: Create GitHub Release 36 | id: create_release 37 | uses: actions/create-release@v1 38 | env: 39 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 40 | with: 41 | tag_name: ${{ github.ref }} 42 | release_name: ${{ github.ref }} 43 | draft: false 44 | prerelease: false 45 | 46 | - name: Get Asset name 47 | run: | 48 | export PKG=$(ls dist/ | grep tar) 49 | set -- $PKG 50 | echo "name=$1" >> $GITHUB_ENV 51 | - name: Upload Release Asset (sdist) to GitHub 52 | id: upload-release-asset 53 | uses: actions/upload-release-asset@v1 54 | env: 55 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 56 | with: 57 | upload_url: ${{ steps.create_release.outputs.upload_url }} 58 | asset_path: dist/${{ env.name }} 59 | asset_name: ${{ env.name }} 60 | asset_content_type: application/zip -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.ipynb_checkpoints/ 4 | *.idea/ 5 | *.py[cod] 6 | 7 | # Distribution / packaging 8 | .Python 9 | env/ 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | 25 | # PyInstaller 26 | # Usually these files are written by a python script from a template 27 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 28 | *.manifest 29 | *.spec 30 | 31 | # Sphinx documentation 32 | docs/_build/ 33 | docs/build/ 34 | docs/source/generated/ 35 | docs/source/auto_examples/ 36 | 37 | # PyBuilder 38 | target/ 39 | 40 | # Coverage 41 | htmlcov/ 42 | .coverage 43 | 44 | # Generated files 45 | .DS_Store 46 | .vscode/ 47 | 48 | # direnv 49 | .env 50 | 51 | # vim 52 | *.swp 53 | 54 | # binary data 55 | binaries/ 56 | results/ 57 | 58 | # notebooks 59 | notebooks/ 60 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 20.8b1 4 | hooks: 5 | - id: black 6 | name: black 7 | entry: black src/ 8 | language: python 9 | stages: [commit] 10 | always_run: true 11 | - repo: https://gitlab.com/PyCQA/flake8 12 | rev: 3.8.3 13 | hooks: 14 | - id: flake8 15 | stages: [commit] 16 | - repo: local 17 | hooks: 18 | - id: tests 19 | name: tests 20 | entry: pytest 21 | language: python 22 | verbose: true 23 | pass_filenames: false 24 | always_run: true 25 | stages: [push] 26 | - id: typing 27 | name: typing 28 | entry: mypy src/ 29 | verbose: true 30 | always_run: true 31 | language: python 32 | stages: [push] 33 | pass_filenames: false 34 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: ubuntu-20.04 10 | tools: 11 | python: "3.9" 12 | 13 | python: 14 | install: 15 | - requirements: requirements-docs.txt 16 | - requirements: requirements.txt 17 | 18 | formats: 19 | - pdf 20 | - epub -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | Authors 2 | ============ 3 | 4 | * Carlos Hernandez-Olivan [web](https://carlosholivan.github.io) 5 | 6 | Contributors 7 | ============ 8 | * Ignacio Zay Pinilla, University of Zaragoza: quantization algorithm 9 | * Sonia Rubio Llamas, University of Zaragoza: symbolic music structure analysis experiments 10 | * Sergio Ferraz Laplana, University of Zaragoza: Compound Word tokenizer -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as shown below." 3 | title: musicaiz 4 | authors: 5 | - family-names: Hernandez-Olivan 6 | given-names: Carlos 7 | preferred-citation: 8 | type: article 9 | authors: 10 | - family-names: Hernandez-Olivan 11 | given-names: Carlos 12 | orcid: "https://orcid.org/0000-0002-0235-2267" 13 | - family-names: Beltran 14 | given-names: Jose R. 15 | orcid: "https://orcid.org/0000-0002-7500-4650" 16 | title: "musicaiz: A Python Library for Symbolic Music Generation, Analysis and Visualization" 17 | journal: arXiv 18 | year: 2022 19 | license: AGPL-3.0 20 | date-released: 2022-09-15 21 | url: "https://carlosholivan.github.io/musicaiz/" 22 | repository-code: "https://github.com/carlosholivan/musicaiz" -------------------------------------------------------------------------------- /TODOs.md: -------------------------------------------------------------------------------- 1 | ## Improvements 2 | 3 | ### Tokenizerrs/Models 4 | 5 | - Vocabulary as a dictionary (now is a str) 6 | 7 | ### Converters 8 | - [ ] Add MusicXML 9 | - [ ] Add ABC notation. 10 | - [ ] JSON to musicaiz objects. 11 | 12 | ### Plotters 13 | - [ ] Adjust plotters. Plot in secs or ticks and be careful with tick labels in plots that have too much data, 14 | numbers can overlap and the plot won't be clean. 15 | 16 | ### Harmony 17 | - [ ] Measure just the correct interval (and not all the possible intervals based on the pitch) if note name is known (now it measures all the possible intervals given pitch, but if we do know the note name the interval is just one). 18 | - [ ] Support key changes in middle of a piece when loading with ``loaders.Musa`` object. 19 | - [ ] Initialize note names correctly if key or tonality is known (know the note name initialization is arbitrary, can be the enharmonic or not) 20 | 21 | ### Features 22 | - [ ] Function to compute: Polyphonic rate 23 | - [ ] Function to compute: Polyphony 24 | 25 | ### Tokenizers 26 | - [ ] MusicTransformer 27 | - [ ] Octuple 28 | - [ ] Compound Word 29 | 30 | ### Synthesis 31 | - [ ] Add function to synthesize a ``loaders.Musa`` object (can be inherited from ``pretty_midi``). 32 | 33 | ### Other TODOs 34 | - [ ] Harmony: cadences 35 | - [ ] Rhythm: sincopation 36 | - [ ] Synzesize notes to be able to play chords, intervals, scales...(this might end being a plugin for composition assistance). -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/images/logo_rectangle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/docs/images/logo_rectangle.png -------------------------------------------------------------------------------- /docs/images/logo_rectangle_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/docs/images/logo_rectangle_dark.png -------------------------------------------------------------------------------- /docs/images/logo_square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/docs/images/logo_square.png -------------------------------------------------------------------------------- /docs/images/logo_square_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/docs/images/logo_square_dark.png -------------------------------------------------------------------------------- /docs/images/mmm_tokens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/docs/images/mmm_tokens.png -------------------------------------------------------------------------------- /docs/images/musicaiz_general.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/docs/images/musicaiz_general.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | div.wy-side-nav-search .version { 2 | color: #404040; 3 | font-weight: bold; 4 | } 5 | 6 | nav.wy-nav-top { 7 | background: #6048cc; 8 | } 9 | 10 | div.wy-nav-content { 11 | max-width: 1000px; 12 | } 13 | 14 | span.caption-text { 15 | color: #9580f0; 16 | } 17 | 18 | .highlight { 19 | background-color: rgb(242, 255, 234); 20 | } 21 | 22 | 23 | /*Extends the docstring signature box.*/ 24 | .rst-content dl:not(.docutils) dt { 25 | display: block; 26 | padding: 10px; 27 | word-wrap: break-word; 28 | padding-right: 100px; 29 | } 30 | /*Lists in an admonition note do not have awkward whitespace below.*/ 31 | .rst-content .admonition-note .section ul { 32 | margin-bottom: 0px; 33 | } 34 | /*Properties become blue (classmethod, staticmethod, property)*/ 35 | .rst-content dl dt em.property { 36 | color: #2980b9; 37 | text-transform: uppercase; 38 | } 39 | 40 | .rst-content .section ol p, 41 | .rst-content .section ul p { 42 | margin-bottom: 0px; 43 | } 44 | 45 | /* Adjustment to Sphinx Book Theme */ 46 | .table td { 47 | /* Remove row spacing */ 48 | padding: 0; 49 | } 50 | 51 | table { 52 | /* Force full width for all table */ 53 | width: 136% !important; 54 | } 55 | 56 | img.inline-figure { 57 | /* Override the display: block for img */ 58 | display: inherit !important; 59 | } 60 | 61 | #version-warning-banner { 62 | /* Make version warning clickable */ 63 | z-index: 1; 64 | margin-left: 0; 65 | /* 20% is for ToC rightbar */ 66 | /* 2 * 1.5625em is for horizontal margins */ 67 | width: calc(100% - 20% - 2 * 1.5625em); 68 | } 69 | 70 | span.rst-current-version > span.fa.fa-book { 71 | /* Move the book icon away from the top right 72 | * corner of the version flyout menu */ 73 | margin: 10px 0px 0px 5px; 74 | } 75 | 76 | /* Adjustment to Version block */ 77 | .rst-versions { 78 | z-index: 1200 !important; 79 | } 80 | 81 | dt:target, span.highlighted { 82 | background-color: #fbe54e; 83 | } 84 | 85 | /* allow scrollable images */ 86 | .figure { 87 | max-width: 100%; 88 | overflow-x: auto; 89 | } 90 | img.horizontal-scroll { 91 | max-width: none; 92 | } 93 | 94 | .clear-both { 95 | clear: both; 96 | min-height: 100px; 97 | margin-top: 15px; 98 | } 99 | 100 | .buttons-float-left { 101 | width: 150px; 102 | float: left; 103 | } 104 | 105 | .buttons-float-right { 106 | width: 150px; 107 | float: right; 108 | } 109 | 110 | /* Wrap code blocks instead of horizontal scrolling. */ 111 | pre { 112 | white-space: pre-wrap; 113 | } 114 | 115 | /* notebook formatting */ 116 | .cell .cell_output { 117 | max-height: 250px; 118 | overflow-y: auto; 119 | } 120 | 121 | /* Yellow doesn't render well on light background */ 122 | .cell .cell_output pre .-Color-Yellow { 123 | color: #785840; 124 | } -------------------------------------------------------------------------------- /docs/source/algorithms.rst: -------------------------------------------------------------------------------- 1 | .. _structure: 2 | 3 | .. automodule:: musicaiz.algorithms -------------------------------------------------------------------------------- /docs/source/converters.rst: -------------------------------------------------------------------------------- 1 | .. _structure: 2 | 3 | .. automodule:: musicaiz.converters -------------------------------------------------------------------------------- /docs/source/datasets.rst: -------------------------------------------------------------------------------- 1 | .. _datasets: 2 | 3 | .. automodule:: musicaiz.datasets -------------------------------------------------------------------------------- /docs/source/evaluation.rst: -------------------------------------------------------------------------------- 1 | .. _loaders: 2 | 3 | .. automodule:: musicaiz.eval -------------------------------------------------------------------------------- /docs/source/examples/README.txt: -------------------------------------------------------------------------------- 1 | ----------------- 2 | Advanced examples 3 | ----------------- -------------------------------------------------------------------------------- /docs/source/features.rst: -------------------------------------------------------------------------------- 1 | .. _features: 2 | 3 | .. automodule:: musicaiz.features -------------------------------------------------------------------------------- /docs/source/genindex.rst: -------------------------------------------------------------------------------- 1 | Index 2 | ===== 3 | 4 | :ref:genindex -------------------------------------------------------------------------------- /docs/source/glossary.rst: -------------------------------------------------------------------------------- 1 | Glossary 2 | ======== 3 | 4 | Notes 5 | ------- 6 | .. glossary:: 7 | 8 | note 9 | ... 10 | pitch 11 | ... 12 | 13 | Harmony 14 | ------- 15 | .. glossary:: 16 | 17 | interval 18 | ... 19 | interval quality 20 | ... 21 | chord 22 | ... -------------------------------------------------------------------------------- /docs/source/harmony.rst: -------------------------------------------------------------------------------- 1 | .. _harmony: 2 | 3 | .. automodule:: musicaiz.harmony -------------------------------------------------------------------------------- /docs/source/implementations.rst: -------------------------------------------------------------------------------- 1 | Implementation Details 2 | ====================== 3 | 4 | Rhythm Features 5 | --------------- 6 | 7 | [1] Roig, C., Tardón, L. J., Barbancho, I., & Barbancho, A. M. (2014). Automatic melody composition based on a probabilistic model of music style and harmonic rules. Knowledge-Based Systems, 71, 419-434. http://dx.doi.org/10.1016/j.knosys.2014.08.018 8 | 9 | 10 | Structure Features 11 | ------------------ 12 | 13 | [2] Louie, W. MusicPlot: Interactive Self-Similarity Matrix for Music Structure Visualization. https://wlouie1.github.io/MusicPlot/musicplot_paper.pdf 14 | 15 | 16 | Quantization 17 | ------------ 18 | 19 | [3] https://www.fransabsil.nl/archpdf/advquant.pdf 20 | 21 | 22 | Evaluation 23 | ---------- 24 | 25 | [4] Yang, L. C., & Lerch, A. (2020). On the evaluation of generative models in music. Neural Computing and Applications, 32(9), 4773-4784. https://doi.org/10.1007/s00521-018-3849-7 26 | 27 | Tokenizers 28 | ---------- 29 | 30 | [5] Ens, J., & Pasquier, P. (2020). Flexible generation with the multi-track music machine. In Proceedings of the 21st International Society for Music Information Retrieval Conference, ISMIR. 31 | 32 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | ******** 2 | MusicAIz 3 | ******** 4 | 5 | `musicaiz` is a python package for symbolic music generation. It provides the building 6 | blocks necessary to create and evaluate symbolic discrete music. 7 | 8 | Musicaiz can be used for generating, analyzing or evaluating symbolic music data. 9 | 10 | .. image:: ../images/musicaiz_general.png 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | :caption: Getting Started: 15 | 16 | install 17 | introduction 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | :caption: Musicaiz Documentation: 22 | 23 | loaders 24 | structure 25 | harmony 26 | rhythm 27 | features 28 | algorithms 29 | tokenizers 30 | evaluation 31 | plotters 32 | converters 33 | datasets 34 | models 35 | 36 | 37 | .. toctree:: 38 | :maxdepth: 2 39 | :caption: Tutorials: 40 | 41 | notebooks/2-write_midi.ipynb 42 | notebooks/3-plot.ipynb 43 | 44 | 45 | .. toctree:: 46 | :maxdepth: 1 47 | :caption: References: 48 | 49 | implementations 50 | genindex 51 | glossary 52 | -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | Installation instructions 2 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 3 | 4 | pypi 5 | ~~~~ 6 | The simplest way to install *musicaiz* is through the Python Package Index (PyPI). 7 | This can be achieved by executing the following command:: 8 | 9 | pip install musicaiz 10 | 11 | or:: 12 | 13 | sudo pip install musicaiz 14 | 15 | 16 | 17 | Source 18 | ~~~~~~ 19 | 20 | If you've downloaded the archive manually from the `releases 21 | `_ page, you can install using the 22 | `setuptools` script:: 23 | 24 | tar xzf musicaiz-VERSION.tar.gz 25 | cd musicaiz-VERSION/ 26 | python setup.py install 27 | 28 | If you intend to develop musicaiz or make changes to the source code, you can 29 | install with `pip install -e` to link to your actively developed source tree:: 30 | 31 | tar xzf musicaiz-VERSION.tar.gz 32 | cd musicaiz-VERSION/ 33 | pip install -e . 34 | 35 | Alternately, the latest development version can be installed via pip:: 36 | 37 | pip install git+https://github.com/carlosholivan/musicaiz 38 | -------------------------------------------------------------------------------- /docs/source/introduction.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ^^^^^^^^^^^^ 3 | 4 | This section covers the fundamental usage of *musicaiz*, including 5 | a package overview, basic and advanced usage. 6 | 7 | Quickstart 8 | ~~~~~~~~~~ 9 | 10 | The goal of musicaiz is to provide a framwork for symbolic music generation. 11 | Musicaiz contains 3 basic modules that contains the music basic principles definitions: harmony, structure and rhythm. 12 | There are other modules that use these definitions to plot, tokenize, evaluate and generate symbolic music data. 13 | 14 | 15 | Analyze data 16 | ------------ 17 | 18 | Load a MIDI file: 19 | 20 | .. code-block:: python 21 | 22 | from musicaiz import loaders 23 | 24 | # load file 25 | midi = loaders.Musa("../files/mozart.mid") 26 | 27 | # get instruments 28 | instruments = midi.instruments 29 | 30 | 31 | Create data 32 | ----------- 33 | 34 | Obtain notes that belong to a chord: 35 | 36 | .. code-block:: python 37 | 38 | from musicaiz import harmony 39 | 40 | # Triad chord build from 1st degree of C major 41 | harmony.Tonality.get_chord_notes_from_degree(tonality="C_MAJOR", degree="I", scale="MAJOR") 42 | 43 | # Triad chord build from 1st degree of C major in dorian scale 44 | harmony.Tonality.get_chord_notes_from_degree(tonality="C_MAJOR", degree="I", scale="DORIAN") 45 | 46 | 47 | 48 | Tokenize 49 | -------- 50 | 51 | We can encode MIDI data as tokens with musicaiz. The current 52 | implementation supports the following tokenizers: 53 | 54 | `MMM`: Ens, J., & Pasquier, P. (2020). Mmm: Exploring conditional multi-track music generation with the transformer. 55 | 56 | .. image:: ../images/mmm_tokens.png 57 | :width: 75% 58 | :align: center 59 | 60 | -------------------------------------------------------------------------------- /docs/source/loaders.rst: -------------------------------------------------------------------------------- 1 | .. _loaders: 2 | 3 | .. automodule:: musicaiz.loaders -------------------------------------------------------------------------------- /docs/source/models.rst: -------------------------------------------------------------------------------- 1 | .. _structure: 2 | 3 | .. automodule:: musicaiz.models -------------------------------------------------------------------------------- /docs/source/plotters.rst: -------------------------------------------------------------------------------- 1 | .. _plotters: 2 | 3 | .. automodule:: musicaiz.plotters -------------------------------------------------------------------------------- /docs/source/rhythm.rst: -------------------------------------------------------------------------------- 1 | .. _rhythm: 2 | 3 | .. automodule:: musicaiz.rhythm -------------------------------------------------------------------------------- /docs/source/structure.rst: -------------------------------------------------------------------------------- 1 | .. _structure: 2 | 3 | .. automodule:: musicaiz.structure -------------------------------------------------------------------------------- /docs/source/tokenizers.rst: -------------------------------------------------------------------------------- 1 | .. _tokenizers: 2 | 3 | .. automodule:: musicaiz.tokenizers -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: musicaiz 2 | 3 | channels: 4 | - defaults 5 | - conda-forge 6 | 7 | dependencies: 8 | - jupyter 9 | - pip 10 | - coverage 11 | - pytest-mpl 12 | - pytest-cov 13 | - pytest 14 | 15 | - pip: 16 | - -r requirements.txt 17 | -------------------------------------------------------------------------------- /musicaiz/__init__.py: -------------------------------------------------------------------------------- 1 | # All the musicai sub-modules 2 | from .structure import * 3 | from .rhythm import * 4 | from .harmony import * 5 | from .features import * 6 | from .plotters import * 7 | from .tokenizers import * 8 | from .datasets import * 9 | from .converters import * 10 | from .algorithms import * 11 | from .models import * -------------------------------------------------------------------------------- /musicaiz/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithms 3 | ========== 4 | 5 | This module allows to extract, modify and predict different aspects of a music piece. 6 | 7 | 8 | Harmonic Shifting 9 | ----------------- 10 | 11 | .. autosummary:: 12 | :toctree: generated/ 13 | 14 | harmonic_shifting 15 | 16 | Key-Profiles 17 | ------------ 18 | 19 | State of the Art algorithms for key finding in symbolic music. 20 | The key-profiles weights in the library are the following: 21 | 22 | - KRUMHANSL_KESSLER 23 | 24 | [1] Krumhansl, C. L., & Kessler, E. J. (1982). 25 | Tracing the dynamic changes in perceived tonal organization in a spatial representation of musical keys. 26 | Psychological review, 89(4), 334. 27 | 28 | - TEMPERLEY 29 | 30 | [2] Temperley, D. (1999). 31 | What's key for key? The Krumhansl-Schmuckler key-finding algorithm reconsidered. 32 | Music Perception, 17(1), 65-100. 33 | 34 | - ALBRETCH_SHANAHAN 35 | 36 | [3] Albrecht, J., & Shanahan, D. (2013). 37 | The use of large corpora to train a new type of key-finding algorithm: An improved treatment of the minor mode. 38 | Music Perception: An Interdisciplinary Journal, 31(1), 59-67. 39 | 40 | - SIGNATURE_FIFTHS 41 | 42 | 43 | .. autosummary:: 44 | :toctree: generated/ 45 | 46 | KeyDetectionAlgorithms 47 | KrumhanslKessler 48 | Temperley 49 | AlbrechtShanahan 50 | key_detection 51 | 52 | Chord Prediction 53 | ---------------- 54 | 55 | Predict chords at beat-level. 56 | 57 | .. autosummary:: 58 | :toctree: generated/ 59 | 60 | predict_chords 61 | get_chords 62 | get_chords_candidates 63 | compute_chord_notes_dist 64 | _notes_to_onehot 65 | """ 66 | 67 | from .harmonic_shift import ( 68 | harmonic_shifting, 69 | scale_change 70 | ) 71 | 72 | from .key_profiles import ( 73 | KeyDetectionAlgorithms, 74 | KrumhanslKessler, 75 | Temperley, 76 | AlbrechtShanahan, 77 | signature_fifths, 78 | _signature_fifths_keys, 79 | _right_left_notes, 80 | _correlation, 81 | _keys_correlations, 82 | signature_fifths_profiles, 83 | _eights_per_pitch_class, 84 | key_detection 85 | ) 86 | 87 | from .chord_prediction import ( 88 | predict_chords, 89 | get_chords, 90 | get_chords_candidates, 91 | compute_chord_notes_dist, 92 | _notes_to_onehot, 93 | ) 94 | 95 | 96 | __all__ = [ 97 | "harmonic_shifting", 98 | "scale_change", 99 | "KeyDetectionAlgorithms", 100 | "KrumhanslKessler", 101 | "Temperley", 102 | "AlbrechtShanahan", 103 | "signature_fifths", 104 | "_signature_fifths_keys", 105 | "_right_left_notes", 106 | "_correlation", 107 | "_keys_correlations", 108 | "signature_fifths_profiles", 109 | "_eights_per_pitch_class", 110 | "key_detection", 111 | "predict_chords", 112 | "get_chords", 113 | "get_chords_candidates", 114 | "compute_chord_notes_dist", 115 | "_notes_to_onehot", 116 | ] 117 | -------------------------------------------------------------------------------- /musicaiz/algorithms/chord_prediction.py: -------------------------------------------------------------------------------- 1 | import pretty_midi as pm 2 | import numpy as np 3 | from typing import List, Dict 4 | 5 | 6 | from musicaiz.structure import Note 7 | from musicaiz.rhythm import NoteLengths 8 | from musicaiz.harmony import Chord 9 | 10 | 11 | def predict_chords(musa_obj): 12 | notes_beats = [] 13 | for i in range(len(musa_obj.beats)): 14 | nts = musa_obj.get_notes_in_beat(i) 15 | nts = [n for n in nts if not n.is_drum] 16 | if nts is not None or len(nts) != 0: 17 | notes_beats.append(nts) 18 | notes_pitches_segments = [_notes_to_onehot(note) for note in notes_beats] 19 | # Convert chord labels to onehot 20 | chords_onehot = Chord.chords_to_onehot() 21 | # step 1: Compute the distance between all the chord vectors and the notes vectors 22 | all_dists = [compute_chord_notes_dist(chords_onehot, segment) for segment in notes_pitches_segments] 23 | # step 2: get chord candidates per step which distance is the lowest 24 | chord_segments = get_chords_candidates(all_dists) 25 | # step 3: clean chord candidates 26 | chords = get_chords(chord_segments, chords_onehot) 27 | return chords 28 | 29 | 30 | def get_chords( 31 | chord_segments: List[List[str]], 32 | chords_onehot: Dict[str, List[int]], 33 | ) -> List[List[str]]: 34 | """ 35 | Clean the predicted chords that are extracted with get_chords_candidates method 36 | by comparing each chord in a step with the chords in the previous and next steps. 37 | The ouput chords are the ones wich distances are the lowest. 38 | 39 | Parameters 40 | ---------- 41 | 42 | chord_segments: List[List[str]] 43 | The chord candidates extracted with get_chords_candidates method. 44 | 45 | Returns 46 | ------- 47 | 48 | chords: List[List[str]] 49 | """ 50 | chords = [] 51 | for i, _ in enumerate(chord_segments): 52 | cross_dists = {} 53 | for j, _ in enumerate(chord_segments[i]): 54 | if i == 0: 55 | for item in range(len(chord_segments[i + 1])): 56 | dist = np.linalg.norm(np.array(chords_onehot[chord_segments[i][j]]) - np.array(chords_onehot[chord_segments[i+1][item]])) 57 | cross_dists.update( 58 | { 59 | chord_segments[i][j] + " " + chord_segments[i+1][item]: dist 60 | } 61 | ) 62 | if i != 0: 63 | for item in range(len(chord_segments[i - 1])): 64 | dist = np.linalg.norm(np.array(chords_onehot[chord_segments[i][j]]) - np.array(chords_onehot[chord_segments[i-1][item]])) 65 | cross_dists.update( 66 | { 67 | chord_segments[i][j] + " " + chord_segments[i-1][item]: dist 68 | } 69 | ) 70 | #print("--------") 71 | #print(cross_dists) 72 | chords_list = [(i.split(" ")[0], cross_dists[i]) for i in cross_dists if cross_dists[i]==min(cross_dists.values())] 73 | chords_dict = {} 74 | chords_dict.update(chords_list) 75 | #print(chords_dict) 76 | # Diminish distances if in previous step there's one or more chords equal to the chords in the current step 77 | for chord, dist in chords_dict.items(): 78 | if i != 0: 79 | prev_chords = [c for c in chords[i - 1]] 80 | tonics = [c.split("-")[0] for c in prev_chords] 81 | tonic = chord.split("-")[0] 82 | if chord not in prev_chords or tonic not in tonics: 83 | chords_dict[chord] = dist + 0.5 84 | #print(chords_dict) 85 | new_chords_list = [i for i in chords_dict if chords_dict[i]==min(chords_dict.values())] 86 | #print(new_chords_list) 87 | chords.append(new_chords_list) 88 | # If a 7th chord is predicted at a time step and the same chord triad is at 89 | # the prev at next steps, we'll substitute the triad chord for the 7th chord 90 | #for step in chords: 91 | # chord_names = "/".join(step) 92 | # if "SEVENTH" in chord_names: 93 | return chords 94 | 95 | 96 | def get_chords_candidates(dists: List[Dict[str, float]]) -> List[List[str]]: 97 | """ 98 | Gets the chords with the minimum distance in a list of dictionaries 99 | where each element of the list is a step (beat) corresponding to the note 100 | vectors and the items are dicts with the chord names (key) and dists (val.) 101 | 102 | Parameters 103 | ---------- 104 | 105 | dists: List[Dict[str, float]] 106 | The list of distances between chord and note vectors as dictionaries per step. 107 | 108 | Returns 109 | ------- 110 | 111 | chord_segments: List[List[str]] 112 | A list with all the chords predicted per step. 113 | """ 114 | chord_segments = [] 115 | for dists_dict in dists: 116 | chord_segments.append([i for i in dists_dict if dists_dict[i]==min(dists_dict.values())]) 117 | return chord_segments 118 | 119 | 120 | def compute_chord_notes_dist( 121 | chords_onehot: Dict[str, List[int]], 122 | notes_onehot: Dict[str, List[int]], 123 | ) -> Dict[str, float]: 124 | """ 125 | Compute the distance between each chord and a single notes vector. 126 | The outpput is given as a dictionary with the chord name (key) and the distance (val.). 127 | 128 | Parameters 129 | ---------- 130 | 131 | chords_onehot: Dict[str, List[int]] 132 | 133 | notes_onehot: Dict[str, List[int]] 134 | 135 | Returns 136 | ------- 137 | 138 | dists: Dict[str, float] 139 | """ 140 | dists = {} 141 | for chord, chord_vec in chords_onehot.items(): 142 | dist = np.linalg.norm(np.array(notes_onehot)-np.array(chord_vec)) 143 | dists.update({chord: dist}) 144 | return dists 145 | 146 | 147 | def _notes_to_onehot(notes: List[Note]) -> List[int]: 148 | """ 149 | Converts a list of notes into a list of 0s and 1s. 150 | The output list will have 12 elements corresponding to 151 | the notes in the chromatic scale from C to B. 152 | If the note C is in the input list, the index corresponding 153 | to that note in the output list will be 1, otherwise it'll be 0. 154 | 155 | Parameters 156 | ---------- 157 | notes: List[Note]) 158 | 159 | Returns 160 | ------- 161 | pitches_onehot: List[int] 162 | """ 163 | pitches = [pm.note_name_to_number(note.note_name + "-1") for note in notes] 164 | pitches = list(dict.fromkeys(pitches)) 165 | pitches_onehot = [1 if i in pitches else 0 for i in range(0, 12)] 166 | return pitches_onehot 167 | -------------------------------------------------------------------------------- /musicaiz/converters/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converters 3 | ========== 4 | 5 | This module allows to export symbolic music in different formats. 6 | 7 | 8 | JSON 9 | ---- 10 | 11 | .. autosummary:: 12 | :toctree: generated/ 13 | 14 | MusaJSON 15 | 16 | 17 | Pretty MIDI 18 | ----------- 19 | 20 | .. autosummary:: 21 | :toctree: generated/ 22 | 23 | prettymidi_note_to_musicaiz 24 | musicaiz_note_to_prettymidi 25 | 26 | 27 | Protobufs 28 | ---------- 29 | 30 | .. autosummary:: 31 | :toctree: generated/ 32 | 33 | musa_to_proto 34 | proto_to_musa 35 | """ 36 | 37 | from .musa_json import ( 38 | MusaJSON, 39 | BarJSON, 40 | InstrumentJSON, 41 | NoteJSON, 42 | ) 43 | 44 | from .pretty_midi_musa import ( 45 | prettymidi_note_to_musicaiz, 46 | musicaiz_note_to_prettymidi, 47 | musa_to_prettymidi, 48 | ) 49 | 50 | from .musa_protobuf import ( 51 | musa_to_proto, 52 | proto_to_musa 53 | ) 54 | 55 | from . import protobuf 56 | 57 | __all__ = [ 58 | "MusaJSON", 59 | "BarJSON", 60 | "InstrumentJSON", 61 | "NoteJSON", 62 | "prettymidi_note_to_musicaiz", 63 | "musicaiz_note_to_prettymidi", 64 | "musa_to_prettymidi", 65 | "protobuf", 66 | "musa_to_proto", 67 | "proto_to_musa" 68 | ] 69 | -------------------------------------------------------------------------------- /musicaiz/converters/musa_json.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | from pathlib import Path 3 | import json 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class NoteJSON: 9 | start: int # ticks 10 | end: int # ticks 11 | pitch: int 12 | velocity: int 13 | bar_idx: int 14 | beat_idx: int 15 | subbeat_idx: int 16 | instrument_idx: int 17 | instrument_prog: int 18 | 19 | 20 | @dataclass 21 | class BarJSON: 22 | time_sig: str 23 | bpm: int 24 | start: int # ticks 25 | end: int # ticks 26 | 27 | 28 | @dataclass 29 | class InstrumentJSON: 30 | is_drum: bool 31 | name: str 32 | n_prog: int 33 | 34 | 35 | @dataclass 36 | class JSON: 37 | tonality: str 38 | time_sig: str 39 | 40 | 41 | class MusaJSON: 42 | 43 | """ 44 | This class converst a `musicaiz` :func:`~musicaiz.loaders.Musa` object 45 | into a JSON format. 46 | Note that this conversion is different that the .json method of Musa class, 47 | since that is intended for encoding musicaiz objects and this class here 48 | can be encoded and decoded with other softwares since it does not encode 49 | musicaiz objects. 50 | 51 | Examples 52 | -------- 53 | 54 | >>> file = Path("../0.mid") 55 | >>> midi = Musa(file, structure="bars", absolute_timing=True) 56 | >>> musa_json = MusaJSON(midi) 57 | 58 | To add a field inside an instrument: 59 | 60 | >>> musa_json.add_instrument_field( 61 | n_program=0, 62 | field="hello", 63 | value=2 64 | ) 65 | 66 | Save the json to disk: 67 | 68 | >>> musa_json.save("filename") 69 | """ 70 | 71 | def __init__( 72 | self, 73 | musa_obj, # An initialized Musa object 74 | ): 75 | self.midi = musa_obj 76 | self.json = self.to_json(musa_obj=self.midi) 77 | 78 | def save(self, filename: str, path: Union[str, Path] = ""): 79 | """Saves the JSON into disk.""" 80 | with open(Path(path, filename + ".json"), "w") as write_file: 81 | json.dump(self.json, write_file) 82 | 83 | @staticmethod 84 | def to_json(musa_obj): 85 | composition = {} 86 | 87 | # headers 88 | composition["tonality"] = musa_obj.tonality 89 | composition["resolution"] = musa_obj.resolution 90 | composition["instruments"] = [] 91 | composition["bars"] = [] 92 | composition["notes"] = [] 93 | 94 | composition["instruments"] = [] 95 | for _, instr in enumerate(musa_obj.instruments): 96 | composition["instruments"].append( 97 | { 98 | "is_drum": instr.is_drum, 99 | "name": instr.name, 100 | "n_prog": int(instr.program), 101 | } 102 | ) 103 | for _, bar in enumerate(musa_obj.bars): 104 | composition["bars"].append( 105 | { 106 | "time_sig": bar.time_sig.time_sig, 107 | "start": bar.start_ticks, 108 | "end": bar.end_ticks, 109 | "bpm": bar.bpm, 110 | } 111 | ) 112 | for _, note in enumerate(musa_obj.notes): 113 | composition["notes"].append( 114 | { 115 | "start": note.start_ticks, 116 | "end": note.end_ticks, 117 | "pitch": note.pitch, 118 | "velocity": note.velocity, 119 | "bar_idx": note.bar_idx, 120 | "beat_idx": note.beat_idx, 121 | "subbeat_idx": note.subbeat_idx, 122 | "instrument_idx": note.instrument_idx, 123 | "instrument_prog": note.instrument_prog, 124 | 125 | } 126 | ) 127 | return composition 128 | 129 | def add_instrument_field(self, n_program: int, field: str, value: Union[str, int, float]): 130 | """ 131 | Adds a new key - value pair to the instrument which n_program is equal to the 132 | input ``n_program``. 133 | 134 | Parameters 135 | ---------- 136 | 137 | n_program: int 138 | 139 | field: str 140 | 141 | value: Union[str, int, float] 142 | """ 143 | self.__check_n_progr(n_program) 144 | for instr in self.json["instruments"]: 145 | if n_program != instr["n_prog"]: 146 | continue 147 | instr.update({str(field): value}) 148 | self.json 149 | 150 | def delete_instrument_field(): 151 | pass 152 | 153 | def __check_n_progr(self, n_program: int): 154 | """ 155 | Checks if the input ``n_program`` is in the current json. 156 | 157 | Parameters 158 | ---------- 159 | 160 | n_program: int 161 | The program number corresponding to the instrument. 162 | 163 | Raises 164 | ------ 165 | 166 | ValueError: _description_ 167 | """ 168 | progrs = [instr["n_prog"] for instr in self.json["instruments"]] 169 | # check if n_prog exists in the current json 170 | if n_program not in progrs: 171 | raise ValueError(f"The input n_program {n_program} is not in the current json. The n_programs of the instruments in the current json are {progrs}.") 172 | 173 | def add_bar_field(): 174 | NotImplementedError 175 | 176 | def delete_bar_field(): 177 | NotImplementedError 178 | 179 | def add_note_field(): 180 | NotImplementedError 181 | 182 | def delete_note_field(): 183 | NotImplementedError 184 | 185 | def add_header_field(): 186 | NotImplementedError 187 | 188 | def delete_header_field(): 189 | NotImplementedError 190 | 191 | 192 | class JSONMusa: 193 | NotImplementedError 194 | -------------------------------------------------------------------------------- /musicaiz/converters/pretty_midi_musa.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import pretty_midi as pm 3 | 4 | 5 | from musicaiz.structure import NoteClassBase 6 | 7 | 8 | def prettymidi_note_to_musicaiz(note: str) -> Tuple[str, int]: 9 | octave = int("".join(filter(str.isdigit, note))) 10 | # Get the note name without the octave 11 | note_name = note.replace(str(octave), "") 12 | musa_note_name = NoteClassBase.get_note_with_name(note_name) 13 | return musa_note_name.name, octave 14 | 15 | 16 | def musicaiz_note_to_prettymidi( 17 | note: str, 18 | octave: int 19 | ) -> str: 20 | """ 21 | >>> note = "F_SHARP" 22 | >>> octave = 3 23 | >>> pm_note = musicaiz_note_to_prettymidi(note, octave) 24 | >>> "F#3" 25 | """ 26 | note_name = note.replace("SHARP", "#") 27 | note_name = note_name.replace("FLAT", "b") 28 | note_name = note_name.replace("_", "") 29 | pm_note = note_name + str(octave) 30 | return pm_note 31 | 32 | 33 | def musa_to_prettymidi(musa_obj): 34 | """ 35 | Converts a Musa object into a PrettMIDI object. 36 | 37 | Returns 38 | ------- 39 | 40 | midi: PrettyMIDI 41 | The pretty_midi object. 42 | """ 43 | # TODO: Write also metadata in PrettyMIDI object: pitch bends.. 44 | midi = pm.PrettyMIDI( 45 | resolution=musa_obj.resolution, 46 | initial_tempo=musa_obj.tempo_changes[0]["tempo"] 47 | ) 48 | midi.time_signature_changes = [] 49 | for ts in musa_obj.time_signature_changes: 50 | midi.time_signature_changes.append( 51 | pm.TimeSignature( 52 | numerator=ts["time_sig"].num, 53 | denominator=ts["time_sig"].denom, 54 | time=ts["ms"] / 1000 55 | ) 56 | ) 57 | # TODO: Get ticks for each event (see Mido) 58 | midi._tick_scales = [ 59 | (0, 60.0 / (musa_obj.tempo_changes[0]["tempo"] * midi.resolution)) 60 | ] 61 | 62 | for i, inst in enumerate(musa_obj.instruments): 63 | midi.instruments.append( 64 | pm.Instrument( 65 | program=inst.program, 66 | is_drum=inst.is_drum, 67 | name=inst.name 68 | ) 69 | ) 70 | notes = musa_obj.get_notes_in_bars( 71 | bar_start=0, bar_end=musa_obj.total_bars, 72 | program=int(inst.program) 73 | ) 74 | for note in notes: 75 | midi.instruments[i].notes.append( 76 | pm.Note( 77 | velocity=note.velocity, 78 | pitch=note.pitch, 79 | start=note.start_sec, 80 | end=note.end_sec 81 | ) 82 | ) 83 | return midi 84 | -------------------------------------------------------------------------------- /musicaiz/converters/protobuf/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Installing protobuf and `protoc` command 3 | ## Windows 4 | 5 | 1. Go to [protobuf releases](https://github.com/protocolbuffers/protobuf/releases/) and download the `...winXX.zip` 6 | 2. Add the path to the `proto.exe` to the env vars (or `set PATH=`) 7 | 8 | ## Linux 9 | 10 | 11 | ## Generate the .py files from .proto files 12 | 13 | From the root path of this library run: 14 | ``` 15 | protoc musicaiz/converters/protobuf/musicaiz.proto --python_out=. --pyi_out=. 16 | ``` 17 | 18 | ## Try it out 19 | 20 | ```` 21 | import musicaiz 22 | midi = musicaiz.loaders.Musa("tests/fixtures/midis/mz_332_1.mid", structure="bars") 23 | midi.to_proto() 24 | ```` -------------------------------------------------------------------------------- /musicaiz/converters/protobuf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/musicaiz/converters/protobuf/__init__.py -------------------------------------------------------------------------------- /musicaiz/converters/protobuf/musicaiz.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package musicaiz; 4 | 5 | // protobuf of Musa object initialized with the arg `structure="bars"` 6 | // protoc musicaiz/converters/protobuf/music $1.proto --python_out=. --grpc_python_out=. 7 | 8 | message Musa { 9 | 10 | repeated TimeSignatureChanges time_signature_changes = 5; 11 | repeated SubdivisionNote subdivision_note = 6; 12 | repeated File file = 7; 13 | repeated TotalBars total_bars = 8; 14 | repeated Tonality tonality = 9; 15 | repeated Resolution resolution = 10; 16 | repeated IsQuantized is_quantized = 11; 17 | repeated QuantizerArgs quantizer_args = 12; 18 | repeated AbsoluteTiming absolute_timing = 13; 19 | repeated CutNotes cut_notes = 14; 20 | repeated TempoChanges tempo_changes = 15; 21 | repeated InstrumentsProgs instruments_progs = 16; 22 | repeated Instrument instruments = 17; 23 | repeated Bar bars = 18; 24 | repeated Note notes = 19; 25 | repeated Beat beats = 20; 26 | repeated Subbeat subbeats = 21; 27 | 28 | message TimeSignatureChanges {} 29 | message SubdivisionNote {} 30 | message File {} 31 | message TotalBars {} 32 | message Tonality {} 33 | message Resolution {} 34 | message IsQuantized {} 35 | message QuantizerArgs {} 36 | message AbsoluteTiming {} 37 | message CutNotes {} 38 | message TempoChanges {} 39 | message InstrumentsProgs {} 40 | 41 | message Instrument { 42 | // Instrument index. 43 | int32 instrument = 1; 44 | // The n program of the instrument. 45 | int32 program = 2; 46 | // The name of the instrument. 47 | string name = 3; 48 | // The instrument's family. 49 | string family = 4; 50 | bool is_drum = 5; 51 | } 52 | 53 | message Note { 54 | int32 pitch = 1; 55 | string pitch_name = 2; 56 | string note_name = 3; 57 | string octave = 4; 58 | bool ligated = 5; 59 | 60 | // Timing inf of the Note 61 | int32 start_ticks = 6; 62 | int32 end_ticks = 7; 63 | float start_sec = 8; 64 | float end_sec = 9; 65 | string symbolic = 10; 66 | 67 | int32 velocity = 11; 68 | 69 | int32 bar_idx = 12; 70 | int32 beat_idx = 13; 71 | int32 subbeat_idx = 14; 72 | 73 | int32 instrument_idx = 15; 74 | int32 instrument_prog = 16; 75 | } 76 | 77 | message Bar { 78 | float bpm = 1; 79 | string time_sig = 2; 80 | int32 resolution = 3; 81 | bool absolute_timing = 4; 82 | 83 | // Timing inf of the Bar 84 | int32 note_density = 5; 85 | int32 harmonic_density = 6; 86 | int32 start_ticks = 7; 87 | int32 end_ticks = 8; 88 | float start_sec = 9; 89 | float end_sec = 10; 90 | } 91 | 92 | message Beat { 93 | float bpm = 1; 94 | string time_sig = 2; 95 | int32 resolution = 3; 96 | bool absolute_timing = 4; 97 | 98 | // Timing 99 | int32 start_ticks = 7; 100 | int32 end_ticks = 8; 101 | float start_sec = 9; 102 | float end_sec = 10; 103 | 104 | int32 global_idx = 11; 105 | int32 bar_idx = 12; 106 | } 107 | 108 | message Subbeat { 109 | float bpm = 1; 110 | string time_sig = 2; 111 | int32 resolution = 3; 112 | bool absolute_timing = 4; 113 | 114 | // Timing inf of the Bar 115 | int32 note_density = 5; 116 | int32 harmonic_density = 6; 117 | int32 start_ticks = 7; 118 | int32 end_ticks = 8; 119 | float start_sec = 9; 120 | float end_sec = 10; 121 | 122 | int32 global_idx = 11; 123 | int32 bar_idx = 12; 124 | int32 beat_idx = 13; 125 | } 126 | } -------------------------------------------------------------------------------- /musicaiz/converters/protobuf/musicaiz_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: musicaiz/converters/protobuf/musicaiz.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf.internal import builder as _builder 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | 15 | 16 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n+musicaiz/converters/protobuf/musicaiz.proto\x12\x08musicaiz\"\xa8\x10\n\x04Musa\x12\x43\n\x16time_signature_changes\x18\x05 \x03(\x0b\x32#.musicaiz.Musa.TimeSignatureChanges\x12\x38\n\x10subdivision_note\x18\x06 \x03(\x0b\x32\x1e.musicaiz.Musa.SubdivisionNote\x12!\n\x04\x66ile\x18\x07 \x03(\x0b\x32\x13.musicaiz.Musa.File\x12,\n\ntotal_bars\x18\x08 \x03(\x0b\x32\x18.musicaiz.Musa.TotalBars\x12)\n\x08tonality\x18\t \x03(\x0b\x32\x17.musicaiz.Musa.Tonality\x12-\n\nresolution\x18\n \x03(\x0b\x32\x19.musicaiz.Musa.Resolution\x12\x30\n\x0cis_quantized\x18\x0b \x03(\x0b\x32\x1a.musicaiz.Musa.IsQuantized\x12\x34\n\x0equantizer_args\x18\x0c \x03(\x0b\x32\x1c.musicaiz.Musa.QuantizerArgs\x12\x36\n\x0f\x61\x62solute_timing\x18\r \x03(\x0b\x32\x1d.musicaiz.Musa.AbsoluteTiming\x12*\n\tcut_notes\x18\x0e \x03(\x0b\x32\x17.musicaiz.Musa.CutNotes\x12\x32\n\rtempo_changes\x18\x0f \x03(\x0b\x32\x1b.musicaiz.Musa.TempoChanges\x12:\n\x11instruments_progs\x18\x10 \x03(\x0b\x32\x1f.musicaiz.Musa.InstrumentsProgs\x12.\n\x0binstruments\x18\x11 \x03(\x0b\x32\x19.musicaiz.Musa.Instrument\x12 \n\x04\x62\x61rs\x18\x12 \x03(\x0b\x32\x12.musicaiz.Musa.Bar\x12\"\n\x05notes\x18\x13 \x03(\x0b\x32\x13.musicaiz.Musa.Note\x12\"\n\x05\x62\x65\x61ts\x18\x14 \x03(\x0b\x32\x13.musicaiz.Musa.Beat\x12(\n\x08subbeats\x18\x15 \x03(\x0b\x32\x16.musicaiz.Musa.Subbeat\x1a\x16\n\x14TimeSignatureChanges\x1a\x11\n\x0fSubdivisionNote\x1a\x06\n\x04\x46ile\x1a\x0b\n\tTotalBars\x1a\n\n\x08Tonality\x1a\x0c\n\nResolution\x1a\r\n\x0bIsQuantized\x1a\x0f\n\rQuantizerArgs\x1a\x10\n\x0e\x41\x62soluteTiming\x1a\n\n\x08\x43utNotes\x1a\x0e\n\x0cTempoChanges\x1a\x12\n\x10InstrumentsProgs\x1a`\n\nInstrument\x12\x12\n\ninstrument\x18\x01 \x01(\x05\x12\x0f\n\x07program\x18\x02 \x01(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0e\n\x06\x66\x61mily\x18\x04 \x01(\t\x12\x0f\n\x07is_drum\x18\x05 \x01(\x08\x1a\xb6\x02\n\x04Note\x12\r\n\x05pitch\x18\x01 \x01(\x05\x12\x12\n\npitch_name\x18\x02 \x01(\t\x12\x11\n\tnote_name\x18\x03 \x01(\t\x12\x0e\n\x06octave\x18\x04 \x01(\t\x12\x0f\n\x07ligated\x18\x05 \x01(\x08\x12\x13\n\x0bstart_ticks\x18\x06 \x01(\x05\x12\x11\n\tend_ticks\x18\x07 \x01(\x05\x12\x11\n\tstart_sec\x18\x08 \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\t \x01(\x02\x12\x10\n\x08symbolic\x18\n \x01(\t\x12\x10\n\x08velocity\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x12\x10\n\x08\x62\x65\x61t_idx\x18\r \x01(\x05\x12\x13\n\x0bsubbeat_idx\x18\x0e \x01(\x05\x12\x16\n\x0einstrument_idx\x18\x0f \x01(\x05\x12\x17\n\x0finstrument_prog\x18\x10 \x01(\x05\x1a\xcd\x01\n\x03\x42\x61r\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x14\n\x0cnote_density\x18\x05 \x01(\x05\x12\x18\n\x10harmonic_density\x18\x06 \x01(\x05\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x1a\xc3\x01\n\x04\x42\x65\x61t\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x12\x12\n\nglobal_idx\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x1a\x88\x02\n\x07Subbeat\x12\x0b\n\x03\x62pm\x18\x01 \x01(\x02\x12\x10\n\x08time_sig\x18\x02 \x01(\t\x12\x12\n\nresolution\x18\x03 \x01(\x05\x12\x17\n\x0f\x61\x62solute_timing\x18\x04 \x01(\x08\x12\x14\n\x0cnote_density\x18\x05 \x01(\x05\x12\x18\n\x10harmonic_density\x18\x06 \x01(\x05\x12\x13\n\x0bstart_ticks\x18\x07 \x01(\x05\x12\x11\n\tend_ticks\x18\x08 \x01(\x05\x12\x11\n\tstart_sec\x18\t \x01(\x02\x12\x0f\n\x07\x65nd_sec\x18\n \x01(\x02\x12\x12\n\nglobal_idx\x18\x0b \x01(\x05\x12\x0f\n\x07\x62\x61r_idx\x18\x0c \x01(\x05\x12\x10\n\x08\x62\x65\x61t_idx\x18\r \x01(\x05\x62\x06proto3') 17 | 18 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) 19 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'musicaiz.converters.protobuf.musicaiz_pb2', globals()) 20 | if _descriptor._USE_C_DESCRIPTORS == False: 21 | 22 | DESCRIPTOR._options = None 23 | _MUSA._serialized_start=58 24 | _MUSA._serialized_end=2146 25 | _MUSA_TIMESIGNATURECHANGES._serialized_start=876 26 | _MUSA_TIMESIGNATURECHANGES._serialized_end=898 27 | _MUSA_SUBDIVISIONNOTE._serialized_start=900 28 | _MUSA_SUBDIVISIONNOTE._serialized_end=917 29 | _MUSA_FILE._serialized_start=919 30 | _MUSA_FILE._serialized_end=925 31 | _MUSA_TOTALBARS._serialized_start=927 32 | _MUSA_TOTALBARS._serialized_end=938 33 | _MUSA_TONALITY._serialized_start=940 34 | _MUSA_TONALITY._serialized_end=950 35 | _MUSA_RESOLUTION._serialized_start=952 36 | _MUSA_RESOLUTION._serialized_end=964 37 | _MUSA_ISQUANTIZED._serialized_start=966 38 | _MUSA_ISQUANTIZED._serialized_end=979 39 | _MUSA_QUANTIZERARGS._serialized_start=981 40 | _MUSA_QUANTIZERARGS._serialized_end=996 41 | _MUSA_ABSOLUTETIMING._serialized_start=998 42 | _MUSA_ABSOLUTETIMING._serialized_end=1014 43 | _MUSA_CUTNOTES._serialized_start=1016 44 | _MUSA_CUTNOTES._serialized_end=1026 45 | _MUSA_TEMPOCHANGES._serialized_start=1028 46 | _MUSA_TEMPOCHANGES._serialized_end=1042 47 | _MUSA_INSTRUMENTSPROGS._serialized_start=1044 48 | _MUSA_INSTRUMENTSPROGS._serialized_end=1062 49 | _MUSA_INSTRUMENT._serialized_start=1064 50 | _MUSA_INSTRUMENT._serialized_end=1160 51 | _MUSA_NOTE._serialized_start=1163 52 | _MUSA_NOTE._serialized_end=1473 53 | _MUSA_BAR._serialized_start=1476 54 | _MUSA_BAR._serialized_end=1681 55 | _MUSA_BEAT._serialized_start=1684 56 | _MUSA_BEAT._serialized_end=1879 57 | _MUSA_SUBBEAT._serialized_start=1882 58 | _MUSA_SUBBEAT._serialized_end=2146 59 | # @@protoc_insertion_point(module_scope) 60 | -------------------------------------------------------------------------------- /musicaiz/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Datasets 3 | ======== 4 | 5 | This submodule presents helper functions to works and process MIR datasets for 6 | different tasks such as: 7 | 8 | - Automatic Chord Recognition: ACR 9 | - Music Generation or Composition 10 | 11 | .. note:: 12 | Not all the datasets shown here are included in `musicaiz` souce code. 13 | Some of these datasets have their own GitHub repository with their corresponding helper functions. 14 | 15 | 16 | 17 | Music Composition or Generation 18 | ------------------------------- 19 | 20 | - JS Bach Chorales 21 | 22 | Boulanger-Lewandowski, N., Bengio, Y., & Vincent, P. (2012). 23 | Modeling temporal dependencies in high-dimensional sequences: 24 | Application to polyphonic music generation and transcription. 25 | arXiv preprint arXiv:1206.6392. 26 | 27 | | Download: http://www-ens.iro.umontreal.ca/~boulanni/icml2012 28 | | Paper: https://arxiv.org/abs/1206.6392 29 | 30 | 31 | - JS Bach Fakes 32 | 33 | Peracha, O. (2021). 34 | Js fake chorales: a synthetic dataset of polyphonic music with human annotation. 35 | arXiv preprint arXiv:2107.10388. 36 | 37 | | Paper: https://arxiv.org/abs/2107.10388 38 | | Repository & Download: https://github.com/omarperacha/js-fakes 39 | 40 | 41 | - Lakh MIDI Dataset (LMD) 42 | 43 | Raffel, C. (2016). 44 | Learning-based methods for comparing sequences, with applications to audio-to-midi alignment and matching. 45 | Columbia University. 46 | 47 | | Thesis: http://colinraffel.com/publications/thesis.pdf 48 | | Download: https://colinraffel.com/projects/lmd/ 49 | 50 | 51 | - MAESTRO 52 | 53 | Hawthorne, C., Stasyuk, A., Roberts, A., Simon, I., Huang, C. Z. A., Dieleman, S., ... & Eck, D. (2018). 54 | Enabling factorized piano music modeling and generation with the MAESTRO dataset. arXiv preprint arXiv:1810.12247. 55 | 56 | | Download: https://magenta.tensorflow.org/datasets/maestro 57 | | Paper: https://arxiv.org/abs/1810.12247 58 | 59 | 60 | - Slakh2100 61 | 62 | Manilow, E., Wichern, G., Seetharaman, P., & Le Roux, J. (2019). 63 | Cutting music source separation some Slakh: A dataset to study the impact of training data quality and quantity. 64 | In 2019 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA) (pp. 45-49). IEEE. 65 | 66 | | Download link: https://zenodo.org/record/4599666#.YpD8ZO7P1PY 67 | | Paper: https://ieeexplore.ieee.org/abstract/document/8937170 68 | | Repository: https://github.com/ethman/Slakh 69 | 70 | 71 | - Meta-MIDI Dataset 72 | 73 | Ens, J., & Pasquier, P. (2021). 74 | Building the metamidi dataset: Linking symbolic and audio musical data. 75 | In Proceedings of 22st International Conference on Music Information Retrieval, ISMIR. 76 | 77 | | Download link: https://zenodo.org/record/5142664#.YpD8he7P1PY 78 | | Paper: https://archives.ismir.net/ismir2021/paper/000022.pdf 79 | | Repository: https://github.com/jeffreyjohnens/MetaMIDIDataset 80 | 81 | 82 | Automatic Chord Recognition 83 | --------------------------- 84 | 85 | - Schubert Winterreise Dataset 86 | 87 | Christof Weiß, Frank Zalkow, Vlora Arifi-Müller, Meinard Müller, Hendrik Vincent Koops, Anja Volk, & Harald G. Grohganz. 88 | (2020). Schubert Winterreise Dataset [Data set]. In ACM Journal on Computing and Cultural Heritage (1.0). 89 | 90 | | Download: https://doi.org/10.5281/zenodo.3968389 91 | | Paper: https://dl.acm.org/doi/10.1145/3429743 92 | 93 | .. autosummary:: 94 | :toctree: generated/ 95 | 96 | SWDPathsConfig 97 | SWD_FILES 98 | shubert_winterreise 99 | 100 | 101 | - BPS-FH Dataset 102 | 103 | Chen, T. P., & Su, L. (2018). 104 | Functional Harmony Recognition of Symbolic Music Data with Multi-task Recurrent Neural Networks. 105 | In Proceedings of 19th International Conference on Music Information Retrieval, ISMIR. pp. 90-97. 106 | 107 | | Repository & Download: https://github.com/Tsung-Ping/functional-harmony 108 | | Paper: http://ismir2018.ircam.fr/doc/pdfs/178_Paper.pdf 109 | 110 | .. autosummary:: 111 | :toctree: generated/ 112 | 113 | BPSFHPathsConfig 114 | 115 | """ 116 | 117 | from .configs import ( 118 | MusicGenerationDataset, 119 | ) 120 | from .jsbchorales import JSBChorales 121 | from .lmd import LakhMIDI 122 | from .maestro import Maestro 123 | from .bps_fh import BPSFH 124 | 125 | 126 | __all__ = [ 127 | "MusicGenerationDataset", 128 | "JSBChorales", 129 | "LakhMIDI", 130 | "Maestro", 131 | "BPSFH" 132 | ] 133 | -------------------------------------------------------------------------------- /musicaiz/datasets/bps_fh.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, Dict, Any 2 | from pathlib import Path 3 | import pandas as pd 4 | import math 5 | import numpy as np 6 | 7 | from musicaiz.rhythm import TimeSignature 8 | from musicaiz.harmony import Tonality, AllChords 9 | from musicaiz.structure import NoteClassBase 10 | 11 | 12 | class BPSFH: 13 | 14 | TIME_SIGS = { 15 | "1": "4/4", 16 | "2": "2/4", 17 | "3": "4/4", 18 | "4": "6/8", 19 | "5": "3/4", 20 | "6": "2/4", 21 | "7": "4/4", 22 | "8": "4/4", 23 | "9": "4/4", 24 | "10": "2/4", 25 | "11": "4/4", 26 | "12": "3/8", 27 | "13": "4/4", 28 | "14": "4/4", 29 | "15": "3/4", 30 | "16": "2/4", 31 | "17": "4/4", 32 | "18": "3/4", 33 | "19": "2/4", 34 | "20": "4/4", 35 | "21": "4/4", 36 | "22": "3/4", 37 | "23": "12/8", 38 | "24": "2/4", 39 | "25": "3/4", 40 | "26": "2/4", 41 | "27": "3/4", 42 | "28": "6/8", 43 | "29": "4/4", 44 | "30": "2/4", 45 | "31": "3/4", 46 | "32": "4/4", 47 | } 48 | 49 | def __init__(self, path: Union[str, Path]): 50 | self.path = Path(path) 51 | 52 | def parse_anns( 53 | self, 54 | anns: str = "high" 55 | ) -> Dict[str, pd.DataFrame]: 56 | """ 57 | Converts the bar index float annotations in 8th notes. 58 | """ 59 | table = {} 60 | 61 | for file in list(self.path.glob("*/*.mid")): 62 | filename = file.name.split(".")[0] 63 | table[file.name.split(".")[0]] = [] 64 | anns_path = Path(self.path, file.name.split(".")[0]) 65 | 66 | gt = pd.read_excel( 67 | Path(anns_path, "phrases.xlsx"), 68 | header=None 69 | ) 70 | 71 | # Read file with musanalysis to get tempo 72 | time_sig = TimeSignature(self.TIME_SIGS[filename]) 73 | 74 | # Loop in rows 75 | prev_sec = "" 76 | rows_ans = [] 77 | j = 0 78 | for i, row in gt.iterrows(): 79 | if anns == "high": 80 | sec_name = row[2] 81 | elif anns == "mid": 82 | sec_name = row[3] 83 | elif anns == "low": 84 | sec_name = row[4] 85 | if i == 0: 86 | if row[0] < 0: 87 | dec, quarters = math.modf(row[0]) 88 | if quarters == 0.0: 89 | summator = 1 90 | elif quarters == -1.0: 91 | summator = time_sig.num 92 | else: 93 | summator = 0 94 | if anns == "high": 95 | rows_ans.append([ann for k, ann in enumerate(row) if k <= 2]) 96 | elif anns == "mid": 97 | rows_ans.append([ann for k, ann in enumerate(row) if (k <= 1 or k == 3)]) 98 | elif anns == "low": 99 | rows_ans.append([ann for k, ann in enumerate(row) if (k <= 1 or k == 4)]) 100 | rows_ans[-1][0] += summator 101 | rows_ans[-1][1] += summator 102 | prev_sec = sec_name 103 | j += 1 104 | end_time = row[1] 105 | continue 106 | if sec_name != prev_sec: 107 | rows_ans[j - 1][1] = end_time + summator 108 | if anns == "high": 109 | rows_ans.append([ann for k, ann in enumerate(row) if k <= 2]) 110 | elif anns == "mid": 111 | rows_ans.append([ann for k, ann in enumerate(row) if (k <= 1 or k == 3)]) 112 | elif anns == "low": 113 | rows_ans.append([ann for k, ann in enumerate(row) if (k <= 1 or k == 4)]) 114 | rows_ans[-1][0] += summator 115 | rows_ans[-1][1] += summator 116 | prev_sec = sec_name 117 | j += 1 118 | if i == len(gt) - 1: 119 | rows_ans[-1][1] = row[1] + summator 120 | end_time = row[1] 121 | 122 | new_df = pd.DataFrame(columns=np.arange(3)) 123 | for i, r in enumerate(rows_ans): 124 | new_df.loc[i] = r 125 | 126 | table[file.name.split(".")[0]] = new_df 127 | 128 | return table 129 | 130 | @classmethod 131 | def bpsfh_key_to_musicaiz(cls, note: str) -> Tonality: 132 | alt = None 133 | if "-" in note: 134 | alt = "FLAT" 135 | note = note.split("-")[0] 136 | elif "+" in note: 137 | alt = "SHARP" 138 | note = note.split("+")[0] 139 | if note.isupper(): 140 | mode = "MAJOR" 141 | else: 142 | mode = "MINOR" 143 | note = note.capitalize() 144 | if alt is None: 145 | tonality = Tonality[note + "_" + mode] 146 | else: 147 | tonality = Tonality[note + "_" + alt + "_" + mode] 148 | return tonality 149 | 150 | @classmethod 151 | def bpsfh_chord_quality_to_musicaiz(cls, quality: str) -> AllChords: 152 | if quality == "M": 153 | q = "MAJOR_TRIAD" 154 | elif quality == "m": 155 | q = "MINOR_TRIAD" 156 | elif quality == "M7": 157 | q = "MAJOR_SEVENTH" 158 | elif quality == "m7": 159 | q = "MINOR_SEVENTH" 160 | elif quality == "D7": 161 | q = "DOMINANT_SEVENTH" 162 | elif quality == "a": 163 | q = "AUGMENTED_TRIAD" 164 | return AllChords[q] 165 | 166 | @classmethod 167 | def bpsfh_chord_to_musicaiz( 168 | cls, 169 | note: str, 170 | degree: int, 171 | quality: str, 172 | ) -> Tuple[NoteClassBase, AllChords]: 173 | tonality = cls.bpsfh_key_to_musicaiz(note) 174 | qt = cls.bpsfh_chord_quality_to_musicaiz(quality) 175 | notes = tonality.notes 176 | return notes[degree - 1], qt 177 | -------------------------------------------------------------------------------- /musicaiz/datasets/configs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from enum import Enum 3 | from typing import List, Dict, Type 4 | from abc import ABCMeta 5 | 6 | from musicaiz.tokenizers import TokenizerArguments 7 | from musicaiz.datasets.utils import tokenize_path 8 | 9 | 10 | TOKENIZE_VALID_SPLITS = [ 11 | "train", 12 | "validation", 13 | "test", 14 | "all", 15 | ] 16 | 17 | 18 | class MusicGenerationDatasetNames(Enum): 19 | MAESTRO = ["maestro"] 20 | LAKH_MIDI = ["lakh_midi_dataset", "lakh_midi", "lmd"] 21 | JSB_CHORALES = ["jsbchorales", "jsb_chorales", "bach_chorales"] 22 | 23 | @classmethod 24 | def all_values(cls) -> List[str]: 25 | all = [] 26 | for n in cls.__members__.values(): 27 | for name in n.value: 28 | all.append(name) 29 | return all 30 | 31 | 32 | class MusicGenerationDataset(metaclass=ABCMeta): 33 | 34 | def _prepare_tokenize( 35 | self, 36 | dataset_path: str, 37 | output_path: str, 38 | output_file: str, 39 | metadata: Dict[str, str], 40 | tokenize_split: str, 41 | args: Type[TokenizerArguments], 42 | dirs_splitted: bool, 43 | ) -> None: 44 | 45 | """Depending on the args that are passed to this method, the 46 | tokenization selected will be one of the available tokenizers that 47 | mathces with the args object. 48 | The current tokenizers available are: :const:`~musicaiz.tokenizers.constants.TOKENIZER_ARGUMENTS` 49 | """ 50 | 51 | # make same dirs to store the token sequences separated in 52 | # train, valid and test 53 | dest_train_path = Path(output_path, "train") 54 | dest_train_path.mkdir(parents=True, exist_ok=True) 55 | 56 | dest_val_path = Path(output_path, "validation") 57 | dest_val_path.mkdir(parents=True, exist_ok=True) 58 | 59 | dest_test_path = Path(output_path, "test") 60 | dest_test_path.mkdir(parents=True, exist_ok=True) 61 | 62 | if metadata is not None: 63 | # Split metadata in train, validation and test 64 | train_metadata, val_metadata, test_metadata = {}, {}, {} 65 | for key, val in metadata.items(): 66 | if val["split"] == "train": 67 | train_metadata.update({key: val}) 68 | elif val["split"] == "validation": 69 | val_metadata.update({key: val}) 70 | elif val["split"] == "test": 71 | test_metadata.update({key: val}) 72 | else: 73 | continue 74 | else: 75 | train_metadata, val_metadata, test_metadata = None, None, None 76 | 77 | if dirs_splitted: 78 | data_train_path = Path(dataset_path, "train") 79 | data_val_path = Path(dataset_path, "valid") 80 | data_test_path = Path(dataset_path, "test") 81 | else: 82 | data_train_path = dataset_path 83 | data_val_path = dataset_path 84 | data_test_path = dataset_path 85 | 86 | if tokenize_split not in TOKENIZE_VALID_SPLITS: 87 | raise ValueError(f"tokenize_split must be one of the following: {[f for f in TOKENIZE_VALID_SPLITS]}") 88 | if tokenize_split == "train" or tokenize_split == "all": 89 | tokenize_path(data_train_path, dest_train_path, train_metadata, output_file, args) 90 | if tokenize_split == "validation" or tokenize_split == "all": 91 | tokenize_path(data_val_path, dest_val_path, val_metadata, output_file, args) 92 | if tokenize_split == "test" or tokenize_split == "all": 93 | tokenize_path(data_test_path, dest_test_path, test_metadata, output_file, args) 94 | 95 | # save configs json 96 | TokenizerArguments.save(args, output_path) 97 | -------------------------------------------------------------------------------- /musicaiz/datasets/jsbchorales.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Union 3 | 4 | from musicaiz.datasets.configs import ( 5 | MusicGenerationDataset, 6 | MusicGenerationDatasetNames 7 | ) 8 | from musicaiz.tokenizers import ( 9 | MMMTokenizer, 10 | MMMTokenizerArguments, 11 | ) 12 | 13 | 14 | class JSBChorales(MusicGenerationDataset): 15 | """ 16 | """ 17 | def __init__(self): 18 | self.name = MusicGenerationDatasetNames.JSB_CHORALES.name.lower() 19 | 20 | def tokenize( 21 | self, 22 | dataset_path: str, 23 | output_path: str, 24 | tokenize_split: str, 25 | args: MMMTokenizerArguments, 26 | output_file: str = "token-sequences", 27 | ) -> None: 28 | """ 29 | 30 | Parameters 31 | ---------- 32 | 33 | dataset_path (str): _description_ 34 | 35 | output_path (str): _description_ 36 | 37 | tokenize_split (str): _description_ 38 | 39 | args (Type[TokenizerArguments]): _description_ 40 | 41 | output_file (str, optional): _description_. Defaults to "token-sequences". 42 | 43 | Examples 44 | -------- 45 | 46 | >>> # initialize tokenizer args 47 | >>> args = MMMTokenizerArguments( 48 | >>> prev_tokens="", 49 | >>> windowing=True, 50 | >>> time_unit="HUNDRED_TWENTY_EIGHT", 51 | >>> num_programs=None, 52 | >>> shuffle_tracks=True, 53 | >>> track_density=False, 54 | >>> window_size=32, 55 | >>> hop_length=16, 56 | >>> time_sig=True, 57 | >>> velocity=True, 58 | >>> ) 59 | >>> # initialize dataset 60 | >>> dataset = JSBChorales() 61 | >>> dataset.tokenize( 62 | >>> dataset_path="path/to/dataset", 63 | >>> output_path="output/path", 64 | >>> output_file="token-sequences", 65 | >>> args=args, 66 | >>> tokenize_split="test" 67 | >>> ) 68 | >>> # get vocabulary and save it in `dataset_path` 69 | >>> vocab = MMMTokenizer.get_vocabulary( 70 | >>> dataset_path="output/path" 71 | >>> ) 72 | """ 73 | 74 | self._prepare_tokenize( 75 | dataset_path, 76 | output_path, 77 | output_file, 78 | None, 79 | tokenize_split, 80 | args, 81 | True 82 | ) 83 | 84 | 85 | # TODO: args parsing here 86 | if __name__ == "__main__": 87 | args = MMMTokenizerArguments( 88 | prev_tokens="", 89 | windowing=True, 90 | time_unit="HUNDRED_TWENTY_EIGHT", 91 | num_programs=None, 92 | shuffle_tracks=True, 93 | track_density=False, 94 | window_size=32, 95 | hop_length=16, 96 | time_sig=True, 97 | velocity=True, 98 | ) 99 | dataset = JSBChorales() 100 | dataset.tokenize( 101 | dataset_path="H:/INVESTIGACION/Datasets/JSB Chorales/", 102 | output_path="H:/GitHub/musanalysis-datasets/jsbchorales/mmm/32_bars_166", 103 | output_file="token-sequences", 104 | args=args, 105 | tokenize_split="validation" 106 | ) 107 | vocab = MMMTokenizer.get_vocabulary( 108 | dataset_path="H:/GitHub/musanalysis-datasets/jsbchorales/mmm/32_bars_166" 109 | ) 110 | -------------------------------------------------------------------------------- /musicaiz/datasets/lmd.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union, Dict, Type 3 | 4 | from musicaiz.datasets.configs import ( 5 | MusicGenerationDataset, 6 | MusicGenerationDatasetNames 7 | ) 8 | from musicaiz.tokenizers import ( 9 | MMMTokenizer, 10 | MMMTokenizerArguments, 11 | TokenizerArguments 12 | ) 13 | 14 | 15 | class LakhMIDI(MusicGenerationDataset): 16 | """ 17 | """ 18 | 19 | def __init__(self): 20 | self.name = MusicGenerationDatasetNames.LAKH_MIDI.name.lower() 21 | 22 | def tokenize( 23 | self, 24 | dataset_path: Union[Path, str], 25 | output_path: Union[Path, str], 26 | tokenize_split: str, 27 | args: Type[TokenizerArguments], 28 | output_file: str = "token-sequences", 29 | train_split: float = 0.7, 30 | test_split: float = 0.2 31 | ) -> None: 32 | """ 33 | 34 | Parameters 35 | ---------- 36 | 37 | dataset_path (str): _description_ 38 | 39 | output_path (str): _description_ 40 | 41 | tokenize_split (str): _description_ 42 | 43 | args (Type[TokenizerArguments]): _description_ 44 | 45 | output_file (str, optional): _description_. Defaults to "token-sequences". 46 | 47 | Examples 48 | -------- 49 | 50 | >>> # initialize tokenizer args 51 | >>> args = MMMTokenizerArguments( 52 | >>> prev_tokens="", 53 | >>> windowing=True, 54 | >>> time_unit="HUNDRED_TWENTY_EIGHT", 55 | >>> num_programs=None, 56 | >>> shuffle_tracks=True, 57 | >>> track_density=False, 58 | >>> window_size=32, 59 | >>> hop_length=16, 60 | >>> time_sig=True, 61 | >>> velocity=True, 62 | >>> ) 63 | >>> # initialize dataset 64 | >>> dataset = LakhMIDI() 65 | >>> dataset.tokenize( 66 | >>> dataset_path="path/to/dataset", 67 | >>> output_path="output/path", 68 | >>> output_file="token-sequences", 69 | >>> args=args, 70 | >>> tokenize_split="all" 71 | >>> ) 72 | >>> # get vocabulary and save it in `dataset_path` 73 | >>> vocab = MMMTokenizer.get_vocabulary( 74 | >>> dataset_path="output/path" 75 | >>> ) 76 | """ 77 | metadata = self.get_metadata( 78 | dataset_path, 79 | train_split, 80 | test_split 81 | ) 82 | self._prepare_tokenize( 83 | dataset_path, 84 | output_path, 85 | output_file, 86 | metadata, 87 | tokenize_split, 88 | args, 89 | False 90 | ) 91 | 92 | @staticmethod 93 | def get_metadata( 94 | dataset_path: str, 95 | train_split: float = 0.7, 96 | test_split: float = 0.2 97 | ) -> Dict[str, str]: 98 | """ 99 | 100 | Args: 101 | dataset_path (str): _description_ 102 | train_split (float): _description_ 103 | test_split (float): _description_ 104 | 105 | validation split is automatically calculated as: 106 | 1 - train_split - test_split 107 | 108 | Returns: 109 | _type_: _description_ 110 | """ 111 | 112 | if isinstance(dataset_path, str): 113 | dataset_path = Path(dataset_path) 114 | 115 | composers_json = {} 116 | # iterate over subdirs which are different artists 117 | for composer_path in dataset_path.glob("*/"): 118 | # 1. Process composer 119 | composer = composer_path.stem 120 | # Some composers are written with 2 different composers separated by "/" 121 | # we'll only consider the 1st one 122 | composer = composer.replace(" ", "_") 123 | composer = composer.upper() 124 | 125 | # iterate over songs of an artist 126 | songs = [f for f in composer_path.glob("*/")] 127 | n_songs = len(songs) 128 | 129 | train_idxs = int(round(n_songs * train_split)) 130 | val_idxs = int(n_songs * test_split) 131 | 132 | train_seqs = songs[:train_idxs] 133 | val_seqs = songs[train_idxs:val_idxs+train_idxs] 134 | 135 | 136 | # split in train, validation and test 137 | # we do this here to ensure that every artist is at least in 138 | # the training and test sets (if n_songs > 1) 139 | for song in songs: 140 | #----------------- 141 | period = "" 142 | 143 | # 3. Process canonical genre 144 | genre = "" 145 | 146 | split = "" 147 | 148 | if song in train_seqs: 149 | split = "train" 150 | elif song in val_seqs: 151 | split = "validation" 152 | else: 153 | split = "test" 154 | 155 | composers_json.update( 156 | { 157 | composer_path.stem + "/" + song.name: { 158 | "composer": composer, 159 | "period": period, 160 | "genre": genre, 161 | "split": split 162 | } 163 | } 164 | ) 165 | return composers_json 166 | 167 | 168 | # TODO: args parsing here 169 | if __name__ == "__main__": 170 | args = MMMTokenizerArguments( 171 | prev_tokens="", 172 | windowing=True, 173 | time_unit="HUNDRED_TWENTY_EIGHT", 174 | num_programs=None, 175 | shuffle_tracks=True, 176 | track_density=False, 177 | window_size=32, 178 | hop_length=16, 179 | time_sig=True, 180 | velocity=True, 181 | ) 182 | dataset = LakhMIDI() 183 | dataset.tokenize( 184 | dataset_path="H:/INVESTIGACION/Datasets/LMD/clean_midi", 185 | output_path="H:/GitHub/musanalysis-datasets/lmd/mmm/32_bars_166", 186 | output_file="token-sequences", 187 | args=args, 188 | tokenize_split="validation" 189 | ) 190 | vocab = MMMTokenizer.get_vocabulary( 191 | dataset_path="H:/GitHub/musanalysis-datasets/lmd/mmm/32_bars_166" 192 | ) 193 | -------------------------------------------------------------------------------- /musicaiz/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import multiprocessing as mp 3 | from tqdm import tqdm 4 | from typing import List, Optional, Dict, Type 5 | from rich.console import Console 6 | 7 | 8 | from musicaiz.tokenizers import ( 9 | MMMTokenizer, 10 | MMMTokenizerArguments, 11 | REMITokenizer, 12 | REMITokenizerArguments, 13 | TOKENIZER_ARGUMENTS, 14 | TokenizerArguments, 15 | ) 16 | 17 | 18 | def tokenize_path( 19 | dataset_path: str, 20 | dest_path: str, 21 | metadata: Optional[Dict], 22 | output_file: str, 23 | args: Type[TokenizerArguments], 24 | ) -> None: 25 | 26 | text_file = open(Path(dest_path, output_file + ".txt"), "w") 27 | 28 | n_jobs = mp.cpu_count() 29 | pool = mp.Pool(n_jobs) 30 | results = [] 31 | 32 | if metadata is not None: 33 | elements = metadata.keys() 34 | total = len(list(metadata.keys())) 35 | else: 36 | elements = dataset_path.rglob("*.mid") 37 | elements = [f.name for f in dataset_path.rglob("*.mid")] 38 | total = len(elements) 39 | 40 | for el in elements: 41 | # Some files in LMD hace errors (OSError: data byte must be in range 0..127), 42 | # so we avoid parsing those files 43 | results.append( 44 | { 45 | "result": pool.apply_async( 46 | _processer, 47 | args=( 48 | metadata, 49 | el, 50 | dataset_path, 51 | args 52 | ) 53 | ), 54 | "file": el 55 | } 56 | ) 57 | 58 | pbar = tqdm( 59 | results, 60 | total=total, 61 | bar_format="{l_bar}{bar:10}{r_bar}", 62 | colour="GREEN" 63 | ) 64 | for result in pbar: 65 | console = Console() 66 | console.print(f'Processing file [bold orchid1]{result["file"]}[/bold orchid1]') 67 | res = result["result"].get() 68 | if res is not None: 69 | text_file.write(res) 70 | 71 | pool.close() 72 | pool.join() 73 | 74 | 75 | def _processer( 76 | metadata: Optional[Dict], 77 | data_piece: Path, 78 | dataset_path: Path, 79 | args: Type[TokenizerArguments], 80 | ) -> List[str]: 81 | 82 | """ 83 | 84 | Parameters 85 | ---------- 86 | 87 | data_piece: pathlib.Path 88 | The path to the midi file. 89 | 90 | data_piece: pathlib.Path 91 | The parent path where the midi file is. 92 | 93 | 94 | Raises: 95 | ValueError: _description_ 96 | 97 | Returns: 98 | _type_: _description_ 99 | """ 100 | 101 | try: 102 | prev_tokens = "" 103 | # Tokenize 104 | file = Path(dataset_path, data_piece) 105 | if metadata is not None: 106 | if "composer" in metadata[data_piece].keys(): 107 | prev_tokens = f"COMPOSER={metadata[data_piece]['composer']} " 108 | if "period" in metadata[data_piece].keys(): 109 | prev_tokens += f"PERIOD={metadata[data_piece]['period']} " 110 | if "genre" in metadata[data_piece].keys(): 111 | prev_tokens += f"GENRE={metadata[data_piece]['genre']}" 112 | 113 | if type(args) not in TOKENIZER_ARGUMENTS: 114 | raise ValueError("Non valid tokenizer args object.") 115 | if isinstance(args, MMMTokenizerArguments): 116 | args.prev_tokens = prev_tokens 117 | tokenizer = MMMTokenizer(file, args) 118 | piece_tokens = tokenizer.tokenize_file() 119 | elif isinstance(args, REMITokenizerArguments): 120 | args.prev_tokens = prev_tokens 121 | tokenizer = REMITokenizer(file, args) 122 | piece_tokens = tokenizer.tokenize_file() 123 | else: 124 | raise ValueError("Non valid tokenizer.") 125 | piece_tokens += "\n" 126 | 127 | return piece_tokens 128 | except: 129 | pass 130 | -------------------------------------------------------------------------------- /musicaiz/features/graphs.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | 4 | 5 | def musa_to_graph(musa_object) -> nx.graph: 6 | """Converts a Musa object into a Graph where nodes are 7 | the notes and edges are connections between notes. 8 | 9 | A similar symbolic music graph representation was introduced in: 10 | 11 | Jeong, D., Kwon, T., Kim, Y., & Nam, J. (2019, May). 12 | Graph neural network for music score data and modeling expressive piano performance. 13 | In International Conference on Machine Learning (pp. 3060-3070). PMLR. 14 | 15 | Parameters 16 | ---------- 17 | musa_object 18 | 19 | Returns 20 | ------- 21 | _type_: _description_ 22 | """ 23 | g = nx.Graph() 24 | for i, note in enumerate(musa_object.notes): 25 | g.add_node(i, pitch=note.pitch, velocity=note.velocity, start=note.start_ticks, end=note.end_ticks) 26 | nodes = list(g.nodes(data=True)) 27 | 28 | # Add edges 29 | for i, node in enumerate(nodes): 30 | for j, next_node in enumerate(nodes): 31 | # if note has already finished it's not in the current subdivision 32 | # TODO: Check these conditions 33 | if i >= j: 34 | continue 35 | if node[1]["start"] >= next_node[1]["start"] and next_node[1]["end"] <= node[1]["end"]: 36 | g.add_edge(i, j, weight=5, color="violet") 37 | elif node[1]["start"] <= next_node[1]["start"] and next_node[1]["end"] <= node[1]["end"]: 38 | g.add_edge(i, j, weight=5, color="violet") 39 | if (j - i == 1) and (not g.has_edge(i, j)): 40 | g.add_edge(i, j, weight=5, color="red") 41 | if g.has_edge(i, i): 42 | g.remove_edge(i, i) 43 | return g 44 | 45 | 46 | def plot_graph(graph: nx.graph, show: bool = False): 47 | """Plots a graph with matplotlib. 48 | 49 | Args: 50 | graph: nx.graph 51 | """ 52 | plt.figure(figsize=(50, 10), dpi=100) 53 | "Plots a networkx graph." 54 | pos = {i: (data["start"], data["pitch"]) for i, data in list(graph.nodes(data=True))} 55 | if nx.get_edge_attributes(graph, 'color') == {}: 56 | colors = ["violet" for _ in range(len(graph.edges()))] 57 | else: 58 | colors = nx.get_edge_attributes(graph, 'color').values() 59 | if nx.get_edge_attributes(graph, 'weight') == {}: 60 | weights = [1 for _ in range(len(graph.edges()))] 61 | else: 62 | weights = nx.get_edge_attributes(graph, 'weight').values() 63 | nx.draw( 64 | graph, 65 | pos, 66 | with_labels=True, 67 | edge_color=colors, 68 | width=list(weights), 69 | node_color='lightblue' 70 | ) 71 | if show: 72 | plt.show() 73 | -------------------------------------------------------------------------------- /musicaiz/features/predict_midi.py: -------------------------------------------------------------------------------- 1 | # This module uses the functions defined in the 2 | # other `features` submodules to predict features 3 | # from a midi files. 4 | 5 | 6 | from typing import Union, TextIO, List 7 | 8 | 9 | from musicaiz.loaders import Musa 10 | from musicaiz.rhythm import get_subdivisions 11 | from .rhythm import ( 12 | get_start_sec, 13 | get_ioi, 14 | get_labeled_beat_vector, 15 | compute_all_rmss 16 | ) 17 | from musicaiz.structure import Note 18 | 19 | 20 | def _concatenate_notes_from_different_files( 21 | files: List[Union[str, TextIO]] 22 | ) -> List[Note]: 23 | # load midi file 24 | file_notes = [] 25 | for i, file in enumerate(files): 26 | midi_object = Musa(file, structure="instruments") 27 | # extract notes from all instruments that are not drums 28 | file_notes += [instr.notes for instr in midi_object.instruments if not instr.is_drum] 29 | # Concatenate all lists into one 30 | all_notes = sum(file_notes, []) 31 | 32 | # extract subdivisions 33 | subdivisions = get_subdivisions( 34 | total_bars=midi_object.total_bars, 35 | subdivision="eight", 36 | time_sig=midi_object.time_sig, 37 | ) 38 | return all_notes, subdivisions 39 | 40 | 41 | def predic_time_sig_numerator(files: List[Union[str, TextIO]]): 42 | """Uses `features.rhythm` functions.""" 43 | # load midi files and get all notes 44 | all_notes, subdivisions = _concatenate_notes_from_different_files(files) 45 | # 1. Get iois 46 | note_ons = get_start_sec(all_notes) 47 | iois = get_ioi(note_ons) 48 | # 2. Get labeled beat vector 49 | ioi_prime = get_labeled_beat_vector(iois) 50 | # 3. Get all rmss matrices 51 | all_rmss = compute_all_rmss(ioi_prime) 52 | # 4. Select rmss with more bar repeated instances 53 | # TODO 54 | pass 55 | return all_rmss 56 | -------------------------------------------------------------------------------- /musicaiz/features/structure.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import ruptures as rpt 3 | import numpy as np 4 | import networkx as nx 5 | from dataclasses import dataclass 6 | from typing import List, Union, Optional 7 | from pathlib import Path 8 | from enum import Enum 9 | 10 | from musicaiz.loaders import Musa 11 | from musicaiz.features import ( 12 | get_novelty_func, 13 | musa_to_graph 14 | ) 15 | 16 | 17 | LEVELS = ["high", "mid", "low"] 18 | DATASETS = ["BPS", "SWD"] 19 | 20 | 21 | @dataclass 22 | class PeltArgs: 23 | penalty: int 24 | alpha: int 25 | betha: int 26 | level: str 27 | 28 | class LevelsBPS(Enum): 29 | 30 | HIGH = PeltArgs( 31 | penalty = 4, 32 | alpha = 2.3, 33 | betha = 1.5, 34 | level = "high" 35 | ) 36 | 37 | MID = PeltArgs( 38 | penalty = 0.5, 39 | alpha = 1, 40 | betha = 0.01, 41 | level = "mid" 42 | ) 43 | 44 | LOW = PeltArgs( 45 | penalty = 0.1, 46 | alpha = 0.1, 47 | betha = 0.15, 48 | level = "low" 49 | ) 50 | 51 | class LevelsSWD(Enum): 52 | 53 | MID = PeltArgs( 54 | penalty = 0.7, 55 | alpha = 0.6, 56 | betha = 0.15, 57 | level = "mid" 58 | ) 59 | 60 | class StructurePrediction: 61 | 62 | def __init__( 63 | self, 64 | file: Optional[Union[str, Path]] = None, 65 | ): 66 | 67 | # Convert file into a Musa object to be processed 68 | if file is not None: 69 | self.midi_object = Musa( 70 | file=file, 71 | ) 72 | else: 73 | self.midi_object = Musa(file=None) 74 | 75 | def notes(self, level: str, dataset: str) -> List[int]: 76 | return self._get_structure_boundaries(level, dataset) 77 | 78 | def beats(self, level: str, dataset: str) -> List[int]: 79 | result = self._get_structure_boundaries(level, dataset) 80 | return [self.midi_object.notes[n].beat_idx for n in result] 81 | 82 | def bars(self, level: str, dataset: str) -> List[int]: 83 | result = self._get_structure_boundaries(level, dataset) 84 | return [self.midi_object.notes[n].bar_idx for n in result] 85 | 86 | def ms(self, level: str, dataset: str) -> List[float]: 87 | result = self._get_structure_boundaries(level, dataset) 88 | return [self.midi_object.notes[n].start_sec * 1000 for n in result] 89 | 90 | def _get_structure_boundaries( 91 | self, 92 | level: str, 93 | dataset: str 94 | ): 95 | """ 96 | Get the note indexes where a section ends. 97 | """ 98 | if level not in LEVELS: 99 | raise ValueError(f"Level {level} not supported.") 100 | if dataset not in DATASETS: 101 | raise ValueError(f"Dataset {dataset} not supported.") 102 | if level == "high" and dataset == "BPS": 103 | pelt_args = LevelsBPS.HIGH.value 104 | elif level == "mid" and dataset == "BPS": 105 | pelt_args = LevelsBPS.MID.value 106 | elif level == "low" and dataset == "BPS": 107 | pelt_args = LevelsBPS.LOW.value 108 | elif level == "mid" and dataset == "SWD": 109 | pelt_args = LevelsSWD.MID.value 110 | 111 | g = musa_to_graph(self.midi_object) 112 | mat = nx.attr_matrix(g)[0] 113 | n = get_novelty_func(mat) 114 | nn = np.reshape(n, (n.size, 1)) 115 | # detection 116 | try: 117 | algo = rpt.Pelt( 118 | model="rbf", 119 | min_size=pelt_args.alpha*(len(self.midi_object.notes)/15), 120 | jump=int(pelt_args.betha*pelt_args.alpha*(len(self.midi_object.notes)/15)), 121 | ).fit(nn) 122 | result = algo.predict(pen=pelt_args.penalty) 123 | except: 124 | warnings.warn("No structure found.") 125 | result = [0, len(self.midi_object.notes)-1] 126 | 127 | return result 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /musicaiz/harmony/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Harmony 3 | ======== 4 | 5 | This submodule contains objects that are related to harmony elements. 6 | 7 | The basic harmonic elements are: 8 | 9 | - Key: A tonic and mode (additionally it can have a chord progresison as attribute) 10 | 11 | - Chord Progression: A list of chords 12 | 13 | - Chord: List of 2 intervals (triad chords), 3 intervals (7ths), etc. 14 | 15 | - Interval: List of 2 notes. 16 | 17 | Intervals 18 | --------- 19 | 20 | .. autosummary:: 21 | :toctree: generated/ 22 | 23 | IntervalClass 24 | IntervalQuality 25 | IntervalSemitones 26 | IntervalComplexity 27 | Interval 28 | 29 | Chords 30 | ------ 31 | 32 | .. autosummary:: 33 | :toctree: generated/ 34 | 35 | ChordQualities 36 | ChordType 37 | AllChords 38 | Chord 39 | 40 | Keys 41 | ---- 42 | 43 | .. autosummary:: 44 | :toctree: generated/ 45 | 46 | DegreesQualities 47 | DegreesRoman 48 | Degrees 49 | MajorTriadDegrees 50 | MinorNaturalTriadDegrees 51 | MinorHarmonicTriadDegrees 52 | MinorMelodicTriadDegrees 53 | DorianTriadDegrees 54 | PhrygianTriadDegrees 55 | LydianTriadDegrees 56 | MixolydianTriadDegrees 57 | LocrianTriadDegrees 58 | TriadsModes 59 | MajorSeventhDegrees 60 | MinorNaturalSeventhDegrees 61 | MinorHarmonicSeventhDegrees 62 | MinorMelodicSeventhDegrees 63 | DorianSeventhDegrees 64 | PhrygianSeventhDegrees 65 | LydianSeventhDegrees 66 | MixolydianSeventhDegrees 67 | LocrianSeventhDegrees 68 | SeventhsModes 69 | AccidentalNotes 70 | AccidentalDegrees 71 | ModeConstructors 72 | Scales 73 | Tonality 74 | """ 75 | 76 | from .intervals import ( 77 | Interval, 78 | IntervalClass, 79 | IntervalQuality, 80 | IntervalSemitones, 81 | IntervalComplexity, 82 | ) 83 | from .chords import ( 84 | AllChords, 85 | ChordQualities, 86 | ChordType, 87 | Chord, 88 | ) 89 | from .keys import ( 90 | DegreesQualities, 91 | DegreesRoman, 92 | Degrees, 93 | MajorTriadDegrees, 94 | MinorNaturalTriadDegrees, 95 | MinorHarmonicTriadDegrees, 96 | MinorMelodicTriadDegrees, 97 | DorianTriadDegrees, 98 | PhrygianTriadDegrees, 99 | LydianTriadDegrees, 100 | MixolydianTriadDegrees, 101 | LocrianTriadDegrees, 102 | TriadsModes, 103 | MajorSeventhDegrees, 104 | MinorNaturalSeventhDegrees, 105 | MinorHarmonicSeventhDegrees, 106 | MinorMelodicSeventhDegrees, 107 | DorianSeventhDegrees, 108 | PhrygianSeventhDegrees, 109 | LydianSeventhDegrees, 110 | MixolydianSeventhDegrees, 111 | LocrianSeventhDegrees, 112 | SeventhsModes, 113 | AccidentalNotes, 114 | AccidentalDegrees, 115 | ModeConstructors, 116 | Scales, 117 | Tonality, 118 | ) 119 | 120 | __all__ = [ 121 | "IntervalClass", 122 | "IntervalQuality", 123 | "IntervalSemitones", 124 | "IntervalComplexity", 125 | "Interval", 126 | "ChordQualities", 127 | "ChordType", 128 | "AllChords", 129 | "Chord", 130 | "DegreesQualities", 131 | "DegreesRoman", 132 | "Degrees", 133 | "MajorTriadDegrees", 134 | "MinorNaturalTriadDegrees", 135 | "MinorHarmonicTriadDegrees", 136 | "MinorMelodicTriadDegrees", 137 | "DorianTriadDegrees", 138 | "PhrygianTriadDegrees", 139 | "LydianTriadDegrees", 140 | "MixolydianTriadDegrees", 141 | "LocrianTriadDegrees", 142 | "TriadsModes", 143 | "MajorSeventhDegrees", 144 | "MinorNaturalSeventhDegrees", 145 | "MinorHarmonicSeventhDegrees", 146 | "MinorMelodicSeventhDegrees", 147 | "DorianSeventhDegrees", 148 | "PhrygianSeventhDegrees", 149 | "LydianSeventhDegrees", 150 | "MixolydianSeventhDegrees", 151 | "LocrianSeventhDegrees", 152 | "SeventhsModes", 153 | "AccidentalNotes", 154 | "AccidentalDegrees", 155 | "ModeConstructors", 156 | "Scales", 157 | "Tonality", 158 | ] 159 | -------------------------------------------------------------------------------- /musicaiz/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models 3 | ====== 4 | 5 | This module provides baseline models for symbolic music generation. 6 | 7 | The submodule is divided in: 8 | 9 | - Transformer Composers: Transformer-based models. 10 | 11 | 12 | Transformer Composers 13 | --------------------- 14 | 15 | Contains a GPT model that can be trained to generate symbolic music. 16 | 17 | .. autosummary:: 18 | :toctree: generated/ 19 | 20 | transformer_composers 21 | 22 | """ 23 | 24 | 25 | from . import ( 26 | transformer_composers, 27 | ) 28 | 29 | __all__ = [ 30 | "transformer_composers", 31 | ] -------------------------------------------------------------------------------- /musicaiz/models/transformer_composers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer Composers 3 | ===================== 4 | 5 | This submodule presents a GPT2 model that generates music. 6 | 7 | The tokenization is previously done with `musanalysis` 8 | :func:`~musanalysis.tokenizers.MMMTokenizer` class. 9 | 10 | 11 | Installation 12 | ------------ 13 | 14 | To train these models you should install torch with cuda. We recommend torch version 1.11.0 15 | with cuda 113: 16 | 17 | >>> pip3 install torch==1.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 18 | 19 | Apart from that, `apex` is also necessary. To install it properly, follow the instructions in: 20 | https://github.com/NVIDIA/apex 21 | 22 | 23 | Configurations 24 | -------------- 25 | 26 | .. autosummary:: 27 | :toctree: generated/ 28 | 29 | GPTConfigs 30 | TrainConfigs 31 | 32 | 33 | Dataloaders 34 | ----------- 35 | 36 | .. autosummary:: 37 | :toctree: generated/ 38 | 39 | build_torch_loaders 40 | get_vocabulary 41 | 42 | 43 | Model 44 | ----- 45 | 46 | .. autosummary:: 47 | :toctree: generated/ 48 | 49 | self_attention 50 | MultiheadAttention 51 | PositionalEncoding 52 | Embedding 53 | ResidualConnection 54 | FeedForward 55 | Decoder 56 | GPT2 57 | 58 | 59 | Train 60 | ----- 61 | 62 | .. autosummary:: 63 | :toctree: generated/ 64 | 65 | train 66 | 67 | Generation 68 | ---------- 69 | 70 | .. autosummary:: 71 | :toctree: generated/ 72 | 73 | sample_sequence 74 | 75 | 76 | Gradio App 77 | ---------- 78 | 79 | There's a simple app for this model built with Gradio. 80 | To try the demo locally run: 81 | 82 | >>> python models/transformer_composers/app.py 83 | 84 | 85 | Examples 86 | -------- 87 | 88 | Train model: 89 | 90 | >>> python models/transformer_composers/train.py --dataset_path="..." --is_splitted True 91 | 92 | Generate Sequence: 93 | 94 | >>> python models/transformer_composers/generate.py --dataset_path H:/GitHub/musicaiz-datasets/jsbchorales/mmm/all_bars --dataset_name jsbchorales --save_midi True --file_path ../midi 95 | """ 96 | 97 | from .configs import ( 98 | GPTConfigs, 99 | TrainConfigs 100 | ) 101 | from .dataset import ( 102 | build_torch_loaders, 103 | get_vocabulary 104 | ) 105 | from .transformers import ( 106 | self_attention, 107 | MultiheadAttention, 108 | PositionalEncoding, 109 | Embedding, 110 | ResidualConnection, 111 | FeedForward, 112 | Decoder, 113 | GPT2, 114 | ) 115 | from .train import train 116 | from .generate import sample_sequence 117 | 118 | 119 | __all__ = [ 120 | "GPTConfigs", 121 | "TrainConfigs", 122 | "build_torch_loaders", 123 | "get_vocabulary", 124 | "self_attention", 125 | "MultiheadAttention", 126 | "PositionalEncoding", 127 | "Embedding", 128 | "ResidualConnection", 129 | "FeedForward", 130 | "Decoder", 131 | "GPT2", 132 | "train", 133 | "sample_sequence", 134 | ] -------------------------------------------------------------------------------- /musicaiz/models/transformer_composers/configs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from datetime import datetime 3 | 4 | 5 | class GPTConfigs: 6 | """ 7 | ... 8 | The vocabulary size is given by the ``vocabulary.txt`` file that 9 | must be placed in the dataset path (this file is generated when 10 | tokenizing). 11 | """ 12 | N_DECODERS = 2 13 | SEQ_LEN = 512 14 | EMBED_DIM = 32 # also d_model 15 | N_HEADS = 4 # must be divisor of embed dim 16 | DROPOUT = 0.1 17 | 18 | 19 | class TrainConfigs: 20 | TRAIN_SPLIT = 0.8 21 | IS_SPLITTED = False # if dataset is not already splitted in 2 dirs: train and validation 22 | CHECKPOINT_PATH = Path("results", str(datetime.now().strftime("%Y-%m-%d_%H-%M"))) 23 | MODEL_NAME = "gpt" 24 | WEIGHT_DECAY = 0.01 25 | EPOCHS = 250 26 | BATCH_SIZE = 64 27 | LR = 5e-3 28 | ADAM_EPSILON = 1e-6 29 | CKPT_STEPS = 100 # steps to save checkpoint 30 | LOG_STEPS = 1 31 | LOG_DIR = Path("results", str(datetime.now().strftime("%Y-%m-%d_%H-%M"))) 32 | GRAD_ACUM_STEPS = 1 33 | FP16 = True 34 | FP16_OPT_LEVEL = "O2" 35 | -------------------------------------------------------------------------------- /musicaiz/models/transformer_composers/generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import torch 4 | import torch.nn.functional as F 5 | from pathlib import Path 6 | from tqdm import trange 7 | import json 8 | from typing import Union, Optional 9 | import argparse 10 | 11 | from musicaiz.models.transformer_composers.dataset import get_vocabulary 12 | from musicaiz.models.transformer_composers.train import initialize_model 13 | from musicaiz.models.transformer_composers.configs import GPTConfigs 14 | from musicaiz.tokenizers.mmm import MMMTokenizer 15 | 16 | 17 | def indices_to_text(indices, vocabulary): 18 | return " ".join([vocabulary[index] for index in indices]) 19 | 20 | 21 | def top_k_logits(logits, k): 22 | if k == 0: 23 | return logits 24 | values, _ = torch.topk(logits, k) 25 | min_values = values[:, -1] 26 | return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) 27 | 28 | 29 | def sample_sequence( 30 | dataset_path: Union[Path, str], 31 | checkpoint_path: Union[Path, str] = Path("results"), 32 | model_name: str = "gpt", 33 | dataset_name: str = "maestro", 34 | start_token: Optional[int] = None, #vocabulary.index("PIECE_START") 35 | batch_size: int = 1, 36 | context: str = "PIECE_START", 37 | temperature: float = 1.0, 38 | top_k: int = 0, 39 | device: str = 'cpu', 40 | sample: bool = True, 41 | seq_len: int = 512, 42 | save_midi: bool = False, 43 | file_path: str = "" 44 | ) -> str: 45 | 46 | """ 47 | This function generates a sequence from a pretrained model. The condition to generate the sequence is 48 | to store the pretrained model as `model_name.pth` in the directory `modelname_datasetname`. 49 | The dataset path must be provided to look at the `vocabulary.txt` file that allows to convert the token 50 | indexes to the token names. 51 | Another thing to consider is that, when saving the midi file the `time_unit` must be the same of the ones that 52 | the training dataset so the midi will be generated with the correct timing information. 53 | 54 | Parameters 55 | ---------- 56 | 57 | dataset_path: Union[Path, str] 58 | 59 | checkpoint_path: Union[Path, str] = Path("results") 60 | 61 | model_name: str = "gpt" 62 | 63 | dataset_name: str = "maestro" 64 | 65 | start_token: Optional[int] = None 66 | 67 | batch_size: int = 1 68 | 69 | context: str = "PIECE_START" 70 | 71 | temperature: float = 1.0 72 | 73 | top_k: int = 0 74 | 75 | device: str = 'cpu' 76 | 77 | sample: bool = True 78 | 79 | seq_len: int = 512 80 | 81 | save_midi: bool = False 82 | 83 | file_path: str = "" 84 | 85 | Returns 86 | ------- 87 | 88 | token_seq: str 89 | The generated token sequence. 90 | """ 91 | # TODO: STOP generation if a PAD is generated 92 | vocabulary = get_vocabulary(dataset_path) 93 | if " " in context: 94 | context = context.split(" ") 95 | else: 96 | context = list(context) 97 | context = [vocabulary.index(c) for c in context] 98 | # Get configs file 99 | model_path = Path(checkpoint_path, model_name + "_" + dataset_name) 100 | configs_file = Path(model_path, model_name + "_configs.json") 101 | with open(configs_file) as file: 102 | configs = json.load(file) 103 | 104 | model = initialize_model(model_name, configs, device) 105 | model.to(device) 106 | model.load_state_dict( 107 | torch.load( 108 | Path(model_path, model_name + ".pth"), 109 | map_location=device 110 | )["model_state_dict"]), 111 | 112 | model.eval() 113 | 114 | if start_token is None: 115 | assert context is not None, 'You must give the start_token or the context' 116 | context = torch.tensor(context, device=device, dtype=torch.float16).unsqueeze(0).repeat(batch_size, 1) 117 | else: 118 | assert context is None, 'You must give the start_token or the context' 119 | context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.float16) 120 | 121 | prev = context 122 | output = context 123 | with torch.no_grad(): 124 | for i in trange(seq_len): #trange(configs["model_configs"]["SEQ_LEN"]): 125 | logits = model(output) 126 | logits = logits[:, -1, :] / temperature 127 | logits = top_k_logits(logits, k=top_k) 128 | log_probs = F.softmax(logits, dim=-1) 129 | if sample: 130 | prev = torch.multinomial(log_probs, num_samples=1) 131 | else: 132 | _, prev = torch.topk(log_probs, k=1, dim=-1) 133 | output = torch.cat((output, prev), dim=1) 134 | output = output.to(torch.int) 135 | 136 | logging.info(f"Generated token seq indexes: {output.tolist()[0]}") 137 | 138 | token_seq = indices_to_text(output.tolist()[0], vocabulary) 139 | logging.info(f"Generated token seq is: {token_seq}") 140 | 141 | if save_midi: 142 | data = list(token_seq.split(" ")) 143 | # TODO: be careful with the time unit in which TIME_DELTA tokens are represented 144 | midi_obj = MMMTokenizer.tokens_to_musa(data, absolute_timing=True, time_unit="HUNDRED_TWENTY_EIGHT") 145 | midi_obj.write_midi(file_path) 146 | logging.info(f"Saved midi file to: {token_seq}") 147 | 148 | return token_seq 149 | 150 | def parse_args(): 151 | parser = argparse.ArgumentParser() 152 | 153 | parser.add_argument( 154 | "--dataset_path", 155 | type=str, 156 | help="", 157 | required=True, 158 | ) 159 | parser.add_argument( 160 | "--dataset_name", 161 | type=str, 162 | help="", 163 | required=True, 164 | ) 165 | parser.add_argument( 166 | "--sequence_length", 167 | type=int, 168 | help="", 169 | required=False, 170 | default=GPTConfigs.SEQ_LEN, 171 | ) 172 | parser.add_argument( 173 | "--save_midi", 174 | type=bool, 175 | help="", 176 | required=False, 177 | default=False, 178 | ) 179 | parser.add_argument( 180 | "--file_path", 181 | type=str, 182 | help="", 183 | required=False, 184 | default="", 185 | ) 186 | return parser.parse_args() 187 | 188 | 189 | if __name__ == "__main__": 190 | args = parse_args() 191 | sample_sequence( 192 | dataset_path=args.dataset_path, 193 | dataset_name=args.dataset_name, 194 | seq_len=args.sequence_length, 195 | save_midi=args.save_midi, 196 | file_path=args.file_path, 197 | ) 198 | -------------------------------------------------------------------------------- /musicaiz/models/transformer_composers/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt import * 2 | from .attention import * 3 | from .layers import * -------------------------------------------------------------------------------- /musicaiz/models/transformer_composers/transformers/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class MultiheadAttention(nn.Module): 8 | def __init__( 9 | self, 10 | n_heads: int, 11 | embed_dim: int, 12 | device: str, 13 | dropout: float = 0.1, 14 | causal: bool = False 15 | ): 16 | super(MultiheadAttention, self).__init__() 17 | 18 | self.n_head = n_heads 19 | self.d_model = embed_dim 20 | self.d_k = self.d_v = embed_dim // n_heads 21 | self.causal = causal 22 | self.device = device 23 | 24 | self.w_q = nn.Linear(embed_dim, embed_dim) 25 | self.w_k = nn.Linear(embed_dim, embed_dim) 26 | self.w_v = nn.Linear(embed_dim, embed_dim) 27 | self.w_o = nn.Linear(embed_dim, embed_dim) 28 | 29 | self.self_attention = self_attention 30 | self.dropout = nn.Dropout(p=dropout) 31 | 32 | def forward(self, query, key, value, mask = True): 33 | batch_num = query.size(0) 34 | 35 | query = self.w_q(query).view(batch_num, -1, self.n_head, self.d_k).transpose(1, 2) 36 | key = self.w_k(key).view(batch_num, -1, self.n_head, self.d_k).transpose(1, 2) 37 | value = self.w_v(value).view(batch_num, -1, self.n_head, self.d_k).transpose(1, 2) 38 | 39 | attention_result, attention_score = self.self_attention(query, key, value, self.device, mask, self.causal) 40 | attention_result = attention_result.transpose(1,2).contiguous().view(batch_num, -1, self.n_head * self.d_k) 41 | attn_output = self.w_o(attention_result) 42 | 43 | return attn_output 44 | 45 | 46 | def self_attention(query, key, value, device: str, mask=True, causal=False): 47 | key_transpose = torch.transpose(key, -2, -1) 48 | matmul_result = torch.matmul(query, key_transpose) 49 | d_k = query.size()[-1] 50 | attention_score = matmul_result/math.sqrt(d_k) 51 | 52 | if mask: 53 | mask = (torch.triu(torch.ones((query.size()[2], query.size()[2]))) == 1) 54 | mask = mask.transpose(0, 1).float() 55 | mask = mask.masked_fill(mask == 0, -1e20) 56 | attention_score = mask.masked_fill(mask == 1, float(0.0)) 57 | 58 | if causal: 59 | query_len = query.size()[2] 60 | i, j = torch.triu_indices(query_len, query_len, 1) 61 | attention_score[i, j] = -1e4 62 | 63 | softmax_attention_score = F.softmax(attention_score, dim=-1).to(device) 64 | 65 | # When training with fp16, `softmax_attention_score` music be float16 but not when generating 66 | try: 67 | result = torch.matmul(softmax_attention_score, value) 68 | except: 69 | result = torch.matmul(softmax_attention_score.to(torch.float16), value) 70 | 71 | return result, softmax_attention_score 72 | -------------------------------------------------------------------------------- /musicaiz/models/transformer_composers/transformers/gpt.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from musicaiz.models.transformer_composers.transformers.attention import MultiheadAttention 4 | from musicaiz.models.transformer_composers.transformers.layers import ( 5 | ResidualConnection, 6 | FeedForward, 7 | Embedding 8 | ) 9 | 10 | 11 | class GPT2(nn.Module): 12 | def __init__( 13 | self, 14 | vocab_size, 15 | embedding_dim, 16 | n_decoders, 17 | sequence_len, 18 | n_heads, 19 | device: str, 20 | causal=False 21 | ): 22 | super(GPT2, self).__init__() 23 | 24 | self.vocab_size = vocab_size 25 | self.embedding_dim = embedding_dim 26 | self.n_decoders = n_decoders 27 | self.sequence_len = sequence_len 28 | self.n_heads = n_heads 29 | self.causal = causal 30 | self.device = device 31 | 32 | # Embedding 33 | self.embedding = Embedding( 34 | vocab_size=vocab_size, embedding_dim=embedding_dim, device=device 35 | ) 36 | 37 | self.decoders = nn.Sequential( 38 | *[Decoder( 39 | d_model=embedding_dim, n_head=n_heads, device=self.device, dropout=0.1 40 | ) for _ in range(n_decoders)] 41 | ) 42 | 43 | self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False) 44 | 45 | def forward(self, input_ids): 46 | x = self.embedding(input_ids.long()) 47 | x = self.decoders(x) 48 | lm_logits = self.lm_head(x) 49 | 50 | return lm_logits 51 | 52 | 53 | class Decoder(nn.Module): 54 | def __init__(self, d_model, n_head, dropout, device, causal=True): 55 | super(Decoder,self).__init__() 56 | 57 | self.masked_multi_head_attention = MultiheadAttention( 58 | embed_dim=d_model, n_heads=n_head, causal=causal, device=device 59 | ) 60 | self.residual_1 = ResidualConnection(d_model, dropout=dropout) 61 | self.feed_forward= FeedForward(d_model) 62 | self.residual_2 = ResidualConnection(d_model, dropout=dropout) 63 | 64 | 65 | def forward(self, x): 66 | x = self.residual_1(x, lambda x: self.masked_multi_head_attention(x, x, x)) 67 | x = self.residual_2(x, self.feed_forward) 68 | 69 | return x -------------------------------------------------------------------------------- /musicaiz/models/transformer_composers/transformers/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | 9 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 10 | super().__init__() 11 | self.dropout = nn.Dropout(p=dropout) 12 | 13 | position = torch.arange(max_len).unsqueeze(1) 14 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 15 | pe = torch.zeros(max_len, 1, d_model) 16 | pe[:, 0, 0::2] = torch.sin(position * div_term) 17 | pe[:, 0, 1::2] = torch.cos(position * div_term) 18 | self.register_buffer('pe', pe) 19 | 20 | def forward(self, x): 21 | x = x + self.pe[:x.size(0)] 22 | return self.dropout(x) 23 | 24 | 25 | class Embedding(nn.Module): 26 | def __init__(self, vocab_size, embedding_dim, device: str): 27 | super().__init__() 28 | self.vocab_size = vocab_size 29 | self.device = device 30 | self.token_emb = nn.Embedding( 31 | num_embeddings=vocab_size, embedding_dim=embedding_dim 32 | ) 33 | self.position_emb = PositionalEncoding( 34 | d_model=embedding_dim 35 | ) 36 | 37 | def forward(self, input_ids): 38 | x = self.token_emb(input_ids) * math.sqrt(self.vocab_size) 39 | x = self.position_emb(x.type(torch.IntTensor).to(self.device)) 40 | return x 41 | 42 | 43 | class ResidualConnection(nn.Module): 44 | def __init__(self, size, dropout): 45 | super(ResidualConnection,self).__init__() 46 | self.norm = nn.LayerNorm(size) 47 | self.dropout = nn.Dropout(dropout) 48 | 49 | def forward(self, x, sublayer): 50 | return x + self.dropout(sublayer(self.norm(x))) 51 | 52 | 53 | class FeedForward(nn.Module): 54 | def __init__(self, d_model, dropout=0.1, activation='gelu'): 55 | super(FeedForward, self).__init__() 56 | self.w_1 = nn.Linear(d_model, d_model * 4) 57 | self.w_2 = nn.Linear(d_model * 4, d_model) 58 | self.dropout = nn.Dropout(p=dropout) 59 | 60 | if activation == 'gelu': 61 | self.activation = F.gelu 62 | elif activation == 'relu': 63 | self.activation = F.relu 64 | 65 | def forward(self, x): 66 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 67 | -------------------------------------------------------------------------------- /musicaiz/plotters/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotters 3 | ======== 4 | 5 | This module allows to plot Musa objects. 6 | 7 | 8 | Pianoroll 9 | --------- 10 | 11 | .. autosummary:: 12 | :toctree: generated/ 13 | 14 | Pianoroll 15 | PianorollHTML 16 | 17 | """ 18 | 19 | from .pianorolls import Pianoroll, PianorollHTML 20 | 21 | __all__ = [ 22 | "Pianoroll", 23 | "PianorollHTML" 24 | ] 25 | -------------------------------------------------------------------------------- /musicaiz/rhythm/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rhythm 3 | ====== 4 | 5 | This module provides objects and methods that define and deal with time events. 6 | 7 | The submodule is divided in: 8 | 9 | - Key: A tonic and mode (additionally it can have a chord progresison as attribute) 10 | 11 | - Chord Progression: A list of chords 12 | 13 | - Chord: List of 2 intervals (triad chords), 3 intervals (7ths), etc. 14 | 15 | - Interval: List of 2 notes. 16 | 17 | Timing 18 | ------ 19 | 20 | Defines and contains helper functions to deal with Time Signatures, time events, etc. 21 | 22 | .. autosummary:: 23 | :toctree: generated/ 24 | 25 | TimingConsts 26 | NoteLengths 27 | ms_per_tick 28 | _bar_str_to_tuple 29 | ticks_per_bar 30 | ms_per_note 31 | get_subdivisions 32 | TimeSignature 33 | 34 | 35 | Quantizer 36 | --------- 37 | 38 | Quantizes symbolic music as it is done in Logic Pro by following the steps 39 | described in: 40 | 41 | [1] https://www.fransabsil.nl/archpdf/advquant.pdf 42 | 43 | .. autosummary:: 44 | :toctree: generated/ 45 | 46 | QuantizerConfig 47 | basic_quantizer 48 | advanced_quantizer 49 | get_ticks_from_subdivision 50 | _find_nearest 51 | 52 | """ 53 | 54 | from .timing import ( 55 | TimingConsts, 56 | NoteLengths, 57 | SymbolicNoteLengths, 58 | TimeSignature, 59 | ms_per_tick, 60 | _bar_str_to_tuple, 61 | ticks_per_bar, 62 | ms_per_note, 63 | ms_per_bar, 64 | get_subdivisions, 65 | get_symbolic_duration, 66 | Timing, 67 | Beat, 68 | Subdivision, 69 | ) 70 | 71 | from .quantizer import ( 72 | QuantizerConfig, 73 | basic_quantizer, 74 | advanced_quantizer, 75 | get_ticks_from_subdivision, 76 | _find_nearest 77 | ) 78 | 79 | __all__ = [ 80 | "TimingConsts", 81 | "NoteLengths", 82 | "ms_per_tick", 83 | "_bar_str_to_tuple", 84 | "ticks_per_bar", 85 | "ms_per_note", 86 | "ms_per_bar", 87 | "get_subdivisions", 88 | "QuantizerConfig", 89 | "basic_quantizer", 90 | "advanced_quantizer", 91 | "get_ticks_from_subdivision", 92 | "_find_nearest", 93 | "SymbolicNoteLengths", 94 | "get_symbolic_duration", 95 | "TimeSignature", 96 | "Timing", 97 | "Beat", 98 | "Subdivision", 99 | ] 100 | -------------------------------------------------------------------------------- /musicaiz/rhythm/quantizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Union, Optional 2 | import numpy as np 3 | from dataclasses import dataclass 4 | 5 | from musicaiz.rhythm import TimingConsts, ms_per_tick 6 | 7 | 8 | @dataclass 9 | class QuantizerConfig: 10 | """ 11 | Basic quantizer arguments. 12 | 13 | Parameters 14 | ---------- 15 | 16 | note: str 17 | The note length of the grid. 18 | 19 | strength: parameter between 0 and 1. 20 | Example GRID = [0 24 48], STAR_TICKS = [3 ,21, 40] and Aq 21 | START_NEW_TICS = [(3-0)*strength, (21-24)*strength, (40-48)*strength] 22 | END_NEW_TICKS = [] 23 | 24 | delta_qr: Q_range in ticks 25 | 26 | type_q: type of quantization 27 | if negative: only differences between start_tick and grid > Q_r is 28 | taking into account for the quantization. If positive only differences 29 | between start_tick and grid < Q_r is taking into accounto for the quantization. 30 | If none all start_tick is quantized based on the strength (it works similar to basic 31 | quantization but adding the strength parameter) 32 | """ 33 | delta_qr: int = 12 34 | strength: int = 1 # 100% 35 | type_q: Optional[str] = None 36 | 37 | 38 | def _find_nearest( 39 | array: List[Union[int, float]], value: Union[int, float] 40 | ) -> Union[int, float]: 41 | """Find de array component value closest to a given value 42 | [3, 6, 9, 12] 5 --> 6""" 43 | array = np.asarray(array) 44 | idx = (np.abs(array - value)).argmin() 45 | return array[idx] 46 | 47 | 48 | def get_ticks_from_subdivision( 49 | subdivisions: List[Dict[str, Union[int, float]]] 50 | ) -> List[int]: 51 | """Extract the grid array in ticks from a subdivision""" 52 | v_grid = [] 53 | for i in range(len(subdivisions)): 54 | v_grid.append(subdivisions[i].get("ticks")) 55 | 56 | return v_grid 57 | 58 | 59 | def basic_quantizer( 60 | input_notes, 61 | grid: List[int], 62 | bpm: int = TimingConsts.DEFAULT_BPM.value 63 | ): 64 | for i in range(len(input_notes)): 65 | start_tick = input_notes[i].start_ticks 66 | end_tick = input_notes[i].end_ticks 67 | start_tick_quantized = _find_nearest(grid, start_tick) 68 | 69 | delta_tick = start_tick - start_tick_quantized 70 | 71 | if delta_tick > 0: 72 | input_notes[i].start_ticks = start_tick - delta_tick 73 | input_notes[i].end_ticks = end_tick - delta_tick 74 | 75 | elif delta_tick < 0: 76 | input_notes[i].start_ticks = start_tick + abs(delta_tick) 77 | input_notes[i].end_ticks = end_tick + abs(delta_tick) 78 | 79 | input_notes[i].start_sec = input_notes[i].start_ticks * ms_per_tick(bpm) 80 | input_notes[i].end_sec = input_notes[i].end_ticks * ms_per_tick(bpm) 81 | 82 | 83 | def advanced_quantizer( 84 | input_notes, 85 | grid: List[int], 86 | config: QuantizerConfig, 87 | bpm: int = TimingConsts.DEFAULT_BPM.value, 88 | resolution: int = TimingConsts.RESOLUTION.value, 89 | ): 90 | """ 91 | This function quantizes a musa object given a grid. 92 | 93 | Parameters 94 | ---------- 95 | 96 | file: musa object 97 | 98 | grid: array of ints in ticks 99 | """ 100 | 101 | Aq = config.strength 102 | 103 | for i in range(len(input_notes)): 104 | 105 | start_tick = input_notes[i].start_ticks 106 | end_tick = input_notes[i].end_ticks 107 | start_tick_quantized = _find_nearest(grid, start_tick) 108 | delta_tick = start_tick - start_tick_quantized 109 | delta_tick_q = int(delta_tick * Aq) 110 | 111 | if config.type_q == "negative" and (abs(delta_tick) > config.delta_qr): 112 | if delta_tick > 0: 113 | input_notes[i].start_ticks = start_tick - delta_tick_q 114 | input_notes[i].end_ticks = end_tick - delta_tick_q 115 | else: 116 | if delta_tick < 0: 117 | input_notes[i].start_ticks = start_tick + abs(delta_tick_q) 118 | input_notes[i].end_ticks = end_tick + abs(delta_tick_q) 119 | 120 | elif config.type_q == "positive" and (abs(delta_tick) < config.delta_qr): 121 | if delta_tick > 0: 122 | input_notes[i].start_ticks = input_notes[i].start_ticks - delta_tick_q 123 | input_notes[i].end_ticks = input_notes[i].end_ticks - delta_tick_q 124 | else: 125 | if delta_tick < 0: 126 | input_notes[i].start_ticks = input_notes[i].start_ticks + abs(delta_tick_q) 127 | input_notes[i].end_ticks = input_notes[i].end_ticks + abs(delta_tick_q) 128 | 129 | elif config.type_q is None: 130 | if delta_tick > 0: 131 | input_notes[i].start_ticks = input_notes[i].start_ticks - delta_tick_q 132 | input_notes[i].end_ticks = input_notes[i].end_ticks - delta_tick_q 133 | else: 134 | if delta_tick < 0: 135 | input_notes[i].start_ticks = input_notes[i].start_ticks + abs(delta_tick_q) 136 | input_notes[i].end_ticks = input_notes[i].end_ticks + abs(delta_tick_q) 137 | 138 | input_notes[i].start_sec = input_notes[i].start_ticks * ms_per_tick(bpm, resolution) / 1000 139 | input_notes[i].end_sec = input_notes[i].end_ticks * ms_per_tick(bpm, resolution) / 1000 140 | -------------------------------------------------------------------------------- /musicaiz/structure/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Structure 3 | ========= 4 | 5 | This module provides objects and methods that allows to create and analyze the structure parts of 6 | symbolic music. 7 | 8 | The basic structure elements are: 9 | 10 | - Piece: The whole piece or MIDI file that contains lists of instruments, bars, notes... 11 | It can also contain harmonic attributes like Key, Chord Progressions, etc, depending if we do want 12 | to predict or generate them. 13 | 14 | - Instruments: A list of bars or directly, notes (depending if we want to distrubte the notes in 15 | bars or not). 16 | 17 | - Bar: List of notes (it can also contain a list of Chords). 18 | 19 | - Notes: The basic element in music in both time and harmonic axes. 20 | 21 | Notes 22 | ------ 23 | 24 | .. autosummary:: 25 | :toctree: generated/ 26 | 27 | AccidentalsNames 28 | AccidentalsValues 29 | NoteClassNames 30 | NoteClassBase 31 | NoteValue 32 | NoteTiming 33 | Note 34 | 35 | Instruments 36 | ----------- 37 | 38 | .. autosummary:: 39 | :toctree: generated/ 40 | 41 | InstrumentMidiPrograms 42 | InstrumentMidiFamilies 43 | Instrument 44 | 45 | Bars 46 | ---- 47 | 48 | .. autosummary:: 49 | :toctree: generated/ 50 | 51 | Bar 52 | 53 | """ 54 | 55 | from .notes import ( 56 | AccidentalsNames, 57 | AccidentalsValues, 58 | NoteClassNames, 59 | NoteClassBase, 60 | NoteValue, 61 | NoteTiming, 62 | Note, 63 | ) 64 | from .instruments import ( 65 | InstrumentMidiPrograms, 66 | InstrumentMidiFamilies, 67 | Instrument, 68 | ) 69 | from .bars import Bar 70 | 71 | __all__ = [ 72 | "AccidentalsNames", 73 | "AccidentalsValues", 74 | "NoteClassNames", 75 | "NoteClassBase", 76 | "NoteValue", 77 | "NoteTiming", 78 | "Note", 79 | "InstrumentMidiPrograms", 80 | "InstrumentMidiFamilies", 81 | "Instrument", 82 | "Bar", 83 | ] 84 | -------------------------------------------------------------------------------- /musicaiz/structure/bars.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Optional 2 | import numpy as np 3 | 4 | 5 | from musicaiz.rhythm import ( 6 | TimingConsts, 7 | ms_per_bar, 8 | ms_per_tick, 9 | Timing, 10 | ) 11 | from musicaiz.structure import Note 12 | 13 | 14 | class Bar: 15 | 16 | """Defines a class to group notes in bars. 17 | 18 | Attributes 19 | ---------- 20 | 21 | time_sig: str 22 | If we do know the time signature in advance, we can initialize Musa object with it. 23 | This will assume that all the MIDI has the same time signature. 24 | 25 | bpm: int 26 | The tempo or bpm of the MIDI file. If this parameter is not initialized we suppose 27 | 120bpm with a resolution (sequencer ticks) of 960 ticks, which means that we have 28 | 500 ticks per quarter note. 29 | 30 | resolution: int 31 | the pulses o ticks per quarter note (PPQ or TPQN). If this parameter is not initialized 32 | we suppose a resolution (sequencer ticks) of 960 ticks. 33 | 34 | absolute_timing: bool 35 | selects how note timing attributes are initialized when reading a MIDI file. 36 | If `absolute_timing` is True, notes will be written in absolute times. 37 | If `absolute_timing` is False, times will be relative to the bar start. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | start: Optional[Union[int, float]] = None, 43 | end: Optional[Union[int, float]] = None, 44 | time_sig: str = TimingConsts.DEFAULT_TIME_SIGNATURE.value, 45 | bpm: int = TimingConsts.DEFAULT_BPM.value, 46 | resolution: int = TimingConsts.RESOLUTION.value, 47 | absolute_timing: bool = True, 48 | ): 49 | self.bpm = bpm 50 | self.time_sig = time_sig 51 | self.resolution = resolution 52 | self.absolute_timing = absolute_timing 53 | 54 | # The following attributes are set when loading a MIDI file 55 | # with Musa class 56 | self.note_density = None 57 | self.harmonic_density = None 58 | 59 | self.ms_tick = ms_per_tick(bpm, resolution) 60 | 61 | if start is not None and end is not None: 62 | timings = Timing._initialize_timing_attributes(start, end, self.ms_tick) 63 | 64 | self.start_ticks = timings["start_ticks"] 65 | self.end_ticks = timings["end_ticks"] 66 | self.start_sec = timings["start_sec"] 67 | self.end_sec = timings["end_sec"] 68 | else: 69 | self.start_ticks = None 70 | self.end_ticks = None 71 | self.start_sec = None 72 | self.end_sec = None 73 | 74 | def relative_notes_timing(self, bar_start: float): 75 | """The bar start is the value in ticks where the bar starts""" 76 | ms_tick = ms_per_tick(self.bpm, self.resolution) 77 | for note in self.notes: 78 | note.start_ticks = note.start_ticks - bar_start 79 | note.end_ticks = note.end_ticks - bar_start 80 | note.start_sec = note.start_ticks * ms_tick / 1000 81 | note.end_sec = note.end_ticks * ms_tick / 1000 82 | 83 | @staticmethod 84 | def get_last_note(note_seq: List[Note]) -> float: 85 | """Get last note in note_seq.""" 86 | end_secs = 0 87 | for note in note_seq: 88 | if note.end_sec > end_secs: 89 | last_note = note 90 | return last_note 91 | 92 | def get_bars_durations( 93 | self, 94 | note_seq: List[Note] 95 | ) -> List[float]: 96 | """ 97 | Build array of bar durations. 98 | We suppose that the note_seq is in the same time signature. 99 | """ 100 | last_note = self.get_last_note(note_seq) 101 | end_secs = last_note.end_secs 102 | sec_measure = ms_per_bar(self.time_sig, self.bpm) * 1000 103 | bar_durations = np.arange(0, end_secs, sec_measure).tolist() 104 | if end_secs % sec_measure != 0: 105 | bar_durations.append(bar_durations[-1] + sec_measure) 106 | return bar_durations 107 | 108 | @classmethod 109 | def get_total_bars(cls, note_seq: List[Note]) -> int: 110 | return len(cls.get_bars_durations(note_seq)) 111 | 112 | @classmethod 113 | def group_notes_in_bars(cls, note_seq: List[Note]) -> List[List[Note]]: 114 | bars_durations = cls.get_bars_durations(note_seq) 115 | bars = [] 116 | prev_bar_sec = 0 117 | for bar_sec in bars_durations: 118 | for note in note_seq: 119 | if bar_sec >= note.end_sec and prev_bar_sec < note.end_sec: 120 | bars.append(note) 121 | prev_bar_sec = bar_sec 122 | return bars 123 | 124 | def __repr__(self): 125 | if self.start_sec is not None: 126 | start_sec = round(self.start_sec, 2) 127 | else: 128 | start_sec = self.start_sec 129 | if self.end_sec is not None: 130 | end_sec = round(self.end_sec, 2) 131 | else: 132 | end_sec = self.end_sec 133 | 134 | return "Bar(time_signature={}, " \ 135 | "note_density={}, " \ 136 | "harmonic_density={} " \ 137 | "start_ticks={} " \ 138 | "end_ticks={} " \ 139 | "start_sec={} " \ 140 | "end_sec={})".format( 141 | self.time_sig, 142 | self.note_density, 143 | self.harmonic_density, 144 | self.start_ticks, 145 | self.end_ticks, 146 | start_sec, 147 | end_sec 148 | ) 149 | -------------------------------------------------------------------------------- /musicaiz/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tokenizers 3 | ========== 4 | 5 | This module provides methods to encode symbolic music in order to train Sequence models. 6 | 7 | The parent class in the EncodeBase class. The tokenizers are: 8 | 9 | - MMMTokenizer: Multi-Track Music Machine tokenizer. 10 | 11 | 12 | Basic Encoding 13 | -------------- 14 | 15 | .. autosummary:: 16 | :toctree: generated/ 17 | 18 | EncodeBase 19 | TokenizerArguments 20 | 21 | 22 | Multi-Track Music Machine 23 | ------------------------- 24 | 25 | This submodule contains the implementation of the MMM encoding: 26 | 27 | [1] Ens, J., & Pasquier, P. (2020). 28 | Flexible generation with the multi-track music machine. 29 | In Proceedings of the 21st International Society for Music Information Retrieval Conference, ISMIR. 30 | 31 | .. autosummary:: 32 | :toctree: generated/ 33 | 34 | MMMTokenizerArguments 35 | MMMTokenizer 36 | 37 | 38 | REMI and REMI+ 39 | -------------- 40 | 41 | This submodule contains the implementation of the REMI encoding: 42 | 43 | [2] Huang, Y. S., & Yang, Y. H. (2020). 44 | Pop music transformer: Beat-based modeling and generation of expressive pop piano compositions. 45 | In Proceedings of the 28th ACM International Conference on Multimedia (pp. 1180-1188). 46 | 47 | And REMI+ encoding: 48 | 49 | [3] von Rutte, D., Biggio, L., Kilcher, Y. & Hofmann, T. (2023). 50 | FIGARO: Controllable Muic Generation using Learned and Expert Features. 51 | ICLR 2023. 52 | 53 | .. autosummary:: 54 | :toctree: generated/ 55 | 56 | REMITokenizerArguments 57 | REMITokenizer 58 | 59 | Compound Word (CPWord) 60 | -------------- 61 | 62 | This submodule contains the implementation of the CPWord encoding: 63 | 64 | [4] Hsiao, W. Y., Liu, J. Y., Yeh, Y. C & Yang, Y. H. (2021). 65 | Compund Word Transformer: Learning to compose full-song music over dynamic directed hypergraphs. 66 | In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 35, No. 1, pp. 178-186). 67 | 68 | .. autosummary:: 69 | :toctree: generated/ 70 | 71 | CPWordTokenizerArguments 72 | CPWordTokenizer 73 | """ 74 | 75 | from enum import Enum 76 | 77 | from .encoder import ( 78 | EncodeBase, 79 | TokenizerArguments, 80 | ) 81 | from .mmm import ( 82 | MMMTokenizer, 83 | MMMTokenizerArguments, 84 | ) 85 | from .remi import ( 86 | REMITokenizer, 87 | REMITokenizerArguments, 88 | ) 89 | from .cpword import ( 90 | CPWordTokenizerArguments, 91 | CPWordTokenizer, 92 | ) 93 | from .one_hot import ( 94 | OneHot, 95 | ) 96 | 97 | 98 | TOKENIZER_ARGUMENTS = [ 99 | MMMTokenizerArguments, 100 | REMITokenizerArguments, 101 | CPWordTokenizerArguments 102 | ] 103 | 104 | 105 | class Tokenizers(Enum): 106 | MULTI_TRACK_MUSIC_MACHINE = ("MMM", MMMTokenizerArguments) 107 | REMI = ("REMI", REMITokenizerArguments) 108 | CPWORD = ("CPWORD", CPWordTokenizerArguments) 109 | 110 | @property 111 | def name(self): 112 | return self.value[0] 113 | 114 | @property 115 | def arg(self): 116 | return self.value[1] 117 | 118 | @staticmethod 119 | def names(): 120 | return [t.value[0] for t in Tokenizers.__members__.values()] 121 | 122 | @staticmethod 123 | def args(): 124 | return [t.value[1] for t in Tokenizers.__members__.values()] 125 | 126 | 127 | __all__ = [ 128 | "EncodeBase", 129 | "TokenizerArguments", 130 | "TOKENIZER_ARGUMENTS", 131 | "Tokenizers", 132 | "MMMTokenizerArguments", 133 | "MMMTokenizer", 134 | "REMITokenizerArguments", 135 | "REMITokenizer", 136 | "CPWordTokenizerArguments", 137 | "CPWordTokenizer", 138 | "OneHot" 139 | ] 140 | -------------------------------------------------------------------------------- /musicaiz/tokenizers/one_hot.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | OneHot.one_hot 4 | 5 | """ 6 | import numpy as np 7 | from typing import List 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | # Our modules 12 | from musicaiz.structure import Note 13 | 14 | 15 | class OneHot: 16 | 17 | # TODO: Initialize with time axis ticks or secs? 18 | """ 19 | This class .... 20 | 21 | """ 22 | 23 | @staticmethod 24 | def one_hot( 25 | notes: List[Note], 26 | min_pitch: int = 0, 27 | max_pitch: int = 127, 28 | time_axis: str = "ticks", 29 | step: int = 10, 30 | vel_one_hot: bool = True, 31 | ) -> np.array: 32 | 33 | """ 34 | Each pitch from `min_pitch` to `max_pitch` has a value of 0 35 | (if the pitch value is not being played in the current time step) or 1 36 | (if the pitch value is being played in the current time step) 37 | 38 | Parameters 39 | ---------- 40 | 41 | Raises 42 | ------ 43 | ValueError 44 | if ... 45 | 46 | Examples 47 | -------- 48 | Decompose a magnitude spectrogram into 32 components with NMF 49 | 50 | >>> y, sr = librosa.load(librosa.ex('choice'), duration=5) 51 | >>> S = np.abs(librosa.stft(y)) 52 | >>> comps, acts = librosa.decompose.decompose(S, n_components=8) 53 | Sort components by ascending peak frequency 54 | 55 | """ 56 | 57 | if time_axis == "ticks": 58 | max_time = int(notes[-1].end_ticks / step) 59 | loop_step = step 60 | elif time_axis == "seconds": 61 | max_time = int(notes[-1].end_sec / step) 62 | loop_step = int(step * 1000) # step in for loop must be in ms 63 | else: 64 | raise ValueError("Not a valid axis. Axist must be 'seconds' or 'ticks'.") 65 | 66 | if step > max_time: 67 | raise ValueError(f"Step value {step} must be smaller than the total time steps {max_time}.") 68 | 69 | # Initialize one hot array 70 | pitch_range = max_pitch - min_pitch 71 | one_hot = np.zeros((pitch_range, max_time)) 72 | for time_step in range(0, max_time, loop_step): 73 | for note in notes: 74 | if time_axis == "ticks": 75 | note_start = note.start_ticks 76 | note_end = note.end_ticks 77 | elif time_axis == "seconds": 78 | note_start = note.start_sec * 1000 79 | note_end = note.end_sec * 1000 80 | 81 | if (note_start <= time_step <= note_end) and (min_pitch <= note.pitch <= max_pitch): 82 | if vel_one_hot: 83 | one_hot[note.pitch - min_pitch, time_step] = note.velocity 84 | else: 85 | one_hot[note.pitch - min_pitch, time_step] = 1 86 | return one_hot 87 | 88 | def notes_to_one_hot(): 89 | """Same as before but each instrument in the 3rd axis""" 90 | pass 91 | 92 | # TODO: Move this to plotters module? 93 | @staticmethod 94 | def plot(one_hot_tensor: np.array): 95 | plt.subplots(figsize=(20, 5)) 96 | aspect = int(one_hot_tensor.shape[1] / one_hot_tensor.shape[0] / 10) 97 | if aspect <= 0: 98 | aspect = 1 99 | plt.imshow(one_hot_tensor, origin="lower", aspect=aspect) 100 | plt.xlabel("Time") 101 | plt.ylabel("Pitch") 102 | -------------------------------------------------------------------------------- /musicaiz/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, TextIO, Optional 2 | from pathlib import Path 3 | 4 | 5 | from musicaiz.structure import Note 6 | from musicaiz.loaders import Musa 7 | from musicaiz.rhythm import get_subdivisions, TimingConsts 8 | 9 | 10 | def get_list_files_path(path: Union[Path, str]) -> List[str]: 11 | """Returns a list of all files in a directory as strings. 12 | """ 13 | if isinstance(path, str): 14 | path = Path(path) 15 | files = list(path.rglob('*.mid')) 16 | files.extend(list(path.rglob('*.midi'))) 17 | 18 | files_str = [str(f) for f in files] 19 | return files_str 20 | 21 | 22 | def sort_notes(note_seq: List[Note]) -> List[Note]: 23 | """Sorts a list of notes by the start_ticks notes attribute.""" 24 | note_seq.sort(key=lambda x: x.start_ticks, reverse=False) 25 | return note_seq 26 | 27 | 28 | def group_notes_in_subdivisions_bars(musa_obj: Musa, subdiv: str) -> List[List[List[Note]]]: 29 | """This function groups notes in the selected subdivision. 30 | The result is a list which elements are lists that represent the bars, 31 | and inside them, lists that represent the notes in each subdivision. 32 | 33 | Parameters 34 | ---------- 35 | musa_obj: Musa 36 | A Musa object initialized with the argument `structure="bars"`. 37 | 38 | Returns 39 | ------- 40 | all_subdiv_notes: List[List[List[Note]]] 41 | A list of bars in which each element is a subdivision which is a List of notes that 42 | are in the subdivision. 43 | Ex.: For 4 bars at 4/4 bar and 8th note as subdivision, we'll have a list of 4 items 44 | beign each item a list of 8 elements which are the 8th notes that are in each 4/4 bar. 45 | Inside the subdivisions list, we'll find the notes that belong to the subdivision. 46 | """ 47 | # Group notes in bars (no instruments) 48 | bars = Musa.group_instrument_bar_notes(musa_obj) 49 | # 1. Sort midi notes in all the bars 50 | sorted_bars = [sort_notes(b) for b in bars] 51 | # Retain the highest note at a time frame (1 16th note) 52 | grid = get_subdivisions( 53 | total_bars=len(sorted_bars), 54 | subdivision=subdiv, 55 | time_sig=musa_obj.time_sig.time_sig, 56 | bpm=musa_obj.bpm, 57 | resolution=musa_obj.resolution, 58 | absolute_timing=musa_obj.absolute_timing 59 | ) 60 | step_ticks = grid[1]["ticks"] 61 | bar_grid = [g for g in grid if g["bar"] == 1] 62 | 63 | # Group notes in subdivisions 64 | all_subdiv_notes = [] # Lis with N items = N bars 65 | new_step_ticks = 0 66 | for b, bar in enumerate(sorted_bars): 67 | bar_notes = [] # List with N items = to N subdivisions per bar 68 | for s, subd in enumerate(bar_grid): 69 | subdiv_notes = [] # List of notes in the current subdivision 70 | start = new_step_ticks 71 | end = new_step_ticks + step_ticks 72 | for i, note in enumerate(bar): 73 | if note.start_ticks <= start and note.end_ticks >= end: 74 | subdiv_notes.append(note) 75 | new_step_ticks += step_ticks 76 | bar_notes.append(subdiv_notes) 77 | all_subdiv_notes.append(bar_notes) 78 | return all_subdiv_notes 79 | 80 | 81 | def get_highest_subdivision_bars_notes( 82 | all_subdiv_notes: List[List[List[Note]]] 83 | ) -> List[List[Note]]: 84 | """Extracts the highest note in each subdivision. 85 | 86 | Parameters 87 | ---------- 88 | all_subdiv_notes: List[List[List[Note]]] 89 | A list of bars in which each element is a subdivision which is a List of notes that 90 | are in the subdivision. 91 | 92 | Returns 93 | ------- 94 | bar_highest_subdiv_notes: List[List[Note]] 95 | A list of bars in which each element in the bar carresponds to the note with the 96 | highest pitch in the subdivision. 97 | """ 98 | # Retrieve the note with the highest pitch in each subdivision of every bar 99 | bar_highest_subdiv_notes = [] 100 | for b, bar in enumerate(all_subdiv_notes): 101 | highest_subdiv_notes = [] # List of N items = each item is the highest note in the subdiv 102 | for s, subdiv in enumerate(bar): 103 | if len(subdiv) == 0: 104 | highest_subdiv_notes.append(Note(pitch=0, start=0, end=1, velocity=0)) 105 | continue 106 | for n, note in enumerate(subdiv): 107 | prev_pitch = 0 108 | # If there are no notes in a bar, fill it with a note which pitch is 0 109 | if subdiv[n].pitch > prev_pitch: 110 | highest_note = note 111 | highest_subdiv_notes.append(highest_note) 112 | bar_highest_subdiv_notes.append(highest_subdiv_notes) 113 | return bar_highest_subdiv_notes 114 | 115 | 116 | def __initialization( 117 | file: Union[Musa, str, TextIO, Path], 118 | structure: str = "bars", 119 | quantize: bool = True, 120 | quantize_note: Optional[str] = "sixteenth", 121 | tonality: Optional[str] = None, 122 | time_sig: str = TimingConsts.DEFAULT_TIME_SIGNATURE.value, 123 | bpm: int = TimingConsts.DEFAULT_BPM.value, 124 | resolution: int = TimingConsts.RESOLUTION.value, 125 | absolute_timing: bool = True 126 | ) -> Musa: 127 | """ 128 | If both musa_obj and file are give, the class will be initialized with the given file 129 | not with the musa_obj. 130 | """ 131 | if isinstance(file, Musa): 132 | musa_obj = file 133 | elif isinstance(file, str) or isinstance(file, Path) or isinstance(file, TextIO): 134 | musa_obj = Musa( 135 | file=file, 136 | structure=structure, 137 | quantize=quantize, 138 | quantize_note=quantize_note, 139 | tonality=tonality, 140 | time_sig=time_sig, 141 | bpm=bpm, 142 | resolution=resolution, 143 | absolute_timing=absolute_timing 144 | ) 145 | else: 146 | raise ValueError("You must pass a Musa object or a file to initialize this class.") 147 | return musa_obj 148 | -------------------------------------------------------------------------------- /musicaiz/version.py: -------------------------------------------------------------------------------- 1 | MAJOR = 0 2 | MINOR = 1 3 | PATCH = 2 4 | 5 | __version__ = "%d.%d.%d" % (MAJOR, MINOR, PATCH) -------------------------------------------------------------------------------- /musicaiz/wrappers.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from itertools import repeat 3 | import time 4 | from typing import Union, List, Any 5 | from pathlib import Path 6 | 7 | 8 | from musicaiz import utils 9 | 10 | 11 | def multiprocess_path( 12 | func, 13 | path: Union[List[str], Path], 14 | args: Union[List[Any], None] = None, 15 | n_jobs=None 16 | ) -> str: 17 | if n_jobs is None: 18 | n_jobs = multiprocessing.cpu_count() 19 | pool = multiprocessing.Pool(n_jobs) 20 | 21 | if isinstance(path, list): 22 | files_str = path 23 | else: 24 | files_str = utils.get_list_files_path(path) 25 | 26 | if args is not None: 27 | args_it = [repeat(arg) for arg in args] 28 | results = pool.starmap( 29 | func, 30 | zip(files_str, args_it) 31 | ) 32 | else: 33 | results = pool.starmap( 34 | func, 35 | zip(files_str) 36 | ) 37 | pool.close() 38 | pool.join() 39 | return results 40 | 41 | 42 | def timeis(func): 43 | 44 | def wrap(*args, **kwargs): 45 | start = time.time() 46 | result = func(*args, **kwargs) 47 | end = time.time() 48 | 49 | print(f"Processing time for method {func.__name__}: {end-start} sec") 50 | return result 51 | return wrap -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 48", 4 | "wheel >= 0.29.0", 5 | ] 6 | build-backend = 'setuptools.build_meta' -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | sphinx==4.3.2 2 | sphinx-gallery==0.10.1 3 | numpydoc==1.1.0 4 | sphinxcontrib-svg2pdfconverter==1.2.0 5 | sphinx-multiversion==0.2.4 6 | sphinx-rtd-theme==1.0.0 7 | sphinx-design==0.1.0 8 | sphinx_panels==0.6.0 9 | nbsphinx==0.8.8 10 | pandoc==2.2 11 | matplotlib>=3.4.3 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pretty_midi==0.2.9 2 | pytest==6.2.4 3 | pytest-cov 4 | pytest-mpl 5 | numpy==1.21.2 6 | pandas==1.4.4 7 | matplotlib>=3.4.3 8 | seaborn 9 | coverage==6.2 10 | pre-commit 11 | tqdm 12 | plotly==5.8.0 13 | mypy==0.960 14 | protobuf==4.21.3 15 | rich==12.6.0 16 | ruptures==1.1.7 17 | 18 | # For models submodule 19 | networkx==2.8.6 20 | sklearn==0.0 21 | gradio==3.0.15 22 | torchsummary==1.5.1 23 | prettytable==3.3.0 24 | torch==1.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 25 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | max-line-length=100 3 | ignore = E203,W503 4 | 5 | [tool:pytest] 6 | addopts = --cov-report term-missing --cov musicaiz --cov-report=xml --disable-pytest-warnings --mpl --mpl-baseline-path=tests/baseline_images/test_display 7 | xfail_strict = true 8 | 9 | [metadata] 10 | name = musicaiz 11 | version = attr: musicaiz.version.__version__ 12 | description = A python framework for symbolic music generation, evaluation and analysis 13 | long_description = file: README.md 14 | long_description_content_type = text/markdown; charset=UTF-8 15 | url = https://carlosholivan.github.io/musicaiz 16 | author = Carlos Hernandez-Olivan 17 | author_email = carlosher@unizar.es 18 | license = ISC 19 | license_file = LICENSE.md 20 | license_file_content_type = text/markdown; charset=UTF-8 21 | project_urls = 22 | Documentation = https://carlosholivan.github.io/musicaiz/docs 23 | Download = https://github.com/carlosholivan/musicaiz/releases 24 | Source = https://github.com/carlosholivan/musicaiz 25 | Tracker = https://github.com/carlosholivan/musicaiz/issues 26 | #Discussion forum = https://groups.google.com/g/musicaiz 27 | classifiers = 28 | License :: OSI Approved :: GNU Affero General Public License v3 29 | Programming Language :: Python 30 | Development Status :: 3 - Alpha 31 | Intended Audience :: Developers 32 | Topic :: Multimedia :: Sound/Audio :: Analysis 33 | Framework :: Matplotlib 34 | Programming Language :: Python :: 3 35 | Programming Language :: Python :: 3.8 36 | Programming Language :: Python :: 3.9 37 | 38 | [options] 39 | packages = find: 40 | include_package_data = True 41 | install_requires = 42 | pretty_midi==0.2.9 43 | numpy==1.21.2 44 | pandas==1.4.4 45 | pytest==6.2.4 46 | matplotlib>=3.4.3 47 | plotly==5.8.0 48 | mypy==0.960 49 | seaborn==0.11.2 50 | pre-commit==2.19.0 51 | tqdm==4.64.0 52 | networkx==2.8.6 53 | sklearn==0.0 54 | gradio==3.0.15 55 | torchsummary==1.5.1 56 | prettytable==3.3.0 57 | torch==1.11.0 58 | protobuf==4.21.3 59 | rich==12.6.0 60 | ruptures==1.1.7 61 | python_requires >=3.8 62 | 63 | [options.extras_require] 64 | docs = 65 | sphinx_rtd_theme==0.5.* 66 | sphinx-design==0.1.0 67 | sphinx_panels==0.6.0 68 | nbsphinx==0.8.8 69 | sphinx != 1.3.1 70 | numpydoc==1.1.0 71 | matplotlib >= 3.3.0 72 | sphinx-multiversion >= 0.2.3 73 | sphinx-gallery >= 0.7 74 | ipython >= 7.0 75 | sphinxcontrib-svg2pdfconverter 76 | presets 77 | sphinx_book_theme==0.3.2 78 | nbconvert 79 | tests = 80 | pytest-mpl 81 | pytest-cov 82 | pytest 83 | display = 84 | matplotlib >= 3.3.0 85 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == '__main__': 4 | setup() 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pytest 3 | 4 | 5 | @pytest.fixture 6 | def fixture_dir(): 7 | yield Path("./tests/fixtures/") 8 | -------------------------------------------------------------------------------- /tests/fixtures/datasets/jsbchorales/test/14.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/datasets/jsbchorales/test/14.mid -------------------------------------------------------------------------------- /tests/fixtures/datasets/jsbchorales/test/27.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/datasets/jsbchorales/test/27.mid -------------------------------------------------------------------------------- /tests/fixtures/datasets/jsbchorales/test/9.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/datasets/jsbchorales/test/9.mid -------------------------------------------------------------------------------- /tests/fixtures/datasets/jsbchorales/train/2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/datasets/jsbchorales/train/2.mid -------------------------------------------------------------------------------- /tests/fixtures/datasets/jsbchorales/train/3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/datasets/jsbchorales/train/3.mid -------------------------------------------------------------------------------- /tests/fixtures/datasets/jsbchorales/train/4.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/datasets/jsbchorales/train/4.mid -------------------------------------------------------------------------------- /tests/fixtures/datasets/jsbchorales/valid/1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/datasets/jsbchorales/valid/1.mid -------------------------------------------------------------------------------- /tests/fixtures/datasets/lmd/ABBA/Andante, Andante.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/datasets/lmd/ABBA/Andante, Andante.mid -------------------------------------------------------------------------------- /tests/fixtures/datasets/maestro/maestro-v2.0.0.csv: -------------------------------------------------------------------------------- 1 | canonical_composer,canonical_title,split,year,midi_filename,audio_filename,duration 2 | Alban Berg,Sonata Op. 1,train,2018,2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R3_2018_wav--1.midi,2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R3_2018_wav--1.wav,698.661160312 3 | -------------------------------------------------------------------------------- /tests/fixtures/datasets/maestro/maestro-v2.0.0/2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R3_2018_wav--1.midi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/datasets/maestro/maestro-v2.0.0/2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R3_2018_wav--1.midi -------------------------------------------------------------------------------- /tests/fixtures/datasets/maestro/maestro-v2.0.0/maestro-v2.0.0.csv: -------------------------------------------------------------------------------- 1 | canonical_composer,canonical_title,split,year,midi_filename,audio_filename,duration 2 | Alban Berg,Sonata Op. 1,train,2018,2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R3_2018_wav--1.midi,2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R3_2018_wav--1.wav,698.661160312 3 | -------------------------------------------------------------------------------- /tests/fixtures/midis/midi_changes.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/midis/midi_changes.mid -------------------------------------------------------------------------------- /tests/fixtures/midis/midi_data.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/midis/midi_data.mid -------------------------------------------------------------------------------- /tests/fixtures/tokenizers/cpword_tokens.txt: -------------------------------------------------------------------------------- 1 | FAMILY=METRIC BAR=0 PITCH=NONE VELOCITY=NONE DURATION=NONE PROGRAM=NONE TEMPO=NONE TIME_SIG=4/4 FAMILY=METRIC POSITION=4 PITCH=NONE VELOCITY=NONE DURATION=NONE PROGRAM=NONE TEMPO=120 TIME_SIG=NONE FAMILY=NOTE POSITION=NONE PITCH=69 VELOCITY=127 DURATION=4 PROGRAM=30 TEMPO=NONE TIME_SIG=NONE FAMILY=NOTE POSITION=NONE PITCH=64 VELOCITY=127 DURATION=8 PROGRAM=30 TEMPO=NONE TIME_SIG=NONE FAMILY=METRIC POSITION=8 PITCH=NONE VELOCITY=NONE DURATION=NONE PROGRAM=NONE TEMPO=120 TIME_SIG=NONE FAMILY=NOTE POSITION=NONE PITCH=67 VELOCITY=127 DURATION=4 PROGRAM=30 TEMPO=NONE TIME_SIG=NONE FAMILY=METRIC POSITION=12 PITCH=NONE VELOCITY=NONE DURATION=NONE PROGRAM=NONE TEMPO=120 TIME_SIG=NONE FAMILY=NOTE POSITION=NONE PITCH=64 VELOCITY=127 DURATION=4 PROGRAM=30 TEMPO=NONE TIME_SIG=NONE FAMILY=METRIC BAR=2 PITCH=NONE VELOCITY=NONE DURATION=NONE PROGRAM=NONE TEMPO=NONE TIME_SIG=4/4 FAMILY=METRIC POSITION=32 PITCH=NONE VELOCITY=NONE DURATION=NONE PROGRAM=NONE TEMPO=120 TIME_SIG=NONE FAMILY=NOTE POSITION=NONE PITCH=72 VELOCITY=127 DURATION=4 PROGRAM=30 TEMPO=NONE TIME_SIG=NONE FAMILY=METRIC POSITION=36 PITCH=NONE VELOCITY=NONE DURATION=NONE PROGRAM=NONE TEMPO=120 TIME_SIG=NONE FAMILY=NOTE POSITION=NONE PITCH=69 VELOCITY=127 DURATION=4 PROGRAM=30 TEMPO=NONE TIME_SIG=NONE FAMILY=METRIC POSITION=44 PITCH=NONE VELOCITY=NONE DURATION=NONE PROGRAM=NONE TEMPO=120 TIME_SIG=NONE FAMILY=NOTE POSITION=NONE PITCH=67 VELOCITY=127 DURATION=2 PROGRAM=30 TEMPO=NONE TIME_SIG=NONE -------------------------------------------------------------------------------- /tests/fixtures/tokenizers/mmm_tokens.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/fixtures/tokenizers/mmm_tokens.mid -------------------------------------------------------------------------------- /tests/fixtures/tokenizers/mmm_tokens.txt: -------------------------------------------------------------------------------- 1 | PIECE_START TRACK_START INST=30 BAR_START TIME_DELTA=4 NOTE_ON=69 NOTE_ON=64 TIME_DELTA=4 NOTE_OFF=69 NOTE_ON=67 TIME_DELTA=4 NOTE_OFF=64 NOTE_OFF=67 NOTE_ON=64 TIME_DELTA=4 NOTE_OFF=64 BAR_END BAR_START TIME_DELTA=16 BAR_END BAR_START NOTE_ON=72 TIME_DELTA=4 NOTE_OFF=72 NOTE_ON=69 TIME_DELTA=4 NOTE_OFF=69 TIME_DELTA=4 NOTE_ON=67 TIME_DELTA=2 NOTE_OFF=67 TIME_DELTA=2 BAR_END TRACK_END -------------------------------------------------------------------------------- /tests/fixtures/tokenizers/remi_tokens.txt: -------------------------------------------------------------------------------- 1 | BAR=0 TIME_SIG=4/4 SUB_BEAT=4 TEMPO=120 INST=30 PITCH=69 DUR=4 VELOCITY=127 PITCH=64 DUR=8 VELOCITY=127 SUB_BEAT=8 PITCH=67 DUR=4 VELOCITY=127 SUB_BEAT=12 PITCH=64 DUR=4 VELOCITY=127 BAR=2 TIME_SIG=4/4 SUB_BEAT=0 PITCH=72 DUR=4 VELOCITY=127 SUB_BEAT=4 PITCH=69 DUR=4 VELOCITY=127 SUB_BEAT=12 PITCH=67 DUR=2 VELOCITY=127 -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/algorithms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/algorithms/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/algorithms/test_chord_prediction.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from musicaiz.loaders import Musa 4 | from musicaiz.algorithms import predict_chords 5 | 6 | 7 | @pytest.fixture 8 | def midi_sample(fixture_dir): 9 | return fixture_dir / "midis" / "midi_data.mid" 10 | 11 | 12 | def test_predict_chords(midi_sample): 13 | # Import MIDI file 14 | midi = Musa(midi_sample) 15 | 16 | got = predict_chords(midi) 17 | assert len(got) == len(midi.beats) 18 | for i in got: 19 | assert len(i) != 0 20 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/algorithms/test_harmonic_shift.py: -------------------------------------------------------------------------------- 1 | from musicaiz.structure import Note 2 | from musicaiz.algorithms import scale_change 3 | 4 | def test_scale_change_a(): 5 | # Test case: All the origin_notes belong to the origin_tonality and scale 6 | origin_notes = [ 7 | Note(pitch=43, start=0, end=96, velocity=82), # G 8 | Note(pitch=59, start=96, end=96*2, velocity=82), # B 9 | Note(pitch=60, start=96*2, end=96*3, velocity=82), # C 10 | Note(pitch=72, start=96*4, end=96*5, velocity=44) # C 11 | ] 12 | origin_tonality = "C_MAJOR" 13 | origin_scale = "MAJOR" 14 | target_tonality = "G_MAJOR" 15 | target_scale = "MAJOR" 16 | correction = True 17 | 18 | expected_notes = [ 19 | Note(pitch=38, start=0, end=96, velocity=82), # D 20 | Note(pitch=54, start=96, end=96*2, velocity=82), # F# 21 | Note(pitch=67, start=96*2, end=96*3, velocity=82), # G 22 | Note(pitch=79, start=96*4, end=96*5, velocity=44) # G 23 | ] 24 | 25 | got_notes = scale_change( 26 | origin_notes, 27 | origin_tonality, 28 | origin_scale, 29 | target_tonality, 30 | target_scale, 31 | correction 32 | ) 33 | 34 | for i, note in enumerate(expected_notes): 35 | assert expected_notes[i].pitch == got_notes[i].pitch 36 | 37 | 38 | def test_scale_change_b(): 39 | # Test case: Some origin_notes do not belong to the origin_torality and scale. 40 | # Correction applied. 41 | origin_notes = [ 42 | Note(pitch=44, start=0, end=96, velocity=82), # G# 43 | ] 44 | origin_tonality = "C_MAJOR" 45 | origin_scale = "MAJOR" 46 | target_tonality = "G_MAJOR" 47 | target_scale = "MAJOR" 48 | correction = True 49 | 50 | expected_notes = [ 51 | Note(pitch=40, start=0, end=96, velocity=82), # E 52 | ] 53 | 54 | got_notes = scale_change( 55 | origin_notes, 56 | origin_tonality, 57 | origin_scale, 58 | target_tonality, 59 | target_scale, 60 | correction 61 | ) 62 | 63 | for i, note in enumerate(expected_notes): 64 | assert expected_notes[i].pitch == got_notes[i].pitch 65 | 66 | 67 | def test_scale_change_c(): 68 | # Test case: Some origin_notes do not belong to the origin_torality and scale. 69 | # No correction applied. 70 | origin_notes = [ 71 | Note(pitch=44, start=0, end=96, velocity=82), # G# 72 | ] 73 | origin_tonality = "C_MAJOR" 74 | origin_scale = "MAJOR" 75 | target_tonality = "G_MAJOR" 76 | target_scale = "MAJOR" 77 | correction = False 78 | 79 | expected_notes = [ 80 | Note(pitch=39, start=0, end=96, velocity=82), # D# 81 | ] 82 | 83 | got_notes = scale_change( 84 | origin_notes, 85 | origin_tonality, 86 | origin_scale, 87 | target_tonality, 88 | target_scale, 89 | correction 90 | ) 91 | 92 | for i, note in enumerate(expected_notes): 93 | assert expected_notes[i].pitch == got_notes[i].pitch 94 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/algorithms/test_key_profiles.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import math 3 | 4 | from musicaiz.structure import Note 5 | from musicaiz.algorithms import ( 6 | _keys_correlations, 7 | key_detection 8 | ) 9 | 10 | 11 | @pytest.fixture 12 | def durations(): 13 | return { 14 | "C": 432, 15 | "C_SHARP": 231, 16 | "D": 0, 17 | "D_SHARP": 405, 18 | "E": 12, 19 | "F": 316, 20 | "F_SHARP": 4, 21 | "G": 126, 22 | "G_SHARP": 612, 23 | "A": 0, 24 | "A_SHARP": 191, 25 | "B": 1, 26 | } 27 | 28 | 29 | @pytest.fixture 30 | def expected_k_k_corr(): 31 | # from http://rnhart.net/articles/key-finding/ 32 | return { 33 | "C_MAJOR": -0.00009, 34 | "C_MINOR": 0.622, 35 | "C_SHARP_MAJOR": 0.538, 36 | "C_SHARP_MINOR": 0.094, 37 | "D_MAJOR": -0.741, 38 | "D_MINOR": -0.313, 39 | "D_SHARP_MAJOR": 0.579, 40 | "D_SHARP_MINOR": 0.152, 41 | "E_MAJOR": -0.269, 42 | "E_MINOR": -0.4786, 43 | "F_MAJOR": 0.101, 44 | "F_MINOR": 0.775, 45 | "F_SHARP_MAJOR": -0.043, 46 | "F_SHARP_MINOR": -0.469, 47 | "G_MAJOR": -0.464, 48 | "G_MINOR": -0.127, 49 | "G_SHARP_MAJOR": 0.970, 50 | "G_SHARP_MINOR": 0.391, 51 | "A_MAJOR": -0.582, 52 | "A_MINOR": -0.176, 53 | "A_SHARP_MAJOR": 0.113, 54 | "A_SHARP_MINOR": 0.250, 55 | "B_MAJOR": -0.201, 56 | "B_MINOR": -0.721, 57 | } 58 | 59 | 60 | @pytest.fixture 61 | def notes(): 62 | return [ 63 | Note(pitch=18, start=0.0, end=1.0, velocity=127), # F# 64 | Note(pitch=16, start=1.0, end=2.0, velocity=127), # E 65 | Note(pitch=18, start=2.0, end=4.0, velocity=127), # F# 66 | Note(pitch=18, start=2.0, end=4.0, velocity=127), # D 67 | ] 68 | 69 | 70 | def test_keys_correlations(durations, expected_k_k_corr): 71 | corr = _keys_correlations(durations, "k-k") 72 | assert len(corr.keys()) == len(expected_k_k_corr.keys()) 73 | for got_corr_v, expected_corr_v in zip(corr.values(), expected_k_k_corr.values()): 74 | assert math.isclose(got_corr_v, expected_corr_v, abs_tol=0.001) 75 | 76 | 77 | def test_key_detection_krumhansl(notes): 78 | # Test KrumhanslKessler with different num of notes 79 | expected_key_1_note = "F_SHARP_MAJOR" 80 | got_key_1_note = key_detection(notes[0:1], "k-k") 81 | assert expected_key_1_note == got_key_1_note 82 | 83 | expected_key_2_notes = "E_MAJOR" 84 | got_key_2_notes = key_detection(notes[0:2], "k-k") 85 | assert expected_key_2_notes == got_key_2_notes 86 | 87 | expected_key_all_notes = "F_SHARP_MINOR" 88 | got_key_all_notes = key_detection(notes, "k-k") 89 | assert expected_key_all_notes == got_key_all_notes 90 | 91 | 92 | def test_key_detection_temperley(notes): 93 | # Test Temperley with different num of notes 94 | expected_key_1_note = "B_MINOR" 95 | got_key_1_note = key_detection(notes[0:1], "temperley") 96 | assert expected_key_1_note == got_key_1_note 97 | 98 | expected_key_2_notes = "B_MINOR" 99 | got_key_2_notes = key_detection(notes[0:2], "temperley") 100 | assert expected_key_2_notes == got_key_2_notes 101 | 102 | expected_key_all_notes = "B_MINOR" 103 | got_key_all_notes = key_detection(notes, "temperley") 104 | assert expected_key_all_notes == got_key_all_notes 105 | 106 | 107 | def test_key_detection_albrecht_shanahan(notes): 108 | # Test AlbrechtShanahan with different num of notes 109 | expected_key_1_note = "F_SHARP_MAJOR" 110 | got_key_1_note = key_detection(notes[0:1], "a-s") 111 | assert expected_key_1_note == got_key_1_note 112 | 113 | expected_key_2_notes = "E_MAJOR" 114 | got_key_2_notes = key_detection(notes[0:2], "a-s") 115 | assert expected_key_2_notes == got_key_2_notes 116 | 117 | expected_key_all_notes = "B_MINOR" 118 | got_key_all_notes = key_detection(notes, "a-s") 119 | assert expected_key_all_notes == got_key_all_notes 120 | 121 | 122 | def test_key_detection_albrecht_shanahan(notes): 123 | # Test SignatureFifths with different num of notes 124 | expected_key_1_note = "D_MAJOR" 125 | got_key_1_note = key_detection(notes[0:1], "5ths") 126 | assert expected_key_1_note == got_key_1_note 127 | 128 | expected_key_2_notes = "D_MAJOR" 129 | got_key_2_notes = key_detection(notes[0:2], "5ths") 130 | assert expected_key_2_notes == got_key_2_notes 131 | 132 | expected_key_all_notes = "D_MAJOR" 133 | got_key_all_notes = key_detection(notes, "5ths") 134 | assert expected_key_all_notes == got_key_all_notes 135 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/converters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/converters/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/converters/test_musa_json.py: -------------------------------------------------------------------------------- 1 | from musicaiz.converters import ( 2 | MusaJSON, 3 | BarJSON, 4 | InstrumentJSON, 5 | NoteJSON 6 | ) 7 | from musicaiz.loaders import Musa 8 | from .test_musa_to_protobuf import midi_sample 9 | 10 | 11 | def test_MusaJSON(midi_sample): 12 | midi = Musa( 13 | midi_sample, 14 | ) 15 | got = MusaJSON.to_json(midi) 16 | 17 | assert got["tonality"] is None 18 | assert got["resolution"] == 480 19 | assert len(got["instruments"]) == 2 20 | assert len(got["bars"]) == 3 21 | assert len(got["notes"]) == 37 22 | 23 | for inst in got["instruments"]: 24 | assert set(inst.keys()) == set(InstrumentJSON.__dataclass_fields__.keys()) 25 | for bar in got["bars"]: 26 | assert set(bar.keys()) == set(BarJSON.__dataclass_fields__.keys()) 27 | for note in got["notes"]: 28 | assert set(note.keys()) == set(NoteJSON.__dataclass_fields__.keys()) 29 | 30 | # TODO 31 | #def test_JSONMusa(midi_sample, midi_data): 32 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/converters/test_musa_to_protobuf.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from musicaiz.converters import musa_to_proto, proto_to_musa 4 | from musicaiz.loaders import Musa 5 | 6 | 7 | @pytest.fixture 8 | def midi_sample(fixture_dir): 9 | return fixture_dir / "midis" / "midi_data.mid" 10 | 11 | 12 | @pytest.fixture 13 | def midi_data(): 14 | return { 15 | "expected_instruments": 2, 16 | "expected_instrument_name_1": "Piano right", 17 | "expected_instrument_name_2": "Piano left", 18 | } 19 | 20 | 21 | def _assert_midi_valid_instr_obj(midi_data, instruments): 22 | # check instrs 23 | assert midi_data["expected_instruments"] == len(instruments) 24 | # check instrs names 25 | assert midi_data["expected_instrument_name_1"] == instruments[0].name 26 | assert midi_data["expected_instrument_name_2"] == instruments[1].name 27 | # check instrs is_drum 28 | assert instruments[0].is_drum is False 29 | assert instruments[1].is_drum is False 30 | 31 | 32 | def _assert_valid_note_obj(note): 33 | assert 0 <= note.pitch <= 128 34 | assert 0 <= note.velocity <= 128 35 | assert note.pitch_name != "" 36 | assert note.note_name != "" 37 | assert note.octave != "" 38 | assert note.symbolic != "" 39 | assert note.start_ticks >= 0 40 | assert note.end_ticks >= 0 41 | assert note.start_sec >= 0.0 42 | assert note.end_sec >= 0.0 43 | 44 | 45 | def test_musa_to_proto(midi_sample, midi_data): 46 | midi = Musa(midi_sample) 47 | got = musa_to_proto(midi) 48 | 49 | _assert_midi_valid_instr_obj(midi_data, got.instruments) 50 | 51 | # check bars 52 | assert len(got.instruments) != 0 53 | assert len(got.bars) != 0 54 | 55 | # check every bar attributes are not empty 56 | for i, bar in enumerate(got.bars): 57 | # check only the first 5 bars since the midi file is large 58 | if i < 5: 59 | assert bar.start_ticks >= 0 60 | assert bar.end_ticks >= 0 61 | assert bar.start_sec >= 0.0 62 | assert bar.end_sec >= 0.0 63 | for note in got.notes: 64 | _assert_valid_note_obj(note) 65 | 66 | 67 | def test_proto_to_musa(midi_sample, midi_data): 68 | midi = Musa(midi_sample) 69 | proto = musa_to_proto(midi) 70 | got = proto_to_musa(proto) 71 | 72 | _assert_midi_valid_instr_obj(midi_data, got.instruments) 73 | 74 | # check bars 75 | assert len(got.instruments) != 0 76 | 77 | # check every bar attributes are not empty 78 | for i, bar in enumerate(got.bars): 79 | # check only the first 5 bars since the midi file is large 80 | if i < 5: 81 | assert bar.start_ticks >= 0 82 | assert bar.end_ticks >= 0 83 | assert bar.start_sec >= 0.0 84 | assert bar.end_sec >= 0.0 85 | # check every note 86 | for note in got.notes: 87 | _assert_valid_note_obj(note) 88 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/converters/test_pretty_midi_musa.py: -------------------------------------------------------------------------------- 1 | from musicaiz.converters import ( 2 | prettymidi_note_to_musicaiz, 3 | musicaiz_note_to_prettymidi, 4 | musa_to_prettymidi, 5 | ) 6 | from musicaiz.loaders import Musa 7 | from .test_musa_to_protobuf import midi_sample 8 | 9 | 10 | def test_prettymidi_note_to_musicaiz(): 11 | note = "G#4" 12 | expected_name = "G_SHARP" 13 | expected_octave = 4 14 | 15 | got_name, got_octave = prettymidi_note_to_musicaiz(note) 16 | 17 | assert got_name == expected_name 18 | assert got_octave == expected_octave 19 | 20 | 21 | def test_musicaiz_note_to_prettymidi(): 22 | note = "G_SHARP" 23 | octave = 4 24 | expected = "G#4" 25 | 26 | got = musicaiz_note_to_prettymidi(note, octave) 27 | 28 | assert got == expected 29 | 30 | 31 | def test_musa_to_prettymidi(midi_sample): 32 | midi = Musa(midi_sample) 33 | got = musa_to_prettymidi(midi) 34 | 35 | assert len(got.instruments) == 2 36 | 37 | for inst in got.instruments: 38 | assert len(inst.notes) != 0 39 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/datasets/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/datasets/asserts.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | import tempfile 4 | from musicaiz.tokenizers import MMMTokenizer 5 | 6 | 7 | def _assert_tokenize(dataset_path, dataset, args): 8 | # create temp ouput file that will be deleted after the testing 9 | with tempfile.TemporaryDirectory() as output_path: 10 | # tokenize 11 | output_file = "token-sequences" 12 | dataset.tokenize( 13 | dataset_path=dataset_path, 14 | output_path=output_path, 15 | output_file=output_file, 16 | args=args, 17 | tokenize_split="all" 18 | ) 19 | # save configs 20 | assert Path(output_path, "configs.json").is_file() 21 | 22 | # check that train, validation and test paths exist and contain a txt 23 | assert Path(output_path, "train", output_file + ".txt").is_file() 24 | assert Path(output_path, "validation", output_file + ".txt").is_file() 25 | assert Path(output_path, "test", output_file + ".txt").is_file() 26 | 27 | # check that txt in validation path is not empty 28 | # we don't check all 3 files (train, valid and test) since the fixture 29 | # datasets (specially amestro) do not contain all the files but only the train ones. 30 | assert os.path.getsize(Path(output_path, "train", output_file + ".txt")) > 0 31 | 32 | # get vocabulary and save it in `dataset_path` 33 | vocab = MMMTokenizer.get_vocabulary( 34 | dataset_path=output_path 35 | ) 36 | assert len(vocab) != 0 37 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/datasets/test_bps_fh.py: -------------------------------------------------------------------------------- 1 | from musicaiz.harmony import Tonality, AllChords 2 | from musicaiz.datasets import BPSFH 3 | from musicaiz.structure import NoteClassBase 4 | 5 | 6 | def test_bpsfh_key_to_musicaiz(): 7 | 8 | note = "A-" 9 | expected = Tonality.A_FLAT_MAJOR 10 | 11 | got = BPSFH.bpsfh_key_to_musicaiz(note) 12 | assert got == expected 13 | 14 | 15 | def test_bpsfh_chord_quality_to_musicaiz(): 16 | 17 | quality = "a" 18 | expected = AllChords.AUGMENTED_TRIAD 19 | 20 | got = BPSFH.bpsfh_chord_quality_to_musicaiz(quality) 21 | assert got == expected 22 | 23 | 24 | def test_bpsfh_chord_to_musicaiz(): 25 | 26 | note = "f" 27 | quality = "D7" 28 | degree = 5 29 | expected = (NoteClassBase.C, AllChords.DOMINANT_SEVENTH) 30 | 31 | got = BPSFH.bpsfh_chord_to_musicaiz(note, degree, quality) 32 | assert got == expected 33 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/datasets/test_jsbchorales.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .asserts import _assert_tokenize 4 | from musicaiz.datasets import JSBChorales 5 | from musicaiz.tokenizers import MMMTokenizerArguments 6 | 7 | 8 | @pytest.fixture 9 | def dataset_path(fixture_dir): 10 | return fixture_dir / "datasets" / "jsbchorales" 11 | 12 | 13 | def test_JSBChorales_tokenize(dataset_path): 14 | # initialize tokenizer args 15 | args = MMMTokenizerArguments( 16 | prev_tokens="", 17 | windowing=False, 18 | time_unit="SIXTEENTH", 19 | num_programs=None, 20 | shuffle_tracks=True, 21 | track_density=False, 22 | time_sig=True, 23 | velocity=True, 24 | tempo=False 25 | ) 26 | # initialize dataset 27 | dataset = JSBChorales() 28 | assert dataset.name == "jsb_chorales" 29 | _assert_tokenize(dataset_path, dataset, args) 30 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/datasets/test_lmd.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .asserts import _assert_tokenize 4 | from musicaiz.datasets import LakhMIDI 5 | from musicaiz.tokenizers import MMMTokenizerArguments 6 | 7 | 8 | @pytest.fixture 9 | def dataset_path(fixture_dir): 10 | return fixture_dir / "datasets" / "lmd" 11 | 12 | 13 | def test_LakhMIDI_get_metadata(dataset_path): 14 | 15 | expected = { 16 | "ABBA/Andante, Andante.mid": { 17 | "composer": "ABBA", 18 | "period": "", 19 | "genre": "", 20 | "split": "train", 21 | } 22 | } 23 | 24 | dataset = LakhMIDI() 25 | got = dataset.get_metadata(dataset_path) 26 | 27 | assert got.keys() == expected.keys() 28 | for got_v, exp_v in zip(got.values(), expected.values()): 29 | assert set(got_v.values()) == set(exp_v.values()) 30 | 31 | 32 | def test_LakhMIDI_tokenize(dataset_path): 33 | # initialize tokenizer args 34 | args = MMMTokenizerArguments( 35 | prev_tokens="", 36 | windowing=False, 37 | time_unit="SIXTEENTH", 38 | num_programs=None, 39 | shuffle_tracks=True, 40 | track_density=False, 41 | time_sig=True, 42 | velocity=True, 43 | tempo=False 44 | ) 45 | # initialize dataset 46 | dataset = LakhMIDI() 47 | assert dataset.name == "lakh_midi" 48 | _assert_tokenize(dataset_path, dataset, args) 49 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/datasets/test_maestro.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .asserts import _assert_tokenize 4 | from musicaiz.datasets import Maestro 5 | from musicaiz.tokenizers import MMMTokenizerArguments 6 | 7 | 8 | @pytest.fixture 9 | def dataset_path(fixture_dir): 10 | return fixture_dir / "datasets" / "maestro" 11 | 12 | 13 | def test_Maestro_get_metadata(dataset_path): 14 | 15 | expected = { 16 | "2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R3_2018_wav--1.midi": { 17 | "composer": "ALBAN_BERG", 18 | "period": "ROMANTICISM", 19 | "genre": "SONATA", 20 | "split": "train", 21 | } 22 | } 23 | 24 | dataset = Maestro() 25 | got = dataset.get_metadata(dataset_path) 26 | 27 | assert got.keys() == expected.keys() 28 | for got_v, exp_v in zip(got.values(), expected.values()): 29 | assert set(got_v.values()) == set(exp_v.values()) 30 | 31 | 32 | def test_Maestro_tokenize(dataset_path): 33 | # initialize tokenizer args 34 | args = MMMTokenizerArguments( 35 | prev_tokens="", 36 | windowing=False, 37 | time_unit="HUNDRED_TWENTY_EIGHT", 38 | num_programs=None, 39 | shuffle_tracks=True, 40 | track_density=False, 41 | time_sig=True, 42 | velocity=True, 43 | tempo=False 44 | ) 45 | # initialize dataset 46 | dataset = Maestro() 47 | assert dataset.name == "maestro" 48 | _assert_tokenize(dataset_path, dataset, args) 49 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/eval/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/eval/test_eval.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import matplotlib.pyplot as plt 3 | 4 | from musicaiz import eval 5 | from musicaiz.eval import _DEFAULT_MEASURES 6 | 7 | 8 | @pytest.fixture 9 | def dataset_path(fixture_dir): 10 | return fixture_dir / "datasets" / "jsbchorales" / "train" 11 | 12 | 13 | @pytest.fixture 14 | def dataset2_path(fixture_dir): 15 | return fixture_dir / "datasets" / "jsbchorales" / "test" 16 | 17 | 18 | def test_get_all_dataset_measures(dataset_path): 19 | measures = eval.get_all_dataset_measures( 20 | dataset_path 21 | ) 22 | assert len(measures) != 0 23 | 24 | 25 | def test_get_average_dataset_measures(dataset_path): 26 | avgs = eval.get_average_dataset_measures( 27 | dataset_path 28 | ) 29 | assert set(avgs.keys()) == set(_DEFAULT_MEASURES) 30 | 31 | 32 | def test_get_distribution_all(dataset_path, dataset2_path): 33 | measures_1 = eval.get_all_dataset_measures( 34 | dataset_path 35 | ) 36 | measures_2 = eval.get_all_dataset_measures( 37 | dataset2_path 38 | ) 39 | 40 | # Compute the distances 41 | dataset_measures_dist = eval.euclidean_distance(measures_1) 42 | dataset2_measures_dist = eval.euclidean_distance(measures_2) 43 | inter_measures_dist = eval.euclidean_distance(measures_1, measures_2) 44 | 45 | keys = set(["0-1", "1-0", "1-2", "2-1", "0-2", "2-0"]) 46 | assert set(dataset_measures_dist.keys()) == keys 47 | assert set(dataset2_measures_dist.keys()) == keys 48 | assert set(inter_measures_dist.keys()) == keys 49 | 50 | for v in dataset_measures_dist.values(): 51 | set(v.keys()) == set(_DEFAULT_MEASURES) 52 | for v in dataset2_measures_dist.values(): 53 | set(v.keys()) == set(_DEFAULT_MEASURES) 54 | for v in inter_measures_dist.values(): 55 | set(v.keys()) == set(_DEFAULT_MEASURES) 56 | 57 | # Plot the distributions 58 | eval.get_distribution( 59 | (dataset_measures_dist, "Dataset 1 intra"), 60 | (dataset2_measures_dist, "Dataset 2 intra"), 61 | (inter_measures_dist, "Dataset 1 - Dataset 2 inter"), 62 | measure="all", 63 | show=False 64 | ) 65 | plt.close('all') 66 | 67 | eval.model_features_violinplot( 68 | dataset_measures_dist, dataset2_measures_dist, 69 | "Dataset 1", "Dataset 2", 70 | show=False 71 | ) 72 | plt.close('all') 73 | 74 | eval.plot_measures( 75 | dataset_measures_dist, dataset2_measures_dist, 76 | "Dataset 1", "Dataset 2", 77 | show=False 78 | ) 79 | plt.close('all') 80 | 81 | # Compute the overlapping area between the distributions 82 | ov_area = eval.compute_overlapped_area( 83 | dataset_measures_dist, dataset2_measures_dist, "PR" 84 | ) 85 | 86 | assert ov_area <= 1.0 87 | 88 | kld = eval.compute_kld( 89 | dataset_measures_dist, dataset2_measures_dist, "PR" 90 | ) 91 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/features/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/features/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/features/test_graphs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import matplotlib.pyplot as plt 4 | import networkx as nx 5 | 6 | from musicaiz.loaders import Musa 7 | from musicaiz.features import ( 8 | musa_to_graph, 9 | plot_graph, 10 | ) 11 | 12 | 13 | @pytest.fixture 14 | def midi_sample_2(fixture_dir): 15 | return fixture_dir / "midis" / "midi_data.mid" 16 | 17 | 18 | def test_musa_to_graph(midi_sample_2): 19 | musa_obj = Musa(midi_sample_2) 20 | graph = musa_to_graph(musa_obj) 21 | 22 | # n notes must be equal to n nodes 23 | assert len(musa_obj.notes) == len(graph.nodes) 24 | 25 | # adjacency matrix 26 | mat = nx.attr_matrix(graph)[0] 27 | 28 | # n notes must be equal to n nodes 29 | assert len(musa_obj.notes) == mat.shape[0] 30 | 31 | 32 | def test_plot_graph(midi_sample_2): 33 | musa_obj = Musa(midi_sample_2) 34 | graph = musa_to_graph(musa_obj) 35 | 36 | plot_graph(graph, show=False) 37 | plt.close("all") 38 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/features/test_pitch.py: -------------------------------------------------------------------------------- 1 | # Our modules 2 | from musicaiz.features import ( 3 | get_highest_lowest_pitches, 4 | get_pitch_range, 5 | get_pitch_classes, 6 | get_note_density, 7 | ) 8 | from musicaiz.structure import ( 9 | Note 10 | ) 11 | 12 | 13 | # ===============PitchStatistics class Tests=========== 14 | # ===================================================== 15 | def test_get_highest_lowest_pitches_a(): 16 | notes = [ 17 | Note(pitch=75, start=0.0, end=1.0, velocity=127), 18 | Note(pitch=9, start=1.0, end=2.0, velocity=127), 19 | Note(pitch=127, start=1.2, end=1.6, velocity=127) 20 | ] 21 | expected_highest = 127 22 | expected_lowest = 9 23 | got_highest, got_lowest = get_highest_lowest_pitches(notes) 24 | assert got_highest == expected_highest 25 | assert got_lowest == expected_lowest 26 | 27 | 28 | def test_get_pitch_range(): 29 | notes = [ 30 | Note(pitch=75, start=0.0, end=1.0, velocity=127), 31 | Note(pitch=127, start=1.2, end=1.6, velocity=127), 32 | Note(pitch=17, start=1.9, end=2.0, velocity=127), 33 | Note(pitch=127, start=2.0, end=3.0, velocity=127) 34 | ] 35 | expected = 110 36 | got = get_pitch_range(notes) 37 | assert expected == got 38 | 39 | 40 | def test_get_pitch_classes(): 41 | notes = [ 42 | Note(pitch=75, start=0.0, end=1.0, velocity=127), 43 | Note(pitch=9, start=1.0, end=2.0, velocity=127), 44 | Note(pitch=127, start=1.2, end=1.6, velocity=127), 45 | Note(pitch=127, start=1.9, end=2.0, velocity=127), 46 | Note(pitch=127, start=2.0, end=3.0, velocity=127) 47 | ] 48 | expected = { 49 | "127": 3, 50 | "9": 1, 51 | "75": 1 52 | } 53 | got = get_pitch_classes(notes) 54 | assert expected.keys() == got.keys() 55 | for k in expected.keys(): 56 | assert expected[k] == got[k] 57 | 58 | 59 | def test_get_note_density(): 60 | notes = [ 61 | Note(pitch=75, start=0.0, end=1.0, velocity=127), 62 | Note(pitch=9, start=1.0, end=2.0, velocity=127), 63 | Note(pitch=127, start=1.2, end=1.6, velocity=127), 64 | Note(pitch=127, start=1.9, end=2.0, velocity=127), 65 | Note(pitch=127, start=2.0, end=3.0, velocity=127) 66 | ] 67 | expected = 5 68 | got = get_note_density(notes) 69 | assert expected == got 70 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/features/test_rhythm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | from musicaiz.structure import ( 5 | Note, 6 | ) 7 | from musicaiz.features import ( 8 | get_ioi, 9 | get_start_sec, 10 | get_labeled_beat_vector, 11 | _delete_duplicates, 12 | _split_labeled_beat_vector, 13 | compute_rhythm_self_similarity_matrix, 14 | ) 15 | 16 | 17 | def test_get_start_sec(): 18 | notes = [ 19 | Note(pitch=55, start=15.25, end=15.32, velocity=127), 20 | Note(pitch=79, start=16.75, end=16.78, velocity=127), 21 | Note(pitch=74, start=18.75, end=18.78, velocity=127), 22 | Note(pitch=55, start=18.77, end=18.79, velocity=127), 23 | ] 24 | 25 | expected = [15.25, 16.75, 18.75, 18.77] 26 | 27 | got = get_start_sec(notes) 28 | assert set(got) == set(expected) 29 | 30 | 31 | def test_delete_duplicates(): 32 | all_note_on = [0, 3, 4, 3, 3, 2, 2] 33 | expected = [0, 3, 4, 2] 34 | 35 | got = _delete_duplicates(all_note_on) 36 | assert set(got) == set(expected) 37 | 38 | 39 | def test_get_ioi_a(): 40 | 41 | all_note_on = [0.0, 1.0, 1.375] 42 | delete_overlap = True 43 | expected = [1.0, 0.375] 44 | 45 | got = get_ioi(all_note_on, delete_overlap) 46 | 47 | assert len(got) == len(expected) 48 | assert set(got) == set(expected) 49 | 50 | 51 | def test_get_ioi_b(): 52 | 53 | all_note_on = [0.0, 1.0, 1.0, 1.375] 54 | delete_overlap = False 55 | expected = [1.0, 0.0, 0.375] 56 | 57 | got = get_ioi(all_note_on, delete_overlap) 58 | 59 | assert len(got) == len(expected) 60 | assert set(got) == set(expected) 61 | 62 | 63 | def test_get_labeled_beat_vector_a(): 64 | # Test case: Paper example 65 | iois = [0.5, 0.375, 0.125] 66 | expected = [4, 4, 4, 4, 3, 3, 3, 1] 67 | 68 | got = get_labeled_beat_vector(iois) 69 | assert set(got) == set(expected) 70 | 71 | 72 | def test_get_labeled_beat_vector_b(): 73 | # Test case: Different IOI length 74 | iois = [1, 0.25] 75 | expected = [8, 8, 8, 8, 8, 8, 8, 8, 2, 2] 76 | 77 | got = get_labeled_beat_vector(iois) 78 | assert set(got) == set(expected) 79 | 80 | 81 | def test_split_labeled_beat_vector_a(): 82 | labeled_beat_vector = [8, 8, 8, 8, 8, 8, 8, 8, 2, 2] 83 | beat_value = 2 84 | 85 | expected = [ 86 | [8, 8], 87 | [8, 8], 88 | [8, 8], 89 | [8, 8], 90 | [2, 2] 91 | ] 92 | 93 | got = _split_labeled_beat_vector(labeled_beat_vector, beat_value) 94 | for i in range(len(expected)): 95 | assert set(got[i]) == set(expected[i]) 96 | 97 | 98 | def test_split_labeled_beat_vector_b(): 99 | labeled_beat_vector = [4, 4, 4, 4] 100 | beat_value = 4 101 | 102 | expected = [[4, 4, 4, 4]] 103 | 104 | got = _split_labeled_beat_vector(labeled_beat_vector, beat_value) 105 | for i in range(len(expected)): 106 | assert set(got[i]) == set(expected[i]) 107 | 108 | 109 | def test_split_labeled_beat_vector_c(): 110 | # Test case: labeled beat vector length < beat value 111 | labeled_beat_vector = [1] 112 | beat_value = 12 113 | 114 | expected = [[1]] 115 | 116 | got = _split_labeled_beat_vector(labeled_beat_vector, beat_value) 117 | for i in range(len(expected)): 118 | assert got[i] == expected[i][0] 119 | 120 | 121 | def test_compute_rhythm_self_similarity_matrix(): 122 | splitted_beat_vector = [ 123 | [8, 8], 124 | [8, 8], 125 | [8, 8], 126 | [8, 8], 127 | [2, 2] 128 | ] 129 | 130 | expected = np.array([ 131 | [0, 0, 0, 0, 1], 132 | [0, 0, 0, 0, 1], 133 | [0, 0, 0, 0, 1], 134 | [0, 0, 0, 0, 1], 135 | [1, 1, 1, 1, 0], 136 | ]) 137 | 138 | got = compute_rhythm_self_similarity_matrix(splitted_beat_vector) 139 | comparison = got == expected 140 | assert comparison.all() 141 | 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/features/test_structure.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from musicaiz.features import StructurePrediction 4 | 5 | 6 | @pytest.fixture 7 | def midi_sample(fixture_dir): 8 | return fixture_dir / "midis" / "midi_changes.mid" 9 | 10 | 11 | def test_StructurePrediction_notes(midi_sample): 12 | sp = StructurePrediction(midi_sample) 13 | 14 | dataset = "BPS" 15 | 16 | level = "low" 17 | got = sp.notes(level, dataset) 18 | assert len(got) != 0 19 | 20 | level = "mid" 21 | got = sp.notes(level, dataset) 22 | assert len(got) != 0 23 | 24 | level = "high" 25 | got = sp.notes(level, dataset) 26 | assert len(got) != 0 27 | 28 | 29 | def test_StructurePrediction_beats(midi_sample): 30 | sp = StructurePrediction(midi_sample) 31 | 32 | dataset = "BPS" 33 | 34 | level = "low" 35 | got = sp.beats(level, dataset) 36 | assert len(got) != 0 37 | 38 | level = "mid" 39 | got = sp.beats(level, dataset) 40 | assert len(got) != 0 41 | 42 | level = "high" 43 | got = sp.beats(level, dataset) 44 | assert len(got) != 0 45 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/harmony/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/harmony/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/harmony/test_chords.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | # Our modules 5 | from musicaiz.harmony import ( 6 | AllChords, 7 | Chord, 8 | ) 9 | 10 | 11 | # ===============AllChords class Tests================= 12 | # ===================================================== 13 | def test_Chord_split_chord_name_a(): 14 | # Test case: Valid chord 15 | chord_name = "C#mb5" 16 | expected_note = "C#" 17 | expected_quality = "mb5" 18 | got_note, got_quality = Chord.split_chord_name(chord_name) 19 | assert expected_note == got_note 20 | assert expected_quality == got_quality 21 | 22 | 23 | def test_Chord_split_chord_name_b(): 24 | # Test case: Valid chord, note with double sharps (not valid) 25 | chord_name = "C##mb5" 26 | with pytest.raises(ValueError): 27 | Chord.split_chord_name(chord_name) 28 | 29 | 30 | def test_Chords_split_chord_name_c(): 31 | # Test case: Bad character 32 | chord_name = "---" 33 | with pytest.raises(ValueError): 34 | Chord.split_chord_name(chord_name) 35 | 36 | 37 | def test_Chords_split_chord_name_d(): 38 | # Test case: Invalid chord quality (valid note) 39 | chord_name = "Cmb55" 40 | with pytest.raises(ValueError): 41 | Chord.split_chord_name(chord_name) 42 | 43 | 44 | def test_Chord_get_chord_from_name_a(): 45 | # Test case: Invalid chord quality (valid note) 46 | chord_name = "Cmb5" 47 | expected = AllChords.DIMINISHED_SEVENTH 48 | got = AllChords.get_chord_from_name(chord_name) 49 | assert expected == got 50 | 51 | 52 | def test_AllChords_a(): 53 | # Test case: Initialize with valid chord 54 | chord_name = "Cm7b5" 55 | got = Chord(chord_name) 56 | assert got.chord == AllChords.HALF_DIMINISHED_SEVENTH 57 | assert got.quality == "m7b5" 58 | assert got.root_note == "C" 59 | 60 | 61 | def test_AllChords_b(): 62 | # Test case: Initialize with no input chord 63 | got = Chord() 64 | assert got.chord is None 65 | assert got.quality is None 66 | assert got.root_note is None 67 | 68 | 69 | @pytest.mark.skip("Fix this when it's implemented") 70 | def test_AllChords_get_notes_a(): 71 | # Test case: Initialize with valid chord 72 | chord_name = "Gm7b5" 73 | expected = ["G", "Bb", "Db", "F"] 74 | chord = Chord(chord_name) 75 | got = chord.get_notes() 76 | assert set(expected) == set(got) 77 | 78 | 79 | @pytest.mark.skip("Fix this when it's implemented") 80 | def test_AllChords_get_notes_b(): 81 | # Test case: Initialize with valid chord 82 | chord_name = "G#M7" 83 | expected = ["G#", "B#", "D#", "F##"] 84 | chord = Chord(chord_name) 85 | got = chord.get_notes() 86 | assert set(expected) == set(got) 87 | 88 | 89 | @pytest.mark.skip("Fix this when it's implemented") 90 | def test_AllChords_get_notes_c(): 91 | # Test case: Add inversion diverse than 0 92 | chord_name = "G#M7" 93 | inversion = 2 94 | expected = ["Eb", "G", "G#", "C"] 95 | chord = Chord(chord_name) 96 | got = chord.get_notes(inversion) 97 | assert set(expected) == set(got) 98 | 99 | 100 | def test_AllChords_get_notes_d(): 101 | # Test case: Add invalid inversion 102 | chord_name = "G#M7" 103 | inversion = 8 104 | chord = Chord(chord_name) 105 | with pytest.raises(ValueError): 106 | chord.get_notes(inversion) 107 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/loaders/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/plotters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/plotters/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/plotters/test_pianorolls.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import matplotlib.pyplot as plt 3 | 4 | from musicaiz.plotters import Pianoroll, PianorollHTML 5 | from musicaiz.loaders import Musa 6 | 7 | 8 | @pytest.fixture 9 | def midi_sample(fixture_dir): 10 | return fixture_dir / "tokenizers" / "mmm_tokens.mid" 11 | 12 | @pytest.fixture 13 | def midi_multiinstr(fixture_dir): 14 | return fixture_dir / "midis" / "midi_changes.mid" 15 | 16 | 17 | def test_Pianoroll_plot_instrument(midi_sample): 18 | # Test case: plot one instrument 19 | musa_obj = Musa(midi_sample) 20 | plot = Pianoroll(musa_obj) 21 | plot.plot_instruments( 22 | program=30, 23 | bar_start=0, 24 | bar_end=4, 25 | print_measure_data=True, 26 | show_bar_labels=False, 27 | show_grid=False, 28 | show=False, 29 | ) 30 | plt.close("all") 31 | 32 | 33 | def test_Pianoroll_plot_instruments(midi_multiinstr): 34 | # Test case: plot multiple instruments 35 | musa_obj = Musa(midi_multiinstr) 36 | plot = Pianoroll(musa_obj) 37 | plot.plot_instruments( 38 | program=[48, 45, 74, 49, 49, 42, 25, 48, 21, 46, 0, 15, 72, 44], 39 | bar_start=0, 40 | bar_end=4, 41 | print_measure_data=True, 42 | show_bar_labels=False, 43 | show_grid=True, 44 | show=False, 45 | ) 46 | plt.close("all") 47 | 48 | 49 | def test_PianorollHTML_plot_instrument(midi_sample): 50 | musa_obj = Musa(midi_sample) 51 | plot = PianorollHTML(musa_obj) 52 | plot.plot_instruments( 53 | program=30, 54 | bar_start=0, 55 | bar_end=2, 56 | show_grid=False, 57 | show=False 58 | ) 59 | plt.close("all") 60 | 61 | 62 | def test_PianorollHTML_plot_instruments(midi_multiinstr): 63 | musa_obj = Musa(midi_multiinstr) 64 | plot = PianorollHTML(musa_obj) 65 | plot.plot_instruments( 66 | program=[48, 45, 74, 49, 49, 42, 25, 48, 21, 46, 0, 15, 72, 44], 67 | bar_start=0, 68 | bar_end=4, 69 | show_grid=False, 70 | show=False 71 | ) 72 | plt.close("all") 73 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/rhythm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/rhythm/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/rhythm/test_quantizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from musicaiz import rhythm 5 | from musicaiz.structure import Note 6 | from musicaiz.rhythm.quantizer import ( 7 | QuantizerConfig, 8 | basic_quantizer, 9 | get_ticks_from_subdivision, 10 | advanced_quantizer, 11 | _find_nearest, 12 | ) 13 | 14 | 15 | @pytest.fixture 16 | def grid_16(): 17 | grid = rhythm.get_subdivisions( 18 | total_bars=1, subdivision="sixteenth", time_sig="4/4", bpm=120, resolution=96 19 | ) 20 | v_grid = get_ticks_from_subdivision(grid) 21 | 22 | return v_grid 23 | 24 | 25 | @pytest.fixture 26 | def grid_8(): 27 | grid = rhythm.get_subdivisions( 28 | total_bars=1, subdivision="eight", time_sig="4/4", bpm=120, resolution=96 29 | ) 30 | 31 | v_grid = get_ticks_from_subdivision(grid) 32 | 33 | return v_grid 34 | 35 | 36 | def test_find_nearest_a(): 37 | input = [1, 2, 3, 4, 5, 6, 7, 8] 38 | value = 2.2 39 | expected = 2 40 | 41 | got = _find_nearest(input, value) 42 | 43 | assert got == expected 44 | 45 | 46 | def test_find_nearest_b(): 47 | input = np.array([1, 2, 3, 4, 5, 6, 7, 8]) 48 | value = 6.51 49 | expected = 7 50 | 51 | got = _find_nearest(input, value) 52 | 53 | assert got == expected 54 | 55 | 56 | def test_basic_quantizer(grid_16): 57 | 58 | notes_bar1 = [ 59 | Note(pitch=69, start=1, end=24, velocity=127), 60 | Note(pitch=64, start=12, end=24, velocity=127), 61 | Note(pitch=67, start=121, end=250, velocity=127), 62 | Note(pitch=64, start=0, end=162, velocity=127), 63 | ] 64 | 65 | basic_quantizer(notes_bar1, grid_16) 66 | 67 | expected = [ 68 | Note(pitch=69, start=0, end=23, velocity=127), 69 | Note(pitch=64, start=0, end=12, velocity=127), 70 | Note(pitch=67, start=120, end=249, velocity=127), 71 | Note(pitch=64, start=0, end=162, velocity=127), 72 | ] 73 | 74 | for i in range(len(notes_bar1)): 75 | assert notes_bar1[i].start_ticks == expected[i].start_ticks 76 | assert notes_bar1[i].end_ticks == expected[i].end_ticks 77 | 78 | 79 | def test_basic_quantizer_2(grid_8): 80 | 81 | notes_bar1 = [ 82 | Note(pitch=69, start=1, end=24, velocity=127), 83 | Note(pitch=64, start=12, end=24, velocity=127), 84 | Note(pitch=67, start=121, end=250, velocity=127), 85 | Note(pitch=64, start=0, end=162, velocity=127), 86 | ] 87 | 88 | basic_quantizer(notes_bar1, grid_8) 89 | 90 | expected = [ 91 | Note(pitch=69, start=0, end=23, velocity=127), 92 | Note(pitch=64, start=0, end=12, velocity=127), 93 | Note(pitch=67, start=144, end=273, velocity=127), 94 | Note(pitch=64, start=0, end=162, velocity=127), 95 | ] 96 | 97 | for i in range(len(notes_bar1)): 98 | assert notes_bar1[i].start_ticks == expected[i].start_ticks 99 | assert notes_bar1[i].end_ticks == expected[i].end_ticks 100 | 101 | 102 | def test_advanced_quantizer_1(grid_16): 103 | 104 | config = QuantizerConfig( 105 | strength=1, 106 | delta_qr=12, 107 | type_q="positive", 108 | ) 109 | 110 | notes_bar1 = [ 111 | Note(pitch=69, start=1, end=24, velocity=127), 112 | Note(pitch=64, start=12, end=24, velocity=127), 113 | Note(pitch=67, start=121, end=250, velocity=127), 114 | Note(pitch=64, start=13, end=18, velocity=127), 115 | ] 116 | 117 | advanced_quantizer(notes_bar1, grid_16, config, 120, 96) 118 | 119 | expected = [ 120 | Note(pitch=69, start=0, end=23, velocity=127), 121 | Note(pitch=64, start=12, end=24, velocity=127), 122 | Note(pitch=67, start=120, end=249, velocity=127), 123 | Note(pitch=64, start=24, end=29, velocity=127), 124 | ] 125 | 126 | for i in range(len(notes_bar1)): 127 | assert notes_bar1[i].start_ticks == expected[i].start_ticks 128 | assert notes_bar1[i].end_ticks == expected[i].end_ticks 129 | 130 | 131 | def test_advanced_quantizer_2(grid_16): 132 | 133 | config = QuantizerConfig( 134 | strength=1, 135 | delta_qr=12, 136 | type_q=None, 137 | ) 138 | 139 | notes_bar1 = [ 140 | Note(pitch=69, start=1, end=24, velocity=127), 141 | Note(pitch=64, start=12, end=24, velocity=127), 142 | Note(pitch=67, start=121, end=250, velocity=127), 143 | Note(pitch=64, start=13, end=18, velocity=127), 144 | ] 145 | 146 | advanced_quantizer(notes_bar1, grid_16, config, 120, 96) 147 | 148 | expected = [ 149 | Note(pitch=69, start=0, end=23, velocity=127), 150 | Note(pitch=64, start=0, end=12, velocity=127), 151 | Note(pitch=67, start=120, end=249, velocity=127), 152 | Note(pitch=64, start=24, end=29, velocity=127), 153 | ] 154 | 155 | for i in range(len(notes_bar1)): 156 | assert notes_bar1[i].start_ticks == expected[i].start_ticks 157 | assert notes_bar1[i].end_ticks == expected[i].end_ticks 158 | 159 | 160 | def test_advanced_quantizer_3(grid_16): 161 | 162 | config = QuantizerConfig( 163 | strength=1, 164 | delta_qr=12, 165 | type_q=None, 166 | ) 167 | 168 | notes_bar1 = [ # i dont know why but it changes when asing to a object 169 | Note(pitch=69, start=1, end=24, velocity=127), 170 | Note(pitch=64, start=12, end=24, velocity=127), 171 | Note(pitch=67, start=121, end=250, velocity=127), 172 | Note(pitch=64, start=13, end=18, velocity=127), 173 | ] 174 | 175 | advanced_quantizer(notes_bar1, grid_16, config, 120, 96) 176 | 177 | expected = [ 178 | Note(pitch=69, start=0, end=23, velocity=127), 179 | Note(pitch=64, start=0, end=12, velocity=127), 180 | Note(pitch=67, start=120, end=249, velocity=127), 181 | Note(pitch=64, start=24, end=29, velocity=127), 182 | ] 183 | 184 | for i in range(len(notes_bar1)): 185 | assert notes_bar1[i].start_ticks == expected[i].start_ticks 186 | assert notes_bar1[i].end_ticks == expected[i].end_ticks 187 | 188 | 189 | def test_advanced_quantizer_4(grid_16): 190 | 191 | config = QuantizerConfig( 192 | strength=0.75, 193 | delta_qr=12, 194 | type_q=None, 195 | ) 196 | 197 | notes_bar1 = [ 198 | Note(pitch=69, start=1, end=24, velocity=127), 199 | Note(pitch=64, start=12, end=24, velocity=127), 200 | Note(pitch=67, start=121, end=250, velocity=127), 201 | Note(pitch=64, start=30, end=50, velocity=127), 202 | ] 203 | 204 | advanced_quantizer(notes_bar1, grid_16, config, 120, 96) 205 | 206 | expected = [ 207 | Note(pitch=69, start=1, end=24, velocity=127), 208 | Note(pitch=64, start=3, end=15, velocity=127), 209 | Note(pitch=67, start=121, end=250, velocity=127), 210 | Note(pitch=64, start=26, end=46, velocity=127), 211 | ] 212 | 213 | for i in range(len(notes_bar1)): 214 | assert notes_bar1[i].start_ticks == expected[i].start_ticks 215 | assert notes_bar1[i].end_ticks == expected[i].end_ticks 216 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/structure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/structure/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/structure/test_instruments.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | # Our modules 5 | from musicaiz.structure import ( 6 | InstrumentMidiPrograms, 7 | InstrumentMidiFamilies, 8 | Instrument, 9 | ) 10 | 11 | 12 | # ===============InstrumentMidiPrograms class Tests==== 13 | # ===================================================== 14 | def test_InstrumentMidiPrograms_get_possible_names_a(): 15 | got = InstrumentMidiPrograms.ACOUSTIC_GRAND_PIANO.possible_names 16 | expected = [ 17 | "ACOUSTIC_GRAND_PIANO", 18 | "ACOUSTIC GRAND PIANO", 19 | "acoustic_grand_piano", 20 | "acoustic grand piano" 21 | ] 22 | assert set(got) == set(expected) 23 | 24 | 25 | def test_InstrumentMidiPrograms_get_all_instrument_names_b(): 26 | got = InstrumentMidiPrograms.get_all_instrument_names() 27 | assert len(got) != 0 28 | 29 | 30 | def test_InstrumentMidiPrograms_get_all_possible_names(): 31 | got = InstrumentMidiPrograms.get_all_possible_names() 32 | assert len(got) != 0 33 | 34 | 35 | def test_InstrumentMidiPrograms_check_name(): 36 | name = "violin" 37 | got = InstrumentMidiPrograms._check_name(name) 38 | assert got 39 | 40 | 41 | def test_InstrumentMidiPrograms_map_name(): 42 | name = "acoustic grand piano" 43 | expected = InstrumentMidiPrograms.ACOUSTIC_GRAND_PIANO 44 | got = InstrumentMidiPrograms.map_name(name) 45 | assert expected == got 46 | 47 | 48 | def test_InstrumentMidiPrograms_get_name_from_program(): 49 | program = 1 50 | expected = InstrumentMidiPrograms.BRIGHT_ACOUSTIC_PIANO 51 | got = InstrumentMidiPrograms.get_name_from_program(program) 52 | assert expected == got 53 | 54 | 55 | # ===============InstrumentMidiFamilies class Tests==== 56 | # ===================================================== 57 | def test_InstrumentMidiFamilies_get_family_from_instrument_name_a(): 58 | # Test case: Non valid instrument name 59 | instrument_name = "piano" 60 | with pytest.raises(ValueError): 61 | InstrumentMidiFamilies.get_family_from_instrument_name(instrument_name) 62 | 63 | 64 | def test_InstrumentMidiFamilies_get_family_from_instrument_name_b(): 65 | # Test case: Valid instrument name 66 | instrument_name = "acoustic grand piano" 67 | expected = InstrumentMidiFamilies.PIANO 68 | got = InstrumentMidiFamilies.get_family_from_instrument_name(instrument_name) 69 | assert got == expected 70 | 71 | 72 | # ===============Instrument class Tests================ 73 | # ===================================================== 74 | def test_Instrument_a(): 75 | # Test case: Initializing with program and name 76 | program = 0 77 | name = "acoustic grand piano" 78 | instrument = Instrument(program, name) 79 | assert instrument.family == "PIANO" 80 | assert instrument.name == "ACOUSTIC_GRAND_PIANO" 81 | assert instrument.is_drum is False 82 | 83 | 84 | def test_Instrument_b(): 85 | # Test case: Initializing with program, name and custom is_drum 86 | program = 0 87 | name = "acoustic grand piano" 88 | is_drum = True 89 | instrument = Instrument(program, name, is_drum) 90 | assert instrument.family == "PIANO" 91 | assert instrument.name == "ACOUSTIC_GRAND_PIANO" 92 | assert instrument.is_drum is False 93 | 94 | 95 | def test_Instrument_c(): 96 | # Test case: Initializing with program and not name 97 | program = 3 98 | instrument = Instrument(program=program) 99 | assert instrument.family == "PIANO" 100 | assert instrument.name == "HONKY_TONK_PIANO" 101 | assert instrument.is_drum is False 102 | 103 | 104 | def test_Instrument_d(): 105 | # Test case: Initializing with valid name and not program 106 | name = "viola" 107 | instrument = Instrument(name=name) 108 | assert instrument.family == "STRINGS" 109 | assert instrument.name == "VIOLA" 110 | assert instrument.program == 41 111 | assert instrument.is_drum is False 112 | 113 | 114 | # TODO: Add test where is_drum is true: Ex instr = "DRUMS_1" 115 | 116 | 117 | def test_Instrument_e(): 118 | # Test case: Initializing with invalid name and not program 119 | name = "yamaha piano" 120 | with pytest.raises(ValueError): 121 | Instrument(name=name) 122 | 123 | 124 | def test_Instrument_f(): 125 | # Test case: Initializing with no name nor program (wrong) 126 | with pytest.raises(ValueError): 127 | Instrument() 128 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/tokenizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosholivan/musicaiz/70f95854a3777b0323ed47f5a0822cf71eb96a70/tests/unit/musicaiz/tokenizers/__init__.py -------------------------------------------------------------------------------- /tests/unit/musicaiz/tokenizers/test_cpword.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | from pathlib import Path 4 | import tempfile 5 | 6 | from musicaiz.tokenizers import ( 7 | CPWordTokenizer, 8 | CPWordTokenizerArguments 9 | ) 10 | 11 | 12 | @pytest.fixture 13 | def cpword_tokens(fixture_dir): 14 | tokens_path = fixture_dir / "tokenizers" / "cpword_tokens.txt" 15 | text_file = open(tokens_path, "r") 16 | # read whole file to a string 17 | yield text_file.read() 18 | 19 | 20 | @pytest.fixture 21 | def midi_sample(fixture_dir): 22 | return fixture_dir / "tokenizers" / "mmm_tokens.mid" 23 | 24 | 25 | def test_CPWordTokenizer_tokenize(midi_sample, cpword_tokens): 26 | args = CPWordTokenizerArguments(sub_beat="SIXTEENTH") 27 | tokenizer = CPWordTokenizer(midi_sample, args) 28 | tokens = tokenizer.tokenize_file() 29 | assert tokens == cpword_tokens 30 | 31 | # write midi 32 | midi = CPWordTokenizer.tokens_to_musa(tokens) 33 | with tempfile.TemporaryDirectory() as output_path: 34 | path = os.path.join(output_path, 'midi.mid') 35 | midi.writemidi(path) 36 | assert Path(path).is_file() 37 | -------------------------------------------------------------------------------- /tests/unit/musicaiz/tokenizers/test_remi.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | from musicaiz.tokenizers import ( 5 | REMITokenizer, 6 | REMITokenizerArguments, 7 | ) 8 | from .test_mmm import ( 9 | _assert_valid_musa_obj, 10 | musa_obj_tokens, 11 | musa_obj_abs, 12 | midi_sample, 13 | ) 14 | 15 | 16 | @pytest.fixture 17 | def remi_tokens(fixture_dir): 18 | tokens_path = fixture_dir / "tokenizers" / "remi_tokens.txt" 19 | text_file = open(tokens_path, "r") 20 | # read whole file to a string 21 | yield text_file.read() 22 | 23 | 24 | def test_REMITokenizer_split_tokens_by_bar(remi_tokens): 25 | tokens = remi_tokens.split(" ") 26 | expected_bar_1 = [ 27 | [ 28 | "BAR=0", 29 | "TIME_SIG=4/4", 30 | "SUB_BEAT=4", 31 | "TEMPO=120", 32 | "INST=30", 33 | "PITCH=69", 34 | "DUR=4", 35 | "VELOCITY=127", 36 | "PITCH=64", 37 | "DUR=8", 38 | "VELOCITY=127", 39 | "SUB_BEAT=8", 40 | "PITCH=67", 41 | "DUR=4", 42 | "VELOCITY=127", 43 | "SUB_BEAT=12", 44 | "PITCH=64", 45 | "DUR=4", 46 | "VELOCITY=127", 47 | ] 48 | ] 49 | got = REMITokenizer.split_tokens_by_bar(tokens) 50 | assert set(expected_bar_1[0]) == set(got[0]) 51 | 52 | 53 | def test_REMITokenizer_split_tokens_by_subbeat(remi_tokens): 54 | tokens = remi_tokens.split(" ") 55 | expected_subbeats_bar_1 = [ 56 | [ 57 | "BAR=0", 58 | "TIME_SIG=4/4", 59 | ], 60 | [ 61 | "SUB_BEAT=4", 62 | "TEMPO=120", 63 | "INST=30", 64 | "PITCH=69", 65 | "DUR=4", 66 | "VELOCITY=127", 67 | "PITCH=64", 68 | "DUR=8", 69 | "VELOCITY=127", 70 | ], 71 | [ 72 | "SUB_BEAT=8", 73 | "PITCH=67", 74 | "DUR=4", 75 | "VELOCITY=127", 76 | ], 77 | [ 78 | "SUB_BEAT=12", 79 | "PITCH=64", 80 | "DUR=4", 81 | "VELOCITY=127", 82 | ] 83 | ] 84 | got = REMITokenizer.split_tokens_by_subbeat(tokens) 85 | for i in range(len(expected_subbeats_bar_1)): 86 | assert set(expected_subbeats_bar_1[i]) == set(got[i]) 87 | 88 | 89 | def test_REMITokenizer_tokens_to_musa_a(remi_tokens, musa_obj_abs): 90 | # Test case: 1 polyphonic instrument, absolute timings 91 | got = REMITokenizer.tokens_to_musa( 92 | tokens=remi_tokens, 93 | sub_beat="SIXTEENTH" 94 | ) 95 | _assert_valid_musa_obj(got, musa_obj_abs) 96 | 97 | 98 | def test_REMITokenizer_get_tokens_analytics(remi_tokens): 99 | got = REMITokenizer.get_tokens_analytics(remi_tokens) 100 | expected_total_tokens = 33 101 | expected_unique_tokens = 16 102 | expected_total_notes = 7 103 | expected_unique_notes = 4 104 | expected_total_bars = 2 105 | expected_total_instruments = 1 106 | expected_total_pieces = 1 107 | 108 | assert expected_total_pieces == got["total_pieces"] 109 | assert expected_total_tokens == got["total_tokens"] 110 | assert expected_unique_tokens == got["unique_tokens"] 111 | assert expected_total_notes == got["total_notes"] 112 | assert expected_unique_notes == got["unique_notes"] 113 | assert expected_total_bars == got["total_bars"] 114 | assert expected_total_instruments == got["total_instruments"] 115 | 116 | 117 | def test_REMITokenizer_tokenize_bars(midi_sample, remi_tokens): 118 | 119 | expected = remi_tokens 120 | 121 | args = REMITokenizerArguments(sub_beat="SIXTEENTH") 122 | tokenizer = REMITokenizer( 123 | midi_sample, 124 | args=args 125 | ) 126 | got = tokenizer.tokenize_bars() 127 | assert got == expected 128 | 129 | 130 | def test_REMITokenizer_tokenize_file(midi_sample): 131 | args = REMITokenizerArguments(sub_beat="SIXTEENTH") 132 | tokenizer = REMITokenizer(midi_sample, args) 133 | got = tokenizer.tokenize_file() 134 | assert got != "" 135 | --------------------------------------------------------------------------------