├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── audio.cpython-310.pyc │ │ ├── audio.cpython-38.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── datasets.cpython-310.pyc │ │ └── datasets.cpython-38.pyc │ ├── audio.py │ └── datasets.py ├── ddp │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── distrib.cpython-310.pyc │ │ ├── distrib.cpython-38.pyc │ │ ├── executor.cpython-38.pyc │ │ ├── __init__.cpython-310.pyc │ │ └── executor.cpython-310.pyc │ ├── executor.py │ └── distrib.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── aero.cpython-310.pyc │ │ ├── aero.cpython-38.pyc │ │ ├── snake.cpython-38.pyc │ │ ├── spec.cpython-310.pyc │ │ ├── spec.cpython-38.pyc │ │ ├── utils.cpython-38.pyc │ │ ├── demucs.cpython-310.pyc │ │ ├── modules.cpython-38.pyc │ │ ├── seanet.cpython-310.pyc │ │ ├── seanet.cpython-38.pyc │ │ ├── snake.cpython-310.pyc │ │ ├── utils.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── modules.cpython-310.pyc │ │ ├── stft_loss.cpython-310.pyc │ │ ├── stft_loss.cpython-38.pyc │ │ ├── modelFactory.cpython-38.pyc │ │ ├── discriminators.cpython-310.pyc │ │ ├── discriminators.cpython-38.pyc │ │ └── modelFactory.cpython-310.pyc │ ├── test_demucs.ipynb │ ├── modelFactory.py │ ├── spec.py │ ├── utils.py │ ├── snake.py │ ├── stft_loss.py │ ├── seanet.py │ ├── discriminators.py │ └── modules.py ├── __pycache__ │ ├── enhance.cpython-38.pyc │ ├── metrics.cpython-38.pyc │ ├── solver.cpython-310.pyc │ ├── solver.cpython-38.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-38.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── enhance.cpython-310.pyc │ ├── evaluate.cpython-310.pyc │ ├── evaluate.cpython-38.pyc │ ├── metrics.cpython-310.pyc │ ├── wandb_logger.cpython-38.pyc │ ├── wandb_logger.cpython-310.pyc │ ├── model_serializer.cpython-310.pyc │ └── model_serializer.cpython-38.pyc ├── model_serializer.py ├── enhance.py ├── metrics.py ├── evaluate.py ├── wandb_logger.py └── utils.py ├── outputs ├── chopin-11-44 │ └── aeromamba │ │ ├── trainer.log │ │ └── .hydra │ │ ├── overrides.yaml │ │ ├── config.yaml │ │ └── hydra.yaml └── musdb-mixture-11-44 │ └── aeromamba │ ├── wandb │ ├── run-20241029_105349-0ilrdhbg │ │ ├── run-0ilrdhbg.wandb │ │ ├── logs │ │ │ ├── debug-internal.log │ │ │ ├── debug-core.log │ │ │ └── debug.log │ │ └── files │ │ │ ├── wandb-metadata.json │ │ │ ├── output.log │ │ │ └── requirements.txt │ ├── run-20241029_105445-p0dnbwer │ │ ├── files │ │ │ ├── wandb-summary.json │ │ │ ├── wandb-metadata.json │ │ │ ├── config.yaml │ │ │ ├── requirements.txt │ │ │ └── output.log │ │ ├── run-p0dnbwer.wandb │ │ └── logs │ │ │ ├── debug-core.log │ │ │ ├── debug-internal.log │ │ │ └── debug.log │ ├── debug-internal.log │ └── debug.log │ ├── .hydra │ ├── overrides.yaml │ ├── config.yaml │ └── hydra.yaml │ └── trainer.log ├── conf ├── dset │ ├── chopin-11-44.yaml │ ├── chopin-11-44-HQ.yaml │ ├── chopin-11-44-one.yaml │ └── musdb-mixture-11-44.yaml ├── experiment │ ├── aero.yaml │ └── aeromamba.yaml └── main_config.yaml ├── requirements.txt ├── predict_batch.sh ├── data_prep ├── resample_data.py └── create_meta_files.py ├── test.py ├── predict.py ├── predict_batch_with_ola.py ├── README.md ├── train.py └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/ddp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /outputs/chopin-11-44/aeromamba/trainer.log: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105349-0ilrdhbg/run-0ilrdhbg.wandb: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/files/wandb-summary.json: -------------------------------------------------------------------------------- 1 | {"_wandb":{"runtime":43}} -------------------------------------------------------------------------------- /conf/dset/chopin-11-44.yaml: -------------------------------------------------------------------------------- 1 | # @package dset 2 | name: chopin-11-44 3 | train: egs/chopin-11-44/tr 4 | valid: 5 | test: egs/chopin-11-44/tt -------------------------------------------------------------------------------- /src/__pycache__/enhance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/enhance.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/solver.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/solver.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/solver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/solver.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /conf/dset/chopin-11-44-HQ.yaml: -------------------------------------------------------------------------------- 1 | # @package dset 2 | name: chopin-11-44-HQ 3 | train: egs/chopin-11-44-HQ/tr 4 | valid: 5 | test: egs/chopin-11-44-HQ/tt -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/enhance.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/enhance.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/evaluate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/evaluate.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/evaluate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/evaluate.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/metrics.cpython-310.pyc -------------------------------------------------------------------------------- /conf/dset/chopin-11-44-one.yaml: -------------------------------------------------------------------------------- 1 | # @package dset 2 | name: chopin-11-44-one 3 | train: egs/chopin-11-44-one/tr 4 | valid: 5 | test: egs/chopin-11-44-one/tt -------------------------------------------------------------------------------- /src/__pycache__/wandb_logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/wandb_logger.cpython-38.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/audio.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/data/__pycache__/audio.cpython-310.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/audio.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/data/__pycache__/audio.cpython-38.pyc -------------------------------------------------------------------------------- /src/ddp/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/ddp/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/ddp/__pycache__/distrib.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/ddp/__pycache__/distrib.cpython-310.pyc -------------------------------------------------------------------------------- /src/ddp/__pycache__/distrib.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/ddp/__pycache__/distrib.cpython-38.pyc -------------------------------------------------------------------------------- /src/ddp/__pycache__/executor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/ddp/__pycache__/executor.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/aero.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/aero.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/aero.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/aero.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/snake.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/snake.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/spec.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/spec.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/spec.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/spec.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/wandb_logger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/wandb_logger.cpython-310.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/data/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/datasets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/data/__pycache__/datasets.cpython-310.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/data/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /src/ddp/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/ddp/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/ddp/__pycache__/executor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/ddp/__pycache__/executor.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/demucs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/demucs.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/seanet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/seanet.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/seanet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/seanet.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/snake.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/snake.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/model_serializer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/model_serializer.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/model_serializer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/__pycache__/model_serializer.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/stft_loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/stft_loss.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/stft_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/stft_loss.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/modelFactory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/modelFactory.cpython-38.pyc -------------------------------------------------------------------------------- /conf/dset/musdb-mixture-11-44.yaml: -------------------------------------------------------------------------------- 1 | # @package dset 2 | name: musdb-mixture-11-44 3 | train: egs/musdb-mixture-11-44/tr 4 | valid: 5 | test: egs/musdb-mixture-11-44/tt 6 | 7 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/.hydra/overrides.yaml: -------------------------------------------------------------------------------- 1 | - checkpoint_file=/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/chopin-11-44-one/aeromamba/checkpoint.th 2 | -------------------------------------------------------------------------------- /src/models/__pycache__/discriminators.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/discriminators.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/discriminators.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/discriminators.cpython-38.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/modelFactory.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/src/models/__pycache__/modelFactory.cpython-310.pyc -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/run-p0dnbwer.wandb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aeromamba-super-resolution/aeromamba/HEAD/outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/run-p0dnbwer.wandb -------------------------------------------------------------------------------- /outputs/chopin-11-44/aeromamba/.hydra/overrides.yaml: -------------------------------------------------------------------------------- 1 | - dset=chopin-11-44 2 | - experiment=aeromamba 3 | - +filename=/home/wallace.abreu/Mestrado/aeromamba-lamir/ds_datasets/chopin/test/V.ASHKENAZYTrack11.wav 4 | - +output=/home/wallace.abreu/Mestrado/aeromamba-lamir/test 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colorlog==5.0.1 2 | hydra-colorlog==1.1.0 3 | hydra-core==1.1.1 4 | hyperlink==17.3.1 5 | HyperPyYAML==1.0.0 6 | matplotlib 7 | numpy==1.26.4 8 | opencv-python==4.9.0.80 9 | python==3.10.0 10 | soundfile 11 | sox 12 | torch==1.12.1+cu113 13 | torchvision==0.13.1+cu113 14 | torchaudio==0.12.1 15 | tqdm 16 | wandb -------------------------------------------------------------------------------- /src/models/test_demucs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "mamba", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "name": "python", 19 | "version": "3.10.0" 20 | } 21 | }, 22 | "nbformat": 4, 23 | "nbformat_minor": 2 24 | } 25 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/trainer.log: -------------------------------------------------------------------------------- 1 | [2024-10-29 10:54:45,319][__main__][INFO] - For logs, checkpoints and samples check /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba 2 | [2024-10-29 10:54:45,320][src.wandb_logger][INFO] - current path: /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba, rank: None 3 | [2024-10-29 10:54:55,262][__main__][INFO] - Loading model aero from last state. 4 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105349-0ilrdhbg/logs/debug-internal.log: -------------------------------------------------------------------------------- 1 | {"time":"2024-10-29T10:53:49.374481967-03:00","level":"INFO","msg":"using version","core version":"0.18.5"} 2 | {"time":"2024-10-29T10:53:49.374506239-03:00","level":"INFO","msg":"created symlink","path":"/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105349-0ilrdhbg/logs/debug-core.log"} 3 | {"time":"2024-10-29T10:53:49.595564208-03:00","level":"INFO","msg":"created new stream","id":"0ilrdhbg"} 4 | {"time":"2024-10-29T10:53:49.595638337-03:00","level":"INFO","msg":"stream: started","id":"0ilrdhbg"} 5 | {"time":"2024-10-29T10:53:49.59583599-03:00","level":"INFO","msg":"sender: started","stream_id":"0ilrdhbg"} 6 | {"time":"2024-10-29T10:53:49.595748251-03:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"0ilrdhbg"}} 7 | {"time":"2024-10-29T10:53:49.59583588-03:00","level":"INFO","msg":"handler: started","stream_id":{"value":"0ilrdhbg"}} 8 | {"time":"2024-10-29T10:53:50.274006895-03:00","level":"INFO","msg":"Starting system monitor"} 9 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105349-0ilrdhbg/logs/debug-core.log: -------------------------------------------------------------------------------- 1 | {"time":"2024-10-29T10:53:48.239410953-03:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmpqozzfhjx/port-296384.txt","pid":296384,"debug":false,"disable-analytics":false} 2 | {"time":"2024-10-29T10:53:48.239440376-03:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false} 3 | {"time":"2024-10-29T10:53:48.24054031-03:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":296384} 4 | {"time":"2024-10-29T10:53:48.240534395-03:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":38075,"Zone":""}} 5 | {"time":"2024-10-29T10:53:48.39099901-03:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:34394"} 6 | {"time":"2024-10-29T10:53:49.359181719-03:00","level":"INFO","msg":"handleInformInit: received","streamId":"0ilrdhbg","id":"127.0.0.1:34394"} 7 | {"time":"2024-10-29T10:53:49.595661072-03:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"0ilrdhbg","id":"127.0.0.1:34394"} 8 | {"time":"2024-10-29T10:53:52.759559195-03:00","level":"INFO","msg":"Parent process exited, terminating service process."} 9 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105349-0ilrdhbg/files/wandb-metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "os": "Linux-6.0.12-100.fc35.x86_64-x86_64-with-glibc2.34", 3 | "python": "3.10.0", 4 | "startedAt": "2024-10-29T13:53:49.354098Z", 5 | "program": "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py", 6 | "email": "abreu.engcb@poli.ufrj.br", 7 | "root": "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba", 8 | "host": "zermatt.smt.ufrj.br", 9 | "username": "wallace.abreu", 10 | "executable": "/home/wallace.abreu/miniconda3/envs/lamir_test/bin/python", 11 | "cpu_count": 4, 12 | "cpu_count_logical": 8, 13 | "gpu": "NVIDIA GeForce RTX 2080 Ti", 14 | "gpu_count": 1, 15 | "disk": { 16 | "/": { 17 | "total": "78675447808", 18 | "used": "59385651200" 19 | } 20 | }, 21 | "memory": { 22 | "total": "33512456192" 23 | }, 24 | "cpu": { 25 | "count": 4, 26 | "countLogical": 8 27 | }, 28 | "gpu_nvidia": [ 29 | { 30 | "name": "NVIDIA GeForce RTX 2080 Ti", 31 | "memoryTotal": "11811160064", 32 | "cudaCores": 4352, 33 | "architecture": "Turing" 34 | } 35 | ], 36 | "cudaVersion": "11.8" 37 | } -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105349-0ilrdhbg/files/output.log: -------------------------------------------------------------------------------- 1 | [2024-10-29 10:53:52,672][__main__][ERROR] - Some error happened 2 | Traceback (most recent call last): 3 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py", line 79, in main 4 | _main(args) 5 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py", line 72, in _main 6 | run(args) 7 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py", line 47, in run 8 | model = _load_model(args) 9 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py", line 30, in _load_model 10 | package = torch.load(checkpoint_file, 'cpu') 11 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/torch/serialization.py", line 699, in load 12 | with _open_file_like(f, 'rb') as opened_file: 13 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/torch/serialization.py", line 230, in _open_file_like 14 | return _open_file(name_or_buffer, mode) 15 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/torch/serialization.py", line 211, in __init__ 16 | super(_open_file, self).__init__(open(name, mode)) 17 | FileNotFoundError: [Errno 2] No such file or directory: 'checkpoint.th' 18 | -------------------------------------------------------------------------------- /predict_batch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check for the number of command-line arguments 4 | if [ "$#" -ne 2 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | # Set the input and output folder paths from command-line arguments 10 | input_folder="$1" 11 | output_folder="$2" 12 | 13 | # Ensure the input folder exists 14 | if [ ! -d "$input_folder" ]; then 15 | echo "Input folder does not exist: $input_folder" 16 | exit 1 17 | fi 18 | 19 | # Create the output folder if it doesn't exist 20 | if [ ! -d "$output_folder" ]; then 21 | mkdir -p "$output_folder" 22 | fi 23 | 24 | # Loop through files in the input folder 25 | for file in "$input_folder"/*; do 26 | if [ -f "$file" ]; then 27 | # Extract the base name of the input file (without extension) 28 | base_name=$(basename "$file" .wav) 29 | 30 | # Construct the expected output file name 31 | output_file="${output_folder}/${base_name}_ola.wav" 32 | 33 | # Check if the output file already exists 34 | if [ ! -f "$output_file" ]; then 35 | # Run the Python command with the complete file path (including extension) 36 | python predict.py dset=chopin-11-44-one experiment=aeromamba +filename="$file" +output="$output_folder" 37 | else 38 | echo "Output file already exists: $output_file" 39 | fi 40 | fi 41 | done 42 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/files/wandb-metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "os": "Linux-6.0.12-100.fc35.x86_64-x86_64-with-glibc2.34", 3 | "python": "3.10.0", 4 | "startedAt": "2024-10-29T13:54:46.001594Z", 5 | "args": [ 6 | "checkpoint_file=/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/chopin-11-44-one/aeromamba/checkpoint.th" 7 | ], 8 | "program": "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py", 9 | "email": "abreu.engcb@poli.ufrj.br", 10 | "root": "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba", 11 | "host": "zermatt.smt.ufrj.br", 12 | "username": "wallace.abreu", 13 | "executable": "/home/wallace.abreu/miniconda3/envs/lamir_test/bin/python", 14 | "cpu_count": 4, 15 | "cpu_count_logical": 8, 16 | "gpu": "NVIDIA GeForce RTX 2080 Ti", 17 | "gpu_count": 1, 18 | "disk": { 19 | "/": { 20 | "total": "78675447808", 21 | "used": "59385651200" 22 | } 23 | }, 24 | "memory": { 25 | "total": "33512456192" 26 | }, 27 | "cpu": { 28 | "count": 4, 29 | "countLogical": 8 30 | }, 31 | "gpu_nvidia": [ 32 | { 33 | "name": "NVIDIA GeForce RTX 2080 Ti", 34 | "memoryTotal": "11811160064", 35 | "cudaCores": 4352, 36 | "architecture": "Turing" 37 | } 38 | ], 39 | "cudaVersion": "11.8" 40 | } -------------------------------------------------------------------------------- /src/models/modelFactory.py: -------------------------------------------------------------------------------- 1 | from src.models.aero import Aero 2 | from src.models.seanet import Seanet 3 | from src.models.discriminators import Discriminator, MultiPeriodDiscriminator, MultiScaleDiscriminator 4 | 5 | 6 | def get_model(args): 7 | if args.experiment.model == 'aero': 8 | generator = Aero(**args.experiment.aero) 9 | elif args.experiment.model == 'seanet': 10 | generator = Seanet(**args.experiment.seanet) 11 | 12 | models = {'generator': generator} 13 | 14 | if 'adversarial' in args.experiment and args.experiment.adversarial: 15 | if 'msd_melgan' in args.experiment.discriminator_models: 16 | discriminator = Discriminator(**args.experiment.melgan_discriminator) 17 | models.update({'msd_melgan': discriminator}) 18 | if 'msd' in args.experiment.discriminator_models: 19 | msd = MultiScaleDiscriminator(**args.experiment.msd) 20 | models.update({'msd': msd}) 21 | if 'mpd' in args.experiment.discriminator_models: 22 | mpd = MultiPeriodDiscriminator(**args.experiment.mpd) 23 | models.update({'mpd': mpd}) 24 | if 'hifi' in args.experiment.discriminator_models: 25 | mpd = MultiPeriodDiscriminator(**args.experiment.mpd) 26 | msd = MultiScaleDiscriminator(**args.experiment.msd) 27 | models.update({'mpd': mpd, 'msd': msd}) 28 | 29 | return models -------------------------------------------------------------------------------- /src/models/spec.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs 3 | """ 4 | """Conveniance wrapper to perform STFT and iSTFT""" 5 | 6 | import torch as th 7 | 8 | 9 | def spectro(x, n_fft=512, hop_length=None, pad=0, win_length=None): 10 | *other, length = x.shape 11 | x = x.reshape(-1, length) 12 | z = th.stft(x, 13 | n_fft * (1 + pad), 14 | hop_length or n_fft // 4, 15 | window=th.hann_window(win_length).to(x), 16 | win_length=win_length or n_fft, 17 | normalized=True, 18 | center=True, 19 | return_complex=True, 20 | pad_mode='reflect') 21 | _, freqs, frame = z.shape 22 | return z.view(*other, freqs, frame) 23 | 24 | 25 | def ispectro(z, hop_length=None, length=None, pad=0, win_length=None): 26 | *other, freqs, frames = z.shape 27 | n_fft = 2 * freqs - 2 28 | z = z.view(-1, freqs, frames) 29 | win_length = win_length or n_fft // (1 + pad) 30 | x = th.istft(z, 31 | n_fft, 32 | hop_length or n_fft // 2, 33 | window=th.hann_window(win_length).to(z.real), 34 | win_length=win_length, 35 | normalized=True, 36 | length=length, 37 | center=True) 38 | _, length = x.shape 39 | return x.view(*other, length) -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | from torch.nn import functional as F 4 | import torch 5 | 6 | def capture_init(init): 7 | """capture_init. 8 | 9 | Decorate `__init__` with this, and you can then 10 | recover the *args and **kwargs passed to it in `self._init_args_kwargs` 11 | """ 12 | 13 | @functools.wraps(init) 14 | def __init__(self, *args, **kwargs): 15 | self._init_args_kwargs = (args, kwargs) 16 | init(self, *args, **kwargs) 17 | 18 | return __init__ 19 | 20 | def unfold(a, kernel_size, stride): 21 | """Given input of size [*OT, T], output Tensor of size [*OT, F, K] 22 | with K the kernel size, by extracting frames with the given stride. 23 | This will pad the input so that `F = ceil(T / K)`. 24 | see https://github.com/pytorch/pytorch/issues/60466 25 | """ 26 | *shape, length = a.shape 27 | n_frames = math.ceil(length / stride) 28 | tgt_length = (n_frames - 1) * stride + kernel_size 29 | a = F.pad(a, (0, tgt_length - length)) 30 | strides = list(a.stride()) 31 | assert strides[-1] == 1, 'data should be contiguous' 32 | strides = strides[:-1] + [stride, 1] 33 | return a.as_strided([*shape, n_frames, kernel_size], strides) 34 | 35 | 36 | def weights_init(m): 37 | classname = m.__class__.__name__ 38 | if classname.find("Conv") != -1: 39 | m.weight.data.normal_(0.0, 0.02) 40 | elif classname.find("BatchNorm2d") != -1: 41 | m.weight.data.normal_(1.0, 0.02) 42 | m.bias.data.fill_(0) -------------------------------------------------------------------------------- /conf/experiment/aero.yaml: -------------------------------------------------------------------------------- 1 | # @package experiment 2 | name: aero 3 | 4 | #Dataset related 5 | lr_sr: 11025 # low resolution sample rate, added to support BWE. Should be included in training cfg 6 | hr_sr: 44100 # high resolution sample rate. Should be included in training cfg 7 | segment: 4 8 | stride: 4 # in seconds, how much to stride between training examples 9 | pad: true # if training sample is too short, pad it 10 | upsample: false 11 | batch_size: 1 12 | nfft: 512 13 | hop_length: 256 14 | fixed_n_examples: null 15 | power_threshold: 1e-3 16 | 17 | # models related 18 | model: aero 19 | aero: # see aero.py for a detailed description: 20 | in_channels: 1 21 | out_channels: 1 22 | # Channels 23 | channels: 48 24 | growth: 2 25 | # STFT 26 | nfft: ${experiment.nfft} 27 | hop_length: ${experiment.hop_length} 28 | end_iters: 0 29 | cac: true 30 | # Main structure 31 | rewrite: true 32 | hybrid: false 33 | hybrid_old: false 34 | # Frequency Branch 35 | freq_emb: 0.2 36 | emb_scale: 10 37 | emb_smooth: true 38 | # Convolutions 39 | kernel_size: 8 40 | strides: [ 4,4,2,2 ] 41 | context: 1 42 | context_enc: 0 43 | freq_ends: 4 44 | enc_freq_attn: 0 45 | # normalization 46 | norm_starts: 2 47 | norm_groups: 4 48 | # DConv residual branch 49 | dconv_mode: 1 50 | dconv_depth: 2 51 | dconv_comp: 4 52 | dconv_time_attn: 2 53 | dconv_lstm: 2 54 | dconv_init: 1e-3 55 | # Weight init 56 | rescale: 0.1 57 | lr_sr: ${experiment.lr_sr} 58 | hr_sr: ${experiment.hr_sr} 59 | spec_upsample: true 60 | act_func: snake 61 | debug: false 62 | 63 | adversarial: True 64 | features_loss_lambda: 100 65 | only_features_loss: False 66 | only_adversarial_loss: False 67 | discriminator_models: [ msd_melgan ] #msd_melgan/msd_hifi/mpd/hifi 68 | melgan_discriminator: 69 | n_layers: 4 70 | num_D: 3 71 | downsampling_factor: 4 72 | ndf: 16 73 | -------------------------------------------------------------------------------- /conf/experiment/aeromamba.yaml: -------------------------------------------------------------------------------- 1 | # @package experiment 2 | name: aeromamba 3 | 4 | #Dataset related 5 | lr_sr: 11025 # low resolution sample rate, added to support BWE. Should be included in training cfg 6 | hr_sr: 44100 # high resolution sample rate. Should be included in training cfg 7 | segment: 4 8 | stride: 4 # in seconds, how much to stride between training examples 9 | pad: true # if training sample is too short, pad it 10 | upsample: false 11 | batch_size: 4 12 | nfft: 512 13 | hop_length: 256 14 | fixed_n_examples: null 15 | power_threshold: 1e-3 16 | 17 | # models related 18 | model: aero 19 | aero: # see aero.py for a detailed description: 20 | in_channels: 1 21 | out_channels: 1 22 | # Channels 23 | channels: 48 24 | growth: 2 25 | # STFT 26 | nfft: ${experiment.nfft} 27 | hop_length: ${experiment.hop_length} 28 | end_iters: 0 29 | cac: true 30 | # Main structure 31 | rewrite: true 32 | hybrid: false 33 | hybrid_old: false 34 | # Frequency Branch 35 | freq_emb: 0.2 36 | emb_scale: 10 37 | emb_smooth: true 38 | # Convolutions 39 | kernel_size: 8 40 | strides: [ 4,4,2,2 ] 41 | context: 1 42 | context_enc: 0 43 | freq_ends: 4 44 | enc_freq_attn: 0 45 | # normalization 46 | norm_starts: 2 47 | norm_groups: 4 48 | # DConv residual branch 49 | dconv_mode: 1 50 | dconv_depth: 2 51 | dconv_comp: 4 52 | dconv_time_attn: 4 53 | dconv_lstm: 4 54 | dconv_mamba: 0 55 | dconv_init: 1e-3 56 | # Weight init 57 | rescale: 0.1 58 | lr_sr: ${experiment.lr_sr} 59 | hr_sr: ${experiment.hr_sr} 60 | spec_upsample: true 61 | act_func: snake 62 | debug: false 63 | 64 | adversarial: True 65 | features_loss_lambda: 100 66 | only_features_loss: False 67 | only_adversarial_loss: False 68 | discriminator_models: [ msd_melgan ] #msd_melgan/msd_hifi/mpd/hifi 69 | melgan_discriminator: 70 | n_layers: 4 71 | num_D: 3 72 | downsampling_factor: 4 73 | ndf: 16 74 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/logs/debug-core.log: -------------------------------------------------------------------------------- 1 | {"time":"2024-10-29T10:54:45.392375178-03:00","level":"INFO","msg":"started logging, with flags","port-filename":"/tmp/tmp8aam3u_7/port-296628.txt","pid":296628,"debug":false,"disable-analytics":false} 2 | {"time":"2024-10-29T10:54:45.392395984-03:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false} 3 | {"time":"2024-10-29T10:54:45.392545361-03:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":296628} 4 | {"time":"2024-10-29T10:54:45.392525385-03:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":32981,"Zone":""}} 5 | {"time":"2024-10-29T10:54:45.586407105-03:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:51356"} 6 | {"time":"2024-10-29T10:54:46.000239122-03:00","level":"INFO","msg":"handleInformInit: received","streamId":"p0dnbwer","id":"127.0.0.1:51356"} 7 | {"time":"2024-10-29T10:54:46.113683275-03:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"p0dnbwer","id":"127.0.0.1:51356"} 8 | {"time":"2024-10-29T10:55:29.034701583-03:00","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"127.0.0.1:51356"} 9 | {"time":"2024-10-29T10:55:29.037452976-03:00","level":"INFO","msg":"server is shutting down"} 10 | {"time":"2024-10-29T10:55:29.06756904-03:00","level":"INFO","msg":"connection: Close: initiating connection closure","id":"127.0.0.1:51356"} 11 | {"time":"2024-10-29T10:55:29.067676885-03:00","level":"INFO","msg":"connection: Close: connection successfully closed","id":"127.0.0.1:51356"} 12 | {"time":"2024-10-29T10:55:30.317618856-03:00","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"127.0.0.1:51356"} 13 | {"time":"2024-10-29T10:55:30.317630442-03:00","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"127.0.0.1:51356"} 14 | {"time":"2024-10-29T10:55:30.317636542-03:00","level":"INFO","msg":"server is closed"} 15 | -------------------------------------------------------------------------------- /data_prep/resample_data.py: -------------------------------------------------------------------------------- 1 | import sox 2 | import os 3 | import sys 4 | import argparse 5 | from multiprocessing import Pool 6 | 7 | 8 | def resample_subdir(data_dir, data_subdir, out_dir, target_sr): 9 | print(f'resampling {data_subdir}') 10 | tfm = sox.Transformer() 11 | tfm.set_output_format(rate=target_sr) 12 | out_sub_dir = os.path.join(out_dir, data_subdir) 13 | if not os.path.isdir(out_sub_dir): 14 | os.makedirs(out_sub_dir) 15 | for file in os.listdir(os.path.join(data_dir, data_subdir)): 16 | out_path = os.path.join(out_sub_dir, file) 17 | in_path = os.path.join(data_dir, data_subdir, file) 18 | if os.path.isfile(out_path): 19 | print(f'{out_path} already exists.') 20 | elif not file.lower().endswith('.wav'): 21 | print(f'{in_path}: invalid file type.') 22 | else: 23 | success = tfm.build_file(input_filepath=in_path, output_filepath=out_path) 24 | if success: 25 | print(f'Succesfully saved {in_path} to {out_path}') 26 | 27 | 28 | def resample_data(data_dir, out_dir, target_sr): 29 | with Pool() as p: 30 | p.starmap(resample_subdir, 31 | [(data_dir, data_subdir, out_dir, target_sr) for data_subdir in os.listdir(data_dir)]) 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser(description='Resample data.') 36 | parser.add_argument('--data_dir', help='directory containing source files') 37 | parser.add_argument('--out_dir', help='directory to write target files') 38 | parser.add_argument('--target_sr', type=int, help='target sample rate') 39 | return parser.parse_args() 40 | 41 | """Usage: python data_prep/resample_data.py --data_dir --out_dir --target_sr """ 42 | def main(): 43 | args = parse_args() 44 | print(args) 45 | 46 | resample_data(args.data_dir, args.out_dir, args.target_sr) 47 | print(f'Done resampling to target rate {args.target_sr}.') 48 | 49 | 50 | if __name__ == '__main__': 51 | main() -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/debug-internal.log: -------------------------------------------------------------------------------- 1 | {"time":"2024-10-29T10:54:46.007098196-03:00","level":"INFO","msg":"using version","core version":"0.18.5"} 2 | {"time":"2024-10-29T10:54:46.007125601-03:00","level":"INFO","msg":"created symlink","path":"/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/logs/debug-core.log"} 3 | {"time":"2024-10-29T10:54:46.113606895-03:00","level":"INFO","msg":"created new stream","id":"p0dnbwer"} 4 | {"time":"2024-10-29T10:54:46.113668138-03:00","level":"INFO","msg":"stream: started","id":"p0dnbwer"} 5 | {"time":"2024-10-29T10:54:46.113758884-03:00","level":"INFO","msg":"sender: started","stream_id":"p0dnbwer"} 6 | {"time":"2024-10-29T10:54:46.113731211-03:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"p0dnbwer"}} 7 | {"time":"2024-10-29T10:54:46.113772978-03:00","level":"INFO","msg":"handler: started","stream_id":{"value":"p0dnbwer"}} 8 | {"time":"2024-10-29T10:54:46.502880952-03:00","level":"INFO","msg":"Starting system monitor"} 9 | {"time":"2024-10-29T10:55:29.037405842-03:00","level":"INFO","msg":"stream: closing","id":"p0dnbwer"} 10 | {"time":"2024-10-29T10:55:29.03753781-03:00","level":"INFO","msg":"Stopping system monitor"} 11 | {"time":"2024-10-29T10:55:29.041068641-03:00","level":"INFO","msg":"Stopped system monitor"} 12 | {"time":"2024-10-29T10:55:29.265813191-03:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"} 13 | {"time":"2024-10-29T10:55:29.265867751-03:00","level":"WARN","msg":"No source type found, not creating job artifact"} 14 | {"time":"2024-10-29T10:55:29.26588668-03:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"} 15 | {"time":"2024-10-29T10:55:30.05801325-03:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"} 16 | {"time":"2024-10-29T10:55:30.316637743-03:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"p0dnbwer"}} 17 | {"time":"2024-10-29T10:55:30.316670476-03:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"p0dnbwer"}} 18 | {"time":"2024-10-29T10:55:30.316697936-03:00","level":"INFO","msg":"sender: closed","stream_id":"p0dnbwer"} 19 | {"time":"2024-10-29T10:55:30.317570109-03:00","level":"INFO","msg":"stream: closed","id":"p0dnbwer"} 20 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/logs/debug-internal.log: -------------------------------------------------------------------------------- 1 | {"time":"2024-10-29T10:54:46.007098196-03:00","level":"INFO","msg":"using version","core version":"0.18.5"} 2 | {"time":"2024-10-29T10:54:46.007125601-03:00","level":"INFO","msg":"created symlink","path":"/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/logs/debug-core.log"} 3 | {"time":"2024-10-29T10:54:46.113606895-03:00","level":"INFO","msg":"created new stream","id":"p0dnbwer"} 4 | {"time":"2024-10-29T10:54:46.113668138-03:00","level":"INFO","msg":"stream: started","id":"p0dnbwer"} 5 | {"time":"2024-10-29T10:54:46.113758884-03:00","level":"INFO","msg":"sender: started","stream_id":"p0dnbwer"} 6 | {"time":"2024-10-29T10:54:46.113731211-03:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"p0dnbwer"}} 7 | {"time":"2024-10-29T10:54:46.113772978-03:00","level":"INFO","msg":"handler: started","stream_id":{"value":"p0dnbwer"}} 8 | {"time":"2024-10-29T10:54:46.502880952-03:00","level":"INFO","msg":"Starting system monitor"} 9 | {"time":"2024-10-29T10:55:29.037405842-03:00","level":"INFO","msg":"stream: closing","id":"p0dnbwer"} 10 | {"time":"2024-10-29T10:55:29.03753781-03:00","level":"INFO","msg":"Stopping system monitor"} 11 | {"time":"2024-10-29T10:55:29.041068641-03:00","level":"INFO","msg":"Stopped system monitor"} 12 | {"time":"2024-10-29T10:55:29.265813191-03:00","level":"WARN","msg":"No job ingredients found, not creating job artifact"} 13 | {"time":"2024-10-29T10:55:29.265867751-03:00","level":"WARN","msg":"No source type found, not creating job artifact"} 14 | {"time":"2024-10-29T10:55:29.26588668-03:00","level":"INFO","msg":"sender: sendDefer: no job artifact to save"} 15 | {"time":"2024-10-29T10:55:30.05801325-03:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"} 16 | {"time":"2024-10-29T10:55:30.316637743-03:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"p0dnbwer"}} 17 | {"time":"2024-10-29T10:55:30.316670476-03:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"p0dnbwer"}} 18 | {"time":"2024-10-29T10:55:30.316697936-03:00","level":"INFO","msg":"sender: closed","stream_id":"p0dnbwer"} 19 | {"time":"2024-10-29T10:55:30.317570109-03:00","level":"INFO","msg":"stream: closed","id":"p0dnbwer"} 20 | -------------------------------------------------------------------------------- /src/model_serializer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import logging 4 | 5 | import torch 6 | 7 | from src.utils import copy_state 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | SERIALIZE_KEY_MODELS = 'models' 12 | SERIALIZE_KEY_OPTIMIZERS = 'optimizers' 13 | SERIALIZE_KEY_HISTORY = 'history' 14 | SERIALIZE_KEY_STATE = 'state' 15 | SERIALIZE_KEY_BEST_STATES = 'best_states' 16 | SERIALIZE_KEY_ARGS = 'args' 17 | 18 | 19 | def serialize_model(model): 20 | args, kwargs = model._init_args_kwargs 21 | state = copy_state(model.state_dict()) 22 | return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state} 23 | 24 | 25 | def _serialize_models(models): 26 | serialized_models = {} 27 | for name, model in models.items(): 28 | serialized_models[name] = serialize_model(model) 29 | return serialized_models 30 | 31 | 32 | def _serialize_optimizers(optimizers): 33 | serialized_optimizers = {} 34 | for name, optimizer in optimizers.items(): 35 | serialized_optimizers[name] = optimizer.state_dict() 36 | return serialized_optimizers 37 | 38 | 39 | def serialize(models, optimizers, history, best_states, args): 40 | checkpoint_file = Path(args.checkpoint_file) 41 | best_file = Path(args.best_file) 42 | 43 | package = {} 44 | package[SERIALIZE_KEY_MODELS] = _serialize_models(models) 45 | package[SERIALIZE_KEY_OPTIMIZERS] = _serialize_optimizers(optimizers) 46 | package[SERIALIZE_KEY_HISTORY] = history 47 | package[SERIALIZE_KEY_BEST_STATES] = best_states 48 | package[SERIALIZE_KEY_ARGS] = args 49 | tmp_path = str(checkpoint_file) + ".tmp" 50 | torch.save(package, tmp_path) 51 | # renaming is sort of atomic on UNIX (not really true on NFS) 52 | # but still less chances of leaving a half written checkpoint behind. 53 | os.rename(tmp_path, checkpoint_file) 54 | 55 | # Saving only the latest best model. 56 | models = package[SERIALIZE_KEY_MODELS] 57 | for model_name, best_state in package[SERIALIZE_KEY_BEST_STATES].items(): 58 | models[model_name][SERIALIZE_KEY_STATE] = best_state 59 | model_filename = model_name + '_' + best_file.name 60 | tmp_path = os.path.join(best_file.parent, model_filename) + ".tmp" 61 | torch.save(models[model_name], tmp_path) 62 | model_path = Path(best_file.parent / model_filename) 63 | os.rename(tmp_path, model_path) -------------------------------------------------------------------------------- /src/models/snake.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, sin, pow 3 | from torch.nn import Parameter 4 | from torch.distributions.exponential import Exponential 5 | 6 | 7 | class Snake(nn.Module): 8 | ''' 9 | Implementation of the serpentine-like sine-based periodic activation function: 10 | .. math:: 11 | Snake_a := x + \frac{1}{a} sin^2(ax) = x - \frac{1}{2a}cos{2ax} + \frac{1}{2a} 12 | This activation function is able to better extrapolate to previously unseen data, 13 | especially in the case of learning periodic functions 14 | 15 | Shape: 16 | - Input: (N, *) where * means, any number of additional 17 | dimensions 18 | - Output: (N, *), same shape as the input 19 | 20 | Parameters: 21 | - a - trainable parameter 22 | 23 | References: 24 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 25 | https://arxiv.org/abs/2006.08195 26 | 27 | Examples: 28 | >>> a1 = snake(256) 29 | >>> x = torch.randn(256) 30 | >>> x = a1(x) 31 | ''' 32 | 33 | def __init__(self, in_features, a=None, trainable=True): 34 | ''' 35 | Initialization. 36 | Args: 37 | in_features: shape of the input 38 | a: trainable parameter 39 | trainable: sets `a` as a trainable parameter 40 | 41 | `a` is initialized to 1 by default, higher values = higher-frequency, 42 | 5-50 is a good starting point if you already think your data is periodic, 43 | consider starting lower e.g. 0.5 if you think not, but don't worry, 44 | `a` will be trained along with the rest of your model 45 | ''' 46 | super(Snake, self).__init__() 47 | self.in_features = in_features if isinstance(in_features, list) else [in_features] 48 | 49 | # Initialize `a` 50 | if a is not None: 51 | self.a = Parameter(torch.ones(self.in_features) * a) # create a tensor out of alpha 52 | else: 53 | m = Exponential(torch.tensor([0.1])) 54 | self.a = Parameter((m.rsample(self.in_features)).squeeze()) # random init = mix of frequencies 55 | 56 | self.a.requiresGrad = trainable # set the training of `a` to true 57 | 58 | def extra_repr(self) -> str: 59 | return 'in_features={}'.format(self.in_features) 60 | 61 | def forward(self, x): 62 | ''' 63 | Forward pass of the function. 64 | Applies the function to the input elementwise. 65 | Snake ∶= x + 1/a* sin^2 (xa) 66 | ''' 67 | return x + (1.0 / self.a) * pow(sin(x * self.a), 2) 68 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/files/config.yaml: -------------------------------------------------------------------------------- 1 | _wandb: 2 | value: 3 | cli_version: 0.18.5 4 | m: [] 5 | python_version: 3.10.0 6 | t: 7 | "1": 8 | - 1 9 | - 11 10 | - 49 11 | - 50 12 | - 55 13 | "2": 14 | - 1 15 | - 11 16 | - 49 17 | - 50 18 | - 55 19 | "3": 20 | - 13 21 | - 16 22 | - 23 23 | - 55 24 | "4": 3.10.0 25 | "5": 0.18.5 26 | "6": 4.46.0 27 | "8": 28 | - 2 29 | - 5 30 | "12": 0.18.5 31 | "13": linux-x86_64 32 | adversarial: 33 | value: true 34 | aero: 35 | value: '{''in_channels'': 1, ''out_channels'': 1, ''channels'': 48, ''growth'': 2, ''nfft'': ''${experiment.nfft}'', ''hop_length'': ''${experiment.hop_length}'', ''end_iters'': 0, ''cac'': True, ''rewrite'': True, ''hybrid'': False, ''hybrid_old'': False, ''freq_emb'': 0.2, ''emb_scale'': 10, ''emb_smooth'': True, ''kernel_size'': 8, ''strides'': [4, 4, 2, 2], ''context'': 1, ''context_enc'': 0, ''freq_ends'': 4, ''enc_freq_attn'': 0, ''norm_starts'': 2, ''norm_groups'': 4, ''dconv_mode'': 1, ''dconv_depth'': 2, ''dconv_comp'': 4, ''dconv_time_attn'': 4, ''dconv_lstm'': 4, ''dconv_mamba'': 0, ''dconv_init'': 0.001, ''rescale'': 0.1, ''lr_sr'': ''${experiment.lr_sr}'', ''hr_sr'': ''${experiment.hr_sr}'', ''spec_upsample'': True, ''act_func'': ''snake'', ''debug'': False}' 36 | batch_size: 37 | value: 4 38 | discriminator_models: 39 | value: 40 | - msd_melgan 41 | epochs: 42 | value: 696 43 | eval_every: 44 | value: 3 45 | features_loss_lambda: 46 | value: 100 47 | fixed_n_examples: 48 | value: null 49 | hop_length: 50 | value: 256 51 | hr_sr: 52 | value: 44100 53 | losses: 54 | value: 55 | - stft 56 | lr: 57 | value: 0.0003 58 | lr_sr: 59 | value: 11025 60 | melgan_discriminator: 61 | value: '{''n_layers'': 4, ''num_D'': 3, ''downsampling_factor'': 4, ''ndf'': 16}' 62 | model: 63 | value: aero 64 | name: 65 | value: aeromamba 66 | nfft: 67 | value: 512 68 | only_adversarial_loss: 69 | value: false 70 | only_features_loss: 71 | value: false 72 | optim: 73 | value: adam 74 | pad: 75 | value: true 76 | power_threshold: 77 | value: 0.001 78 | segment: 79 | value: 4 80 | stride: 81 | value: 4 82 | test: 83 | value: /home/wallace.abreu/Mestrado/aero_vanilla/egs/musdb-mixture-11-44/tt 84 | train: 85 | value: /home/wallace.abreu/Mestrado/aero_vanilla/egs/musdb-mixture-11-44/tr 86 | upsample: 87 | value: false 88 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105349-0ilrdhbg/files/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.14.0 2 | platformdirs==4.2.2 3 | llvmlite==0.43.0 4 | typing_extensions==4.12.2 5 | tqdm==4.66.4 6 | texttable==1.7.0 7 | soxr==0.4.0 8 | soundfile==0.12.1 9 | pyzstd==0.16.0 10 | pyppmd==1.1.0 11 | pycryptodomex==3.20.0 12 | pybcj==1.0.2 13 | portalocker==2.10.1 14 | pooch==1.8.2 15 | numba==0.60.0 16 | multivolumefile==0.2.3 17 | msgpack==1.0.8 18 | mir-eval==0.7 19 | lazy_loader==0.4 20 | inflate64==1.0.0 21 | Brotli==1.1.0 22 | audioread==3.0.1 23 | yacs==0.1.8 24 | termcolor==2.4.0 25 | pybind11==2.13.1 26 | py7zr==0.21.1 27 | netCDF4==1.7.1.post1 28 | librosa==0.10.2.post1 29 | jams==0.3.4 30 | iopath==0.1.10 31 | h5py==3.11.0 32 | soundata==1.0.1 33 | pysofaconventions==0.1.5 34 | pyroomacoustics==0.7.5 35 | mat73==0.65 36 | fvcore==0.1.5.post20221221 37 | spatialscaper==0.1.5 38 | setuptools==75.1.0 39 | wheel==0.44.0 40 | pip==24.2 41 | numpy==1.26.4 42 | numpy==1.24.3 43 | opencv-python==4.9.0.80 44 | Brotli==1.0.9 45 | certifi==2024.8.30 46 | charset-normalizer==3.3.2 47 | idna==3.7 48 | pillow==10.4.0 49 | PySocks==1.7.1 50 | six==1.16.0 51 | typing_extensions==4.11.0 52 | mkl-service==2.4.0 53 | torch==1.12.1 54 | urllib3==2.2.3 55 | requests==2.32.3 56 | mkl-fft==1.3.1 57 | mkl-random==1.2.2 58 | torchaudio==0.12.1 59 | torchvision==0.13.1 60 | sox==1.5.0 61 | smmap==5.0.1 62 | setproctitle==1.3.3 63 | sentry-sdk==2.17.0 64 | PyYAML==6.0.2 65 | psutil==6.1.0 66 | protobuf==5.28.3 67 | docker-pycreds==0.4.0 68 | click==8.1.7 69 | gitdb==4.0.11 70 | GitPython==3.1.43 71 | wandb==0.18.5 72 | ninja==1.11.1.1 73 | packaging==24.1 74 | safetensors==0.4.5 75 | regex==2024.9.11 76 | fsspec==2024.10.0 77 | filelock==3.16.1 78 | einops==0.8.0 79 | triton==3.1.0 80 | huggingface-hub==0.26.2 81 | tokenizers==0.20.1 82 | transformers==4.46.0 83 | mamba-ssm==1.1.3.post1 84 | colorlog==5.0.1 85 | antlr4-python3-runtime==4.8 86 | omegaconf==2.1.2 87 | hydra-core==1.1.1 88 | hydra-colorlog==1.1.0 89 | hyperlink==17.3.1 90 | ruamel.yaml.clib==0.2.12 91 | ruamel.yaml==0.18.6 92 | HyperPyYAML==1.0.0 93 | pycparser==2.22 94 | cffi==1.17.1 95 | redo==3.0.0 96 | docopt==0.6.2 97 | argparse==1.4.0 98 | zope.interface==7.1.1 99 | tomli==2.0.2 100 | simplejson==3.19.3 101 | python-dateutil==2.9.0.post0 102 | orderedmultidict==1.0.1 103 | MarkupSafe==3.0.2 104 | greenlet==3.1.1 105 | constantly==23.10.4 106 | Automat==24.8.1 107 | attrs==24.2.0 108 | SQLAlchemy==2.0.36 109 | Jinja2==3.1.4 110 | incremental==24.7.2 111 | furl==2.1.3 112 | Twisted==24.10.0 113 | buildtools==1.0.6 114 | pyparsing==3.2.0 115 | kiwisolver==1.4.7 116 | causal-conv1d==1.1.2.post1 117 | fonttools==4.54.1 118 | cycler==0.12.1 119 | contourpy==1.3.0 120 | matplotlib==3.9.2 121 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/files/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.14.0 2 | platformdirs==4.2.2 3 | llvmlite==0.43.0 4 | typing_extensions==4.12.2 5 | tqdm==4.66.4 6 | texttable==1.7.0 7 | soxr==0.4.0 8 | soundfile==0.12.1 9 | pyzstd==0.16.0 10 | pyppmd==1.1.0 11 | pycryptodomex==3.20.0 12 | pybcj==1.0.2 13 | portalocker==2.10.1 14 | pooch==1.8.2 15 | numba==0.60.0 16 | multivolumefile==0.2.3 17 | msgpack==1.0.8 18 | mir-eval==0.7 19 | lazy_loader==0.4 20 | inflate64==1.0.0 21 | Brotli==1.1.0 22 | audioread==3.0.1 23 | yacs==0.1.8 24 | termcolor==2.4.0 25 | pybind11==2.13.1 26 | py7zr==0.21.1 27 | netCDF4==1.7.1.post1 28 | librosa==0.10.2.post1 29 | jams==0.3.4 30 | iopath==0.1.10 31 | h5py==3.11.0 32 | soundata==1.0.1 33 | pysofaconventions==0.1.5 34 | pyroomacoustics==0.7.5 35 | mat73==0.65 36 | fvcore==0.1.5.post20221221 37 | spatialscaper==0.1.5 38 | setuptools==75.1.0 39 | wheel==0.44.0 40 | pip==24.2 41 | numpy==1.26.4 42 | numpy==1.24.3 43 | opencv-python==4.9.0.80 44 | Brotli==1.0.9 45 | certifi==2024.8.30 46 | charset-normalizer==3.3.2 47 | idna==3.7 48 | pillow==10.4.0 49 | PySocks==1.7.1 50 | six==1.16.0 51 | typing_extensions==4.11.0 52 | mkl-service==2.4.0 53 | torch==1.12.1 54 | urllib3==2.2.3 55 | requests==2.32.3 56 | mkl-fft==1.3.1 57 | mkl-random==1.2.2 58 | torchaudio==0.12.1 59 | torchvision==0.13.1 60 | sox==1.5.0 61 | smmap==5.0.1 62 | setproctitle==1.3.3 63 | sentry-sdk==2.17.0 64 | PyYAML==6.0.2 65 | psutil==6.1.0 66 | protobuf==5.28.3 67 | docker-pycreds==0.4.0 68 | click==8.1.7 69 | gitdb==4.0.11 70 | GitPython==3.1.43 71 | wandb==0.18.5 72 | ninja==1.11.1.1 73 | packaging==24.1 74 | safetensors==0.4.5 75 | regex==2024.9.11 76 | fsspec==2024.10.0 77 | filelock==3.16.1 78 | einops==0.8.0 79 | triton==3.1.0 80 | huggingface-hub==0.26.2 81 | tokenizers==0.20.1 82 | transformers==4.46.0 83 | mamba-ssm==1.1.3.post1 84 | colorlog==5.0.1 85 | antlr4-python3-runtime==4.8 86 | omegaconf==2.1.2 87 | hydra-core==1.1.1 88 | hydra-colorlog==1.1.0 89 | hyperlink==17.3.1 90 | ruamel.yaml.clib==0.2.12 91 | ruamel.yaml==0.18.6 92 | HyperPyYAML==1.0.0 93 | pycparser==2.22 94 | cffi==1.17.1 95 | redo==3.0.0 96 | docopt==0.6.2 97 | argparse==1.4.0 98 | zope.interface==7.1.1 99 | tomli==2.0.2 100 | simplejson==3.19.3 101 | python-dateutil==2.9.0.post0 102 | orderedmultidict==1.0.1 103 | MarkupSafe==3.0.2 104 | greenlet==3.1.1 105 | constantly==23.10.4 106 | Automat==24.8.1 107 | attrs==24.2.0 108 | SQLAlchemy==2.0.36 109 | Jinja2==3.1.4 110 | incremental==24.7.2 111 | furl==2.1.3 112 | Twisted==24.10.0 113 | buildtools==1.0.6 114 | pyparsing==3.2.0 115 | kiwisolver==1.4.7 116 | causal-conv1d==1.1.2.post1 117 | fonttools==4.54.1 118 | cycler==0.12.1 119 | contourpy==1.3.0 120 | matplotlib==3.9.2 121 | -------------------------------------------------------------------------------- /src/ddp/executor.py: -------------------------------------------------------------------------------- 1 | # taken from https://github.com/facebookresearch/denoiser 2 | 3 | import logging 4 | import subprocess as sp 5 | import sys 6 | from pathlib import Path 7 | 8 | from hydra import utils 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class ChildrenManager: 14 | def __init__(self): 15 | self.children = [] 16 | self.failed = False 17 | 18 | def add(self, child): 19 | child.rank = len(self.children) 20 | self.children.append(child) 21 | 22 | def __enter__(self): 23 | return self 24 | 25 | def __exit__(self, exc_type, exc_value, traceback): 26 | if exc_value is not None: 27 | logger.error("An exception happened while starting workers %r", exc_value) 28 | self.failed = True 29 | try: 30 | while self.children and not self.failed: 31 | for child in list(self.children): 32 | try: 33 | exitcode = child.wait(0.1) 34 | except sp.TimeoutExpired: 35 | continue 36 | else: 37 | self.children.remove(child) 38 | if exitcode: 39 | logger.error(f"Worker {child.rank} died, killing all workers") 40 | self.failed = True 41 | except KeyboardInterrupt: 42 | logger.error("Received keyboard interrupt, trying to kill all workers.") 43 | self.failed = True 44 | for child in self.children: 45 | child.terminate() 46 | if not self.failed: 47 | logger.info("All workers completed successfully") 48 | 49 | 50 | def start_ddp_workers(args): 51 | import torch as th 52 | log = utils.HydraConfig().cfg.hydra.job_logging.handlers.file.filename 53 | rendezvous_file = Path(args.rendezvous_file) 54 | if rendezvous_file.exists(): 55 | rendezvous_file.unlink() 56 | 57 | world_size = th.cuda.device_count() 58 | if not world_size: 59 | logger.error( 60 | "DDP is only available on GPU. Make sure GPUs are properly configured with cuda.") 61 | sys.exit(1) 62 | logger.info(f"Starting {world_size} worker processes for DDP.") 63 | with ChildrenManager() as manager: 64 | for rank in range(world_size): 65 | kwargs = {} 66 | argv = list(sys.argv) 67 | argv += [f"world_size={world_size}", f"rank={rank}"] 68 | if rank > 0: 69 | kwargs['stdin'] = sp.DEVNULL 70 | kwargs['stdout'] = sp.DEVNULL 71 | kwargs['stderr'] = sp.DEVNULL 72 | log += f".{rank}" 73 | argv.append("hydra.job_logging.handlers.file.filename=" + log) 74 | manager.add(sp.Popen([sys.executable] + argv, cwd=utils.get_original_cwd(), **kwargs)) 75 | sys.exit(int(manager.failed)) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs 3 | """ 4 | import logging 5 | import os 6 | 7 | import torch 8 | from pathlib import Path 9 | import hydra 10 | import wandb 11 | 12 | from src.data.datasets import LrHrSet 13 | from src.ddp import distrib 14 | from src.evaluate import evaluate 15 | from src.models import modelFactory 16 | from src.utils import bold 17 | from src.wandb_logger import _init_wandb_run 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | SERIALIZE_KEY_MODELS = 'models' 22 | SERIALIZE_KEY_BEST_STATES = 'best_states' 23 | SERIALIZE_KEY_STATE = 'state' 24 | 25 | 26 | def _load_model(args): 27 | model_name = args.experiment.model 28 | checkpoint_file = Path(args.checkpoint_file) 29 | model = modelFactory.get_model(args)['generator'] 30 | package = torch.load(checkpoint_file, 'cpu') 31 | load_best = args.continue_best 32 | if load_best: 33 | logger.info(bold(f'Loading model {model_name} from best state.')) 34 | model.load_state_dict( 35 | package[SERIALIZE_KEY_BEST_STATES][SERIALIZE_KEY_MODELS]['generator'][SERIALIZE_KEY_STATE]) 36 | else: 37 | logger.info(bold(f'Loading model {model_name} from last state.')) 38 | model.load_state_dict(package[SERIALIZE_KEY_MODELS]['generator'][SERIALIZE_KEY_STATE]) 39 | 40 | return model 41 | 42 | def run(args): 43 | tt_dataset = LrHrSet(args.dset.test, args.experiment.lr_sr, args.experiment.hr_sr, 44 | stride=10, segment=10, with_path=True, upsample=args.experiment.upsample) 45 | tt_loader = distrib.loader(tt_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers) 46 | 47 | model = _load_model(args) 48 | model.cuda() 49 | 50 | lsd, visqol, enhanced_filenames = evaluate(args, tt_loader, 0, model) 51 | logger.info(f'Done evaluation.') 52 | logger.info(f'LSD={lsd} , VISQOL={visqol}') 53 | 54 | 55 | 56 | def _main(args): 57 | global __file__ 58 | print(args) 59 | # Updating paths in config 60 | for key, value in args.dset.items(): 61 | if isinstance(value, str): 62 | args.dset[key] = hydra.utils.to_absolute_path(value) 63 | __file__ = hydra.utils.to_absolute_path(__file__) 64 | if args.verbose: 65 | logger.setLevel(logging.DEBUG) 66 | logging.getLogger("src").setLevel(logging.DEBUG) 67 | 68 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 69 | logger.debug(args) 70 | 71 | _init_wandb_run(args) 72 | run(args) 73 | wandb.finish() 74 | 75 | 76 | @hydra.main(config_path="conf", config_name="main_config") # for latest version of hydra=1.0 77 | def main(args): 78 | try: 79 | _main(args) 80 | except Exception: 81 | logger.exception("Some error happened") 82 | # Hydra intercepts exit code, fixed in beta but I could not get the beta to work 83 | os._exit(1) 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /outputs/chopin-11-44/aeromamba/.hydra/config.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | name: aeromamba 3 | lr_sr: 11025 4 | hr_sr: 44100 5 | segment: 4 6 | stride: 4 7 | pad: true 8 | upsample: false 9 | batch_size: 4 10 | nfft: 512 11 | hop_length: 256 12 | fixed_n_examples: null 13 | power_threshold: 0.001 14 | model: aero 15 | aero: 16 | in_channels: 1 17 | out_channels: 1 18 | channels: 48 19 | growth: 2 20 | nfft: ${experiment.nfft} 21 | hop_length: ${experiment.hop_length} 22 | end_iters: 0 23 | cac: true 24 | rewrite: true 25 | hybrid: false 26 | hybrid_old: false 27 | freq_emb: 0.2 28 | emb_scale: 10 29 | emb_smooth: true 30 | kernel_size: 8 31 | strides: 32 | - 4 33 | - 4 34 | - 2 35 | - 2 36 | context: 1 37 | context_enc: 0 38 | freq_ends: 4 39 | enc_freq_attn: 0 40 | norm_starts: 2 41 | norm_groups: 4 42 | dconv_mode: 1 43 | dconv_depth: 2 44 | dconv_comp: 4 45 | dconv_time_attn: 4 46 | dconv_lstm: 4 47 | dconv_mamba: 0 48 | dconv_init: 0.001 49 | rescale: 0.1 50 | lr_sr: ${experiment.lr_sr} 51 | hr_sr: ${experiment.hr_sr} 52 | spec_upsample: true 53 | act_func: snake 54 | debug: false 55 | adversarial: true 56 | features_loss_lambda: 100 57 | only_features_loss: false 58 | only_adversarial_loss: false 59 | discriminator_models: 60 | - msd_melgan 61 | melgan_discriminator: 62 | n_layers: 4 63 | num_D: 3 64 | downsampling_factor: 4 65 | ndf: 16 66 | dset: 67 | name: chopin-11-44 68 | train: egs/chopin-11-44/tr 69 | valid: null 70 | test: egs/chopin-11-44/tt 71 | num_prints: 10 72 | device: cuda 73 | num_workers: 8 74 | verbose: 0 75 | show: 0 76 | log_results: true 77 | checkpoint: true 78 | continue_from: '' 79 | continue_best: false 80 | restart: false 81 | checkpoint_file: checkpoint.th 82 | best_file: best.th 83 | history_file: history.json 84 | test_results_file: test_results.json 85 | samples_dir: samples/ 86 | keep_history: true 87 | seed: 2036 88 | dummy: '' 89 | visqol: false 90 | visqol_path: null 91 | eval_every: 3 92 | enhance_samples_limit: -1 93 | valid_equals_test: null 94 | cross_valid: false 95 | cross_valid_every: 1 96 | joint_evaluate_and_enhance: false 97 | evaluate_on_best: false 98 | wandb: 99 | project_name: AEROMamba MUSDB 100 | entity: null 101 | mode: online 102 | log: all 103 | log_freq: 10 104 | n_files_to_log: 5 105 | n_files_to_log_to_table: 1 106 | tags: [] 107 | resume: false 108 | optim: adam 109 | lr: 0.0003 110 | beta1: 0.8 111 | beta2: 0.999 112 | losses: 113 | - stft 114 | stft_sc_factor: 0.5 115 | stft_mag_factor: 0.5 116 | epochs: 696 117 | ddp: false 118 | ddp_backend: nccl 119 | rendezvous_file: ./rendezvous 120 | rank: null 121 | world_size: null 122 | filename: /home/wallace.abreu/Mestrado/aeromamba-lamir/ds_datasets/chopin/test/V.ASHKENAZYTrack11.wav 123 | output: /home/wallace.abreu/Mestrado/aeromamba-lamir/test 124 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/.hydra/config.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | name: aeromamba 3 | lr_sr: 11025 4 | hr_sr: 44100 5 | segment: 4 6 | stride: 4 7 | pad: true 8 | upsample: false 9 | batch_size: 4 10 | nfft: 512 11 | hop_length: 256 12 | fixed_n_examples: null 13 | power_threshold: 0.001 14 | model: aero 15 | aero: 16 | in_channels: 1 17 | out_channels: 1 18 | channels: 48 19 | growth: 2 20 | nfft: ${experiment.nfft} 21 | hop_length: ${experiment.hop_length} 22 | end_iters: 0 23 | cac: true 24 | rewrite: true 25 | hybrid: false 26 | hybrid_old: false 27 | freq_emb: 0.2 28 | emb_scale: 10 29 | emb_smooth: true 30 | kernel_size: 8 31 | strides: 32 | - 4 33 | - 4 34 | - 2 35 | - 2 36 | context: 1 37 | context_enc: 0 38 | freq_ends: 4 39 | enc_freq_attn: 0 40 | norm_starts: 2 41 | norm_groups: 4 42 | dconv_mode: 1 43 | dconv_depth: 2 44 | dconv_comp: 4 45 | dconv_time_attn: 4 46 | dconv_lstm: 4 47 | dconv_mamba: 0 48 | dconv_init: 0.001 49 | rescale: 0.1 50 | lr_sr: ${experiment.lr_sr} 51 | hr_sr: ${experiment.hr_sr} 52 | spec_upsample: true 53 | act_func: snake 54 | debug: false 55 | adversarial: true 56 | features_loss_lambda: 100 57 | only_features_loss: false 58 | only_adversarial_loss: false 59 | discriminator_models: 60 | - msd_melgan 61 | melgan_discriminator: 62 | n_layers: 4 63 | num_D: 3 64 | downsampling_factor: 4 65 | ndf: 16 66 | dset: 67 | name: musdb-mixture-11-44 68 | train: /home/wallace.abreu/Mestrado/aero_vanilla/egs/musdb-mixture-11-44/tr 69 | valid: null 70 | test: /home/wallace.abreu/Mestrado/aero_vanilla/egs/musdb-mixture-11-44/tt 71 | num_prints: 10 72 | device: cuda 73 | num_workers: 8 74 | verbose: 0 75 | show: 0 76 | log_results: true 77 | checkpoint: true 78 | continue_from: '' 79 | continue_best: false 80 | restart: false 81 | checkpoint_file: /home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/chopin-11-44-one/aeromamba/checkpoint.th 82 | best_file: best.th 83 | history_file: history.json 84 | test_results_file: test_results.json 85 | samples_dir: samples/ 86 | keep_history: true 87 | seed: 2036 88 | dummy: '' 89 | visqol: false 90 | visqol_path: null 91 | eval_every: 3 92 | enhance_samples_limit: -1 93 | valid_equals_test: null 94 | cross_valid: false 95 | cross_valid_every: 1 96 | joint_evaluate_and_enhance: false 97 | evaluate_on_best: false 98 | wandb: 99 | project_name: AEROMamba MUSDB 100 | entity: null 101 | mode: online 102 | log: all 103 | log_freq: 10 104 | n_files_to_log: 5 105 | n_files_to_log_to_table: 1 106 | tags: [] 107 | resume: false 108 | optim: adam 109 | lr: 0.0003 110 | beta1: 0.8 111 | beta2: 0.999 112 | losses: 113 | - stft 114 | stft_sc_factor: 0.5 115 | stft_mag_factor: 0.5 116 | epochs: 696 117 | ddp: false 118 | ddp_backend: nccl 119 | rendezvous_file: ./rendezvous 120 | rank: null 121 | world_size: null 122 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/files/output.log: -------------------------------------------------------------------------------- 1 | [2024-10-29 10:54:55,262][__main__][INFO] - Loading model aero from last state. 2 | Traceback (most recent call last): 3 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py", line 87, in 4 | main() 5 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/hydra/main.py", line 48, in decorated_main 6 | _run_hydra( 7 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/hydra/_internal/utils.py", line 377, in _run_hydra 8 | run_and_report( 9 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/hydra/_internal/utils.py", line 211, in run_and_report 10 | return func() 11 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/hydra/_internal/utils.py", line 378, in 12 | lambda: hydra.run( 13 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 98, in run 14 | ret = run_job( 15 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/hydra/core/utils.py", line 160, in run_job 16 | ret.return_value = task_function(task_cfg) 17 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py", line 79, in main 18 | _main(args) 19 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py", line 72, in _main 20 | run(args) 21 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py", line 50, in run 22 | lsd, visqol, enhanced_filenames = evaluate(args, tt_loader, 0, model) 23 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/src/evaluate.py", line 161, in evaluate 24 | metrics_i = evaluate_lr_hr_data(data, model, wandb_n_files_to_log, files_to_log, epoch, args) 25 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/src/evaluate.py", line 65, in evaluate_lr_hr_data 26 | pr, pr_spec, lr_spec = model(lr, return_spec=True, return_lr_spec=True) 27 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl 28 | return forward_call(*input, **kwargs) 29 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/src/models/aero.py", line 486, in forward 30 | x = encode(x, inject) 31 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl 32 | return forward_call(*input, **kwargs) 33 | File "/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/src/models/aero.py", line 120, in forward 34 | x = self.pre_conv(x) 35 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl 36 | return forward_call(*input, **kwargs) 37 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 457, in forward 38 | return self._conv_forward(input, self.weight, self.bias) 39 | File "/home/wallace.abreu/miniconda3/envs/lamir_test/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 453, in _conv_forward 40 | return F.conv2d(input, weight, bias, self.stride, 41 | KeyboardInterrupt 42 | -------------------------------------------------------------------------------- /src/enhance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch 5 | import torchaudio 6 | from PIL import Image 7 | from src.utils import LogProgress, convert_spectrogram_to_heatmap 8 | logger = logging.getLogger(__name__) 9 | 10 | def get_estimate(model, lr_sig): 11 | torch.set_num_threads(1) 12 | with torch.no_grad(): 13 | out = model(lr_sig) 14 | return out 15 | 16 | def convert_float_to_int16(signal): 17 | signal = np.clip(signal, -1.0, 1.0).numpy() 18 | return (signal * 32767).astype(np.int16) 19 | 20 | def write(wav, filename, sr): 21 | # Normalize audio if it prevents clipping 22 | wav = wav / max(wav.abs().max().item(), 1) 23 | torchaudio.save(filename, wav.cpu(), sr, encoding="PCM_S", bits_per_sample=16) 24 | 25 | 26 | def save_wavs(processed_sigs, lr_sigs, hr_sigs, filenames, lr_sr, hr_sr): 27 | # Write result 28 | current_filename = '' 29 | counter = 0 30 | for lr, hr, pr, filename in zip(lr_sigs, hr_sigs, processed_sigs, filenames): 31 | write(lr, filename + "_lr.wav", sr=lr_sr) 32 | write(hr, filename + "_hr.wav", sr=hr_sr) 33 | write(pr, filename + "_pr.wav", sr=hr_sr) 34 | current_filename = filename 35 | 36 | def save_specs(lr_spec, pr_spec, hr_spec, filename): 37 | lr_spec_path = filename + "_lr_spec.png" 38 | if not os.path.isfile(lr_spec_path): 39 | lr_spec = lr_spec.cpu().abs().pow(2).log2()[0, :, :].numpy() 40 | lr_spec = convert_spectrogram_to_heatmap(lr_spec) 41 | lr_spec_img = Image.fromarray(lr_spec) 42 | lr_spec_img.save(lr_spec_path) 43 | 44 | hr_spec_path = filename + "_hr_spec.png" 45 | if not os.path.isfile(hr_spec_path): 46 | hr_spec = hr_spec.cpu().abs().pow(2).log2()[0, :, :].numpy() 47 | hr_spec = convert_spectrogram_to_heatmap(hr_spec) 48 | hr_spec_img = Image.fromarray(hr_spec) 49 | hr_spec_img.save(hr_spec_path) 50 | 51 | pr_spec = pr_spec.cpu().abs().pow(2).log2()[0, :, :].numpy() 52 | pr_spec = convert_spectrogram_to_heatmap(pr_spec) 53 | pr_spec_img = Image.fromarray(pr_spec) 54 | pr_spec_img.save(filename + "_pr_spec.png") 55 | 56 | 57 | def enhance(dataloader, model, args): 58 | model.eval() 59 | if not os.path.exists(args.samples_dir): 60 | os.makedirs(args.samples_dir, exist_ok=True) 61 | lr_sr = args.experiment.lr_sr if 'experiment' in args else args.lr_sr 62 | hr_sr = args.experiment.hr_sr if 'experiment' in args else args.hr_sr 63 | 64 | total_filenames = [] 65 | 66 | iterator = LogProgress(logger, dataloader, name="Generate enhanced files") 67 | 68 | for i, data in enumerate(iterator): 69 | # Get batch data 70 | (lr_sigs, lr_paths), (hr_sigs, hr_paths) = data 71 | lr_sigs = lr_sigs.to(args.device) 72 | hr_sigs = hr_sigs.to(args.device) 73 | filenames = [os.path.join(args.samples_dir, os.path.basename(path).rsplit(".", 1)[0] + '_' + str(i)) for path in lr_paths] 74 | total_filenames += [os.path.basename(path).rsplit(".", 1)[0] + '_' + str(i) for path in lr_paths] 75 | 76 | estimates = get_estimate(model, lr_sigs) 77 | 78 | save_wavs(estimates, lr_sigs, hr_sigs, filenames, lr_sr, hr_sr) 79 | 80 | if i == args.enhance_samples_limit: 81 | break 82 | model.train() 83 | return total_filenames 84 | 85 | -------------------------------------------------------------------------------- /src/data/audio.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs 3 | """ 4 | import math 5 | import torchaudio 6 | import torch 7 | import random 8 | from torch.nn import functional as F 9 | 10 | 11 | class Audioset: 12 | def __init__(self, files=None, length=None, stride=None, 13 | pad=True, with_path=False, sample_rate=None, 14 | channels=None, fixed_n_examples=None): 15 | """ 16 | files should be a list [(file, length)] 17 | """ 18 | self.files = files 19 | self.num_examples = [] 20 | self.length = length 21 | self.stride = stride or length 22 | self.with_path = with_path 23 | self.sample_rate = sample_rate 24 | self.channels = channels 25 | self.fixed_n_examples = fixed_n_examples 26 | 27 | for file, file_length in self.files: 28 | if length is None: 29 | examples = 1 30 | elif file_length < length: 31 | examples = 1 if pad else 0 32 | elif pad: 33 | examples = int(math.ceil((file_length - self.length) / self.stride) + 1) 34 | if self.fixed_n_examples is not None: 35 | if examples > self.fixed_n_examples: 36 | examples = self.fixed_n_examples 37 | else: 38 | examples = (file_length - self.length) // self.stride + 1 39 | if self.fixed_n_examples is not None: 40 | if examples > self.fixed_n_examples: 41 | examples = self.fixed_n_examples 42 | self.num_examples.append(examples) 43 | 44 | def __len__(self): 45 | return sum(self.num_examples) 46 | 47 | def __getitem__(self, index): 48 | for (file, file_samples), examples in zip(self.files, self.num_examples): 49 | if index >= examples: 50 | index -= examples 51 | continue 52 | num_frames = 0 53 | offset = 0 54 | if self.length is not None: 55 | num_frames = self.length 56 | offset = self.stride * index 57 | 58 | if torchaudio.get_audio_backend() in ['soundfile', 'sox_io']: 59 | out, sr = torchaudio.load(str(file), 60 | frame_offset=offset, 61 | num_frames=num_frames or -1) 62 | else: 63 | out, sr = torchaudio.load(str(file), offset=offset, num_frames=num_frames) 64 | 65 | 66 | if sr != self.sample_rate: 67 | raise RuntimeError(f"Expected {file} to have sample rate of " 68 | f"{self.sample_rate}, but got {sr}") 69 | if out.shape[0] != self.channels: 70 | #raise RuntimeError(f"Expected {file} to have shape of " 71 | # f"{self.channels}, but got {out.shape[0]}") 72 | #print("Normalizing stereo file") 73 | out = torch.mean(out, dim=0, keepdim=True) 74 | if num_frames: 75 | out = F.pad(out, (0, num_frames - out.shape[-1])) 76 | if self.with_path: 77 | return out, file 78 | else: 79 | return out 80 | -------------------------------------------------------------------------------- /src/ddp/distrib.py: -------------------------------------------------------------------------------- 1 | # taken from https://github.com/facebookresearch/denoiser 2 | 3 | import logging 4 | import os 5 | 6 | import torch 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch.utils.data import DataLoader, Subset 9 | from torch.nn.parallel.distributed import DistributedDataParallel 10 | 11 | logger = logging.getLogger(__name__) 12 | rank = 0 13 | world_size = 1 14 | 15 | 16 | def init(args): 17 | """init. 18 | 19 | Initialize DDP using the given rendezvous file. 20 | """ 21 | global rank, world_size 22 | if args.ddp: 23 | assert args.rank is not None and args.world_size is not None 24 | rank = args.rank 25 | world_size = args.world_size 26 | if world_size == 1: 27 | return 28 | torch.cuda.set_device(rank) 29 | torch.distributed.init_process_group( 30 | backend=args.ddp_backend, 31 | init_method='file://' + os.path.abspath(args.rendezvous_file), 32 | world_size=world_size, 33 | rank=rank) 34 | logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size) 35 | 36 | 37 | def close(): 38 | if world_size == 1: 39 | return 40 | torch.distributed.destroy_process_group() 41 | logger.debug("Closed distribued process, rank %d/%d", rank, world_size) 42 | 43 | def average(metrics, count=1.): 44 | """average. 45 | 46 | Average all the relevant metrices across processes 47 | `metrics`should be a 1D float32 vector. Returns the average of `metrics` 48 | over all hosts. You can use `count` to control the weight of each worker. 49 | """ 50 | if world_size == 1: 51 | return metrics 52 | tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) 53 | tensor *= count 54 | torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) 55 | return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() 56 | 57 | 58 | def wrap(model): 59 | """wrap. 60 | 61 | Wrap a model with DDP if distributed training is enabled. 62 | """ 63 | if world_size == 1: 64 | return model 65 | else: 66 | return DistributedDataParallel( 67 | model, 68 | device_ids=[torch.cuda.current_device()], 69 | output_device=torch.cuda.current_device()) 70 | 71 | 72 | def barrier(): 73 | if world_size > 1: 74 | torch.distributed.barrier() 75 | 76 | 77 | def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): 78 | """loader. 79 | 80 | Create a dataloader properly in case of distributed training. 81 | If a gradient is going to be computed you must set `shuffle=True`. 82 | 83 | :param dataset: the dataset to be parallelized 84 | :param args: relevant args for the loader 85 | :param shuffle: shuffle examples 86 | :param klass: loader class 87 | :param kwargs: relevant args 88 | """ 89 | 90 | if world_size == 1: 91 | return klass(dataset, *args, shuffle=shuffle, **kwargs) 92 | 93 | if shuffle: 94 | # train means we will compute backward, we use DistributedSampler 95 | sampler = DistributedSampler(dataset) 96 | # We ignore shuffle, DistributedSampler already shuffles 97 | return klass(dataset, *args, **kwargs, sampler=sampler) 98 | else: 99 | # We make a manual shard, as DistributedSampler otherwise replicate some examples 100 | dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) 101 | return klass(dataset, *args, shuffle=shuffle) 102 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import time 5 | 6 | import hydra 7 | import torch 8 | import logging 9 | from pathlib import Path 10 | 11 | import torchaudio 12 | from torchaudio.functional import resample 13 | 14 | from src.enhance import write 15 | from src.models import modelFactory 16 | from src.model_serializer import SERIALIZE_KEY_MODELS, SERIALIZE_KEY_BEST_STATES, SERIALIZE_KEY_STATE 17 | from src.utils import bold 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | SEGMENT_DURATION_SEC = 10 23 | 24 | def _load_model(args): 25 | model_name = args.experiment.model 26 | checkpoint_file = Path(args.checkpoint_file) 27 | model = modelFactory.get_model(args)['generator'] 28 | package = torch.load(checkpoint_file, 'cpu') 29 | load_best = args.continue_best 30 | if load_best: 31 | logger.info(bold(f'Loading model {model_name} from best state.')) 32 | model.load_state_dict( 33 | package[SERIALIZE_KEY_BEST_STATES][SERIALIZE_KEY_MODELS]['generator'][SERIALIZE_KEY_STATE]) 34 | else: 35 | logger.info(bold(f'Loading model {model_name} from last state.')) 36 | model.load_state_dict(package[SERIALIZE_KEY_MODELS]['generator'][SERIALIZE_KEY_STATE]) 37 | 38 | return model 39 | 40 | 41 | @hydra.main(config_path="conf", config_name="main_config") # for latest version of hydra=1.0 42 | def main(args): 43 | global __file__ 44 | __file__ = hydra.utils.to_absolute_path(__file__) 45 | 46 | print(args) 47 | model = _load_model(args) 48 | device = torch.device('cuda') 49 | model.cuda() 50 | filename = args.filename 51 | file_basename = Path(filename).stem 52 | output_dir = args.output 53 | lr_sig, sr = torchaudio.load(str(filename)) 54 | if lr_sig.shape[1] > 1: 55 | lr_sig = torch.mean(lr_sig, dim=0, keepdim=True) 56 | 57 | if args.experiment.upsample: 58 | lr_sig = resample(lr_sig, sr, args.experiment.hr_sr) 59 | sr = args.experiment.hr_sr 60 | 61 | logger.info(f'lr wav shape: {lr_sig.shape}') 62 | 63 | segment_duration_samples = sr * SEGMENT_DURATION_SEC 64 | n_chunks = math.ceil(lr_sig.shape[-1] / segment_duration_samples) 65 | logger.info(f'number of chunks: {n_chunks}') 66 | 67 | lr_chunks = [] 68 | for i in range(n_chunks): 69 | start = i * segment_duration_samples 70 | end = min((i + 1) * segment_duration_samples, lr_sig.shape[-1]) 71 | lr_chunks.append(lr_sig[:, start:end]) 72 | 73 | pr_chunks = [] 74 | 75 | model.eval() 76 | pred_start = time.time() 77 | with torch.no_grad(): 78 | for i, lr_chunk in enumerate(lr_chunks): 79 | pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0) 80 | logger.info(f'lr chunk {i} shape: {lr_chunk.shape}') 81 | logger.info(f'pr chunk {i} shape: {pr_chunk.shape}') 82 | pr_chunks.append(pr_chunk.cpu()) 83 | 84 | pred_duration = time.time() - pred_start 85 | logger.info(f'prediction duration: {pred_duration}') 86 | 87 | pr = torch.concat(pr_chunks, dim=-1) 88 | 89 | logger.info(f'pr wav shape: {pr.shape}') 90 | 91 | out_filename = os.path.join(output_dir, file_basename + '.wav') 92 | os.makedirs(output_dir, exist_ok=True) 93 | 94 | logger.info(f'saving to: {out_filename}, with sample_rate: {args.experiment.hr_sr}') 95 | 96 | write(pr, out_filename, args.experiment.hr_sr) 97 | 98 | """ 99 | Need to add filename and output to args. 100 | Usage: python predict.py +filename= +output= 101 | """ 102 | if __name__ == "__main__": 103 | main() -------------------------------------------------------------------------------- /conf/main_config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - experiment: aeromamba 3 | - dset: chopin-11-44-one 4 | - override hydra/job_logging: colorlog 5 | - override hydra/hydra_logging: colorlog 6 | 7 | # Logging and printing, and does not impact training 8 | num_prints: 10 9 | device: cuda 10 | num_workers: 8 11 | verbose: 0 12 | show: 0 # just show the model and its size and exit 13 | 14 | # log results 15 | log_results: true 16 | 17 | # Checkpointing, by default automatically load last checkpoint 18 | checkpoint: true 19 | continue_from: '' # Path the a checkpoint.th file to start from. 20 | # this is not used in the name of the experiment! 21 | # so use a dummy=something not to mixup experiment. 22 | continue_best: false # continue from best, not last state if continue_from is set. 23 | restart: false # Ignore existing checkpoints 24 | checkpoint_file: checkpoint.th 25 | best_file: best.th # will contain only best model at any point 26 | history_file: history.json 27 | test_results_file: test_results.json 28 | samples_dir: samples/ 29 | keep_history: true 30 | 31 | # Other stuff 32 | seed: 2036 33 | dummy: '' # use this if you want twice the same exp, with a different name 34 | 35 | # Evaluation stuff 36 | visqol: False # compute visqol? 37 | visqol_path: # *INSERT ABSOLUTE PATH TO VISQOL HERE* 38 | eval_every: 3 # compute test metrics every so epochs 39 | enhance_samples_limit: -1 40 | valid_equals_test: # whether valid_dset == test_dset, set in train.py script 41 | cross_valid: False 42 | cross_valid_every: 1 43 | joint_evaluate_and_enhance: False 44 | evaluate_on_best: False 45 | 46 | #wand_b 47 | wandb: 48 | project_name: 'AEROMamba Chopin' 49 | entity: #optional, must exist beforehand in wandb account 50 | mode: online # online/offline/disabled 51 | log: all # gradients/parameters/all/None 52 | log_freq: 10 53 | n_files_to_log: 5 # number or -1 for all files 54 | n_files_to_log_to_table: 1 # this is for the results table at the end of run 55 | tags: [ ] 56 | resume: false 57 | 58 | # Optimization related 59 | optim: adam 60 | lr: 3e-4 61 | beta1: 0.8 62 | beta2: 0.999 63 | losses: [ stft ] 64 | stft_sc_factor: .5 65 | stft_mag_factor: .5 66 | epochs: 696 67 | 68 | # Experiment launching, distributed 69 | ddp: false 70 | ddp_backend: nccl 71 | rendezvous_file: ./rendezvous 72 | 73 | # Internal config, don't set manually 74 | rank: 75 | world_size: 76 | 77 | # Hydra config 78 | hydra: 79 | sweep: 80 | dir: ./outputs/${dset.name}/${experiment.name} 81 | subdir: ${hydra.job.num} 82 | run: 83 | dir: ./outputs/${dset.name}/${experiment.name} 84 | job: 85 | config: 86 | # configuration for the ${hydra.job.override_dirname} runtime variable 87 | override_dirname: 88 | kv_sep: '=' 89 | item_sep: ',' 90 | # Remove all paths, as the / in them would mess up things 91 | # Remove params that would not impact the training itself 92 | # Remove all slurm and submit params. 93 | # This is ugly I know... 94 | exclude_keys: [ 95 | 'hydra.job_logging.handles.file.filename', 96 | 'dset.train', 'dset.valid', 'dset.test', 97 | 'num_prints', 'continue_from', 98 | 'device', 'num_workers', 'print_freq', 'restart', 'verbose', 99 | 'log' ] 100 | job_logging: 101 | handlers: 102 | file: 103 | class: logging.FileHandler 104 | mode: w 105 | formatter: colorlog 106 | filename: trainer.log 107 | console: 108 | class: logging.StreamHandler 109 | formatter: colorlog 110 | stream: ext://sys.stderr 111 | 112 | hydra_logging: 113 | handlers: 114 | console: 115 | class: logging.StreamHandler 116 | formatter: colorlog 117 | stream: ext://sys.stderr 118 | -------------------------------------------------------------------------------- /data_prep/create_meta_files.py: -------------------------------------------------------------------------------- 1 | import sox 2 | import os 3 | import sys 4 | import argparse 5 | import glob 6 | import torchaudio 7 | from collections import namedtuple 8 | import json 9 | from multiprocessing import Process, Manager 10 | import pathlib 11 | 12 | FILE_PATTERN='*.wav' 13 | TOTAL_N_SPEAKERS=2 14 | TRAIN_N_SPEAKERS=1 15 | TEST_N_SPEAKERS=1 16 | 17 | Info = namedtuple("Info", ["length", "sample_rate", "channels"]) 18 | 19 | 20 | def get_info(path): 21 | info = torchaudio.info(path) 22 | if hasattr(info, 'num_frames'): 23 | # new version of torchaudio 24 | return Info(info.num_frames, info.sample_rate, info.num_channels) 25 | else: 26 | siginfo = info[0] 27 | return Info(siginfo.length // siginfo.channels, siginfo.rate, siginfo.channels) 28 | 29 | 30 | def add_subdir_meta(subdir_path, shared_meta, n_samples_limit): 31 | if n_samples_limit and len(shared_meta) > n_samples_limit: 32 | return 33 | print(f'creating meta for {subdir_path}') 34 | audio_files = glob.glob(os.path.join(subdir_path, FILE_PATTERN)) 35 | for idx, file in enumerate(audio_files): 36 | info = get_info(file) 37 | shared_meta.append((file, info.length)) 38 | 39 | 40 | def create_subdirs_meta(subdirs_paths, n_samples_limit): 41 | with Manager() as manager: 42 | shared_meta = manager.list() 43 | processes = [] 44 | for subdir_path in subdirs_paths: 45 | p = Process(target=add_subdir_meta, args=(subdir_path, shared_meta, n_samples_limit)) 46 | p.start() 47 | processes.append(p) 48 | for p in processes: 49 | p.join() 50 | 51 | meta = list(shared_meta) 52 | meta.sort() 53 | if n_samples_limit: 54 | meta = meta[:n_samples_limit] 55 | return meta 56 | 57 | def create_meta(data_dir, n_samples_limit=None): 58 | root, subdirs, files = next(os.walk(data_dir, topdown=True)) 59 | #subdirs.sort() no need when using train, test folder organisation 60 | #assert len(subdirs) == TOTAL_N_SPEAKERS 61 | train_subdirs_paths = [os.path.join(root, subdirs[1])] 62 | 63 | test_subdirs_paths = [os.path.join(root, subdirs[0])] 64 | #assert len(test_subdirs_paths) == TEST_N_SPEAKERS 65 | train_meta = create_subdirs_meta(train_subdirs_paths, n_samples_limit) 66 | test_meta = create_subdirs_meta(test_subdirs_paths, n_samples_limit) 67 | 68 | if n_samples_limit: 69 | assert len(train_meta) == n_samples_limit 70 | assert len(test_meta) == n_samples_limit 71 | 72 | return train_meta, test_meta 73 | 74 | 75 | 76 | def parse_args(): 77 | parser = argparse.ArgumentParser(description='Resample data.') 78 | parser.add_argument('data_dir', help='directory containing source files') 79 | parser.add_argument('target_dir', help='output directory for created json files') 80 | parser.add_argument('json_filename', help='filename for created json files') 81 | parser.add_argument('--n_samples_limit', type=int, help='limit number of files') 82 | return parser.parse_args() 83 | 84 | 85 | 86 | """ 87 | usage: python data_prep/create_meta_file.py 88 | """ 89 | def main(): 90 | args = parse_args() 91 | 92 | 93 | os.makedirs(args.target_dir, exist_ok=True) 94 | os.makedirs(os.path.join(args.target_dir, 'tr'), exist_ok=True) 95 | os.makedirs(os.path.join(args.target_dir, 'tt'), exist_ok=True) 96 | 97 | 98 | train_meta, test_meta = create_meta(args.data_dir, args.n_samples_limit) 99 | 100 | train_json_object = json.dumps(train_meta, indent=4) 101 | test_json_object = json.dumps(test_meta, indent=4) 102 | with open(os.path.join(args.target_dir, 'tr', args.json_filename + '.json'), "w") as train_out: 103 | train_out.write(train_json_object) 104 | with open(os.path.join(args.target_dir, 'tt', args.json_filename + '.json'), "w") as test_out: 105 | test_out.write(test_json_object) 106 | 107 | print(f'Done creating meta for {args.data_dir}.') 108 | 109 | 110 | if __name__ == '__main__': 111 | main() -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105349-0ilrdhbg/logs/debug.log: -------------------------------------------------------------------------------- 1 | 2024-10-29 10:53:49,338 INFO MainThread:296384 [wandb_setup.py:_flush():79] Current SDK version is 0.18.5 2 | 2024-10-29 10:53:49,338 INFO MainThread:296384 [wandb_setup.py:_flush():79] Configure stats pid to 296384 3 | 2024-10-29 10:53:49,338 INFO MainThread:296384 [wandb_setup.py:_flush():79] Loading settings from /home/wallace.abreu/.config/wandb/settings 4 | 2024-10-29 10:53:49,338 INFO MainThread:296384 [wandb_setup.py:_flush():79] Loading settings from /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/settings 5 | 2024-10-29 10:53:49,339 INFO MainThread:296384 [wandb_setup.py:_flush():79] Loading settings from environment variables: {} 6 | 2024-10-29 10:53:49,339 INFO MainThread:296384 [wandb_setup.py:_flush():79] Applying setup settings: {'mode': 'online', '_disable_service': None} 7 | 2024-10-29 10:53:49,339 WARNING MainThread:296384 [wandb_setup.py:_flush():79] Could not save program above cwd: /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py 8 | 2024-10-29 10:53:49,340 INFO MainThread:296384 [wandb_setup.py:_flush():79] Inferring run settings from compute environment: {'program_relpath': None, 'program_abspath': '/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py', 'program': '/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py'} 9 | 2024-10-29 10:53:49,340 INFO MainThread:296384 [wandb_setup.py:_flush():79] Applying login settings: {} 10 | 2024-10-29 10:53:49,341 INFO MainThread:296384 [wandb_init.py:_log_setup():534] Logging user logs to /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105349-0ilrdhbg/logs/debug.log 11 | 2024-10-29 10:53:49,341 INFO MainThread:296384 [wandb_init.py:_log_setup():535] Logging internal logs to /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105349-0ilrdhbg/logs/debug-internal.log 12 | 2024-10-29 10:53:49,343 INFO MainThread:296384 [wandb_init.py:init():621] calling init triggers 13 | 2024-10-29 10:53:49,343 INFO MainThread:296384 [wandb_init.py:init():628] wandb.init called with sweep_config: {} 14 | config: {'eval_every': 3, 'optim': 'adam', 'lr': 0.0003, 'losses': ['stft'], 'epochs': 696, 'name': 'aeromamba', 'lr_sr': 11025, 'hr_sr': 44100, 'segment': 4, 'stride': 4, 'pad': True, 'upsample': False, 'batch_size': 4, 'nfft': 512, 'hop_length': 256, 'fixed_n_examples': None, 'power_threshold': 0.001, 'model': 'aero', 'aero': {'in_channels': 1, 'out_channels': 1, 'channels': 48, 'growth': 2, 'nfft': '${experiment.nfft}', 'hop_length': '${experiment.hop_length}', 'end_iters': 0, 'cac': True, 'rewrite': True, 'hybrid': False, 'hybrid_old': False, 'freq_emb': 0.2, 'emb_scale': 10, 'emb_smooth': True, 'kernel_size': 8, 'strides': [4, 4, 2, 2], 'context': 1, 'context_enc': 0, 'freq_ends': 4, 'enc_freq_attn': 0, 'norm_starts': 2, 'norm_groups': 4, 'dconv_mode': 1, 'dconv_depth': 2, 'dconv_comp': 4, 'dconv_time_attn': 4, 'dconv_lstm': 4, 'dconv_mamba': 0, 'dconv_init': 0.001, 'rescale': 0.1, 'lr_sr': '${experiment.lr_sr}', 'hr_sr': '${experiment.hr_sr}', 'spec_upsample': True, 'act_func': 'snake', 'debug': False}, 'adversarial': True, 'features_loss_lambda': 100, 'only_features_loss': False, 'only_adversarial_loss': False, 'discriminator_models': ['msd_melgan'], 'melgan_discriminator': {'n_layers': 4, 'num_D': 3, 'downsampling_factor': 4, 'ndf': 16}, 'train': '/home/wallace.abreu/Mestrado/aero_vanilla/egs/musdb-mixture-11-44/tr', 'test': '/home/wallace.abreu/Mestrado/aero_vanilla/egs/musdb-mixture-11-44/tt'} 15 | 2024-10-29 10:53:49,344 INFO MainThread:296384 [wandb_init.py:init():671] starting backend 16 | 2024-10-29 10:53:49,344 INFO MainThread:296384 [wandb_init.py:init():675] sending inform_init request 17 | 2024-10-29 10:53:49,349 INFO MainThread:296384 [backend.py:_multiprocessing_setup():104] multiprocessing start_methods=fork,spawn,forkserver, using: spawn 18 | 2024-10-29 10:53:49,350 INFO MainThread:296384 [wandb_init.py:init():688] backend started and connected 19 | 2024-10-29 10:53:49,356 INFO MainThread:296384 [wandb_init.py:init():783] updated telemetry 20 | 2024-10-29 10:53:49,357 INFO MainThread:296384 [wandb_init.py:init():816] communicating run to backend with 90.0 second timeout 21 | 2024-10-29 10:53:50,250 INFO MainThread:296384 [wandb_init.py:init():867] starting run threads in backend 22 | 2024-10-29 10:53:51,869 INFO MainThread:296384 [wandb_run.py:_console_start():2463] atexit reg 23 | 2024-10-29 10:53:51,869 INFO MainThread:296384 [wandb_run.py:_redirect():2311] redirect: wrap_raw 24 | 2024-10-29 10:53:51,869 INFO MainThread:296384 [wandb_run.py:_redirect():2376] Wrapping output streams. 25 | 2024-10-29 10:53:51,872 INFO MainThread:296384 [wandb_run.py:_redirect():2401] Redirects installed. 26 | 2024-10-29 10:53:51,928 INFO MainThread:296384 [wandb_init.py:init():911] run started, returning control to user process 27 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/debug.log: -------------------------------------------------------------------------------- 1 | 2024-10-29 10:54:45,978 INFO MainThread:296628 [wandb_setup.py:_flush():79] Current SDK version is 0.18.5 2 | 2024-10-29 10:54:45,978 INFO MainThread:296628 [wandb_setup.py:_flush():79] Configure stats pid to 296628 3 | 2024-10-29 10:54:45,978 INFO MainThread:296628 [wandb_setup.py:_flush():79] Loading settings from /home/wallace.abreu/.config/wandb/settings 4 | 2024-10-29 10:54:45,979 INFO MainThread:296628 [wandb_setup.py:_flush():79] Loading settings from /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/settings 5 | 2024-10-29 10:54:45,980 INFO MainThread:296628 [wandb_setup.py:_flush():79] Loading settings from environment variables: {} 6 | 2024-10-29 10:54:45,980 INFO MainThread:296628 [wandb_setup.py:_flush():79] Applying setup settings: {'mode': 'online', '_disable_service': None} 7 | 2024-10-29 10:54:45,980 WARNING MainThread:296628 [wandb_setup.py:_flush():79] Could not save program above cwd: /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py 8 | 2024-10-29 10:54:45,984 INFO MainThread:296628 [wandb_setup.py:_flush():79] Inferring run settings from compute environment: {'program_relpath': None, 'program_abspath': '/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py', 'program': '/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py'} 9 | 2024-10-29 10:54:45,984 INFO MainThread:296628 [wandb_setup.py:_flush():79] Applying login settings: {} 10 | 2024-10-29 10:54:45,984 INFO MainThread:296628 [wandb_init.py:_log_setup():534] Logging user logs to /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/logs/debug.log 11 | 2024-10-29 10:54:45,987 INFO MainThread:296628 [wandb_init.py:_log_setup():535] Logging internal logs to /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/logs/debug-internal.log 12 | 2024-10-29 10:54:45,990 INFO MainThread:296628 [wandb_init.py:init():621] calling init triggers 13 | 2024-10-29 10:54:45,990 INFO MainThread:296628 [wandb_init.py:init():628] wandb.init called with sweep_config: {} 14 | config: {'eval_every': 3, 'optim': 'adam', 'lr': 0.0003, 'losses': ['stft'], 'epochs': 696, 'name': 'aeromamba', 'lr_sr': 11025, 'hr_sr': 44100, 'segment': 4, 'stride': 4, 'pad': True, 'upsample': False, 'batch_size': 4, 'nfft': 512, 'hop_length': 256, 'fixed_n_examples': None, 'power_threshold': 0.001, 'model': 'aero', 'aero': {'in_channels': 1, 'out_channels': 1, 'channels': 48, 'growth': 2, 'nfft': '${experiment.nfft}', 'hop_length': '${experiment.hop_length}', 'end_iters': 0, 'cac': True, 'rewrite': True, 'hybrid': False, 'hybrid_old': False, 'freq_emb': 0.2, 'emb_scale': 10, 'emb_smooth': True, 'kernel_size': 8, 'strides': [4, 4, 2, 2], 'context': 1, 'context_enc': 0, 'freq_ends': 4, 'enc_freq_attn': 0, 'norm_starts': 2, 'norm_groups': 4, 'dconv_mode': 1, 'dconv_depth': 2, 'dconv_comp': 4, 'dconv_time_attn': 4, 'dconv_lstm': 4, 'dconv_mamba': 0, 'dconv_init': 0.001, 'rescale': 0.1, 'lr_sr': '${experiment.lr_sr}', 'hr_sr': '${experiment.hr_sr}', 'spec_upsample': True, 'act_func': 'snake', 'debug': False}, 'adversarial': True, 'features_loss_lambda': 100, 'only_features_loss': False, 'only_adversarial_loss': False, 'discriminator_models': ['msd_melgan'], 'melgan_discriminator': {'n_layers': 4, 'num_D': 3, 'downsampling_factor': 4, 'ndf': 16}, 'train': '/home/wallace.abreu/Mestrado/aero_vanilla/egs/musdb-mixture-11-44/tr', 'test': '/home/wallace.abreu/Mestrado/aero_vanilla/egs/musdb-mixture-11-44/tt'} 15 | 2024-10-29 10:54:45,993 INFO MainThread:296628 [wandb_init.py:init():671] starting backend 16 | 2024-10-29 10:54:45,993 INFO MainThread:296628 [wandb_init.py:init():675] sending inform_init request 17 | 2024-10-29 10:54:45,998 INFO MainThread:296628 [backend.py:_multiprocessing_setup():104] multiprocessing start_methods=fork,spawn,forkserver, using: spawn 18 | 2024-10-29 10:54:45,999 INFO MainThread:296628 [wandb_init.py:init():688] backend started and connected 19 | 2024-10-29 10:54:46,004 INFO MainThread:296628 [wandb_init.py:init():783] updated telemetry 20 | 2024-10-29 10:54:46,005 INFO MainThread:296628 [wandb_init.py:init():816] communicating run to backend with 90.0 second timeout 21 | 2024-10-29 10:54:46,478 INFO MainThread:296628 [wandb_init.py:init():867] starting run threads in backend 22 | 2024-10-29 10:54:46,830 INFO MainThread:296628 [wandb_run.py:_console_start():2463] atexit reg 23 | 2024-10-29 10:54:46,830 INFO MainThread:296628 [wandb_run.py:_redirect():2311] redirect: wrap_raw 24 | 2024-10-29 10:54:46,831 INFO MainThread:296628 [wandb_run.py:_redirect():2376] Wrapping output streams. 25 | 2024-10-29 10:54:46,831 INFO MainThread:296628 [wandb_run.py:_redirect():2401] Redirects installed. 26 | 2024-10-29 10:54:46,842 INFO MainThread:296628 [wandb_init.py:init():911] run started, returning control to user process 27 | 2024-10-29 10:55:29,067 WARNING MsgRouterThr:296628 [router.py:message_loop():77] message_loop has been closed 28 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/.hydra/hydra.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ./outputs/${dset.name}/${experiment.name} 4 | sweep: 5 | dir: ./outputs/${dset.name}/${experiment.name} 6 | subdir: ${hydra.job.num} 7 | launcher: 8 | _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher 9 | sweeper: 10 | _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper 11 | max_batch_size: null 12 | help: 13 | app_name: ${hydra.job.name} 14 | header: '${hydra.help.app_name} is powered by Hydra. 15 | 16 | ' 17 | footer: 'Powered by Hydra (https://hydra.cc) 18 | 19 | Use --hydra-help to view Hydra specific help 20 | 21 | ' 22 | template: '${hydra.help.header} 23 | 24 | == Configuration groups == 25 | 26 | Compose your configuration from those groups (group=option) 27 | 28 | 29 | $APP_CONFIG_GROUPS 30 | 31 | 32 | == Config == 33 | 34 | Override anything in the config (foo.bar=value) 35 | 36 | 37 | $CONFIG 38 | 39 | 40 | ${hydra.help.footer} 41 | 42 | ' 43 | hydra_help: 44 | template: 'Hydra (${hydra.runtime.version}) 45 | 46 | See https://hydra.cc for more info. 47 | 48 | 49 | == Flags == 50 | 51 | $FLAGS_HELP 52 | 53 | 54 | == Configuration groups == 55 | 56 | Compose your configuration from those groups (For example, append hydra/job_logging=disabled 57 | to command line) 58 | 59 | 60 | $HYDRA_CONFIG_GROUPS 61 | 62 | 63 | Use ''--cfg hydra'' to Show the Hydra config. 64 | 65 | ' 66 | hydra_help: ??? 67 | hydra_logging: 68 | version: 1 69 | formatters: 70 | colorlog: 71 | (): colorlog.ColoredFormatter 72 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s' 73 | handlers: 74 | console: 75 | class: logging.StreamHandler 76 | formatter: colorlog 77 | stream: ext://sys.stderr 78 | root: 79 | level: INFO 80 | handlers: 81 | - console 82 | disable_existing_loggers: false 83 | job_logging: 84 | version: 1 85 | formatters: 86 | simple: 87 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 88 | colorlog: 89 | (): colorlog.ColoredFormatter 90 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 91 | - %(message)s' 92 | log_colors: 93 | DEBUG: purple 94 | INFO: green 95 | WARNING: yellow 96 | ERROR: red 97 | CRITICAL: red 98 | handlers: 99 | console: 100 | class: logging.StreamHandler 101 | formatter: colorlog 102 | stream: ext://sys.stderr 103 | file: 104 | class: logging.FileHandler 105 | formatter: colorlog 106 | filename: trainer.log 107 | mode: w 108 | root: 109 | level: INFO 110 | handlers: 111 | - console 112 | - file 113 | disable_existing_loggers: false 114 | env: {} 115 | searchpath: [] 116 | callbacks: {} 117 | output_subdir: .hydra 118 | overrides: 119 | hydra: [] 120 | task: 121 | - checkpoint_file=/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/chopin-11-44-one/aeromamba/checkpoint.th 122 | job: 123 | name: test 124 | override_dirname: checkpoint_file=/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/chopin-11-44-one/aeromamba/checkpoint.th 125 | id: ??? 126 | num: ??? 127 | config_name: main_config 128 | env_set: {} 129 | env_copy: [] 130 | config: 131 | override_dirname: 132 | kv_sep: '=' 133 | item_sep: ',' 134 | exclude_keys: 135 | - hydra.job_logging.handles.file.filename 136 | - dset.train 137 | - dset.valid 138 | - dset.test 139 | - num_prints 140 | - continue_from 141 | - device 142 | - num_workers 143 | - print_freq 144 | - restart 145 | - verbose 146 | - log 147 | runtime: 148 | version: 1.1.1 149 | cwd: /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir 150 | config_sources: 151 | - path: hydra.conf 152 | schema: pkg 153 | provider: hydra 154 | - path: /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/conf 155 | schema: file 156 | provider: main 157 | - path: hydra_plugins.hydra_colorlog.conf 158 | schema: pkg 159 | provider: hydra-colorlog 160 | - path: '' 161 | schema: structured 162 | provider: schema 163 | choices: 164 | dset: musdb-mixture-11-44 165 | experiment: aeromamba 166 | hydra/env: default 167 | hydra/callbacks: null 168 | hydra/job_logging: colorlog 169 | hydra/hydra_logging: colorlog 170 | hydra/hydra_help: default 171 | hydra/help: default 172 | hydra/sweeper: basic 173 | hydra/launcher: basic 174 | hydra/output: default 175 | verbose: false 176 | -------------------------------------------------------------------------------- /outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/logs/debug.log: -------------------------------------------------------------------------------- 1 | 2024-10-29 10:54:45,978 INFO MainThread:296628 [wandb_setup.py:_flush():79] Current SDK version is 0.18.5 2 | 2024-10-29 10:54:45,978 INFO MainThread:296628 [wandb_setup.py:_flush():79] Configure stats pid to 296628 3 | 2024-10-29 10:54:45,978 INFO MainThread:296628 [wandb_setup.py:_flush():79] Loading settings from /home/wallace.abreu/.config/wandb/settings 4 | 2024-10-29 10:54:45,979 INFO MainThread:296628 [wandb_setup.py:_flush():79] Loading settings from /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/settings 5 | 2024-10-29 10:54:45,980 INFO MainThread:296628 [wandb_setup.py:_flush():79] Loading settings from environment variables: {} 6 | 2024-10-29 10:54:45,980 INFO MainThread:296628 [wandb_setup.py:_flush():79] Applying setup settings: {'mode': 'online', '_disable_service': None} 7 | 2024-10-29 10:54:45,980 WARNING MainThread:296628 [wandb_setup.py:_flush():79] Could not save program above cwd: /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py 8 | 2024-10-29 10:54:45,984 INFO MainThread:296628 [wandb_setup.py:_flush():79] Inferring run settings from compute environment: {'program_relpath': None, 'program_abspath': '/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py', 'program': '/nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/test.py'} 9 | 2024-10-29 10:54:45,984 INFO MainThread:296628 [wandb_setup.py:_flush():79] Applying login settings: {} 10 | 2024-10-29 10:54:45,984 INFO MainThread:296628 [wandb_init.py:_log_setup():534] Logging user logs to /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/logs/debug.log 11 | 2024-10-29 10:54:45,987 INFO MainThread:296628 [wandb_init.py:_log_setup():535] Logging internal logs to /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/outputs/musdb-mixture-11-44/aeromamba/wandb/run-20241029_105445-p0dnbwer/logs/debug-internal.log 12 | 2024-10-29 10:54:45,990 INFO MainThread:296628 [wandb_init.py:init():621] calling init triggers 13 | 2024-10-29 10:54:45,990 INFO MainThread:296628 [wandb_init.py:init():628] wandb.init called with sweep_config: {} 14 | config: {'eval_every': 3, 'optim': 'adam', 'lr': 0.0003, 'losses': ['stft'], 'epochs': 696, 'name': 'aeromamba', 'lr_sr': 11025, 'hr_sr': 44100, 'segment': 4, 'stride': 4, 'pad': True, 'upsample': False, 'batch_size': 4, 'nfft': 512, 'hop_length': 256, 'fixed_n_examples': None, 'power_threshold': 0.001, 'model': 'aero', 'aero': {'in_channels': 1, 'out_channels': 1, 'channels': 48, 'growth': 2, 'nfft': '${experiment.nfft}', 'hop_length': '${experiment.hop_length}', 'end_iters': 0, 'cac': True, 'rewrite': True, 'hybrid': False, 'hybrid_old': False, 'freq_emb': 0.2, 'emb_scale': 10, 'emb_smooth': True, 'kernel_size': 8, 'strides': [4, 4, 2, 2], 'context': 1, 'context_enc': 0, 'freq_ends': 4, 'enc_freq_attn': 0, 'norm_starts': 2, 'norm_groups': 4, 'dconv_mode': 1, 'dconv_depth': 2, 'dconv_comp': 4, 'dconv_time_attn': 4, 'dconv_lstm': 4, 'dconv_mamba': 0, 'dconv_init': 0.001, 'rescale': 0.1, 'lr_sr': '${experiment.lr_sr}', 'hr_sr': '${experiment.hr_sr}', 'spec_upsample': True, 'act_func': 'snake', 'debug': False}, 'adversarial': True, 'features_loss_lambda': 100, 'only_features_loss': False, 'only_adversarial_loss': False, 'discriminator_models': ['msd_melgan'], 'melgan_discriminator': {'n_layers': 4, 'num_D': 3, 'downsampling_factor': 4, 'ndf': 16}, 'train': '/home/wallace.abreu/Mestrado/aero_vanilla/egs/musdb-mixture-11-44/tr', 'test': '/home/wallace.abreu/Mestrado/aero_vanilla/egs/musdb-mixture-11-44/tt'} 15 | 2024-10-29 10:54:45,993 INFO MainThread:296628 [wandb_init.py:init():671] starting backend 16 | 2024-10-29 10:54:45,993 INFO MainThread:296628 [wandb_init.py:init():675] sending inform_init request 17 | 2024-10-29 10:54:45,998 INFO MainThread:296628 [backend.py:_multiprocessing_setup():104] multiprocessing start_methods=fork,spawn,forkserver, using: spawn 18 | 2024-10-29 10:54:45,999 INFO MainThread:296628 [wandb_init.py:init():688] backend started and connected 19 | 2024-10-29 10:54:46,004 INFO MainThread:296628 [wandb_init.py:init():783] updated telemetry 20 | 2024-10-29 10:54:46,005 INFO MainThread:296628 [wandb_init.py:init():816] communicating run to backend with 90.0 second timeout 21 | 2024-10-29 10:54:46,478 INFO MainThread:296628 [wandb_init.py:init():867] starting run threads in backend 22 | 2024-10-29 10:54:46,830 INFO MainThread:296628 [wandb_run.py:_console_start():2463] atexit reg 23 | 2024-10-29 10:54:46,830 INFO MainThread:296628 [wandb_run.py:_redirect():2311] redirect: wrap_raw 24 | 2024-10-29 10:54:46,831 INFO MainThread:296628 [wandb_run.py:_redirect():2376] Wrapping output streams. 25 | 2024-10-29 10:54:46,831 INFO MainThread:296628 [wandb_run.py:_redirect():2401] Redirects installed. 26 | 2024-10-29 10:54:46,842 INFO MainThread:296628 [wandb_init.py:init():911] run started, returning control to user process 27 | 2024-10-29 10:55:29,067 WARNING MsgRouterThr:296628 [router.py:message_loop():77] message_loop has been closed 28 | -------------------------------------------------------------------------------- /outputs/chopin-11-44/aeromamba/.hydra/hydra.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ./outputs/${dset.name}/${experiment.name} 4 | sweep: 5 | dir: ./outputs/${dset.name}/${experiment.name} 6 | subdir: ${hydra.job.num} 7 | launcher: 8 | _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher 9 | sweeper: 10 | _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper 11 | max_batch_size: null 12 | help: 13 | app_name: ${hydra.job.name} 14 | header: '${hydra.help.app_name} is powered by Hydra. 15 | 16 | ' 17 | footer: 'Powered by Hydra (https://hydra.cc) 18 | 19 | Use --hydra-help to view Hydra specific help 20 | 21 | ' 22 | template: '${hydra.help.header} 23 | 24 | == Configuration groups == 25 | 26 | Compose your configuration from those groups (group=option) 27 | 28 | 29 | $APP_CONFIG_GROUPS 30 | 31 | 32 | == Config == 33 | 34 | Override anything in the config (foo.bar=value) 35 | 36 | 37 | $CONFIG 38 | 39 | 40 | ${hydra.help.footer} 41 | 42 | ' 43 | hydra_help: 44 | template: 'Hydra (${hydra.runtime.version}) 45 | 46 | See https://hydra.cc for more info. 47 | 48 | 49 | == Flags == 50 | 51 | $FLAGS_HELP 52 | 53 | 54 | == Configuration groups == 55 | 56 | Compose your configuration from those groups (For example, append hydra/job_logging=disabled 57 | to command line) 58 | 59 | 60 | $HYDRA_CONFIG_GROUPS 61 | 62 | 63 | Use ''--cfg hydra'' to Show the Hydra config. 64 | 65 | ' 66 | hydra_help: ??? 67 | hydra_logging: 68 | version: 1 69 | formatters: 70 | colorlog: 71 | (): colorlog.ColoredFormatter 72 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s' 73 | handlers: 74 | console: 75 | class: logging.StreamHandler 76 | formatter: colorlog 77 | stream: ext://sys.stderr 78 | root: 79 | level: INFO 80 | handlers: 81 | - console 82 | disable_existing_loggers: false 83 | job_logging: 84 | version: 1 85 | formatters: 86 | simple: 87 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 88 | colorlog: 89 | (): colorlog.ColoredFormatter 90 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 91 | - %(message)s' 92 | log_colors: 93 | DEBUG: purple 94 | INFO: green 95 | WARNING: yellow 96 | ERROR: red 97 | CRITICAL: red 98 | handlers: 99 | console: 100 | class: logging.StreamHandler 101 | formatter: colorlog 102 | stream: ext://sys.stderr 103 | file: 104 | class: logging.FileHandler 105 | formatter: colorlog 106 | filename: trainer.log 107 | mode: w 108 | root: 109 | level: INFO 110 | handlers: 111 | - console 112 | - file 113 | disable_existing_loggers: false 114 | env: {} 115 | searchpath: [] 116 | callbacks: {} 117 | output_subdir: .hydra 118 | overrides: 119 | hydra: [] 120 | task: 121 | - dset=chopin-11-44 122 | - experiment=aeromamba 123 | - +filename=/home/wallace.abreu/Mestrado/aeromamba-lamir/ds_datasets/chopin/test/V.ASHKENAZYTrack11.wav 124 | - +output=/home/wallace.abreu/Mestrado/aeromamba-lamir/test 125 | job: 126 | name: predict_with_ola 127 | override_dirname: +filename=/home/wallace.abreu/Mestrado/aeromamba-lamir/ds_datasets/chopin/test/V.ASHKENAZYTrack11.wav,+output=/home/wallace.abreu/Mestrado/aeromamba-lamir/test,dset=chopin-11-44,experiment=aeromamba 128 | id: ??? 129 | num: ??? 130 | config_name: main_config 131 | env_set: {} 132 | env_copy: [] 133 | config: 134 | override_dirname: 135 | kv_sep: '=' 136 | item_sep: ',' 137 | exclude_keys: 138 | - hydra.job_logging.handles.file.filename 139 | - dset.train 140 | - dset.valid 141 | - dset.test 142 | - num_prints 143 | - continue_from 144 | - device 145 | - num_workers 146 | - print_freq 147 | - restart 148 | - verbose 149 | - log 150 | runtime: 151 | version: 1.1.1 152 | cwd: /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir 153 | config_sources: 154 | - path: hydra.conf 155 | schema: pkg 156 | provider: hydra 157 | - path: /nfs/home/wallace.abreu/Mestrado/aeromamba-lamir/conf 158 | schema: file 159 | provider: main 160 | - path: hydra_plugins.hydra_colorlog.conf 161 | schema: pkg 162 | provider: hydra-colorlog 163 | - path: '' 164 | schema: structured 165 | provider: schema 166 | choices: 167 | dset: chopin-11-44 168 | experiment: aeromamba 169 | hydra/env: default 170 | hydra/callbacks: null 171 | hydra/job_logging: colorlog 172 | hydra/hydra_logging: colorlog 173 | hydra/hydra_help: default 174 | hydra/help: default 175 | hydra/sweeper: basic 176 | hydra/launcher: basic 177 | hydra/output: default 178 | verbose: false 179 | -------------------------------------------------------------------------------- /predict_batch_with_ola.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import time 5 | 6 | import hydra 7 | import torch 8 | import logging 9 | from pathlib import Path 10 | import numpy as np 11 | import torchaudio 12 | from torchaudio.functional import resample 13 | 14 | from src.enhance import write 15 | from src.models import modelFactory 16 | from src.model_serializer import SERIALIZE_KEY_MODELS, SERIALIZE_KEY_BEST_STATES, SERIALIZE_KEY_STATE 17 | from src.utils import bold 18 | import soundfile as sf 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def overlap_and_add(chunks, overlap=256, window_len=1024): 23 | W = window_len 24 | win_left_side = np.bartlett(2 * overlap)[:overlap] 25 | win_right_side = np.bartlett(2 * overlap)[overlap:] 26 | window = np.concatenate((win_left_side, np.ones(W - 2 * overlap), win_right_side)) 27 | left_window = np.concatenate((np.ones(W - overlap), win_right_side)) 28 | right_window = np.concatenate((win_left_side, np.ones(W - overlap))) 29 | n_chunks = len(chunks) 30 | for i in range(n_chunks): 31 | if i == 0: 32 | y = (chunks[i].reshape(-1,) * left_window) 33 | else: 34 | x_chunk = chunks[i].reshape(-1,) 35 | if len(x_chunk) < W or i == n_chunks - 1: 36 | end_pad = W - len(x_chunk) 37 | x_chunk = np.pad(x_chunk, (0, end_pad), 'constant', constant_values=0) 38 | x_ola = x_chunk * right_window 39 | else: 40 | x_ola = x_chunk * window 41 | y = np.pad(y, (0, W - overlap), 'constant', constant_values=0) 42 | x_ola = np.pad(x_ola, (len(y) - len(x_ola), 0), 'constant', constant_values=0) 43 | y += x_ola 44 | return y 45 | 46 | SEGMENT_DURATION_SEC = 1 47 | 48 | def _load_model(args): 49 | model_name = args.experiment.model 50 | checkpoint_file = Path(args.checkpoint_file) 51 | model = modelFactory.get_model(args)['generator'] 52 | package = torch.load(checkpoint_file, 'cpu') 53 | load_best = args.continue_best 54 | if load_best: 55 | logger.info(bold(f'Loading model {model_name} from best state.')) 56 | model.load_state_dict( 57 | package[SERIALIZE_KEY_BEST_STATES][SERIALIZE_KEY_MODELS]['generator'][SERIALIZE_KEY_STATE]) 58 | else: 59 | logger.info(bold(f'Loading model {model_name} from last state.')) 60 | model.load_state_dict(package[SERIALIZE_KEY_MODELS]['generator'][SERIALIZE_KEY_STATE]) 61 | 62 | return model 63 | 64 | 65 | @hydra.main(config_path="conf", config_name="main_config") # for latest version of hydra=1.0 66 | def main(args): 67 | global __file__ 68 | __file__ = hydra.utils.to_absolute_path(__file__) 69 | 70 | print(args) 71 | model = _load_model(args) 72 | device = torch.device('cuda') 73 | model.cuda() 74 | folder_path = args.folder_path 75 | for files in os.listdir(folder_path): 76 | filename = os.path.join(folder_path, files) 77 | file_basename = Path(filename).stem 78 | output_dir = args.output 79 | lr_sig, sr = torchaudio.load(str(filename)) 80 | if lr_sig.shape[1] > 1: 81 | lr_sig = torch.mean(lr_sig, dim=0, keepdim=True) 82 | if args.experiment.upsample: 83 | lr_sig = resample(lr_sig, sr, args.experiment.hr_sr) 84 | sr = args.experiment.hr_sr 85 | 86 | logger.info(f'lr wav shape: {lr_sig.shape}') 87 | 88 | segment_duration_samples = sr * SEGMENT_DURATION_SEC 89 | W_hr = 44095 # 44100 samples minus the edge effect samples 90 | W_lr = 11025 91 | overlap_hr = 900 #heuristic value 92 | overlap_lr = overlap_hr // 4 93 | n_chunks = math.ceil(lr_sig.shape[-1] / (W_lr - overlap_lr)) 94 | logger.info(f'number of chunks: {n_chunks}') 95 | 96 | 97 | lr_chunks = [] 98 | for i in range(n_chunks): 99 | start = i * (W_lr - overlap_lr) 100 | end = min(start + W_lr, lr_sig.shape[-1]) 101 | lr_chunks.append(lr_sig[:, start:end]) 102 | pr_chunks = [] 103 | 104 | model.eval() 105 | pred_start = time.time() 106 | 107 | with torch.no_grad(): 108 | for i, lr_chunk in enumerate(lr_chunks): 109 | pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0) 110 | #remove edge effect samples (only the 4 final samples are distorted) 111 | pr_chunk = pr_chunk[:, :-5] 112 | pr_chunks.append(pr_chunk.cpu()) 113 | 114 | pred_duration = time.time() - pred_start 115 | logger.info(f'prediction duration: {pred_duration}') 116 | 117 | pr_ola = overlap_and_add(pr_chunks, overlap=overlap_hr, window_len=W_hr) 118 | logger.info(f'pr wav shape: {pr_ola.shape}') 119 | 120 | out_filename_ola = os.path.join(output_dir, file_basename + '.wav') 121 | os.makedirs(output_dir, exist_ok=True) 122 | 123 | logger.info(f'saving to: {out_filename_ola}, with sample_rate: {args.experiment.hr_sr}') 124 | 125 | sf.write(out_filename_ola, pr_ola, args.experiment.hr_sr) 126 | 127 | 128 | """ 129 | Need to add filename and output to args. 130 | Usage: python predict.py +folder_path= +output= 131 | """ 132 | if __name__ == "__main__": 133 | main() -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import subprocess 4 | import logging 5 | import time 6 | 7 | import numpy as np 8 | import sox 9 | import torch 10 | import torch.nn as nn 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | SLEEP_DURATION = 0.1 15 | VISQOL_MIN_DURATION = 0.48 16 | 17 | def run_metrics(clean, estimate, args, filename): 18 | hr_sr = args.experiment.hr_sr if 'experiment' in args else args.hr_sr 19 | speech_mode = args.experiment.speech_mode if 'speech_mode' in args.experiment else True 20 | lsd, visqol = get_metrics(clean, estimate, hr_sr, filename, speech_mode,args) 21 | return lsd, visqol 22 | 23 | 24 | def get_metrics(clean, estimate, sr, filename, speech_mode, args): 25 | calc_visqol = args.visqol and args.visqol_path 26 | visqol_path = args.visqol_path 27 | clean = clean.squeeze(dim=1) 28 | estimate = estimate.squeeze(dim=1) 29 | estimate_numpy = estimate.numpy() 30 | clean_numpy = clean.numpy() 31 | 32 | lsd = get_lsd(clean, estimate).item() 33 | visqol = get_visqol(clean_numpy, estimate_numpy, filename, sr, speech_mode, visqol_path) if calc_visqol else 0 34 | return lsd, visqol 35 | 36 | 37 | class STFTMag(nn.Module): 38 | def __init__(self, 39 | nfft=1024, 40 | hop=256): 41 | super().__init__() 42 | self.nfft = nfft 43 | self.hop = hop 44 | self.register_buffer('window', torch.hann_window(nfft), False) 45 | 46 | # x: [B,T] or [T] 47 | @torch.no_grad() 48 | def forward(self, x): 49 | T = x.shape[-1] 50 | stft = torch.stft(x, 51 | self.nfft, 52 | self.hop, 53 | window=self.window, 54 | ) # return_complex=False) #[B, F, TT,2] 55 | mag = torch.norm(stft, p=2, dim=-1) # [B, F, TT] 56 | return mag 57 | 58 | # taken from: https://github.com/nanahou/metric/blob/master/measure_SNR_LSD.py 59 | def get_lsd(ref_sig, out_sig): 60 | """ 61 | Compute LSD (log spectral distance) 62 | Arguments: 63 | out_sig: vector (torch.Tensor), enhanced signal [B,T] 64 | ref_sig: vector (torch.Tensor), reference signal(ground truth) [B,T] 65 | """ 66 | 67 | stft = STFTMag(2048, 512) 68 | sp = torch.log10(stft(ref_sig).square().clamp(1e-8)) 69 | st = torch.log10(stft(out_sig).square().clamp(1e-8)) 70 | return (sp - st).square().mean(dim=1).sqrt().mean() 71 | 72 | 73 | # based on: https://github.com/eagomez2/upf-smc-speech-enhancement-thesis/blob/main/src/utils/evaluation_process.py 74 | def get_visqol(ref_sig, out_sig, filename, sr, speech_mode, visqol_path): 75 | tmp_reference = f"{filename}_ref.wav" 76 | tmp_estimation = f"{filename}_est.wav" 77 | 78 | reference_abs_path = os.path.abspath(tmp_reference) 79 | estimation_abs_path = os.path.abspath(tmp_estimation) 80 | 81 | if speech_mode: 82 | target_sr = 16000 if sr != 16000 else None 83 | else: 84 | target_sr = 48000 if sr != 48000 else None 85 | 86 | tfm = sox.Transformer() 87 | tfm.convert(bitdepth=16, samplerate=target_sr) 88 | ref_sig = np.transpose(ref_sig) 89 | out_sig = np.transpose(out_sig) 90 | 91 | try: 92 | tfm.build_file(input_array=ref_sig, sample_rate_in=sr, output_filepath=reference_abs_path) 93 | tfm.build_file(input_array=out_sig, sample_rate_in=sr, output_filepath=estimation_abs_path) 94 | while not os.path.exists(reference_abs_path) and not os.path.exists(estimation_abs_path): 95 | time.sleep(SLEEP_DURATION) 96 | 97 | if not os.path.isfile(reference_abs_path): 98 | raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), reference_abs_path) 99 | if not os.path.isfile(estimation_abs_path): 100 | raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), estimation_abs_path) 101 | 102 | ref_duration = sox.file_info.duration(reference_abs_path) 103 | est_duration = sox.file_info.duration(estimation_abs_path) 104 | 105 | if ref_duration < VISQOL_MIN_DURATION or est_duration < VISQOL_MIN_DURATION: 106 | raise ValueError('File duration is too small.') 107 | 108 | visqol_cmd = ("cd " + visqol_path + "; " + 109 | "./bazel-bin/visqol " 110 | f"--reference_file {reference_abs_path} " 111 | f"--degraded_file {estimation_abs_path} ") 112 | 113 | if speech_mode: 114 | visqol_cmd += f"--use_speech_mode" 115 | 116 | visqol = subprocess.run(visqol_cmd, shell=True, 117 | stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 118 | 119 | # parse stdout to get the current float value 120 | visqol = visqol.stdout.decode("utf-8").split("\t")[-1].replace("\n", "") 121 | visqol = float(visqol) 122 | 123 | except FileNotFoundError as e: 124 | logger.info(f'visqol: failed to create {filename}') 125 | logger.info(str(e)) 126 | visqol = 0 127 | 128 | except Exception as e: 129 | logger.info(f'failed to get visqol of {filename}') 130 | logger.info(str(e)) 131 | visqol = 0 132 | 133 | else: 134 | # remove files to avoid filling space storage 135 | os.remove(reference_abs_path) 136 | os.remove(estimation_abs_path) 137 | 138 | return visqol 139 | -------------------------------------------------------------------------------- /src/models/stft_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs 3 | """ 4 | 5 | """STFT-based Loss modules.""" 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | def stft(x, fft_size, hop_size, win_length, window): 12 | """Perform STFT and convert to magnitude spectrogram. 13 | Args: 14 | x (Tensor): Input signal tensor (B, T). 15 | fft_size (int): FFT size. 16 | hop_size (int): Hop size. 17 | win_length (int): Window length. 18 | window (str): Window function type. 19 | Returns: 20 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 21 | """ 22 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window) 23 | real = x_stft[..., 0] 24 | imag = x_stft[..., 1] 25 | 26 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 27 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 28 | 29 | 30 | class SpectralConvergengeLoss(torch.nn.Module): 31 | """Spectral convergence loss module.""" 32 | 33 | def __init__(self): 34 | """Initilize spectral convergence loss module.""" 35 | super(SpectralConvergengeLoss, self).__init__() 36 | 37 | def forward(self, x_mag, y_mag): 38 | """Calculate forward propagation. 39 | Args: 40 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 41 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 42 | Returns: 43 | Tensor: Spectral convergence loss value. 44 | """ 45 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 46 | 47 | 48 | class LogSTFTMagnitudeLoss(torch.nn.Module): 49 | """Log STFT magnitude loss module.""" 50 | 51 | def __init__(self): 52 | """Initilize los STFT magnitude loss module.""" 53 | super(LogSTFTMagnitudeLoss, self).__init__() 54 | 55 | def forward(self, x_mag, y_mag): 56 | """Calculate forward propagation. 57 | Args: 58 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 59 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 60 | Returns: 61 | Tensor: Log STFT magnitude loss value. 62 | """ 63 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 64 | 65 | 66 | class STFTLoss(torch.nn.Module): 67 | """STFT loss module.""" 68 | 69 | def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): 70 | """Initialize STFT loss module.""" 71 | super(STFTLoss, self).__init__() 72 | self.fft_size = fft_size 73 | self.shift_size = shift_size 74 | self.win_length = win_length 75 | self.register_buffer("window", getattr(torch, window)(win_length)) 76 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 77 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 78 | 79 | def forward(self, x, y): 80 | """Calculate forward propagation. 81 | Args: 82 | x (Tensor): Predicted signal (B, T). 83 | y (Tensor): Groundtruth signal (B, T). 84 | Returns: 85 | Tensor: Spectral convergence loss value. 86 | Tensor: Log STFT magnitude loss value. 87 | """ 88 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 89 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 90 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 91 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 92 | 93 | return sc_loss, mag_loss 94 | 95 | 96 | class MultiResolutionSTFTLoss(torch.nn.Module): 97 | """Multi resolution STFT loss module.""" 98 | 99 | def __init__(self, 100 | fft_sizes=[1024, 2048, 512], 101 | hop_sizes=[120, 240, 50], 102 | win_lengths=[600, 1200, 240], 103 | window="hann_window", factor_sc=0.1, factor_mag=0.1): 104 | """Initialize Multi resolution STFT loss module. 105 | Args: 106 | fft_sizes (list): List of FFT sizes. 107 | hop_sizes (list): List of hop sizes. 108 | win_lengths (list): List of window lengths. 109 | window (str): Window function type. 110 | factor (float): a balancing factor across different losses. 111 | """ 112 | super(MultiResolutionSTFTLoss, self).__init__() 113 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 114 | self.stft_losses = torch.nn.ModuleList() 115 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 116 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 117 | self.factor_sc = factor_sc 118 | self.factor_mag = factor_mag 119 | 120 | def forward(self, x, y): 121 | """Calculate forward propagation. 122 | Args: 123 | x (Tensor): Predicted signal (B, T). 124 | y (Tensor): Groundtruth signal (B, T). 125 | Returns: 126 | Tensor: Multi resolution spectral convergence loss value. 127 | Tensor: Multi resolution log STFT magnitude loss value. 128 | """ 129 | sc_loss = 0.0 130 | mag_loss = 0.0 131 | for f in self.stft_losses: 132 | sc_l, mag_l = f(x, y) 133 | sc_loss += sc_l 134 | mag_loss += mag_l 135 | sc_loss /= len(self.stft_losses) 136 | mag_loss /= len(self.stft_losses) 137 | 138 | return self.factor_sc*sc_loss, self.factor_mag*mag_loss 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AEROMamba 2 | 3 | ## About 4 | 5 | Official PyTorch implementation of 6 | 7 | **AEROMamba: An efficient architecture for audio super-resolution using generative adversarial networks and state space models** 8 | 9 | whose demo is available in our [Webpage](https://aeromamba-super-resolution.github.io/). Our model is closely related to [AERO](https://github.com/slp-rl/aero) and [Mamba](https://github.com/state-spaces/mamba), so make sure to check them out if any questions arise regarding these modules. 10 | 11 | ## Installation 12 | 13 | Requirements: 14 | - Python 3.10.0 15 | - Pytorch 1.12.1 16 | - CUDA 11.3 17 | 18 | Instructions: 19 | - Create a conda environment `conda create -n python=3.10` or venv with python==3.10.0 20 | - Run `pip install -r requirements.txt` 21 | 22 | If there is any error in the previous step, make sure to install manually the required libs. For PyTorch/CUDA and Mamba, manual installation is done through 23 | 24 | - `CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE pip install causal_conv1d==1.1.2.post1` 25 | - `CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE pip install mamba-ssm==1.1.3.post1` 26 | - `conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch` 27 | 28 | Also, make sure to unzip the contents of [Mamba](https://github.com/state-spaces/mamba/archive/refs/tags/v1.1.3.post1.zip) (the mamba folder) inside aeromamba/src/models/ . 29 | 30 | ### ViSQOL 31 | 32 | We did not use ViSQOL for training and validation, but if you want to, see [AERO](https://github.com/slp-rl/aero) for instructions. 33 | 34 | ## Datasets 35 | 36 | ### Download data 37 | 38 | For popular music we use the mixture tracks of [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html#musdb18-hq-uncompressed-wav) dataset. 39 | 40 | For piano music, we collected a private dataset from CDs whose metadata are described in our [Webpage](https://aeromamba-super-resolution.github.io/). 41 | 42 | ### Resample data 43 | 44 | Data are a collection of high/low resolution pairs. Corresponding high and low resolution signals should be in different folders, eg: hr_dataset and lr_dataset. 45 | 46 | In order to create each folder, one should run `resample_data` a total of 5 times, 47 | to include all source/target pairs. 48 | 49 | We downsample once to a target 11.025 kHz, from the original 44.1 kHz. 50 | 51 | e.g. for 11.025 and 44.1 kHz: \ 52 | `python data_prep/resample_data.py --data_dir --out_dir --target_sr 11025` 53 | 54 | ### Create egs files 55 | 56 | For each low and high resolution pair, one should create "egs files" twice: for low and high resolution. 57 | `create_meta_files.py` creates a pair of train and val "egs files", each under its respective folder. 58 | Each "egs file" contains meta information about the signals: paths and signal lengths. 59 | 60 | `python data_prep/create_meta_files.py egs/musdb/ lr` 61 | 62 | `python data_prep/create_meta_files.py egs/musdb/ hr` 63 | 64 | ## Train 65 | 66 | Run `train.py` with `dset` and `experiment` parameters, or set the default values in main_config.yaml file. 67 | 68 | ` 69 | python train.py dset= experiment= 70 | ` 71 | 72 | To train with multiple GPUs, run with parameter `ddp=true`. e.g. 73 | ` 74 | python train.py dset= experiment= ddp=true 75 | ` 76 | 77 | ## Test (on whole dataset) 78 | 79 | ` 80 | python test.py dset= experiment= 81 | ` 82 | 83 | ## Inference 84 | 85 | ### Single sample 86 | 87 | ` 88 | python predict.py dset= experiment= +filename= +output= 89 | ` 90 | 91 | ### Multiple samples 92 | 93 | ` 94 | bash predict_batch.sh 95 | ` 96 | 97 | We also provide predict_with_ola.py to predict large files that do not fit in the GPU, without the need for segmentation, using Overlap-and-Add. The original predict.py is also capable of joining predicted segments, but its naïve method causes clicks. 98 | 99 | ` 100 | python predict_with_ola.py dset= experiment= +folder_path= +output= 101 | ` 102 | ### Checkpoints 103 | 104 | To use pre-trained models for MUSDB18-HQ or PianoEval data, one can download checkpoints from [here](https://poliufrjbr-my.sharepoint.com/:f:/g/personal/abreu_engcb_poli_ufrj_br/EhqOtFGTmeZNr-WNv976Jw8BLfpgBYisodrRb2uTGvrFsg?e=5j1nx4). 105 | 106 | To link to checkpoint when testing or predicting, override/set path under `checkpoint_file:` in `conf/main_config.yaml.` e.g. 107 | 108 | ` 109 | python test.py dset= experiment= +checkpoint_file= 110 | ` 111 | 112 | Alternatively, make sure that the checkpoint file is in its corresponding output folder: 113 | For each low to high resolution setting, hydra creates a folder under `outputs//` 114 | 115 | Make sure that `restart: false` in `conf/main_config.yaml` 116 | 117 | ### Citation 118 | 119 | @inproceedings{Abreu2024lamir, 120 | author = {Wallace Abreu and Luiz Wagner Pereira Biscainho}, 121 | title = {AEROMamba: An Efficient Architecture for Audio Super-Resolution Using Generative Adversarial Networks and State Space Models}, 122 | booktitle = {Proceedings of the 1st Latin American Music Information Retrieval Workshop}, 123 | year = {2024}, 124 | address = {Rio de Janeiro, Brazil}, 125 | } 126 | 127 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs 3 | """ 4 | import itertools 5 | import logging 6 | import os 7 | import numpy as np 8 | import hydra 9 | import wandb 10 | import random 11 | 12 | from src.ddp.executor import start_ddp_workers 13 | from src.models import modelFactory 14 | from src.utils import print_network 15 | from src.wandb_logger import _init_wandb_run 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def run(args): 21 | import torch 22 | 23 | from src.ddp import distrib 24 | from src.data.datasets import LrHrSet 25 | from src.solver import Solver 26 | logger.info(f'calling distrib.init') 27 | distrib.init(args) 28 | torch.autograd.set_detect_anomaly(True) 29 | 30 | _init_wandb_run(args) 31 | 32 | if distrib.rank == 0: 33 | if not os.path.exists(args.samples_dir): 34 | os.makedirs(args.samples_dir) 35 | 36 | # torch also initialize cuda seed if available 37 | random.seed(args.seed) 38 | np.random.seed(args.seed) 39 | torch.manual_seed(args.seed) 40 | torch.cuda.manual_seed_all(args.seed) 41 | 42 | torch.backends.cudnn.benchmark = True 43 | models = modelFactory.get_model(args) 44 | for model_name, model in models.items(): 45 | print_network(model_name, model, logger) 46 | wandb.watch(tuple(models.values()), log=args.wandb.log, log_freq=args.wandb.log_freq) 47 | 48 | if args.show: 49 | logger.info(models) 50 | mb = sum(p.numel() for p in models.parameters()) * 4 / 2 ** 20 51 | logger.info('Size: %.1f MB', mb) 52 | return 53 | 54 | assert args.experiment.batch_size % distrib.world_size == 0 55 | args.experiment.batch_size //= distrib.world_size 56 | 57 | # Building datasets and loaders 58 | tr_dataset = LrHrSet(args.dset.train, args.experiment.lr_sr, args.experiment.hr_sr, 59 | args.experiment.stride, args.experiment.segment, upsample=args.experiment.upsample, fixed_n_examples=args.experiment.fixed_n_examples) 60 | 61 | # Filter items based on the threshold (for silent and low-power segments). Threshold value can be tuned 62 | if args.experiment.power_threshold > 0: 63 | filtered_lr_set = [] 64 | filtered_hr_set = [] 65 | for i in range(len(tr_dataset.lr_set)): 66 | lr_signal = tr_dataset.lr_set[i] 67 | hr_signal = tr_dataset.hr_set[i] 68 | if hr_signal is not None: 69 | hr_power = torch.square(hr_signal).sum() / args.experiment.hr_sr 70 | if hr_power >= args.experiment.power_threshold: 71 | filtered_lr_set.append(lr_signal) 72 | filtered_hr_set.append(hr_signal) 73 | 74 | # Replace lr_set and hr_set with filtered lists 75 | tr_dataset.lr_set = filtered_lr_set 76 | tr_dataset.hr_set = filtered_hr_set 77 | 78 | tr_loader = distrib.loader(tr_dataset, batch_size=args.experiment.batch_size, shuffle=True, 79 | num_workers=args.num_workers) 80 | 81 | if args.dset.valid: 82 | args.valid_equals_test = args.dset.valid == args.dset.test 83 | 84 | # Validation and Test batch size, segments and strides can be set differently by the user depending on GPU resources 85 | 86 | if args.dset.valid: 87 | cv_dataset = LrHrSet(args.dset.valid, args.experiment.lr_sr, args.experiment.hr_sr, 88 | args.experiment.stride, args.experiment.segment, upsample=args.experiment.upsample, fixed_n_examples=args.experiment.fixed_n_examples) 89 | cv_loader = distrib.loader(cv_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers) 90 | else: 91 | cv_loader = None 92 | 93 | if args.dset.test: 94 | tt_dataset = LrHrSet(args.dset.test, args.experiment.lr_sr, args.experiment.hr_sr, 95 | stride=40, segment=10, with_path=True, upsample=args.experiment.upsample) 96 | tt_loader = distrib.loader(tt_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers) 97 | else: 98 | tt_loader = None 99 | data = {"tr_loader": tr_loader, "cv_loader": cv_loader, "tt_loader": tt_loader} 100 | 101 | if torch.cuda.is_available() and args.device=='cuda': 102 | for model in models.values(): 103 | model.cuda() 104 | 105 | # optimizer 106 | if args.optim == "adam": 107 | optimizer = torch.optim.Adam(models['generator'].parameters(), lr=args.lr, betas=(0.9, args.beta2)) 108 | else: 109 | logger.fatal('Invalid optimizer %s', args.optim) 110 | os._exit(1) 111 | 112 | optimizers = {'optimizer': optimizer} 113 | 114 | 115 | if 'adversarial' in args.experiment and args.experiment.adversarial: 116 | disc_optimizer = torch.optim.Adam( 117 | itertools.chain(*[models[disc_name].parameters() for disc_name in 118 | args.experiment.discriminator_models]), 119 | args.lr, betas=(0.9, args.beta2)) 120 | optimizers.update({'disc_optimizer': disc_optimizer}) 121 | 122 | 123 | # Construct Solver 124 | solver = Solver(data, models, optimizers, args) 125 | solver.train() 126 | 127 | distrib.close() 128 | 129 | 130 | 131 | def _main(args): 132 | global __file__ 133 | print(args) 134 | # Updating paths in config 135 | for key, value in args.dset.items(): 136 | if isinstance(value, str): 137 | args.dset[key] = hydra.utils.to_absolute_path(value) 138 | __file__ = hydra.utils.to_absolute_path(__file__) 139 | if args.verbose: 140 | logger.setLevel(logging.DEBUG) 141 | logging.getLogger("src").setLevel(logging.DEBUG) 142 | 143 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 144 | logger.debug(args) 145 | 146 | 147 | 148 | if args.ddp and args.rank is None: 149 | start_ddp_workers(args) 150 | else: 151 | run(args) 152 | 153 | wandb.finish() 154 | 155 | 156 | @hydra.main(config_path="conf", config_name="main_config") # for latest version of hydra=1.0 157 | def main(args): 158 | try: 159 | _main(args) 160 | except Exception: 161 | logger.exception("Some error happened") 162 | # Hydra intercepts exit code, fixed in beta but I could not get the beta to work 163 | os._exit(1) 164 | 165 | 166 | if __name__ == "__main__": 167 | main() 168 | -------------------------------------------------------------------------------- /src/models/seanet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from src.models.utils import capture_init, weights_init 4 | from src.models.modules import WNConv1d, WNConvTranspose1d 5 | from torchaudio.functional import resample 6 | from torch.nn import functional as F 7 | 8 | 9 | 10 | class ResnetBlock(nn.Module): 11 | def __init__(self, dim, dilation=1): 12 | super().__init__() 13 | self.block = nn.Sequential( 14 | nn.LeakyReLU(0.2), 15 | nn.ReflectionPad1d(dilation), 16 | WNConv1d(dim, dim, kernel_size=3, dilation=dilation), 17 | nn.LeakyReLU(0.2), 18 | WNConv1d(dim, dim, kernel_size=1), 19 | ) 20 | self.shortcut = WNConv1d(dim, dim, kernel_size=1) 21 | 22 | def forward(self, x): 23 | return self.shortcut(x) + self.block(x) 24 | 25 | 26 | class Seanet(nn.Module): 27 | 28 | @capture_init 29 | def __init__(self, 30 | latent_space_size=128, 31 | ngf=32, n_residual_layers=3, 32 | resample=1, 33 | normalize=True, 34 | floor=1e-3, 35 | ratios=[8, 8, 2, 2], 36 | in_channels=1, 37 | out_channels=1, 38 | lr_sr=16000, 39 | hr_sr=16000, 40 | upsample=True): 41 | super().__init__() 42 | 43 | self.resample = resample 44 | self.normalize = normalize 45 | self.floor = floor 46 | self.lr_sr = lr_sr 47 | self.hr_sr = hr_sr 48 | self.scale_factor = int(self.hr_sr / self.lr_sr) 49 | self.upsample = upsample 50 | 51 | self.encoder = nn.ModuleList() 52 | self.decoder = nn.ModuleList() 53 | 54 | self.ratios = ratios 55 | mult = int(2 ** len(ratios)) 56 | 57 | decoder_wrapper_conv_layer = [ 58 | nn.LeakyReLU(0.2), 59 | nn.ReflectionPad1d(3), 60 | WNConv1d(latent_space_size, mult * ngf, kernel_size=7, padding=0), 61 | ] 62 | 63 | encoder_wrapper_conv_layer = [ 64 | nn.LeakyReLU(0.2), 65 | nn.ReflectionPad1d(3), 66 | WNConv1d(mult * ngf, latent_space_size, kernel_size=7, padding=0) 67 | ] 68 | 69 | self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer)) 70 | self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer)) 71 | 72 | for i, r in enumerate(ratios): 73 | encoder_block = [ 74 | nn.LeakyReLU(0.2), 75 | WNConv1d(mult * ngf // 2, 76 | mult * ngf, 77 | kernel_size=r * 2, 78 | stride=r, 79 | padding=r // 2 + r % 2, 80 | ), 81 | ] 82 | 83 | decoder_block = [ 84 | nn.LeakyReLU(0.2), 85 | WNConvTranspose1d( 86 | mult * ngf, 87 | mult * ngf // 2, 88 | kernel_size=r * 2, 89 | stride=r, 90 | padding=r // 2 + r % 2, 91 | output_padding=r % 2, 92 | ), 93 | ] 94 | 95 | for j in range(n_residual_layers - 1, -1, -1): 96 | encoder_block = [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] + encoder_block 97 | 98 | for j in range(n_residual_layers): 99 | decoder_block += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] 100 | 101 | mult //= 2 102 | 103 | self.encoder.insert(0, nn.Sequential(*encoder_block)) 104 | self.decoder.append(nn.Sequential(*decoder_block)) 105 | 106 | encoder_wrapper_conv_layer = [ 107 | nn.ReflectionPad1d(3), 108 | WNConv1d(in_channels, ngf, kernel_size=7, padding=0), 109 | nn.Tanh(), 110 | ] 111 | self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer)) 112 | 113 | decoder_wrapper_conv_layer = [ 114 | nn.LeakyReLU(0.2), 115 | nn.ReflectionPad1d(3), 116 | WNConv1d(ngf, out_channels, kernel_size=7, padding=0), 117 | nn.Tanh(), 118 | ] 119 | self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer)) 120 | 121 | self.apply(weights_init) 122 | 123 | def estimate_output_length(self, length): 124 | """ 125 | Return the nearest valid length to use with the model so that 126 | there is no time steps left over in a convolutions, e.g. for all 127 | layers, size of the input - kernel_size % stride = 0. 128 | 129 | If the mixture has a valid length, the estimated sources 130 | will have exactly the same length. 131 | """ 132 | depth = len(self.ratios) 133 | for idx in range(depth - 1, -1, -1): 134 | stride = self.ratios[idx] 135 | kernel_size = 2 * stride 136 | padding = stride // 2 + stride % 2 137 | length = math.ceil((length - kernel_size + 2 * padding) / stride) + 1 138 | length = max(length, 1) 139 | for idx in range(depth): 140 | stride = self.ratios[idx] 141 | kernel_size = 2 * stride 142 | padding = stride // 2 + stride % 2 143 | output_padding = stride % 2 144 | length = (length - 1) * stride + kernel_size - 2 * padding + output_padding 145 | return int(length) 146 | 147 | def pad_to_valid_length(self, signal): 148 | valid_length = self.estimate_output_length(signal.shape[-1]) 149 | padding_len = valid_length - signal.shape[-1] 150 | signal = F.pad(signal, (0, padding_len)) 151 | return signal, padding_len 152 | 153 | def forward(self, signal): 154 | 155 | target_len = signal.shape[-1] 156 | if self.upsample: 157 | target_len *= self.scale_factor 158 | if self.normalize: 159 | mono = signal.mean(dim=1, keepdim=True) 160 | std = mono.std(dim=-1, keepdim=True) 161 | signal = signal / (self.floor + std) 162 | else: 163 | std = 1 164 | x = signal 165 | if self.upsample: 166 | x = resample(x, self.lr_sr, self.hr_sr) 167 | 168 | x, padding_len = self.pad_to_valid_length(x) 169 | skips = [] 170 | for i, encode in enumerate(self.encoder): 171 | skips.append(x) 172 | x = encode(x) 173 | for j, decode in enumerate(self.decoder): 174 | x = decode(x) 175 | skip = skips.pop(-1) 176 | x = x + skip 177 | if target_len < x.shape[-1]: 178 | x = x[..., :target_len] 179 | return std * x 180 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import PIL 4 | import torch 5 | 6 | from src.ddp import distrib 7 | from src.data.datasets import match_signal 8 | from src.enhance import save_wavs, save_specs 9 | from src.metrics import run_metrics 10 | from src.utils import LogProgress, bold 11 | from src.wandb_logger import log_data_to_wandb 12 | from src.models.spec import spectro 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | 18 | def evaluate_lr_hr_pr_data(data, wandb_n_files_to_log, files_to_log, epoch, args): 19 | lr, hr, pr, filename = data 20 | filename = filename[0] 21 | hr_sr = args.experiment.hr_sr if 'experiment' in args else args.hr_sr 22 | if args.experiment.upsample: 23 | lr_sr = hr_sr 24 | else: 25 | lr_sr = args.experiment.lr_sr if 'experiment' in args else args.lr_sr 26 | 27 | if wandb_n_files_to_log == -1 or len(files_to_log) < wandb_n_files_to_log: 28 | files_to_log.append(filename) 29 | 30 | if args.device != 'cpu': 31 | hr = hr.cpu() 32 | pr = pr.cpu() 33 | 34 | hr_spec_path = os.path.join(args.samples_dir, filename + '_hr_spec.png') 35 | pr_spec_path = os.path.join(args.samples_dir, filename + '_pr_spec.png') 36 | lr_spec_path = os.path.join(args.samples_dir, filename + '_lr_spec.png') 37 | 38 | hr_spec = PIL.Image.open(hr_spec_path) if os.path.exists(hr_spec_path) else None 39 | pr_spec = PIL.Image.open(pr_spec_path) if os.path.exists(pr_spec_path) else None 40 | lr_spec = PIL.Image.open(lr_spec_path) if os.path.exists(lr_spec_path) else None 41 | 42 | lsd_i, visqol_i = run_metrics(hr, pr, args, filename) 43 | if filename in files_to_log: 44 | log_data_to_wandb(pr, hr, lr, lsd_i, visqol_i, 45 | filename, epoch, lr_sr, hr_sr, lr_spec, pr_spec, hr_spec) 46 | 47 | return {'lsd': lsd_i, 'visqol': visqol_i, 'filename': filename} 48 | 49 | from pathlib import Path 50 | 51 | """ 52 | This is for saving intermediate spectrogram output as well as final time signal output of model. 53 | """ 54 | def evaluate_lr_hr_data(data, model, wandb_n_files_to_log, files_to_log, epoch, args, enhance=True): 55 | (lr, lr_path), (hr, hr_path) = data 56 | lr, hr = lr.to(args.device), hr.to(args.device) 57 | hr_sr = args.experiment.hr_sr if 'experiment' in args else args.hr_sr 58 | if args.experiment.upsample: 59 | lr_sr = hr_sr 60 | else: 61 | lr_sr = args.experiment.lr_sr if 'experiment' in args else args.lr_sr 62 | model.eval() 63 | if args.experiment.model == 'aero': 64 | with torch.no_grad(): 65 | pr, pr_spec, lr_spec = model(lr, return_spec=True, return_lr_spec=True) 66 | pr = match_signal(pr, hr.shape[-1]) 67 | hr_spec = model._spec(hr, scale=True) 68 | else: 69 | nfft= args.experiment.nfft 70 | win_length= nfft//4 71 | pr = model(lr) 72 | pr_spec = spectro(pr, n_fft=nfft, win_length=win_length) 73 | lr_spec = spectro(lr, n_fft=nfft, win_length=win_length) 74 | hr_spec = spectro(hr, n_fft=nfft, win_length=win_length) 75 | model.train() 76 | filename = Path(hr_path[0]).stem 77 | 78 | if wandb_n_files_to_log == -1 or len(files_to_log) < wandb_n_files_to_log: 79 | files_to_log.append(filename) 80 | 81 | if args.device != 'cpu': 82 | hr = hr.cpu() 83 | pr = pr.cpu() 84 | lr = lr.cpu() 85 | 86 | lsd_i, visqol_i = run_metrics(hr, pr, args, filename) 87 | if filename in files_to_log: 88 | log_data_to_wandb(pr, hr, lr, lsd_i, visqol_i, 89 | filename, epoch, lr_sr, hr_sr, lr_spec.cpu(), pr_spec.cpu(), hr_spec.cpu()) 90 | 91 | if enhance: 92 | os.makedirs(args.samples_dir, exist_ok=True) 93 | lr_sr = args.experiment.hr_sr if args.experiment.upsample else args.experiment.lr_sr 94 | save_wavs(pr, lr, hr, [os.path.join(args.samples_dir, filename)], lr_sr, args.experiment.hr_sr) 95 | save_specs(lr_spec, pr_spec, hr_spec, os.path.join(args.samples_dir, filename)) 96 | 97 | return {'lsd': lsd_i, 'visqol': visqol_i, 'filename': filename} 98 | 99 | 100 | def evaluate_on_saved_data(args, data_loader, epoch): 101 | 102 | total_lsd = 0 103 | total_visqol = 0 104 | 105 | lsd_count = 0 106 | visqol_count = 0 107 | 108 | total_cnt = 0 109 | 110 | files_to_log = [] 111 | wandb_n_files_to_log = args.wandb.n_files_to_log if 'wandb' in args else args.wandb_n_files_to_log 112 | 113 | with torch.no_grad(): 114 | iterator = LogProgress(logger, data_loader, name="Eval estimates") 115 | for i, data in enumerate(iterator): 116 | metrics_i = evaluate_lr_hr_pr_data(data, wandb_n_files_to_log, files_to_log, epoch, args) 117 | 118 | total_lsd += metrics_i['lsd'] 119 | total_visqol += metrics_i['visqol'] 120 | 121 | lsd_count += 1 if metrics_i['lsd'] != 0 else 0 122 | visqol_count += 1 if metrics_i['visqol'] != 0 else 0 123 | 124 | total_cnt += 1 125 | 126 | if lsd_count != 0: 127 | avg_lsd, = [total_lsd / lsd_count] 128 | else: 129 | avg_lsd = 0 130 | 131 | if visqol_count != 0: 132 | avg_visqol, = [total_visqol / visqol_count] 133 | else: 134 | avg_visqol = 0 135 | 136 | logger.info(bold( 137 | f'{args.experiment.name}, {args.experiment.lr_sr}->{args.experiment.hr_sr}. Test set performance:' 138 | f'LSD={avg_lsd} ({lsd_count}/{total_cnt}), VISQOL={avg_visqol} ({visqol_count}/{total_cnt}).')) 139 | 140 | return avg_lsd, avg_visqol 141 | 142 | 143 | def evaluate(args, data_loader, epoch, model): 144 | total_lsd = 0 145 | total_visqol = 0 146 | 147 | lsd_count = 0 148 | visqol_count = 0 149 | 150 | total_cnt = 0 151 | 152 | total_filenames = [] 153 | 154 | files_to_log = [] 155 | wandb_n_files_to_log = args.wandb.n_files_to_log if 'wandb' in args else args.wandb_n_files_to_log 156 | 157 | with torch.no_grad(): 158 | iterator = LogProgress(logger, data_loader, name="Eval estimates") 159 | for i, data in enumerate(iterator): 160 | 161 | metrics_i = evaluate_lr_hr_data(data, model, wandb_n_files_to_log, files_to_log, epoch, args) 162 | total_lsd += metrics_i['lsd'] 163 | total_visqol += metrics_i['visqol'] 164 | 165 | total_filenames.append(metrics_i['filename']) 166 | 167 | lsd_count += 1 if metrics_i['lsd'] != 0 else 0 168 | visqol_count += 1 if metrics_i['visqol'] != 0 else 0 169 | 170 | total_cnt += 1 171 | 172 | if lsd_count != 0: 173 | avg_lsd, = distrib.average([total_lsd / lsd_count], lsd_count) 174 | else: 175 | avg_lsd = 0 176 | if visqol_count != 0: 177 | avg_visqol, = distrib.average([total_visqol / visqol_count], visqol_count) 178 | else: 179 | avg_visqol = 0 180 | 181 | 182 | logger.info(bold( 183 | f'{args.experiment.name}, {args.experiment.lr_sr}->{args.experiment.hr_sr}. Test set performance:' 184 | f'LSD={avg_lsd} ({lsd_count}/{total_cnt}), VISQOL={avg_visqol} ({visqol_count}/{total_cnt}).')) 185 | return avg_lsd, avg_visqol, total_filenames 186 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /src/data/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs 3 | """ 4 | 5 | import json 6 | import logging 7 | import os 8 | import random 9 | 10 | import torch 11 | from tqdm import tqdm 12 | import torchaudio 13 | 14 | from torch.nn import functional as F 15 | from torchaudio.functional import resample 16 | from torch.utils.data import Dataset 17 | from torchaudio.transforms import Spectrogram 18 | 19 | from src.data.audio import Audioset 20 | from src.utils import match_signal 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def match_files(lr, hr): 26 | """match_files. 27 | Sort files to match lr and hr filenames. 28 | :param lr: list of the low-resolution filenames 29 | :param hr: list of the high-resolution filenames 30 | """ 31 | lr.sort() 32 | hr.sort() 33 | 34 | 35 | def assert_sets(lr_set, hr_set): 36 | n_samples = len(lr_set) 37 | for i in tqdm(range(n_samples)): 38 | assert lr_set[i].shape == hr_set[i].shape, f"file {i} shape is not the same, lr: {lr_set[i].shape}, hr: {hr_set[i].shape}" 39 | 40 | 41 | def match_source_to_target_length(source_sig, target_sig): 42 | target_len = target_sig.shape[-1] 43 | source_len = source_sig.shape[-1] 44 | if target_len < source_len: 45 | source_sig = source_sig[..., :target_len] 46 | elif target_len > source_len: 47 | source_sig = F.pad(source_sig, (0, target_len - source_len)) 48 | return source_sig 49 | 50 | 51 | class PrHrSet(Dataset): 52 | def __init__(self, samples_dir, filenames=None): 53 | self.samples_dir = samples_dir 54 | if filenames is not None: 55 | files = [i for i in os.listdir(samples_dir) if any(i for j in filenames if j in i)] 56 | else: 57 | files = os.listdir(samples_dir) 58 | 59 | self.hr_filenames = list(sorted(filter(lambda x: x.endswith('_hr.wav'), files))) 60 | self.lr_filenames = list(sorted(filter(lambda x: x.endswith('_lr.wav'), files))) 61 | self.pr_filenames = list(sorted(filter(lambda x: x.endswith('_pr.wav'), files))) 62 | 63 | def __len__(self): 64 | return len(self.hr_filenames) 65 | 66 | def __getitem__(self, i): 67 | lr_i, lr_sr = torchaudio.load(os.path.join(self.samples_dir, self.lr_filenames[i])) 68 | hr_i, hr_sr = torchaudio.load(os.path.join(self.samples_dir, self.hr_filenames[i])) 69 | pr_i, pr_sr = torchaudio.load(os.path.join(self.samples_dir, self.pr_filenames[i])) 70 | pr_i = match_signal(pr_i, hr_i.shape[-1]) 71 | assert hr_i.shape == pr_i.shape 72 | lr_filename = self.lr_filenames[i] 73 | lr_filename = lr_filename[:lr_filename.index('_lr.wav')] 74 | hr_filename = self.hr_filenames[i] 75 | hr_filename = hr_filename[:hr_filename.index('_hr.wav')] 76 | pr_filename = self.pr_filenames[i] 77 | pr_filename = pr_filename[:pr_filename.index('_pr.wav')] 78 | assert lr_filename == hr_filename == pr_filename 79 | 80 | return lr_i, hr_i, pr_i, lr_filename 81 | 82 | 83 | class LrHrSet(Dataset): 84 | def __init__(self, json_dir, lr_sr, hr_sr, stride=None, segment=None, 85 | pad=True, with_path=False, stft=False, win_len=64, hop_len=16, n_fft=4096, complex_as_channels=True, 86 | upsample=True, fixed_n_examples=None): 87 | """__init__. 88 | :param json_dir: directory containing both hr.json and lr.json 89 | :param stride: the stride used for splitting audio sequences in seconds 90 | :param segment: the segment length used for splitting audio sequences in seconds 91 | :param pad: pad the end of the sequence with zeros 92 | :param sample_rate: the signals sampling rate 93 | :param with_path: whether to return tensors with filepath 94 | :param stft: convert to spectrogram 95 | :param win_len: stft window length in seconds 96 | :param hop_len: stft hop length in seconds 97 | :param n_fft: stft number of frequency bins 98 | :param complex_as_channels: True - move complex dimension to channel dimension. output is [2, Fr, T] 99 | False - last dimension is complex channels, output is [1, Fr, T, 2] 100 | """ 101 | 102 | self.lr_sr = lr_sr 103 | self.hr_sr = hr_sr 104 | self.stft = stft 105 | self.with_path = with_path 106 | self.upsample = upsample 107 | self.fixed_n_examples = fixed_n_examples 108 | 109 | if self.stft: 110 | self.window_length = int(self.hr_sr / 1000 * win_len) # 64 ms 111 | self.hop_length = int(self.hr_sr / 1000 * hop_len) # 16 ms 112 | self.window = torch.hann_window(self.window_length) 113 | self.n_fft = n_fft 114 | self.complex_as_channels = complex_as_channels 115 | self.spectrogram = Spectrogram(n_fft=n_fft, win_length=self.window_length, hop_length=self.hop_length, 116 | power=None) 117 | 118 | lr_json = os.path.join(json_dir, 'lr.json') 119 | hr_json = os.path.join(json_dir, 'hr.json') 120 | 121 | with open(lr_json, 'r') as f: 122 | lr = json.load(f) 123 | with open(hr_json, 'r') as f: 124 | hr = json.load(f) 125 | 126 | lr_stride = stride * lr_sr if stride else None 127 | hr_stride = stride * hr_sr if stride else None 128 | lr_length = segment * lr_sr if segment else None 129 | hr_length = segment * hr_sr if segment else None 130 | 131 | match_files(lr, hr) 132 | self.lr_set = Audioset(lr, sample_rate=lr_sr, length=lr_length, stride=lr_stride, pad=pad, channels=1, 133 | with_path=with_path, fixed_n_examples=self.fixed_n_examples) 134 | self.hr_set = Audioset(hr, sample_rate=hr_sr, length=hr_length, stride=hr_stride, pad=pad, channels=1, 135 | with_path=with_path, fixed_n_examples=self.fixed_n_examples) 136 | assert len(self.hr_set) == len(self.lr_set), f"hr: {len(self.hr_set)}, lr: {len(self.lr_set)}" 137 | 138 | 139 | 140 | #if self.fixed_n_examples is not None: 141 | # self.list_of_indexes = random.sample(range(len(self.hr_set)), self.fixed_n_examples) 142 | def __getitem__(self, index): 143 | if self.fixed_n_examples is not None: 144 | index = random.sample(range(len(self.hr_set)), 1)[0] 145 | if self.with_path: 146 | hr_sig, hr_path = self.hr_set[index] 147 | lr_sig, lr_path = self.lr_set[index] 148 | else: 149 | hr_sig = self.hr_set[index] 150 | lr_sig = self.lr_set[index] 151 | if self.upsample: 152 | lr_sig = resample(lr_sig, self.lr_sr, self.hr_sr) 153 | lr_sig = match_signal(lr_sig, hr_sig.shape[-1]) 154 | 155 | if self.stft: 156 | hr_sig = torch.view_as_real(self.spectrogram(hr_sig)) 157 | lr_sig = torch.view_as_real(self.spectrogram(lr_sig)) 158 | if self.complex_as_channels: 159 | Ch, Fr, T, _ = hr_sig.shape 160 | hr_sig = hr_sig.reshape(2 * Ch, Fr, T) 161 | lr_sig = lr_sig.reshape(2 * Ch, Fr, T) 162 | 163 | if self.with_path: 164 | return (lr_sig, lr_path), (hr_sig, hr_path) 165 | else: 166 | return lr_sig, hr_sig 167 | 168 | def __len__(self): 169 | return len(self.lr_set) 170 | 171 | 172 | if __name__ == "__main__": 173 | json_dir = 'egs/chopin-11-44-tiny/tr' 174 | lr_sr = 11025 175 | hr_sr = 44100 176 | pad = True 177 | stride_sec = 2 178 | segment_sec = 2 179 | 180 | data_set = LrHrSet(json_dir, lr_sr, hr_sr, stride_sec, segment_sec, fixed_n_examples=15) 181 | print(data_set) 182 | -------------------------------------------------------------------------------- /src/wandb_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import PIL 4 | import wandb 5 | import logging 6 | 7 | from torchaudio.functional import resample 8 | from torchaudio.transforms import Spectrogram 9 | 10 | from src.metrics import run_metrics 11 | from src.utils import convert_spectrogram_to_heatmap 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | SPECTOGRAM_EPSILON = 1e-13 16 | 17 | 18 | def _get_wandb_config(args): 19 | included_keys = ['eval_every', 'optim', 'lr', 'losses', 'epochs'] 20 | wandb_config = {k: args[k] for k in included_keys} 21 | wandb_config.update(**args.experiment) 22 | wandb_config.update({'train': args.dset.train, 'test': args.dset.test}) 23 | return wandb_config 24 | 25 | 26 | def _init_wandb_run(args, train=True): 27 | tags = args.wandb.tags 28 | wandb_mode = os.environ['WANDB_MODE'] if 'WANDB_MODE' in os.environ.keys() else args.wandb.mode 29 | logger.info(f'current path: {os.getcwd()}, rank: {args.rank}') 30 | if args.ddp: 31 | experiment_name = args.experiment.name + f'-rank={args.rank}' 32 | else: 33 | experiment_name = args.experiment.name 34 | if train and args.ddp and args.wandb.resume: 35 | group_id_path = os.path.join(os.getcwd(), 'group_id.dat') 36 | if not os.path.exists(group_id_path): 37 | group_id = wandb.util.generate_id() 38 | with open(group_id_path, 'w+') as f: 39 | f.write(group_id) 40 | else: 41 | group_id = open(group_id_path).read() 42 | wandb.init(mode=wandb_mode, project=args.wandb.project_name, entity=args.wandb.entity, 43 | config=_get_wandb_config(args), 44 | group=os.path.basename(args.dset.name), 45 | id=f"{group_id}-worker-{args.rank}", job_type="worker", 46 | resume='allow', name=experiment_name, 47 | tags=tags) 48 | else: 49 | wandb.init(mode=wandb_mode, project=args.wandb.project_name, entity=args.wandb.entity, 50 | config=_get_wandb_config(args), 51 | group=os.path.basename(args.dset.name), resume=args.wandb.resume, name=experiment_name, 52 | tags=tags) 53 | 54 | 55 | def log_data_to_wandb(pr_signal, hr_signal, lr_signal, lsd, visqol, filename, epoch, lr_sr, hr_sr, lr_spec=None, pr_spec=None, hr_spec=None): 56 | spectrogram_transform = Spectrogram() 57 | enhanced_spectrogram = spectrogram_transform(pr_signal).log2()[0, :, :].numpy() 58 | enhanced_spectrogram_wandb_image = wandb.Image(convert_spectrogram_to_heatmap(enhanced_spectrogram), 59 | caption='PR') 60 | enhanced_wandb_audio = wandb.Audio(pr_signal.squeeze().numpy(), sample_rate=hr_sr, caption='PR') 61 | 62 | wandb_dict = {f'test samples/{filename}/lsd': lsd, 63 | f'test samples/{filename}/visqol': visqol, 64 | f'test samples/{filename}/spectrogram': enhanced_spectrogram_wandb_image, 65 | f'test samples/{filename}/audio': enhanced_wandb_audio} 66 | 67 | if pr_spec is not None and hr_spec is not None and lr_spec is not None: 68 | if not isinstance(pr_spec, PIL.Image.Image): 69 | pr_spec = pr_spec.abs().pow(2).log2()[0,:,:].numpy() 70 | pr_spec = convert_spectrogram_to_heatmap(pr_spec) 71 | enhanced_pr_spectrogram_wandb_image = wandb.Image(pr_spec, caption='PR spec') 72 | wandb_dict.update({f'test samples/{filename}/pr_spec': enhanced_pr_spectrogram_wandb_image}) 73 | 74 | if epoch <= 10: 75 | if not isinstance(hr_spec, PIL.Image.Image): 76 | hr_spec = hr_spec.abs().pow(2).log2()[0, :, :].numpy() 77 | hr_spec = convert_spectrogram_to_heatmap(hr_spec) 78 | enhanced_hr_spectrogram_wandb_image = wandb.Image(hr_spec, caption='HR spec') 79 | wandb_dict.update({f'test samples/{filename}/hr_spec': enhanced_hr_spectrogram_wandb_image}) 80 | 81 | if not isinstance(lr_spec, PIL.Image.Image): 82 | lr_spec = lr_spec.abs().pow(2).log2()[0, :, :].numpy() 83 | lr_spec = convert_spectrogram_to_heatmap(lr_spec) 84 | enhanced_lr_spectrogram_wandb_image = wandb.Image(lr_spec, caption='LR spec') 85 | wandb_dict.update({f'test samples/{filename}/lr_spec': enhanced_lr_spectrogram_wandb_image}) 86 | 87 | if epoch <= 10: 88 | hr_name = f'{filename}_hr' 89 | hr_enhanced_spectrogram = spectrogram_transform(hr_signal).log2()[0, :, :].numpy() 90 | hr_enhanced_spectrogram_wandb_image = wandb.Image(convert_spectrogram_to_heatmap(hr_enhanced_spectrogram), 91 | caption='HR') 92 | hr_enhanced_wandb_audio = wandb.Audio(hr_signal.squeeze().numpy(), sample_rate=hr_sr, caption='HR') 93 | wandb_dict.update({f'test samples/{filename}/{hr_name}_spectrogram': hr_enhanced_spectrogram_wandb_image, 94 | f'test samples/{filename}/{hr_name}_audio': hr_enhanced_wandb_audio}) 95 | 96 | lr_name = f'{filename}_lr' 97 | lr_enhanced_spectrogram = spectrogram_transform(lr_signal).log2()[0, :, :].numpy() 98 | lr_enhanced_spectrogram_wandb_image = wandb.Image(convert_spectrogram_to_heatmap(lr_enhanced_spectrogram), 99 | caption='LR') 100 | lr_enhanced_wandb_audio = wandb.Audio(lr_signal.squeeze().numpy(), sample_rate=lr_sr, caption='LR') 101 | wandb_dict.update({f'test samples/{filename}/{lr_name}_spectrogram': lr_enhanced_spectrogram_wandb_image, 102 | f'test samples/{filename}/{lr_name}_audio': lr_enhanced_wandb_audio}) 103 | 104 | wandb.log(wandb_dict, 105 | step=epoch) 106 | 107 | 108 | def create_wandb_table(args, data_loader, epoch): 109 | wandb_table = init_wandb_table() 110 | 111 | for i, data in enumerate(data_loader): 112 | if args.wandb.n_files_to_log_to_table and i >= args.wandb.n_files_to_log_to_table: 113 | break 114 | lr, hr, pr, filename = data 115 | filename = filename[0] 116 | lsd, visqol = run_metrics(hr, pr, args, filename) 117 | add_data_to_wandb_table((hr, lr, pr), (lsd, visqol), filename, args, wandb_table) 118 | 119 | wandb.log({"Results": wandb_table}, step=epoch) 120 | 121 | 122 | def init_wandb_table(): 123 | columns = ['filename', 'hr audio', 'hr spectogram', 'lr audio', 'lr spectogram', 'pr audio','pr spectogram', 124 | 'lsd', 'visqol'] 125 | table = wandb.Table(columns=columns) 126 | return table 127 | 128 | 129 | def add_data_to_wandb_table(signals, metrics, filename, args, wandb_table): 130 | hr, lr, pr = signals 131 | 132 | spectrogram_transform = Spectrogram(n_fft=args.experiment.nfft) 133 | 134 | lr_upsampled = resample(lr, args.experiment.lr_sr, args.experiment.hr_sr) 135 | 136 | hr_spectrogram = spectrogram_transform(hr).log2()[0, :, :].numpy() 137 | lr_spectrogram = (SPECTOGRAM_EPSILON + spectrogram_transform(lr_upsampled)).log2()[0, :, :].numpy() 138 | pr_spectrogram = spectrogram_transform(pr).log2()[0, :, :].numpy() 139 | hr_wandb_spec = wandb.Image(convert_spectrogram_to_heatmap(hr_spectrogram)) 140 | lr_wandb_spec = wandb.Image(convert_spectrogram_to_heatmap(lr_spectrogram)) 141 | pr_wandb_spec = wandb.Image(convert_spectrogram_to_heatmap(pr_spectrogram)) 142 | lsd, visqol = metrics 143 | 144 | hr_sr = args.experiment.hr_sr 145 | lr_sr = args.experiment.lr_sr 146 | 147 | hr_wandb_audio = wandb.Audio(hr.squeeze().numpy(), sample_rate=hr_sr, caption=filename + '_hr') 148 | lr_wandb_audio = wandb.Audio(lr.squeeze().numpy(), sample_rate=lr_sr, caption=filename + '_lr') 149 | pr_wandb_audio = wandb.Audio(pr.squeeze().numpy(), sample_rate=hr_sr, caption=filename + '_pr') 150 | 151 | wandb_table.add_data(filename, hr_wandb_audio, hr_wandb_spec, lr_wandb_audio, lr_wandb_spec, 152 | pr_wandb_audio, pr_wandb_spec, 153 | lsd, visqol) -------------------------------------------------------------------------------- /src/models/discriminators.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn.utils import weight_norm, spectral_norm 4 | from torch.nn import Conv1d, AvgPool1d, Conv2d 5 | import torch.nn.functional as F 6 | 7 | from src.models.modules import WNConv1d 8 | from src.models.utils import weights_init 9 | from src.utils import capture_init 10 | 11 | 12 | # Melgan discriminator 13 | 14 | class NLayerDiscriminator(nn.Module): 15 | def __init__(self, ndf, n_layers, downsampling_factor): 16 | super().__init__() 17 | model = nn.ModuleDict() 18 | model["layer_0"] = nn.Sequential( 19 | nn.ReflectionPad1d(7), 20 | WNConv1d(1, ndf, kernel_size=15), 21 | nn.LeakyReLU(0.2, True), 22 | ) 23 | 24 | nf = ndf 25 | stride = downsampling_factor 26 | max_nf = (stride ** (n_layers -1) ) *ndf 27 | for n in range(1, n_layers + 1): 28 | nf_prev = nf 29 | nf = min(nf * stride, max_nf) 30 | model["layer_%d" % n] = nn.Sequential( 31 | WNConv1d( 32 | nf_prev, 33 | nf, 34 | kernel_size=stride * 10 + 1, 35 | stride=stride, 36 | padding=stride * 5, 37 | groups=nf_prev // 4, 38 | ), 39 | nn.LeakyReLU(0.2, True), 40 | ) 41 | nf = min(nf * 2, max_nf) 42 | model["layer_%d" % (n_layers + 1)] = nn.Sequential( 43 | WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2), 44 | nn.LeakyReLU(0.2, True), 45 | ) 46 | model["layer_%d" % (n_layers + 2)] = WNConv1d( 47 | nf, 1, kernel_size=3, stride=1, padding=1 48 | ) 49 | self.model = model 50 | 51 | def forward(self, x): 52 | results = [] 53 | for key, layer in self.model.items(): 54 | x = layer(x) 55 | results.append(x) 56 | return results 57 | 58 | 59 | class Discriminator(nn.Module): 60 | @capture_init 61 | def __init__(self, num_D, ndf, n_layers, downsampling_factor): 62 | super().__init__() 63 | self.model = nn.ModuleDict() 64 | self.num_D = num_D 65 | for i in range(num_D): 66 | self.model[f"disc_{i}"] = NLayerDiscriminator( 67 | ndf, n_layers, downsampling_factor 68 | ) 69 | 70 | self.downsample = AvgPool1d(4, stride=2, padding=1, count_include_pad=False) 71 | self.apply(weights_init) 72 | 73 | def forward(self, x): 74 | results = [] 75 | for key, disc in self.model.items(): 76 | results.append(disc(x)) 77 | x = self.downsample(x) 78 | return results 79 | 80 | # HiFiGAN discriminators 81 | 82 | LRELU_SLOPE = 0.1 83 | 84 | 85 | def get_padding(kernel_size, dilation=1): 86 | return int((kernel_size * dilation - dilation) / 2) 87 | 88 | 89 | class DiscriminatorP(torch.nn.Module): 90 | @capture_init 91 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, hidden=32): 92 | super(DiscriminatorP, self).__init__() 93 | self.period = period 94 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 95 | self.convs = nn.ModuleList([ 96 | norm_f(Conv2d(1, hidden, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 97 | norm_f(Conv2d(hidden, hidden * 4, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 98 | norm_f(Conv2d(hidden * 4, hidden * 16, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 99 | norm_f(Conv2d(hidden * 16, hidden * 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 100 | norm_f(Conv2d(hidden * 32, hidden * 32, (kernel_size, 1), 1, padding=(2, 0))), 101 | ]) 102 | self.conv_post = norm_f(Conv2d(hidden * 32, 1, (3, 1), 1, padding=(1, 0))) 103 | 104 | def forward(self, x): 105 | fmap = [] 106 | 107 | # 1d to 2d 108 | b, c, t = x.shape 109 | if t % self.period != 0: # pad first 110 | n_pad = self.period - (t % self.period) 111 | x = F.pad(x, (0, n_pad), "reflect") 112 | t = t + n_pad 113 | x = x.view(b, c, t // self.period, self.period) 114 | 115 | for l in self.convs: 116 | x = l(x) 117 | x = F.leaky_relu(x, LRELU_SLOPE) 118 | fmap.append(x) 119 | x = self.conv_post(x) 120 | fmap.append(x) 121 | x = torch.flatten(x, 1, -1) 122 | 123 | return x, fmap 124 | 125 | 126 | class MultiPeriodDiscriminator(torch.nn.Module): 127 | @capture_init 128 | def __init__(self, hidden=32, periods=[2, 3, 5, 7, 11]): 129 | super(MultiPeriodDiscriminator, self).__init__() 130 | self.discriminators = nn.ModuleList([ 131 | DiscriminatorP(period, hidden=hidden) for period in periods 132 | ]) 133 | 134 | def forward(self, y, y_hat): 135 | y_d_rs = [] 136 | y_d_gs = [] 137 | fmap_rs = [] 138 | fmap_gs = [] 139 | for i, d in enumerate(self.discriminators): 140 | y_d_r, fmap_r = d(y) 141 | y_d_g, fmap_g = d(y_hat) 142 | y_d_rs.append(y_d_r) 143 | fmap_rs.append(fmap_r) 144 | y_d_gs.append(y_d_g) 145 | fmap_gs.append(fmap_g) 146 | 147 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 148 | 149 | 150 | class DiscriminatorS(torch.nn.Module): 151 | @capture_init 152 | def __init__(self, use_spectral_norm=False, hidden=128): 153 | super(DiscriminatorS, self).__init__() 154 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 155 | 156 | self.convs = nn.ModuleList([ 157 | norm_f(Conv1d(1, hidden, 15, 1, padding=7)), 158 | norm_f(Conv1d(hidden, hidden, 41, 2, groups=4, padding=20)), 159 | norm_f(Conv1d(hidden, hidden * 2, 41, 2, groups=16, padding=20)), 160 | norm_f(Conv1d(hidden * 2, hidden * 4, 41, 4, groups=16, padding=20)), 161 | norm_f(Conv1d(hidden * 4, hidden * 8, 41, 4, groups=16, padding=20)), 162 | norm_f(Conv1d(hidden * 8, hidden * 8, 41, 1, groups=16, padding=20)), 163 | norm_f(Conv1d(hidden * 8, hidden * 8, 5, 1, padding=2)), 164 | ]) 165 | self.conv_post = norm_f(Conv1d(hidden * 8, 1, 3, 1, padding=1)) 166 | 167 | def forward(self, x): 168 | fmap = [] 169 | for l in self.convs: 170 | x = l(x) 171 | x = F.leaky_relu(x, LRELU_SLOPE) 172 | fmap.append(x) 173 | x = self.conv_post(x) 174 | fmap.append(x) 175 | x = torch.flatten(x, 1, -1) 176 | 177 | return x, fmap 178 | 179 | 180 | class MultiScaleDiscriminator(torch.nn.Module): 181 | @capture_init 182 | def __init__(self, hidden=64, num_D=3): 183 | super(MultiScaleDiscriminator, self).__init__() 184 | self.discriminators = nn.ModuleList([ 185 | DiscriminatorS(use_spectral_norm=i == 0, hidden=hidden) for i in range(num_D) 186 | ]) 187 | self.meanpools = nn.ModuleList([ 188 | AvgPool1d(4, 2, padding=2), 189 | AvgPool1d(4, 2, padding=2) 190 | ]) 191 | 192 | def forward(self, y, y_hat): 193 | y_d_rs = [] 194 | y_d_gs = [] 195 | fmap_rs = [] 196 | fmap_gs = [] 197 | for i, d in enumerate(self.discriminators): 198 | if i != 0: 199 | y = self.meanpools[i - 1](y) 200 | y_hat = self.meanpools[i - 1](y_hat) 201 | y_d_r, fmap_r = d(y) 202 | y_d_g, fmap_g = d(y_hat) 203 | y_d_rs.append(y_d_r) 204 | fmap_rs.append(fmap_r) 205 | y_d_gs.append(y_d_g) 206 | fmap_gs.append(fmap_g) 207 | 208 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 209 | 210 | 211 | def feature_loss(fmap_r, fmap_g): 212 | loss = 0 213 | total_n_layers = 0 214 | for dr, dg in zip(fmap_r, fmap_g): 215 | for rl, gl in zip(dr, dg): 216 | total_n_layers += 1 217 | loss += torch.mean(torch.abs(rl - gl)) 218 | 219 | return loss / total_n_layers 220 | 221 | 222 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 223 | loss = 0 224 | # r_losses = [] 225 | # g_losses = [] 226 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 227 | r_loss = torch.mean((1 - dr) ** 2) 228 | g_loss = torch.mean(dg ** 2) 229 | loss += (r_loss + g_loss) 230 | # r_losses.append(r_loss.item()) 231 | # g_losses.append(g_loss.item()) 232 | 233 | return loss # , r_losses, g_losses 234 | 235 | 236 | def generator_loss(disc_outputs): 237 | loss = 0 238 | # gen_losses = [] 239 | for dg in disc_outputs: 240 | l = torch.mean((1 - dg) ** 2) 241 | # gen_losses.append(l) 242 | loss += l 243 | 244 | return loss # , gen_losses -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import time 4 | import numpy as np 5 | import cv2 6 | import math 7 | import torch 8 | import os 9 | import librosa 10 | import matplotlib.pyplot as plt 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | from contextlib import contextmanager 15 | 16 | 17 | def get_network_description(network): 18 | '''Get the string and total parameters of the network''' 19 | if isinstance(network, torch.nn.DataParallel): 20 | network = network.module 21 | s = str(network) 22 | n = sum(map(lambda x: x.numel(), network.parameters())) 23 | return s, n 24 | 25 | def print_network(network_name, network, logger): 26 | s, n = get_network_description(network) 27 | if isinstance(network, torch.nn.DataParallel): 28 | net_struc_str = '{} - {}'.format(network.__class__.__name__, 29 | network.module.__class__.__name__) 30 | else: 31 | net_struc_str = '{}'.format(network.__class__.__name__) 32 | 33 | logger.info( 34 | '{} structure: {}, with parameters: {:,d}'.format(network_name, net_struc_str, n)) 35 | logger.info(s) 36 | 37 | 38 | def capture_init(init): 39 | """capture_init. 40 | 41 | Decorate `__init__` with this, and you can then 42 | recover the *args and **kwargs passed to it in `self._init_args_kwargs` 43 | """ 44 | 45 | @functools.wraps(init) 46 | def __init__(self, *args, **kwargs): 47 | self._init_args_kwargs = (args, kwargs) 48 | init(self, *args, **kwargs) 49 | 50 | return __init__ 51 | 52 | 53 | def unfold(a, kernel_size, stride): 54 | """Given input of size [*OT, T], output Tensor of size [*OT, F, K] 55 | with K the kernel size, by extracting frames with the given stride. 56 | This will pad the input so that `F = ceil(T / K)`. 57 | see https://github.com/pytorch/pytorch/issues/60466 58 | """ 59 | *shape, length = a.shape 60 | n_frames = math.ceil(length / stride) 61 | tgt_length = (n_frames - 1) * stride + kernel_size 62 | a = F.pad(a, (0, tgt_length - length)) 63 | strides = list(a.stride()) 64 | assert strides[-1] == 1, 'data should be contiguous' 65 | strides = strides[:-1] + [stride, 1] 66 | return a.as_strided([*shape, n_frames, kernel_size], strides) 67 | 68 | 69 | class LogProgress: 70 | """ 71 | Sort of like tqdm but using log lines and not as real time. 72 | Args: 73 | - logger: logger obtained from `logging.getLogger`, 74 | - iterable: iterable object to wrap 75 | - updates (int): number of lines that will be printed, e.g. 76 | if `updates=5`, log every 1/5th of the total length. 77 | - total (int): length of the iterable, in case it does not support 78 | `len`. 79 | - name (str): prefix to use in the log. 80 | - level: logging level (like `logging.INFO`). 81 | """ 82 | 83 | def __init__(self, 84 | logger, 85 | iterable, 86 | updates=5, 87 | total=None, 88 | name="LogProgress", 89 | level=logging.INFO): 90 | self.iterable = iterable 91 | self.total = total or len(iterable) 92 | self.updates = updates 93 | self.name = name 94 | self.logger = logger 95 | self.level = level 96 | 97 | def update(self, **infos): 98 | self._infos = infos 99 | 100 | def __iter__(self): 101 | self._iterator = iter(self.iterable) 102 | self._index = -1 103 | self._infos = {} 104 | self._begin = time.time() 105 | return self 106 | 107 | def __next__(self): 108 | self._index += 1 109 | try: 110 | value = next(self._iterator) 111 | except StopIteration: 112 | raise 113 | else: 114 | return value 115 | finally: 116 | log_every = max(1, self.total // self.updates) 117 | # logging is delayed by 1 it, in order to have the metrics from update 118 | if self._index >= 1 and self._index % log_every == 0: 119 | self._log() 120 | 121 | def _log(self): 122 | self._speed = (1 + self._index) / (time.time() - self._begin) 123 | infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items()) 124 | if self._speed < 1e-4: 125 | speed = "oo sec/it" 126 | elif self._speed < 0.1: 127 | speed = f"{1 / self._speed:.1f} sec/it" 128 | else: 129 | speed = f"{self._speed:.1f} it/sec" 130 | out = f"{self.name} | {self._index}/{self.total} | {speed}" 131 | if infos: 132 | out += " | " + infos 133 | self.logger.log(self.level, out) 134 | 135 | 136 | def scale_minmax(X, min=0.0, max=1.0): 137 | isnan = np.isnan(X).any() 138 | isinf = np.isinf(X).any() 139 | if isinf: 140 | X[X == np.inf] = 1e9 141 | X[X == -np.inf] = 1e-9 142 | if isnan: 143 | X[X == np.nan] = 1e-9 144 | # logger.info(f'isnan: {isnan}, isinf: {isinf}, max: {X.max()}, min: {X.min()}') 145 | 146 | X_std = (X - X.min()) / (X.max() - X.min()) 147 | X_scaled = X_std * (max - min) + min 148 | return X_scaled 149 | 150 | 151 | def convert_spectrogram_to_heatmap(spectrogram): 152 | spectrogram += 1e-9 153 | spectrogram = scale_minmax(spectrogram, 0, 255).astype(np.uint8).squeeze() 154 | spectrogram = np.flip(spectrogram, axis=0) 155 | spectrogram = 255 - spectrogram 156 | # spectrogram = (255 * (spectrogram - np.min(spectrogram)) / np.ptp(spectrogram)).astype(np.uint8).squeeze()[::-1,:] 157 | heatmap = cv2.applyColorMap(spectrogram, cv2.COLORMAP_INFERNO) 158 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 159 | return heatmap 160 | 161 | 162 | def colorize(text, color): 163 | """ 164 | Display text with some ANSI color in the terminal. 165 | """ 166 | code = f"\033[{color}m" 167 | restore = "\033[0m" 168 | return "".join([code, text, restore]) 169 | 170 | 171 | def bold(text): 172 | """ 173 | Display text in bold in the terminal. 174 | """ 175 | return colorize(text, "1") 176 | 177 | 178 | def copy_state(state): 179 | return {k: v.cpu().clone() for k, v in state.items()} 180 | 181 | 182 | def serialize_model(model): 183 | args, kwargs = model._init_args_kwargs 184 | state = copy_state(model.state_dict()) 185 | return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state} 186 | 187 | 188 | @contextmanager 189 | def swap_state(model, state): 190 | """ 191 | Context manager that swaps the state of a model, e.g: 192 | 193 | # model is in old state 194 | with swap_state(model, new_state): 195 | # model in new state 196 | # model back to old state 197 | """ 198 | old_state = copy_state(model.state_dict()) 199 | model.load_state_dict(state) 200 | try: 201 | yield 202 | finally: 203 | model.load_state_dict(old_state) 204 | 205 | 206 | def pull_metric(history, name): 207 | out = [] 208 | for metrics in history: 209 | if name in metrics: 210 | out.append(metrics[name]) 211 | return out 212 | 213 | 214 | def match_signal(signal, ref_len): 215 | sig_len = signal.shape[-1] 216 | if sig_len < ref_len: 217 | signal = F.pad(signal, (0, ref_len - sig_len)) 218 | elif sig_len > ref_len: 219 | signal = signal[..., :ref_len] 220 | return signal 221 | 222 | import subprocess 223 | 224 | def run_ssh_command(hostname, username, password, command): 225 | # Construct the SSH command with changing directory 226 | ssh_command = f'sshpass -p {password} ssh {username}@{hostname} "cd /home/wallace.abreu/Mestrado/aero_vanilla/ && {command}"' 227 | 228 | try: 229 | # Execute the SSH command 230 | subprocess.run(ssh_command, shell=True, check=True) 231 | 232 | except subprocess.CalledProcessError as e: 233 | print(f"An error occurred: {e}") 234 | 235 | def save_spectrograms(audio_batch, fs, folder_path, N=2048, H=1024): 236 | """ 237 | Save spectrograms of a batch of audio tensors as images. 238 | 239 | Parameters 240 | ---------- 241 | audio_batch : torch.Tensor 242 | Batch of audio tensors with shape (batch_size, audio_length) 243 | sample_rate : int 244 | Sampling rate of the audio tensors 245 | folder_path : str 246 | Path to the folder where spectrogram images will be saved 247 | """ 248 | 249 | # Create the folder if it doesn't exist 250 | if not os.path.exists(folder_path): 251 | os.makedirs(folder_path) 252 | eps = 1e-9 253 | # Iterate over each audio tensor in the batch 254 | for i, x in enumerate(audio_batch): 255 | # Calculate the Short-Time Fourier Transform (STFT) 256 | specgram = torch.stft(x, n_fft=N, hop_length=H, win_length=N, window=torch.hamming_window(N).cuda()).pow(2).sum(-1).sqrt().squeeze() 257 | # Convert frequencies to Hz 258 | freqs = torch.linspace(0, fs // 2, specgram.size(1)) 259 | 260 | # Convert frames to time in seconds 261 | times = torch.linspace(0, len(x.T) / fs, specgram.size(0)) 262 | # Plot the spectrogram 263 | plt.figure(figsize=(10, 6)) 264 | plt.imshow( eps + 10 * specgram.log10().detach().cpu().numpy() , aspect='auto', origin='lower', cmap='inferno', extent=[times[0], times[-1], freqs[0], freqs[-1]]) 265 | plt.colorbar(label='Intensity (log scale)') 266 | plt.ylabel('Frequency [Hz]') 267 | plt.xlabel('Time [s]') 268 | plt.ylim(0, 22050) 269 | plt.title('spectrogram_{i+1}') 270 | plt.tight_layout() 271 | 272 | # Save the spectrogram plot as an image 273 | save_path = os.path.join(folder_path, f'spectrogram_{i+1}.png') 274 | plt.savefig(save_path) 275 | plt.close() 276 | 277 | print(f'Saved spectrogram {i+1} to {save_path}') -------------------------------------------------------------------------------- /src/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils import weight_norm 4 | 5 | from src.models.snake import Snake 6 | from src.models.utils import unfold 7 | from src.models.mamba.mamba_ssm import Mamba 8 | 9 | import typing as tp 10 | 11 | def WNConv1d(*args, **kwargs): 12 | return weight_norm(nn.Conv1d(*args, **kwargs)) 13 | 14 | 15 | def WNConvTranspose1d(*args, **kwargs): 16 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 17 | 18 | class BLSTM(nn.Module): 19 | """ 20 | BiLSTM with same hidden units as input dim. 21 | If `max_steps` is not None, input will be splitting in overlapping 22 | chunks and the LSTM applied separately on each chunk. 23 | """ 24 | 25 | def __init__(self, dim, layers=1, max_steps=None, skip=False): 26 | super().__init__() 27 | assert max_steps is None or max_steps % 4 == 0 28 | self.max_steps = max_steps 29 | self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) 30 | self.linear = nn.Linear(2 * dim, dim) 31 | self.skip = skip 32 | 33 | def forward(self, x): 34 | B, C, T = x.shape 35 | y = x 36 | framed = False 37 | if self.max_steps is not None and T > self.max_steps: 38 | width = self.max_steps 39 | stride = width // 2 40 | frames = unfold(x, width, stride) 41 | nframes = frames.shape[2] 42 | framed = True 43 | x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) 44 | 45 | x = x.permute(2, 0, 1) 46 | 47 | x = self.lstm(x)[0] 48 | x = self.linear(x) 49 | x = x.permute(1, 2, 0) 50 | if framed: 51 | out = [] 52 | frames = x.reshape(B, -1, C, width) 53 | limit = stride // 2 54 | for k in range(nframes): 55 | if k == 0: 56 | out.append(frames[:, k, :, :-limit]) 57 | elif k == nframes - 1: 58 | out.append(frames[:, k, :, limit:]) 59 | else: 60 | out.append(frames[:, k, :, limit:-limit]) 61 | out = torch.cat(out, -1) 62 | out = out[..., :T] 63 | x = out 64 | if self.skip: 65 | x = x + y 66 | return x 67 | 68 | 69 | class LocalState(nn.Module): 70 | """Local state allows to have attention based only on data (no positional embedding), 71 | but while setting a constraint on the time window (e.g. decaying penalty term). 72 | Also a failed experiments with trying to provide some frequency based attention. 73 | """ 74 | 75 | def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): 76 | super().__init__() 77 | assert channels % heads == 0, (channels, heads) 78 | self.heads = heads 79 | self.nfreqs = nfreqs 80 | self.ndecay = ndecay 81 | self.content = nn.Conv1d(channels, channels, 1) 82 | self.query = nn.Conv1d(channels, channels, 1) 83 | self.key = nn.Conv1d(channels, channels, 1) 84 | if nfreqs: 85 | self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) 86 | if ndecay: 87 | self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) 88 | # Initialize decay close to zero (there is a sigmoid), for maximum initial window. 89 | self.query_decay.weight.data *= 0.01 90 | assert self.query_decay.bias is not None # stupid type checker 91 | self.query_decay.bias.data[:] = -2 92 | # self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) 93 | self.proj = nn.Conv1d(channels, channels, 1) 94 | 95 | def forward(self, x): 96 | B, C, T = x.shape 97 | heads = self.heads 98 | indexes = torch.arange(T, device=x.device, dtype=x.dtype) 99 | # left index are keys, right index are queries 100 | delta = indexes[:, None] - indexes[None, :] 101 | 102 | queries = self.query(x).view(B, heads, -1, T) 103 | keys = self.key(x).view(B, heads, -1, T) 104 | # t are keys, s are queries 105 | dots = torch.einsum("bhct,bhcs->bhts", keys, queries) 106 | dots /= keys.shape[2] ** 0.5 107 | if self.nfreqs: 108 | periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) 109 | freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) 110 | freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 111 | tmp = torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) 112 | dots += tmp 113 | if self.ndecay: 114 | decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) 115 | decay_q = self.query_decay(x).view(B, heads, -1, T) 116 | decay_q = torch.sigmoid(decay_q) / 2 117 | decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay ** 0.5 118 | dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) 119 | 120 | # Kill self reference. 121 | dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) 122 | weights = torch.softmax(dots, dim=2) 123 | 124 | content = self.content(x).view(B, heads, -1, T) 125 | result = torch.einsum("bhts,bhct->bhcs", weights, content) 126 | 127 | result = result.reshape(B, -1, T) 128 | return x + self.proj(result) 129 | 130 | 131 | class LayerScale(nn.Module): 132 | """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). 133 | This rescales diagonaly residual outputs close to 0 initially, then learnt. 134 | """ 135 | 136 | def __init__(self, channels: int, init: float = 0): 137 | super().__init__() 138 | self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) 139 | self.scale.data[:] = init 140 | 141 | def forward(self, x): 142 | return self.scale[:, None] * x 143 | 144 | 145 | class DConv(nn.Module): 146 | """ 147 | New residual branches in each encoder layer. 148 | This alternates dilated convolutions, potentially with LSTMs and attention. 149 | Also before entering each residual branch, dimension is projected on a smaller subspace, 150 | e.g. of dim `channels // compress`. 151 | """ 152 | 153 | def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, 154 | norm=True, time_attn=False, heads=4, ndecay=4, lstm=False, 155 | act_func='gelu', freq_dim=None, reshape=False, 156 | kernel=3, dilate=True, mamba=False, d_state=16, d_conv=4, expand=2): 157 | """ 158 | Args: 159 | channels: input/output channels for residual branch. 160 | compress: amount of channel compression inside the branch. 161 | depth: number of layers in the residual branch. Each layer has its own 162 | projection, and potentially LSTM and attention. 163 | init: initial scale for LayerNorm. 164 | norm: use GroupNorm. 165 | time_attn: use LocalAttention. 166 | heads: number of heads for the LocalAttention. 167 | ndecay: number of decay controls in the LocalAttention. 168 | lstm: use LSTM. 169 | gelu: Use GELU activation. 170 | kernel: kernel size for the (dilated) convolutions. 171 | dilate: if true, use dilation, increasing with the depth. 172 | """ 173 | 174 | super().__init__() 175 | assert kernel % 2 == 1 176 | self.channels = channels 177 | self.compress = compress 178 | self.depth = abs(depth) 179 | dilate = depth > 0 180 | 181 | self.time_attn = time_attn 182 | self.lstm = lstm 183 | self.mamba = mamba 184 | self.reshape = reshape 185 | self.act_func = act_func 186 | self.freq_dim = freq_dim 187 | 188 | norm_fn: tp.Callable[[int], nn.Module] 189 | norm_fn = lambda d: nn.Identity() # noqa 190 | if norm: 191 | norm_fn = lambda d: nn.GroupNorm(1, d) # noqa 192 | 193 | self.hidden = int(channels / compress) 194 | 195 | self.d_model = channels 196 | self.d_state = d_state 197 | self.d_conv = d_conv 198 | self.expand = expand 199 | 200 | act: tp.Type[nn.Module] 201 | if act_func == 'gelu': 202 | act = nn.GELU 203 | elif act_func == 'snake': 204 | act = Snake 205 | else: 206 | act = nn.ReLU 207 | 208 | self.layers = nn.ModuleList([]) 209 | for d in range(self.depth): 210 | layer = nn.ModuleDict() 211 | dilation = 2 ** d if dilate else 1 212 | padding = dilation * (kernel // 2) 213 | conv1 = nn.ModuleList([nn.Conv1d(channels, self.hidden, kernel, dilation=dilation, padding=padding), 214 | norm_fn(self.hidden)]) 215 | act_layer = act(freq_dim) if act_func == 'snake' else act() 216 | conv2 = nn.ModuleList([nn.Conv1d(self.hidden, 2 * channels, 1), 217 | norm_fn(2 * channels), nn.GLU(1), 218 | LayerScale(channels, init)]) 219 | 220 | layer.update({'conv1': nn.Sequential(*conv1), 'act': act_layer, 'conv2': nn.Sequential(*conv2)}) 221 | if mamba: 222 | layer.update({'mamba': Mamba(channels, d_state, d_conv, expand)}) 223 | if lstm: 224 | layer.update({'lstm': BLSTM(self.hidden, layers=2, max_steps=200, skip=True)}) 225 | if time_attn: 226 | layer.update({'time_attn': LocalState(self.hidden, heads=heads, ndecay=ndecay)}) 227 | 228 | self.layers.append(layer) 229 | 230 | def forward(self, x): 231 | 232 | if self.reshape: 233 | B, C, Fr, T = x.shape 234 | x = x.permute(0, 2, 1, 3).reshape(-1, C, T) 235 | 236 | for layer in self.layers: 237 | skip = x 238 | 239 | x = layer['conv1'](x) 240 | 241 | if self.act_func == 'snake' and self.reshape: 242 | x = x.view(B, Fr, self.hidden, T).permute(0, 2, 3, 1) 243 | x = layer['act'](x) 244 | if self.act_func == 'snake' and self.reshape: 245 | x = x.permute(0, 3, 1, 2).reshape(-1, self.hidden, T) 246 | 247 | if self.lstm: 248 | x = layer['lstm'](x) 249 | if self.time_attn: 250 | x = layer['time_attn'](x) 251 | 252 | x = layer['conv2'](x) 253 | x = skip + x 254 | 255 | if self.reshape: 256 | x = x.view(B, Fr, C, T).permute(0, 2, 1, 3) 257 | 258 | return x 259 | 260 | 261 | class ScaledEmbedding(nn.Module): 262 | """ 263 | Boost learning rate for embeddings (with `scale`). 264 | Also, can make embeddings continuous with `smooth`. 265 | """ 266 | 267 | def __init__(self, num_embeddings: int, embedding_dim: int, 268 | scale: float = 10., smooth=False): 269 | super().__init__() 270 | self.embedding = nn.Embedding(num_embeddings, embedding_dim) 271 | if smooth: 272 | weight = torch.cumsum(self.embedding.weight.data, dim=0) 273 | # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that. 274 | weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] 275 | self.embedding.weight.data[:] = weight 276 | self.embedding.weight.data /= scale 277 | self.scale = scale 278 | 279 | @property 280 | def weight(self): 281 | return self.embedding.weight * self.scale 282 | 283 | def forward(self, x): 284 | out = self.embedding(x) * self.scale 285 | return out 286 | 287 | 288 | class FTB(nn.Module): 289 | 290 | def __init__(self, input_dim=257, in_channel=9, r_channel=5): 291 | super(FTB, self).__init__() 292 | self.input_dim = input_dim 293 | self.in_channel = in_channel 294 | self.conv1 = nn.Sequential( 295 | nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]), 296 | nn.BatchNorm2d(r_channel), 297 | nn.ReLU() 298 | ) 299 | 300 | self.conv1d = nn.Sequential( 301 | nn.Conv1d(r_channel * input_dim, in_channel, kernel_size=9, padding=4), 302 | nn.BatchNorm1d(in_channel), 303 | nn.ReLU() 304 | ) 305 | self.freq_fc = nn.Linear(input_dim, input_dim, bias=False) 306 | 307 | self.conv2 = nn.Sequential( 308 | nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]), 309 | nn.BatchNorm2d(in_channel), 310 | nn.ReLU() 311 | ) 312 | 313 | def forward(self, inputs): 314 | ''' 315 | inputs should be [Batch, Ca, Dim, Time] 316 | ''' 317 | # T-F attention 318 | conv1_out = self.conv1(inputs) 319 | B, C, D, T = conv1_out.size() 320 | reshape1_out = torch.reshape(conv1_out, [B, C * D, T]) 321 | conv1d_out = self.conv1d(reshape1_out) 322 | conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T]) 323 | 324 | # now is also [B,C,D,T] 325 | att_out = conv1d_out * inputs 326 | 327 | # tranpose to [B,C,T,D] 328 | att_out = torch.transpose(att_out, 2, 3) 329 | freqfc_out = self.freq_fc(att_out) 330 | att_out = torch.transpose(freqfc_out, 2, 3) 331 | 332 | cat_out = torch.cat([att_out, inputs], 1) 333 | outputs = self.conv2(cat_out) 334 | return outputs 335 | --------------------------------------------------------------------------------