├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── __init__.py ├── batch_infer.py ├── binarize.py ├── configs ├── base.yaml ├── continuous.yaml ├── discrete.yaml ├── midi_conformer.yaml ├── quant_two_head_model.yaml └── two_head_model.yaml ├── data └── .gitkeep ├── deployment ├── __init__.py ├── base_onnx_module.py ├── me_onnx_module.py └── me_quant_onnx_module.py ├── experiments └── .gitkeep ├── export.py ├── infer.py ├── inference ├── __init__.py ├── base_infer.py ├── me_infer.py └── me_quant_infer.py ├── lr_scheduler └── scheduler.py ├── modules ├── __init__.py ├── attention │ ├── __init__.py │ └── base_attention.py ├── commons │ ├── __init__.py │ └── tts_modules.py ├── conform │ ├── Gconform.py │ └── __init__.py ├── contentvec │ └── __init__.py ├── conv │ └── base_conv.py ├── losses │ ├── __init__.py │ └── bound_loss.py ├── metrics │ ├── __init__.py │ └── midi_acc.py ├── model │ ├── Gmidi_conform.py │ └── __init__.py └── rmvpe │ ├── __init__.py │ ├── constants.py │ ├── deepunet.py │ ├── inference.py │ ├── model.py │ ├── seq.py │ ├── spec.py │ └── utils.py ├── preprocessing ├── __init__.py ├── base_binarizer.py ├── me_binarizer.py └── me_quant_binarizer.py ├── pretrained └── .gitkeep ├── requirements.txt ├── simplify.py ├── train.py ├── training ├── __init__.py ├── base_task.py ├── me_quant_task.py └── me_task.py ├── utils ├── __init__.py ├── binarizer_utils.py ├── config_utils.py ├── indexed_datasets.py ├── infer_utils.py ├── multiprocess_utils.py ├── pitch_utils.py ├── plot.py ├── slicer2.py └── training_utils.py └── webui.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # Data and model directories 163 | /data/ 164 | /experiments/ 165 | /pretrained/ 166 | *.mid 167 | *.wav 168 | 169 | # git 170 | .git 171 | .github 172 | .gitignore 173 | Dockerfile 174 | .dockerignore 175 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # Data and model directories 163 | /data/ 164 | /experiments/ 165 | /pretrained/ 166 | *.mid 167 | *.wav 168 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM busybox as downloader 2 | WORKDIR /home 3 | LABEL org.opencontainers.image.source=https://github.com/openvpi/SOME 4 | LABEL org.opencontainers.image.description="SOME: Singing-Oriented MIDI Extractor." 5 | LABEL org.opencontainers.image.licenses=MIT 6 | RUN wget -O- https://github.com/openvpi/SOME/releases/latest/download/0918_continuous256_clean_3spk_fixmel.zip|unzip - 7 | FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel 8 | COPY . /opt/app 9 | WORKDIR /opt/app 10 | RUN pip3 install -r requirements.txt gradio==3.47.1 11 | COPY --from=downloader /home experiments 12 | EXPOSE 7860 13 | CMD [ "python", "webui.py", "--addr=0.0.0.0" ] 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Team OpenVPI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SOME 2 | SOME: Singing-Oriented MIDI Extractor. 3 | 4 | > WARNING 5 | > 6 | > This project is under beta version now. No backward compatibility is guaranteed. 7 | 8 | ## Overview 9 | 10 | SOME is a MIDI extractor that can convert singing voice to MIDI sequence, with the following advantages: 11 | 12 | 1. Speed: 9x faster than real-time on an i5 12400 CPU, and 300x on a 3080Ti GPU. 13 | 2. Low resource dependency: SOME can be trained on custom dataset, and can achieve good results with only 3 hours of training data. 14 | 3. Functionality: SOME can produce non-integer MIDI values, which is specially suitable for DiffSinger variance labeling. 15 | 16 | ## Getting Started 17 | 18 | > 中文教程 / Chinese Tutorials: [Text](https://openvpi-docs.feishu.cn/wiki/RaHSwdMQvisdcKkRFpqclhM7ndc), [Video](https://www.bilibili.com/video/BV1my4y1N7VR) 19 | 20 | ### Installation 21 | 22 | SOME requires Python 3.8 or later. We strongly recommend you create a virtual environment via Conda or venv before installing dependencies. 23 | 24 | 1. Install PyTorch 2.1 or later following the [official instructions](https://pytorch.org/get-started/locally/) according to your OS and hardware. 25 | 26 | 2. Install other dependencies via the following command: 27 | 28 | ```bash 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | 3. (Optional) For better pitch extraction results, please download the RMVPE pretrained model from [here](https://github.com/yxlllc/RMVPE/releases) and extract it into `pretrained/` directory. 33 | 34 | ### Inference via pretrained model (MIDI files) 35 | 36 | Download pretrained model of SOME from [releases](https://github.com/openvpi/SOME/releases) and extract them somewhere. 37 | 38 | To infer with CLI, run the following command: 39 | 40 | ```bash 41 | python infer.py --model CKPT_PATH --wav WAV_PATH 42 | ``` 43 | 44 | This will load model at CKPT_PATH, extract MIDI from audio file at WAV_PATH and save a MIDI file. For more useful options, run 45 | 46 | ```bash 47 | python infer.py --help 48 | ``` 49 | 50 | To infer with Web UI, run the following command: 51 | 52 | ```bash 53 | python webui.py --work_dir WORK_DIR 54 | ``` 55 | 56 | Then you can open the gradio interface through your browser and use the models under WORK_DIR following the instructions on the web page. For more useful options, run 57 | 58 | ```bash 59 | python webui.py --help 60 | ``` 61 | 62 | ### Inference via pretrained model (DiffSinger dataset) 63 | 64 | Download pretrained model of SOME from [releases](https://github.com/openvpi/SOME/releases) and extract them somewhere. 65 | 66 | To use SOME for an existing DiffSinger dataset, you should have a transcriptions.csv with `name`, `ph_seq`, `ph_dur` and `ph_num` in it. Run the following command: 67 | 68 | ```bash 69 | python batch_infer.py --model CKPT_PATH --dataset RAW_DATA_DIR --overwrite 70 | ``` 71 | 72 | This will use the model to get all MIDI sequences (with floating point pitch values) from the recordings in the dataset and **OVERWRITE** its transcriptions.csv with `note_seq` and `note_dur` added or replaced. Please be careful and back up your files if necessary. 73 | 74 | For more useful options, run 75 | 76 | ```bash 77 | python batch_infer.py --help 78 | ``` 79 | 80 | ### Training from scratch 81 | 82 | _Training scripts are uploaded but may not be well-organized yet. For the best compatibility, we suggest training your own model after a stable release in the future._ 83 | 84 | 85 | ## Disclaimer 86 | 87 | Any organization or individual is prohibited from using any recordings obtained without consent from the provider as training data. If you do not comply with this item, you could be in violation of copyright laws or software EULAs. 88 | 89 | ## License 90 | 91 | SOME is licensed under the [MIT License](LICENSE). 92 | 93 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openvpi/SOME/dc9916fb35350748351ddcdfc128fe908ef53f5d/__init__.py -------------------------------------------------------------------------------- /batch_infer.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pathlib 3 | from csv import DictReader, DictWriter 4 | from typing import List 5 | 6 | import click 7 | import librosa 8 | import tqdm 9 | import yaml 10 | 11 | import inference 12 | from utils.config_utils import print_config 13 | from utils.slicer2 import Slicer 14 | 15 | task_inference_mapping = { 16 | 'training.MIDIExtractionTask': 'inference.MIDIExtractionInference', 17 | 'training.QuantizedMIDIExtractionTask': 'inference.QuantizedMIDIExtractionInference', 18 | } 19 | 20 | 21 | def model_init(model_path): 22 | model_path = pathlib.Path(model_path) 23 | with open(model_path.with_name('config.yaml'), 'r', encoding='utf8') as f: 24 | config = yaml.safe_load(f) 25 | print_config(config) 26 | infer_cls = task_inference_mapping[config['task_cls']] 27 | 28 | pkg = ".".join(infer_cls.split(".")[:-1]) 29 | cls_name = infer_cls.split(".")[-1] 30 | infer_cls = getattr(importlib.import_module(pkg), cls_name) 31 | assert issubclass(infer_cls, inference.BaseInference), \ 32 | f'Binarizer class {infer_cls} is not a subclass of {inference.BaseInference}.' 33 | model = infer_cls(config=config, model_path=model_path) 34 | return model, config 35 | 36 | 37 | def calc_seq(note_midi, note_rest): 38 | midi_num = round(note_midi, 0) 39 | cent = int(round(note_midi - midi_num, 2) * 100) 40 | if cent > 0: 41 | cent = f"+{cent}" 42 | elif cent == 0: 43 | cent = "" 44 | 45 | seq = f"{librosa.midi_to_note(midi_num, unicode=False)}{cent}" 46 | return seq if not note_rest else 'rest' 47 | 48 | 49 | def infer(wav, infer_ins, config): 50 | wav_path = pathlib.Path(wav) 51 | waveform, _ = librosa.load(wav_path, sr=config['audio_sample_rate'], mono=True) 52 | slicer = Slicer(sr=config['audio_sample_rate'], max_sil_kept=1000) 53 | chunks = slicer.slice(waveform) 54 | midis = infer_ins.infer([c['waveform'] for c in chunks]) 55 | 56 | res: list = [] 57 | for offset, segment in zip([c['offset'] for c in chunks], midis): 58 | offset = round(offset, 6) 59 | note_midi = segment['note_midi'].tolist() 60 | # tempo = 120 61 | note_dur = segment['note_dur'].tolist() 62 | note_rest = segment['note_rest'].tolist() 63 | assert len(note_midi) == len(note_dur) == len(note_rest) 64 | 65 | last_time = 0 66 | for mid, dur, rest in zip(note_midi, note_dur, note_rest): 67 | dur = round(dur, 6) 68 | last_time = round(last_time, 6) 69 | seq = calc_seq(mid, rest) 70 | midi_info: dict = { 71 | 'start_time': round(offset + last_time, 6), 72 | 'end_time': round(offset + last_time + dur, 6), 73 | 'note_seq': seq 74 | } 75 | if res: 76 | if midi_info['start_time'] < res[-1]['end_time']: 77 | midi_info['start_time'] = res[-1]['end_time'] 78 | midi_info['note_dur'] = round(midi_info['end_time'] - midi_info['start_time'], 6) 79 | res.append(midi_info) 80 | last_time += dur 81 | return res 82 | 83 | 84 | def get_word_durs(ph_durs, ph_nums): 85 | res = [] 86 | cur = 0 87 | s_time = 0 88 | for num_phonemes in ph_nums: 89 | word_dur = round(sum(ph_durs[cur:cur + num_phonemes]), 6) 90 | ed_time = s_time + word_dur 91 | res.append((round(s_time, 6), round(ed_time, 6))) 92 | cur += num_phonemes 93 | s_time += word_dur 94 | return res 95 | 96 | 97 | def midi_align(midi_res, midi_durs, tolerance=0.05): 98 | res = [] 99 | bound = [x[0] for x in midi_durs] + [midi_durs[-1][1]] 100 | 101 | for mid in midi_res: 102 | for i in range(len(bound)): 103 | if bound[i] - tolerance <= mid['start_time'] <= bound[i] + tolerance: 104 | mid['start_time'] = bound[i] 105 | if bound[i] - tolerance <= mid['end_time'] <= bound[i] + tolerance: 106 | mid['end_time'] = bound[i] 107 | mid['note_dur'] = round(mid['end_time'] - mid['start_time'], 6) 108 | if mid['note_dur'] > 0: 109 | res.append(mid) 110 | return res 111 | 112 | 113 | def get_all_overlap_midis(interval, segments): 114 | res = [] 115 | for segment in segments: 116 | if interval[0] < segment['start_time'] < interval[1]: 117 | res.append(segment) 118 | elif interval[0] < segment['end_time'] < interval[1]: 119 | res.append(segment) 120 | elif segment['start_time'] <= interval[0] and interval[1] <= segment['end_time']: 121 | res.append(segment) 122 | return res 123 | 124 | 125 | def get_max_overlap_midi(interval, segments): 126 | matching_segment = 'rest' 127 | max_overlap = 0 128 | 129 | for segment in segments: 130 | overlap = max(0, min(interval[1], segment['end_time']) - max(interval[0], segment['start_time'])) 131 | if overlap > max_overlap: 132 | max_overlap = overlap 133 | matching_segment = segment['note_seq'] 134 | return matching_segment 135 | 136 | 137 | @click.command(help='Batch inference on existing DiffSinger dataset.') 138 | @click.option( 139 | '--dataset', required=True, metavar='RAW_DATA_DIR', 140 | help='Path to the dataset directory. Equivalent to \'raw_data_dir\' in DiffSinger configuration files.' 141 | ) 142 | @click.option('--model', required=True, metavar='CKPT_PATH', help='Path to the model checkpoint (*.ckpt)') 143 | @click.option('--round_midi', is_flag=True, help='Round MIDI values to integers') 144 | @click.option( 145 | '--csv', required=False, metavar='CSV_PATH', 146 | help='Path to the output transcriptions.csv file (default to the same file in the dataset)' 147 | ) 148 | @click.option('--overwrite', is_flag=True, help='Overwrite the existing transcriptions.csv file') 149 | def batch_infer(dataset, model, round_midi, csv, overwrite): 150 | data_path = pathlib.Path(dataset) 151 | model_path = pathlib.Path(model) 152 | csv_path = pathlib.Path(csv) if csv is not None else data_path / 'transcriptions.csv' 153 | if csv_path.exists() and not overwrite: 154 | raise FileExistsError(f'The CSV path \'{csv_path}\' already exists. Please re-try with --overwrite option.') 155 | infer_ins, config = model_init(model_path) 156 | 157 | # count = 0 158 | csv_data: List[dict] = [] 159 | with open(f'{data_path}/transcriptions.csv', 'r', encoding='utf8', newline='') as f: 160 | reader = DictReader(f) 161 | for row in reader: 162 | csv_data.append(row) 163 | 164 | for row in tqdm.tqdm(csv_data): 165 | audio_path = data_path / 'wavs' / f"{row['name']}.wav" 166 | if not audio_path.exists(): 167 | print(f'WARNING: audio file does not exist: \'{audio_path}\'') 168 | continue 169 | # print(f"\r\n{audio_path}: start") 170 | result = infer(audio_path, infer_ins, config) 171 | 172 | ph_dur = [round(float(x), 6) for x in row['ph_dur'].split(" ")] 173 | ph_num = [int(x) for x in row['ph_num'].split(" ")] 174 | note_seq = [] 175 | note_dur = [] 176 | 177 | midi_dur_list = get_word_durs(ph_dur, ph_num) 178 | result = midi_align(result, midi_dur_list) 179 | 180 | for (start_time, end_time) in midi_dur_list: 181 | word_duration = round(end_time - start_time, 6) 182 | if round_midi: 183 | match_seq = get_max_overlap_midi((start_time, end_time), result) 184 | note_seq.append(match_seq) 185 | note_dur.append(word_duration) 186 | else: 187 | temp_seq = [] 188 | temp_dur = [] 189 | match_midi = get_all_overlap_midis((start_time, end_time), result) 190 | 191 | for midi in match_midi: 192 | if midi['start_time'] <= start_time: 193 | temp_seq.append(midi['note_seq']) 194 | midi_dur = round(min(end_time, midi['end_time']) - start_time, 6) 195 | elif midi['end_time'] >= end_time: 196 | temp_seq.append(midi['note_seq']) 197 | midi_dur = round(end_time - max(start_time, midi['start_time']), 6) 198 | elif midi['start_time'] <= start_time and midi['end_time'] >= end_time: 199 | temp_seq.append(midi['note_seq']) 200 | midi_dur = word_duration 201 | else: 202 | temp_seq.append(midi['note_seq']) 203 | midi_dur = round(midi['note_dur'], 6) 204 | temp_dur.append(midi_dur) 205 | 206 | if not match_midi: 207 | temp_seq.append('rest') 208 | temp_dur.append(word_duration) 209 | 210 | if round(sum(temp_dur), 6) < word_duration: 211 | temp_seq.append('rest') 212 | temp_dur.append(word_duration - round(sum(temp_dur), 6)) 213 | 214 | note_seq.extend(temp_seq) 215 | note_dur.extend(temp_dur) 216 | 217 | assert len(note_seq) == len(note_dur) 218 | row['note_seq'] = " ".join([str(x) for x in note_seq]) 219 | row['note_dur'] = " ".join([str(round(x, 6)) for x in note_dur]) 220 | # print(f" {audio_path}:\r\nnote_seq: {note_seq}\r\nnote_dur: {note_dur}") 221 | # count += 1 222 | 223 | with open(csv_path, 'w', encoding='utf8', newline='') as f: 224 | writer = DictWriter(f, fieldnames=['name', 'ph_seq', 'ph_dur', 'ph_num', 'note_seq', 'note_dur']) 225 | writer.writeheader() 226 | writer.writerows(csv_data) 227 | 228 | 229 | if __name__ == "__main__": 230 | batch_infer() 231 | -------------------------------------------------------------------------------- /binarize.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pathlib 3 | 4 | import click 5 | 6 | import preprocessing 7 | from utils.config_utils import read_full_config, print_config 8 | 9 | 10 | @click.command(help='Process the raw dataset into binary dataset') 11 | @click.option('--config', required=True, metavar='FILE', help='Path to the configuration file') 12 | def binarize(config): 13 | config = pathlib.Path(config) 14 | config = read_full_config(config) 15 | print_config(config) 16 | binarizer_cls = config['binarizer_cls'] 17 | pkg = ".".join(binarizer_cls.split(".")[:-1]) 18 | cls_name = binarizer_cls.split(".")[-1] 19 | binarizer_cls = getattr(importlib.import_module(pkg), cls_name) 20 | assert issubclass(binarizer_cls, preprocessing.BaseBinarizer), \ 21 | f'Binarizer class {binarizer_cls} is not a subclass of {preprocessing.BaseBinarizer}.' 22 | print("| Binarizer: ", binarizer_cls) 23 | binarizer_cls(config=config).process() 24 | 25 | 26 | if __name__ == '__main__': 27 | binarize() 28 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | # preprocessing 2 | binarizer_cls: preprocessing.BaseBinarizer 3 | raw_data_dir: [] 4 | binary_data_dir: null 5 | binarization_args: 6 | num_workers: 8 7 | shuffle: true 8 | valid_set_name: valid 9 | train_set_name: train 10 | 11 | hop_size: 512 12 | win_size: 2048 13 | audio_sample_rate: 44100 14 | fmin: 40 15 | fmax: 8000 16 | test_prefixes: [] 17 | units_encoder: mel # contentvec768l12 18 | units_encoder_ckpt: pretrained/contentvec/checkpoint_best_legacy_500.pt 19 | pe: rmvpe 20 | pe_ckpt: pretrained/rmvpe/model.pt 21 | 22 | # global constants 23 | midi_min: 0 24 | midi_max: 127 25 | 26 | # neural networks 27 | units_dim: 80 # 768 28 | midi_num_bins: 128 29 | model_cls: null 30 | midi_extractor_args: {} 31 | 32 | # training 33 | use_midi_loss: true 34 | use_bound_loss: true 35 | task_cls: training.BaseTask 36 | sort_by_len: true 37 | optimizer_args: 38 | optimizer_cls: torch.optim.AdamW 39 | lr: 0.0001 40 | beta1: 0.9 41 | beta2: 0.98 42 | weight_decay: 0 43 | 44 | lr_scheduler_args: 45 | scheduler_cls: lr_scheduler.scheduler.WarmupLR 46 | warmup_steps: 5000 47 | min_lr: 0.00001 48 | 49 | clip_grad_norm: 1 50 | accumulate_grad_batches: 1 51 | sampler_frame_count_grid: 6 52 | ds_workers: 4 53 | dataloader_prefetch_factor: 2 54 | 55 | max_batch_size: 8 56 | max_batch_frames: 80000 57 | max_val_batch_size: 1 58 | max_val_batch_frames: 10000 59 | num_valid_plots: 100 60 | log_interval: 100 61 | num_sanity_val_steps: 1 # steps of validation at the beginning 62 | val_check_interval: 1000 63 | num_ckpt_keep: 5 64 | max_updates: 100000 65 | permanent_ckpt_start: 200000 66 | permanent_ckpt_interval: 40000 67 | 68 | ########### 69 | # pytorch lightning 70 | # Read https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api for possible values 71 | ########### 72 | pl_trainer_accelerator: 'auto' 73 | pl_trainer_devices: 'auto' 74 | pl_trainer_precision: '32-true' 75 | pl_trainer_num_nodes: 1 76 | pl_trainer_strategy: 77 | name: auto 78 | process_group_backend: nccl 79 | find_unused_parameters: false 80 | nccl_p2p: true 81 | seed: 114514 82 | 83 | ########### 84 | # finetune 85 | ########### 86 | 87 | finetune_enabled: false 88 | finetune_ckpt_path: null 89 | finetune_ignored_params: [] 90 | finetune_strict_shapes: true 91 | 92 | freezing_enabled: false 93 | frozen_params: [] 94 | -------------------------------------------------------------------------------- /configs/continuous.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | configs/base.yaml 3 | 4 | # preprocessing 5 | binarizer_cls: preprocessing.MIDIExtractionBinarizer 6 | raw_data_dir: [] 7 | binary_data_dir: null 8 | binarization_args: 9 | num_workers: 0 10 | skip_glide: true # skip data with glide 11 | merge_rest: true # merge continuous rest notes 12 | merge_slur: true # merge slurs with the similar pitch 13 | round_midi: false # round midi value 14 | slur_tolerance: 0.5 # maximum allowed value of pitch change of a slur to be merged 15 | 16 | key_shift_factor: 8 17 | key_shift_range: [-12, 12] 18 | test_prefixes: [] 19 | units_encoder: mel # contentvec768l12 20 | units_encoder_ckpt: pretrained/contentvec/checkpoint_best_legacy_500.pt 21 | pe: rmvpe 22 | pe_ckpt: pretrained/rmvpe/model.pt 23 | 24 | # global constants 25 | midi_prob_deviation: 1.0 26 | rest_threshold: 0.1 27 | 28 | # neural networks 29 | units_dim: 80 # 768 30 | midi_num_bins: 128 # 256 31 | model_cls: modules.model.Gmidi_conform.midi_conforms 32 | midi_extractor_args: 33 | lay: 8 34 | dim: 512 35 | use_lay_skip: true 36 | kernel_size: 31 37 | conv_drop: 0.1 38 | ffn_latent_drop: 0.1 39 | ffn_out_drop: 0.1 40 | attention_drop: 0.1 41 | attention_heads: 8 42 | attention_heads_dim: 64 43 | 44 | # training 45 | task_cls: training.MIDIExtractionTask 46 | optimizer_args: 47 | optimizer_cls: torch.optim.AdamW 48 | lr: 0.0001 49 | beta1: 0.9 50 | beta2: 0.98 51 | weight_decay: 0 52 | 53 | lr_scheduler_args: 54 | scheduler_cls: lr_scheduler.scheduler.WarmupLR 55 | warmup_steps: 5000 56 | min_lr: 0.00001 57 | 58 | max_batch_size: 8 59 | max_batch_frames: 80000 60 | num_valid_plots: 10 61 | val_check_interval: 1000 62 | num_ckpt_keep: 5 63 | max_updates: 100000 64 | permanent_ckpt_start: 60000 65 | permanent_ckpt_interval: 8000 66 | -------------------------------------------------------------------------------- /configs/discrete.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - configs/base.yaml 3 | 4 | # preprocessing 5 | binarizer_cls: preprocessing.QuantizedMIDIExtractionBinarizer 6 | raw_data_dir: [] 7 | binary_data_dir: data/some_ds_quant_spk4_aug8/binary 8 | binarization_args: 9 | num_workers: 0 10 | shuffle: true 11 | skip_glide: true # skip data with glide 12 | merge_rest: true # merge continuous rest notes 13 | merge_slur: true # merge slurs with the similar pitch 14 | use_bound_loss: true 15 | use_midi_loss: true 16 | 17 | test_prefixes: [] 18 | units_encoder: mel # contentvec768l12 19 | units_encoder_ckpt: pretrained/contentvec/checkpoint_best_legacy_500.pt 20 | pe: rmvpe 21 | pe_ckpt: pretrained/rmvpe/model.pt 22 | 23 | # global constants 24 | key_shift_range: [-12, 12] 25 | key_shift_factor: 8 26 | 27 | # neural networks 28 | units_dim: 80 # 768 29 | midi_num_bins: 129 # rest = 128 30 | model_cls: modules.model.Gmidi_conform.midi_conforms 31 | midi_extractor_args: 32 | lay: 3 33 | dim: 512 34 | use_lay_skip: true 35 | kernel_size: 31 36 | conv_drop: 0.1 37 | ffn_latent_drop: 0.1 38 | ffn_out_drop: 0.1 39 | attention_drop: 0.1 40 | attention_heads: 8 41 | attention_heads_dim: 64 42 | 43 | # training 44 | task_cls: training.QuantizedMIDIExtractionTask 45 | optimizer_args: 46 | optimizer_cls: torch.optim.AdamW 47 | lr: 0.0001 48 | beta1: 0.9 49 | beta2: 0.98 50 | weight_decay: 0 51 | 52 | lr_scheduler_args: 53 | scheduler_cls: lr_scheduler.scheduler.WarmupLR 54 | warmup_steps: 10000 55 | min_lr: 0.00001 56 | 57 | max_batch_size: 8 58 | max_batch_frames: 80000 59 | num_valid_plots: 10 60 | val_check_interval: 1000 61 | num_ckpt_keep: 5 62 | max_updates: 100000 63 | permanent_ckpt_start: 200000 64 | permanent_ckpt_interval: 40000 65 | -------------------------------------------------------------------------------- /configs/midi_conformer.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - configs/base.yaml 3 | 4 | 5 | model_cls: modules.model.Gmidi_conform.midi_conforms 6 | task_cls: training.MIDIExtractionTask 7 | binary_data_dir: data/some_ds_roundmidi_spk3_aug8/binary 8 | 9 | num_valid_plots: 100 10 | log_interval: 100 11 | num_sanity_val_steps: 1 # steps of validation at the beginning 12 | val_check_interval: 5000 13 | num_ckpt_keep: 6 14 | max_updates: 300000 15 | 16 | midi_prob_deviation: 1.0 17 | midi_shift_proportion: 0.0 18 | midi_shift_range: [-12, 12] 19 | rest_threshold: 0.1 20 | 21 | 22 | midi_extractor_args: 23 | lay: 8 24 | dim: 512 25 | 26 | use_lay_skip: true 27 | kernel_size: 31 28 | conv_drop: 0.1 29 | ffn_latent_drop: 0.1 30 | ffn_out_drop: 0.1 31 | attention_drop: 0.1 32 | attention_heads: 8 33 | attention_heads_dim: 64 34 | 35 | pl_trainer_precision: 'bf16' -------------------------------------------------------------------------------- /configs/quant_two_head_model.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - configs/discrete.yaml 3 | 4 | binary_data_dir: data/some_ds_quant_spk4_aug8/binary 5 | 6 | # neural networks 7 | units_dim: 80 # 768 8 | midi_num_bins: 129 # rest = 128 9 | model_cls: modules.model.Gmidi_conform.midi_conforms 10 | midi_extractor_args: 11 | lay: 3 12 | dim: 512 13 | use_lay_skip: true 14 | kernel_size: 31 15 | conv_drop: 0.1 16 | ffn_latent_drop: 0.1 17 | ffn_out_drop: 0.1 18 | attention_drop: 0.1 19 | attention_heads: 8 20 | attention_heads_dim: 64 21 | 22 | # training 23 | task_cls: training.QuantizedMIDIExtractionTask 24 | optimizer_args: 25 | optimizer_cls: torch.optim.AdamW 26 | lr: 0.0001 27 | beta1: 0.9 28 | beta2: 0.98 29 | weight_decay: 0 30 | 31 | lr_scheduler_args: 32 | scheduler_cls: lr_scheduler.scheduler.WarmupLR 33 | warmup_steps: 10000 34 | min_lr: 0.00001 35 | 36 | max_batch_size: 8 37 | max_batch_frames: 80000 38 | num_valid_plots: 10 39 | val_check_interval: 1000 40 | num_ckpt_keep: 5 41 | max_updates: 100000 42 | permanent_ckpt_start: 200000 43 | permanent_ckpt_interval: 40000 44 | -------------------------------------------------------------------------------- /configs/two_head_model.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | configs/continuous.yaml 3 | 4 | # preprocessing 5 | binarizer_cls: preprocessing.MIDIExtractionBinarizer 6 | raw_data_dir: [] 7 | binary_data_dir: data/some_dataset_mel/binary 8 | binarization_args: 9 | num_workers: 0 10 | shuffle: true 11 | 12 | test_prefixes: [] 13 | units_encoder: mel # contentvec768l12 14 | units_encoder_ckpt: pretrained/contentvec/checkpoint_best_legacy_500.pt 15 | pe: rmvpe 16 | pe_ckpt: pretrained/rmvpe/model.pt 17 | 18 | # global constants 19 | midi_prob_deviation: 1.0 20 | rest_threshold: 0.1 21 | 22 | # neural networks 23 | units_dim: 80 # 768 24 | midi_num_bins: 128 25 | model_cls: modules.model.Gmidi_conform.midi_conforms 26 | midi_extractor_args: 27 | lay: 3 28 | dim: 512 29 | use_lay_skip: true 30 | kernel_size: 31 31 | conv_drop: 0.1 32 | ffn_latent_drop: 0.1 33 | ffn_out_drop: 0.1 34 | attention_drop: 0.1 35 | attention_heads: 8 36 | attention_heads_dim: 64 37 | 38 | # training 39 | task_cls: training.MIDIExtractionTask 40 | use_bound_loss: true 41 | use_midi_loss: true 42 | optimizer_args: 43 | optimizer_cls: torch.optim.AdamW 44 | lr: 0.0001 45 | beta1: 0.9 46 | beta2: 0.98 47 | weight_decay: 0 48 | 49 | lr_scheduler_args: 50 | scheduler_cls: lr_scheduler.scheduler.WarmupLR 51 | warmup_steps: 5000 52 | min_lr: 0.00001 53 | 54 | max_batch_size: 8 55 | max_batch_frames: 80000 56 | num_valid_plots: 10 57 | val_check_interval: 1000 58 | num_ckpt_keep: 5 59 | max_updates: 100000 60 | permanent_ckpt_start: 200000 61 | permanent_ckpt_interval: 40000 62 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openvpi/SOME/dc9916fb35350748351ddcdfc128fe908ef53f5d/data/.gitkeep -------------------------------------------------------------------------------- /deployment/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_onnx_module import BaseONNXModule 2 | from .me_onnx_module import MIDIExtractionONNXModule 3 | from .me_quant_onnx_module import QuantizedMIDIExtractionONNXModule 4 | 5 | task_module_mapping = { 6 | 'training.MIDIExtractionTask': 'deployment.MIDIExtractionONNXModule', 7 | 'training.QuantizedMIDIExtractionTask': 'deployment.QuantizedMIDIExtractionONNXModule', 8 | } 9 | -------------------------------------------------------------------------------- /deployment/base_onnx_module.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from collections import OrderedDict 3 | 4 | from librosa.filters import mel 5 | import torch 6 | from torch import nn 7 | 8 | from utils import build_object_from_class_name 9 | 10 | 11 | class BaseONNXModule(nn.Module): 12 | def __init__(self, config: dict, model_path: pathlib.Path, device=None): 13 | super().__init__() 14 | if device is None: 15 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | self.config = config 17 | self.model_path = model_path 18 | self.device = device 19 | self.timestep = self.config['hop_size'] / self.config['audio_sample_rate'] 20 | self.model: torch.nn.Module = self.build_model() 21 | 22 | def build_model(self) -> nn.Module: 23 | model: nn.Module = build_object_from_class_name( 24 | self.config['model_cls'], nn.Module, config=self.config 25 | ).eval().to(self.device) 26 | state_dict = torch.load(self.model_path, map_location=self.device)['state_dict'] 27 | prefix_in_ckpt = 'model' 28 | state_dict = OrderedDict({ 29 | k[len(prefix_in_ckpt) + 1:]: v 30 | for k, v in state_dict.items() if k.startswith(f'{prefix_in_ckpt}.') 31 | }) 32 | model.load_state_dict(state_dict, strict=True) 33 | print(f'| load \'{prefix_in_ckpt}\' from \'{self.model_path}\'.') 34 | return model 35 | 36 | 37 | class MelSpectrogram_ONNX(nn.Module): 38 | def __init__( 39 | self, 40 | n_mel_channels, 41 | sampling_rate, 42 | win_length, 43 | hop_length, 44 | n_fft=None, 45 | mel_fmin=0, 46 | mel_fmax=None, 47 | clamp=1e-5 48 | ): 49 | super().__init__() 50 | n_fft = win_length if n_fft is None else n_fft 51 | mel_basis = mel( 52 | sr=sampling_rate, 53 | n_fft=n_fft, 54 | n_mels=n_mel_channels, 55 | fmin=mel_fmin, 56 | fmax=mel_fmax, 57 | htk=True) 58 | mel_basis = torch.from_numpy(mel_basis).float() 59 | self.register_buffer("mel_basis", mel_basis) 60 | self.n_fft = win_length if n_fft is None else n_fft 61 | self.hop_length = hop_length 62 | self.win_length = win_length 63 | self.sampling_rate = sampling_rate 64 | self.n_mel_channels = n_mel_channels 65 | self.clamp = clamp 66 | 67 | def forward(self, audio, center=True): 68 | fft = torch.stft( 69 | audio, 70 | n_fft=self.n_fft, 71 | hop_length=self.hop_length, 72 | win_length=self.win_length, 73 | window=torch.hann_window(self.win_length, device=audio.device), 74 | center=center, 75 | return_complex=False 76 | ) 77 | magnitude = torch.sqrt(torch.sum(fft ** 2, dim=-1)) 78 | mel_output = torch.matmul(self.mel_basis, magnitude) 79 | log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) 80 | return log_mel_spec 81 | -------------------------------------------------------------------------------- /deployment/me_onnx_module.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import torch 4 | 5 | from utils.infer_utils import decode_bounds_to_alignment, decode_gaussian_blurred_probs, decode_note_sequence 6 | from .base_onnx_module import BaseONNXModule, MelSpectrogram_ONNX 7 | 8 | 9 | class MIDIExtractionONNXModule(BaseONNXModule): 10 | def __init__(self, config: dict, model_path: pathlib.Path, device=None): 11 | super().__init__(config, model_path, device=device) 12 | self.mel_extractor = MelSpectrogram_ONNX( 13 | n_mel_channels=self.config['units_dim'], sampling_rate=self.config['audio_sample_rate'], 14 | win_length=self.config['win_size'], hop_length=self.config['hop_size'], 15 | mel_fmin=self.config['fmin'], mel_fmax=self.config['fmax'] 16 | ).to(self.device) 17 | self.rmvpe = None 18 | self.midi_min = self.config['midi_min'] 19 | self.midi_max = self.config['midi_max'] 20 | self.midi_deviation = self.config['midi_prob_deviation'] 21 | self.rest_threshold = self.config['rest_threshold'] 22 | 23 | def forward(self, waveform: torch.Tensor): 24 | units = self.mel_extractor(waveform).transpose(1, 2) 25 | pitch = torch.zeros(units.shape[:2], dtype=torch.float32, device=self.device) 26 | masks = torch.ones_like(pitch, dtype=torch.bool) 27 | probs, bounds = self.model(x=units, f0=pitch, mask=masks, sig=True) 28 | probs *= masks[..., None] 29 | bounds *= masks 30 | unit2note_pred = decode_bounds_to_alignment(bounds, use_diff=False) * masks 31 | midi_pred, rest_pred = decode_gaussian_blurred_probs( 32 | probs, vmin=self.midi_min, vmax=self.midi_max, 33 | deviation=self.midi_deviation, threshold=self.rest_threshold 34 | ) 35 | note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence( 36 | unit2note_pred, midi_pred, ~rest_pred & masks 37 | ) 38 | note_rest_pred = ~note_mask_pred 39 | return note_midi_pred, note_rest_pred, note_dur_pred * self.timestep 40 | -------------------------------------------------------------------------------- /deployment/me_quant_onnx_module.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import torch 4 | 5 | from utils.infer_utils import decode_bounds_to_alignment, decode_note_sequence 6 | from .base_onnx_module import BaseONNXModule, MelSpectrogram_ONNX 7 | 8 | 9 | class QuantizedMIDIExtractionONNXModule(BaseONNXModule): 10 | def __init__(self, config: dict, model_path: pathlib.Path, device=None): 11 | super().__init__(config, model_path, device=device) 12 | self.mel_extractor = MelSpectrogram_ONNX( 13 | n_mel_channels=self.config['units_dim'], sampling_rate=self.config['audio_sample_rate'], 14 | win_length=self.config['win_size'], hop_length=self.config['hop_size'], 15 | mel_fmin=self.config['fmin'], mel_fmax=self.config['fmax'] 16 | ).to(self.device) 17 | self.rmvpe = None 18 | 19 | def forward(self, waveform: torch.Tensor): 20 | units = self.mel_extractor(waveform).transpose(1, 2) 21 | pitch = torch.zeros(units.shape[:2], dtype=torch.float32, device=self.device) 22 | masks = torch.ones_like(pitch, dtype=torch.bool) 23 | probs, bounds = self.model(x=units, f0=pitch, mask=masks, sig=True) 24 | probs *= masks[..., None] 25 | bounds *= masks 26 | unit2note_pred = decode_bounds_to_alignment(bounds) * masks 27 | midi_pred = probs.argmax(dim=-1) 28 | rest_pred = midi_pred == 128 29 | note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence( 30 | unit2note_pred, midi_pred.clip(min=0, max=127), ~rest_pred & masks 31 | ) 32 | note_rest_pred = ~note_mask_pred 33 | return note_midi_pred, note_rest_pred, note_dur_pred * self.timestep 34 | -------------------------------------------------------------------------------- /experiments/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openvpi/SOME/dc9916fb35350748351ddcdfc128fe908ef53f5d/experiments/.gitkeep -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pathlib 3 | from typing import Dict, Tuple, Union 4 | 5 | import click 6 | import onnx 7 | import onnxsim 8 | import torch 9 | import yaml 10 | 11 | import deployment 12 | from utils.config_utils import print_config 13 | 14 | 15 | def onnx_override_io_shapes( 16 | model, # ModelProto 17 | input_shapes: Dict[str, Tuple[Union[str, int]]] = None, 18 | output_shapes: Dict[str, Tuple[Union[str, int]]] = None, 19 | ): 20 | """ 21 | Override the shapes of inputs/outputs of the model graph (in-place operation). 22 | :param model: model to perform the operation on 23 | :param input_shapes: a dict with keys as input/output names and values as shape tuples 24 | :param output_shapes: the same as input_shapes 25 | """ 26 | def _override_shapes( 27 | shape_list_old, # RepeatedCompositeFieldContainer[ValueInfoProto] 28 | shape_dict_new: Dict[str, Tuple[Union[str, int]]]): 29 | for value_info in shape_list_old: 30 | if value_info.name in shape_dict_new: 31 | name = value_info.name 32 | dims = value_info.type.tensor_type.shape.dim 33 | assert len(shape_dict_new[name]) == len(dims), \ 34 | f'Number of given and existing dimensions mismatch: {name}' 35 | for i, dim in enumerate(shape_dict_new[name]): 36 | if isinstance(dim, int): 37 | dims[i].dim_param = '' 38 | dims[i].dim_value = dim 39 | else: 40 | dims[i].dim_value = 0 41 | dims[i].dim_param = dim 42 | 43 | if input_shapes is not None: 44 | _override_shapes(model.graph.input, input_shapes) 45 | if output_shapes is not None: 46 | _override_shapes(model.graph.output, output_shapes) 47 | 48 | 49 | @click.command(help='Run inference with a trained model') 50 | @click.option('--model', required=True, metavar='CKPT_PATH', help='Path to the model checkpoint (*.ckpt)') 51 | @click.option('--out', required=False, metavar='ONNX_PATH', help='Path to the output model (*.onnx)') 52 | def export(model, out): 53 | model_path = pathlib.Path(model) 54 | with open(model_path.with_name('config.yaml'), 'r', encoding='utf8') as f: 55 | config = yaml.safe_load(f) 56 | print_config(config) 57 | module_cls = deployment.task_module_mapping[config['task_cls']] 58 | 59 | pkg = ".".join(module_cls.split(".")[:-1]) 60 | cls_name = module_cls.split(".")[-1] 61 | module_cls = getattr(importlib.import_module(pkg), cls_name) 62 | assert issubclass(module_cls, deployment.BaseONNXModule), \ 63 | f'Module class {module_cls} is not a subclass of {deployment.BaseONNXModule}.' 64 | module_ins = module_cls(config=config, model_path=model_path) 65 | 66 | waveform = torch.randn((1, 114514), dtype=torch.float32, device=module_ins.device) 67 | out_path = pathlib.Path(out) if out is not None else model_path.with_suffix('.onnx') 68 | torch.onnx.export( 69 | module_ins, 70 | waveform, 71 | out_path, 72 | input_names=['waveform'], 73 | output_names=[ 74 | 'note_midi', 75 | 'note_rest', 76 | 'note_dur' 77 | ], 78 | dynamic_axes={ 79 | 'waveform': { 80 | 1: 'n_samples' 81 | }, 82 | 'note_midi': { 83 | 1: 'n_notes' 84 | }, 85 | 'note_rest': { 86 | 1: 'n_notes' 87 | }, 88 | 'note_dur': { 89 | 1: 'n_notes' 90 | }, 91 | }, 92 | opset_version=17 93 | ) 94 | onnx_model = onnx.load(out_path.as_posix()) 95 | onnx_override_io_shapes(onnx_model, output_shapes={ 96 | 'note_midi': (1, 'n_notes'), 97 | 'note_rest': (1, 'n_notes'), 98 | 'note_dur': (1, 'n_notes'), 99 | }) 100 | print('Running ONNX Simplifier...') 101 | onnx_model, check = onnxsim.simplify( 102 | onnx_model, 103 | include_subgraph=True 104 | ) 105 | assert check, 'Simplified ONNX model could not be validated' 106 | onnx.save(onnx_model, out_path) 107 | 108 | 109 | if __name__ == '__main__': 110 | export() 111 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pathlib 3 | 4 | import click 5 | import librosa 6 | import yaml 7 | 8 | import inference 9 | from utils.config_utils import print_config 10 | from utils.infer_utils import build_midi_file 11 | from utils.slicer2 import Slicer 12 | 13 | 14 | @click.command(help='Run inference with a trained model') 15 | @click.option('--model', required=True, metavar='CKPT_PATH', help='Path to the model checkpoint (*.ckpt)') 16 | @click.option('--wav', required=True, metavar='WAV_PATH', help='Path to the input wav file (*.wav)') 17 | @click.option('--midi', required=False, metavar='MIDI_PATH', help='Path to the output MIDI file (*.mid)') 18 | @click.option('--tempo', required=False, type=float, default=120, metavar='TEMPO', help='Specify tempo in the output MIDI') 19 | def infer(model, wav, midi, tempo): 20 | model_path = pathlib.Path(model) 21 | with open(model_path.with_name('config.yaml'), 'r', encoding='utf8') as f: 22 | config = yaml.safe_load(f) 23 | print_config(config) 24 | infer_cls = inference.task_inference_mapping[config['task_cls']] 25 | 26 | pkg = ".".join(infer_cls.split(".")[:-1]) 27 | cls_name = infer_cls.split(".")[-1] 28 | infer_cls = getattr(importlib.import_module(pkg), cls_name) 29 | assert issubclass(infer_cls, inference.BaseInference), \ 30 | f'Inference class {infer_cls} is not a subclass of {inference.BaseInference}.' 31 | infer_ins = infer_cls(config=config, model_path=model_path) 32 | 33 | wav_path = pathlib.Path(wav) 34 | waveform, _ = librosa.load(wav_path, sr=config['audio_sample_rate'], mono=True) 35 | slicer = Slicer(sr=config['audio_sample_rate'], max_sil_kept=1000) 36 | chunks = slicer.slice(waveform) 37 | midis = infer_ins.infer([c['waveform'] for c in chunks]) 38 | 39 | midi_file = build_midi_file([c['offset'] for c in chunks], midis, tempo=tempo) 40 | 41 | midi_path = pathlib.Path(midi) if midi is not None else wav_path.with_suffix('.mid') 42 | midi_file.save(midi_path) 43 | print(f'MIDI file saved at: \'{midi_path}\'') 44 | 45 | 46 | if __name__ == '__main__': 47 | infer() 48 | -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_infer import BaseInference 2 | from .me_infer import MIDIExtractionInference 3 | from .me_quant_infer import QuantizedMIDIExtractionInference 4 | 5 | task_inference_mapping = { 6 | 'training.MIDIExtractionTask': 'inference.MIDIExtractionInference', 7 | 'training.QuantizedMIDIExtractionTask': 'inference.QuantizedMIDIExtractionInference', 8 | } 9 | -------------------------------------------------------------------------------- /inference/base_infer.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from collections import OrderedDict 3 | from typing import Dict, List 4 | 5 | import numpy as np 6 | import torch 7 | import tqdm 8 | from torch import nn 9 | 10 | from utils import build_object_from_class_name 11 | 12 | 13 | class BaseInference: 14 | def __init__(self, config: dict, model_path: pathlib.Path, device=None): 15 | if device is None: 16 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | self.config = config 18 | self.model_path = model_path 19 | self.device = device 20 | self.timestep = self.config['hop_size'] / self.config['audio_sample_rate'] 21 | self.model: torch.nn.Module = self.build_model() 22 | 23 | def build_model(self) -> nn.Module: 24 | model: nn.Module = build_object_from_class_name( 25 | self.config['model_cls'], nn.Module, config=self.config 26 | ).eval().to(self.device) 27 | state_dict = torch.load(self.model_path, map_location=self.device)['state_dict'] 28 | prefix_in_ckpt = 'model' 29 | state_dict = OrderedDict({ 30 | k[len(prefix_in_ckpt) + 1:]: v 31 | for k, v in state_dict.items() if k.startswith(f'{prefix_in_ckpt}.') 32 | }) 33 | model.load_state_dict(state_dict, strict=True) 34 | print(f'| load \'{prefix_in_ckpt}\' from \'{self.model_path}\'.') 35 | return model 36 | 37 | def preprocess(self, waveform: np.ndarray) -> Dict[str, torch.Tensor]: 38 | raise NotImplementedError() 39 | 40 | def forward_model(self, sample: Dict[str, torch.Tensor]): 41 | raise NotImplementedError() 42 | 43 | def postprocess(self, results: Dict[str, torch.Tensor]) -> List[Dict[str, np.ndarray]]: 44 | raise NotImplementedError() 45 | 46 | def infer(self, waveforms: List[np.ndarray]) -> List[Dict[str, np.ndarray]]: 47 | results = [] 48 | for w in tqdm.tqdm(waveforms): 49 | model_in = self.preprocess(w) 50 | model_out = self.forward_model(model_in) 51 | res = self.postprocess(model_out) 52 | results.append(res) 53 | return results 54 | -------------------------------------------------------------------------------- /inference/me_infer.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Dict, List 3 | 4 | import librosa 5 | import numpy as np 6 | import torch 7 | 8 | import modules.rmvpe 9 | from utils.binarizer_utils import get_pitch_parselmouth 10 | from utils.infer_utils import decode_bounds_to_alignment, decode_gaussian_blurred_probs, decode_note_sequence 11 | from utils.pitch_utils import resample_align_curve 12 | from .base_infer import BaseInference 13 | 14 | 15 | class MIDIExtractionInference(BaseInference): 16 | def __init__(self, config: dict, model_path: pathlib.Path, device=None): 17 | super().__init__(config, model_path, device=device) 18 | self.mel_spec = modules.rmvpe.MelSpectrogram( 19 | n_mel_channels=self.config['units_dim'], sampling_rate=self.config['audio_sample_rate'], 20 | win_length=self.config['win_size'], hop_length=self.config['hop_size'], 21 | mel_fmin=self.config['fmin'], mel_fmax=self.config['fmax'] 22 | ).to(self.device) 23 | self.rmvpe = None 24 | self.midi_min = self.config['midi_min'] 25 | self.midi_max = self.config['midi_max'] 26 | self.midi_deviation = self.config['midi_prob_deviation'] 27 | self.rest_threshold = self.config['rest_threshold'] 28 | 29 | def preprocess(self, waveform: np.ndarray) -> Dict[str, torch.Tensor]: 30 | wav_tensor = torch.from_numpy(waveform).unsqueeze(0).to(self.device) 31 | units = self.mel_spec(wav_tensor).transpose(1, 2) 32 | length = units.shape[1] 33 | # f0_algo = self.config['pe'] 34 | # if f0_algo == 'parselmouth': 35 | # f0, _ = get_pitch_parselmouth( 36 | # waveform, sample_rate=self.config['audio_sample_rate'], 37 | # hop_size=self.config['hop_size'], length=length, interp_uv=True 38 | # ) 39 | # elif f0_algo == 'rmvpe': 40 | # if self.rmvpe is None: 41 | # self.rmvpe = modules.rmvpe.RMVPE(self.config['pe_ckpt'], device=self.device) 42 | # f0, _ = self.rmvpe.get_pitch( 43 | # waveform, sample_rate=self.config['audio_sample_rate'], 44 | # hop_size=self.rmvpe.mel_extractor.hop_length, 45 | # length=(waveform.shape[0] + self.rmvpe.mel_extractor.hop_length - 1) // self.rmvpe.mel_extractor.hop_length, 46 | # interp_uv=True 47 | # ) 48 | # f0 = resample_align_curve( 49 | # f0, 50 | # original_timestep=self.rmvpe.mel_extractor.hop_length / self.config['audio_sample_rate'], 51 | # target_timestep=self.config['hop_size'] / self.config['audio_sample_rate'], 52 | # align_length=length 53 | # ) 54 | # else: 55 | # raise NotImplementedError(f'Invalid pitch extractor: {f0_algo}') 56 | # pitch = librosa.hz_to_midi(f0) 57 | # pitch = torch.from_numpy(pitch).unsqueeze(0).to(self.device) 58 | pitch = torch.zeros(units.shape[:2], dtype=torch.float32, device=self.device) 59 | return { 60 | 'units': units, 61 | 'pitch': pitch, 62 | 'masks': torch.ones_like(pitch, dtype=torch.bool) 63 | } 64 | 65 | @torch.no_grad() 66 | def forward_model(self, sample: Dict[str, torch.Tensor]): 67 | 68 | 69 | 70 | probs, bounds = self.model(x=sample['units'], f0=sample['pitch'], mask=sample['masks'],sig=True) 71 | 72 | return { 73 | 'probs': probs, 74 | 'bounds': bounds, 75 | 'masks': sample['masks'], 76 | } 77 | 78 | def postprocess(self, results: Dict[str, torch.Tensor]) -> List[Dict[str, np.ndarray]]: 79 | probs = results['probs'] 80 | bounds = results['bounds'] 81 | masks = results['masks'] 82 | probs *= masks[..., None] 83 | bounds *= masks 84 | unit2note_pred = decode_bounds_to_alignment(bounds) * masks 85 | midi_pred, rest_pred = decode_gaussian_blurred_probs( 86 | probs, vmin=self.midi_min, vmax=self.midi_max, 87 | deviation=self.midi_deviation, threshold=self.rest_threshold 88 | ) 89 | note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence( 90 | unit2note_pred, midi_pred, ~rest_pred & masks 91 | ) 92 | note_rest_pred = ~note_mask_pred 93 | return { 94 | 'note_midi': note_midi_pred.squeeze(0).cpu().numpy(), 95 | 'note_dur': note_dur_pred.squeeze(0).cpu().numpy() * self.timestep, 96 | 'note_rest': note_rest_pred.squeeze(0).cpu().numpy() 97 | } 98 | -------------------------------------------------------------------------------- /inference/me_quant_infer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from utils.infer_utils import decode_bounds_to_alignment, decode_note_sequence 7 | from .me_infer import MIDIExtractionInference 8 | 9 | 10 | class QuantizedMIDIExtractionInference(MIDIExtractionInference): 11 | @torch.no_grad() 12 | def forward_model(self, sample: Dict[str, torch.Tensor]): 13 | probs, bounds = self.model(x=sample['units'], f0=sample['pitch'], mask=sample['masks'], softmax=True) 14 | 15 | return { 16 | 'probs': probs, 17 | 'bounds': bounds, 18 | 'masks': sample['masks'], 19 | } 20 | 21 | def postprocess(self, results: Dict[str, torch.Tensor]) -> List[Dict[str, np.ndarray]]: 22 | probs = results['probs'] 23 | bounds = results['bounds'] 24 | masks = results['masks'] 25 | probs *= masks[..., None] 26 | bounds *= masks 27 | unit2note_pred = decode_bounds_to_alignment(bounds) * masks 28 | midi_pred = probs.argmax(dim=-1) 29 | rest_pred = midi_pred == 128 30 | note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence( 31 | unit2note_pred, midi_pred.clip(min=0, max=127), ~rest_pred & masks 32 | ) 33 | note_rest_pred = ~note_mask_pred 34 | return { 35 | 'note_midi': note_midi_pred.squeeze(0).cpu().numpy(), 36 | 'note_dur': note_dur_pred.squeeze(0).cpu().numpy() * self.timestep, 37 | 'note_rest': note_rest_pred.squeeze(0).cpu().numpy() 38 | } 39 | -------------------------------------------------------------------------------- /lr_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | # from typeguard import check_argument_types 8 | 9 | 10 | class WarmupLR(_LRScheduler): 11 | """The WarmupLR scheduler 12 | 13 | This scheduler is almost same as NoamLR Scheduler except for following 14 | difference: 15 | 16 | NoamLR: 17 | lr = optimizer.lr * model_size ** -0.5 18 | * min(step ** -0.5, step * warmup_step ** -1.5) 19 | WarmupLR: 20 | lr = optimizer.lr * warmup_step ** 0.5 21 | * min(step ** -0.5, step * warmup_step ** -1.5) 22 | 23 | Note that the maximum lr equals to optimizer.lr in this scheduler. 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | optimizer: torch.optim.Optimizer, 30 | warmup_steps: Union[int, float] = 5000, 31 | min_lr=2e-5, 32 | last_epoch: int = -1, 33 | ): 34 | # assert check_argument_types() 35 | self.warmup_steps = warmup_steps 36 | self.min_lr = min_lr 37 | super().__init__(optimizer, last_epoch) 38 | 39 | def __repr__(self): 40 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, lr={self.base_lr}, min_lr={self.min_lr}, last_epoch={self.last_epoch})" 41 | 42 | def get_lr(self): 43 | step_num = self.last_epoch + 1 44 | if self.warmup_steps == 0: 45 | lrs = [] 46 | for lr in self.base_lrs: 47 | lr = lr * step_num ** -0.5 48 | if lr < self.min_lr: 49 | lr = self.min_lr 50 | lrs.append(lr) 51 | return lrs 52 | else: 53 | lrs = [] 54 | for lr in self.base_lrs: 55 | lr = lr * self.warmup_steps ** 0.5 * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5) 56 | if lr < self.min_lr and step_num > self.warmup_steps: 57 | lr = self.min_lr 58 | lrs.append(lr) 59 | return lrs 60 | 61 | def set_step(self, step: int): 62 | self.last_epoch = step 63 | 64 | class SGDRLR(_LRScheduler): 65 | """The WarmupLR scheduler 66 | 67 | This scheduler is almost same as NoamLR Scheduler except for following 68 | difference: 69 | 70 | NoamLR: 71 | lr = optimizer.lr * model_size ** -0.5 72 | * min(step ** -0.5, step * warmup_step ** -1.5) 73 | WarmupLR: 74 | lr = optimizer.lr * warmup_step ** 0.5 75 | * min(step ** -0.5, step * warmup_step ** -1.5) 76 | 77 | Note that the maximum lr equals to optimizer.lr in this scheduler. 78 | 79 | """ 80 | 81 | def __init__( 82 | self, 83 | optimizer: torch.optim.Optimizer, 84 | warmup_steps: Union[int, float] = 25000, 85 | min_lr=1e-5, 86 | last_epoch: int = -1, T_0=1500, eta_max=0.1, eta_min=0.,T_mul=2,T_mult=2 87 | ): 88 | # assert check_argument_types() 89 | self.warmup_steps = warmup_steps 90 | self.min_lr = min_lr 91 | self.eta_min = eta_min 92 | self.T_0 = T_0 93 | self.eta_max = eta_max 94 | self.T_mul = T_mul 95 | self.T_mult = T_mult 96 | 97 | super().__init__(optimizer, last_epoch) 98 | 99 | def __repr__(self): 100 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, lr={self.base_lr}, min_lr={self.min_lr}, last_epoch={self.last_epoch})" 101 | 102 | def adjust_lr(self,): 103 | step_num = self.last_epoch + 1 104 | if self.T_mul == 2: 105 | i = np.log2(step_num / self.T_0 + 1).astype(np.int32) 106 | T_cur = step_num - self.T_0 * (self.T_mult ** (i) - 1) 107 | T_i = (self.T_0 * self.T_mult ** i) 108 | elif self.T_mul == 1: 109 | T_cur = step_num % self.T_0 110 | T_i = self.T_0 111 | cur_lr = self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + np.cos(np.pi * T_cur / T_i)) 112 | return cur_lr 113 | 114 | 115 | def get_lr(self): 116 | # step_num = self.last_epoch + 1 117 | if self.warmup_steps == 0: 118 | lrs = [] 119 | for lr in self.base_lrs: 120 | lr = self.adjust_lr() 121 | lrs.append(lr) 122 | return lrs 123 | else: 124 | lrs = [] 125 | for lr in self.base_lrs: 126 | lr = self.adjust_lr() 127 | lrs.append(lr) 128 | return lrs 129 | 130 | def set_step(self, step: int): 131 | self.last_epoch = step 132 | class LSGDRLR(_LRScheduler): 133 | """The WarmupLR scheduler 134 | 135 | This scheduler is almost same as NoamLR Scheduler except for following 136 | difference: 137 | 138 | NoamLR: 139 | lr = optimizer.lr * model_size ** -0.5 140 | * min(step ** -0.5, step * warmup_step ** -1.5) 141 | WarmupLR: 142 | lr = optimizer.lr * warmup_step ** 0.5 143 | * min(step ** -0.5, step * warmup_step ** -1.5) 144 | 145 | Note that the maximum lr equals to optimizer.lr in this scheduler. 146 | 147 | """ 148 | 149 | def __init__( 150 | self, 151 | optimizer: torch.optim.Optimizer, 152 | warmup_steps: Union[int, float] = 25000, 153 | min_lr=1e-5, 154 | last_epoch: int = -1, T_0=1500, eta_max=0.1, eta_min=0.,T_mul=2,T_mult=0.9999 155 | ): 156 | # assert check_argument_types() 157 | self.warmup_steps = warmup_steps 158 | self.min_lr = min_lr 159 | self.eta_min = eta_min 160 | self.T_0 = T_0 161 | self.eta_max = eta_max 162 | self.T_mul = T_mul 163 | self.T_mult = T_mult 164 | 165 | super().__init__(optimizer, last_epoch) 166 | 167 | def __repr__(self): 168 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, lr={self.base_lr}, min_lr={self.min_lr}, last_epoch={self.last_epoch})" 169 | 170 | def adjust_lr(self,): 171 | step_num = self.last_epoch + 1 172 | 173 | cur_lr = self.eta_min* self.T_mult ** step_num + np.cos(np.pi * step_num / self.T_0 ) 174 | return cur_lr 175 | 176 | 177 | def get_lr(self): 178 | # step_num = self.last_epoch + 1 179 | if self.warmup_steps == 0: 180 | lrs = [] 181 | for lr in self.base_lrs: 182 | lr = self.adjust_lr() 183 | lrs.append(lr) 184 | return lrs 185 | else: 186 | lrs = [] 187 | for lr in self.base_lrs: 188 | lr = self.adjust_lr() 189 | lrs.append(lr) 190 | return lrs 191 | 192 | def set_step(self, step: int): 193 | self.last_epoch = step 194 | 195 | class V2LSGDRLR(_LRScheduler): 196 | """The WarmupLR scheduler 197 | 198 | This scheduler is almost same as NoamLR Scheduler except for following 199 | difference: 200 | 201 | NoamLR: 202 | lr = optimizer.lr * model_size ** -0.5 203 | * min(step ** -0.5, step * warmup_step ** -1.5) 204 | WarmupLR: 205 | lr = optimizer.lr * warmup_step ** 0.5 206 | * min(step ** -0.5, step * warmup_step ** -1.5) 207 | 208 | Note that the maximum lr equals to optimizer.lr in this scheduler. 209 | 210 | """ 211 | 212 | def __init__( 213 | self, 214 | optimizer: torch.optim.Optimizer, 215 | warmup_steps: Union[int, float] = 25000, 216 | min_lr=1e-5, 217 | last_epoch: int = -1, T_0=1500, eta_max=0.1, eta_min=0.,T_mul=2,T_mult=0.9999 218 | ): 219 | # assert check_argument_types() 220 | self.warmup_steps = warmup_steps 221 | self.min_lr = min_lr 222 | self.eta_min = eta_min 223 | self.T_0 = T_0 224 | self.eta_max = eta_max 225 | self.T_mul = T_mul 226 | self.T_mult = T_mult 227 | 228 | super().__init__(optimizer, last_epoch) 229 | 230 | def __repr__(self): 231 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, lr={self.base_lr}, min_lr={self.min_lr}, last_epoch={self.last_epoch})" 232 | 233 | def ctxadjust_lr(self,T_mul = 1,T_0=15000,T_mult=1.5,eta_min=0.0000001,eta_max=0.00006,tmctx=0.99,ws=8000): 234 | step_num = self.last_epoch+1 235 | if T_mul == 2: 236 | i = np.log2(step_num / T_0 + 1).astype(np.int32) 237 | T_cur = step_num - T_0 * (T_mult ** (i) - 1) 238 | T_i = (T_0 * T_mult ** i) 239 | elif T_mul == 1: 240 | T_cur = (step_num + ws) % T_0 241 | T_i = T_0 242 | T_curX = (step_num + ws) // T_0 243 | 244 | 245 | cur_lr = eta_min + 0.5 * (eta_max *(tmctx**T_curX)- eta_min*(tmctx**T_curX)) * (1 + np.cos(np.pi * T_cur / T_i)) 246 | if ws>step_num: 247 | cur_lr=step_num*(eta_max/ws) 248 | 249 | return cur_lr 250 | class V3LSGDRLR(_LRScheduler): 251 | """The WarmupLR schedulerA 252 | This scheduler is almost same as NoamLR Scheduler except for following 253 | difference: 254 | NoamLR: 255 | lr = optimizer.lr * model_size ** -0.5 256 | * min(step ** -0.5, step * warmup_step ** -1.5) 257 | WarmupLR: 258 | lr = optimizer.lr * warmup_step ** 0.5 259 | * min(step ** -0.5, step * warmup_step ** -1.5) 260 | Note that the maximum lr equals to optimizer.lr in this scheduler. 261 | """ 262 | def __init__(self,optimizer: torch.optim.Optimizer,warmup_steps: Union[int, float] = 25000,min_lr=1e-5,last_epoch: int = -1, T_0=1500, eta_max=0.1, eta_min=0., T_mul=2, T_mult=0.9999): 263 | # assert check_argument_types() 264 | self.warmup_steps = warmup_steps 265 | self.min_lr = min_lr 266 | self.eta_min = eta_min 267 | self.T_0 = T_0 268 | self.eta_max = eta_max 269 | self.T_mul = T_mul 270 | self.T_mult = T_mult 271 | super().__init__(optimizer, last_epoch) 272 | def __repr__(self): 273 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, lr={self.base_lr}, min_lr={self.min_lr}, last_epoch={self.last_epoch})" 274 | def ctxadjust_lr(self, T_0=15000, eta_min=0.00006, eta_max=0.00009, tmctx=0.98, ws=5000): 275 | step_num = self.last_epoch + 1 #+360000 276 | T_cur = (step_num + ws) % T_0 277 | T_i = T_0 278 | T_curX = (step_num + ws) // T_0 279 | cur_lr = eta_min * (tmctx ** T_curX) + 0.5 * (eta_max * (tmctx ** T_curX) - eta_min * (tmctx ** T_curX)) * ( 280 | 1 + np.cos(np.pi * T_cur / T_i)) 281 | if ws > step_num: 282 | cur_lr = step_num * (eta_max / ws) 283 | return cur_lr 284 | 285 | 286 | 287 | def get_lr(self): 288 | # step_num = self.last_epoch + 1 289 | if self.warmup_steps == 0: 290 | lrs = [] 291 | for lr in self.base_lrs: 292 | lr = self.ctxadjust_lr() 293 | lrs.append(lr) 294 | return lrs 295 | else: 296 | lrs = [] 297 | for lr in self.base_lrs: 298 | lr = self.ctxadjust_lr() 299 | lrs.append(lr) 300 | return lrs 301 | 302 | def set_step(self, step: int): 303 | self.last_epoch = step 304 | 305 | 306 | 307 | class NoamHoldAnnealing(_LRScheduler): 308 | def __init__(self, optimizer, max_steps=175680, warmup_steps=None, warmup_ratio=0.2, hold_steps=None, 309 | hold_ratio=0.3, decay_rate=1.0, min_lr=1.e-5, last_epoch=-1): 310 | """ 311 | From Nemo: 312 | Implementation of the Noam Hold Annealing policy from the SqueezeFormer paper. 313 | 314 | Unlike NoamAnnealing, the peak learning rate can be explicitly set for this scheduler. 315 | The schedule first performs linear warmup, 316 | then holds the peak LR, then decays with some schedule for 317 | the remainder of the steps. 318 | Therefore the min-lr is still dependent on the hyper parameters selected. 319 | 320 | It's schedule is determined by three factors- 321 | 322 | Warmup Steps: Initial stage, where linear warmup 323 | occurs uptil the peak LR is reached. Unlike NoamAnnealing, 324 | the peak LR is explicitly stated here instead of a scaling factor. 325 | 326 | Hold Steps: Intermediate stage, where the peak LR 327 | is maintained for some number of steps. In this region, 328 | the high peak LR allows the model to converge faster 329 | if training is stable. However the high LR 330 | may also cause instability during training. 331 | Should usually be a significant fraction of training 332 | steps (around 30-40% of the entire training steps). 333 | 334 | Decay Steps: Final stage, where the LR rapidly decays 335 | with some scaling rate (set by decay rate). 336 | To attain Noam decay, use 0.5, 337 | for Squeezeformer recommended decay, use 1.0. 338 | The fast decay after prolonged high LR during 339 | hold phase allows for rapid convergence. 340 | 341 | References: 342 | - [Squeezeformer: 343 | An Efficient Transformer for Automatic Speech Recognition] 344 | (https://arxiv.org/abs/2206.00888) 345 | 346 | Args: 347 | optimizer: Pytorch compatible Optimizer object. 348 | warmup_steps: Number of training steps in warmup stage 349 | warmup_ratio: Ratio of warmup steps to total steps 350 | hold_steps: Number of training steps to hold the learning rate after warm up 351 | hold_ratio: Ratio of hold steps to total steps 352 | max_steps: Total number of steps while training or `None` for infinite training 353 | decay_rate: Float value describing the polynomial decay 354 | after the hold period. Default value of 0.5 corresponds to Noam decay. 355 | min_lr: Minimum learning rate. 356 | """ 357 | self.decay_rate = decay_rate 358 | self.min_lr = min_lr 359 | self._last_warmup_lr = 0.0 360 | 361 | # Necessary to duplicate as class attributes are hidden in inner class 362 | self.max_steps = max_steps 363 | if warmup_steps is not None: 364 | self.warmup_steps = warmup_steps 365 | elif warmup_ratio is not None: 366 | self.warmup_steps = int(warmup_ratio * max_steps) 367 | else: 368 | self.warmup_steps = 0 369 | 370 | if hold_steps is not None: 371 | self.hold_steps = hold_steps + self.warmup_steps 372 | elif hold_ratio is not None: 373 | self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps 374 | else: 375 | self.hold_steps = 0 376 | 377 | super().__init__(optimizer, last_epoch) 378 | 379 | def _get_warmup_lr(self, step): 380 | lr_val = (step + 1) / (self.warmup_steps + 1) 381 | return [initial_lr * lr_val for initial_lr in self.base_lrs] 382 | 383 | def get_lr(self): 384 | step = self.last_epoch 385 | 386 | # Warmup phase 387 | if step <= self.warmup_steps and self.warmup_steps > 0: 388 | return self._get_warmup_lr(step) 389 | 390 | # Hold phase 391 | if (step >= self.warmup_steps) and (step < self.hold_steps): 392 | return self.base_lrs 393 | 394 | if step > self.max_steps: 395 | return [self.min_lr for _ in self.base_lrs] 396 | 397 | return self._get_lr(step) 398 | 399 | @staticmethod 400 | def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr): 401 | # hold_steps = total number of steps 402 | # to hold the LR, not the warmup + hold steps. 403 | T_warmup_decay = max(1, warmup_steps ** decay_rate) 404 | T_hold_decay = max(1, (step - hold_steps) ** decay_rate) 405 | lr = (initial_lr * T_warmup_decay) / T_hold_decay 406 | lr = max(lr, min_lr) 407 | return lr 408 | 409 | def _get_lr(self, step): 410 | if self.warmup_steps is None or self.warmup_steps == 0: 411 | raise ValueError("Noam scheduler cannot be used without warmup steps") 412 | 413 | if self.hold_steps > 0: 414 | hold_steps = self.hold_steps - self.warmup_steps 415 | else: 416 | hold_steps = 0 417 | 418 | new_lrs = [ 419 | self._noam_hold_annealing(initial_lr=initial_lr, 420 | step=step, 421 | warmup_steps=self.warmup_steps, 422 | hold_steps=hold_steps, 423 | decay_rate=self.decay_rate, 424 | min_lr=self.min_lr) 425 | for initial_lr in self.base_lrs 426 | ] 427 | return new_lrs 428 | 429 | def set_step(self, step: int): 430 | self.last_epoch = step 431 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openvpi/SOME/dc9916fb35350748351ddcdfc128fe908ef53f5d/modules/__init__.py -------------------------------------------------------------------------------- /modules/attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openvpi/SOME/dc9916fb35350748351ddcdfc128fe908ef53f5d/modules/attention/__init__.py -------------------------------------------------------------------------------- /modules/attention/base_attention.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | 8 | class Attention(nn.Module): 9 | def __init__(self, dim, heads=4, dim_head=32, conditiondim=None): 10 | super().__init__() 11 | if conditiondim is None: 12 | conditiondim = dim 13 | 14 | self.scale = dim_head ** -0.5 15 | self.heads = heads 16 | hidden_dim = dim_head * heads 17 | self.to_q = nn.Linear(dim, hidden_dim, bias=False) 18 | self.to_kv = nn.Linear(conditiondim, hidden_dim * 2, bias=False) 19 | 20 | self.to_out = nn.Sequential(nn.Linear(hidden_dim, dim, ), 21 | ) 22 | 23 | def forward(self, q, kv=None, mask=None): 24 | # b, c, h, w = x.shape 25 | if kv is None: 26 | kv = q 27 | # q, kv = map( 28 | # lambda t: rearrange(t, "b c t -> b t c", ), (q, kv) 29 | # ) 30 | 31 | q = self.to_q(q) 32 | k, v = self.to_kv(kv).chunk(2, dim=2) 33 | 34 | q, k, v = map( 35 | lambda t: rearrange(t, "b t (h c) -> b h t c", h=self.heads), (q, k, v) 36 | ) 37 | 38 | if mask is not None: 39 | mask = mask.unsqueeze(1).unsqueeze(1) 40 | 41 | with torch.backends.cuda.sdp_kernel(enable_math=False 42 | ): 43 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) 44 | 45 | out = rearrange(out, "b h t c -> b t (h c) ", h=self.heads, ) 46 | return self.to_out(out) 47 | -------------------------------------------------------------------------------- /modules/commons/__init__.py: -------------------------------------------------------------------------------- 1 | from .tts_modules import LengthRegulator 2 | -------------------------------------------------------------------------------- /modules/commons/tts_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class LengthRegulator(torch.nn.Module): 6 | # noinspection PyMethodMayBeStatic 7 | def forward(self, dur, dur_padding=None, alpha=None): 8 | """ 9 | Example (no batch dim version): 10 | 1. dur = [2,2,3] 11 | 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4] 12 | 3. token_mask = [[1,1,0,0,0,0,0], 13 | [0,0,1,1,0,0,0], 14 | [0,0,0,0,1,1,1]] 15 | 4. token_idx * token_mask = [[1,1,0,0,0,0,0], 16 | [0,0,2,2,0,0,0], 17 | [0,0,0,0,3,3,3]] 18 | 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3] 19 | 20 | :param dur: Batch of durations of each frame (B, T_txt) 21 | :param dur_padding: Batch of padding of each frame (B, T_txt) 22 | :param alpha: duration rescale coefficient 23 | :return: 24 | mel2ph (B, T_speech) 25 | """ 26 | assert alpha is None or alpha > 0 27 | if alpha is not None: 28 | dur = torch.round(dur.float() * alpha).long() 29 | if dur_padding is not None: 30 | dur = dur * (1 - dur_padding.long()) 31 | token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device) 32 | dur_cumsum = torch.cumsum(dur, 1) 33 | dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0) 34 | 35 | pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device) 36 | token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None]) 37 | mel2ph = (token_idx * token_mask.long()).sum(1) 38 | return mel2ph 39 | -------------------------------------------------------------------------------- /modules/conform/Gconform.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | from modules.attention.base_attention import Attention 9 | from modules.conv.base_conv import conform_conv 10 | class GLU(nn.Module): 11 | def __init__(self, dim): 12 | super().__init__() 13 | self.dim = dim 14 | 15 | def forward(self, x): 16 | out, gate = x.chunk(2, dim=self.dim) 17 | 18 | return out * gate.sigmoid() 19 | 20 | class conform_ffn(nn.Module): 21 | def __init__(self, dim, DropoutL1: float = 0.1, DropoutL2: float = 0.1): 22 | super().__init__() 23 | self.ln1 = nn.Linear(dim, dim * 4) 24 | self.ln2 = nn.Linear(dim * 4, dim) 25 | self.drop1 = nn.Dropout(DropoutL1) if DropoutL1 > 0. else nn.Identity() 26 | self.drop2 = nn.Dropout(DropoutL2) if DropoutL2 > 0. else nn.Identity() 27 | self.act = nn.SiLU() 28 | 29 | def forward(self, x): 30 | x = self.ln1(x) 31 | x = self.act(x) 32 | x = self.drop1(x) 33 | x = self.ln2(x) 34 | return self.drop2(x) 35 | 36 | 37 | class conform_blocke(nn.Module): 38 | def __init__(self, dim: int, kernel_size: int = 31, conv_drop: float = 0.1, ffn_latent_drop: float = 0.1, 39 | ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4, 40 | attention_heads_dim: int = 64): 41 | super().__init__() 42 | self.ffn1 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop) 43 | self.ffn2 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop) 44 | self.att = Attention(dim, heads=attention_heads, dim_head=attention_heads_dim) 45 | self.attdrop = nn.Dropout(attention_drop) if attention_drop > 0. else nn.Identity() 46 | self.conv = conform_conv(dim, kernel_size=kernel_size, 47 | 48 | DropoutL=conv_drop, ) 49 | self.norm1 = nn.LayerNorm(dim) 50 | self.norm2 = nn.LayerNorm(dim) 51 | self.norm3 = nn.LayerNorm(dim) 52 | self.norm4 = nn.LayerNorm(dim) 53 | self.norm5 = nn.LayerNorm(dim) 54 | 55 | 56 | def forward(self, x, mask=None,): 57 | x = self.ffn1(self.norm1(x)) * 0.5 + x 58 | 59 | 60 | x = self.attdrop(self.att(self.norm2(x), mask=mask)) + x 61 | x = self.conv(self.norm3(x)) + x 62 | x = self.ffn2(self.norm4(x)) * 0.5 + x 63 | return self.norm5(x) 64 | 65 | # return x 66 | 67 | 68 | class Gcf(nn.Module): 69 | def __init__(self,dim: int, kernel_size: int = 31, conv_drop: float = 0.1, ffn_latent_drop: float = 0.1, 70 | ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4, 71 | attention_heads_dim: int = 64): 72 | super().__init__() 73 | self.att1=conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop, 74 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads, 75 | attention_heads_dim=attention_heads_dim) 76 | self.att2 = conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop, 77 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads, 78 | attention_heads_dim=attention_heads_dim) 79 | self.glu1=nn.Sequential(nn.Linear(dim, dim*2),GLU(2) ) 80 | self.glu2 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2)) 81 | 82 | def forward(self, midi,bound): 83 | midi=self.att1(midi) 84 | bound=self.att2(bound) 85 | midis=self.glu1(midi) 86 | bounds=self.glu2(bound) 87 | return midi+bounds,bound+midis 88 | 89 | 90 | 91 | 92 | class Gmidi_conform(nn.Module): 93 | def __init__(self, lay: int, dim: int, indim: int, outdim: int, use_lay_skip: bool, kernel_size: int = 31, 94 | conv_drop: float = 0.1, 95 | ffn_latent_drop: float = 0.1, 96 | ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4, 97 | attention_heads_dim: int = 64): 98 | super().__init__() 99 | 100 | self.inln = nn.Linear(indim, dim) 101 | self.inln1 = nn.Linear(indim, dim) 102 | self.outln = nn.Linear(dim, outdim) 103 | self.cutheard = nn.Linear(dim, 1) 104 | # self.cutheard = nn.Linear(dim, outdim) 105 | self.lay = lay 106 | self.use_lay_skip = use_lay_skip 107 | self.cf_lay = nn.ModuleList( 108 | [Gcf(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop, 109 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads, 110 | attention_heads_dim=attention_heads_dim) for _ in range(lay)]) 111 | self.att1=conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop, 112 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads, 113 | attention_heads_dim=attention_heads_dim) 114 | self.att2 = conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop, 115 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads, 116 | attention_heads_dim=attention_heads_dim) 117 | 118 | 119 | def forward(self, x, pitch, mask=None): 120 | 121 | # torch.masked_fill() 122 | x1=x.clone() 123 | 124 | x = self.inln(x ) 125 | x1=self.inln1(x1) 126 | if mask is not None: 127 | x = x.masked_fill(~mask.unsqueeze(-1), 0) 128 | for idx, i in enumerate(self.cf_lay): 129 | x,x1 = i(x,x1) 130 | 131 | if mask is not None: 132 | x = x.masked_fill(~mask.unsqueeze(-1), 0) 133 | x,x1=self.att1(x),self.att2(x1) 134 | 135 | cutprp = self.cutheard(x1) 136 | midiout = self.outln(x) 137 | cutprp = torch.sigmoid(cutprp) 138 | cutprp = torch.squeeze(cutprp, -1) 139 | 140 | return midiout, cutprp 141 | -------------------------------------------------------------------------------- /modules/conform/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openvpi/SOME/dc9916fb35350748351ddcdfc128fe908ef53f5d/modules/conform/__init__.py -------------------------------------------------------------------------------- /modules/contentvec/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fairseq import checkpoint_utils 3 | 4 | 5 | class ContentVec768L12(torch.nn.Module): 6 | def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): 7 | super().__init__() 8 | self.device = device 9 | models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="") 10 | self.hubert = models[0].to(self.device).eval() 11 | 12 | def forward(self, waveform): # B, T 13 | feats = waveform.view(1, -1) 14 | padding_mask = torch.BoolTensor(feats.shape).fill_(False) 15 | inputs = { 16 | "source": feats.to(waveform.device), 17 | "padding_mask": padding_mask.to(waveform.device), 18 | "output_layer": 9, # layer 9 19 | } 20 | with torch.no_grad(): 21 | logits = self.hubert.extract_features(**inputs) 22 | feats = logits[0] 23 | units = feats # .transpose(2, 1) 24 | return units 25 | -------------------------------------------------------------------------------- /modules/conv/base_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | 7 | class GLU(nn.Module): 8 | def __init__(self, dim): 9 | super().__init__() 10 | self.dim = dim 11 | 12 | def forward(self, x): 13 | out, gate = x.chunk(2, dim=self.dim) 14 | 15 | return out * gate.sigmoid() 16 | 17 | 18 | class conform_conv(nn.Module): 19 | def __init__(self, channels: int, 20 | kernel_size: int = 31, 21 | 22 | DropoutL=0.1, 23 | 24 | bias: bool = True): 25 | super().__init__() 26 | self.act2 = nn.SiLU() 27 | self.act1 = GLU(1) 28 | 29 | self.pointwise_conv1 = nn.Conv1d( 30 | channels, 31 | 2 * channels, 32 | kernel_size=1, 33 | stride=1, 34 | padding=0, 35 | bias=bias) 36 | 37 | # self.lorder is used to distinguish if it's a causal convolution, 38 | # if self.lorder > 0: 39 | # it's a causal convolution, the input will be padded with 40 | # `self.lorder` frames on the left in forward (causal conv impl). 41 | # else: it's a symmetrical convolution 42 | 43 | assert (kernel_size - 1) % 2 == 0 44 | padding = (kernel_size - 1) // 2 45 | 46 | self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, 47 | stride=1, 48 | padding=padding, 49 | groups=channels, 50 | bias=bias) 51 | 52 | 53 | self.norm = nn.BatchNorm1d(channels) 54 | 55 | 56 | self.pointwise_conv2 = nn.Conv1d(channels, 57 | channels, 58 | kernel_size=1, 59 | stride=1, 60 | padding=0, 61 | bias=bias) 62 | self.drop=nn.Dropout(DropoutL) if DropoutL>0. else nn.Identity() 63 | def forward(self,x): 64 | x=x.transpose(1,2) 65 | x=self.act1(self.pointwise_conv1(x)) 66 | x=self.depthwise_conv (x) 67 | x=self.norm(x) 68 | x=self.act2(x) 69 | x=self.pointwise_conv2(x) 70 | return self.drop(x).transpose(1,2) 71 | 72 | -------------------------------------------------------------------------------- /modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .bound_loss import BinaryEMDLoss, BoundaryLoss 2 | -------------------------------------------------------------------------------- /modules/losses/bound_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn 4 | 5 | 6 | class BinaryEMDLoss(torch.nn.Module): 7 | def __init__(self, bidirectional=False): 8 | super().__init__() 9 | self.loss = torch.nn.L1Loss() 10 | self.bidirectional = bidirectional 11 | 12 | def forward(self, pred, gt): 13 | # pred, gt: [B, T] 14 | scale = math.sqrt(gt.shape[1]) 15 | loss = self.loss(pred.cumsum(dim=1) / scale, gt.cumsum(dim=1) / scale) 16 | if self.bidirectional: 17 | loss += self.loss(pred.flip(1).cumsum(dim=1) / scale, gt.flip(1).cumsum(dim=1) / scale) 18 | loss /= 2 19 | return loss 20 | 21 | 22 | class BoundaryLoss(torch.nn.Module): 23 | def __init__(self, lambda_bce=0.1): 24 | super().__init__() 25 | self.emd = BinaryEMDLoss(bidirectional=False) 26 | self.bce = torch.nn.BCELoss() 27 | self.lambda_bce = lambda_bce 28 | 29 | def forward(self, pred, gt): 30 | # pred, gt: [B, T] 31 | emd_loss = self.emd(pred, gt) 32 | bce_loss = self.bce(pred, gt) 33 | return emd_loss + self.lambda_bce * bce_loss 34 | -------------------------------------------------------------------------------- /modules/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .midi_acc import MIDIAccuracy 2 | -------------------------------------------------------------------------------- /modules/metrics/midi_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | from torch import Tensor 4 | 5 | 6 | class MIDIAccuracy(torchmetrics.Metric): 7 | def __init__(self, *, tolerance, **kwargs): 8 | super().__init__(**kwargs) 9 | self.tolerance = tolerance 10 | self.add_state('correct', default=torch.tensor(0, dtype=torch.int), dist_reduce_fx='sum') 11 | self.add_state('total', default=torch.tensor(0, dtype=torch.int), dist_reduce_fx='sum') 12 | 13 | def update(self, midi_pred: Tensor, rest_pred: Tensor, midi_gt: Tensor, rest_gt: Tensor, mask=None) -> None: 14 | """ 15 | 16 | :param midi_pred: predicted MIDI 17 | :param rest_pred: predict rest flags 18 | :param midi_gt: reference MIDI 19 | :param rest_gt: reference rest flags 20 | :param mask: valid or non-padding mask 21 | """ 22 | assert midi_gt.shape == rest_gt.shape == midi_pred.shape == rest_pred.shape, \ 23 | (f'shapes of pred and gt mismatch: ' 24 | f'{midi_pred.shape}, {rest_pred.shape}, {midi_gt.shape}, {rest_gt.shape}') 25 | if mask is not None: 26 | assert midi_gt.shape == mask.shape, \ 27 | f'shapes of pred, target and mask mismatch: {midi_pred.shape}, {rest_pred.shape}, {mask.shape}' 28 | midi_close = ~rest_pred & ~rest_gt & (torch.abs(midi_pred - midi_gt) <= self.tolerance) 29 | rest_correct = rest_pred == rest_gt 30 | overall = midi_close & rest_correct 31 | if mask is not None: 32 | overall &= mask 33 | 34 | self.correct += overall.sum() 35 | self.total += midi_gt.numel() if mask is None else mask.sum() 36 | 37 | def compute(self) -> Tensor: 38 | return self.correct / self.total 39 | -------------------------------------------------------------------------------- /modules/model/Gmidi_conform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modules.conform.Gconform import Gmidi_conform 5 | 6 | 7 | 8 | class midi_loss(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | self.loss = nn.BCELoss() 12 | 13 | def forward(self, x, target): 14 | midiout, cutp = x 15 | midi_target, cutp_target = target 16 | 17 | cutploss = self.loss(cutp, cutp_target) 18 | midiloss = self.loss(midiout, midi_target) 19 | return midiloss, cutploss 20 | 21 | 22 | class midi_conforms(nn.Module): 23 | def __init__(self, config): 24 | super().__init__() 25 | 26 | cfg = config['midi_extractor_args'] 27 | cfg.update({'indim': config['units_dim'], 'outdim': config['midi_num_bins']}) 28 | self.model = Gmidi_conform(**cfg) 29 | 30 | def forward(self, x, f0, mask=None,softmax=False,sig=False): 31 | 32 | midi,bound=self.model(x, f0, mask) 33 | if sig: 34 | midi = torch.sigmoid(midi) 35 | 36 | if softmax: 37 | midi=F.softmax(midi,dim=2) 38 | 39 | 40 | return midi,bound 41 | 42 | def get_loss(self): 43 | return midi_loss() 44 | -------------------------------------------------------------------------------- /modules/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openvpi/SOME/dc9916fb35350748351ddcdfc128fe908ef53f5d/modules/model/__init__.py -------------------------------------------------------------------------------- /modules/rmvpe/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import * 2 | from .model import E2E0 3 | from .utils import to_local_average_f0, to_viterbi_f0 4 | from .inference import RMVPE 5 | from .spec import MelSpectrogram 6 | -------------------------------------------------------------------------------- /modules/rmvpe/constants.py: -------------------------------------------------------------------------------- 1 | SAMPLE_RATE = 16000 2 | 3 | N_CLASS = 360 4 | 5 | N_MELS = 128 6 | MEL_FMIN = 30 7 | MEL_FMAX = 8000 8 | WINDOW_LENGTH = 1024 9 | CONST = 1997.3794084376191 10 | -------------------------------------------------------------------------------- /modules/rmvpe/deepunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .constants import N_MELS 4 | 5 | 6 | class ConvBlockRes(nn.Module): 7 | def __init__(self, in_channels, out_channels, momentum=0.01): 8 | super(ConvBlockRes, self).__init__() 9 | self.conv = nn.Sequential( 10 | nn.Conv2d(in_channels=in_channels, 11 | out_channels=out_channels, 12 | kernel_size=(3, 3), 13 | stride=(1, 1), 14 | padding=(1, 1), 15 | bias=False), 16 | nn.BatchNorm2d(out_channels, momentum=momentum), 17 | nn.ReLU(), 18 | 19 | nn.Conv2d(in_channels=out_channels, 20 | out_channels=out_channels, 21 | kernel_size=(3, 3), 22 | stride=(1, 1), 23 | padding=(1, 1), 24 | bias=False), 25 | nn.BatchNorm2d(out_channels, momentum=momentum), 26 | nn.ReLU(), 27 | ) 28 | if in_channels != out_channels: 29 | self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) 30 | self.is_shortcut = True 31 | else: 32 | self.is_shortcut = False 33 | 34 | def forward(self, x): 35 | if self.is_shortcut: 36 | return self.conv(x) + self.shortcut(x) 37 | else: 38 | return self.conv(x) + x 39 | 40 | 41 | class ResEncoderBlock(nn.Module): 42 | def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): 43 | super(ResEncoderBlock, self).__init__() 44 | self.n_blocks = n_blocks 45 | self.conv = nn.ModuleList() 46 | self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) 47 | for i in range(n_blocks - 1): 48 | self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) 49 | self.kernel_size = kernel_size 50 | if self.kernel_size is not None: 51 | self.pool = nn.AvgPool2d(kernel_size=kernel_size) 52 | 53 | def forward(self, x): 54 | for i in range(self.n_blocks): 55 | x = self.conv[i](x) 56 | if self.kernel_size is not None: 57 | return x, self.pool(x) 58 | else: 59 | return x 60 | 61 | 62 | class ResDecoderBlock(nn.Module): 63 | def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): 64 | super(ResDecoderBlock, self).__init__() 65 | out_padding = (0, 1) if stride == (1, 2) else (1, 1) 66 | self.n_blocks = n_blocks 67 | self.conv1 = nn.Sequential( 68 | nn.ConvTranspose2d(in_channels=in_channels, 69 | out_channels=out_channels, 70 | kernel_size=(3, 3), 71 | stride=stride, 72 | padding=(1, 1), 73 | output_padding=out_padding, 74 | bias=False), 75 | nn.BatchNorm2d(out_channels, momentum=momentum), 76 | nn.ReLU(), 77 | ) 78 | self.conv2 = nn.ModuleList() 79 | self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) 80 | for i in range(n_blocks-1): 81 | self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) 82 | 83 | def forward(self, x, concat_tensor): 84 | x = self.conv1(x) 85 | x = torch.cat((x, concat_tensor), dim=1) 86 | for i in range(self.n_blocks): 87 | x = self.conv2[i](x) 88 | return x 89 | 90 | 91 | class Encoder(nn.Module): 92 | def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): 93 | super(Encoder, self).__init__() 94 | self.n_encoders = n_encoders 95 | self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) 96 | self.layers = nn.ModuleList() 97 | self.latent_channels = [] 98 | for i in range(self.n_encoders): 99 | self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)) 100 | self.latent_channels.append([out_channels, in_size]) 101 | in_channels = out_channels 102 | out_channels *= 2 103 | in_size //= 2 104 | self.out_size = in_size 105 | self.out_channel = out_channels 106 | 107 | def forward(self, x): 108 | concat_tensors = [] 109 | x = self.bn(x) 110 | for i in range(self.n_encoders): 111 | _, x = self.layers[i](x) 112 | concat_tensors.append(_) 113 | return x, concat_tensors 114 | 115 | 116 | class Intermediate(nn.Module): 117 | def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): 118 | super(Intermediate, self).__init__() 119 | self.n_inters = n_inters 120 | self.layers = nn.ModuleList() 121 | self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)) 122 | for i in range(self.n_inters-1): 123 | self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) 124 | 125 | def forward(self, x): 126 | for i in range(self.n_inters): 127 | x = self.layers[i](x) 128 | return x 129 | 130 | 131 | class Decoder(nn.Module): 132 | def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): 133 | super(Decoder, self).__init__() 134 | self.layers = nn.ModuleList() 135 | self.n_decoders = n_decoders 136 | for i in range(self.n_decoders): 137 | out_channels = in_channels // 2 138 | self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) 139 | in_channels = out_channels 140 | 141 | def forward(self, x, concat_tensors): 142 | for i in range(self.n_decoders): 143 | x = self.layers[i](x, concat_tensors[-1-i]) 144 | return x 145 | 146 | 147 | class TimbreFilter(nn.Module): 148 | def __init__(self, latent_rep_channels): 149 | super(TimbreFilter, self).__init__() 150 | self.layers = nn.ModuleList() 151 | for latent_rep in latent_rep_channels: 152 | self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0])) 153 | 154 | def forward(self, x_tensors): 155 | out_tensors = [] 156 | for i, layer in enumerate(self.layers): 157 | out_tensors.append(layer(x_tensors[i])) 158 | return out_tensors 159 | 160 | 161 | class DeepUnet0(nn.Module): 162 | def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): 163 | super(DeepUnet0, self).__init__() 164 | self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) 165 | self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) 166 | self.tf = TimbreFilter(self.encoder.latent_channels) 167 | self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) 168 | 169 | def forward(self, x): 170 | x, concat_tensors = self.encoder(x) 171 | x = self.intermediate(x) 172 | x = self.decoder(x, concat_tensors) 173 | return x 174 | -------------------------------------------------------------------------------- /modules/rmvpe/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torchaudio.transforms import Resample 5 | 6 | from utils.pitch_utils import interp_f0, resample_align_curve 7 | from .constants import * 8 | from .model import E2E0 9 | from .spec import MelSpectrogram 10 | from .utils import to_local_average_f0, to_viterbi_f0 11 | 12 | 13 | class RMVPE: 14 | def __init__(self, model_path, hop_length=160, device=None): 15 | self.resample_kernel = {} 16 | if device is None: 17 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | else: 19 | self.device = device 20 | self.model = E2E0(4, 1, (2, 2)).eval().to(self.device) 21 | ckpt = torch.load(model_path, map_location=self.device) 22 | self.model.load_state_dict(ckpt['model'], strict=False) 23 | self.mel_extractor = MelSpectrogram( 24 | N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX 25 | ).to(self.device) 26 | 27 | @torch.no_grad() 28 | def mel2hidden(self, mel): 29 | n_frames = mel.shape[-1] 30 | mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant') 31 | hidden = self.model(mel) 32 | return hidden[:, :n_frames] 33 | 34 | def decode(self, hidden, thred=0.03, use_viterbi=False): 35 | if use_viterbi: 36 | f0 = to_viterbi_f0(hidden, thred=thred) 37 | else: 38 | f0 = to_local_average_f0(hidden, thred=thred) 39 | return f0 40 | 41 | def infer_from_audio(self, audio, sample_rate=16000, thred=0.03, use_viterbi=False): 42 | audio = torch.from_numpy(audio).float().unsqueeze(0).to(self.device) 43 | if sample_rate == 16000: 44 | audio_res = audio 45 | else: 46 | key_str = str(sample_rate) 47 | if key_str not in self.resample_kernel: 48 | self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) 49 | self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.device) 50 | audio_res = self.resample_kernel[key_str](audio) 51 | mel = self.mel_extractor(audio_res, center=True) 52 | hidden = self.mel2hidden(mel) 53 | f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi) 54 | return f0 55 | 56 | def get_pitch(self, waveform, sample_rate, hop_size, length, interp_uv=False): 57 | f0 = self.infer_from_audio(waveform, sample_rate=sample_rate) 58 | uv = f0 == 0 59 | f0, uv = interp_f0(f0, uv) 60 | 61 | time_step = hop_size / sample_rate 62 | f0_res = resample_align_curve(f0, 0.01, time_step, length) 63 | uv_res = resample_align_curve(uv.astype(np.float32), 0.01, time_step, length) > 0.5 64 | if not interp_uv: 65 | f0_res[uv_res] = 0 66 | return f0_res, uv_res 67 | -------------------------------------------------------------------------------- /modules/rmvpe/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .constants import * 4 | from .deepunet import DeepUnet0 5 | from .seq import BiGRU 6 | 7 | 8 | class E2E0(nn.Module): 9 | def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, 10 | en_out_channels=16): 11 | super(E2E0, self).__init__() 12 | self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) 13 | self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) 14 | if n_gru: 15 | self.fc = nn.Sequential( 16 | BiGRU(3 * N_MELS, 256, n_gru), 17 | nn.Linear(512, N_CLASS), 18 | nn.Dropout(0.25), 19 | nn.Sigmoid() 20 | ) 21 | else: 22 | self.fc = nn.Sequential( 23 | nn.Linear(3 * N_MELS, N_CLASS), 24 | nn.Dropout(0.25), 25 | nn.Sigmoid() 26 | ) 27 | 28 | def forward(self, mel): 29 | mel = mel.transpose(-1, -2).unsqueeze(1) 30 | x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) 31 | x = self.fc(x) 32 | return x 33 | -------------------------------------------------------------------------------- /modules/rmvpe/seq.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BiGRU(nn.Module): 5 | def __init__(self, input_features, hidden_features, num_layers): 6 | super(BiGRU, self).__init__() 7 | self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) 8 | 9 | def forward(self, x): 10 | return self.gru(x)[0] 11 | -------------------------------------------------------------------------------- /modules/rmvpe/spec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from librosa.filters import mel 5 | 6 | 7 | class MelSpectrogram(torch.nn.Module): 8 | def __init__( 9 | self, 10 | n_mel_channels, 11 | sampling_rate, 12 | win_length, 13 | hop_length, 14 | n_fft=None, 15 | mel_fmin=0, 16 | mel_fmax=None, 17 | clamp=1e-5 18 | ): 19 | super().__init__() 20 | n_fft = win_length if n_fft is None else n_fft 21 | self.hann_window = {} 22 | mel_basis = mel( 23 | sr=sampling_rate, 24 | n_fft=n_fft, 25 | n_mels=n_mel_channels, 26 | fmin=mel_fmin, 27 | fmax=mel_fmax, 28 | htk=True) 29 | mel_basis = torch.from_numpy(mel_basis).float() 30 | self.register_buffer("mel_basis", mel_basis) 31 | self.n_fft = win_length if n_fft is None else n_fft 32 | self.hop_length = hop_length 33 | self.win_length = win_length 34 | self.sampling_rate = sampling_rate 35 | self.n_mel_channels = n_mel_channels 36 | self.clamp = clamp 37 | 38 | def forward(self, audio, keyshift=0, speed=1, center=True): 39 | factor = 2 ** (keyshift / 12) 40 | n_fft_new = int(np.round(self.n_fft * factor)) 41 | win_length_new = int(np.round(self.win_length * factor)) 42 | hop_length_new = int(np.round(self.hop_length * speed)) 43 | 44 | keyshift_key = str(keyshift) + '_' + str(audio.device) 45 | if keyshift_key not in self.hann_window: 46 | self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) 47 | if center: 48 | pad_left = win_length_new // 2 49 | pad_right = (win_length_new + 1) // 2 50 | audio = F.pad(audio, (pad_left, pad_right)) 51 | 52 | fft = torch.stft( 53 | audio, 54 | n_fft=n_fft_new, 55 | hop_length=hop_length_new, 56 | win_length=win_length_new, 57 | window=self.hann_window[keyshift_key], 58 | center=False, 59 | return_complex=True 60 | ) 61 | magnitude = fft.abs() 62 | 63 | if keyshift != 0: 64 | size = self.n_fft // 2 + 1 65 | resize = magnitude.size(1) 66 | if resize < size: 67 | magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) 68 | magnitude = magnitude[:, :size, :] * self.win_length / win_length_new 69 | 70 | mel_output = torch.matmul(self.mel_basis, magnitude) 71 | log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) 72 | return log_mel_spec 73 | -------------------------------------------------------------------------------- /modules/rmvpe/utils.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | 5 | from .constants import * 6 | 7 | 8 | def to_local_average_f0(hidden, center=None, thred=0.03): 9 | idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N] 10 | idx_cents = idx * 20 + CONST # [B=1, N] 11 | if center is None: 12 | center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1] 13 | start = torch.clip(center - 4, min=0) # [B, T, 1] 14 | end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1] 15 | idx_mask = (idx >= start) & (idx < end) # [B, T, N] 16 | weights = hidden * idx_mask # [B, T, N] 17 | product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T] 18 | weight_sum = torch.sum(weights, dim=2) # [B, T] 19 | cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T] 20 | f0 = 10 * 2 ** (cents / 1200) 21 | uv = hidden.max(dim=2)[0] < thred # [B, T] 22 | f0 = f0 * ~uv 23 | return f0.squeeze(0).cpu().numpy() 24 | 25 | 26 | def to_viterbi_f0(hidden, thred=0.03): 27 | # Create viterbi transition matrix 28 | if not hasattr(to_viterbi_f0, 'transition'): 29 | xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS)) 30 | transition = np.maximum(30 - abs(xx - yy), 0) 31 | transition = transition / transition.sum(axis=1, keepdims=True) 32 | to_viterbi_f0.transition = transition 33 | 34 | # Convert to probability 35 | prob = hidden.squeeze(0).cpu().numpy() 36 | prob = prob.T 37 | prob = prob / prob.sum(axis=0) 38 | 39 | # Perform viterbi decoding 40 | path = librosa.sequence.viterbi(prob, to_viterbi_f0.transition).astype(np.int64) 41 | center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device) 42 | 43 | return to_local_average_f0(hidden, center=center, thred=thred) 44 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_binarizer import BaseBinarizer 2 | from .me_binarizer import MIDIExtractionBinarizer 3 | from .me_quant_binarizer import QuantizedMIDIExtractionBinarizer 4 | -------------------------------------------------------------------------------- /preprocessing/base_binarizer.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import random 3 | import warnings 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | from utils.indexed_datasets import IndexedDatasetBuilder 11 | from utils.multiprocess_utils import chunked_multiprocess_run 12 | 13 | 14 | class BinarizationError(Exception): 15 | pass 16 | 17 | 18 | class BaseBinarizer: 19 | """ 20 | Base class for data processing. 21 | 1. *process* and *process_data_split*: 22 | process entire data, generate the train-test split (support parallel processing); 23 | 2. *process_item*: 24 | process singe piece of data; 25 | 3. *get_pitch*: 26 | infer the pitch using some algorithm; 27 | 4. *get_align*: 28 | get the alignment using 'mel2ph' format (see https://arxiv.org/abs/1905.09263). 29 | 5. phoneme encoder, voice encoder, etc. 30 | 31 | Subclasses should define: 32 | 1. *load_metadata*: 33 | how to read multiple datasets from files; 34 | 2. *train_item_names*, *valid_item_names*, *test_item_names*: 35 | how to split the dataset; 36 | 3. load_ph_set: 37 | the phoneme set. 38 | """ 39 | 40 | def __init__(self, config: dict, data_attrs=None): 41 | self.config = config 42 | self.raw_data_dirs = [pathlib.Path(d) for d in config['raw_data_dir']] 43 | self.binary_data_dir = pathlib.Path(self.config['binary_data_dir']) 44 | self.data_attrs = [] if data_attrs is None else data_attrs 45 | 46 | self.binarization_args = self.config['binarization_args'] 47 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 48 | 49 | self.items = {} 50 | self.item_names: list = None 51 | self._train_item_names: list = None 52 | self._valid_item_names: list = None 53 | 54 | self.timestep = self.config['hop_size'] / self.config['audio_sample_rate'] 55 | 56 | def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id): 57 | raise NotImplementedError() 58 | 59 | def split_train_valid_set(self): 60 | """ 61 | Split the dataset into training set and validation set. 62 | :return: train_item_names, valid_item_names 63 | """ 64 | prefixes = set([str(pr) for pr in self.config['test_prefixes']]) 65 | valid_item_names = set() 66 | # Add prefixes that specified speaker index and matches exactly item name to test set 67 | for prefix in deepcopy(prefixes): 68 | if prefix in self.item_names: 69 | valid_item_names.add(prefix) 70 | prefixes.remove(prefix) 71 | # Add prefixes that exactly matches item name without speaker id to test set 72 | for prefix in deepcopy(prefixes): 73 | matched = False 74 | for name in self.item_names: 75 | if name.split(':')[-1] == prefix: 76 | valid_item_names.add(name) 77 | matched = True 78 | if matched: 79 | prefixes.remove(prefix) 80 | # Add names with one of the remaining prefixes to test set 81 | for prefix in deepcopy(prefixes): 82 | matched = False 83 | for name in self.item_names: 84 | if name.startswith(prefix): 85 | valid_item_names.add(name) 86 | matched = True 87 | if matched: 88 | prefixes.remove(prefix) 89 | for prefix in deepcopy(prefixes): 90 | matched = False 91 | for name in self.item_names: 92 | if name.split(':')[-1].startswith(prefix): 93 | valid_item_names.add(name) 94 | matched = True 95 | if matched: 96 | prefixes.remove(prefix) 97 | 98 | if len(prefixes) != 0: 99 | warnings.warn( 100 | f'The following rules in test_prefixes have no matching names in the dataset: {sorted(prefixes)}', 101 | category=UserWarning 102 | ) 103 | warnings.filterwarnings('default') 104 | 105 | valid_item_names = sorted(list(valid_item_names)) 106 | assert len(valid_item_names) > 0, 'Validation set is empty!' 107 | train_item_names = [x for x in self.item_names if x not in set(valid_item_names)] 108 | assert len(train_item_names) > 0, 'Training set is empty!' 109 | 110 | return train_item_names, valid_item_names 111 | 112 | @property 113 | def train_item_names(self): 114 | return self._train_item_names 115 | 116 | @property 117 | def valid_item_names(self): 118 | return self._valid_item_names 119 | 120 | def meta_data_iterator(self, prefix): 121 | if prefix == 'train': 122 | item_names = self.train_item_names 123 | else: 124 | item_names = self.valid_item_names 125 | for item_name in item_names: 126 | meta_data = self.items[item_name] 127 | yield item_name, meta_data 128 | 129 | def process(self): 130 | # load each dataset 131 | for ds_id, data_dir in zip(range(len(self.raw_data_dirs)), self.raw_data_dirs): 132 | self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id) 133 | self.item_names = sorted(list(self.items.keys())) 134 | self._train_item_names, self._valid_item_names = self.split_train_valid_set() 135 | 136 | if self.binarization_args['shuffle']: 137 | random.seed(self.config['seed']) 138 | random.shuffle(self.item_names) 139 | 140 | self.binary_data_dir.mkdir(parents=True, exist_ok=True) 141 | self.check_coverage() 142 | 143 | # Process valid set and train set 144 | try: 145 | self.process_dataset('valid') 146 | self.process_dataset( 147 | 'train', 148 | num_workers=int(self.binarization_args['num_workers']), 149 | apply_augmentation=True 150 | ) 151 | except KeyboardInterrupt: 152 | exit(-1) 153 | 154 | def check_coverage(self): 155 | pass 156 | 157 | def process_dataset(self, prefix, num_workers=0, apply_augmentation=False): 158 | args = [] 159 | builder = IndexedDatasetBuilder(self.binary_data_dir, prefix=prefix, allowed_attr=self.data_attrs) 160 | lengths = [] 161 | total_sec = 0 162 | total_raw_sec = 0 163 | 164 | for item_name, meta_data in self.meta_data_iterator(prefix): 165 | args.append([item_name, meta_data, apply_augmentation]) 166 | 167 | def postprocess(_item, _is_raw=True): 168 | nonlocal total_sec, total_raw_sec 169 | if _item is None: 170 | return 171 | builder.add_item(_item) 172 | lengths.append(_item['length']) 173 | total_sec += _item['seconds'] 174 | if _is_raw: 175 | total_raw_sec += _item['seconds'] 176 | 177 | try: 178 | if num_workers > 0: 179 | # code for parallel processing 180 | for items in tqdm( 181 | chunked_multiprocess_run(self.process_item, args, num_workers=num_workers), 182 | total=len(list(self.meta_data_iterator(prefix))) 183 | ): 184 | for i, item in enumerate(items): 185 | postprocess(item, i == 0) 186 | else: 187 | # code for single cpu processing 188 | for a in tqdm(args): 189 | items = self.process_item(*a) 190 | for i, item in enumerate(items): 191 | postprocess(item, i == 0) 192 | except KeyboardInterrupt: 193 | builder.finalize() 194 | raise 195 | 196 | builder.finalize() 197 | with open(self.binary_data_dir / f'{prefix}.lengths', 'wb') as f: 198 | # noinspection PyTypeChecker 199 | np.save(f, lengths) 200 | 201 | if apply_augmentation: 202 | print(f'| {prefix} total duration (before augmentation): {total_raw_sec:.2f}s') 203 | print( 204 | f'| {prefix} total duration (after augmentation): {total_sec:.2f}s ({total_sec / total_raw_sec:.2f}x)') 205 | else: 206 | print(f'| {prefix} total duration: {total_raw_sec:.2f}s') 207 | 208 | def process_item(self, item_name, meta_data, allow_aug=False): 209 | raise NotImplementedError() 210 | -------------------------------------------------------------------------------- /preprocessing/me_binarizer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import csv 3 | import json 4 | import os 5 | import pathlib 6 | import random 7 | 8 | import librosa 9 | import numpy as np 10 | import torch 11 | from scipy import interpolate 12 | 13 | import modules.contentvec 14 | import modules.rmvpe 15 | from modules.commons import LengthRegulator 16 | from utils.binarizer_utils import merge_slurs, merge_rests, get_mel2ph_torch, get_pitch_parselmouth 17 | from utils.pitch_utils import resample_align_curve 18 | from utils.plot import distribution_to_figure 19 | from .base_binarizer import BaseBinarizer 20 | 21 | os.environ["OMP_NUM_THREADS"] = "1" 22 | MIDI_EXTRACTION_ITEM_ATTRIBUTES = [ 23 | 'units', # contentvec units, float32[T_s, 256] 24 | 'pitch', # actual pitch in semitones, float32[T_s,] 25 | 'note_midi', # note-level MIDI pitch, float32[T_n,] 26 | 'note_rest', # flags for rest notes, bool[T_n,] 27 | 'note_dur', # durations of notes, in number of frames, int64[T_n,] 28 | 'unit2note', # mel2ph format for alignment between units and notes 29 | ] 30 | 31 | # These modules are used as global variables due to a PyTorch shared memory bug on Windows platforms. 32 | # See https://github.com/pytorch/pytorch/issues/100358 33 | contentvec = None 34 | mel_spec = None 35 | rmvpe = None 36 | 37 | 38 | class MIDIExtractionBinarizer(BaseBinarizer): 39 | def __init__(self, config: dict): 40 | super().__init__(config, data_attrs=MIDI_EXTRACTION_ITEM_ATTRIBUTES) 41 | self.lr = LengthRegulator().to(self.device) 42 | self.skip_glide = self.binarization_args['skip_glide'] 43 | self.merge_rest = self.binarization_args['merge_rest'] 44 | self.merge_slur = self.binarization_args['merge_slur'] 45 | self.slur_tolerance = self.binarization_args.get('slur_tolerance') 46 | self.round_midi = self.binarization_args.get('round_midi', False) 47 | self.key_shift_min, self.key_shift_max = self.config['key_shift_range'] 48 | 49 | def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id): 50 | meta_data_dict = {} 51 | if (raw_data_dir / 'transcriptions.csv').exists(): 52 | for utterance_label in csv.DictReader( 53 | open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf-8') 54 | ): 55 | item_name = utterance_label['name'] 56 | temp_dict = { 57 | 'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav') 58 | } 59 | ds_path = raw_data_dir / 'wavs' / f'{item_name}.ds' 60 | with open(ds_path, 'r', encoding='utf8') as f: 61 | ds = json.load(f) 62 | if isinstance(ds, list): 63 | ds = ds[0] 64 | if self.skip_glide and ds.get('note_glide') is not None and any( 65 | g != 'none' for g in ds['note_glide'].split() 66 | ): 67 | print(f'Item {ds_id}:{item_name} contains glide notes. Skipping.') 68 | continue 69 | # normalize 70 | note_seq = [ 71 | librosa.midi_to_note( 72 | np.clip( 73 | librosa.note_to_midi(n, round_midi=self.round_midi), 74 | a_min=0, a_max=127 75 | ), 76 | cents=not self.round_midi, unicode=False 77 | ) if n != 'rest' else 'rest' 78 | for n in ds['note_seq'].split() 79 | ] 80 | note_slur = [bool(int(s)) for s in ds['note_slur'].split()] 81 | note_dur = [float(x) for x in ds['note_dur'].split()] 82 | 83 | # if not len(note_seq) == len(note_slur) == len(note_dur): 84 | # continue 85 | assert len(note_seq) == len(note_slur) == len(note_dur), \ 86 | f'Lengths of note_seq, note_slur and note_dur mismatch in \'{item_name}\'.' 87 | assert any([note != 'rest' for note in note_seq]), \ 88 | f'All notes are rest in \'{item_name}\'.' 89 | 90 | if self.merge_slur: 91 | # merge slurs with the same pitch 92 | note_seq, note_dur = merge_slurs(note_seq, note_dur, note_slur, tolerance=self.slur_tolerance) 93 | 94 | if self.merge_rest: 95 | # merge continuous rest notes 96 | note_seq, note_dur = merge_rests(note_seq, note_dur) 97 | 98 | temp_dict['note_seq'] = note_seq 99 | temp_dict['note_dur'] = note_dur 100 | 101 | meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict 102 | else: 103 | raise FileNotFoundError( 104 | f'transcriptions.csv not found in {raw_data_dir}.' 105 | ) 106 | self.items.update(meta_data_dict) 107 | 108 | def check_coverage(self): 109 | super().check_coverage() 110 | # MIDI pitch distribution summary 111 | midi_map = {} 112 | for item_name in self.items: 113 | for midi in self.items[item_name]['note_seq']: 114 | if midi == 'rest': 115 | continue 116 | midi = librosa.note_to_midi(midi, round_midi=True) 117 | if midi in midi_map: 118 | midi_map[midi] += 1 119 | else: 120 | midi_map[midi] = 1 121 | 122 | print('===== MIDI Pitch Distribution Summary =====') 123 | for i, key in enumerate(sorted(midi_map.keys())): 124 | if i == len(midi_map) - 1: 125 | end = '\n' 126 | elif i % 10 == 9: 127 | end = ',\n' 128 | else: 129 | end = ', ' 130 | print(f'\'{librosa.midi_to_note(key, unicode=False)}\': {midi_map[key]}', end=end) 131 | 132 | # Draw graph. 133 | midis = sorted(midi_map.keys()) 134 | notes = [librosa.midi_to_note(m, unicode=False) for m in range(midis[0], midis[-1] + 1)] 135 | plt = distribution_to_figure( 136 | title='MIDI Pitch Distribution Summary', 137 | x_label='MIDI Key', y_label='Number of occurrences', 138 | items=notes, values=[midi_map.get(m, 0) for m in range(midis[0], midis[-1] + 1)] 139 | ) 140 | filename = self.binary_data_dir / 'midi_distribution.jpg' 141 | plt.savefig(fname=filename, 142 | bbox_inches='tight', 143 | pad_inches=0.25) 144 | print(f'| save summary to \'{filename}\'') 145 | 146 | def _process_item(self, waveform, meta_data, int_midi=False): 147 | wav_tensor = torch.from_numpy(waveform).to(self.device) 148 | units_encoder = self.config['units_encoder'] 149 | if units_encoder == 'contentvec768l12': 150 | global contentvec 151 | if contentvec is None: 152 | contentvec = modules.contentvec.ContentVec768L12(self.config['units_encoder_ckpt'], device=self.device) 153 | units = contentvec(wav_tensor).squeeze(0).cpu().numpy() 154 | elif units_encoder == 'mel': 155 | global mel_spec 156 | if mel_spec is None: 157 | mel_spec = modules.rmvpe.MelSpectrogram( 158 | n_mel_channels=self.config['units_dim'], sampling_rate=self.config['audio_sample_rate'], 159 | win_length=self.config['win_size'], hop_length=self.config['hop_size'], 160 | mel_fmin=self.config['fmin'], mel_fmax=self.config['fmax'] 161 | ).to(self.device) 162 | units = mel_spec(wav_tensor.unsqueeze(0)).transpose(1, 2).squeeze(0).cpu().numpy() 163 | else: 164 | raise NotImplementedError(f'Invalid units encoder: {units_encoder}') 165 | assert len(units.shape) == 2 and units.shape[1] == self.config['units_dim'], \ 166 | f'Shape of units must be [T, units_dim], but is {units.shape}.' 167 | length = units.shape[0] 168 | seconds = length * self.config['hop_size'] / self.config['audio_sample_rate'] 169 | processed_input = { 170 | 'seconds': seconds, 171 | 'length': length, 172 | 'units': units 173 | } 174 | 175 | f0_algo = self.config['pe'] 176 | if f0_algo == 'parselmouth': 177 | f0, _ = get_pitch_parselmouth( 178 | waveform, sample_rate=self.config['audio_sample_rate'], 179 | hop_size=self.config['hop_size'], length=length, interp_uv=True 180 | ) 181 | elif f0_algo == 'rmvpe': 182 | global rmvpe 183 | if rmvpe is None: 184 | rmvpe = modules.rmvpe.RMVPE(self.config['pe_ckpt'], device=self.device) 185 | f0, _ = rmvpe.get_pitch( 186 | waveform, sample_rate=self.config['audio_sample_rate'], 187 | hop_size=rmvpe.mel_extractor.hop_length, 188 | length=(waveform.shape[0] + rmvpe.mel_extractor.hop_length - 1) // rmvpe.mel_extractor.hop_length, 189 | interp_uv=True 190 | ) 191 | f0 = resample_align_curve( 192 | f0, 193 | original_timestep=rmvpe.mel_extractor.hop_length / self.config['audio_sample_rate'], 194 | target_timestep=self.config['hop_size'] / self.config['audio_sample_rate'], 195 | align_length=length 196 | ) 197 | else: 198 | raise NotImplementedError(f'Invalid pitch extractor: {f0_algo}') 199 | pitch = librosa.hz_to_midi(f0) 200 | processed_input['pitch'] = pitch 201 | 202 | note_midi = np.array( 203 | [(librosa.note_to_midi(n, round_midi=int_midi) if n != 'rest' else -1) for n in meta_data['note_seq']], 204 | dtype=np.int64 if int_midi else np.float32 205 | ) 206 | note_rest = note_midi < 0 207 | interp_func = interpolate.interp1d( 208 | np.where(~note_rest)[0], note_midi[~note_rest], 209 | kind='nearest', fill_value='extrapolate' 210 | ) 211 | note_midi[note_rest] = interp_func(np.where(note_rest)[0]) 212 | processed_input['note_midi'] = note_midi 213 | processed_input['note_rest'] = note_rest 214 | 215 | note_dur_sec = torch.FloatTensor(meta_data['note_dur']).to(self.device) 216 | note_acc = torch.round(torch.cumsum(note_dur_sec, dim=0) / self.timestep + 0.5).long() 217 | note_dur = torch.diff(note_acc, dim=0, prepend=torch.LongTensor([0]).to(self.device)) 218 | processed_input['note_dur'] = note_dur.cpu().numpy() 219 | unit2note = get_mel2ph_torch( 220 | self.lr, note_dur_sec, processed_input['length'], self.timestep, device=self.device 221 | ) 222 | processed_input['unit2note'] = unit2note.cpu().numpy() 223 | return processed_input 224 | 225 | @torch.no_grad() 226 | def process_item(self, item_name, meta_data, allow_aug=False): 227 | waveform, _ = librosa.load(meta_data['wav_fn'], sr=self.config['audio_sample_rate'], mono=True) 228 | 229 | processed_input = self._process_item(waveform, meta_data, int_midi=False) 230 | items = [processed_input] 231 | if not allow_aug: 232 | return items 233 | 234 | wav_tensor = torch.from_numpy(waveform).to(self.device) 235 | for _ in range(self.config['key_shift_factor']): 236 | assert mel_spec is not None, 'Units encoder must be mel if augmentation is applied!' 237 | key_shift = random.random() * (self.key_shift_max - self.key_shift_min) + self.key_shift_min 238 | if self.round_midi: 239 | key_shift = round(key_shift) 240 | processed_input_aug = copy.deepcopy(processed_input) 241 | assert isinstance(mel_spec, modules.rmvpe.MelSpectrogram) 242 | processed_input_aug['units'] = mel_spec( 243 | wav_tensor.unsqueeze(0), keyshift=key_shift 244 | ).transpose(1, 2).squeeze(0).cpu().numpy() 245 | processed_input_aug['pitch'] += key_shift 246 | processed_input_aug['note_midi'] += key_shift 247 | items.append(processed_input_aug) 248 | 249 | return items 250 | -------------------------------------------------------------------------------- /preprocessing/me_quant_binarizer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | 5 | import librosa 6 | import torch 7 | 8 | import modules.contentvec 9 | import modules.rmvpe 10 | from .me_binarizer import MIDIExtractionBinarizer 11 | 12 | os.environ["OMP_NUM_THREADS"] = "1" 13 | QUANTIZED_MIDI_EXTRACTION_ITEM_ATTRIBUTES = [ 14 | 'units', # contentvec units, float32[T_s, 256] 15 | 'pitch', # actual pitch in semitones, float32[T_s,] 16 | 'note_midi', # note-level MIDI pitch (0-127: MIDI, 128: rest) int64[T_n,] 17 | 'note_dur', # durations of notes, in number of frames, int64[T_n,] 18 | 'unit2note', # mel2ph format for alignment between units and notes 19 | ] 20 | 21 | 22 | class QuantizedMIDIExtractionBinarizer(MIDIExtractionBinarizer): 23 | def __init__(self, config: dict): 24 | super().__init__(config) 25 | self.round_midi = True 26 | self.data_attrs = QUANTIZED_MIDI_EXTRACTION_ITEM_ATTRIBUTES 27 | 28 | def process_item(self, item_name, meta_data, allow_aug=False): 29 | waveform, _ = librosa.load(meta_data['wav_fn'], sr=self.config['audio_sample_rate'], mono=True) 30 | 31 | processed_input = self._process_item(waveform, meta_data, int_midi=True) 32 | processed_input['note_midi'][processed_input['note_rest']] = 128 33 | items = [processed_input] 34 | if not allow_aug: 35 | return items 36 | 37 | from .me_binarizer import mel_spec 38 | wav_tensor = torch.from_numpy(waveform).to(self.device) 39 | for _ in range(self.config['key_shift_factor']): 40 | assert mel_spec is not None, 'Units encoder must be mel if augmentation is applied!' 41 | key_shift = random.randint(int(self.key_shift_min), int(self.key_shift_max)) 42 | processed_input_aug = copy.deepcopy(processed_input) 43 | assert isinstance(mel_spec, modules.rmvpe.MelSpectrogram) 44 | processed_input_aug['units'] = mel_spec( 45 | wav_tensor.unsqueeze(0), keyshift=key_shift 46 | ).transpose(1, 2).squeeze(0).cpu().numpy() 47 | processed_input_aug['pitch'] += key_shift 48 | processed_input_aug['note_midi'][~processed_input_aug['note_rest']] += key_shift 49 | items.append(processed_input_aug) 50 | 51 | return items 52 | -------------------------------------------------------------------------------- /pretrained/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openvpi/SOME/dc9916fb35350748351ddcdfc128fe908ef53f5d/pretrained/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # It is recommended to install PyTorch manually. 2 | # PyTorch >= 2.0 is recommended, but 1.12 and 1.13 is compatible. 3 | # See instructions at https://pytorch.org/get-started/locally/ 4 | 5 | click 6 | einops==0.6.1 7 | fairseq==0.12.2 8 | gradio==3.47.1 9 | h5py 10 | librosa<0.10.0 11 | lightning>=2.0.0 12 | matplotlib 13 | mido 14 | MonkeyType==23.3.0 15 | numpy # ==1.23.5 16 | onnx==1.14.0 17 | onnxsim==0.4.31 18 | praat-parselmouth==0.4.3 19 | PyYAML 20 | scipy 21 | tensorboard 22 | tensorboardX 23 | torchmetrics 24 | tqdm 25 | -------------------------------------------------------------------------------- /simplify.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import click 4 | import torch 5 | 6 | 7 | @click.command(help='Simplify a checkpoint file, dropping all useless keys for inference.') 8 | @click.argument('input_ckpt', metavar='INPUT_CKPT') 9 | @click.argument('output_ckpt', metavar='OUTPUT_CKPT') 10 | def simplify(input_ckpt, output_ckpt): 11 | input_ckpt_path = pathlib.Path(input_ckpt) 12 | output_ckpt_path = pathlib.Path(output_ckpt) 13 | ckpt = torch.load(input_ckpt_path) 14 | ckpt = { 15 | 'state_dict': ckpt['state_dict'] 16 | } 17 | torch.save(ckpt, output_ckpt_path) 18 | 19 | 20 | if __name__ == '__main__': 21 | simplify() 22 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | import pathlib 5 | import sys 6 | 7 | import click 8 | import lightning.pytorch as pl 9 | import torch.utils.data 10 | import yaml 11 | from lightning.pytorch.loggers import TensorBoardLogger 12 | 13 | import training.base_task 14 | from utils.config_utils import read_full_config, print_config 15 | from utils.training_utils import ( 16 | DsModelCheckpoint, DsTQDMProgressBar, 17 | get_latest_checkpoint_path, get_strategy 18 | ) 19 | 20 | torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system')) 21 | 22 | log_format = '%(asctime)s %(message)s' 23 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 24 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 25 | 26 | 27 | @click.command(help='Train a SOME model') 28 | @click.option('--config', required=True, metavar='FILE', help='Path to the configuration file') 29 | @click.option('--exp_name', required=True, metavar='EXP', help='Name of the experiment') 30 | @click.option('--work_dir', required=False, metavar='DIR', help='Directory to save the experiment') 31 | def train(config, exp_name, work_dir): 32 | config = pathlib.Path(config) 33 | config = read_full_config(config) 34 | print_config(config) 35 | if work_dir is None: 36 | work_dir = pathlib.Path(__file__).parent / 'experiments' 37 | else: 38 | work_dir = pathlib.Path(work_dir) 39 | work_dir = work_dir / exp_name 40 | assert not work_dir.exists() or work_dir.is_dir(), f'Path \'{work_dir}\' is not a directory.' 41 | work_dir.mkdir(parents=True, exist_ok=True) 42 | with open(work_dir / 'config.yaml', 'w', encoding='utf8') as f: 43 | yaml.safe_dump(config, f) 44 | config.update({'work_dir': str(work_dir)}) 45 | 46 | if not config['nccl_p2p']: 47 | print("Disabling NCCL P2P") 48 | os.environ['NCCL_P2P_DISABLE'] = '1' 49 | 50 | pl.seed_everything(config['seed'], workers=True) 51 | assert config['task_cls'] != '' 52 | pkg = ".".join(config["task_cls"].split(".")[:-1]) 53 | cls_name = config["task_cls"].split(".")[-1] 54 | task_cls = getattr(importlib.import_module(pkg), cls_name) 55 | assert issubclass(task_cls, training.BaseTask), f'Task class {task_cls} is not a subclass of {training.BaseTask}.' 56 | 57 | task = task_cls(config=config) 58 | 59 | # work_dir = pathlib.Path(config['work_dir']) 60 | trainer = pl.Trainer( 61 | accelerator=config['pl_trainer_accelerator'], 62 | devices=config['pl_trainer_devices'], 63 | num_nodes=config['pl_trainer_num_nodes'], 64 | strategy=get_strategy(config['pl_trainer_strategy']), 65 | precision=config['pl_trainer_precision'], 66 | callbacks=[ 67 | DsModelCheckpoint( 68 | dirpath=work_dir, 69 | filename='model_ckpt_steps_{step}', 70 | auto_insert_metric_name=False, 71 | monitor='step', 72 | mode='max', 73 | save_last=False, 74 | # every_n_train_steps=config['val_check_interval'], 75 | save_top_k=config['num_ckpt_keep'], 76 | permanent_ckpt_start=config['permanent_ckpt_start'], 77 | permanent_ckpt_interval=config['permanent_ckpt_interval'], 78 | verbose=True 79 | ), 80 | # LearningRateMonitor(logging_interval='step'), 81 | DsTQDMProgressBar(), 82 | ], 83 | logger=TensorBoardLogger( 84 | save_dir=str(work_dir), 85 | name='lightning_logs', 86 | version='lastest' 87 | ), 88 | gradient_clip_val=config['clip_grad_norm'], 89 | val_check_interval=config['val_check_interval'] * config['accumulate_grad_batches'], 90 | # so this is global_steps 91 | check_val_every_n_epoch=None, 92 | log_every_n_steps=1, 93 | max_steps=config['max_updates'], 94 | use_distributed_sampler=False, 95 | num_sanity_val_steps=config['num_sanity_val_steps'], 96 | accumulate_grad_batches=config['accumulate_grad_batches'] 97 | ) 98 | trainer.fit(task, ckpt_path=get_latest_checkpoint_path(work_dir)) 99 | 100 | 101 | os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1' # Prevent unacceptable slowdowns when using 16 precision 102 | 103 | 104 | if __name__ == '__main__': 105 | train() 106 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_task import BaseTask 2 | from .me_task import MIDIExtractionTask 3 | from .me_quant_task import QuantizedMIDIExtractionTask 4 | -------------------------------------------------------------------------------- /training/me_quant_task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | import modules.losses 6 | import modules.metrics 7 | from utils import build_object_from_class_name, collate_nd 8 | from utils.infer_utils import decode_bounds_to_alignment, decode_note_sequence 9 | from .base_task import BaseDataset 10 | from .me_task import MIDIExtractionTask 11 | 12 | 13 | class QuantizedMIDIExtractionDataset(BaseDataset): 14 | def collater(self, samples): 15 | batch = super().collater(samples) 16 | batch['units'] = collate_nd([s['units'] for s in samples]) # [B, T_s, C] 17 | batch['pitch'] = collate_nd([s['pitch'] for s in samples]) # [B, T_s] 18 | batch['note_midi'] = collate_nd([s['note_midi'] for s in samples], pad_value=-1) # [B, T_n] 19 | batch['note_dur'] = collate_nd([s['note_dur'] for s in samples]) # [B, T_n] 20 | unit2note = collate_nd([s['unit2note'] for s in samples]) 21 | batch['unit2note'] = unit2note 22 | batch['midi_idx'] = torch.gather(F.pad(batch['note_midi'], [1, 0], value=-1), 1, unit2note) 23 | bounds = torch.diff( 24 | unit2note, dim=1, prepend=unit2note.new_zeros((batch['size'], 1)) 25 | ) > 0 26 | batch['bounds'] = bounds.float() 27 | return batch 28 | 29 | 30 | class QuantizedMIDIExtractionTask(MIDIExtractionTask): 31 | def __init__(self, config: dict): 32 | super().__init__(config) 33 | self.dataset_cls = QuantizedMIDIExtractionDataset 34 | self.config = config 35 | 36 | # noinspection PyAttributeOutsideInit 37 | def build_model(self): 38 | 39 | model = build_object_from_class_name(self.config['model_cls'], nn.Module, config=self.config) 40 | 41 | return model 42 | 43 | def build_losses_and_metrics(self): 44 | self.midi_loss = nn.CrossEntropyLoss(ignore_index=-1) 45 | self.bound_loss = modules.losses.BinaryEMDLoss(bidirectional=False) 46 | self.register_metric('midi_acc', modules.metrics.MIDIAccuracy(tolerance=0.5)) 47 | 48 | def run_model(self, sample, infer=False): 49 | """ 50 | steps: 51 | 1. run the full model 52 | 2. calculate losses if not infer 53 | """ 54 | spec = sample['units'] # [B, T_ph] 55 | # target = (sample['probs'],sample['bounds']) # [B, T_s, M] 56 | mask = sample['unit2note'] > 0 57 | # mask=None 58 | 59 | f0 = sample['pitch'] 60 | probs, bounds = self.model(x=spec, f0=f0, mask=mask, softmax=infer) 61 | 62 | if infer: 63 | return probs, bounds 64 | else: 65 | losses = {} 66 | 67 | if self.cfg['use_bound_loss']: 68 | bound_loss = self.bound_loss(bounds, sample['bounds']) 69 | 70 | losses['bound_loss'] = bound_loss 71 | if self.cfg['use_midi_loss']: 72 | midi_loss = self.midi_loss(probs.transpose(1, 2), sample['midi_idx']) 73 | 74 | losses['midi_loss'] = midi_loss 75 | 76 | return losses 77 | 78 | def _validation_step(self, sample, batch_idx): 79 | losses = self.run_model(sample, infer=False) 80 | if batch_idx < self.config['num_valid_plots']: 81 | probs, bounds = self.run_model(sample, infer=True) 82 | unit2note_gt = sample['unit2note'] 83 | masks = unit2note_gt > 0 84 | probs *= masks[..., None] 85 | bounds *= masks 86 | # probs: [B, T, 129] => [B, T, 128] 87 | probs_pred = probs[:, :, :-1] 88 | probs_gt = F.one_hot(sample['midi_idx'], num_classes=129)[:, :, :-1] 89 | self.plot_prob(batch_idx, probs_gt, probs_pred) 90 | 91 | unit2note_pred = decode_bounds_to_alignment(bounds) * masks 92 | midi_pred = probs.argmax(dim=-1) 93 | rest_pred = midi_pred == 128 94 | note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence( 95 | unit2note_pred, midi_pred.clip(min=0, max=127), ~rest_pred & masks 96 | ) 97 | note_rest_pred = ~note_mask_pred 98 | self.plot_boundary( 99 | batch_idx, bounds_gt=sample['bounds'], bounds_pred=bounds, 100 | dur_gt=sample['note_dur'], dur_pred=note_dur_pred 101 | ) 102 | self.plot_final( 103 | batch_idx, sample['note_midi'], sample['note_dur'], sample['note_midi'] == 128, 104 | note_midi_pred, note_dur_pred, note_rest_pred, sample['pitch'] 105 | ) 106 | 107 | midi_pred = midi_pred.float() 108 | midi_pred[rest_pred] = -torch.inf # rest part is set to -inf 109 | note_midi_gt = sample['note_midi'].float() 110 | note_rest_gt = sample['note_midi'] == 128 111 | note_midi_gt[note_rest_gt] = -torch.inf 112 | midi_gt = torch.gather(F.pad(note_midi_gt, [1, 0], value=-torch.inf), 1, unit2note_gt) 113 | self.plot_midi_curve( 114 | batch_idx, midi_gt=midi_gt, midi_pred=midi_pred, pitch=sample['pitch'] 115 | ) 116 | self.midi_acc.update( 117 | midi_pred=midi_pred, rest_pred=rest_pred, midi_gt=midi_gt, rest_gt=midi_gt < 0, mask=masks 118 | ) 119 | 120 | return losses, sample['size'] 121 | -------------------------------------------------------------------------------- /training/me_task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | import modules.losses 6 | import modules.metrics 7 | from utils import build_object_from_class_name, collate_nd 8 | from utils.infer_utils import decode_gaussian_blurred_probs, decode_bounds_to_alignment, decode_note_sequence 9 | from utils.plot import boundary_to_figure, curve_to_figure, spec_to_figure, pitch_notes_to_figure 10 | from .base_task import BaseDataset, BaseTask 11 | 12 | 13 | class MIDIExtractionDataset(BaseDataset): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | self.midi_min = self.config['midi_min'] 17 | self.midi_max = self.config['midi_max'] 18 | self.num_bins = self.config['midi_num_bins'] 19 | self.midi_deviation = self.config['midi_prob_deviation'] 20 | self.interval = (self.midi_max - self.midi_min) / (self.num_bins - 1) # align with centers of bins 21 | self.sigma = self.midi_deviation / self.interval 22 | 23 | def midi_to_bin(self, midi): 24 | return (midi - self.midi_min) / self.interval 25 | 26 | def collater(self, samples): 27 | batch = super().collater(samples) 28 | batch['units'] = collate_nd([s['units'] for s in samples]) # [B, T_s, C] 29 | batch['pitch'] = collate_nd([s['pitch'] for s in samples]) # [B, T_s] 30 | batch['note_midi'] = collate_nd([s['note_midi'] for s in samples]) # [B, T_n] 31 | batch['note_rest'] = collate_nd([s['note_rest'] for s in samples]) # [B, T_n] 32 | batch['note_dur'] = collate_nd([s['note_dur'] for s in samples]) # [B, T_n] 33 | 34 | miu = self.midi_to_bin(batch['note_midi'])[:, :, None] # [B, T_n, 1] 35 | x = torch.arange(self.num_bins).float().reshape(1, 1, -1).to(miu.device) # [1, 1, N] 36 | probs = ((x - miu) / self.sigma).pow(2).div(-2).exp() # gaussian blur, [B, T_n, N] 37 | note_mask = collate_nd([torch.ones_like(s['note_rest']) for s in samples], pad_value=False) 38 | probs *= (note_mask[..., None] & ~batch['note_rest'][..., None]) 39 | 40 | probs = F.pad(probs, [0, 0, 1, 0]) 41 | unit2note = collate_nd([s['unit2note'] for s in samples]) 42 | unit2note_ = unit2note[..., None].repeat([1, 1, self.num_bins]) 43 | probs = torch.gather(probs, 1, unit2note_) 44 | batch['probs'] = probs # [B, T_s, N] 45 | batch['unit2note'] = unit2note 46 | bounds = torch.diff( 47 | unit2note, dim=1, prepend=unit2note.new_zeros((batch['size'], 1)) 48 | ) > 0 49 | batch['bounds'] = bounds.float() # [B, T_s] 50 | 51 | return batch 52 | 53 | 54 | # todo 55 | class MIDIExtractionTask(BaseTask): 56 | def __init__(self, config: dict): 57 | super().__init__(config) 58 | self.midiloss = None 59 | self.dataset_cls = MIDIExtractionDataset 60 | self.midi_min = self.config['midi_min'] 61 | self.midi_max = self.config['midi_max'] 62 | self.midi_deviation = self.config['midi_prob_deviation'] 63 | self.rest_threshold = self.config['rest_threshold'] 64 | self.cfg=config 65 | 66 | def build_model(self): 67 | 68 | model = build_object_from_class_name(self.config['model_cls'], nn.Module, config=self.config) 69 | 70 | return model 71 | 72 | def build_losses_and_metrics(self): 73 | 74 | self.midi_loss = nn.BCEWithLogitsLoss() 75 | self.bound_loss = modules.losses.BinaryEMDLoss() 76 | # self.bound_loss = modules.losses.BinaryEMDLoss(bidirectional=True) 77 | self.register_metric('midi_acc', modules.metrics.MIDIAccuracy(tolerance=0.5)) 78 | 79 | def run_model(self, sample, infer=False): 80 | """ 81 | steps: 82 | 1. run the full model 83 | 2. calculate losses if not infer 84 | """ 85 | spec = sample['units'] # [B, T_ph] 86 | # target = (sample['probs'],sample['bounds']) # [B, T_s, M] 87 | mask = sample['unit2note'] > 0 88 | # mask=None 89 | 90 | f0 = sample['pitch'] 91 | 92 | 93 | 94 | 95 | if infer: 96 | probs, bounds = self.model(x=spec, f0=f0, mask=mask, sig=True) 97 | return probs, bounds 98 | else: 99 | losses = {} 100 | probs, bounds = self.model(x=spec, f0=f0, mask=mask, sig=False) 101 | 102 | if self.cfg['use_bound_loss']: 103 | bound_loss = self.bound_loss(bounds, sample['bounds']) 104 | 105 | losses['bound_loss'] = bound_loss 106 | if self.cfg['use_midi_loss']: 107 | midi_loss = self.midi_loss(probs, sample['probs']) 108 | 109 | losses['midi_loss'] = midi_loss 110 | 111 | return losses 112 | 113 | # raise NotImplementedError() 114 | 115 | def _validation_step(self, sample, batch_idx): 116 | losses = self.run_model(sample, infer=False) 117 | if batch_idx < self.config['num_valid_plots']: 118 | probs, bounds = self.run_model(sample, infer=True) 119 | unit2note_gt = sample['unit2note'] 120 | masks = unit2note_gt > 0 121 | probs *= masks[..., None] 122 | bounds *= masks 123 | self.plot_prob(batch_idx, sample['probs'], probs) 124 | 125 | unit2note_pred = decode_bounds_to_alignment(bounds) * masks 126 | midi_pred, rest_pred = decode_gaussian_blurred_probs( 127 | probs, vmin=self.midi_min, vmax=self.midi_max, 128 | deviation=self.midi_deviation, threshold=self.rest_threshold 129 | ) 130 | note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence( 131 | unit2note_pred, midi_pred, ~rest_pred & masks 132 | ) 133 | note_rest_pred = ~note_mask_pred 134 | self.plot_boundary( 135 | batch_idx, bounds_gt=sample['bounds'], bounds_pred=bounds, 136 | dur_gt=sample['note_dur'], dur_pred=note_dur_pred 137 | ) 138 | self.plot_final( 139 | batch_idx, sample['note_midi'], sample['note_dur'], sample['note_rest'], 140 | note_midi_pred, note_dur_pred, note_rest_pred, sample['pitch'] 141 | ) 142 | 143 | midi_pred[rest_pred] = -torch.inf # rest part is set to -inf 144 | note_midi_gt = sample['note_midi'].clone() 145 | note_midi_gt[sample['note_rest']] = -torch.inf 146 | midi_gt = torch.gather(F.pad(note_midi_gt, [1, 0], value=-torch.inf), 1, unit2note_gt) 147 | self.plot_midi_curve( 148 | batch_idx, midi_gt=midi_gt, midi_pred=midi_pred, pitch=sample['pitch'] 149 | ) 150 | self.midi_acc.update( 151 | midi_pred=midi_pred, rest_pred=rest_pred, midi_gt=midi_gt, rest_gt=midi_gt < 0, mask=masks 152 | ) 153 | 154 | return losses, sample['size'] 155 | 156 | ############ 157 | # validation plots 158 | ############ 159 | def plot_prob(self, batch_idx, probs_gt, probs_pred): 160 | name = f'prob/{batch_idx}' 161 | vmin, vmax = 0, 1 162 | spec_cat = torch.cat([(probs_pred - probs_gt).abs() + vmin, probs_gt, probs_pred], -1) 163 | self.logger.experiment.add_figure(name, spec_to_figure(spec_cat[0], vmin, vmax), self.global_step) 164 | 165 | def plot_boundary(self, batch_idx, bounds_gt, bounds_pred, dur_gt, dur_pred): 166 | name = f'boundary/{batch_idx}' 167 | bounds_gt = bounds_gt[0].cpu().numpy() 168 | bounds_pred = bounds_pred[0].cpu().numpy() 169 | dur_gt = dur_gt[0].cpu().numpy() 170 | dur_pred = dur_pred[0].cpu().numpy() 171 | self.logger.experiment.add_figure(name, boundary_to_figure( 172 | bounds_gt, bounds_pred, dur_gt, dur_pred 173 | ), self.global_step) 174 | 175 | def plot_midi_curve(self, batch_idx, midi_gt, midi_pred, pitch): 176 | name = f'midi/{batch_idx}' 177 | midi_gt = midi_gt[0].cpu().numpy() 178 | midi_pred = midi_pred[0].cpu().numpy() 179 | pitch = pitch[0].cpu().numpy() 180 | self.logger.experiment.add_figure(name, curve_to_figure( 181 | midi_gt, midi_pred, curve_base=pitch, grid=1, base_label='pitch' 182 | ), self.global_step) 183 | 184 | def plot_final(self, batch_idx, midi_gt, dur_gt, rest_gt, midi_pred, dur_pred, rest_pred, pitch): 185 | name = f'final/{batch_idx}' 186 | midi_gt = midi_gt[0].cpu().numpy() 187 | midi_pred = midi_pred[0].cpu().numpy() 188 | dur_gt = dur_gt[0].cpu().numpy() 189 | dur_pred = dur_pred[0].cpu().numpy() 190 | rest_gt = rest_gt[0].cpu().numpy() 191 | rest_pred = rest_pred[0].cpu().numpy() 192 | pitch = pitch[0].cpu().numpy() 193 | self.logger.experiment.add_figure(name, pitch_notes_to_figure( 194 | pitch=pitch, note_midi_gt=midi_gt, note_dur_gt=dur_gt, note_rest_gt=rest_gt, 195 | note_midi_pred=midi_pred, note_dur_pred=dur_pred, note_rest_pred=rest_pred 196 | ), self.global_step) 197 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pathlib 4 | import re 5 | import types 6 | from collections import OrderedDict 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from utils.training_utils import get_latest_checkpoint_path 12 | 13 | 14 | def tensors_to_scalars(metrics): 15 | new_metrics = {} 16 | for k, v in metrics.items(): 17 | if isinstance(v, torch.Tensor): 18 | v = v.item() 19 | if type(v) is dict: 20 | v = tensors_to_scalars(v) 21 | new_metrics[k] = v 22 | return new_metrics 23 | 24 | 25 | def collate_nd(values, pad_value=0, max_len=None): 26 | """ 27 | Pad a list of Nd tensors on their first dimension and stack them into a (N+1)d tensor. 28 | """ 29 | size = ((max(v.size(0) for v in values) if max_len is None else max_len), *values[0].shape[1:]) 30 | res = torch.full((len(values), *size), fill_value=pad_value, dtype=values[0].dtype, device=values[0].device) 31 | 32 | for i, v in enumerate(values): 33 | res[i, :len(v), ...] = v 34 | return res 35 | 36 | 37 | def random_continuous_masks(*shape: int, dim: int, device: str | torch.device = 'cpu'): 38 | start, end = torch.sort( 39 | torch.randint( 40 | low=0, high=shape[dim] + 1, size=(*shape[:dim], 2, *((1,) * (len(shape) - dim - 1))), device=device 41 | ).expand(*((-1,) * (dim + 1)), *shape[dim + 1:]), dim=dim 42 | )[0].split(1, dim=dim) 43 | idx = torch.arange( 44 | 0, shape[dim], dtype=torch.long, device=device 45 | ).reshape(*((1,) * dim), shape[dim], *((1,) * (len(shape) - dim - 1))) 46 | masks = (idx >= start) & (idx < end) 47 | return masks 48 | 49 | 50 | def _is_batch_full(batch, num_frames, max_batch_frames, max_batch_size): 51 | if len(batch) == 0: 52 | return 0 53 | if len(batch) == max_batch_size: 54 | return 1 55 | if num_frames > max_batch_frames: 56 | return 1 57 | return 0 58 | 59 | 60 | def batch_by_size( 61 | indices, num_frames_fn, max_batch_frames=80000, max_batch_size=48, 62 | required_batch_size_multiple=1 63 | ): 64 | """ 65 | Yield mini-batches of indices bucketed by size. Batches may contain 66 | sequences of different lengths. 67 | 68 | Args: 69 | indices (List[int]): ordered list of dataset indices 70 | num_frames_fn (callable): function that returns the number of frames at 71 | a given index 72 | max_batch_frames (int, optional): max number of frames in each batch 73 | (default: 80000). 74 | max_batch_size (int, optional): max number of sentences in each 75 | batch (default: 48). 76 | required_batch_size_multiple: require the batch size to be multiple 77 | of a given number 78 | """ 79 | bsz_mult = required_batch_size_multiple 80 | 81 | if isinstance(indices, types.GeneratorType): 82 | indices = np.fromiter(indices, dtype=np.int64, count=-1) 83 | 84 | sample_len = 0 85 | sample_lens = [] 86 | batch = [] 87 | batches = [] 88 | for i in range(len(indices)): 89 | idx = indices[i] 90 | num_frames = num_frames_fn(idx) 91 | sample_lens.append(num_frames) 92 | sample_len = max(sample_len, num_frames) 93 | assert sample_len <= max_batch_frames, ( 94 | "sentence at index {} of size {} exceeds max_batch_samples " 95 | "limit of {}!".format(idx, sample_len, max_batch_frames) 96 | ) 97 | num_frames = (len(batch) + 1) * sample_len 98 | 99 | if _is_batch_full(batch, num_frames, max_batch_frames, max_batch_size): 100 | mod_len = max( 101 | bsz_mult * (len(batch) // bsz_mult), 102 | len(batch) % bsz_mult, 103 | ) 104 | batches.append(batch[:mod_len]) 105 | batch = batch[mod_len:] 106 | sample_lens = sample_lens[mod_len:] 107 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 108 | batch.append(idx) 109 | if len(batch) > 0: 110 | batches.append(batch) 111 | return batches 112 | 113 | 114 | def unpack_dict_to_list(samples): 115 | samples_ = [] 116 | bsz = samples.get('outputs').size(0) 117 | for i in range(bsz): 118 | res = {} 119 | for k, v in samples.items(): 120 | try: 121 | res[k] = v[i] 122 | except: 123 | pass 124 | samples_.append(res) 125 | return samples_ 126 | 127 | 128 | def filter_kwargs(dict_to_filter, kwarg_obj): 129 | import inspect 130 | 131 | sig = inspect.signature(kwarg_obj) 132 | if any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values()): 133 | # the signature contains definitions like **kwargs, so there is no need to filter 134 | return dict_to_filter.copy() 135 | filter_keys = [ 136 | param.name 137 | for param in sig.parameters.values() 138 | if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.KEYWORD_ONLY 139 | ] 140 | filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if 141 | filter_key in dict_to_filter} 142 | return filtered_dict 143 | 144 | 145 | def load_ckpt( 146 | cur_model, ckpt_base_dir, ckpt_steps=None, 147 | prefix_in_ckpt='model', key_in_ckpt='state_dict', 148 | strict=True, device='cpu' 149 | ): 150 | if not isinstance(ckpt_base_dir, pathlib.Path): 151 | ckpt_base_dir = pathlib.Path(ckpt_base_dir) 152 | if ckpt_base_dir.is_file(): 153 | checkpoint_path = [ckpt_base_dir] 154 | elif ckpt_steps is not None: 155 | checkpoint_path = [ckpt_base_dir / f'model_ckpt_steps_{int(ckpt_steps)}.ckpt'] 156 | else: 157 | base_dir = ckpt_base_dir 158 | checkpoint_path = sorted( 159 | [ 160 | ckpt_file 161 | for ckpt_file in base_dir.iterdir() 162 | if ckpt_file.is_file() and re.fullmatch(r'model_ckpt_steps_\d+\.ckpt', ckpt_file.name) 163 | ], 164 | key=lambda x: int(re.search(r'\d+', x.name).group(0)) 165 | ) 166 | assert len(checkpoint_path) > 0, f'| ckpt not found in {ckpt_base_dir}.' 167 | checkpoint_path = checkpoint_path[-1] 168 | ckpt_loaded = torch.load(checkpoint_path, map_location=device) 169 | if key_in_ckpt is None: 170 | state_dict = ckpt_loaded 171 | else: 172 | state_dict = ckpt_loaded[key_in_ckpt] 173 | if prefix_in_ckpt is not None: 174 | state_dict = OrderedDict({ 175 | k[len(prefix_in_ckpt) + 1:]: v 176 | for k, v in state_dict.items() if k.startswith(f'{prefix_in_ckpt}.') 177 | }) 178 | if not strict: 179 | cur_model_state_dict = cur_model.state_dict() 180 | unmatched_keys = [] 181 | for key, param in state_dict.items(): 182 | if key in cur_model_state_dict: 183 | new_param = cur_model_state_dict[key] 184 | if new_param.shape != param.shape: 185 | unmatched_keys.append(key) 186 | print('| Unmatched keys: ', key, new_param.shape, param.shape) 187 | for key in unmatched_keys: 188 | del state_dict[key] 189 | cur_model.load_state_dict(state_dict, strict=strict) 190 | shown_model_name = 'state dict' 191 | if prefix_in_ckpt is not None: 192 | shown_model_name = f'\'{prefix_in_ckpt}\'' 193 | elif key_in_ckpt is not None: 194 | shown_model_name = f'\'{key_in_ckpt}\'' 195 | print(f'| load {shown_model_name} from \'{checkpoint_path}\'.') 196 | 197 | 198 | def remove_padding(x, padding_idx=0): 199 | if x is None: 200 | return None 201 | assert len(x.shape) in [1, 2] 202 | if len(x.shape) == 2: # [T, H] 203 | return x[np.abs(x).sum(-1) != padding_idx] 204 | elif len(x.shape) == 1: # [T] 205 | return x[x != padding_idx] 206 | 207 | 208 | def print_arch(model, model_name='model'): 209 | print(f"| {model_name} Arch: ", model) 210 | # num_params(model, model_name=model_name) 211 | 212 | 213 | def num_params(model, print_out=True, model_name="model"): 214 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 215 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 216 | if print_out: 217 | print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) 218 | return parameters 219 | 220 | 221 | def build_object_from_class_name(cls_str, parent_cls, *args, **kwargs): 222 | import importlib 223 | 224 | pkg = ".".join(cls_str.split(".")[:-1]) 225 | cls_name = cls_str.split(".")[-1] 226 | cls_type = getattr(importlib.import_module(pkg), cls_name) 227 | if parent_cls is not None: 228 | assert issubclass(cls_type, parent_cls), f'| {cls_type} is not subclass of {parent_cls}.' 229 | 230 | return cls_type(*args, **filter_kwargs(kwargs, cls_type)) 231 | 232 | 233 | def build_lr_scheduler_from_config(optimizer, scheduler_args): 234 | try: 235 | # PyTorch 2.0+ 236 | from torch.optim.lr_scheduler import LRScheduler as LRScheduler 237 | except ImportError: 238 | # PyTorch 1.X 239 | from torch.optim.lr_scheduler import _LRScheduler as LRScheduler 240 | 241 | def helper(params): 242 | if isinstance(params, list): 243 | return [helper(s) for s in params] 244 | elif isinstance(params, dict): 245 | resolved = {k: helper(v) for k, v in params.items()} 246 | if 'cls' in resolved: 247 | if ( 248 | resolved["cls"] == "torch.optim.lr_scheduler.ChainedScheduler" 249 | and scheduler_args["scheduler_cls"] == "torch.optim.lr_scheduler.SequentialLR" 250 | ): 251 | raise ValueError(f"ChainedScheduler cannot be part of a SequentialLR.") 252 | resolved['optimizer'] = optimizer 253 | obj = build_object_from_class_name( 254 | resolved['cls'], 255 | LRScheduler, 256 | **resolved 257 | ) 258 | return obj 259 | return resolved 260 | else: 261 | return params 262 | 263 | resolved = helper(scheduler_args) 264 | resolved['optimizer'] = optimizer 265 | return build_object_from_class_name( 266 | scheduler_args['scheduler_cls'], 267 | LRScheduler, 268 | **resolved 269 | ) 270 | 271 | 272 | def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_groups=1): 273 | optimizer = build_object_from_class_name( 274 | optimizer_args['optimizer_cls'], 275 | torch.optim.Optimizer, 276 | [{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)], 277 | **optimizer_args 278 | ) 279 | scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args) 280 | scheduler.optimizer._step_count = 1 281 | for _ in range(step_count): 282 | scheduler.step() 283 | return scheduler.state_dict() 284 | 285 | 286 | def remove_suffix(string: str, suffix: str): 287 | # Just for Python 3.8 compatibility, since `str.removesuffix()` API of is available since Python 3.9 288 | if string.endswith(suffix): 289 | string = string[:-len(suffix)] 290 | return string 291 | -------------------------------------------------------------------------------- /utils/binarizer_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import librosa 4 | import numpy as np 5 | import parselmouth 6 | import torch 7 | 8 | from utils.pitch_utils import interp_f0 9 | 10 | 11 | def merge_slurs(note_seq: list, note_dur: list, note_slur: list, tolerance=None) -> Tuple[list, list]: 12 | """ 13 | merge slurs with the similar pitch 14 | """ 15 | note_midi = [librosa.note_to_midi(n, round_midi=False) if n != 'rest' else 'rest' for n in note_seq] 16 | prev_min = prev_max = None 17 | note_midi_merge_slur = [note_midi[0]] 18 | note_dur_merge_slur = [note_dur[0]] 19 | 20 | def can_be_merged(midi): 21 | if tolerance is None or midi == 'rest' or note_midi_merge_slur[-1] == 'rest': 22 | return note_midi_merge_slur[-1] == midi 23 | return ( 24 | abs(midi - note_midi_merge_slur[-1]) <= tolerance 25 | and (prev_min is None or abs(midi - prev_min) <= tolerance) 26 | and (prev_max is None or abs(midi - prev_max) <= tolerance) 27 | ) 28 | 29 | def get_merged_midi(midi1, dur1, midi2, dur2): 30 | if midi1 == midi2: 31 | return midi1 32 | return (midi1 * dur1 + midi2 * dur2) / (dur1 + dur2) 33 | 34 | for i in range(1, len(note_seq)): 35 | if note_slur[i] and can_be_merged(note_midi[i]): 36 | # update min and max 37 | prev_min = min(note_midi[i], note_midi_merge_slur[-1]) if prev_min is None else min(prev_min, note_midi[i]) 38 | prev_max = max(note_midi[i], note_midi_merge_slur[-1]) if prev_max is None else max(prev_max, note_midi[i]) 39 | note_midi_merge_slur[-1] = get_merged_midi( 40 | note_midi_merge_slur[-1], note_dur_merge_slur[-1], note_midi[i], note_dur[i] 41 | ) 42 | note_dur_merge_slur[-1] += note_dur[i] 43 | else: 44 | note_midi_merge_slur.append(note_midi[i]) 45 | note_dur_merge_slur.append(note_dur[i]) 46 | prev_min = prev_max = None 47 | note_seq_merge_slur = [ 48 | librosa.midi_to_note(n, cents=True, unicode=False) if n != 'rest' else 'rest' for n in note_midi_merge_slur 49 | ] 50 | return note_seq_merge_slur, note_dur_merge_slur 51 | 52 | 53 | def merge_rests(note_seq: list, note_dur: list) -> Tuple[list, list]: 54 | i = 0 55 | note_seq_merge_rest = [] 56 | note_dur_merge_rest = [] 57 | while i < len(note_seq): 58 | if note_seq[i] != 'rest': 59 | note_seq_merge_rest.append(note_seq[i]) 60 | note_dur_merge_rest.append(note_dur[i]) 61 | i += 1 62 | else: 63 | j = i 64 | rest_dur = 0 65 | while j < len(note_seq) and note_seq[j] == 'rest': 66 | rest_dur += note_dur[j] 67 | j += 1 68 | note_seq_merge_rest.append('rest') 69 | note_dur_merge_rest.append(rest_dur) 70 | i = j 71 | return note_seq_merge_rest, note_dur_merge_rest 72 | 73 | 74 | @torch.no_grad() 75 | def get_mel2ph_torch(lr, durs, length, timestep, device='cpu'): 76 | ph_acc = torch.round(torch.cumsum(durs.to(device), dim=0) / timestep + 0.5).long() 77 | ph_dur = torch.diff(ph_acc, dim=0, prepend=torch.LongTensor([0]).to(device)) 78 | mel2ph = lr(ph_dur[None])[0] 79 | num_frames = mel2ph.shape[0] 80 | if num_frames < length: 81 | mel2ph = torch.cat((mel2ph, torch.full((length - num_frames,), fill_value=mel2ph[-1], device=device)), dim=0) 82 | elif num_frames > length: 83 | mel2ph = mel2ph[:length] 84 | return mel2ph 85 | 86 | 87 | def pad_frames(frames, hop_size, n_samples, n_expect): 88 | n_frames = frames.shape[0] 89 | lpad = (int(n_samples // hop_size) - n_frames + 1) // 2 90 | rpad = n_expect - n_frames - lpad 91 | if rpad < 0: 92 | frames = frames[:rpad] 93 | rpad = 0 94 | if lpad > 0 or rpad > 0: 95 | frames = np.pad(frames, (lpad, rpad), mode='constant', constant_values=(frames[0], frames[-1])) 96 | return frames 97 | 98 | 99 | def get_pitch_parselmouth(waveform, sample_rate, hop_size, length, interp_uv=False): 100 | """ 101 | 102 | :param waveform: [T] 103 | :param hop_size: size of each frame 104 | :param sample_rate: sampling rate of waveform 105 | :param length: Expected number of frames 106 | :param interp_uv: Interpolate unvoiced parts 107 | :return: f0, uv 108 | """ 109 | 110 | time_step = hop_size / sample_rate 111 | f0_min = 65 112 | f0_max = 800 113 | 114 | # noinspection PyArgumentList 115 | f0 = parselmouth.Sound(waveform, sampling_frequency=sample_rate).to_pitch_ac( 116 | time_step=time_step, voicing_threshold=0.6, 117 | pitch_floor=f0_min, pitch_ceiling=f0_max 118 | ).selected_array['frequency'].astype(np.float32) 119 | f0 = pad_frames(f0, hop_size, waveform.shape[0], length) 120 | uv = f0 == 0 121 | if interp_uv: 122 | f0, uv = interp_f0(f0, uv) 123 | return f0, uv 124 | 125 | 126 | class SinusoidalSmoothingConv1d(torch.nn.Conv1d): 127 | def __init__(self, kernel_size): 128 | super().__init__( 129 | in_channels=1, 130 | out_channels=1, 131 | kernel_size=kernel_size, 132 | bias=False, 133 | padding='same', 134 | padding_mode='replicate' 135 | ) 136 | smooth_kernel = torch.sin(torch.from_numpy( 137 | np.linspace(0, 1, kernel_size).astype(np.float32) * np.pi 138 | )) 139 | smooth_kernel /= smooth_kernel.sum() 140 | self.weight.data = smooth_kernel[None, None] 141 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pathlib 4 | 5 | import lightning.pytorch.utilities 6 | import yaml 7 | 8 | loaded_config_files = {} 9 | 10 | 11 | def override_dict(old_config: dict, new_config: dict): 12 | for k, v in new_config.items(): 13 | if isinstance(v, dict) and k in old_config: 14 | override_dict(old_config[k], new_config[k]) 15 | else: 16 | old_config[k] = v 17 | 18 | 19 | def read_full_config(config_path: pathlib.Path) -> dict: 20 | config_path = config_path.resolve() 21 | config_path_str = config_path.as_posix() 22 | if config_path in loaded_config_files: 23 | return loaded_config_files[config_path_str] 24 | 25 | with open(config_path, 'r', encoding='utf8') as f: 26 | config = yaml.safe_load(f) 27 | if 'base_config' not in config: 28 | loaded_config_files[config_path_str] = config 29 | return config 30 | 31 | if not isinstance(config['base_config'], list): 32 | config['base_config'] = [config['base_config']] 33 | squashed_config = {} 34 | for base_config in config['base_config']: 35 | c_path = pathlib.Path(base_config) 36 | full_base_config = read_full_config(c_path) 37 | override_dict(squashed_config, full_base_config) 38 | override_dict(squashed_config, config) 39 | squashed_config.pop('base_config') 40 | loaded_config_files[config_path_str] = squashed_config 41 | return squashed_config 42 | 43 | 44 | @lightning.pytorch.utilities.rank_zero.rank_zero_only 45 | def print_config(config: dict): 46 | for i, (k, v) in enumerate(sorted(config.items())): 47 | print(f"\033[0;33m{k}\033[0m: {v}", end='') 48 | if i < len(config) - 1: 49 | print(", ", end="") 50 | if i % 5 == 4: 51 | print() 52 | print() 53 | -------------------------------------------------------------------------------- /utils/indexed_datasets.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import multiprocessing 3 | from collections import deque 4 | 5 | import h5py 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class IndexedDataset: 11 | def __init__(self, path, prefix, num_cache=0): 12 | super().__init__() 13 | self.path = pathlib.Path(path) / f'{prefix}.data' 14 | if not self.path.exists(): 15 | raise FileNotFoundError(f'IndexedDataset not found: {self.path}') 16 | self.dset = None 17 | self.cache = deque(maxlen=num_cache) 18 | self.num_cache = num_cache 19 | 20 | def check_index(self, i): 21 | if i < 0 or i >= len(self.dset): 22 | raise IndexError('index out of range') 23 | 24 | def __del__(self): 25 | if self.dset: 26 | self.dset.close() 27 | 28 | def __getitem__(self, i): 29 | if self.dset is None: 30 | self.dset = h5py.File(self.path, 'r') 31 | self.check_index(i) 32 | if self.num_cache > 0: 33 | for c in self.cache: 34 | if c[0] == i: 35 | return c[1] 36 | item = {k: v[()].item() if v.shape == () else torch.from_numpy(v[()]) for k, v in self.dset[str(i)].items()} 37 | if self.num_cache > 0: 38 | self.cache.appendleft((i, item)) 39 | return item 40 | 41 | def __len__(self): 42 | if self.dset is None: 43 | self.dset = h5py.File(self.path, 'r') 44 | return len(self.dset) 45 | 46 | 47 | class IndexedDatasetBuilder: 48 | def __init__(self, path, prefix, allowed_attr=None): 49 | self.path = pathlib.Path(path) / f'{prefix}.data' 50 | self.prefix = prefix 51 | self.dset = None 52 | self.counter = 0 53 | self.lock = multiprocessing.Lock() 54 | if allowed_attr is not None: 55 | self.allowed_attr = set(allowed_attr) 56 | else: 57 | self.allowed_attr = None 58 | 59 | def add_item(self, item): 60 | if self.dset is None: 61 | self.dset = h5py.File(self.path, 'w') 62 | if self.allowed_attr is not None: 63 | item = { 64 | k: item[k] 65 | for k in self.allowed_attr 66 | if k in item 67 | } 68 | item_no = self.counter 69 | self.counter += 1 70 | for k, v in item.items(): 71 | if v is None: 72 | continue 73 | self.dset.create_dataset(f'{item_no}/{k}', data=v) 74 | 75 | def finalize(self): 76 | if self.dset is not None: 77 | self.dset.close() 78 | 79 | 80 | if __name__ == "__main__": 81 | import random 82 | from tqdm import tqdm 83 | 84 | ds_path = './checkpoints/indexed_ds_example' 85 | size = 100 86 | items = [{"a": np.random.normal(size=[10000, 10]), 87 | "b": np.random.normal(size=[10000, 10])} for i in range(size)] 88 | builder = IndexedDatasetBuilder(ds_path, 'example') 89 | for i in tqdm(range(size)): 90 | builder.add_item(items[i]) 91 | builder.finalize() 92 | ds = IndexedDataset(ds_path, 'example') 93 | for i in tqdm(range(10000)): 94 | idx = random.randint(0, size - 1) 95 | assert (ds[idx]['a'] == items[idx]['a']).all() 96 | -------------------------------------------------------------------------------- /utils/infer_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import mido 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def decode_gaussian_blurred_probs(probs, vmin, vmax, deviation, threshold): 10 | num_bins = int(probs.shape[-1]) 11 | interval = (vmax - vmin) / (num_bins - 1) 12 | width = int(3 * deviation / interval) # 3 * sigma 13 | idx = torch.arange(num_bins, device=probs.device)[None, None, :] # [1, 1, N] 14 | idx_values = idx * interval + vmin 15 | center = torch.argmax(probs, dim=-1, keepdim=True) # [B, T, 1] 16 | start = torch.clip(center - width, min=0) # [B, T, 1] 17 | end = torch.clip(center + width + 1, max=num_bins) # [B, T, 1] 18 | idx_masks = (idx >= start) & (idx < end) # [B, T, N] 19 | weights = probs * idx_masks # [B, T, N] 20 | product_sum = torch.sum(weights * idx_values, dim=2) # [B, T] 21 | weight_sum = torch.sum(weights, dim=2) # [B, T] 22 | values = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T] 23 | rest = probs.max(dim=-1)[0] < threshold # [B, T] 24 | return values, rest 25 | 26 | 27 | def decode_bounds_to_alignment(bounds, use_diff=True): 28 | bounds_step = bounds.cumsum(dim=1).round().long() 29 | if use_diff: 30 | bounds_inc = torch.diff( 31 | bounds_step, dim=1, prepend=torch.full( 32 | (bounds.shape[0], 1), fill_value=-1, 33 | dtype=bounds_step.dtype, device=bounds_step.device 34 | ) 35 | ) > 0 36 | else: 37 | bounds_inc = F.pad((bounds_step[:, 1:] > bounds_step[:, :-1]), [1, 0], value=True) 38 | frame2item = bounds_inc.long().cumsum(dim=1) 39 | return frame2item 40 | 41 | 42 | def decode_note_sequence(frame2item, values, masks, threshold=0.5): 43 | """ 44 | 45 | :param frame2item: [1, 1, 1, 1, 2, 2, 3, 3, 3] 46 | :param values: 47 | :param masks: 48 | :param threshold: minimum ratio of unmasked frames required to be regarded as an unmasked item 49 | :return: item_values, item_dur, item_masks 50 | """ 51 | b = frame2item.shape[0] 52 | space = frame2item.max() + 1 53 | 54 | item_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add( 55 | 1, frame2item, torch.ones_like(frame2item) 56 | )[:, 1:] 57 | item_unmasked_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add( 58 | 1, frame2item, masks.long() 59 | )[:, 1:] 60 | item_masks = item_unmasked_dur / item_dur >= threshold 61 | 62 | values_quant = values.round().long() 63 | histogram = frame2item.new_zeros(b, space * 128, dtype=frame2item.dtype).scatter_add( 64 | 1, frame2item * 128 + values_quant, torch.ones_like(frame2item) * masks 65 | ).unflatten(1, [space, 128])[:, 1:, :] 66 | item_values_center = histogram.float().argmax(dim=2).to(dtype=values.dtype) 67 | values_center = torch.gather(F.pad(item_values_center, [1, 0]), 1, frame2item) 68 | values_near_center = masks & (values >= values_center - 0.5) & (values <= values_center + 0.5) 69 | item_valid_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add( 70 | 1, frame2item, values_near_center.long() 71 | )[:, 1:] 72 | item_values = values.new_zeros(b, space, dtype=values.dtype).scatter_add( 73 | 1, frame2item, values * values_near_center 74 | )[:, 1:] / (item_valid_dur + (item_valid_dur == 0)) 75 | 76 | return item_values, item_dur, item_masks 77 | 78 | 79 | def build_midi_file(offsets: List[float], segments: List[Dict[str, np.ndarray]], tempo=120) -> mido.MidiFile: 80 | midi_file = mido.MidiFile(charset='utf8') 81 | midi_track = mido.MidiTrack() 82 | midi_track.append(mido.MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0)) 83 | last_time = 0 84 | offsets = [round(o * tempo * 8) for o in offsets] 85 | for i, (offset, segment) in enumerate(zip(offsets, segments)): 86 | note_midi = np.round(segment['note_midi']).astype(np.int64).tolist() 87 | note_tick = np.diff(np.round(np.cumsum(segment['note_dur']) * tempo * 8).astype(np.int64), prepend=0).tolist() 88 | note_rest = segment['note_rest'].tolist() 89 | start = offset 90 | for j in range(len(note_midi)): 91 | end = start + note_tick[j] 92 | if i < len(offsets) - 1 and end > offsets[i + 1]: 93 | end = offsets[i + 1] 94 | if start < end and not note_rest[j]: 95 | midi_track.append(mido.Message('note_on', note=note_midi[j], time=start - last_time)) 96 | midi_track.append(mido.Message('note_off', note=note_midi[j], time=end - start)) 97 | last_time = end 98 | start = end 99 | midi_file.tracks.append(midi_track) 100 | return midi_file 101 | 102 | 103 | # if __name__ == '__main__': 104 | # frame2item = torch.LongTensor([ 105 | # [1, 1, 1, 1, 2, 2, 3, 3, 3, 0, 0, 0, 0, 0], 106 | # [1, 1, 1, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0, 0] 107 | # ]) 108 | # values = torch.FloatTensor([ 109 | # [60, 61, 60.5, 63, 57, 57, 50, 55, 54, 0, 0, 0, 0, 0], 110 | # [50, 51, 50.5, 53, 47, 47, 40, 45, 44, 38, 38, 0, 0, 0] 111 | # ]) 112 | # masks = frame2item > 0 113 | # decode_note_sequence(frame2item, values, masks) 114 | -------------------------------------------------------------------------------- /utils/multiprocess_utils.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import re 3 | import traceback 4 | 5 | from torch.multiprocessing import Manager, Process, current_process, get_context 6 | 7 | is_main_process = not bool(re.match(r'((.*Process)|(SyncManager)|(.*PoolWorker))-\d+', current_process().name)) 8 | 9 | 10 | def main_process_print(self, *args, sep=' ', end='\n', file=None): 11 | if is_main_process: 12 | print(self, *args, sep=sep, end=end, file=file) 13 | 14 | 15 | def chunked_worker_run(map_func, args, results_queue=None): 16 | for a in args: 17 | # noinspection PyBroadException 18 | try: 19 | res = map_func(*a) 20 | results_queue.put(res) 21 | except KeyboardInterrupt: 22 | break 23 | except Exception: 24 | traceback.print_exc() 25 | results_queue.put(None) 26 | 27 | 28 | def chunked_multiprocess_run(map_func, args, num_workers, q_max_size=1000): 29 | num_jobs = len(args) 30 | if num_jobs < num_workers: 31 | num_workers = num_jobs 32 | 33 | queues = [Manager().Queue(maxsize=q_max_size // num_workers) for _ in range(num_workers)] 34 | if platform.system().lower() != 'windows': 35 | process_creation_func = get_context('spawn').Process 36 | else: 37 | process_creation_func = Process 38 | 39 | workers = [] 40 | for i in range(num_workers): 41 | worker = process_creation_func( 42 | target=chunked_worker_run, args=(map_func, args[i::num_workers], queues[i]), daemon=True 43 | ) 44 | workers.append(worker) 45 | worker.start() 46 | 47 | for i in range(num_jobs): 48 | yield queues[i % num_workers].get() 49 | 50 | for worker in workers: 51 | worker.join() 52 | worker.close() 53 | -------------------------------------------------------------------------------- /utils/pitch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | f0_bin = 256 5 | f0_max = 1100.0 6 | f0_min = 50.0 7 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 8 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 9 | 10 | 11 | def f0_to_coarse(f0): 12 | is_torch = isinstance(f0, torch.Tensor) 13 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) 14 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 15 | 16 | f0_mel[f0_mel <= 1] = 1 17 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 18 | f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) 19 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) 20 | return f0_coarse 21 | 22 | 23 | def norm_f0(f0, uv=None): 24 | if uv is None: 25 | uv = f0 == 0 26 | f0 = np.log2(f0 + uv) # avoid arithmetic error 27 | f0[uv] = -np.inf 28 | return f0 29 | 30 | 31 | def interp_f0(f0, uv=None): 32 | if uv is None: 33 | uv = f0 == 0 34 | f0 = norm_f0(f0, uv) 35 | if uv.any() and not uv.all(): 36 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 37 | return denorm_f0(f0, uv=None), uv 38 | 39 | 40 | def denorm_f0(f0, uv, pitch_padding=None): 41 | f0 = 2 ** f0 42 | if uv is not None: 43 | f0[uv > 0] = 0 44 | if pitch_padding is not None: 45 | f0[pitch_padding] = 0 46 | return f0 47 | 48 | 49 | def resample_align_curve(points: np.ndarray, original_timestep: float, target_timestep: float, align_length: int): 50 | t_max = (len(points) - 1) * original_timestep 51 | curve_interp = np.interp( 52 | np.arange(0, t_max, target_timestep), 53 | original_timestep * np.arange(len(points)), 54 | points 55 | ).astype(points.dtype) 56 | delta_l = align_length - len(curve_interp) 57 | if delta_l < 0: 58 | curve_interp = curve_interp[:align_length] 59 | elif delta_l > 0: 60 | curve_interp = np.concatenate((curve_interp, np.full(delta_l, fill_value=curve_interp[-1])), axis=0) 61 | return curve_interp 62 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | from matplotlib.ticker import MultipleLocator 7 | 8 | LINE_COLORS = ['w', 'r', 'y', 'cyan', 'm', 'b', 'lime'] 9 | 10 | 11 | def spec_to_figure(spec, vmin=None, vmax=None): 12 | if isinstance(spec, torch.Tensor): 13 | spec = spec.cpu().numpy() 14 | fig = plt.figure(figsize=(12, 15)) 15 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax) 16 | plt.tight_layout() 17 | return fig 18 | 19 | 20 | def dur_to_figure(dur_gt, dur_pred, txt): 21 | if isinstance(dur_gt, torch.Tensor): 22 | dur_gt = dur_gt.cpu().numpy() 23 | if isinstance(dur_pred, torch.Tensor): 24 | dur_pred = dur_pred.cpu().numpy() 25 | dur_gt = dur_gt.astype(np.int64) 26 | dur_pred = dur_pred.astype(np.int64) 27 | dur_gt = np.cumsum(dur_gt) 28 | dur_pred = np.cumsum(dur_pred) 29 | width = max(12, min(48, len(txt) // 2)) 30 | fig = plt.figure(figsize=(width, 8)) 31 | plt.vlines(dur_pred, 12, 22, colors='r', label='pred') 32 | plt.vlines(dur_gt, 0, 10, colors='b', label='gt') 33 | for i in range(len(txt)): 34 | shift = (i % 8) + 1 35 | plt.text((dur_pred[i - 1] + dur_pred[i]) / 2 if i > 0 else dur_pred[i] / 2, 12 + shift, txt[i], 36 | size=16, horizontalalignment='center') 37 | plt.text((dur_gt[i - 1] + dur_gt[i]) / 2 if i > 0 else dur_gt[i] / 2, shift, txt[i], 38 | size=16, horizontalalignment='center') 39 | plt.plot([dur_pred[i], dur_gt[i]], [12, 10], color='black', linewidth=2, linestyle=':') 40 | plt.yticks([]) 41 | plt.xlim(0, max(dur_pred[-1], dur_gt[-1])) 42 | fig.legend() 43 | fig.tight_layout() 44 | return fig 45 | 46 | 47 | def boundary_to_figure( 48 | bounds_gt: np.ndarray, bounds_pred: np.ndarray, 49 | dur_gt: np.ndarray = None, dur_pred: np.ndarray = None 50 | ): 51 | fig = plt.figure(figsize=(12, 6)) 52 | bounds_acc_gt = np.cumsum(bounds_gt) 53 | bounds_acc_pred = np.cumsum(bounds_pred) 54 | plt.plot(bounds_acc_gt, color='b', label='gt') 55 | plt.plot(bounds_acc_pred, color='r', label='pred') 56 | if dur_gt is not None and dur_pred is not None: 57 | height = math.ceil(max(bounds_acc_gt[-1], bounds_acc_pred[-1])) 58 | dur_acc_gt = np.cumsum(dur_gt) 59 | dur_acc_pred = np.cumsum(dur_pred) 60 | plt.vlines(dur_acc_gt[:-1], 0, height / 2, colors='b', linestyles='--') 61 | plt.vlines(dur_acc_pred[:-1], height / 2, height, colors='r', linestyles='--') 62 | plt.gca().yaxis.set_major_locator(MultipleLocator(1)) 63 | plt.grid(axis='y') 64 | plt.legend() 65 | plt.tight_layout() 66 | return fig 67 | 68 | 69 | def pitch_notes_to_figure( 70 | pitch, note_midi_gt, note_dur_gt, note_rest_gt, 71 | note_midi_pred=None, note_dur_pred=None, note_rest_pred=None 72 | ): 73 | fig = plt.figure() 74 | 75 | def draw_notes(note_midi, note_dur, note_rest, color, label): 76 | note_dur_acc = np.cumsum(note_dur) 77 | if note_rest is None: 78 | note_rest = np.zeros_like(note_midi, dtype=np.bool_) 79 | labeled = False 80 | for i in range(len(note_midi)): 81 | if note_rest[i]: 82 | continue 83 | x0 = note_dur_acc[i - 1] if i > 0 else 0 84 | y0 = note_midi[i] - 0.5 85 | rec = plt.Rectangle( 86 | xy=(x0, y0), 87 | width=note_dur[i], height=1, 88 | edgecolor=color, fill=False, 89 | linewidth=1.5, label=label if not labeled else None, 90 | # linestyle='--' if note_rest[i] else '-' 91 | ) 92 | plt.gca().add_patch(rec) 93 | plt.fill_between([x0, x0 + note_dur[i]], y0, y0 + 1, color='none', facecolor=color, alpha=0.2) 94 | labeled = True 95 | 96 | draw_notes(note_midi_gt, note_dur_gt, note_rest_gt, color='b', label='gt') 97 | draw_notes(note_midi_pred, note_dur_pred, note_rest_pred, color='r', label='pred') 98 | plt.plot(pitch, color='grey', label='pitch') 99 | plt.gca().yaxis.set_major_locator(MultipleLocator(1)) 100 | plt.grid(axis='y') 101 | plt.legend() 102 | plt.tight_layout() 103 | return fig 104 | 105 | 106 | def curve_to_figure(curve_gt, curve_pred=None, curve_base=None, grid=None, base_label='base'): 107 | if isinstance(curve_gt, torch.Tensor): 108 | curve_gt = curve_gt.cpu().numpy() 109 | if isinstance(curve_pred, torch.Tensor): 110 | curve_pred = curve_pred.cpu().numpy() 111 | if isinstance(curve_base, torch.Tensor): 112 | curve_base = curve_base.cpu().numpy() 113 | fig = plt.figure() 114 | if curve_base is not None: 115 | plt.plot(curve_base, color='grey', label=base_label) 116 | plt.plot(curve_gt, color='b', label='gt') 117 | if curve_pred is not None: 118 | plt.plot(curve_pred, color='r', label='pred') 119 | if grid is not None: 120 | plt.gca().yaxis.set_major_locator(MultipleLocator(grid)) 121 | plt.grid(axis='y') 122 | plt.legend() 123 | plt.tight_layout() 124 | return fig 125 | 126 | 127 | def distribution_to_figure(title, x_label, y_label, items: list, values: list, zoom=0.8): 128 | fig = plt.figure(figsize=(int(len(items) * zoom), 10)) 129 | plt.bar(x=items, height=values) 130 | plt.tick_params(labelsize=15) 131 | plt.xlim(-1, len(items)) 132 | for a, b in zip(items, values): 133 | plt.text(a, b, b, ha='center', va='bottom', fontsize=15) 134 | plt.grid() 135 | plt.title(title, fontsize=30) 136 | plt.xlabel(x_label, fontsize=20) 137 | plt.ylabel(y_label, fontsize=20) 138 | return fig 139 | -------------------------------------------------------------------------------- /utils/slicer2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # This function is obtained from librosa. 5 | def get_rms( 6 | y, 7 | *, 8 | frame_length=2048, 9 | hop_length=512, 10 | pad_mode="constant", 11 | ): 12 | padding = (int(frame_length // 2), int(frame_length // 2)) 13 | y = np.pad(y, padding, mode=pad_mode) 14 | 15 | axis = -1 16 | # put our new within-frame axis at the end for now 17 | out_strides = y.strides + tuple([y.strides[axis]]) 18 | # Reduce the shape on the framing axis 19 | x_shape_trimmed = list(y.shape) 20 | x_shape_trimmed[axis] -= frame_length - 1 21 | out_shape = tuple(x_shape_trimmed) + tuple([frame_length]) 22 | xw = np.lib.stride_tricks.as_strided( 23 | y, shape=out_shape, strides=out_strides 24 | ) 25 | if axis < 0: 26 | target_axis = axis - 1 27 | else: 28 | target_axis = axis + 1 29 | xw = np.moveaxis(xw, -1, target_axis) 30 | # Downsample along the target axis 31 | slices = [slice(None)] * xw.ndim 32 | slices[axis] = slice(0, None, hop_length) 33 | x = xw[tuple(slices)] 34 | 35 | # Calculate power 36 | power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True) 37 | 38 | return np.sqrt(power) 39 | 40 | 41 | class Slicer: 42 | def __init__(self, 43 | sr: int, 44 | threshold: float = -40., 45 | min_length: int = 5000, 46 | min_interval: int = 300, 47 | hop_size: int = 20, 48 | max_sil_kept: int = 5000): 49 | if not min_length >= min_interval >= hop_size: 50 | raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size') 51 | if not max_sil_kept >= hop_size: 52 | raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size') 53 | min_interval = sr * min_interval / 1000 54 | self.sr = sr 55 | self.threshold = 10 ** (threshold / 20.) 56 | self.hop_size = round(sr * hop_size / 1000) 57 | self.win_size = min(round(min_interval), 4 * self.hop_size) 58 | self.min_length = round(sr * min_length / 1000 / self.hop_size) 59 | self.min_interval = round(min_interval / self.hop_size) 60 | self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) 61 | 62 | def _apply_slice(self, waveform, begin, end): 63 | chunk = { 64 | 'offset': begin * self.hop_size / self.sr 65 | } 66 | if len(waveform.shape) > 1: 67 | chunk['waveform'] = waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)] 68 | else: 69 | chunk['waveform'] = waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)] 70 | return chunk 71 | 72 | # @timeit 73 | def slice(self, waveform): 74 | if len(waveform.shape) > 1: 75 | samples = waveform.mean(axis=0) 76 | else: 77 | samples = waveform 78 | if (samples.shape[0] + self.hop_size - 1) // self.hop_size <= self.min_length: 79 | return [{'offset': 0, 'waveform': waveform}] 80 | rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) 81 | sil_tags = [] 82 | silence_start = None 83 | clip_start = 0 84 | for i, rms in enumerate(rms_list): 85 | # Keep looping while frame is silent. 86 | if rms < self.threshold: 87 | # Record start of silent frames. 88 | if silence_start is None: 89 | silence_start = i 90 | continue 91 | # Keep looping while frame is not silent and silence start has not been recorded. 92 | if silence_start is None: 93 | continue 94 | # Clear recorded silence start if interval is not enough or clip is too short 95 | is_leading_silence = silence_start == 0 and i > self.max_sil_kept 96 | need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length 97 | if not is_leading_silence and not need_slice_middle: 98 | silence_start = None 99 | continue 100 | # Need slicing. Record the range of silent frames to be removed. 101 | if i - silence_start <= self.max_sil_kept: 102 | pos = rms_list[silence_start: i + 1].argmin() + silence_start 103 | if silence_start == 0: 104 | sil_tags.append((0, pos)) 105 | else: 106 | sil_tags.append((pos, pos)) 107 | clip_start = pos 108 | elif i - silence_start <= self.max_sil_kept * 2: 109 | pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin() 110 | pos += i - self.max_sil_kept 111 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start 112 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept 113 | if silence_start == 0: 114 | sil_tags.append((0, pos_r)) 115 | clip_start = pos_r 116 | else: 117 | sil_tags.append((min(pos_l, pos), max(pos_r, pos))) 118 | clip_start = max(pos_r, pos) 119 | else: 120 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start 121 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept 122 | if silence_start == 0: 123 | sil_tags.append((0, pos_r)) 124 | else: 125 | sil_tags.append((pos_l, pos_r)) 126 | clip_start = pos_r 127 | silence_start = None 128 | # Deal with trailing silence. 129 | total_frames = rms_list.shape[0] 130 | if silence_start is not None and total_frames - silence_start >= self.min_interval: 131 | silence_end = min(total_frames, silence_start + self.max_sil_kept) 132 | pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start 133 | sil_tags.append((pos, total_frames + 1)) 134 | # Apply and return slices. 135 | if len(sil_tags) == 0: 136 | return [{'offset': 0, 'waveform': waveform}] 137 | else: 138 | chunks = [] 139 | if sil_tags[0][0] > 0: 140 | chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0])) 141 | for i in range(len(sil_tags) - 1): 142 | chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0])) 143 | if sil_tags[-1][1] < total_frames: 144 | chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames)) 145 | return chunks 146 | -------------------------------------------------------------------------------- /utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | from copy import deepcopy 4 | from pathlib import Path 5 | from typing import Dict 6 | 7 | import lightning.pytorch as pl 8 | import numpy as np 9 | import torch 10 | from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar 11 | from lightning.pytorch.utilities.rank_zero import rank_zero_info 12 | from torch.optim.lr_scheduler import LambdaLR 13 | from torch.utils.data.distributed import Sampler 14 | 15 | import utils 16 | 17 | 18 | # ==========LR schedulers========== 19 | 20 | class WarmupCosineSchedule(LambdaLR): 21 | """ Linear warmup and then cosine decay. 22 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 23 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 24 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 25 | `eta_min` (default=0.0) corresponds to the minimum learning rate reached by the scheduler. 26 | """ 27 | 28 | def __init__(self, optimizer, warmup_steps, t_total, eta_min=0.0, cycles=.5, last_epoch=-1): 29 | self.warmup_steps = warmup_steps 30 | self.t_total = t_total 31 | self.eta_min = eta_min 32 | self.cycles = cycles 33 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 34 | 35 | def lr_lambda(self, step): 36 | if step < self.warmup_steps: 37 | return step / max(1.0, self.warmup_steps) 38 | # progress after warmup 39 | progress = (step - self.warmup_steps) / max(1, self.t_total - self.warmup_steps) 40 | return max(self.eta_min, 0.5 * (1. + math.cos(math.pi * self.cycles * 2.0 * progress))) 41 | 42 | 43 | # ==========Torch samplers========== 44 | 45 | class DsBatchSampler(Sampler): 46 | def __init__(self, dataset, max_batch_frames, max_batch_size, sub_indices=None, 47 | num_replicas=None, rank=None, frame_count_grid=200, 48 | required_batch_count_multiple=1, batch_by_size=True, sort_by_similar_size=True, 49 | shuffle_sample=False, shuffle_batch=False, seed=0, drop_last=False) -> None: 50 | self.dataset = dataset 51 | self.max_batch_frames = max_batch_frames 52 | self.max_batch_size = max_batch_size 53 | self.sub_indices = sub_indices 54 | self.num_replicas = num_replicas 55 | self.rank = rank 56 | self.frame_count_grid = frame_count_grid 57 | self.required_batch_count_multiple = required_batch_count_multiple 58 | self.batch_by_size = batch_by_size 59 | self.sort_by_similar_size = sort_by_similar_size 60 | self.shuffle_sample = shuffle_sample 61 | self.shuffle_batch = shuffle_batch 62 | self.seed = seed 63 | self.drop_last = drop_last 64 | self.epoch = 0 65 | self.batches = None 66 | self.formed = None 67 | 68 | def __form_batches(self): 69 | if self.formed == self.epoch + self.seed: 70 | return 71 | rng = np.random.default_rng(self.seed + self.epoch) 72 | if self.shuffle_sample: 73 | if self.sub_indices is not None: 74 | rng.shuffle(self.sub_indices) 75 | indices = np.array(self.sub_indices) 76 | else: 77 | indices = rng.permutation(len(self.dataset)) 78 | 79 | if self.sort_by_similar_size: 80 | grid = self.frame_count_grid 81 | assert grid > 0 82 | sizes = (np.round(np.array(self.dataset._sizes)[indices] / grid) * grid).clip(grid, None).astype( 83 | np.int64) 84 | indices = indices[np.argsort(sizes, kind='mergesort')] 85 | 86 | indices = indices.tolist() 87 | else: 88 | indices = self.sub_indices if self.sub_indices is not None else list(range(len(self.dataset))) 89 | 90 | if self.batch_by_size: 91 | batches = utils.batch_by_size( 92 | indices, self.dataset.num_frames, 93 | max_batch_frames=self.max_batch_frames, 94 | max_batch_size=self.max_batch_size 95 | ) 96 | else: 97 | batches = [indices[i:i + self.max_batch_size] for i in range(0, len(indices), self.max_batch_size)] 98 | 99 | floored_total_batch_count = (len(batches) // self.num_replicas) * self.num_replicas 100 | if self.drop_last and len(batches) > floored_total_batch_count: 101 | batches = batches[:floored_total_batch_count] 102 | leftovers = [] 103 | else: 104 | leftovers = (rng.permutation(len(batches) - floored_total_batch_count) + floored_total_batch_count).tolist() 105 | 106 | batch_assignment = rng.permuted( 107 | np.arange(floored_total_batch_count).reshape(-1, self.num_replicas).transpose(), axis=0 108 | )[self.rank].tolist() 109 | floored_batch_count = len(batch_assignment) 110 | ceiled_batch_count = floored_batch_count + (1 if len(leftovers) > 0 else 0) 111 | if self.rank < len(leftovers): 112 | batch_assignment.append(leftovers[self.rank]) 113 | elif len(leftovers) > 0: 114 | batch_assignment.append(batch_assignment[self.epoch % floored_batch_count]) 115 | if self.required_batch_count_multiple > 1 and ceiled_batch_count % self.required_batch_count_multiple != 0: 116 | # batch_assignment = batch_assignment[:((floored_batch_count \ 117 | # // self.required_batch_count_multiple) * self.required_batch_count_multiple)] 118 | ceiled_batch_count = math.ceil( 119 | ceiled_batch_count / self.required_batch_count_multiple) * self.required_batch_count_multiple 120 | for i in range(ceiled_batch_count - len(batch_assignment)): 121 | batch_assignment.append( 122 | batch_assignment[(i + self.epoch * self.required_batch_count_multiple) % floored_batch_count]) 123 | 124 | self.batches = [deepcopy(batches[i]) for i in batch_assignment] 125 | 126 | if self.shuffle_batch: 127 | rng.shuffle(self.batches) 128 | 129 | del indices 130 | del batches 131 | del batch_assignment 132 | 133 | def __iter__(self): 134 | self.__form_batches() 135 | return iter(self.batches) 136 | 137 | def __len__(self): 138 | self.__form_batches() 139 | if self.batches is None: 140 | raise RuntimeError("Batches are not initialized. Call __form_batches first.") 141 | return len(self.batches) 142 | 143 | def set_epoch(self, epoch): 144 | self.epoch = epoch 145 | 146 | 147 | class DsEvalBatchSampler(Sampler): 148 | def __init__(self, dataset, max_batch_frames, max_batch_size, rank=None, batch_by_size=True) -> None: 149 | self.dataset = dataset 150 | self.max_batch_frames = max_batch_frames 151 | self.max_batch_size = max_batch_size 152 | self.rank = rank 153 | self.batch_by_size = batch_by_size 154 | self.batches = None 155 | self.batch_size = max_batch_size 156 | self.drop_last = False 157 | 158 | if self.rank == 0: 159 | indices = list(range(len(self.dataset))) 160 | if self.batch_by_size: 161 | self.batches = utils.batch_by_size( 162 | indices, self.dataset.num_frames, 163 | max_batch_frames=self.max_batch_frames, max_batch_size=self.max_batch_size 164 | ) 165 | else: 166 | self.batches = [ 167 | indices[i:i + self.max_batch_size] 168 | for i in range(0, len(indices), self.max_batch_size) 169 | ] 170 | else: 171 | self.batches = [[0]] 172 | 173 | def __iter__(self): 174 | return iter(self.batches) 175 | 176 | def __len__(self): 177 | return len(self.batches) 178 | 179 | 180 | # ==========PL related========== 181 | 182 | class DsModelCheckpoint(ModelCheckpoint): 183 | def __init__( 184 | self, 185 | *args, 186 | permanent_ckpt_start, 187 | permanent_ckpt_interval, 188 | **kwargs 189 | ): 190 | super().__init__(*args, **kwargs) 191 | self.permanent_ckpt_start = permanent_ckpt_start or 0 192 | self.permanent_ckpt_interval = permanent_ckpt_interval or 0 193 | self.enable_permanent_ckpt = self.permanent_ckpt_start > 0 and self.permanent_ckpt_interval > 9 194 | 195 | self._verbose = self.verbose 196 | self.verbose = False 197 | 198 | def state_dict(self): 199 | ret = super().state_dict() 200 | ret.pop('dirpath') 201 | return ret 202 | 203 | def load_state_dict(self, state_dict) -> None: 204 | super().load_state_dict(state_dict) 205 | 206 | def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 207 | if trainer.lightning_module.skip_immediate_ckpt_save: 208 | trainer.lightning_module.skip_immediate_ckpt_save = False 209 | return 210 | self.last_val_step = trainer.global_step 211 | super().on_validation_end(trainer, pl_module) 212 | 213 | def _update_best_and_save( 214 | self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, torch.Tensor] 215 | ) -> None: 216 | k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k 217 | 218 | del_filepath = None 219 | _op = max if self.mode == "min" else min 220 | while len(self.best_k_models) > k and k > 0: 221 | self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] 222 | self.kth_value = self.best_k_models[self.kth_best_model_path] 223 | 224 | del_filepath = self.kth_best_model_path 225 | self.best_k_models.pop(del_filepath) 226 | filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, del_filepath) 227 | if del_filepath is not None and filepath != del_filepath: 228 | self._remove_checkpoint(trainer, del_filepath) 229 | 230 | if len(self.best_k_models) == k and k > 0: 231 | self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] 232 | self.kth_value = self.best_k_models[self.kth_best_model_path] 233 | 234 | super()._update_best_and_save(current, trainer, monitor_candidates) 235 | 236 | def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: 237 | filepath = (Path(self.dirpath) / Path(filepath).name).resolve() 238 | super()._save_checkpoint(trainer, str(filepath)) 239 | if self._verbose: 240 | relative_path = filepath.relative_to(Path('.').resolve()) 241 | rank_zero_info(f'Checkpoint {relative_path} saved.') 242 | 243 | def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str): 244 | filepath = (Path(self.dirpath) / Path(filepath).name).resolve() 245 | relative_path = filepath.relative_to(Path('.').resolve()) 246 | search = re.search(r'steps_\d+', relative_path.stem) 247 | if search: 248 | step = int(search.group(0)[6:]) 249 | if self.enable_permanent_ckpt and \ 250 | step >= self.permanent_ckpt_start and \ 251 | (step - self.permanent_ckpt_start) % self.permanent_ckpt_interval == 0: 252 | rank_zero_info(f'Checkpoint {relative_path} is now permanent.') 253 | return 254 | super()._remove_checkpoint(trainer, filepath) 255 | if self._verbose: 256 | rank_zero_info(f'Removed checkpoint {relative_path}.') 257 | 258 | 259 | def get_latest_checkpoint_path(work_dir): 260 | if not isinstance(work_dir, Path): 261 | work_dir = Path(work_dir) 262 | if not work_dir.exists(): 263 | return None 264 | 265 | last_step = -1 266 | last_ckpt_name = None 267 | 268 | for ckpt in work_dir.glob('model_ckpt_steps_*.ckpt'): 269 | search = re.search(r'steps_\d+', ckpt.name) 270 | if search: 271 | step = int(search.group(0)[6:]) 272 | if step > last_step: 273 | last_step = step 274 | last_ckpt_name = str(ckpt) 275 | 276 | return last_ckpt_name if last_ckpt_name is not None else None 277 | 278 | 279 | class DsTQDMProgressBar(TQDMProgressBar): 280 | def __init__(self, refresh_rate: int = 1, process_position: int = 0, show_steps: bool = True): 281 | super().__init__(refresh_rate, process_position) 282 | self.show_steps = show_steps 283 | 284 | def get_metrics(self, trainer, model): 285 | items = super().get_metrics(trainer, model) 286 | if 'batch_size' in items: 287 | items['batch_size'] = int(items['batch_size']) 288 | if self.show_steps: 289 | items['steps'] = str(trainer.global_step) 290 | for k, v in items.items(): 291 | if isinstance(v, float): 292 | if np.isnan(v): 293 | items[k] = 'nan' 294 | elif 0.001 <= v < 10: 295 | items[k] = np.format_float_positional(v, unique=True, precision=5, trim='-') 296 | elif 0.00001 <= v < 0.001: 297 | if len(np.format_float_positional(v, unique=True, precision=8, trim='-')) > 8: 298 | items[k] = np.format_float_scientific(v, precision=3, unique=True, min_digits=2, trim='-') 299 | else: 300 | items[k] = np.format_float_positional(v, unique=True, precision=5, trim='-') 301 | elif v < 0.00001: 302 | items[k] = np.format_float_scientific(v, precision=3, unique=True, min_digits=2, trim='-') 303 | items.pop("v_num", None) 304 | return items 305 | 306 | 307 | def get_strategy(strategy): 308 | if strategy['name'] == 'auto': 309 | return 'auto' 310 | 311 | from lightning.pytorch.strategies import StrategyRegistry 312 | if strategy['name'] not in StrategyRegistry: 313 | available_names = ", ".join(sorted(StrategyRegistry.keys())) or "none" 314 | raise ValueError(f"Invalid strategy name {strategy['name']}. Available names: {available_names}") 315 | 316 | data = StrategyRegistry[strategy['name']] 317 | params = data['init_params'] 318 | params.update({k: v for k, v in strategy.items() if k != 'name'}) 319 | return data['strategy'](**utils.filter_kwargs(params, data['strategy'])) 320 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import pathlib 4 | import time 5 | from typing import Dict, Tuple 6 | 7 | import click 8 | import gradio as gr 9 | import librosa 10 | import yaml 11 | 12 | import inference 13 | from inference import BaseInference 14 | from utils.infer_utils import build_midi_file 15 | from utils.slicer2 import Slicer 16 | 17 | _work_dir: pathlib.Path = None 18 | _infer_instances: Dict[str, Tuple[BaseInference, dict]] = {} # dict mapping model_rel_path to (infer_ins, config) 19 | 20 | 21 | def infer(model_rel_path, input_audio_path, tempo_value): 22 | if not model_rel_path or not input_audio_path or tempo_value is None: 23 | return None, "Error: required inputs not specified." 24 | if model_rel_path not in _infer_instances: 25 | model_path = _work_dir / model_rel_path 26 | with open(model_path.with_name('config.yaml'), 'r', encoding='utf8') as f: 27 | config = yaml.safe_load(f) 28 | infer_cls = inference.task_inference_mapping[config['task_cls']] 29 | 30 | pkg = ".".join(infer_cls.split(".")[:-1]) 31 | cls_name = infer_cls.split(".")[-1] 32 | infer_cls = getattr(importlib.import_module(pkg), cls_name) 33 | assert issubclass(infer_cls, inference.BaseInference), \ 34 | f'Binarizer class {infer_cls} is not a subclass of {inference.BaseInference}.' 35 | infer_ins = infer_cls(config=config, model_path=model_path) 36 | print(f"Initialized: {infer_ins}") 37 | _infer_instances[model_rel_path] = (infer_ins, config) 38 | else: 39 | infer_ins, config = _infer_instances[model_rel_path] 40 | 41 | input_audio_path = pathlib.Path(input_audio_path) 42 | total_duration = librosa.get_duration(filename=input_audio_path) 43 | if total_duration > 20 * 60: # 20 minutes 44 | return None, f"Error: the input audio is too long (>= 20 minutes)." 45 | 46 | try: 47 | waveform, _ = librosa.load(input_audio_path, sr=config['audio_sample_rate'], mono=True) 48 | except: 49 | return None, f"Error: unsupported or corrupt file format: {input_audio_path.name}" 50 | 51 | start_time = time.time() 52 | slicer = Slicer(sr=config['audio_sample_rate'], max_sil_kept=1000) 53 | chunks = slicer.slice(waveform) 54 | midis = infer_ins.infer([c['waveform'] for c in chunks]) 55 | infer_time = time.time() - start_time 56 | rtf = infer_time / total_duration 57 | print(f'RTF: {rtf}') 58 | 59 | midi_file = build_midi_file([c['offset'] for c in chunks], midis, tempo=tempo_value) 60 | 61 | output_midi_path = input_audio_path.with_suffix('.mid') 62 | midi_file.save(output_midi_path) 63 | os.remove(input_audio_path) 64 | 65 | return output_midi_path, f"Cost {round(infer_time, 2)} s, RTF: {round(rtf, 3)}" 66 | 67 | 68 | @click.command(help='Launch the web UI for inference') 69 | @click.option('--port', type=int, default=7860, help='Server port') 70 | @click.option('--addr', type=str, required=False, help='Server address') 71 | @click.option('--work_dir', type=str, required=False, help='Directory to read the experiments') 72 | def webui(port, work_dir, addr): 73 | if work_dir is None: 74 | work_dir = pathlib.Path(__file__).with_name('experiments') 75 | else: 76 | work_dir = pathlib.Path(work_dir) 77 | assert work_dir.is_dir(), f'{work_dir} is not a directory.' 78 | global _work_dir 79 | _work_dir = work_dir 80 | choices = [ 81 | p.relative_to(work_dir).as_posix() 82 | for p in work_dir.rglob('*.ckpt') 83 | ] 84 | if len(choices) == 0: 85 | raise FileNotFoundError(f'No checkpoints found in {work_dir}.') 86 | iface = gr.Interface( 87 | title="SOME: Singing-Oriented MIDI Extractor", 88 | description="Submit an audio file and download the extracted MIDI file.", 89 | theme="default", 90 | fn=infer, 91 | inputs=[ 92 | gr.components.Dropdown( 93 | label="Model Checkpoint", choices=choices, value=choices[0], 94 | multiselect=False, allow_custom_value=False 95 | ), 96 | gr.components.Audio(label="Input Audio File", type="filepath"), 97 | gr.components.Number(label='Tempo Value', minimum=20, maximum=200, value=120), 98 | ], 99 | outputs=[ 100 | gr.components.File(label="Output MIDI File", file_types=['.mid']), 101 | gr.components.Label(label="Inference Statistics"), 102 | ] 103 | ) 104 | iface.queue(concurrency_count=10) 105 | iface.launch(server_port=port, server_name=addr) 106 | 107 | 108 | if __name__ == "__main__": 109 | webui() 110 | --------------------------------------------------------------------------------