├── losses ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── enh_loss.cpython-39.pyc │ ├── basic_loss.cpython-37.pyc │ ├── basic_loss.cpython-38.pyc │ ├── basic_loss.cpython-39.pyc │ ├── generator_loss.cpython-37.pyc │ ├── generator_loss.cpython-38.pyc │ ├── generator_loss.cpython-39.pyc │ ├── discriminator_loss.cpython-37.pyc │ ├── discriminator_loss.cpython-38.pyc │ └── discriminator_loss.cpython-39.pyc ├── discriminator_loss.py ├── generator_loss.py └── basic_loss.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── soundstream.cpython-37.pyc │ ├── soundstream.cpython-38.pyc │ ├── soundstream.cpython-39.pyc │ ├── soundstream_semantic.cpython-38.pyc │ └── soundstream_semantic.cpython-39.pyc ├── soundstream.py ├── soundstream2.py ├── soundstream_semantic.py └── msstftd.py ├── utils ├── __init__.py ├── __pycache__ │ ├── utils.cpython-37.pyc │ ├── utils.cpython-38.pyc │ ├── utils.cpython-39.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── ddp_utils.cpython-38.pyc │ ├── ddp_utils.cpython-39.pyc │ ├── hifigan_mel.cpython-37.pyc │ ├── hifigan_mel.cpython-38.pyc │ └── hifigan_mel.cpython-39.pyc └── hifigan_mel.py ├── dataloaders ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── base_dataloader.cpython-37.pyc │ ├── base_dataloader.cpython-38.pyc │ └── base_dataloader.cpython-39.pyc └── base_dataloader.py ├── distributed ├── __init__.py ├── __pycache__ │ ├── launch.cpython-37.pyc │ ├── launch.cpython-38.pyc │ ├── launch.cpython-39.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── distributed.cpython-37.pyc │ ├── distributed.cpython-38.pyc │ └── distributed.cpython-39.pyc ├── launch.py └── distributed.py ├── modules ├── commons │ ├── __init__.py │ ├── __pycache__ │ │ ├── ops.cpython-37.pyc │ │ ├── ops.cpython-38.pyc │ │ ├── ops.cpython-39.pyc │ │ ├── pqmf.cpython-37.pyc │ │ ├── pqmf.cpython-38.pyc │ │ ├── pqmf.cpython-39.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── torch_stft.cpython-37.pyc │ │ ├── torch_stft.cpython-38.pyc │ │ ├── torch_stft.cpython-39.pyc │ │ ├── base_layers.cpython-37.pyc │ │ ├── base_layers.cpython-38.pyc │ │ └── base_layers.cpython-39.pyc │ └── pqmf.py ├── discriminators │ ├── __init__.py │ ├── __pycache__ │ │ ├── mrd.cpython-37.pyc │ │ ├── mrd.cpython-38.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── combd_sbd.cpython-37.pyc │ │ ├── combd_sbd.cpython-38.pyc │ │ ├── period_discriminator.cpython-37.pyc │ │ ├── period_discriminator.cpython-38.pyc │ │ ├── period_discriminator.cpython-39.pyc │ │ ├── scale_discriminator.cpython-37.pyc │ │ ├── scale_discriminator.cpython-38.pyc │ │ ├── scale_discriminator.cpython-39.pyc │ │ ├── frequency_discriminator.cpython-37.pyc │ │ ├── frequency_discriminator.cpython-38.pyc │ │ └── frequency_discriminator.cpython-39.pyc │ └── frequency_discriminator.py ├── __pycache__ │ ├── conv.cpython-37.pyc │ ├── conv.cpython-38.pyc │ ├── conv.cpython-39.pyc │ ├── lstm.cpython-37.pyc │ ├── lstm.cpython-38.pyc │ ├── lstm.cpython-39.pyc │ ├── norm.cpython-37.pyc │ ├── norm.cpython-38.pyc │ ├── norm.cpython-39.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── seanet.cpython-37.pyc │ ├── seanet.cpython-38.pyc │ ├── seanet.cpython-39.pyc │ ├── transformer.cpython-37.pyc │ ├── transformer.cpython-38.pyc │ ├── transformer.cpython-39.pyc │ ├── semantic_module.cpython-38.pyc │ └── semantic_module.cpython-39.pyc ├── __init__.py ├── lstm.py ├── norm.py ├── transformer.py └── loss.py ├── descriptaudiocodec ├── .gitattributes ├── tests │ ├── __init__.py │ ├── test_cli.py │ └── test_train.py ├── dac │ ├── compare │ │ ├── __init__.py │ │ └── encodec.py │ ├── nn │ │ ├── __init__.py │ │ └── layers.py │ ├── model │ │ ├── __init__.py │ │ └── discriminator.py │ ├── __init__.py │ ├── __main__.py │ └── utils │ │ ├── decode.py │ │ ├── encode.py │ │ └── __init__.py ├── .dockerignore ├── conf │ ├── ablations │ │ ├── baseline.yml │ │ ├── no-adv.yml │ │ ├── no-mb.yml │ │ ├── no-mpd.yml │ │ ├── no-low-hop.yml │ │ ├── no-mpd-msd.yml │ │ ├── diff-mb.yml │ │ ├── equal-mb.yml │ │ ├── only-speech.yml │ │ └── no-data-balance.yml │ ├── quantizer │ │ ├── 2d.yml │ │ ├── 4d.yml │ │ ├── 24kbps.yml │ │ ├── 256d.yml │ │ ├── 32d.yml │ │ ├── 512d.yml │ │ ├── dropout-0.0.yml │ │ ├── dropout-0.5.yml │ │ └── dropout-0.25.yml │ ├── size │ │ ├── medium.yml │ │ └── small.yml │ ├── 1gpu.yml │ ├── downsampling │ │ ├── 128x.yml │ │ ├── 1024x.yml │ │ ├── 1536x.yml │ │ └── 768x.yml │ ├── base.yml │ └── final │ │ ├── 16khz.yml │ │ ├── 24khz.yml │ │ ├── 44khz.yml │ │ └── 44khz-16kbps.yml ├── requirements.txt ├── Dockerfile ├── Dockerfile.dev ├── .pre-commit-config.yaml ├── docker-compose.yml ├── LICENSE ├── scripts │ ├── compute_entropy.py │ ├── save_test_set.py │ ├── organize_daps.py │ ├── get_samples.py │ ├── evaluate.py │ └── mushra.py ├── setup.py ├── .gitignore └── README.md ├── exp.png ├── fig1.png ├── test_audio ├── music.wav ├── speech_cn.wav └── speech_en.flac ├── quantization ├── __pycache__ │ ├── vq.cpython-37.pyc │ ├── vq.cpython-38.pyc │ ├── vq.cpython-39.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── core_vq.cpython-37.pyc │ ├── core_vq.cpython-38.pyc │ ├── core_vq.cpython-39.pyc │ ├── distrib.cpython-37.pyc │ ├── distrib.cpython-38.pyc │ ├── distrib.cpython-39.pyc │ ├── core_vq_lsx_version.cpython-38.pyc │ └── core_vq_lsx_version.cpython-39.pyc ├── __init__.py ├── distrib.py └── vq.py ├── test_audio_reconstruction └── speech_en_nq_1.wav ├── requirements.txt ├── LICENSE ├── inference.py ├── config └── codec_16k_6kbps_v3_vqdp.yaml └── README.md /losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /distributed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/commons/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /descriptaudiocodec/.gitattributes: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /descriptaudiocodec/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /descriptaudiocodec/dac/compare/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/exp.png -------------------------------------------------------------------------------- /fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/fig1.png -------------------------------------------------------------------------------- /descriptaudiocodec/.dockerignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.wav 3 | *.dac 4 | tests/ 5 | runs/ 6 | -------------------------------------------------------------------------------- /test_audio/music.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/test_audio/music.wav -------------------------------------------------------------------------------- /modules/discriminators/__init__.py: -------------------------------------------------------------------------------- 1 | from .frequency_discriminator import MultiFrequencyDiscriminator -------------------------------------------------------------------------------- /test_audio/speech_cn.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/test_audio/speech_cn.wav -------------------------------------------------------------------------------- /test_audio/speech_en.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/test_audio/speech_en.flac -------------------------------------------------------------------------------- /descriptaudiocodec/conf/ablations/baseline.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | -------------------------------------------------------------------------------- /descriptaudiocodec/dac/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import loss 3 | from . import quantize 4 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/quantizer/2d.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.codebook_dim: 2 6 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/quantizer/4d.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.codebook_dim: 4 6 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/size/medium.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.decoder_dim: 1024 6 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/size/small.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.decoder_dim: 512 6 | -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /descriptaudiocodec/conf/quantizer/24kbps.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.n_codebooks: 28 6 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/quantizer/256d.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.codebook_dim: 256 6 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/quantizer/32d.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.codebook_dim: 32 6 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/quantizer/512d.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.codebook_dim: 512 6 | -------------------------------------------------------------------------------- /modules/__pycache__/conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/conv.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/conv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/conv.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/conv.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/conv.cpython-39.pyc -------------------------------------------------------------------------------- /modules/__pycache__/lstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/lstm.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/lstm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/lstm.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/lstm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/lstm.cpython-39.pyc -------------------------------------------------------------------------------- /modules/__pycache__/norm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/norm.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/norm.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/norm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/norm.cpython-39.pyc -------------------------------------------------------------------------------- /descriptaudiocodec/conf/1gpu.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | 4 | batch_size: 12 5 | val_batch_size: 12 6 | num_workers: 4 7 | -------------------------------------------------------------------------------- /losses/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /losses/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /losses/__pycache__/enh_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/enh_loss.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /modules/__pycache__/seanet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/seanet.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/seanet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/seanet.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/seanet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/seanet.cpython-39.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/vq.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/vq.cpython-37.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/vq.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/vq.cpython-38.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/vq.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/vq.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ddp_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/ddp_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ddp_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/ddp_utils.cpython-39.pyc -------------------------------------------------------------------------------- /descriptaudiocodec/conf/quantizer/dropout-0.0.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.quantizer_dropout: 0.0 6 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/quantizer/dropout-0.5.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.quantizer_dropout: 0.5 6 | -------------------------------------------------------------------------------- /distributed/__pycache__/launch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/distributed/__pycache__/launch.cpython-37.pyc -------------------------------------------------------------------------------- /distributed/__pycache__/launch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/distributed/__pycache__/launch.cpython-38.pyc -------------------------------------------------------------------------------- /distributed/__pycache__/launch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/distributed/__pycache__/launch.cpython-39.pyc -------------------------------------------------------------------------------- /losses/__pycache__/basic_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/basic_loss.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/basic_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/basic_loss.cpython-38.pyc -------------------------------------------------------------------------------- /losses/__pycache__/basic_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/basic_loss.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/soundstream.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/models/__pycache__/soundstream.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/soundstream.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/models/__pycache__/soundstream.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/soundstream.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/models/__pycache__/soundstream.cpython-39.pyc -------------------------------------------------------------------------------- /test_audio_reconstruction/speech_en_nq_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/test_audio_reconstruction/speech_en_nq_1.wav -------------------------------------------------------------------------------- /utils/__pycache__/hifigan_mel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/hifigan_mel.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/hifigan_mel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/hifigan_mel.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/hifigan_mel.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/utils/__pycache__/hifigan_mel.cpython-39.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/dataloaders/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/dataloaders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/dataloaders/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /descriptaudiocodec/conf/quantizer/dropout-0.25.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.quantizer_dropout: 0.25 6 | -------------------------------------------------------------------------------- /distributed/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/distributed/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /distributed/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/distributed/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /distributed/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/distributed/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /losses/__pycache__/generator_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/generator_loss.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/generator_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/generator_loss.cpython-38.pyc -------------------------------------------------------------------------------- /losses/__pycache__/generator_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/generator_loss.cpython-39.pyc -------------------------------------------------------------------------------- /modules/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/transformer.cpython-39.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/ops.cpython-37.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/ops.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/ops.cpython-38.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/ops.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/ops.cpython-39.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/pqmf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/pqmf.cpython-37.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/pqmf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/pqmf.cpython-38.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/pqmf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/pqmf.cpython-39.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/core_vq.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/core_vq.cpython-37.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/core_vq.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/core_vq.cpython-38.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/core_vq.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/core_vq.cpython-39.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/distrib.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/distrib.cpython-37.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/distrib.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/distrib.cpython-38.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/distrib.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/distrib.cpython-39.pyc -------------------------------------------------------------------------------- /distributed/__pycache__/distributed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/distributed/__pycache__/distributed.cpython-37.pyc -------------------------------------------------------------------------------- /distributed/__pycache__/distributed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/distributed/__pycache__/distributed.cpython-38.pyc -------------------------------------------------------------------------------- /distributed/__pycache__/distributed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/distributed/__pycache__/distributed.cpython-39.pyc -------------------------------------------------------------------------------- /modules/__pycache__/semantic_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/semantic_module.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/semantic_module.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/__pycache__/semantic_module.cpython-39.pyc -------------------------------------------------------------------------------- /losses/__pycache__/discriminator_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/discriminator_loss.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/discriminator_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/discriminator_loss.cpython-38.pyc -------------------------------------------------------------------------------- /losses/__pycache__/discriminator_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/losses/__pycache__/discriminator_loss.cpython-39.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/torch_stft.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/torch_stft.cpython-37.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/torch_stft.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/torch_stft.cpython-38.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/torch_stft.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/torch_stft.cpython-39.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/mrd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/mrd.cpython-37.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/mrd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/mrd.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/base_dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/dataloaders/__pycache__/base_dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/base_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/dataloaders/__pycache__/base_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/base_dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/dataloaders/__pycache__/base_dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/soundstream_semantic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/models/__pycache__/soundstream_semantic.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/soundstream_semantic.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/models/__pycache__/soundstream_semantic.cpython-39.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/base_layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/base_layers.cpython-37.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/base_layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/base_layers.cpython-38.pyc -------------------------------------------------------------------------------- /modules/commons/__pycache__/base_layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/commons/__pycache__/base_layers.cpython-39.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /descriptaudiocodec/dac/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import CodecMixin 2 | from .base import DACFile 3 | from .dac import DAC 4 | from .discriminator import Discriminator 5 | -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/combd_sbd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/combd_sbd.cpython-37.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/combd_sbd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/combd_sbd.cpython-38.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/core_vq_lsx_version.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/core_vq_lsx_version.cpython-38.pyc -------------------------------------------------------------------------------- /quantization/__pycache__/core_vq_lsx_version.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/quantization/__pycache__/core_vq_lsx_version.cpython-39.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/period_discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/period_discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/period_discriminator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/period_discriminator.cpython-38.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/period_discriminator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/period_discriminator.cpython-39.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/scale_discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/scale_discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/scale_discriminator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/scale_discriminator.cpython-38.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/scale_discriminator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/scale_discriminator.cpython-39.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/frequency_discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/frequency_discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/frequency_discriminator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/frequency_discriminator.cpython-38.pyc -------------------------------------------------------------------------------- /modules/discriminators/__pycache__/frequency_discriminator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenye234/xcodec/HEAD/modules/discriminators/__pycache__/frequency_discriminator.cpython-39.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | omegaconf 3 | torchaudio 4 | einops 5 | numpy 6 | transformers 7 | tqdm 8 | tensorboard 9 | descript-audiotools>=0.7.2 10 | descript-audio-codec 11 | scipy==1.10.1 -------------------------------------------------------------------------------- /descriptaudiocodec/requirements.txt: -------------------------------------------------------------------------------- 1 | argbind>=0.3.7 2 | descript-audiotools>=0.7.2 3 | einops 4 | numpy 5 | torch 6 | torchaudio 7 | tqdm 8 | tensorboard 9 | numba>=0.5.7 10 | jupyterlab 11 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/ablations/no-adv.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | lambdas: 6 | mel/loss: 1.0 7 | waveform/loss: 1.0 8 | vq/commitment_loss: 0.25 9 | vq/codebook_loss: 1.0 10 | -------------------------------------------------------------------------------- /descriptaudiocodec/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 2 | 3 | COPY . /app 4 | WORKDIR /app 5 | 6 | RUN apt update && apt install -y git 7 | # install the package 8 | RUN pip install . 9 | 10 | # cache the model 11 | RUN python3 -m dac download 12 | -------------------------------------------------------------------------------- /quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from .vq import QuantizedResult, ResidualVectorQuantizer 9 | -------------------------------------------------------------------------------- /descriptaudiocodec/Dockerfile.dev: -------------------------------------------------------------------------------- 1 | ARG IMAGE=pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 2 | ARG GITHUB_TOKEN=none 3 | 4 | FROM $IMAGE 5 | 6 | RUN echo machine github.com login ${GITHUB_TOKEN} > ~/.netrc 7 | 8 | COPY requirements.txt /requirements.txt 9 | 10 | RUN apt update && apt install -y git 11 | 12 | # install the package 13 | RUN pip install --upgrade -r /requirements.txt 14 | -------------------------------------------------------------------------------- /descriptaudiocodec/dac/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | # preserved here for legacy reasons 4 | __model_version__ = "latest" 5 | 6 | import audiotools 7 | 8 | audiotools.ml.BaseModel.INTERN += ["dac.**"] 9 | audiotools.ml.BaseModel.EXTERN += ["einops"] 10 | 11 | 12 | from . import nn 13 | from . import model 14 | from . import utils 15 | from .model import DAC 16 | from .model import DACFile 17 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/downsampling/128x.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | # Model setup 6 | DAC.sample_rate: 44100 7 | DAC.encoder_dim: 64 8 | DAC.encoder_rates: [2, 4, 4, 4] 9 | DAC.decoder_dim: 1536 10 | DAC.decoder_rates: [4, 4, 2, 2, 2, 1] 11 | 12 | # Quantization 13 | DAC.n_codebooks: 2 14 | DAC.codebook_size: 1024 15 | DAC.codebook_dim: 8 16 | DAC.quantizer_dropout: 1.0 17 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/downsampling/1024x.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | # Model setup 6 | DAC.sample_rate: 44100 7 | DAC.encoder_dim: 64 8 | DAC.encoder_rates: [2, 8, 8, 8] 9 | DAC.decoder_dim: 1536 10 | DAC.decoder_rates: [8, 4, 4, 2, 2, 2] 11 | 12 | # Quantization 13 | DAC.n_codebooks: 19 14 | DAC.codebook_size: 1024 15 | DAC.codebook_dim: 8 16 | DAC.quantizer_dropout: 1.0 17 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/downsampling/1536x.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | # Model setup 6 | DAC.sample_rate: 44100 7 | DAC.encoder_dim: 96 8 | DAC.encoder_rates: [2, 8, 8, 12] 9 | DAC.decoder_dim: 1536 10 | DAC.decoder_rates: [12, 4, 4, 2, 2, 2] 11 | 12 | # Quantization 13 | DAC.n_codebooks: 28 14 | DAC.codebook_size: 1024 15 | DAC.codebook_dim: 8 16 | DAC.quantizer_dropout: 1.0 17 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/downsampling/768x.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | # Model setup 6 | DAC.sample_rate: 44100 7 | DAC.encoder_dim: 64 8 | DAC.encoder_rates: [2, 6, 8, 8] 9 | DAC.decoder_dim: 1536 10 | DAC.decoder_rates: [6, 4, 4, 2, 2, 2] 11 | 12 | # Quantization 13 | DAC.n_codebooks: 14 14 | DAC.codebook_size: 1024 15 | DAC.codebook_dim: 8 16 | DAC.quantizer_dropout: 1.0 17 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/ablations/no-mb.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | Discriminator.sample_rate: 44100 6 | Discriminator.fft_sizes: [2048, 1024, 512] 7 | Discriminator.bands: 8 | - [0.0, 1.0] 9 | 10 | # re-weight lambdas to make up for 11 | # lost discriminators vs baseline 12 | lambdas: 13 | mel/loss: 15.0 14 | adv/feat_loss: 5.0 15 | adv/gen_loss: 1.0 16 | vq/commitment_loss: 0.25 17 | vq/codebook_loss: 1.0 18 | -------------------------------------------------------------------------------- /descriptaudiocodec/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/reorder_python_imports 3 | rev: v2.5.0 4 | hooks: 5 | - id: reorder-python-imports 6 | - repo: https://github.com/psf/black 7 | rev: 23.1.0 8 | hooks: 9 | - id: black 10 | language_version: python3 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v4.0.1 13 | hooks: 14 | - id: end-of-file-fixer 15 | - id: trailing-whitespace 16 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/ablations/no-mpd.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | Discriminator.sample_rate: 44100 6 | Discriminator.rates: [1] 7 | Discriminator.periods: [] 8 | Discriminator.fft_sizes: [2048, 1024, 512] 9 | Discriminator.bands: 10 | - [0.0, 0.1] 11 | - [0.1, 0.25] 12 | - [0.25, 0.5] 13 | - [0.5, 0.75] 14 | - [0.75, 1.0] 15 | 16 | lambdas: 17 | mel/loss: 15.0 18 | adv/feat_loss: 2.5 19 | adv/gen_loss: 1.0 20 | vq/commitment_loss: 0.25 21 | vq/codebook_loss: 1.0 22 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/ablations/no-low-hop.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | MelSpectrogramLoss.n_mels: [80] 6 | MelSpectrogramLoss.window_lengths: [512] 7 | MelSpectrogramLoss.mel_fmin: [0] 8 | MelSpectrogramLoss.mel_fmax: [null] 9 | MelSpectrogramLoss.pow: 1.0 10 | MelSpectrogramLoss.clamp_eps: 1.0e-5 11 | MelSpectrogramLoss.mag_weight: 0.0 12 | 13 | lambdas: 14 | mel/loss: 100.0 15 | adv/feat_loss: 2.0 16 | adv/gen_loss: 1.0 17 | vq/commitment_loss: 0.25 18 | vq/codebook_loss: 1.0 19 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/ablations/no-mpd-msd.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | Discriminator.sample_rate: 44100 6 | Discriminator.rates: [] 7 | Discriminator.periods: [] 8 | Discriminator.fft_sizes: [2048, 1024, 512] 9 | Discriminator.bands: 10 | - [0.0, 0.1] 11 | - [0.1, 0.25] 12 | - [0.25, 0.5] 13 | - [0.5, 0.75] 14 | - [0.75, 1.0] 15 | 16 | lambdas: 17 | mel/loss: 15.0 18 | adv/feat_loss: 2.66 19 | adv/gen_loss: 1.0 20 | vq/commitment_loss: 0.25 21 | vq/codebook_loss: 1.0 22 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/ablations/diff-mb.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | Discriminator.sample_rate: 44100 6 | Discriminator.fft_sizes: [2048, 1024, 512] 7 | Discriminator.bands: 8 | - [0.0, 0.05] 9 | - [0.05, 0.1] 10 | - [0.1, 0.25] 11 | - [0.25, 0.5] 12 | - [0.5, 1.0] 13 | 14 | 15 | # re-weight lambdas to make up for 16 | # lost discriminators vs baseline 17 | lambdas: 18 | mel/loss: 15.0 19 | adv/feat_loss: 5.0 20 | adv/gen_loss: 1.0 21 | vq/commitment_loss: 0.25 22 | vq/codebook_loss: 1.0 23 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/ablations/equal-mb.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | Discriminator.sample_rate: 44100 6 | Discriminator.fft_sizes: [2048, 1024, 512] 7 | Discriminator.bands: 8 | - [0.0, 0.2] 9 | - [0.2, 0.4] 10 | - [0.4, 0.6] 11 | - [0.6, 0.8] 12 | - [0.8, 1.0] 13 | 14 | 15 | # re-weight lambdas to make up for 16 | # lost discriminators vs baseline 17 | lambdas: 18 | mel/loss: 15.0 19 | adv/feat_loss: 5.0 20 | adv/gen_loss: 1.0 21 | vq/commitment_loss: 0.25 22 | vq/codebook_loss: 1.0 23 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/ablations/only-speech.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | train/build_dataset.folders: 6 | speech_fb: 7 | - /data/daps/train 8 | speech_hq: 9 | - /data/vctk 10 | - /data/vocalset 11 | - /data/read_speech 12 | - /data/french_speech 13 | speech_uq: 14 | - /data/emotional_speech/ 15 | - /data/common_voice/ 16 | - /data/german_speech/ 17 | - /data/russian_speech/ 18 | - /data/spanish_speech/ 19 | 20 | val/build_dataset.folders: 21 | speech_hq: 22 | - /data/daps/val 23 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Torch modules.""" 8 | 9 | # flake8: noqa 10 | from .conv import ( 11 | pad1d, 12 | unpad1d, 13 | NormConv1d, 14 | NormConvTranspose1d, 15 | NormConv2d, 16 | NormConvTranspose2d, 17 | SConv1d, 18 | SConvTranspose1d, 19 | ) 20 | from .lstm import SLSTM 21 | from .seanet import SEANetEncoder, SEANetDecoder 22 | from .transformer import StreamingTransformerEncoder 23 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/ablations/no-data-balance.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | train/build_dataset.folders: 6 | speech: 7 | - /data/daps/train 8 | - /data/vctk 9 | - /data/vocalset 10 | - /data/read_speech 11 | - /data/french_speech 12 | - /data/emotional_speech/ 13 | - /data/common_voice/ 14 | - /data/german_speech/ 15 | - /data/russian_speech/ 16 | - /data/spanish_speech/ 17 | music: 18 | - /data/musdb/train 19 | - /data/jamendo 20 | general: 21 | - /data/audioset/data/unbalanced_train_segments/ 22 | - /data/audioset/data/balanced_train_segments/ 23 | -------------------------------------------------------------------------------- /descriptaudiocodec/dac/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import argbind 4 | 5 | from dac.utils import download 6 | from dac.utils.decode import decode 7 | from dac.utils.encode import encode 8 | 9 | STAGES = ["encode", "decode", "download"] 10 | 11 | 12 | def run(stage: str): 13 | """Run stages. 14 | 15 | Parameters 16 | ---------- 17 | stage : str 18 | Stage to run 19 | """ 20 | if stage not in STAGES: 21 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") 22 | stage_fn = globals()[stage] 23 | 24 | if stage == "download": 25 | stage_fn() 26 | return 27 | 28 | stage_fn() 29 | 30 | 31 | if __name__ == "__main__": 32 | group = sys.argv.pop(1) 33 | args = argbind.parse_args(group=group) 34 | 35 | with argbind.scope(args): 36 | run(group) 37 | -------------------------------------------------------------------------------- /modules/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """LSTM layers module.""" 8 | 9 | from torch import nn 10 | 11 | 12 | class SLSTM(nn.Module): 13 | """ 14 | LSTM without worrying about the hidden state, nor the layout of the data. 15 | Expects input as convolutional layout. 16 | """ 17 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): 18 | super().__init__() 19 | self.skip = skip 20 | self.lstm = nn.LSTM(dimension, dimension, num_layers) 21 | 22 | def forward(self, x): 23 | x = x.permute(2, 0, 1) 24 | y, _ = self.lstm(x) 25 | if self.skip: 26 | y = y + x 27 | y = y.permute(1, 2, 0) 28 | return y 29 | -------------------------------------------------------------------------------- /modules/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Normalization modules.""" 8 | 9 | import typing as tp 10 | 11 | import einops 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class ConvLayerNorm(nn.LayerNorm): 17 | """ 18 | Convolution-friendly LayerNorm that moves channels to last dimensions 19 | before running the normalization and moves them back to original position right after. 20 | """ 21 | def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): 22 | super().__init__(normalized_shape, **kwargs) 23 | 24 | def forward(self, x): 25 | x = einops.rearrange(x, 'b ... t -> b t ...') 26 | x = super().forward(x) 27 | x = einops.rearrange(x, 'b t ... -> b ... t') 28 | return 29 | -------------------------------------------------------------------------------- /descriptaudiocodec/dac/nn/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch.nn.utils import weight_norm 7 | 8 | 9 | def WNConv1d(*args, **kwargs): 10 | return weight_norm(nn.Conv1d(*args, **kwargs)) 11 | 12 | 13 | def WNConvTranspose1d(*args, **kwargs): 14 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 15 | 16 | 17 | # Scripting this brings model speed up 1.4x 18 | @torch.jit.script 19 | def snake(x, alpha): 20 | shape = x.shape 21 | x = x.reshape(shape[0], shape[1], -1) 22 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) 23 | x = x.reshape(shape) 24 | return x 25 | 26 | 27 | class Snake1d(nn.Module): 28 | def __init__(self, channels): 29 | super().__init__() 30 | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) 31 | 32 | def forward(self, x): 33 | return snake(x, self.alpha) 34 | -------------------------------------------------------------------------------- /descriptaudiocodec/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.5" 2 | services: 3 | base: 4 | build: 5 | context: . 6 | dockerfile: ./Dockerfile.dev 7 | args: 8 | GITHUB_TOKEN: ${GITHUB_TOKEN} 9 | IMAGE: ${IMAGE} 10 | volumes: 11 | - .:/u/home/src 12 | - ${PATH_TO_DATA}:/data 13 | - ${PATH_TO_RUNS}:/runs 14 | - ~/.config/gcloud:/u/home/.config/gcloud 15 | - ~/.zsh_history:/u/home/.zsh_history 16 | environment: 17 | - GITHUB_TOKEN 18 | - HOST_USER_ID 19 | - HOST_USER_GID 20 | - JUPYTER_TOKEN=password 21 | - PATH_TO_DATA=/data 22 | - PATH_TO_RUNS=/runs 23 | - MPLCONFIGDIR=/u/home/.mplconfig 24 | shm_size: 32G 25 | working_dir: /u/home/src 26 | deploy: 27 | resources: 28 | reservations: 29 | devices: 30 | - driver: nvidia 31 | capabilities: [gpu] 32 | dev: 33 | extends: base 34 | profiles: 35 | - interactive 36 | stdin_open: true 37 | tty: true 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 YE Zhen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /descriptaudiocodec/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-present, Descript 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /descriptaudiocodec/scripts/compute_entropy.py: -------------------------------------------------------------------------------- 1 | import argbind 2 | import audiotools as at 3 | import numpy as np 4 | import torch 5 | import tqdm 6 | 7 | import dac 8 | 9 | 10 | @argbind.bind(without_prefix=True, positional=True) 11 | def main( 12 | folder: str, 13 | model_path: str, 14 | n_samples: int = 1024, 15 | device: str = "cuda", 16 | ): 17 | files = at.util.find_audio(folder)[:n_samples] 18 | signals = [ 19 | at.AudioSignal.salient_excerpt(f, loudness_cutoff=-20, duration=1.0) 20 | for f in files 21 | ] 22 | 23 | with torch.no_grad(): 24 | model = dac.model.DAC.load(model_path).to(device) 25 | model.eval() 26 | 27 | codes = [] 28 | for x in tqdm.tqdm(signals): 29 | x = x.to(model.device) 30 | o = model.encode(x.audio_data, x.sample_rate) 31 | codes.append(o["codes"].cpu()) 32 | 33 | codes = torch.cat(codes, dim=-1) 34 | entropy = [] 35 | 36 | for i in range(codes.shape[1]): 37 | codes_ = codes[0, i, :] 38 | counts = torch.bincount(codes_) 39 | counts = (counts / counts.sum()).clamp(1e-10) 40 | entropy.append(-(counts * counts.log()).sum().item() * np.log2(np.e)) 41 | 42 | pct = sum(entropy) / (10 * len(entropy)) 43 | print(f"Entropy for each codebook: {entropy}") 44 | print(f"Effective percentage: {pct * 100}%") 45 | 46 | 47 | if __name__ == "__main__": 48 | args = argbind.parse_args() 49 | with argbind.scope(args): 50 | main() 51 | -------------------------------------------------------------------------------- /descriptaudiocodec/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages 2 | from setuptools import setup 3 | 4 | with open("README.md") as f: 5 | long_description = f.read() 6 | 7 | setup( 8 | name="descript-audio-codec", 9 | version="1.0.0", 10 | classifiers=[ 11 | "Intended Audience :: Developers", 12 | "Natural Language :: English", 13 | "Programming Language :: Python :: 3.7", 14 | "Topic :: Artistic Software", 15 | "Topic :: Multimedia", 16 | "Topic :: Multimedia :: Sound/Audio", 17 | "Topic :: Multimedia :: Sound/Audio :: Editors", 18 | "Topic :: Software Development :: Libraries", 19 | ], 20 | description="A high-quality general neural audio codec.", 21 | long_description=long_description, 22 | long_description_content_type="text/markdown", 23 | author="Prem Seetharaman, Rithesh Kumar", 24 | author_email="prem@descript.com", 25 | url="https://github.com/descriptinc/descript-audio-codec", 26 | license="MIT", 27 | packages=find_packages(), 28 | keywords=["audio", "compression", "machine learning"], 29 | install_requires=[ 30 | "argbind>=0.3.7", 31 | "descript-audiotools>=0.7.2", 32 | "einops", 33 | "numpy", 34 | "torch", 35 | "torchaudio", 36 | "tqdm", 37 | ], 38 | extras_require={ 39 | "dev": [ 40 | "pytest", 41 | "pytest-cov", 42 | "pynvml", 43 | "psutil", 44 | "pandas", 45 | "onnx", 46 | "onnx-simplifier", 47 | "seaborn", 48 | "jupyterlab", 49 | "pandas", 50 | "watchdog", 51 | "pesq", 52 | "tabulate", 53 | "encodec", 54 | ], 55 | }, 56 | ) 57 | -------------------------------------------------------------------------------- /descriptaudiocodec/dac/compare/encodec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from audiotools import AudioSignal 3 | from audiotools.ml import BaseModel 4 | from encodec import EncodecModel 5 | 6 | 7 | class Encodec(BaseModel): 8 | def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0): 9 | super().__init__() 10 | 11 | if sample_rate == 24000: 12 | self.model = EncodecModel.encodec_model_24khz() 13 | else: 14 | self.model = EncodecModel.encodec_model_48khz() 15 | self.model.set_target_bandwidth(bandwidth) 16 | self.sample_rate = 44100 17 | 18 | def forward( 19 | self, 20 | audio_data: torch.Tensor, 21 | sample_rate: int = 44100, 22 | n_quantizers: int = None, 23 | ): 24 | signal = AudioSignal(audio_data, sample_rate) 25 | signal.resample(self.model.sample_rate) 26 | recons = self.model(signal.audio_data) 27 | recons = AudioSignal(recons, self.model.sample_rate) 28 | recons.resample(sample_rate) 29 | return {"audio": recons.audio_data} 30 | 31 | 32 | if __name__ == "__main__": 33 | import numpy as np 34 | from functools import partial 35 | 36 | model = Encodec() 37 | 38 | for n, m in model.named_modules(): 39 | o = m.extra_repr() 40 | p = sum([np.prod(p.size()) for p in m.parameters()]) 41 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params." 42 | setattr(m, "extra_repr", partial(fn, o=o, p=p)) 43 | print(model) 44 | print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) 45 | 46 | length = 88200 * 2 47 | x = torch.randn(1, 1, length).to(model.device) 48 | x.requires_grad_(True) 49 | x.retain_grad() 50 | 51 | # Make a forward pass 52 | out = model(x)["audio"] 53 | 54 | print(x.shape, out.shape) 55 | -------------------------------------------------------------------------------- /descriptaudiocodec/scripts/save_test_set.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from pathlib import Path 3 | 4 | import argbind 5 | import torch 6 | from audiotools.core import util 7 | from audiotools.ml.decorators import Tracker 8 | from train import Accelerator 9 | 10 | import scripts.train as train 11 | 12 | 13 | @torch.no_grad() 14 | def process(batch, accel, test_data): 15 | batch = util.prepare_batch(batch, accel.device) 16 | signal = test_data.transform(batch["signal"].clone(), **batch["transform_args"]) 17 | return signal.cpu() 18 | 19 | 20 | @argbind.bind(without_prefix=True) 21 | @torch.no_grad() 22 | def save_test_set(args, accel, sample_rate: int = 44100, output: str = "samples/input"): 23 | tracker = Tracker() 24 | with argbind.scope(args, "test"): 25 | test_data = train.build_dataset(sample_rate) 26 | 27 | global process 28 | process = tracker.track("process", len(test_data))(process) 29 | 30 | output = Path(output) 31 | output.mkdir(parents=True, exist_ok=True) 32 | (output.parent / "input").mkdir(parents=True, exist_ok=True) 33 | with open(output / "metadata.csv", "w") as csvfile: 34 | keys = ["path", "original"] 35 | writer = csv.DictWriter(csvfile, fieldnames=keys) 36 | writer.writeheader() 37 | 38 | with tracker.live: 39 | for i in range(len(test_data)): 40 | signal = process(test_data[i], accel, test_data) 41 | input_path = output.parent / "input" / f"sample_{i}.wav" 42 | metadata = { 43 | "path": str(input_path), 44 | "original": str(signal.path_to_input_file), 45 | } 46 | writer.writerow(metadata) 47 | signal.write(input_path) 48 | tracker.done("test", f"N={len(test_data)}") 49 | 50 | 51 | if __name__ == "__main__": 52 | args = argbind.parse_args() 53 | with argbind.scope(args): 54 | with Accelerator() as accel: 55 | save_test_set(args, accel) 56 | -------------------------------------------------------------------------------- /dataloaders/base_dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | import torchaudio 6 | from torchaudio.transforms import Resample 7 | 8 | 9 | class WaveDataset(Dataset): 10 | def __init__( 11 | self, 12 | flist_file, 13 | segment_size, 14 | sampling_rate, 15 | split=True, # whether or not to get a segment of an audio sample to form the batch 16 | shuffle=False, 17 | audio_norm_scale: float = 1.0, 18 | ): 19 | self.file_list = self.get_filelist(flist_file) 20 | if shuffle: 21 | random.shuffle(self.file_list) 22 | self.segment_size = segment_size 23 | self.sampling_rate = sampling_rate 24 | self.split = split 25 | self.audio_norm_scale = audio_norm_scale 26 | 27 | def get_filelist(self, fpath): 28 | with open(fpath, 'r') as f: 29 | flist = [l.strip() for l in f if l.strip()] 30 | return flist 31 | 32 | def __getitem__(self, index): 33 | fname = self.file_list[index] 34 | audio, sr = torchaudio.load(fname) 35 | if sr != self.sampling_rate: 36 | audio = Resample(sr, self.sampling_rate)(audio) 37 | if self.audio_norm_scale < 1.0: 38 | audio = audio * self.audio_norm_scale 39 | 40 | if self.split: 41 | if audio.size(1) >= self.segment_size: 42 | max_audio_start = audio.size(1) - self.segment_size 43 | audio_start = random.randint(0, max_audio_start) 44 | audio = audio[:, audio_start:audio_start+self.segment_size] 45 | else: 46 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 47 | # in case, audio clip is too short in validation set 48 | if audio.size(1) < self.segment_size: 49 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 50 | 51 | return audio 52 | 53 | def __len__(self): 54 | return len(self.file_list) 55 | -------------------------------------------------------------------------------- /utils/hifigan_mel.py: -------------------------------------------------------------------------------- 1 | """Adapted from https://github.com/jik876/hifi-gan""" 2 | 3 | import torch 4 | import numpy as np 5 | from scipy.io.wavfile import read 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | 9 | def load_wav(full_path): 10 | sampling_rate, data = read(full_path) 11 | return data, sampling_rate 12 | 13 | 14 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 15 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 16 | 17 | 18 | def dynamic_range_decompression(x, C=1): 19 | return np.exp(x) / C 20 | 21 | 22 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 23 | return torch.log(torch.clamp(x, min=clip_val) * C) 24 | 25 | 26 | def dynamic_range_decompression_torch(x, C=1): 27 | return torch.exp(x) / C 28 | 29 | 30 | def spectral_normalize_torch(magnitudes): 31 | output = dynamic_range_compression_torch(magnitudes) 32 | return output 33 | 34 | 35 | def spectral_de_normalize_torch(magnitudes): 36 | output = dynamic_range_decompression_torch(magnitudes) 37 | return output 38 | 39 | 40 | mel_basis = {} 41 | hann_window = {} 42 | 43 | def mel_spectrogram( 44 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax=None, center=False): 45 | if torch.min(y) < -1.: 46 | print('min value is ', torch.min(y)) 47 | if torch.max(y) > 1.: 48 | print('max value is ', torch.max(y)) 49 | 50 | global mel_basis, hann_window 51 | if fmax not in mel_basis: 52 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 53 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 54 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 55 | 56 | y = torch.nn.functional.pad( 57 | y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 58 | y = y.squeeze(1) 59 | 60 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 61 | center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=False) 62 | 63 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 64 | 65 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 66 | spec = spectral_normalize_torch(spec) 67 | 68 | return spec 69 | -------------------------------------------------------------------------------- /losses/discriminator_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from losses.basic_loss import MSEDLoss 5 | 6 | 7 | class BasicDiscriminatorLoss(nn.Module): 8 | """Least-square GAN loss.""" 9 | 10 | def __init__(self, config=None): 11 | super(BasicDiscriminatorLoss, self).__init__() 12 | 13 | def forward(self, real_outputs, fake_outputs): 14 | loss = 0 15 | real_losses = [] 16 | fake_losses = [] 17 | for dr, dg in zip(real_outputs, fake_outputs): 18 | dr = dr.float() 19 | dg = dg.float() 20 | real_loss = torch.mean((1-dr)**2) 21 | fake_loss = torch.mean(dg**2) 22 | loss += (real_loss + fake_loss) 23 | real_losses.append(real_loss.item()) 24 | fake_losses.append(fake_loss.item()) 25 | 26 | return loss 27 | 28 | 29 | class MSEDiscriminatorLoss(BasicDiscriminatorLoss): 30 | def __init__(self, config=None): 31 | super().__init__(config) 32 | self.mse_loss = MSEDLoss() 33 | 34 | def apply_d_loss(self, scores_fake, scores_real, loss_func): 35 | total_loss = 0 36 | total_real_loss = 0 37 | total_fake_loss = 0 38 | if isinstance(scores_fake, list): 39 | # multi-scale loss 40 | for score_fake, score_real in zip(scores_fake, scores_real): 41 | loss, real_loss, fake_loss = loss_func(score_fake=score_fake, score_real=score_real) 42 | total_loss = total_loss + loss 43 | total_real_loss = total_real_loss + real_loss 44 | total_fake_loss = total_fake_loss + fake_loss 45 | # normalize loss values with number of scales 46 | total_loss /= len(scores_fake) 47 | total_real_loss /= len(scores_real) 48 | total_fake_loss /= len(scores_fake) 49 | else: 50 | # single scale loss 51 | total_loss, total_real_loss, total_fake_loss = loss_func(scores_fake, scores_real) 52 | return total_loss, total_real_loss, total_fake_loss 53 | 54 | def forward(self, real_scores, fake_scores): 55 | mse_D_loss, mse_D_real_loss, mse_D_fake_loss = self.apply_d_loss( 56 | scores_fake=fake_scores, 57 | scores_real=real_scores, 58 | loss_func=self.mse_loss) 59 | return mse_D_loss 60 | -------------------------------------------------------------------------------- /descriptaudiocodec/tests/test_cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for CLI. 3 | """ 4 | import subprocess 5 | from pathlib import Path 6 | 7 | import argbind 8 | import numpy as np 9 | import pytest 10 | import torch 11 | from audiotools import AudioSignal 12 | 13 | from dac.__main__ import run 14 | 15 | 16 | def setup_module(module): 17 | data_dir = Path(__file__).parent / "assets" 18 | data_dir.mkdir(exist_ok=True, parents=True) 19 | input_dir = data_dir / "input" 20 | input_dir.mkdir(exist_ok=True, parents=True) 21 | 22 | for i in range(5): 23 | signal = AudioSignal(np.random.randn(1000), 44_100) 24 | signal.write(input_dir / f"sample_{i}.wav") 25 | return input_dir 26 | 27 | 28 | def teardown_module(module): 29 | repo_root = Path(__file__).parent.parent 30 | subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/assets"]) 31 | 32 | 33 | @pytest.mark.parametrize("model_type", ["44khz", "24khz", "16khz"]) 34 | def test_reconstruction(model_type): 35 | # Test encoding 36 | input_dir = Path(__file__).parent / "assets" / "input" 37 | output_dir = input_dir.parent / model_type / "encoded_output" 38 | args = { 39 | "input": str(input_dir), 40 | "output": str(output_dir), 41 | "device": "cuda" if torch.cuda.is_available() else "cpu", 42 | "model_type": model_type, 43 | } 44 | with argbind.scope(args): 45 | run("encode") 46 | 47 | # Test decoding 48 | input_dir = output_dir 49 | output_dir = input_dir.parent / model_type / "decoded_output" 50 | args = { 51 | "input": str(input_dir), 52 | "output": str(output_dir), 53 | "model_type": model_type, 54 | } 55 | with argbind.scope(args): 56 | run("decode") 57 | 58 | 59 | def test_compression(): 60 | # Test encoding 61 | input_dir = Path(__file__).parent / "assets" / "input" 62 | output_dir = input_dir.parent / "encoded_output_quantizers" 63 | args = { 64 | "input": str(input_dir), 65 | "output": str(output_dir), 66 | "n_quantizers": 3, 67 | "device": "cuda" if torch.cuda.is_available() else "cpu", 68 | } 69 | with argbind.scope(args): 70 | run("encode") 71 | 72 | # Open .dac file 73 | dac_file = output_dir / "sample_0.dac" 74 | artifacts = np.load(dac_file, allow_pickle=True)[()] 75 | codes = artifacts["codes"] 76 | 77 | # Ensure that the number of quantizers is correct 78 | assert codes.shape[1] == 3 79 | 80 | # Ensure that dtype of compression is uint16 81 | assert codes.dtype == np.uint16 82 | 83 | 84 | # CUDA_VISIBLE_DEVICES=0 python -m pytest tests/test_cli.py -s 85 | -------------------------------------------------------------------------------- /descriptaudiocodec/scripts/organize_daps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import shutil 4 | from collections import defaultdict 5 | from typing import Tuple 6 | 7 | import argbind 8 | import numpy as np 9 | import tqdm 10 | from audiotools import util 11 | 12 | 13 | @argbind.bind() 14 | def split( 15 | audio_files, ratio: Tuple[float, float, float] = (0.8, 0.1, 0.1), seed: int = 0 16 | ): 17 | assert sum(ratio) == 1.0 18 | util.seed(seed) 19 | 20 | idx = np.arange(len(audio_files)) 21 | np.random.shuffle(idx) 22 | 23 | b = np.cumsum([0] + list(ratio)) * len(idx) 24 | b = [int(_b) for _b in b] 25 | train_idx = idx[b[0] : b[1]] 26 | val_idx = idx[b[1] : b[2]] 27 | test_idx = idx[b[2] :] 28 | 29 | audio_files = np.array(audio_files) 30 | train_files = audio_files[train_idx] 31 | val_files = audio_files[val_idx] 32 | test_files = audio_files[test_idx] 33 | 34 | return train_files, val_files, test_files 35 | 36 | 37 | def assign(val_split, test_split): 38 | def _assign(value): 39 | if value in val_split: 40 | return "val" 41 | if value in test_split: 42 | return "test" 43 | return "train" 44 | 45 | return _assign 46 | 47 | 48 | DAPS_VAL = ["f2", "m2"] 49 | DAPS_TEST = ["f10", "m10"] 50 | 51 | 52 | @argbind.bind(without_prefix=True) 53 | def process( 54 | dataset: str = "daps", 55 | daps_subset: str = "", 56 | ): 57 | get_split = None 58 | get_value = lambda path: path 59 | 60 | data_path = pathlib.Path("/data") 61 | dataset_path = data_path / dataset 62 | audio_files = util.find_audio(dataset_path) 63 | 64 | if dataset == "daps": 65 | get_split = assign(DAPS_VAL, DAPS_TEST) 66 | get_value = lambda path: (str(path).split("/")[-1].split("_", maxsplit=4)[0]) 67 | audio_files = [ 68 | x 69 | for x in util.find_audio(dataset_path) 70 | if daps_subset in str(x) and "breaths" not in str(x) 71 | ] 72 | 73 | if get_split is None: 74 | _, val, test = split(audio_files) 75 | get_split = assign(val, test) 76 | 77 | splits = defaultdict(list) 78 | for x in audio_files: 79 | _split = get_split(get_value(x)) 80 | splits[_split].append(x) 81 | 82 | with util.chdir(dataset_path): 83 | for k, v in splits.items(): 84 | v = sorted(v) 85 | print(f"Processing {k} in {dataset_path} of length {len(v)}") 86 | for _v in tqdm.tqdm(v): 87 | tgt_path = pathlib.Path( 88 | str(_v).replace(str(dataset_path), str(dataset_path / k)) 89 | ) 90 | tgt_path.parent.mkdir(parents=True, exist_ok=True) 91 | shutil.copyfile(_v, tgt_path) 92 | 93 | 94 | if __name__ == "__main__": 95 | args = argbind.parse_args() 96 | with argbind.scope(args): 97 | process() 98 | -------------------------------------------------------------------------------- /descriptaudiocodec/scripts/get_samples.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import argbind 4 | import torch 5 | from audiotools import AudioSignal 6 | from audiotools.core import util 7 | from audiotools.ml.decorators import Tracker 8 | from train import Accelerator 9 | from train import DAC 10 | 11 | from dac.compare.encodec import Encodec 12 | 13 | Encodec = argbind.bind(Encodec) 14 | 15 | 16 | def load_state( 17 | accel: Accelerator, 18 | tracker: Tracker, 19 | save_path: str, 20 | tag: str = "latest", 21 | load_weights: bool = False, 22 | model_type: str = "dac", 23 | bandwidth: float = 24.0, 24 | ): 25 | kwargs = { 26 | "folder": f"{save_path}/{tag}", 27 | "map_location": "cpu", 28 | "package": not load_weights, 29 | } 30 | tracker.print(f"Resuming from {str(Path('.').absolute())}/{kwargs['folder']}") 31 | 32 | if model_type == "dac": 33 | generator, _ = DAC.load_from_folder(**kwargs) 34 | elif model_type == "encodec": 35 | generator = Encodec(bandwidth=bandwidth) 36 | 37 | generator = accel.prepare_model(generator) 38 | return generator 39 | 40 | 41 | @torch.no_grad() 42 | def process(signal, accel, generator, **kwargs): 43 | signal = signal.to(accel.device) 44 | recons = generator(signal.audio_data, signal.sample_rate, **kwargs)["audio"] 45 | recons = AudioSignal(recons, signal.sample_rate) 46 | recons = recons.normalize(signal.loudness()) 47 | return recons.cpu() 48 | 49 | 50 | @argbind.bind(without_prefix=True) 51 | @torch.no_grad() 52 | def get_samples( 53 | accel, 54 | path: str = "ckpt", 55 | input: str = "samples/input", 56 | output: str = "samples/output", 57 | model_type: str = "dac", 58 | model_tag: str = "latest", 59 | bandwidth: float = 24.0, 60 | n_quantizers: int = None, 61 | ): 62 | tracker = Tracker(log_file=f"{path}/eval.txt", rank=accel.local_rank) 63 | generator = load_state( 64 | accel, 65 | tracker, 66 | save_path=path, 67 | model_type=model_type, 68 | bandwidth=bandwidth, 69 | tag=model_tag, 70 | ) 71 | generator.eval() 72 | kwargs = {"n_quantizers": n_quantizers} if model_type == "dac" else {} 73 | 74 | audio_files = util.find_audio(input) 75 | 76 | global process 77 | process = tracker.track("process", len(audio_files))(process) 78 | 79 | output = Path(output) 80 | output.mkdir(parents=True, exist_ok=True) 81 | 82 | with tracker.live: 83 | for i in range(len(audio_files)): 84 | signal = AudioSignal(audio_files[i]) 85 | recons = process(signal, accel, generator, **kwargs) 86 | recons.write(output / audio_files[i].name) 87 | 88 | tracker.done("test", f"N={len(audio_files)}") 89 | 90 | 91 | if __name__ == "__main__": 92 | args = argbind.parse_args() 93 | with argbind.scope(args): 94 | with Accelerator() as accel: 95 | get_samples(accel) 96 | -------------------------------------------------------------------------------- /descriptaudiocodec/tests/test_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for CLI. 3 | """ 4 | import os 5 | import shlex 6 | import subprocess 7 | from pathlib import Path 8 | 9 | import argbind 10 | import numpy as np 11 | from audiotools import AudioSignal 12 | 13 | from dac.__main__ import run 14 | 15 | 16 | def make_fake_data(data_dir=Path(__file__).parent / "assets"): 17 | data_dir.mkdir(exist_ok=True, parents=True) 18 | input_dir = data_dir / "input" 19 | input_dir.mkdir(exist_ok=True, parents=True) 20 | 21 | for i in range(100): 22 | signal = AudioSignal(np.random.randn(44_100 * 5), 44_100) 23 | signal.write(input_dir / f"sample_{i}.wav") 24 | return input_dir 25 | 26 | 27 | def make_fake_data_tree(): 28 | data_dir = Path(__file__).parent / "assets" 29 | 30 | for relative_dir in [ 31 | "train/speech", 32 | "train/music", 33 | "train/env", 34 | "val/speech", 35 | "val/music", 36 | "val/env", 37 | "test/speech", 38 | "test/music", 39 | "test/env", 40 | ]: 41 | leaf_dir = data_dir / relative_dir 42 | leaf_dir.mkdir(exist_ok=True, parents=True) 43 | make_fake_data(leaf_dir) 44 | return { 45 | split: { 46 | key: [str(data_dir / f"{split}/{key}")] 47 | for key in ["speech", "music", "env"] 48 | } 49 | for split in ["train", "val", "test"] 50 | } 51 | 52 | 53 | def setup_module(module): 54 | # Make fake dataset dir 55 | input_datasets = make_fake_data_tree() 56 | repo_root = Path(__file__).parent.parent 57 | 58 | # Load baseline conf and modify it for testing 59 | conf = argbind.load_args(repo_root / "conf" / "ablations" / "baseline.yml") 60 | 61 | for key in ["train", "val", "test"]: 62 | conf[f"{key}/build_dataset.folders"] = input_datasets[key] 63 | conf["num_iters"] = 1 64 | conf["val/AudioDataset.n_examples"] = 1 65 | conf["val_idx"] = [0] 66 | conf["val_batch_size"] = 1 67 | 68 | argbind.dump_args(conf, Path(__file__).parent / "assets" / "conf.yml") 69 | 70 | 71 | def teardown_module(module): 72 | repo_root = Path(__file__).parent.parent 73 | # Remove fake dataset dir 74 | subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/assets"]) 75 | subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/runs"]) 76 | 77 | 78 | def test_single_gpu_train(): 79 | env = os.environ.copy() 80 | env["CUDA_VISIBLE_DEVICES"] = "0" 81 | repo_root = Path(__file__).parent.parent 82 | args = shlex.split( 83 | f"python {repo_root}/scripts/train.py --args.load {repo_root}/tests/assets/conf.yml --save_path {repo_root}/tests/runs/baseline" 84 | ) 85 | subprocess.check_output(args, env=env) 86 | 87 | 88 | def test_multi_gpu_train(): 89 | env = os.environ.copy() 90 | env["CUDA_VISIBLE_DEVICES"] = "0,1" 91 | repo_root = Path(__file__).parent.parent 92 | args = shlex.split( 93 | f"torchrun --nproc_per_node gpu {repo_root}/scripts/train.py --args.load {repo_root}/tests/assets/conf.yml --save_path {repo_root}/tests/runs/baseline_multigpu" 94 | ) 95 | subprocess.check_output(args, env=env) 96 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | import sys 5 | import torchaudio 6 | 7 | import torch 8 | import typing as tp 9 | from omegaconf import OmegaConf 10 | 11 | from models.soundstream_semantic import SoundStream 12 | import torch.nn.functional as F 13 | 14 | 15 | def build_codec_model(config): 16 | model = eval(config.generator.name)(**config.generator.config) 17 | return model 18 | 19 | def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], sample_rate: int, rescale: bool = False): 20 | limit = 0.99 21 | mx = wav.abs().max() 22 | if rescale: 23 | wav = wav * min(limit / mx, 1) 24 | else: 25 | wav = wav.clamp(-limit, limit) 26 | 27 | path = str(Path(path).with_suffix('.wav')) 28 | torchaudio.save(path, wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16) 29 | 30 | 31 | 32 | def process_audio(input_file, output_file, rescale, args, config, soundstream): 33 | # Loading audio 34 | wav, sr = torchaudio.load(input_file) 35 | if wav.size(0) > 1: 36 | wav = wav.mean(0, keepdim=True) # Convert to mono 37 | if sr != soundstream.sample_rate: 38 | wav = torchaudio.transforms.Resample(sr, soundstream.sample_rate)(wav) 39 | if config.audio_norm_scale < 1.0: 40 | wav = wav * config.audio_norm_scale 41 | 42 | 43 | wav = wav.unsqueeze(1).cuda() 44 | compressed = soundstream.encode(wav, target_bw=args.bw) 45 | print(f"Compressed shape: {compressed.shape}") 46 | # Decode and save 47 | out = soundstream.decode(compressed) 48 | out = out.detach().cpu().squeeze(0) 49 | 50 | save_audio(out, output_file, 16000, rescale=rescale) 51 | print(f"Processed and saved: {output_file}") 52 | 53 | def main(): 54 | parser = argparse.ArgumentParser( 55 | description='High fidelity neural audio codec for a single file.') 56 | parser.add_argument('--input', type=Path, default='test_audio/speech_en.flac', help='Input audio file.') 57 | parser.add_argument('--output', type=Path, default='test_audio_reconstruction/speech_en_nq_1.wav', help='Output audio file.') 58 | parser.add_argument('--resume_path', type=str, default='speech_ckpt/hubert_1k_data/xcodec_speech_hubert.pth', help='Path to model checkpoint.') 59 | parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.') 60 | #bw 0.5-> nq 1; 1->nq 2; 2->nq 4; 4->nq 8 61 | parser.add_argument('-b', '--bw', type=str, default=0.5, help='Target bandwidth.') 62 | args = parser.parse_args() 63 | 64 | args.bw = float(args.bw) 65 | 66 | if not args.input.exists(): 67 | sys.exit(f"Input file {args.input} does not exist.") 68 | 69 | config_path = os.path.join(os.path.dirname(args.resume_path), 'config.yaml') 70 | if not os.path.isfile(config_path): 71 | sys.exit(f"{config_path} file does not exist.") 72 | 73 | config = OmegaConf.load(config_path) 74 | soundstream = build_codec_model(config) 75 | parameter_dict = torch.load(args.resume_path) 76 | soundstream.load_state_dict(parameter_dict ) # Load model 77 | soundstream = soundstream.cuda() 78 | soundstream.eval() 79 | 80 | process_audio(args.input, args.output, args.rescale, args, config, soundstream) 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /descriptaudiocodec/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/env.sh 108 | venv/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # Files created by experiments 131 | output/ 132 | snapshot/ 133 | *.m4a 134 | *.wav 135 | notebooks/scratch.ipynb 136 | notebooks/inspect.ipynb 137 | notebooks/effects.ipynb 138 | notebooks/*.ipynb 139 | notebooks/*.gif 140 | notebooks/*.wav 141 | notebooks/*.mp4 142 | *runs/ 143 | boards/ 144 | samples/ 145 | *.ipynb 146 | 147 | results.json 148 | metrics.csv 149 | mprofile_* 150 | mem.png 151 | 152 | results/ 153 | mprofile* 154 | *.png 155 | # do not ignore the test wav file 156 | !tests/audio/short_test_audio.wav 157 | !tests/audio/output.wav 158 | */.DS_Store 159 | .DS_Store 160 | env.sh 161 | _codebraid/ 162 | **/*.html 163 | **/*.exec.md 164 | flagged/ 165 | log.txt 166 | ckpt/ 167 | .syncthing* 168 | tests/assets/ 169 | archived/ 170 | 171 | *_remote_module_* 172 | *.zip 173 | *.pth 174 | encoded_out/ 175 | recon/ 176 | recons/ 177 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/base.yml: -------------------------------------------------------------------------------- 1 | # Model setup 2 | DAC.sample_rate: 44100 3 | DAC.encoder_dim: 64 4 | DAC.encoder_rates: [2, 4, 8, 8] 5 | DAC.decoder_dim: 1536 6 | DAC.decoder_rates: [8, 8, 4, 2] 7 | 8 | # Quantization 9 | DAC.n_codebooks: 9 10 | DAC.codebook_size: 1024 11 | DAC.codebook_dim: 8 12 | DAC.quantizer_dropout: 1.0 13 | 14 | # Discriminator 15 | Discriminator.sample_rate: 44100 16 | Discriminator.rates: [] 17 | Discriminator.periods: [2, 3, 5, 7, 11] 18 | Discriminator.fft_sizes: [2048, 1024, 512] 19 | Discriminator.bands: 20 | - [0.0, 0.1] 21 | - [0.1, 0.25] 22 | - [0.25, 0.5] 23 | - [0.5, 0.75] 24 | - [0.75, 1.0] 25 | 26 | # Optimization 27 | AdamW.betas: [0.8, 0.99] 28 | AdamW.lr: 0.0001 29 | ExponentialLR.gamma: 0.999996 30 | 31 | amp: false 32 | val_batch_size: 100 33 | device: cuda 34 | num_iters: 250000 35 | save_iters: [10000, 50000, 100000, 200000] 36 | valid_freq: 1000 37 | sample_freq: 10000 38 | num_workers: 32 39 | val_idx: [0, 1, 2, 3, 4, 5, 6, 7] 40 | seed: 0 41 | lambdas: 42 | mel/loss: 15.0 43 | adv/feat_loss: 2.0 44 | adv/gen_loss: 1.0 45 | vq/commitment_loss: 0.25 46 | vq/codebook_loss: 1.0 47 | 48 | VolumeNorm.db: [const, -16] 49 | 50 | # Transforms 51 | build_transform.preprocess: 52 | - Identity 53 | build_transform.augment_prob: 0.0 54 | build_transform.augment: 55 | - Identity 56 | build_transform.postprocess: 57 | - VolumeNorm 58 | - RescaleAudio 59 | - ShiftPhase 60 | 61 | # Loss setup 62 | MultiScaleSTFTLoss.window_lengths: [2048, 512] 63 | MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] 64 | MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 65 | MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] 66 | MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] 67 | MelSpectrogramLoss.pow: 1.0 68 | MelSpectrogramLoss.clamp_eps: 1.0e-5 69 | MelSpectrogramLoss.mag_weight: 0.0 70 | 71 | # Data 72 | batch_size: 72 73 | train/AudioDataset.duration: 0.38 74 | train/AudioDataset.n_examples: 10000000 75 | 76 | val/AudioDataset.duration: 5.0 77 | val/build_transform.augment_prob: 1.0 78 | val/AudioDataset.n_examples: 250 79 | 80 | test/AudioDataset.duration: 10.0 81 | test/build_transform.augment_prob: 1.0 82 | test/AudioDataset.n_examples: 1000 83 | 84 | AudioLoader.shuffle: true 85 | AudioDataset.without_replacement: true 86 | 87 | train/build_dataset.folders: 88 | speech_fb: 89 | - /data/daps/train 90 | speech_hq: 91 | - /data/vctk 92 | - /data/vocalset 93 | - /data/read_speech 94 | - /data/french_speech 95 | speech_uq: 96 | - /data/emotional_speech/ 97 | - /data/common_voice/ 98 | - /data/german_speech/ 99 | - /data/russian_speech/ 100 | - /data/spanish_speech/ 101 | music_hq: 102 | - /data/musdb/train 103 | music_uq: 104 | - /data/jamendo 105 | general: 106 | - /data/audioset/data/unbalanced_train_segments/ 107 | - /data/audioset/data/balanced_train_segments/ 108 | 109 | val/build_dataset.folders: 110 | speech_hq: 111 | - /data/daps/val 112 | music_hq: 113 | - /data/musdb/test 114 | general: 115 | - /data/audioset/data/eval_segments/ 116 | 117 | test/build_dataset.folders: 118 | speech_hq: 119 | - /data/daps/test 120 | music_hq: 121 | - /data/musdb/test 122 | general: 123 | - /data/audioset/data/eval_segments/ 124 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/final/16khz.yml: -------------------------------------------------------------------------------- 1 | # Model setup 2 | DAC.sample_rate: 16000 3 | DAC.encoder_dim: 64 4 | DAC.encoder_rates: [2, 4, 5, 8] 5 | DAC.decoder_dim: 1536 6 | DAC.decoder_rates: [8, 5, 4, 2] 7 | 8 | # Quantization 9 | DAC.n_codebooks: 12 10 | DAC.codebook_size: 1024 11 | DAC.codebook_dim: 8 12 | DAC.quantizer_dropout: 0.5 13 | 14 | # Discriminator 15 | Discriminator.sample_rate: 16000 16 | Discriminator.rates: [] 17 | Discriminator.periods: [2, 3, 5, 7, 11] 18 | Discriminator.fft_sizes: [2048, 1024, 512] 19 | Discriminator.bands: 20 | - [0.0, 0.1] 21 | - [0.1, 0.25] 22 | - [0.25, 0.5] 23 | - [0.5, 0.75] 24 | - [0.75, 1.0] 25 | 26 | # Optimization 27 | AdamW.betas: [0.8, 0.99] 28 | AdamW.lr: 0.0001 29 | ExponentialLR.gamma: 0.999996 30 | 31 | amp: false 32 | val_batch_size: 100 33 | device: cuda 34 | num_iters: 400000 35 | save_iters: [10000, 50000, 100000, 200000] 36 | valid_freq: 1000 37 | sample_freq: 10000 38 | num_workers: 32 39 | val_idx: [0, 1, 2, 3, 4, 5, 6, 7] 40 | seed: 0 41 | lambdas: 42 | mel/loss: 15.0 43 | adv/feat_loss: 2.0 44 | adv/gen_loss: 1.0 45 | vq/commitment_loss: 0.25 46 | vq/codebook_loss: 1.0 47 | 48 | VolumeNorm.db: [const, -16] 49 | 50 | # Transforms 51 | build_transform.preprocess: 52 | - Identity 53 | build_transform.augment_prob: 0.0 54 | build_transform.augment: 55 | - Identity 56 | build_transform.postprocess: 57 | - VolumeNorm 58 | - RescaleAudio 59 | - ShiftPhase 60 | 61 | # Loss setup 62 | MultiScaleSTFTLoss.window_lengths: [2048, 512] 63 | MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] 64 | MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 65 | MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] 66 | MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] 67 | MelSpectrogramLoss.pow: 1.0 68 | MelSpectrogramLoss.clamp_eps: 1.0e-5 69 | MelSpectrogramLoss.mag_weight: 0.0 70 | 71 | # Data 72 | batch_size: 72 73 | train/AudioDataset.duration: 0.38 74 | train/AudioDataset.n_examples: 10000000 75 | 76 | val/AudioDataset.duration: 5.0 77 | val/build_transform.augment_prob: 1.0 78 | val/AudioDataset.n_examples: 250 79 | 80 | test/AudioDataset.duration: 10.0 81 | test/build_transform.augment_prob: 1.0 82 | test/AudioDataset.n_examples: 1000 83 | 84 | AudioLoader.shuffle: true 85 | AudioDataset.without_replacement: true 86 | 87 | train/build_dataset.folders: 88 | speech_fb: 89 | - /data/daps/train 90 | speech_hq: 91 | - /data/vctk 92 | - /data/vocalset 93 | - /data/read_speech 94 | - /data/french_speech 95 | speech_uq: 96 | - /data/emotional_speech/ 97 | - /data/common_voice/ 98 | - /data/german_speech/ 99 | - /data/russian_speech/ 100 | - /data/spanish_speech/ 101 | music_hq: 102 | - /data/musdb/train 103 | music_uq: 104 | - /data/jamendo 105 | general: 106 | - /data/audioset/data/unbalanced_train_segments/ 107 | - /data/audioset/data/balanced_train_segments/ 108 | 109 | val/build_dataset.folders: 110 | speech_hq: 111 | - /data/daps/val 112 | music_hq: 113 | - /data/musdb/test 114 | general: 115 | - /data/audioset/data/eval_segments/ 116 | 117 | test/build_dataset.folders: 118 | speech_hq: 119 | - /data/daps/test 120 | music_hq: 121 | - /data/musdb/test 122 | general: 123 | - /data/audioset/data/eval_segments/ 124 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/final/24khz.yml: -------------------------------------------------------------------------------- 1 | # Model setup 2 | DAC.sample_rate: 24000 3 | DAC.encoder_dim: 64 4 | DAC.encoder_rates: [2, 4, 5, 8] 5 | DAC.decoder_dim: 1536 6 | DAC.decoder_rates: [8, 5, 4, 2] 7 | 8 | # Quantization 9 | DAC.n_codebooks: 32 10 | DAC.codebook_size: 1024 11 | DAC.codebook_dim: 8 12 | DAC.quantizer_dropout: 0.5 13 | 14 | # Discriminator 15 | Discriminator.sample_rate: 24000 16 | Discriminator.rates: [] 17 | Discriminator.periods: [2, 3, 5, 7, 11] 18 | Discriminator.fft_sizes: [2048, 1024, 512] 19 | Discriminator.bands: 20 | - [0.0, 0.1] 21 | - [0.1, 0.25] 22 | - [0.25, 0.5] 23 | - [0.5, 0.75] 24 | - [0.75, 1.0] 25 | 26 | # Optimization 27 | AdamW.betas: [0.8, 0.99] 28 | AdamW.lr: 0.0001 29 | ExponentialLR.gamma: 0.999996 30 | 31 | amp: false 32 | val_batch_size: 100 33 | device: cuda 34 | num_iters: 400000 35 | save_iters: [10000, 50000, 100000, 200000] 36 | valid_freq: 1000 37 | sample_freq: 10000 38 | num_workers: 32 39 | val_idx: [0, 1, 2, 3, 4, 5, 6, 7] 40 | seed: 0 41 | lambdas: 42 | mel/loss: 15.0 43 | adv/feat_loss: 2.0 44 | adv/gen_loss: 1.0 45 | vq/commitment_loss: 0.25 46 | vq/codebook_loss: 1.0 47 | 48 | VolumeNorm.db: [const, -16] 49 | 50 | # Transforms 51 | build_transform.preprocess: 52 | - Identity 53 | build_transform.augment_prob: 0.0 54 | build_transform.augment: 55 | - Identity 56 | build_transform.postprocess: 57 | - VolumeNorm 58 | - RescaleAudio 59 | - ShiftPhase 60 | 61 | # Loss setup 62 | MultiScaleSTFTLoss.window_lengths: [2048, 512] 63 | MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] 64 | MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 65 | MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] 66 | MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] 67 | MelSpectrogramLoss.pow: 1.0 68 | MelSpectrogramLoss.clamp_eps: 1.0e-5 69 | MelSpectrogramLoss.mag_weight: 0.0 70 | 71 | # Data 72 | batch_size: 72 73 | train/AudioDataset.duration: 0.38 74 | train/AudioDataset.n_examples: 10000000 75 | 76 | val/AudioDataset.duration: 5.0 77 | val/build_transform.augment_prob: 1.0 78 | val/AudioDataset.n_examples: 250 79 | 80 | test/AudioDataset.duration: 10.0 81 | test/build_transform.augment_prob: 1.0 82 | test/AudioDataset.n_examples: 1000 83 | 84 | AudioLoader.shuffle: true 85 | AudioDataset.without_replacement: true 86 | 87 | train/build_dataset.folders: 88 | speech_fb: 89 | - /data/daps/train 90 | speech_hq: 91 | - /data/vctk 92 | - /data/vocalset 93 | - /data/read_speech 94 | - /data/french_speech 95 | speech_uq: 96 | - /data/emotional_speech/ 97 | - /data/common_voice/ 98 | - /data/german_speech/ 99 | - /data/russian_speech/ 100 | - /data/spanish_speech/ 101 | music_hq: 102 | - /data/musdb/train 103 | music_uq: 104 | - /data/jamendo 105 | general: 106 | - /data/audioset/data/unbalanced_train_segments/ 107 | - /data/audioset/data/balanced_train_segments/ 108 | 109 | val/build_dataset.folders: 110 | speech_hq: 111 | - /data/daps/val 112 | music_hq: 113 | - /data/musdb/test 114 | general: 115 | - /data/audioset/data/eval_segments/ 116 | 117 | test/build_dataset.folders: 118 | speech_hq: 119 | - /data/daps/test 120 | music_hq: 121 | - /data/musdb/test 122 | general: 123 | - /data/audioset/data/eval_segments/ 124 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/final/44khz.yml: -------------------------------------------------------------------------------- 1 | # Model setup 2 | DAC.sample_rate: 44100 3 | DAC.encoder_dim: 64 4 | DAC.encoder_rates: [2, 4, 8, 8] 5 | DAC.decoder_dim: 1536 6 | DAC.decoder_rates: [8, 8, 4, 2] 7 | 8 | # Quantization 9 | DAC.n_codebooks: 9 10 | DAC.codebook_size: 1024 11 | DAC.codebook_dim: 8 12 | DAC.quantizer_dropout: 0.5 13 | 14 | # Discriminator 15 | Discriminator.sample_rate: 44100 16 | Discriminator.rates: [] 17 | Discriminator.periods: [2, 3, 5, 7, 11] 18 | Discriminator.fft_sizes: [2048, 1024, 512] 19 | Discriminator.bands: 20 | - [0.0, 0.1] 21 | - [0.1, 0.25] 22 | - [0.25, 0.5] 23 | - [0.5, 0.75] 24 | - [0.75, 1.0] 25 | 26 | # Optimization 27 | AdamW.betas: [0.8, 0.99] 28 | AdamW.lr: 0.0001 29 | ExponentialLR.gamma: 0.999996 30 | 31 | amp: false 32 | val_batch_size: 100 33 | device: cuda 34 | num_iters: 400000 35 | save_iters: [10000, 50000, 100000, 200000] 36 | valid_freq: 1000 37 | sample_freq: 10000 38 | num_workers: 32 39 | val_idx: [0, 1, 2, 3, 4, 5, 6, 7] 40 | seed: 0 41 | lambdas: 42 | mel/loss: 15.0 43 | adv/feat_loss: 2.0 44 | adv/gen_loss: 1.0 45 | vq/commitment_loss: 0.25 46 | vq/codebook_loss: 1.0 47 | 48 | VolumeNorm.db: [const, -16] 49 | 50 | # Transforms 51 | build_transform.preprocess: 52 | - Identity 53 | build_transform.augment_prob: 0.0 54 | build_transform.augment: 55 | - Identity 56 | build_transform.postprocess: 57 | - VolumeNorm 58 | - RescaleAudio 59 | - ShiftPhase 60 | 61 | # Loss setup 62 | MultiScaleSTFTLoss.window_lengths: [2048, 512] 63 | MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] 64 | MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 65 | MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] 66 | MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] 67 | MelSpectrogramLoss.pow: 1.0 68 | MelSpectrogramLoss.clamp_eps: 1.0e-5 69 | MelSpectrogramLoss.mag_weight: 0.0 70 | 71 | # Data 72 | batch_size: 72 73 | train/AudioDataset.duration: 0.38 74 | train/AudioDataset.n_examples: 10000000 75 | 76 | val/AudioDataset.duration: 5.0 77 | val/build_transform.augment_prob: 1.0 78 | val/AudioDataset.n_examples: 250 79 | 80 | test/AudioDataset.duration: 10.0 81 | test/build_transform.augment_prob: 1.0 82 | test/AudioDataset.n_examples: 1000 83 | 84 | AudioLoader.shuffle: true 85 | AudioDataset.without_replacement: true 86 | 87 | train/build_dataset.folders: 88 | speech_fb: 89 | - /data/daps/train 90 | speech_hq: 91 | - /data/vctk 92 | - /data/vocalset 93 | - /data/read_speech 94 | - /data/french_speech 95 | speech_uq: 96 | - /data/emotional_speech/ 97 | - /data/common_voice/ 98 | - /data/german_speech/ 99 | - /data/russian_speech/ 100 | - /data/spanish_speech/ 101 | music_hq: 102 | - /data/musdb/train 103 | music_uq: 104 | - /data/jamendo 105 | general: 106 | - /data/audioset/data/unbalanced_train_segments/ 107 | - /data/audioset/data/balanced_train_segments/ 108 | 109 | val/build_dataset.folders: 110 | speech_hq: 111 | - /data/daps/val 112 | music_hq: 113 | - /data/musdb/test 114 | general: 115 | - /data/audioset/data/eval_segments/ 116 | 117 | test/build_dataset.folders: 118 | speech_hq: 119 | - /data/daps/test 120 | music_hq: 121 | - /data/musdb/test 122 | general: 123 | - /data/audioset/data/eval_segments/ 124 | -------------------------------------------------------------------------------- /descriptaudiocodec/dac/utils/decode.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | import argbind 5 | import numpy as np 6 | import torch 7 | from audiotools import AudioSignal 8 | from tqdm import tqdm 9 | 10 | from dac import DACFile 11 | from dac.utils import load_model 12 | 13 | warnings.filterwarnings("ignore", category=UserWarning) 14 | 15 | 16 | @argbind.bind(group="decode", positional=True, without_prefix=True) 17 | @torch.inference_mode() 18 | @torch.no_grad() 19 | def decode( 20 | input: str, 21 | output: str = "", 22 | weights_path: str = "", 23 | model_tag: str = "latest", 24 | model_bitrate: str = "8kbps", 25 | device: str = "cuda", 26 | model_type: str = "44khz", 27 | verbose: bool = False, 28 | ): 29 | """Decode audio from codes. 30 | 31 | Parameters 32 | ---------- 33 | input : str 34 | Path to input directory or file 35 | output : str, optional 36 | Path to output directory, by default "". 37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 38 | weights_path : str, optional 39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 40 | model_tag and model_type. 41 | model_tag : str, optional 42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 43 | model_bitrate: str 44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 45 | device : str, optional 46 | Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. 47 | model_type : str, optional 48 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 49 | """ 50 | generator = load_model( 51 | model_type=model_type, 52 | model_bitrate=model_bitrate, 53 | tag=model_tag, 54 | load_path=weights_path, 55 | ) 56 | generator.to(device) 57 | generator.eval() 58 | 59 | # Find all .dac files in input directory 60 | _input = Path(input) 61 | input_files = list(_input.glob("**/*.dac")) 62 | 63 | # If input is a .dac file, add it to the list 64 | if _input.suffix == ".dac": 65 | input_files.append(_input) 66 | 67 | # Create output directory 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"): 72 | # Load file 73 | artifact = DACFile.load(input_files[i]) 74 | 75 | # Reconstruct audio from codes 76 | recons = generator.decompress(artifact, verbose=verbose) 77 | 78 | # Compute output path 79 | relative_path = input_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = input_files[i] 84 | output_name = relative_path.with_suffix(".wav").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | # Write to file 89 | recons.write(output_path) 90 | 91 | 92 | if __name__ == "__main__": 93 | args = argbind.parse_args() 94 | with argbind.scope(args): 95 | decode() 96 | -------------------------------------------------------------------------------- /descriptaudiocodec/conf/final/44khz-16kbps.yml: -------------------------------------------------------------------------------- 1 | # Model setup 2 | DAC.sample_rate: 44100 3 | DAC.encoder_dim: 64 4 | DAC.encoder_rates: [2, 4, 8, 8] 5 | DAC.latent_dim: 128 6 | DAC.decoder_dim: 1536 7 | DAC.decoder_rates: [8, 8, 4, 2] 8 | 9 | # Quantization 10 | DAC.n_codebooks: 18 # Max bitrate of 16kbps 11 | DAC.codebook_size: 1024 12 | DAC.codebook_dim: 8 13 | DAC.quantizer_dropout: 0.5 14 | 15 | # Discriminator 16 | Discriminator.sample_rate: 44100 17 | Discriminator.rates: [] 18 | Discriminator.periods: [2, 3, 5, 7, 11] 19 | Discriminator.fft_sizes: [2048, 1024, 512] 20 | Discriminator.bands: 21 | - [0.0, 0.1] 22 | - [0.1, 0.25] 23 | - [0.25, 0.5] 24 | - [0.5, 0.75] 25 | - [0.75, 1.0] 26 | 27 | # Optimization 28 | AdamW.betas: [0.8, 0.99] 29 | AdamW.lr: 0.0001 30 | ExponentialLR.gamma: 0.999996 31 | 32 | amp: false 33 | val_batch_size: 100 34 | device: cuda 35 | num_iters: 400000 36 | save_iters: [10000, 50000, 100000, 200000] 37 | valid_freq: 1000 38 | sample_freq: 10000 39 | num_workers: 32 40 | val_idx: [0, 1, 2, 3, 4, 5, 6, 7] 41 | seed: 0 42 | lambdas: 43 | mel/loss: 15.0 44 | adv/feat_loss: 2.0 45 | adv/gen_loss: 1.0 46 | vq/commitment_loss: 0.25 47 | vq/codebook_loss: 1.0 48 | 49 | VolumeNorm.db: [const, -16] 50 | 51 | # Transforms 52 | build_transform.preprocess: 53 | - Identity 54 | build_transform.augment_prob: 0.0 55 | build_transform.augment: 56 | - Identity 57 | build_transform.postprocess: 58 | - VolumeNorm 59 | - RescaleAudio 60 | - ShiftPhase 61 | 62 | # Loss setup 63 | MultiScaleSTFTLoss.window_lengths: [2048, 512] 64 | MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] 65 | MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 66 | MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] 67 | MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] 68 | MelSpectrogramLoss.pow: 1.0 69 | MelSpectrogramLoss.clamp_eps: 1.0e-5 70 | MelSpectrogramLoss.mag_weight: 0.0 71 | 72 | # Data 73 | batch_size: 72 74 | train/AudioDataset.duration: 0.38 75 | train/AudioDataset.n_examples: 10000000 76 | 77 | val/AudioDataset.duration: 5.0 78 | val/build_transform.augment_prob: 1.0 79 | val/AudioDataset.n_examples: 250 80 | 81 | test/AudioDataset.duration: 10.0 82 | test/build_transform.augment_prob: 1.0 83 | test/AudioDataset.n_examples: 1000 84 | 85 | AudioLoader.shuffle: true 86 | AudioDataset.without_replacement: true 87 | 88 | train/build_dataset.folders: 89 | speech_fb: 90 | - /data/daps/train 91 | speech_hq: 92 | - /data/vctk 93 | - /data/vocalset 94 | - /data/read_speech 95 | - /data/french_speech 96 | speech_uq: 97 | - /data/emotional_speech/ 98 | - /data/common_voice/ 99 | - /data/german_speech/ 100 | - /data/russian_speech/ 101 | - /data/spanish_speech/ 102 | music_hq: 103 | - /data/musdb/train 104 | music_uq: 105 | - /data/jamendo 106 | general: 107 | - /data/audioset/data/unbalanced_train_segments/ 108 | - /data/audioset/data/balanced_train_segments/ 109 | 110 | val/build_dataset.folders: 111 | speech_hq: 112 | - /data/daps/val 113 | music_hq: 114 | - /data/musdb/test 115 | general: 116 | - /data/audioset/data/eval_segments/ 117 | 118 | test/build_dataset.folders: 119 | speech_hq: 120 | - /data/daps/test 121 | music_hq: 122 | - /data/musdb/test 123 | general: 124 | - /data/audioset/data/eval_segments/ 125 | -------------------------------------------------------------------------------- /descriptaudiocodec/dac/utils/encode.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from pathlib import Path 4 | 5 | import argbind 6 | import numpy as np 7 | import torch 8 | from audiotools import AudioSignal 9 | from audiotools.core import util 10 | from tqdm import tqdm 11 | 12 | from dac.utils import load_model 13 | 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | 17 | @argbind.bind(group="encode", positional=True, without_prefix=True) 18 | @torch.inference_mode() 19 | @torch.no_grad() 20 | def encode( 21 | input: str, 22 | output: str = "", 23 | weights_path: str = "", 24 | model_tag: str = "latest", 25 | model_bitrate: str = "8kbps", 26 | n_quantizers: int = None, 27 | device: str = "cuda", 28 | model_type: str = "44khz", 29 | win_duration: float = 5.0, 30 | verbose: bool = False, 31 | ): 32 | """Encode audio files in input path to .dac format. 33 | 34 | Parameters 35 | ---------- 36 | input : str 37 | Path to input audio file or directory 38 | output : str, optional 39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 40 | weights_path : str, optional 41 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 42 | model_tag and model_type. 43 | model_tag : str, optional 44 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 45 | model_bitrate: str 46 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 47 | n_quantizers : int, optional 48 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. 49 | device : str, optional 50 | Device to use, by default "cuda" 51 | model_type : str, optional 52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 53 | """ 54 | generator = load_model( 55 | model_type=model_type, 56 | model_bitrate=model_bitrate, 57 | tag=model_tag, 58 | load_path=weights_path, 59 | ) 60 | generator.to(device) 61 | generator.eval() 62 | kwargs = {"n_quantizers": n_quantizers} 63 | 64 | # Find all audio files in input path 65 | input = Path(input) 66 | audio_files = util.find_audio(input) 67 | 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(audio_files)), desc="Encoding files"): 72 | # Load file 73 | signal = AudioSignal(audio_files[i]) 74 | 75 | # Encode audio to .dac format 76 | artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) 77 | 78 | # Compute output path 79 | relative_path = audio_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = audio_files[i] 84 | output_name = relative_path.with_suffix(".dac").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | artifact.save(output_path) 89 | 90 | 91 | if __name__ == "__main__": 92 | args = argbind.parse_args() 93 | with argbind.scope(args): 94 | encode() 95 | -------------------------------------------------------------------------------- /descriptaudiocodec/scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import multiprocessing as mp 3 | from concurrent.futures import ProcessPoolExecutor 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | 7 | import argbind 8 | import torch 9 | from audiotools import AudioSignal 10 | from audiotools import metrics 11 | from audiotools.core import util 12 | from audiotools.ml.decorators import Tracker 13 | from train import losses 14 | 15 | 16 | @dataclass 17 | class State: 18 | stft_loss: losses.MultiScaleSTFTLoss 19 | mel_loss: losses.MelSpectrogramLoss 20 | waveform_loss: losses.L1Loss 21 | sisdr_loss: losses.SISDRLoss 22 | 23 | 24 | def get_metrics(signal_path, recons_path, state): 25 | output = {} 26 | signal = AudioSignal(signal_path) 27 | recons = AudioSignal(recons_path) 28 | for sr in [22050, 44100]: 29 | x = signal.clone().resample(sr) 30 | y = recons.clone().resample(sr) 31 | k = "22k" if sr == 22050 else "44k" 32 | output.update( 33 | { 34 | f"mel-{k}": state.mel_loss(x, y), 35 | f"stft-{k}": state.stft_loss(x, y), 36 | f"waveform-{k}": state.waveform_loss(x, y), 37 | f"sisdr-{k}": state.sisdr_loss(x, y), 38 | f"visqol-audio-{k}": metrics.quality.visqol(x, y), 39 | f"visqol-speech-{k}": metrics.quality.visqol(x, y, "speech"), 40 | } 41 | ) 42 | output["path"] = signal.path_to_file 43 | output.update(signal.metadata) 44 | return output 45 | 46 | 47 | @argbind.bind(without_prefix=True) 48 | @torch.no_grad() 49 | def evaluate( 50 | input: str = "samples/input", 51 | output: str = "samples/output", 52 | n_proc: int = 50, 53 | ): 54 | tracker = Tracker() 55 | 56 | waveform_loss = losses.L1Loss() 57 | stft_loss = losses.MultiScaleSTFTLoss() 58 | mel_loss = losses.MelSpectrogramLoss() 59 | sisdr_loss = losses.SISDRLoss() 60 | 61 | state = State( 62 | waveform_loss=waveform_loss, 63 | stft_loss=stft_loss, 64 | mel_loss=mel_loss, 65 | sisdr_loss=sisdr_loss, 66 | ) 67 | 68 | audio_files = util.find_audio(input) 69 | output = Path(output) 70 | output.mkdir(parents=True, exist_ok=True) 71 | 72 | @tracker.track("metrics", len(audio_files)) 73 | def record(future, writer): 74 | o = future.result() 75 | for k, v in o.items(): 76 | if torch.is_tensor(v): 77 | o[k] = v.item() 78 | writer.writerow(o) 79 | o.pop("path") 80 | return o 81 | 82 | futures = [] 83 | with tracker.live: 84 | with open(output / "metrics.csv", "w") as csvfile: 85 | with ProcessPoolExecutor(n_proc, mp.get_context("fork")) as pool: 86 | for i in range(len(audio_files)): 87 | future = pool.submit( 88 | get_metrics, audio_files[i], output / audio_files[i].name, state 89 | ) 90 | futures.append(future) 91 | 92 | keys = list(futures[0].result().keys()) 93 | writer = csv.DictWriter(csvfile, fieldnames=keys) 94 | writer.writeheader() 95 | 96 | for future in futures: 97 | record(future, writer) 98 | 99 | tracker.done("test", f"N={len(audio_files)}") 100 | 101 | 102 | if __name__ == "__main__": 103 | args = argbind.parse_args() 104 | with argbind.scope(args): 105 | evaluate() 106 | -------------------------------------------------------------------------------- /distributed/launch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # Diffsound 3 | # code based https://github.com/cientgu/VQ-Diffusion 4 | # ------------------------------------------ 5 | import os 6 | 7 | import torch 8 | from torch import distributed as dist 9 | from torch import multiprocessing as mp 10 | 11 | # import distributed as dist_fn 12 | import distributed.distributed as dist_fn 13 | 14 | 15 | def find_free_port(): 16 | import socket 17 | 18 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 19 | 20 | sock.bind(("", 0)) 21 | port = sock.getsockname()[1] 22 | sock.close() 23 | 24 | return port 25 | 26 | 27 | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()): 28 | world_size = n_machine * n_gpu_per_machine 29 | 30 | if world_size > 1: 31 | # if "OMP_NUM_THREADS" not in os.environ: 32 | # os.environ["OMP_NUM_THREADS"] = "1" 33 | if dist_url == "auto": 34 | if n_machine != 1: 35 | raise ValueError('dist_url="auto" not supported in multi-machine jobs') 36 | port = find_free_port() 37 | dist_url = f"tcp://127.0.0.1:{port}" 38 | print('dist_url ', dist_url) 39 | print('n_machine ', n_machine) 40 | print('args ', args) 41 | print('world_size ', world_size) 42 | print('machine_rank ', machine_rank) 43 | if n_machine > 1 and dist_url.startswith("file://"): 44 | raise ValueError( 45 | "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://" 46 | ) 47 | 48 | mp.spawn( 49 | distributed_worker, 50 | nprocs=n_gpu_per_machine, 51 | args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args), 52 | daemon=False, 53 | ) 54 | # n_machine ? world_size 55 | else: 56 | local_rank = 0 57 | fn(local_rank, *args) 58 | 59 | 60 | def distributed_worker( 61 | local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args 62 | ): 63 | if not torch.cuda.is_available(): 64 | raise OSError("CUDA is not available. Please check your environments") 65 | 66 | global_rank = machine_rank * n_gpu_per_machine + local_rank 67 | print('local_rank ',local_rank) 68 | print('global_rank ',global_rank) 69 | try: 70 | dist.init_process_group( 71 | backend="NCCL", 72 | init_method=dist_url, 73 | world_size=world_size, 74 | rank=global_rank, 75 | ) 76 | 77 | except Exception: 78 | raise OSError("failed to initialize NCCL groups") 79 | 80 | # changed 81 | dist_fn.synchronize() 82 | 83 | if n_gpu_per_machine > torch.cuda.device_count(): 84 | raise ValueError( 85 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" 86 | ) 87 | 88 | torch.cuda.set_device(local_rank) 89 | 90 | if dist_fn.LOCAL_PROCESS_GROUP is not None: 91 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") 92 | 93 | # change paert 94 | 95 | n_machine = world_size // n_gpu_per_machine 96 | for i in range(n_machine): 97 | ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) 98 | pg = dist.new_group(ranks_on_i) 99 | 100 | if i == machine_rank: 101 | dist_fn.LOCAL_PROCESS_GROUP = pg 102 | 103 | fn(local_rank, *args) 104 | -------------------------------------------------------------------------------- /config/codec_16k_6kbps_v3_vqdp.yaml: -------------------------------------------------------------------------------- 1 | 2 | ########### model config ########### 3 | generator: 4 | name: SoundStream 5 | config: 6 | n_filters: 32 7 | D: 256 8 | target_bandwidths: [0.5, 1, 1.5, 2, 4] 9 | ratios: [8, 5, 4, 2] # downsampling by 320 10 | sample_rate: 16000 11 | bins: 1024 12 | semantic_techer: hubert_base 13 | # Discriminator list 14 | #d_list: ['mpd', 'msd', 'mfd'] 15 | d_list: ['mfd'] 16 | 17 | mfd: 18 | name: MultiFrequencyDiscriminator 19 | config: 20 | hop_lengths: [32, 64, 128, 256, 512, 1024] 21 | hidden_channels: [64, 128, 256, 512, 512, 512] 22 | domain: double 23 | mel_scale: true 24 | sample_rate: 16000 25 | 26 | mpd: 27 | name: MultiPeriodDiscriminator 28 | config: 29 | period_sizes: [2, 3, 5, 7, 11] 30 | period_kernel_size: 5 31 | 32 | msd: 33 | name: MultiScaleDiscriminator 34 | config: 35 | num_scales: 3 36 | pool_kernel_size: 4 37 | pool_stride: 2 38 | 39 | ########### optimizer config ########### 40 | optimizer: 41 | g: 42 | name: AdamW 43 | config: 44 | lr: 2e-4 45 | betas: [0.8, 0.99] 46 | eps: 1.0e-6 47 | 48 | d: 49 | name: AdamW 50 | config: 51 | lr: 2e-4 52 | betas: [0.8, 0.99] 53 | eps: 1.0e-6 54 | 55 | lr_scheduler: 56 | g: 57 | name: ExponentialLR 58 | config: 59 | gamma: 0.999 60 | d: 61 | name: ExponentialLR 62 | config: 63 | gamma: 0.999 64 | 65 | ########### criterion config ########### 66 | criterion: 67 | g_criterion: 68 | name: losses.generator_loss.GeneratorSTFTLoss 69 | config: 70 | use_mel_loss: false 71 | #adv_criterion: LeastDLoss 72 | adv_criterion: MSEGLoss 73 | mel_loss_weight: 45 74 | use_feature_match: true 75 | feat_match_loss_weight: 20 76 | use_full_stft_loss: true # Magnitude 77 | use_sub_stft_loss: true # PQMF loss 78 | full_stft_loss_weight: 1 79 | sub_stft_loss_weight: 1 80 | mel_scale_loss: 81 | sampling_rate: 16000 82 | n_fft: 1024 83 | num_mels: 80 84 | hop_size: 160 85 | win_size: 800 86 | fmin: 0 87 | full_multi_scale_stft_loss: # Full-band multi-scale STFT loss. 88 | fft_sizes: [512, 1024, 2048] 89 | win_sizes: [480, 960, 1200] 90 | hop_sizes: [120, 240, 300] 91 | sub_multi_scale_stft_loss: # Sub-band multi-scale STFT loss. 92 | num_bands: 6 93 | fft_sizes: [128, 256, 256] 94 | win_sizes: [80, 120, 200] 95 | hop_sizes: [20, 40, 50] 96 | 97 | d_criterion: 98 | name: losses.discriminator_loss.MSEDiscriminatorLoss 99 | config: null 100 | 101 | commit_loss_weight: 1. #1000 102 | semantic_loss_weight: 100 103 | ########### training and data config ########### 104 | 105 | training_file: "/aifs4su/data/zheny/fairseq/vae_v2/codec_final/list_librispeech/train.txt" 106 | validation_file: "/aifs4su/data/zheny/fairseq/vae_v2/codec_final/list_librispeech/valid.txt" 107 | 108 | seed: 2333 109 | cudnn_deterministic: false 110 | tensorboard: true # whether to use tensorboard 111 | #checkpoint_interval: 5 112 | #summary_interval: 10 113 | #validation_interval: 10 114 | 115 | checkpoint_interval: 5000 116 | summary_interval: 100 117 | validation_interval: 5000 118 | 119 | num_epoches: 5000 120 | print_freq: 10 121 | discriminator_iter_start: 0 # start step after which we update discriminators 122 | num_ckpt_keep: 10 123 | 124 | segment_size: 48000 125 | audio_norm_scale: 0.95 126 | batch_size: 8 127 | num_workers: 8 128 | num_plots: 8 129 | -------------------------------------------------------------------------------- /descriptaudiocodec/scripts/mushra.py: -------------------------------------------------------------------------------- 1 | import string 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import argbind 7 | import gradio as gr 8 | from audiotools import preference as pr 9 | 10 | 11 | @argbind.bind(without_prefix=True) 12 | @dataclass 13 | class Config: 14 | folder: str = None 15 | save_path: str = "results.csv" 16 | conditions: List[str] = None 17 | reference: str = None 18 | seed: int = 0 19 | share: bool = False 20 | n_samples: int = 10 21 | 22 | 23 | def get_text(wav_file: str): 24 | txt_file = Path(wav_file).with_suffix(".txt") 25 | if Path(txt_file).exists(): 26 | with open(txt_file, "r") as f: 27 | txt = f.read() 28 | else: 29 | txt = "" 30 | return f"""
{txt}
""" 31 | 32 | 33 | def main(config: Config): 34 | with gr.Blocks() as app: 35 | save_path = config.save_path 36 | samples = gr.State(pr.Samples(config.folder, n_samples=config.n_samples)) 37 | 38 | reference = config.reference 39 | conditions = config.conditions 40 | 41 | player = pr.Player(app) 42 | player.create() 43 | if reference is not None: 44 | player.add("Play Reference") 45 | 46 | user = pr.create_tracker(app) 47 | ratings = [] 48 | 49 | with gr.Row(): 50 | txt = gr.HTML("") 51 | 52 | with gr.Row(): 53 | gr.Button("Rate audio quality", interactive=False) 54 | with gr.Column(scale=8): 55 | gr.HTML(pr.slider_mushra) 56 | 57 | for i in range(len(conditions)): 58 | with gr.Row().style(equal_height=True): 59 | x = string.ascii_uppercase[i] 60 | player.add(f"Play {x}") 61 | with gr.Column(scale=9): 62 | ratings.append(gr.Slider(value=50, interactive=True)) 63 | 64 | def build(user, samples, *ratings): 65 | # Filter out samples user has done already, by looking in the CSV. 66 | samples.filter_completed(user, save_path) 67 | 68 | # Write results to CSV 69 | if samples.current > 0: 70 | start_idx = 1 if reference is not None else 0 71 | name = samples.names[samples.current - 1] 72 | result = {"sample": name, "user": user} 73 | for k, r in zip(samples.order[start_idx:], ratings): 74 | result[k] = r 75 | pr.save_result(result, save_path) 76 | 77 | updates, done, pbar = samples.get_next_sample(reference, conditions) 78 | wav_file = updates[0]["value"] 79 | 80 | txt_update = gr.update(value=get_text(wav_file)) 81 | 82 | return ( 83 | updates 84 | + [gr.update(value=50) for _ in ratings] 85 | + [done, samples, pbar, txt_update] 86 | ) 87 | 88 | progress = gr.HTML() 89 | begin = gr.Button("Submit", elem_id="start-survey") 90 | begin.click( 91 | fn=build, 92 | inputs=[user, samples] + ratings, 93 | outputs=player.to_list() + ratings + [begin, samples, progress, txt], 94 | ).then(None, _js=pr.reset_player) 95 | 96 | # Comment this back in to actually launch the script. 97 | app.launch(share=config.share) 98 | 99 | 100 | if __name__ == "__main__": 101 | args = argbind.parse_args() 102 | with argbind.scope(args): 103 | config = Config() 104 | main(config) 105 | -------------------------------------------------------------------------------- /descriptaudiocodec/dac/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import argbind 4 | from audiotools import ml 5 | 6 | import dac 7 | 8 | DAC = dac.model.DAC 9 | Accelerator = ml.Accelerator 10 | 11 | __MODEL_LATEST_TAGS__ = { 12 | ("44khz", "8kbps"): "0.0.1", 13 | ("24khz", "8kbps"): "0.0.4", 14 | ("16khz", "8kbps"): "0.0.5", 15 | ("44khz", "16kbps"): "1.0.0", 16 | } 17 | 18 | __MODEL_URLS__ = { 19 | ( 20 | "44khz", 21 | "0.0.1", 22 | "8kbps", 23 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", 24 | ( 25 | "24khz", 26 | "0.0.4", 27 | "8kbps", 28 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", 29 | ( 30 | "16khz", 31 | "0.0.5", 32 | "8kbps", 33 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", 34 | ( 35 | "44khz", 36 | "1.0.0", 37 | "16kbps", 38 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", 39 | } 40 | 41 | 42 | @argbind.bind(group="download", positional=True, without_prefix=True) 43 | def download( 44 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" 45 | ): 46 | """ 47 | Function that downloads the weights file from URL if a local cache is not found. 48 | 49 | Parameters 50 | ---------- 51 | model_type : str 52 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". 53 | model_bitrate: str 54 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 55 | Only 44khz model supports 16kbps. 56 | tag : str 57 | The tag of the model to download. Defaults to "latest". 58 | 59 | Returns 60 | ------- 61 | Path 62 | Directory path required to load model via audiotools. 63 | """ 64 | model_type = model_type.lower() 65 | tag = tag.lower() 66 | 67 | assert model_type in [ 68 | "44khz", 69 | "24khz", 70 | "16khz", 71 | ], "model_type must be one of '44khz', '24khz', or '16khz'" 72 | 73 | assert model_bitrate in [ 74 | "8kbps", 75 | "16kbps", 76 | ], "model_bitrate must be one of '8kbps', or '16kbps'" 77 | 78 | if tag == "latest": 79 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] 80 | 81 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) 82 | 83 | if download_link is None: 84 | raise ValueError( 85 | f"Could not find model with tag {tag} and model type {model_type}" 86 | ) 87 | 88 | local_path = ( 89 | Path.home() 90 | / ".cache" 91 | / "descript" 92 | / "dac" 93 | / f"weights_{model_type}_{model_bitrate}_{tag}.pth" 94 | ) 95 | if not local_path.exists(): 96 | local_path.parent.mkdir(parents=True, exist_ok=True) 97 | 98 | # Download the model 99 | import requests 100 | 101 | response = requests.get(download_link) 102 | 103 | if response.status_code != 200: 104 | raise ValueError( 105 | f"Could not download model. Received response code {response.status_code}" 106 | ) 107 | local_path.write_bytes(response.content) 108 | 109 | return local_path 110 | 111 | 112 | def load_model( 113 | model_type: str = "44khz", 114 | model_bitrate: str = "8kbps", 115 | tag: str = "latest", 116 | load_path: str = None, 117 | ): 118 | if not load_path: 119 | load_path = download( 120 | model_type=model_type, model_bitrate=model_bitrate, tag=tag 121 | ) 122 | generator = DAC.load(load_path) 123 | return generator 124 | -------------------------------------------------------------------------------- /models/soundstream.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Optional, Union 2 | 3 | import math 4 | import random 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | from modules.seanet import SEANetEncoder, SEANetDecoder 11 | from quantization import ResidualVectorQuantizer 12 | 13 | 14 | class SoundStream(nn.Module): 15 | """ SoundStream model or EnCodec model. 16 | 17 | Args: 18 | n_filters (int): n_filters (int): Base width for the model. 19 | D (int): Intermediate representation dimension. 20 | target_bandwidths (Sequence[int]): Target bandwidths in K-bits/second. 21 | ratios (Sequence[int]): downsampling factors, whose multiplication is the hop size. 22 | sample_rate (int): wave sampling rate. 23 | bins (int): number of code words in a codebook. 24 | normalize (bool): audio normalization. 25 | 26 | """ 27 | def __init__( 28 | self, 29 | n_filters: int = 32, 30 | D: int = 128, 31 | # target_bandwidths: Sequence[Union[int, float]] = [0.5, 1, 1.5, 2, 4, 6], 32 | target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6], 33 | ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320 34 | sample_rate: int = 16000, 35 | bins: int = 1024, 36 | normalize: bool = False, 37 | causal: bool = False, 38 | ): 39 | super().__init__() 40 | self.hop_length = np.prod(ratios) 41 | # total nb of codebooks, e.g., 6Kb/s, sr=16000 and hop_length=320 => nq = 12 42 | n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / self.hop_length) * 10)) 43 | self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz 44 | self.bits_per_codebook = int(math.log2(bins)) # 1024 => 10 45 | self.target_bandwidths = target_bandwidths 46 | self.n_q = n_q 47 | self.sample_rate = sample_rate 48 | 49 | # Encoder model 50 | self.encoder = SEANetEncoder(n_filters=n_filters, dimension=D, ratios=ratios, causal=causal) 51 | # RVQ model 52 | self.quantizer = ResidualVectorQuantizer(dimension=D, n_q=n_q, bins=bins) 53 | # Decoder model 54 | self.decoder = SEANetDecoder(n_filters= n_filters, dimension=D, ratios=ratios, causal=causal) 55 | 56 | def get_last_layer(self): 57 | return self.decoder.layers[-1].weight 58 | 59 | def forward(self, x: torch.Tensor, bw: int): 60 | e = self.encoder(x) 61 | # randomly select a band-width during training 62 | # bw = self.target_bandwidths[random.randint(0, len(self.target_bandwidths) - 1)] # [0, len(target_bandwidths) - 1], both included 63 | quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) 64 | # print('quantized ', quantized.shape) 65 | # print('codes ', codes.shape) 66 | # print('commit_loss ', commit_loss) 67 | # print('bandwidth ', bandwidth) 68 | # assert 1==2 69 | #quantized = quantized.permute(0,2,1) 70 | o = self.decoder(quantized) 71 | # print('o ', o.shape) 72 | # assert 1==2 73 | return o, commit_loss, None 74 | 75 | def encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: 76 | e = self.encoder(x) 77 | if target_bw is None: 78 | bw = self.target_bandwidths[-1] 79 | else: 80 | bw = target_bw 81 | codes = self.quantizer.encode(e, self.frame_rate, bw) 82 | return codes 83 | 84 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 85 | quantized = self.quantizer.decode(codes) 86 | o = self.decoder(quantized) 87 | return o 88 | 89 | 90 | # test 91 | if __name__ == '__main__': 92 | soundstream = SoundStream(n_filters=32, D=256) 93 | for i in range(10): 94 | print(f"Iter {i}: ") 95 | x = torch.rand(1, 1, 16000) 96 | o, _, _ = soundstream(x) 97 | print('output', o.shape) 98 | -------------------------------------------------------------------------------- /models/soundstream2.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Optional, Union 2 | 3 | import math 4 | import random 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | from modules.seanet import SEANetEncoder, SEANetDecoder 11 | from quantization import ResidualVectorQuantizer 12 | 13 | 14 | class SoundStream(nn.Module): 15 | """ SoundStream model or EnCodec model. 16 | 17 | Args: 18 | n_filters (int): n_filters (int): Base width for the model. 19 | D (int): Intermediate representation dimension. 20 | target_bandwidths (Sequence[int]): Target bandwidths in K-bits/second. 21 | ratios (Sequence[int]): downsampling factors, whose multiplication is the hop size. 22 | sample_rate (int): wave sampling rate. 23 | bins (int): number of code words in a codebook. 24 | normalize (bool): audio normalization. 25 | 26 | """ 27 | def __init__( 28 | self, 29 | n_filters: int = 32, 30 | D: int = 128, 31 | # target_bandwidths: Sequence[Union[int, float]] = [0.5, 1, 1.5, 2, 4, 6], 32 | target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6], 33 | ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320 34 | sample_rate: int = 16000, 35 | bins: int = 1024, 36 | normalize: bool = False, 37 | causal: bool = False, 38 | ): 39 | super().__init__() 40 | self.hop_length = np.prod(ratios) 41 | # total nb of codebooks, e.g., 6Kb/s, sr=16000 and hop_length=320 => nq = 12 42 | n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / self.hop_length) * 10)) 43 | self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz 44 | self.bits_per_codebook = int(math.log2(bins)) # 1024 => 10 45 | self.target_bandwidths = target_bandwidths 46 | self.n_q = n_q 47 | self.sample_rate = sample_rate 48 | 49 | # Encoder model 50 | self.encoder = SEANetEncoder(n_filters=n_filters, dimension=D, ratios=ratios, causal=causal) 51 | # RVQ model 52 | self.quantizer = ResidualVectorQuantizer(dimension=D, n_q=n_q, bins=bins) 53 | # Decoder model 54 | self.decoder = SEANetDecoder(n_filters= n_filters, dimension=D, ratios=ratios, causal=causal) 55 | 56 | def get_last_layer(self): 57 | return self.decoder.layers[-1].weight 58 | 59 | def forward(self, x: torch.Tensor, bw: int): 60 | e = self.encoder(x) 61 | # randomly select a band-width during training 62 | # bw = self.target_bandwidths[random.randint(0, len(self.target_bandwidths) - 1)] # [0, len(target_bandwidths) - 1], both included 63 | quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) 64 | # print('quantized ', quantized.shape) 65 | # print('codes ', codes.shape) 66 | # print('commit_loss ', commit_loss) 67 | # print('bandwidth ', bandwidth) 68 | # assert 1==2 69 | #quantized = quantized.permute(0,2,1) 70 | o = self.decoder(quantized) 71 | # print('o ', o.shape) 72 | # assert 1==2 73 | return o, commit_loss, None 74 | 75 | def encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: 76 | e = self.encoder(x) 77 | if target_bw is None: 78 | bw = self.target_bandwidths[-1] 79 | else: 80 | bw = target_bw 81 | codes = self.quantizer.encode(e, self.frame_rate, bw) 82 | return codes 83 | 84 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 85 | quantized = self.quantizer.decode(codes) 86 | o = self.decoder(quantized) 87 | return o 88 | 89 | 90 | # test 91 | if __name__ == '__main__': 92 | soundstream = SoundStream(n_filters=32, D=256) 93 | for i in range(10): 94 | print(f"Iter {i}: ") 95 | x = torch.rand(1, 1, 16000) 96 | o, _, _ = soundstream(x) 97 | print('output', o.shape) 98 | -------------------------------------------------------------------------------- /distributed/distributed.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # Diffsound 3 | # code based https://github.com/cientgu/VQ-Diffusion 4 | # ------------------------------------------ 5 | import math 6 | import pickle 7 | 8 | import torch 9 | from torch import distributed as dist 10 | from torch.utils import data 11 | 12 | 13 | LOCAL_PROCESS_GROUP = None 14 | 15 | 16 | def is_primary(): 17 | return get_rank() == 0 18 | 19 | 20 | def get_rank(): 21 | if not dist.is_available(): 22 | return 0 23 | 24 | if not dist.is_initialized(): 25 | return 0 26 | 27 | return dist.get_rank() 28 | 29 | 30 | def get_local_rank(): 31 | if not dist.is_available(): 32 | return 0 33 | 34 | if not dist.is_initialized(): 35 | return 0 36 | 37 | if LOCAL_PROCESS_GROUP is None: 38 | raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None") 39 | 40 | return dist.get_rank(group=LOCAL_PROCESS_GROUP) 41 | 42 | 43 | def synchronize(): 44 | if not dist.is_available(): 45 | return 46 | 47 | if not dist.is_initialized(): 48 | return 49 | 50 | world_size = dist.get_world_size() 51 | 52 | if world_size == 1: 53 | return 54 | 55 | dist.barrier() 56 | 57 | 58 | def get_world_size(): 59 | if not dist.is_available(): 60 | return 1 61 | 62 | if not dist.is_initialized(): 63 | return 1 64 | 65 | return dist.get_world_size() 66 | 67 | 68 | def is_distributed(): 69 | raise RuntimeError('Please debug this function!') 70 | return get_world_size() > 1 71 | 72 | 73 | def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False): 74 | world_size = get_world_size() 75 | 76 | if world_size == 1: 77 | return tensor 78 | dist.all_reduce(tensor, op=op, async_op=async_op) 79 | 80 | return tensor 81 | 82 | 83 | def all_gather(data): 84 | world_size = get_world_size() 85 | 86 | if world_size == 1: 87 | return [data] 88 | 89 | buffer = pickle.dumps(data) 90 | storage = torch.ByteStorage.from_buffer(buffer) 91 | tensor = torch.ByteTensor(storage).to("cuda") 92 | 93 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 94 | size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)] 95 | dist.all_gather(size_list, local_size) 96 | size_list = [int(size.item()) for size in size_list] 97 | max_size = max(size_list) 98 | 99 | tensor_list = [] 100 | for _ in size_list: 101 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 102 | 103 | if local_size != max_size: 104 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 105 | tensor = torch.cat((tensor, padding), 0) 106 | 107 | dist.all_gather(tensor_list, tensor) 108 | 109 | data_list = [] 110 | 111 | for size, tensor in zip(size_list, tensor_list): 112 | buffer = tensor.cpu().numpy().tobytes()[:size] 113 | data_list.append(pickle.loads(buffer)) 114 | 115 | return data_list 116 | 117 | 118 | def reduce_dict(input_dict, average=True): 119 | world_size = get_world_size() 120 | 121 | if world_size < 2: 122 | return input_dict 123 | 124 | with torch.no_grad(): 125 | keys = [] 126 | values = [] 127 | 128 | for k in sorted(input_dict.keys()): 129 | keys.append(k) 130 | values.append(input_dict[k]) 131 | 132 | values = torch.stack(values, 0) 133 | dist.reduce(values, dst=0) 134 | 135 | if dist.get_rank() == 0 and average: 136 | values /= world_size 137 | 138 | reduced_dict = {k: v for k, v in zip(keys, values)} 139 | 140 | return reduced_dict 141 | 142 | 143 | def data_sampler(dataset, shuffle, distributed): 144 | if distributed: 145 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 146 | 147 | if shuffle: 148 | return data.RandomSampler(dataset) 149 | 150 | else: 151 | return data.SequentialSampler(dataset) 152 | -------------------------------------------------------------------------------- /quantization/distrib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Torch distributed utilities.""" 8 | 9 | import typing as tp 10 | 11 | import torch 12 | 13 | 14 | def rank(): 15 | if torch.distributed.is_initialized(): 16 | return torch.distributed.get_rank() 17 | else: 18 | return 0 19 | 20 | 21 | def world_size(): 22 | if torch.distributed.is_initialized(): 23 | return torch.distributed.get_world_size() 24 | else: 25 | return 1 26 | 27 | 28 | def is_distributed(): 29 | return world_size() > 1 30 | 31 | 32 | def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): 33 | if is_distributed(): 34 | return torch.distributed.all_reduce(tensor, op) 35 | 36 | 37 | def _is_complex_or_float(tensor): 38 | return torch.is_floating_point(tensor) or torch.is_complex(tensor) 39 | 40 | 41 | def _check_number_of_params(params: tp.List[torch.Tensor]): 42 | # utility function to check that the number of params in all workers is the same, 43 | # and thus avoid a deadlock with distributed all reduce. 44 | if not is_distributed() or not params: 45 | return 46 | #print('params[0].device ', params[0].device) 47 | tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) 48 | all_reduce(tensor) 49 | if tensor.item() != len(params) * world_size(): 50 | # If not all the workers have the same number, for at least one of them, 51 | # this inequality will be verified. 52 | raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " 53 | "at least one worker has a different one.") 54 | 55 | 56 | def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): 57 | """Broadcast the tensors from the given parameters to all workers. 58 | This can be used to ensure that all workers have the same model to start with. 59 | """ 60 | if not is_distributed(): 61 | return 62 | tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] 63 | _check_number_of_params(tensors) 64 | handles = [] 65 | for tensor in tensors: 66 | handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) 67 | handles.append(handle) 68 | for handle in handles: 69 | handle.wait() 70 | 71 | 72 | def sync_buffer(buffers, average=True): 73 | """ 74 | Sync grad for buffers. If average is False, broadcast instead of averaging. 75 | """ 76 | if not is_distributed(): 77 | return 78 | handles = [] 79 | for buffer in buffers: 80 | if torch.is_floating_point(buffer.data): 81 | if average: 82 | handle = torch.distributed.all_reduce( 83 | buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) 84 | else: 85 | handle = torch.distributed.broadcast( 86 | buffer.data, src=0, async_op=True) 87 | handles.append((buffer, handle)) 88 | for buffer, handle in handles: 89 | handle.wait() 90 | if average: 91 | buffer.data /= world_size 92 | 93 | 94 | def sync_grad(params): 95 | """ 96 | Simpler alternative to DistributedDataParallel, that doesn't rely 97 | on any black magic. For simple models it can also be as fast. 98 | Just call this on your model parameters after the call to backward! 99 | """ 100 | if not is_distributed(): 101 | return 102 | handles = [] 103 | for p in params: 104 | if p.grad is not None: 105 | handle = torch.distributed.all_reduce( 106 | p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) 107 | handles.append((p, handle)) 108 | for p, handle in handles: 109 | handle.wait() 110 | p.grad.data /= world_size() 111 | 112 | 113 | def average_metrics(metrics: tp.Dict[str, float], count=1.): 114 | """Average a dictionary of metrics across all workers, using the optional 115 | `count` as unormalized weight. 116 | """ 117 | if not is_distributed(): 118 | return metrics 119 | keys, values = zip(*metrics.items()) 120 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 121 | tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) 122 | tensor *= count 123 | all_reduce(tensor) 124 | averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() 125 | return dict(zip(keys, averaged)) 126 | -------------------------------------------------------------------------------- /quantization/vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Residual vector quantizer implementation.""" 8 | 9 | from dataclasses import dataclass, field 10 | import math 11 | import typing as tp 12 | 13 | import torch 14 | from torch import nn 15 | 16 | # from .core_vq import ResidualVectorQuantization 17 | from .core_vq_lsx_version import ResidualVectorQuantization 18 | 19 | 20 | @dataclass 21 | class QuantizedResult: 22 | quantized: torch.Tensor 23 | codes: torch.Tensor 24 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. 25 | penalty: tp.Optional[torch.Tensor] = None 26 | metrics: dict = field(default_factory=dict) 27 | 28 | 29 | class ResidualVectorQuantizer(nn.Module): 30 | """Residual Vector Quantizer. 31 | Args: 32 | dimension (int): Dimension of the codebooks. 33 | n_q (int): Number of residual vector quantizers used. 34 | bins (int): Codebook size. 35 | decay (float): Decay for exponential moving average over the codebooks. 36 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 37 | kmeans_iters (int): Number of iterations used for kmeans initialization. 38 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 39 | that have an exponential moving average cluster size less than the specified threshold with 40 | randomly selected vector from the current batch. 41 | """ 42 | def __init__( 43 | self, 44 | dimension: int = 256, 45 | n_q: int = 8, 46 | bins: int = 1024, 47 | decay: float = 0.99, 48 | kmeans_init: bool = True, 49 | kmeans_iters: int = 50, 50 | threshold_ema_dead_code: int = 2, 51 | ): 52 | super().__init__() 53 | self.n_q = n_q 54 | self.dimension = dimension 55 | self.bins = bins 56 | self.decay = decay 57 | self.kmeans_init = kmeans_init 58 | self.kmeans_iters = kmeans_iters 59 | self.threshold_ema_dead_code = threshold_ema_dead_code 60 | self.vq = ResidualVectorQuantization( 61 | dim=self.dimension, 62 | codebook_size=self.bins, 63 | num_quantizers=self.n_q, 64 | decay=self.decay, 65 | kmeans_init=self.kmeans_init, 66 | kmeans_iters=self.kmeans_iters, 67 | threshold_ema_dead_code=self.threshold_ema_dead_code, 68 | ) 69 | 70 | def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult: 71 | """Residual vector quantization on the given input tensor. 72 | Args: 73 | x (torch.Tensor): Input tensor. 74 | sample_rate (int): Sample rate of the input tensor. 75 | bandwidth (float): Target bandwidth. 76 | Returns: 77 | QuantizedResult: 78 | The quantized (or approximately quantized) representation with 79 | the associated bandwidth and any penalty term for the loss. 80 | """ 81 | bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) 82 | n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) 83 | quantized, codes, commit_loss = self.vq(x, n_q=n_q) 84 | bw = torch.tensor(n_q * bw_per_q).to(x) 85 | return quantized, codes, bw, torch.mean(commit_loss) 86 | #return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) 87 | 88 | def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int: 89 | """Return n_q based on specified target bandwidth. 90 | """ 91 | bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) 92 | n_q = self.n_q 93 | if bandwidth and bandwidth > 0.: 94 | n_q = int(max(1, math.floor(bandwidth / bw_per_q))) 95 | return n_q 96 | 97 | def get_bandwidth_per_quantizer(self, sample_rate: int): 98 | """Return bandwidth per quantizer for a given input sample rate. 99 | """ 100 | return math.log2(self.bins) * sample_rate / 1000 101 | 102 | def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: 103 | """Encode a given input tensor with the specified sample rate at the given bandwidth. 104 | The RVQ encode method sets the appropriate number of quantizer to use 105 | and returns indices for each quantizer. 106 | """ 107 | n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) 108 | codes = self.vq.encode(x, n_q=n_q) 109 | return codes 110 | 111 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 112 | """Decode the given codes to the quantized representation. 113 | """ 114 | quantized = self.vq.decode(codes) 115 | return quantized 116 | -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """A streamable transformer.""" 8 | 9 | import typing as tp 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000): 17 | """Create time embedding for the given positions, target dimension `dim`. 18 | """ 19 | # We aim for BTC format 20 | assert dim % 2 == 0 21 | half_dim = dim // 2 22 | adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1) 23 | phase = positions / (max_period ** (adim / (half_dim - 1))) 24 | return torch.cat([ 25 | torch.cos(phase), 26 | torch.sin(phase), 27 | ], dim=-1) 28 | 29 | 30 | class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer): 31 | def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore 32 | if self.norm_first: 33 | sa_input = self.norm1(x) 34 | x = x + self._sa_block(sa_input, x_past, past_context) 35 | x = x + self._ff_block(self.norm2(x)) 36 | else: 37 | sa_input = x 38 | x = self.norm1(x + self._sa_block(sa_input, x_past, past_context)) 39 | x = self.norm2(x + self._ff_block(x)) 40 | 41 | return x, sa_input 42 | 43 | # self-attention block 44 | def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore 45 | _, T, _ = x.shape 46 | _, H, _ = x_past.shape 47 | 48 | queries = x 49 | keys = torch.cat([x_past, x], dim=1) 50 | values = keys 51 | 52 | queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1) 53 | keys_pos = torch.arange(T + H, device=x.device).view(1, -1) 54 | delta = queries_pos - keys_pos 55 | valid_access = (delta >= 0) & (delta <= past_context) 56 | x = self.self_attn(queries, keys, values, 57 | attn_mask=~valid_access, 58 | need_weights=False)[0] 59 | return self.dropout1(x) 60 | 61 | 62 | class StreamingTransformerEncoder(nn.Module): 63 | """TransformerEncoder with streaming support. 64 | 65 | Args: 66 | dim (int): dimension of the data. 67 | hidden_scale (int): intermediate dimension of FF module is this times the dimension. 68 | num_heads (int): number of heads. 69 | num_layers (int): number of layers. 70 | max_period (float): maxium period of cosines in the positional embedding. 71 | past_context (int or None): receptive field for the causal mask, infinite if None. 72 | gelu (bool): if true uses GeLUs, otherwise use ReLUs. 73 | norm_in (bool): normalize the input. 74 | dropout (float): dropout probability. 75 | **kwargs: See `nn.TransformerEncoderLayer`. 76 | """ 77 | def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5, 78 | max_period: float = 10000, past_context: int = 1000, gelu: bool = True, 79 | norm_in: bool = True, dropout: float = 0., **kwargs): 80 | super().__init__() 81 | assert dim % num_heads == 0 82 | hidden_dim = int(dim * hidden_scale) 83 | 84 | self.max_period = max_period 85 | self.past_context = past_context 86 | activation: tp.Any = F.gelu if gelu else F.relu 87 | 88 | self.norm_in: nn.Module 89 | if norm_in: 90 | self.norm_in = nn.LayerNorm(dim) 91 | else: 92 | self.norm_in = nn.Identity() 93 | 94 | self.layers = nn.ModuleList() 95 | for idx in range(num_layers): 96 | self.layers.append( 97 | StreamingTransformerEncoderLayer( 98 | dim, num_heads, hidden_dim, 99 | activation=activation, batch_first=True, dropout=dropout, **kwargs)) 100 | 101 | def forward(self, x: torch.Tensor, 102 | states: tp.Optional[tp.List[torch.Tensor]] = None, 103 | offset: tp.Union[int, torch.Tensor] = 0): 104 | B, T, C = x.shape 105 | if states is None: 106 | states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))] 107 | 108 | positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset 109 | pos_emb = create_sin_embedding(positions, C, max_period=self.max_period) 110 | 111 | new_state: tp.List[torch.Tensor] = [] 112 | x = self.norm_in(x) 113 | x = x + pos_emb 114 | 115 | for layer_state, layer in zip(states, self.layers): 116 | x, new_layer_state = layer(x, layer_state, self.past_context) 117 | new_layer_state = torch.cat([layer_state, new_layer_state], dim=1) 118 | new_state.append(new_layer_state[:, -self.past_context:, :]) 119 | return x, new_state, offset + T 120 | -------------------------------------------------------------------------------- /losses/generator_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from modules.commons.pqmf import PQMF 8 | from losses.basic_loss import FeatureMatchLoss, MultiResolutionSTFTLoss, LeastDLoss, MSEGLoss, MSEDLoss 9 | from utils.hifigan_mel import mel_spectrogram 10 | 11 | 12 | class BasicGeneratorLoss(nn.Module): 13 | def __init__(self, config): 14 | super(BasicGeneratorLoss, self).__init__() 15 | self.config = config 16 | self.adv_criterion = eval(config.adv_criterion)() 17 | if self.config.use_feature_match: 18 | self.feature_match_criterion = FeatureMatchLoss() 19 | 20 | def forward( 21 | self, 22 | targets: torch.Tensor, 23 | outputs: torch.Tensor, 24 | output_real: Dict[str, torch.Tensor], 25 | output_fake: Dict[str, torch.Tensor], 26 | fmap_real: Optional[Dict[str, torch.Tensor]] = None, 27 | fmap_fake: Optional[Dict[str, torch.Tensor]] = None, 28 | use_adv_loss: bool = True, 29 | ): 30 | """ 31 | Args: 32 | targets: ground-truth waveforms. 33 | outputs: generated waveforms. 34 | output_real: logits from discriminators on real waveforms. 35 | output_fake: logits from discriminators on generated/fake waveforms. 36 | fmap_real: feature mappings of real waveforms. 37 | fmap_fake: feature mappings of generated/fake waveforms. 38 | """ 39 | g_loss = 0 40 | g_loss_items = {} 41 | 42 | if use_adv_loss: 43 | for key in output_fake.keys(): 44 | adv_loss_item = self.adv_criterion(output_fake[key]) 45 | g_loss += adv_loss_item 46 | g_loss_items[f"Train/G_adv_{key}"] = adv_loss_item.item() 47 | 48 | if self.config.use_feature_match: 49 | assert fmap_real is not None and fmap_fake is not None 50 | fmap_loss_item = self.feature_match_criterion( 51 | fmap_real[key], fmap_fake[key]) * self.config.feat_match_loss_weight 52 | g_loss += fmap_loss_item 53 | g_loss_items[f"Train/G_fm_{key}"] = fmap_loss_item.item() / self.config.feat_match_loss_weight 54 | 55 | if self.config.use_mel_loss: 56 | hps_mel_scale_loss = self.config.mel_scale_loss if isinstance(self.config.mel_scale_loss, list) \ 57 | else [self.config.mel_scale_loss] 58 | 59 | for i, _hps_mel_scale_loss in enumerate(hps_mel_scale_loss): 60 | outputs_mel = mel_spectrogram(outputs.squeeze(1), **_hps_mel_scale_loss) 61 | target_mel = mel_spectrogram(targets.squeeze(1), **_hps_mel_scale_loss) 62 | mel_loss = F.l1_loss(outputs_mel, target_mel.detach()) * self.config.mel_loss_weight 63 | g_loss += mel_loss 64 | g_loss_items[f"Train/G_mel_loss_{i}"] = mel_loss.item() / self.config.mel_loss_weight 65 | 66 | return g_loss, g_loss_items 67 | 68 | 69 | class GeneratorSTFTLoss(BasicGeneratorLoss): 70 | def __init__(self, config): 71 | super().__init__(config) 72 | if self.config.use_full_stft_loss: 73 | self.stft_full_criterion = MultiResolutionSTFTLoss( 74 | **self.config.full_multi_scale_stft_loss) 75 | 76 | if self.config.use_sub_stft_loss: 77 | self.pqmf = PQMF(self.config.sub_multi_scale_stft_loss.num_bands) 78 | self.stft_sub_criterion = MultiResolutionSTFTLoss( 79 | **self.config.sub_multi_scale_stft_loss) 80 | 81 | def forward( 82 | self, targets, outputs, output_real, output_fake, fmap_real, fmap_fake, 83 | use_adv_loss: bool = True 84 | ): 85 | g_loss, g_loss_items = super().forward( 86 | targets, outputs, output_real, output_fake, fmap_real, fmap_fake, use_adv_loss=use_adv_loss) 87 | 88 | # Optional: full-band STFT Loss 89 | if self.config.use_full_stft_loss: 90 | sc_full_loss, mg_full_loss = \ 91 | self.stft_full_criterion(outputs.squeeze(1), targets.squeeze(1)) 92 | g_loss = g_loss + self.config.full_stft_loss_weight * (sc_full_loss + mg_full_loss) 93 | g_loss_items["Train/G_sc_full"] = sc_full_loss.item() 94 | g_loss_items["Train/G_mg_full"] = mg_full_loss.item() 95 | 96 | # Optional: sub-band STFT Loss 97 | if self.config.use_sub_stft_loss: 98 | targets_sub = self.pqmf.analysis(targets) 99 | outputs_sub = self.pqmf.analysis(outputs) 100 | size = outputs_sub.size(-1) 101 | outputs_sub_view = outputs_sub.view(-1, size) 102 | targets_sub_view = targets_sub.view(-1, size) 103 | 104 | sc_sub_loss, mg_sub_loss = \ 105 | self.stft_sub_criterion(outputs_sub_view, targets_sub_view) 106 | g_loss = g_loss + self.config.sub_stft_loss_weight * (sc_sub_loss + mg_sub_loss) 107 | g_loss_items["Train/G_sc_sub"] = sc_sub_loss.item() 108 | g_loss_items["Train/G_mg_sub"] = mg_sub_loss.item() 109 | 110 | return g_loss, g_loss_items 111 | 112 | -------------------------------------------------------------------------------- /modules/discriminators/frequency_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from omegaconf import DictConfig 4 | 5 | from modules.commons.torch_stft import TorchSTFT 6 | 7 | 8 | class MultiFrequencyDiscriminator(nn.Module): 9 | def __init__(self, config: DictConfig): 10 | super().__init__() 11 | 12 | self.stfts = nn.ModuleList([ 13 | TorchSTFT( 14 | fft_size=x * 4, 15 | hop_size=x, 16 | win_size=x * 4, 17 | normalized=True, # returns the normalized STFT results, i.e., multiplied by frame_length^{-0.5} 18 | domain=config.domain, 19 | mel_scale=config.mel_scale, 20 | sample_rate=config.sample_rate, 21 | ) for x in config.hop_lengths 22 | ]) 23 | 24 | self.domain = config.domain 25 | if self.domain == 'double': 26 | self.discriminators = nn.ModuleList([ 27 | FrequenceDiscriminator(2, c) 28 | for x, c in zip(config.hop_lengths, config.hidden_channels)]) 29 | else: 30 | self.discriminators = nn.ModuleList([ 31 | FrequenceDiscriminator(1, c) 32 | for x, c in zip(config.hop_lengths, config.hidden_channels)]) 33 | 34 | def forward(self, y, y_hat, **kwargs): 35 | if y.ndim == 3: 36 | y = y.view(-1, y.shape[-1]) 37 | 38 | if y_hat.ndim == 3: 39 | y_hat = y_hat.view(-1, y_hat.shape[-1]) 40 | 41 | real_outputs = [] 42 | fake_outputs = [] 43 | real_feature_maps = [] 44 | fake_feature_maps = [] 45 | 46 | for stft, layer in zip(self.stfts, self.discriminators): 47 | mag, phase = stft.transform(y.squeeze(1)) 48 | fake_mag, fake_phase = stft.transform(y_hat.squeeze(1)) 49 | if self.domain == 'double': 50 | mag = torch.stack(torch.chunk(mag, 2, dim=1), dim=1) 51 | fake_mag = torch.stack(torch.chunk(fake_mag, 2, dim=1), dim=1) 52 | else: 53 | mag = mag.unsqueeze(1) 54 | fake_mag = fake_mag.unsqueeze(1) 55 | 56 | real_out, real_feat_map = layer(mag) 57 | fake_out, fake_feat_map = layer(fake_mag) 58 | real_outputs.append(real_out) 59 | fake_outputs.append(fake_out) 60 | real_feature_maps.append(real_feat_map) 61 | fake_feature_maps.append(fake_feat_map) 62 | 63 | return real_outputs, fake_outputs, real_feature_maps, fake_feature_maps 64 | 65 | 66 | class FrequenceDiscriminator(nn.Module): 67 | def __init__(self, in_channels, hidden_channels=512): 68 | super(FrequenceDiscriminator, self).__init__() 69 | 70 | self.discriminator = nn.ModuleList() 71 | self.discriminator += [ 72 | nn.Sequential( 73 | nn.ReflectionPad2d((1, 1, 1, 1)), 74 | nn.utils.weight_norm(nn.Conv2d( 75 | in_channels, hidden_channels // 32, 76 | kernel_size=(3, 3), stride=(1, 1))) 77 | ), 78 | nn.Sequential( 79 | nn.LeakyReLU(0.2, True), 80 | nn.ReflectionPad2d((1, 1, 1, 1)), 81 | nn.utils.weight_norm(nn.Conv2d( 82 | hidden_channels // 32, hidden_channels // 16, 83 | kernel_size=(3, 3), stride=(2, 2))) 84 | ), 85 | nn.Sequential( 86 | nn.LeakyReLU(0.2, True), 87 | nn.ReflectionPad2d((1, 1, 1, 1)), 88 | nn.utils.weight_norm(nn.Conv2d( 89 | hidden_channels // 16, hidden_channels // 8, 90 | kernel_size=(3, 3), stride=(1, 1))) 91 | ), 92 | nn.Sequential( 93 | nn.LeakyReLU(0.2, True), 94 | nn.ReflectionPad2d((1, 1, 1, 1)), 95 | nn.utils.weight_norm(nn.Conv2d( 96 | hidden_channels // 8, hidden_channels // 4, 97 | kernel_size=(3, 3), stride=(2, 2))) 98 | ), 99 | nn.Sequential( 100 | nn.LeakyReLU(0.2, True), 101 | nn.ReflectionPad2d((1, 1, 1, 1)), 102 | nn.utils.weight_norm(nn.Conv2d( 103 | hidden_channels // 4, hidden_channels // 2, 104 | kernel_size=(3, 3), stride=(1, 1))) 105 | ), 106 | nn.Sequential( 107 | nn.LeakyReLU(0.2, True), 108 | nn.ReflectionPad2d((1, 1, 1, 1)), 109 | nn.utils.weight_norm(nn.Conv2d( 110 | hidden_channels // 2, hidden_channels, 111 | kernel_size=(3, 3), stride=(2, 2))) 112 | ), 113 | nn.Sequential( 114 | nn.LeakyReLU(0.2, True), 115 | nn.ReflectionPad2d((1, 1, 1, 1)), 116 | nn.utils.weight_norm(nn.Conv2d( 117 | hidden_channels, 1, 118 | kernel_size=(3, 3), stride=(1, 1))) 119 | ) 120 | ] 121 | 122 | def forward(self, x): 123 | hiddens = [] 124 | for layer in self.discriminator: 125 | x = layer(x) 126 | hiddens.append(x) 127 | return x, hiddens[:-1] 128 | -------------------------------------------------------------------------------- /models/soundstream_semantic.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from typing import Sequence, Optional, Union 4 | import sys 5 | 6 | import math 7 | import random 8 | 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | import torch.nn.functional as F 13 | 14 | import descriptaudiocodec.dac.model.dac as dac2 15 | 16 | 17 | from quantization import ResidualVectorQuantizer 18 | from transformers import AutoModel 19 | 20 | from modules.semantic_module import Encoder,Decoder 21 | 22 | 23 | 24 | class SoundStream(nn.Module): 25 | def __init__( 26 | self, 27 | n_filters: int = 32, 28 | D: int = 128, 29 | target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6], 30 | ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320 31 | sample_rate: int = 16000, 32 | bins: int = 1024, 33 | normalize: bool = False, 34 | causal: bool = False, 35 | semantic_techer: str = 'hubert_base_general' 36 | ): 37 | super().__init__() 38 | self.hop_length = np.prod(ratios) 39 | 40 | n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / self.hop_length) * 10)) 41 | self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz 42 | self.bits_per_codebook = int(math.log2(bins)) # 1024 => 10 43 | self.target_bandwidths = target_bandwidths 44 | self.n_q = n_q 45 | self.sample_rate = sample_rate 46 | self.encoder = dac2.Encoder(64,ratios,D) 47 | 48 | self.encoder_semantic = Encoder(input_channels=768,encode_channels=768) 49 | self.decoder_semantic = Decoder(code_dim=768,output_channels=768,decode_channels=768) 50 | # out_D=D+768 51 | self.quantizer = ResidualVectorQuantizer(dimension=D+768, n_q=n_q, bins=bins) 52 | 53 | self.decoder_2 = dac2.Decoder(D,1024,ratios,) 54 | 55 | if semantic_techer=='hubert_base': 56 | self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960") 57 | elif semantic_techer=='wavlm_base_plus': 58 | self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus") 59 | elif semantic_techer=='hubert_base_general': 60 | 61 | self.semantic_model = AutoModel.from_pretrained("ZhenYe234/hubert_base_general_audio") 62 | self.semantic_model.eval() 63 | 64 | self.fc_prior = nn.Linear(D+768, D+768 ) 65 | 66 | self.fc_post1= nn.Linear( D+768, 768 ) 67 | self.fc_post2= nn.Linear( D+768, D) 68 | 69 | def get_last_layer(self): 70 | return self.decoder.layers[-1].weight 71 | 72 | def calculate_rec_loss(self, rec, target): 73 | 74 | target = target / target.norm(dim=-1, keepdim=True) 75 | rec = rec / rec.norm(dim=-1, keepdim=True) 76 | rec_loss = (1 - (target * rec).sum(-1)).mean() 77 | 78 | return rec_loss 79 | 80 | @torch.no_grad() 81 | def get_regress_target(self, x ): 82 | x= x[:,0,:] 83 | x = F.pad(x, (160, 160)) 84 | target = self.semantic_model(x, output_hidden_states=True) .hidden_states 85 | target = torch.stack(target, dim=1)#.transpose(-1, -2)#.flatten(start_dim=1, end_dim=2) 86 | 87 | # average for all layers 88 | target = target.mean(1) 89 | # target = target[9] 90 | return target 91 | 92 | 93 | 94 | def forward(self, x: torch.Tensor, bw: int): 95 | 96 | e_semantic_input = self.get_regress_target(x).detach() 97 | 98 | e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) 99 | e_acoustic = self.encoder(x) 100 | 101 | 102 | e= torch.cat([e_acoustic, e_semantic], dim=1) 103 | 104 | e = self.fc_prior(e.transpose(1, 2)).transpose(1, 2) 105 | 106 | 107 | quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) 108 | 109 | quantized_semantic = self.fc_post1(quantized.transpose(1, 2)).transpose(1, 2) 110 | quantized_acoustic = self.fc_post2(quantized.transpose(1, 2)).transpose(1, 2) 111 | 112 | o = self.decoder_2(quantized_acoustic) 113 | 114 | o_semantic = self.decoder_semantic(quantized_semantic ) 115 | semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(),o_semantic) 116 | 117 | return o, commit_loss, semantic_recon_loss,None 118 | 119 | def encode(self, x: torch.Tensor,target_bw: Optional[int] = None) -> torch.Tensor: 120 | 121 | bw = target_bw 122 | 123 | e_semantic_input = self.get_regress_target(x).detach() 124 | 125 | e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) 126 | e_acoustic = self.encoder(x) 127 | 128 | 129 | if e_acoustic.shape[2] != e_semantic.shape[2]: 130 | e_acoustic = self.encoder(F.pad(x[:,0,:], (160, 160)).unsqueeze(0)) 131 | 132 | e= torch.cat([e_acoustic, e_semantic], dim=1) 133 | 134 | e = self.fc_prior(e.transpose(1, 2)).transpose(1, 2) 135 | 136 | 137 | quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) 138 | return codes 139 | 140 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 141 | quantized = self.quantizer.decode(codes) 142 | quantized_acoustic = self.fc_post2(quantized.transpose(1, 2)).transpose(1, 2) 143 | 144 | o = self.decoder_2(quantized_acoustic) 145 | return o 146 | 147 | # test 148 | if __name__ == '__main__': 149 | soundstream = SoundStream(n_filters=32, D=256) 150 | 151 | for i in range(10): 152 | print(f"Iter {i}: ") 153 | x = torch.rand(1, 1, 16000) 154 | o, commit_loss, distill_loss,_= soundstream(x,soundstream.target_bandwidths[-1]) 155 | print('output', o.shape) 156 | -------------------------------------------------------------------------------- /losses/basic_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FeatureMatchLoss(nn.Module): 7 | def __init__(self): 8 | super(FeatureMatchLoss, self).__init__() 9 | 10 | def forward(self, real_features, fake_features): 11 | loss = 0 12 | num_items = 0 13 | for (fake_feature, real_feature) in zip(fake_features, real_features): 14 | if isinstance(fake_feature, list): 15 | for (_fake_feature, _real_feature) in zip(fake_feature, real_feature): 16 | loss = loss + F.l1_loss(_fake_feature.float(), _real_feature.float().detach()) 17 | num_items += 1 18 | else: 19 | loss = loss + F.l1_loss(fake_feature.float(), real_feature.float().detach()) 20 | num_items += 1 21 | loss /= num_items 22 | return loss 23 | 24 | 25 | class LeastDLoss(nn.Module): 26 | def __init__(self): 27 | super(LeastDLoss, self).__init__() 28 | 29 | def forward(self, disc_outputs): 30 | loss = 0 31 | for dg in disc_outputs: 32 | dg = dg.float() 33 | l = torch.mean((1-dg)**2) 34 | loss += l 35 | return loss 36 | 37 | 38 | class MSEDLoss(nn.Module): 39 | def __init__(self): 40 | super(MSEDLoss, self).__init__() 41 | self.loss_func = nn.MSELoss() 42 | 43 | def forward(self, score_fake, score_real): 44 | loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape)) 45 | loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape)) 46 | loss_d = loss_real + loss_fake 47 | return loss_d, loss_real, loss_fake 48 | 49 | 50 | class HingeDLoss(nn.Module): 51 | def __init__(self): 52 | super(HingeDLoss, self).__init__() 53 | 54 | def forward(self, score_fake, score_real): 55 | loss_real = torch.mean(F.relu(1. - score_real)) 56 | loss_fake = torch.mean(F.relu(1. + score_fake)) 57 | loss_d = loss_real + loss_fake 58 | return loss_d, loss_real, loss_fake 59 | 60 | 61 | class MSEGLoss(nn.Module): 62 | def __init__(self): 63 | super(MSEGLoss, self).__init__() 64 | 65 | def forward(self, scores): 66 | loss_fake = 0 67 | num_items = 0 68 | if isinstance(scores, list): 69 | for score in scores: 70 | loss_fake = loss_fake + F.mse_loss(score, score.new_ones(score.shape)) 71 | num_items += 1 72 | else: 73 | loss_fake = F.mse_loss(scores, scores.new_ones(scores.shape)) 74 | num_items += 1 75 | return loss_fake / num_items 76 | 77 | 78 | class HingeGLoss(nn.Module): 79 | def __init__(self): 80 | super(HingeGLoss, self).__init__() 81 | 82 | def forward(self, score_real): 83 | loss_fake = torch.mean(F.relu(1. - score_real)) 84 | return loss_fake 85 | 86 | 87 | def stft(x, fft_size, hop_size, win_size, window): 88 | x_stft = torch.stft(x, fft_size, hop_size, win_size, window,return_complex=False) 89 | real = x_stft[..., 0] 90 | imag = x_stft[..., 1] 91 | outputs = torch.clamp(real ** 2 + imag ** 2, min=1e-7).transpose(2, 1) 92 | outputs = torch.sqrt(outputs) 93 | 94 | return outputs 95 | 96 | 97 | class SpectralConvergence(nn.Module): 98 | def __init__(self): 99 | super(SpectralConvergence, self).__init__() 100 | 101 | def forward(self, predicts_mag, targets_mag): 102 | x = torch.norm(targets_mag - predicts_mag, p='fro') 103 | y = torch.norm(targets_mag, p='fro') 104 | 105 | return x / y 106 | 107 | 108 | class LogSTFTMagnitude(nn.Module): 109 | def __init__(self): 110 | super(LogSTFTMagnitude, self).__init__() 111 | 112 | def forward(self, predicts_mag, targets_mag): 113 | log_predicts_mag = torch.log(predicts_mag) 114 | log_targets_mag = torch.log(targets_mag) 115 | outputs = F.l1_loss(log_predicts_mag, log_targets_mag) 116 | 117 | return outputs 118 | 119 | 120 | class STFTLoss(nn.Module): 121 | def __init__( 122 | self, 123 | fft_size=1024, 124 | hop_size=120, 125 | win_size=600, 126 | ): 127 | super(STFTLoss, self).__init__() 128 | 129 | self.fft_size = fft_size 130 | self.hop_size = hop_size 131 | self.win_size = win_size 132 | self.register_buffer('window', torch.hann_window(win_size)) 133 | self.sc_loss = SpectralConvergence() 134 | self.mag = LogSTFTMagnitude() 135 | 136 | def forward(self, predicts, targets): 137 | predicts_mag = stft(predicts, self.fft_size, self.hop_size, self.win_size, self.window) 138 | targets_mag = stft(targets, self.fft_size, self.hop_size, self.win_size, self.window) 139 | 140 | sc_loss = self.sc_loss(predicts_mag, targets_mag) 141 | mag_loss = self.mag(predicts_mag, targets_mag) 142 | 143 | return sc_loss, mag_loss 144 | 145 | 146 | class MultiResolutionSTFTLoss(nn.Module): 147 | def __init__( 148 | self, 149 | fft_sizes=[1024, 2048, 512], 150 | win_sizes=[600, 1200, 240], 151 | hop_sizes=[120, 240, 50], 152 | **kwargs 153 | ): 154 | super(MultiResolutionSTFTLoss, self).__init__() 155 | self.loss_layers = torch.nn.ModuleList() 156 | for (fft_size, win_size, hop_size) in zip(fft_sizes, win_sizes, hop_sizes): 157 | self.loss_layers.append(STFTLoss(fft_size, hop_size, win_size)) 158 | 159 | def forward(self, fake_signals, true_signals): 160 | sc_losses, mag_losses = [], [] 161 | for layer in self.loss_layers: 162 | sc_loss, mag_loss = layer(fake_signals, true_signals) 163 | sc_losses.append(sc_loss) 164 | mag_losses.append(mag_loss) 165 | 166 | sc_loss = sum(sc_losses) / len(sc_losses) 167 | mag_loss = sum(mag_losses) / len(mag_losses) 168 | 169 | return sc_loss, mag_loss 170 | -------------------------------------------------------------------------------- /modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchaudio.transforms import MelSpectrogram 6 | 7 | 8 | def adversarial_g_loss(y_disc_gen): 9 | loss = 0.0 10 | for i in range(len(y_disc_gen)): 11 | #print(y_disc_gen[i].shape) 12 | # assert 1==2 13 | stft_loss = F.relu(1-y_disc_gen[i]).mean().squeeze() 14 | loss += stft_loss 15 | return loss/len(y_disc_gen) 16 | 17 | 18 | def feature_loss(fmap_r, fmap_gen): 19 | loss = 0.0 20 | for i in range(len(fmap_r)): 21 | for j in range(len(fmap_r[i])): 22 | stft_loss = ((fmap_r[i][j]-fmap_gen[i][j]).abs()/(fmap_r[i][j].abs().mean())).mean() 23 | loss += stft_loss 24 | return loss/(len(fmap_r)*len(fmap_r[0])) 25 | 26 | 27 | def sim_loss(y_disc_r, y_disc_gen): 28 | loss = 0.0 29 | for i in range(len(y_disc_r)): 30 | loss += F.mse_loss(y_disc_r[i], y_disc_gen[i]) 31 | return loss/len(y_disc_r) 32 | 33 | 34 | def sisnr_loss(x, s, eps=1e-8): 35 | """ 36 | calculate training loss 37 | input: 38 | x: separated signal, N x S tensor, estimate value 39 | s: reference signal, N x S tensor, True value 40 | Return: 41 | sisnr: N tensor 42 | """ 43 | if x.shape != s.shape: 44 | if x.shape[-1] > s.shape[-1]: 45 | x = x[:, :s.shape[-1]] 46 | else: 47 | s = s[:, :x.shape[-1]] 48 | def l2norm(mat, keepdim=False): 49 | return torch.norm(mat, dim=-1, keepdim=keepdim) 50 | if x.shape != s.shape: 51 | raise RuntimeError( 52 | "Dimention mismatch when calculate si-snr, {} vs {}".format( 53 | x.shape, s.shape)) 54 | x_zm = x - torch.mean(x, dim=-1, keepdim=True) 55 | s_zm = s - torch.mean(s, dim=-1, keepdim=True) 56 | t = torch.sum( 57 | x_zm * s_zm, dim=-1, 58 | keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps) 59 | loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)) 60 | return torch.sum(loss) / x.shape[0] 61 | 62 | 63 | def reconstruction_loss(x, G_x, args, eps=1e-7): 64 | L = 100*F.mse_loss(x, G_x) # wav L1 loss 65 | #loss_sisnr = sisnr_loss(G_x, x) # 66 | #L += 0.01*loss_sisnr 67 | # print('L0 ', L) 68 | # print('loss_sisnr ', 0.01*loss_sisnr) 69 | # print('L0 ', L) 70 | for i in range(6,11): 71 | s = 2**i 72 | melspec = MelSpectrogram(sample_rate=args.sr, n_fft=s, hop_length=s//4, n_mels=64, wkwargs={"device": args.device}).to(args.device) 73 | S_x = melspec(x) 74 | S_G_x = melspec(G_x) 75 | loss = ((S_x-S_G_x).abs().mean() + (((torch.log(S_x.abs()+eps)-torch.log(S_G_x.abs()+eps))**2).mean(dim=-2)**0.5).mean())/(i) 76 | L += loss 77 | #print('i ,loss ', i, loss) 78 | #assert 1==2 79 | return L 80 | 81 | 82 | def criterion_d(y_disc_r, y_disc_gen, fmap_r_det, fmap_gen_det): 83 | loss = 0.0 84 | loss_f = feature_loss(fmap_r_det, fmap_gen_det) 85 | for i in range(len(y_disc_r)): 86 | loss += F.relu(1-y_disc_r[i]).mean() + F.relu(1+y_disc_gen[i]).mean() 87 | return loss/len(y_disc_gen) + 0.0*loss_f 88 | 89 | 90 | def criterion_g(commit_loss, x, G_x, fmap_r, fmap_gen, y_disc_r, y_disc_gen, args): 91 | adv_g_loss = adversarial_g_loss(y_disc_gen) 92 | feat_loss = feature_loss(fmap_r, fmap_gen) + sim_loss(y_disc_r, y_disc_gen) # 预测结果也应该尽可能相似 93 | rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args) 94 | total_loss = args.LAMBDA_COM * commit_loss + args.LAMBDA_ADV*adv_g_loss + \ 95 | args.LAMBDA_FEAT*feat_loss + args.LAMBDA_REC*rec_loss 96 | return total_loss, adv_g_loss, feat_loss, rec_loss 97 | 98 | 99 | def adopt_weight(weight, global_step, threshold=0, value=0.): 100 | if global_step < threshold: 101 | weight = value 102 | return weight 103 | 104 | 105 | def adopt_dis_weight(weight, global_step, threshold=0, value=0.): 106 | if global_step % 3 == 0: # 0,3,6,9,13....这些时间步,不更新dis 107 | weight = value 108 | return weight 109 | 110 | 111 | def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args): 112 | if last_layer is not None: 113 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 114 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 115 | else: 116 | print('last_layer cannot be none') 117 | assert 1==2 118 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 119 | d_weight = torch.clamp(d_weight, 1.0, 1.0).detach() 120 | d_weight = d_weight * args.LAMBDA_ADV 121 | return d_weight 122 | 123 | 124 | def loss_g(codebook_loss, inputs, reconstructions, fmap_r, fmap_gen, 125 | y_disc_r, y_disc_gen, global_step, last_layer=None, is_training=True, args=None): 126 | rec_loss = reconstruction_loss(inputs.contiguous(), reconstructions.contiguous(), args) 127 | adv_g_loss = adversarial_g_loss(y_disc_gen) 128 | feat_loss = feature_loss(fmap_r, fmap_gen) + sim_loss(y_disc_r, y_disc_gen) # 129 | d_weight = torch.tensor(1.0) 130 | # try: 131 | # d_weight = calculate_adaptive_weight(rec_loss, adv_g_loss, last_layer, args) # 动态调整重构损失和对抗损失 132 | # except RuntimeError: 133 | # assert not is_training 134 | # d_weight = torch.tensor(0.0) 135 | disc_factor = adopt_weight(args.LAMBDA_ADV, global_step, threshold=args.discriminator_iter_start) 136 | #feat_factor = adopt_weight(args.LAMBDA_FEAT, global_step, threshold=args.discriminator_iter_start) 137 | loss = rec_loss + d_weight * disc_factor * adv_g_loss + \ 138 | args.LAMBDA_FEAT*feat_loss + args.LAMBDA_COM * codebook_loss 139 | return loss, rec_loss, adv_g_loss, feat_loss, d_weight 140 | 141 | 142 | def loss_dis(y_disc_r_det, y_disc_gen_det, fmap_r_det, fmap_gen_det, global_step, args): 143 | disc_factor = adopt_weight(args.LAMBDA_ADV, global_step, threshold=args.discriminator_iter_start) 144 | d_loss = disc_factor * criterion_d(y_disc_r_det, y_disc_gen_det, fmap_r_det, fmap_gen_det) 145 | return d_loss 146 | -------------------------------------------------------------------------------- /descriptaudiocodec/README.md: -------------------------------------------------------------------------------- 1 | # Descript Audio Codec (.dac): High-Fidelity Audio Compression with Improved RVQGAN 2 | 3 | This repository contains training and inference scripts 4 | for the Descript Audio Codec (.dac), a high fidelity general 5 | neural audio codec, introduced in the paper titled **High-Fidelity Audio Compression with Improved RVQGAN**. 6 | 7 | ![](https://static.arxiv.org/static/browse/0.3.4/images/icons/favicon-16x16.png) [arXiv Paper: High-Fidelity Audio Compression with Improved RVQGAN 8 | ](http://arxiv.org/abs/2306.06546)
9 | 📈 [Demo Site](https://descript.notion.site/Descript-Audio-Codec-11389fce0ce2419891d6591a68f814d5)
10 | ⚙ [Model Weights](https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth) 11 | 12 | 👉 With Descript Audio Codec, you can compress **44.1 KHz audio** into discrete codes at a **low 8 kbps bitrate**.
13 | 🤌 That's approximately **90x compression** while maintaining exceptional fidelity and minimizing artifacts.
14 | 💪 Our universal model works on all domains (speech, environment, music, etc.), making it widely applicable to generative modeling of all audio.
15 | 👌 It can be used as a drop-in replacement for EnCodec for all audio language modeling applications (such as AudioLMs, MusicLMs, MusicGen, etc.)
16 | 17 |

18 | Comparison of compressions approaches. Our model achieves a higher compression factor compared to all baseline methods. Our model has a ~90x compression factor compared to 32x compression factor of EnCodec and 64x of SoundStream. Note that we operate at a target bitrate of 8 kbps, whereas EnCodec operates at 24 kbps and SoundStream at 6 kbps. We also operate at 44.1 kHz, whereas EnCodec operates at 48 kHz and SoundStream operates at 24 kHz.

19 | 20 | 21 | ## Usage 22 | 23 | ### Installation 24 | ``` 25 | pip install descript-audio-codec 26 | ``` 27 | OR 28 | 29 | ``` 30 | pip install git+https://github.com/descriptinc/descript-audio-codec 31 | ``` 32 | 33 | ### Weights 34 | Weights are released as part of this repo under MIT license. 35 | We release weights for models that can natively support 16 kHz, 24kHz, and 44.1kHz sampling rates. 36 | Weights are automatically downloaded when you first run `encode` or `decode` command. You can cache them using one of the following commands 37 | ```bash 38 | python3 -m dac download # downloads the default 44kHz variant 39 | python3 -m dac download --model_type 44khz # downloads the 44kHz variant 40 | python3 -m dac download --model_type 24khz # downloads the 24kHz variant 41 | python3 -m dac download --model_type 16khz # downloads the 16kHz variant 42 | ``` 43 | We provide a Dockerfile that installs all required dependencies for encoding and decoding. The build process caches the default model weights inside the image. This allows the image to be used without an internet connection. [Please refer to instructions below.](#docker-image) 44 | 45 | 46 | ### Compress audio 47 | ``` 48 | python3 -m dac encode /path/to/input --output /path/to/output/codes 49 | ``` 50 | 51 | This command will create `.dac` files with the same name as the input files. 52 | It will also preserve the directory structure relative to input root and 53 | re-create it in the output directory. Please use `python -m dac encode --help` 54 | for more options. 55 | 56 | ### Reconstruct audio from compressed codes 57 | ``` 58 | python3 -m dac decode /path/to/output/codes --output /path/to/reconstructed_input 59 | ``` 60 | 61 | This command will create `.wav` files with the same name as the input files. 62 | It will also preserve the directory structure relative to input root and 63 | re-create it in the output directory. Please use `python -m dac decode --help` 64 | for more options. 65 | 66 | ### Programmatic Usage 67 | ```py 68 | import dac 69 | from audiotools import AudioSignal 70 | 71 | # Download a model 72 | model_path = dac.utils.download(model_type="44khz") 73 | model = dac.DAC.load(model_path) 74 | 75 | model.to('cuda') 76 | 77 | # Load audio signal file 78 | signal = AudioSignal('input.wav') 79 | 80 | # Encode audio signal as one long file 81 | # (may run out of GPU memory on long files) 82 | signal.to(model.device) 83 | 84 | x = model.preprocess(signal.audio_data, signal.sample_rate) 85 | z, codes, latents, _, _ = model.encode(x) 86 | 87 | # Decode audio signal 88 | y = model.decode(z) 89 | 90 | # Alternatively, use the `compress` and `decompress` functions 91 | # to compress long files. 92 | 93 | signal = signal.cpu() 94 | x = model.compress(signal) 95 | 96 | # Save and load to and from disk 97 | x.save("compressed.dac") 98 | x = dac.DACFile.load("compressed.dac") 99 | 100 | # Decompress it back to an AudioSignal 101 | y = model.decompress(x) 102 | 103 | # Write to file 104 | y.write('output.wav') 105 | ``` 106 | 107 | ### Docker image 108 | We provide a dockerfile to build a docker image with all the necessary 109 | dependencies. 110 | 1. Building the image. 111 | ``` 112 | docker build -t dac . 113 | ``` 114 | 2. Using the image. 115 | 116 | Usage on CPU: 117 | ``` 118 | docker run dac 119 | ``` 120 | 121 | Usage on GPU: 122 | ``` 123 | docker run --gpus=all dac 124 | ``` 125 | 126 | `` can be one of the compression and reconstruction commands listed 127 | above. For example, if you want to run compression, 128 | 129 | ``` 130 | docker run --gpus=all dac python3 -m dac encode ... 131 | ``` 132 | 133 | 134 | ## Training 135 | The baseline model configuration can be trained using the following commands. 136 | 137 | ### Pre-requisites 138 | Please install the correct dependencies 139 | ``` 140 | pip install -e ".[dev]" 141 | ``` 142 | 143 | ## Environment setup 144 | 145 | We have provided a Dockerfile and docker compose setup that makes running experiments easy. 146 | 147 | To build the docker image do: 148 | 149 | ``` 150 | docker compose build 151 | ``` 152 | 153 | Then, to launch a container, do: 154 | 155 | ``` 156 | docker compose run -p 8888:8888 -p 6006:6006 dev 157 | ``` 158 | 159 | The port arguments (`-p`) are optional, but useful if you want to launch a Jupyter and Tensorboard instances within the container. The 160 | default password for Jupyter is `password`, and the current directory 161 | is mounted to `/u/home/src`, which also becomes the working directory. 162 | 163 | Then, run your training command. 164 | 165 | 166 | ### Single GPU training 167 | ``` 168 | export CUDA_VISIBLE_DEVICES=0 169 | python scripts/train.py --args.load conf/ablations/baseline.yml --save_path runs/baseline/ 170 | ``` 171 | 172 | ### Multi GPU training 173 | ``` 174 | export CUDA_VISIBLE_DEVICES=0,1 175 | torchrun --nproc_per_node gpu scripts/train.py --args.load conf/ablations/baseline.yml --save_path runs/baseline/ 176 | ``` 177 | 178 | ## Testing 179 | We provide two test scripts to test CLI + training functionality. Please 180 | make sure that the trainig pre-requisites are satisfied before launching these 181 | tests. To launch these tests please run 182 | ``` 183 | python -m pytest tests 184 | ``` 185 | 186 | ## Results 187 | 188 |

189 |

190 | -------------------------------------------------------------------------------- /modules/commons/pqmf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Pseudo QMF modules.""" 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from scipy.signal import kaiser 13 | 14 | 15 | def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0): 16 | """Design prototype filter for PQMF. 17 | This method is based on `A Kaiser window approach for the design of prototype 18 | filters of cosine modulated filterbanks`_. 19 | Args: 20 | taps (int): The number of filter taps. 21 | cutoff_ratio (float): Cut-off frequency ratio. 22 | beta (float): Beta coefficient for kaiser window. 23 | Returns: 24 | ndarray: Impluse response of prototype filter (taps + 1,). 25 | .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: 26 | https://ieeexplore.ieee.org/abstract/document/681427 27 | """ 28 | # check the arguments are valid 29 | assert taps % 2 == 0, "The number of taps mush be even number." 30 | assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." 31 | 32 | # make initial filter 33 | omega_c = np.pi * cutoff_ratio 34 | with np.errstate(invalid='ignore'): 35 | h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \ 36 | / (np.pi * (np.arange(taps + 1) - 0.5 * taps)) 37 | h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form 38 | 39 | # apply kaiser window 40 | w = kaiser(taps + 1, beta) 41 | h = h_i * w 42 | 43 | return h 44 | 45 | 46 | class PQMF(torch.nn.Module): 47 | """PQMF module. 48 | This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. 49 | .. _`Near-perfect-reconstruction pseudo-QMF banks`: 50 | https://ieeexplore.ieee.org/document/258122 51 | """ 52 | 53 | def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0): 54 | """Initilize PQMF module. 55 | The cutoff_ratio and beta parameters are optimized for #subbands = 4. 56 | See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195. 57 | Args: 58 | subbands (int): The number of subbands. 59 | taps (int): The number of filter taps. 60 | cutoff_ratio (float): Cut-off frequency ratio. 61 | beta (float): Beta coefficient for kaiser window. 62 | """ 63 | super(PQMF, self).__init__() 64 | 65 | if subbands == 8: 66 | cutoff_ratio = 0.07949452 67 | elif subbands == 6: 68 | cutoff_ratio = 0.10032791 69 | elif subbands == 4: 70 | cutoff_ratio = 0.13 71 | elif subbands == 2: 72 | cutoff_ratio = 0.25 73 | 74 | # build analysis & synthesis filter coefficients 75 | h_proto = design_prototype_filter(taps, cutoff_ratio, beta) 76 | h_analysis = np.zeros((subbands, len(h_proto))) 77 | h_synthesis = np.zeros((subbands, len(h_proto))) 78 | for k in range(subbands): 79 | h_analysis[k] = 2 * h_proto * np.cos( 80 | (2 * k + 1) * (np.pi / (2 * subbands)) * 81 | (np.arange(taps + 1) - (taps / 2)) + 82 | (-1) ** k * np.pi / 4) 83 | h_synthesis[k] = 2 * h_proto * np.cos( 84 | (2 * k + 1) * (np.pi / (2 * subbands)) * 85 | (np.arange(taps + 1) - (taps / 2)) - 86 | (-1) ** k * np.pi / 4) 87 | 88 | # convert to tensor 89 | analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) 90 | synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0) 91 | 92 | # register coefficients as beffer 93 | self.register_buffer("analysis_filter", analysis_filter) 94 | self.register_buffer("synthesis_filter", synthesis_filter) 95 | 96 | # filter for downsampling & upsampling 97 | updown_filter = torch.zeros((subbands, subbands, subbands)).float() 98 | for k in range(subbands): 99 | updown_filter[k, k, 0] = 1.0 100 | self.register_buffer("updown_filter", updown_filter) 101 | self.subbands = subbands 102 | 103 | # keep padding info 104 | self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) 105 | 106 | def analysis(self, x): 107 | """Analysis with PQMF. 108 | Args: 109 | x (Tensor): Input tensor (B, 1, T). 110 | Returns: 111 | Tensor: Output tensor (B, subbands, T // subbands). 112 | """ 113 | x = F.conv1d(self.pad_fn(x), self.analysis_filter) 114 | return F.conv1d(x, self.updown_filter, stride=self.subbands) 115 | 116 | def synthesis(self, x): 117 | """Synthesis with PQMF. 118 | Args: 119 | x (Tensor): Input tensor (B, subbands, T // subbands). 120 | Returns: 121 | Tensor: Output tensor (B, 1, T). 122 | """ 123 | # NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands. 124 | # Not sure this is the correct way, it is better to check again. 125 | # TODO(kan-bayashi): Understand the reconstruction procedure 126 | x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands) 127 | return F.conv1d(self.pad_fn(x), self.synthesis_filter) 128 | 129 | 130 | def _objective(cutoff_ratio): 131 | h_proto = design_prototype_filter(num_taps, cutoff_ratio, beta) 132 | conv_h_proto = np.convolve(h_proto, h_proto[::-1], mode='full') 133 | length_conv_h = conv_h_proto.shape[0] 134 | half_length = length_conv_h // 2 135 | 136 | check_steps = np.arange((half_length) // (2 * num_subbands)) * 2 * num_subbands 137 | _phi_new = conv_h_proto[half_length:][check_steps] 138 | phi_new = np.abs(_phi_new[1:]).max() 139 | # Since phi_new is not convex, This value should also be considered. 140 | diff_zero_coef = np.abs(_phi_new[0] - 1 / (2 * num_subbands)) 141 | 142 | return phi_new + diff_zero_coef 143 | 144 | if __name__ == "__main__": 145 | model = PQMF(4) 146 | import numpy as np 147 | import scipy.optimize as optimize 148 | 149 | x = np.load('data/train/audio/010000.npy') 150 | x = torch.FloatTensor(x).unsqueeze(0).unsqueeze(0) 151 | out = model.analysis(x) 152 | print(out.shape) 153 | x_hat = model.synthesis(out) 154 | loss = torch.nn.functional.mse_loss( 155 | x[..., :x_hat.shape[-1]], 156 | x_hat[..., :x_hat.shape[-1]], 157 | reduction="sum" 158 | ) 159 | print(loss) 160 | from scipy.io.wavfile import write 161 | audio = x_hat.squeeze().numpy() 162 | write('a.wav', 24000, audio) 163 | 164 | model = PQMF(6) 165 | out = model.analysis(x) 166 | print(out.shape) 167 | x_hat = model.synthesis(out) 168 | loss = torch.nn.functional.mse_loss( 169 | x[..., :x_hat.shape[-1]], 170 | x_hat[..., :x_hat.shape[-1]], 171 | reduction="sum" 172 | ) 173 | print(loss) 174 | audio = x_hat.squeeze().numpy() 175 | write('b.wav', 24000, audio) 176 | 177 | num_subbands = 6 178 | num_taps = 62 179 | beta = 9.0 180 | 181 | ret = optimize.minimize(_objective, np.array([0.01]), 182 | bounds=optimize.Bounds(0.01, 0.99)) 183 | opt_cutoff_ratio = ret.x[0] 184 | print(f"optimized cutoff ratio = {opt_cutoff_ratio:.08f}") 185 | -------------------------------------------------------------------------------- /models/msstftd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """MS-STFT discriminator, provided here for reference.""" 8 | 9 | import typing as tp 10 | 11 | import torchaudio 12 | import torch 13 | from torch import nn 14 | from einops import rearrange 15 | 16 | from modules import NormConv2d 17 | 18 | 19 | FeatureMapType = tp.List[torch.Tensor] 20 | LogitsType = torch.Tensor 21 | DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] 22 | 23 | 24 | def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): 25 | return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) 26 | 27 | 28 | class DiscriminatorSTFT(nn.Module): 29 | """STFT sub-discriminator. 30 | Args: 31 | filters (int): Number of filters in convolutions 32 | in_channels (int): Number of input channels. Default: 1 33 | out_channels (int): Number of output channels. Default: 1 34 | n_fft (int): Size of FFT for each scale. Default: 1024 35 | hop_length (int): Length of hop between STFT windows for each scale. Default: 256 36 | kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` 37 | stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` 38 | dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` 39 | win_length (int): Window size for each scale. Default: 1024 40 | normalized (bool): Whether to normalize by magnitude after stft. Default: True 41 | norm (str): Normalization method. Default: `'weight_norm'` 42 | activation (str): Activation function. Default: `'LeakyReLU'` 43 | activation_params (dict): Parameters to provide to the activation function. 44 | growth (int): Growth factor for the filters. Default: 1 45 | """ 46 | def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, 47 | n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, 48 | filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], 49 | stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', 50 | activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): 51 | super().__init__() 52 | assert len(kernel_size) == 2 53 | assert len(stride) == 2 54 | self.filters = filters 55 | self.in_channels = in_channels 56 | self.out_channels = out_channels 57 | self.n_fft = n_fft 58 | self.hop_length = hop_length 59 | self.win_length = win_length 60 | self.normalized = normalized 61 | self.activation = getattr(torch.nn, activation)(**activation_params) 62 | self.spec_transform = torchaudio.transforms.Spectrogram( 63 | n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, 64 | normalized=self.normalized, center=False, pad_mode=None, power=None) 65 | spec_channels = 2 * self.in_channels 66 | self.convs = nn.ModuleList() 67 | self.convs.append( 68 | NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) 69 | ) 70 | in_chs = min(filters_scale * self.filters, max_filters) 71 | for i, dilation in enumerate(dilations): 72 | out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) 73 | self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, 74 | dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), 75 | norm=norm)) 76 | in_chs = out_chs 77 | out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) 78 | self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), 79 | padding=get_2d_padding((kernel_size[0], kernel_size[0])), 80 | norm=norm)) 81 | self.conv_post = NormConv2d(out_chs, self.out_channels, 82 | kernel_size=(kernel_size[0], kernel_size[0]), 83 | padding=get_2d_padding((kernel_size[0], kernel_size[0])), 84 | norm=norm) 85 | 86 | def forward(self, x: torch.Tensor): 87 | fmap = [] 88 | # print('x ', x.shape) 89 | z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] 90 | # print('z ', z.shape) 91 | z = torch.cat([z.real, z.imag], dim=1) 92 | # print('cat_z ', z.shape) 93 | z = rearrange(z, 'b c w t -> b c t w') 94 | for i, layer in enumerate(self.convs): 95 | z = layer(z) 96 | z = self.activation(z) 97 | # print('z i', i, z.shape) 98 | fmap.append(z) 99 | z = self.conv_post(z) 100 | # print('logit ', z.shape) 101 | return z, fmap 102 | 103 | 104 | class MultiScaleSTFTDiscriminator(nn.Module): 105 | """Multi-Scale STFT (MS-STFT) discriminator. 106 | Args: 107 | filters (int): Number of filters in convolutions 108 | in_channels (int): Number of input channels. Default: 1 109 | out_channels (int): Number of output channels. Default: 1 110 | n_ffts (Sequence[int]): Size of FFT for each scale 111 | hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale 112 | win_lengths (Sequence[int]): Window size for each scale 113 | **kwargs: additional args for STFTDiscriminator 114 | """ 115 | def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, 116 | n_ffts: tp.List[int] = [1024, 2048, 512, 256, 128], hop_lengths: tp.List[int] = [256, 512, 128, 64, 32], 117 | win_lengths: tp.List[int] = [1024, 2048, 512, 256, 128], **kwargs): 118 | super().__init__() 119 | assert len(n_ffts) == len(hop_lengths) == len(win_lengths) 120 | self.discriminators = nn.ModuleList([ 121 | DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, 122 | n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) 123 | for i in range(len(n_ffts)) 124 | ]) 125 | self.num_discriminators = len(self.discriminators) 126 | 127 | def forward(self, x: torch.Tensor) -> DiscriminatorOutput: 128 | logits = [] 129 | fmaps = [] 130 | for disc in self.discriminators: 131 | logit, fmap = disc(x) 132 | logits.append(logit) 133 | fmaps.append(fmap) 134 | return logits, fmaps 135 | 136 | 137 | def test(): 138 | disc = MultiScaleSTFTDiscriminator(filters=32) 139 | y = torch.randn(1, 1, 24000) 140 | y_hat = torch.randn(1, 1, 24000) 141 | 142 | y_disc_r, fmap_r = disc(y) 143 | #print('y_disc_r ', len(y_disc_r)) 144 | # print('fmap_r ', len(fmap_r)) 145 | y_disc_gen, fmap_gen = disc(y_hat) 146 | # print('y_disc_gen ', y_disc_gen.shape) 147 | # print('fmap_gen ', len(fmap_gen)) 148 | assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators 149 | 150 | assert all([len(fm) == 5 for fm in fmap_r + fmap_gen]) 151 | assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm]) 152 | assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen]) 153 | 154 | 155 | if __name__ == '__main__': 156 | test() 157 | -------------------------------------------------------------------------------- /descriptaudiocodec/dac/model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from audiotools import AudioSignal 5 | from audiotools import ml 6 | from audiotools import STFTParams 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | 11 | def WNConv1d(*args, **kwargs): 12 | act = kwargs.pop("act", True) 13 | conv = weight_norm(nn.Conv1d(*args, **kwargs)) 14 | if not act: 15 | return conv 16 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 17 | 18 | 19 | def WNConv2d(*args, **kwargs): 20 | act = kwargs.pop("act", True) 21 | conv = weight_norm(nn.Conv2d(*args, **kwargs)) 22 | if not act: 23 | return conv 24 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 25 | 26 | 27 | class MPD(nn.Module): 28 | def __init__(self, period): 29 | super().__init__() 30 | self.period = period 31 | self.convs = nn.ModuleList( 32 | [ 33 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), 34 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), 35 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), 36 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), 37 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), 38 | ] 39 | ) 40 | self.conv_post = WNConv2d( 41 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False 42 | ) 43 | 44 | def pad_to_period(self, x): 45 | t = x.shape[-1] 46 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 47 | return x 48 | 49 | def forward(self, x): 50 | fmap = [] 51 | 52 | x = self.pad_to_period(x) 53 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 54 | 55 | for layer in self.convs: 56 | x = layer(x) 57 | fmap.append(x) 58 | 59 | x = self.conv_post(x) 60 | fmap.append(x) 61 | 62 | return fmap 63 | 64 | 65 | class MSD(nn.Module): 66 | def __init__(self, rate: int = 1, sample_rate: int = 44100): 67 | super().__init__() 68 | self.convs = nn.ModuleList( 69 | [ 70 | WNConv1d(1, 16, 15, 1, padding=7), 71 | WNConv1d(16, 64, 41, 4, groups=4, padding=20), 72 | WNConv1d(64, 256, 41, 4, groups=16, padding=20), 73 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20), 74 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), 75 | WNConv1d(1024, 1024, 5, 1, padding=2), 76 | ] 77 | ) 78 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) 79 | self.sample_rate = sample_rate 80 | self.rate = rate 81 | 82 | def forward(self, x): 83 | x = AudioSignal(x, self.sample_rate) 84 | x.resample(self.sample_rate // self.rate) 85 | x = x.audio_data 86 | 87 | fmap = [] 88 | 89 | for l in self.convs: 90 | x = l(x) 91 | fmap.append(x) 92 | x = self.conv_post(x) 93 | fmap.append(x) 94 | 95 | return fmap 96 | 97 | 98 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 99 | 100 | 101 | class MRD(nn.Module): 102 | def __init__( 103 | self, 104 | window_length: int, 105 | hop_factor: float = 0.25, 106 | sample_rate: int = 44100, 107 | bands: list = BANDS, 108 | ): 109 | """Complex multi-band spectrogram discriminator. 110 | Parameters 111 | ---------- 112 | window_length : int 113 | Window length of STFT. 114 | hop_factor : float, optional 115 | Hop factor of the STFT, defaults to ``0.25 * window_length``. 116 | sample_rate : int, optional 117 | Sampling rate of audio in Hz, by default 44100 118 | bands : list, optional 119 | Bands to run discriminator over. 120 | """ 121 | super().__init__() 122 | 123 | self.window_length = window_length 124 | self.hop_factor = hop_factor 125 | self.sample_rate = sample_rate 126 | self.stft_params = STFTParams( 127 | window_length=window_length, 128 | hop_length=int(window_length * hop_factor), 129 | match_stride=True, 130 | ) 131 | 132 | n_fft = window_length // 2 + 1 133 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 134 | self.bands = bands 135 | 136 | ch = 32 137 | convs = lambda: nn.ModuleList( 138 | [ 139 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), 140 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 143 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), 144 | ] 145 | ) 146 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 147 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) 148 | 149 | def spectrogram(self, x): 150 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) 151 | x = torch.view_as_real(x.stft()) 152 | x = rearrange(x, "b 1 f t c -> (b 1) c t f") 153 | # Split into bands 154 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 155 | return x_bands 156 | 157 | def forward(self, x): 158 | x_bands = self.spectrogram(x) 159 | fmap = [] 160 | 161 | x = [] 162 | for band, stack in zip(x_bands, self.band_convs): 163 | for layer in stack: 164 | band = layer(band) 165 | fmap.append(band) 166 | x.append(band) 167 | 168 | x = torch.cat(x, dim=-1) 169 | x = self.conv_post(x) 170 | fmap.append(x) 171 | 172 | return fmap 173 | 174 | 175 | class Discriminator(ml.BaseModel): 176 | def __init__( 177 | self, 178 | rates: list = [], 179 | periods: list = [2, 3, 5, 7, 11], 180 | fft_sizes: list = [2048, 1024, 512], 181 | sample_rate: int = 44100, 182 | bands: list = BANDS, 183 | ): 184 | """Discriminator that combines multiple discriminators. 185 | 186 | Parameters 187 | ---------- 188 | rates : list, optional 189 | sampling rates (in Hz) to run MSD at, by default [] 190 | If empty, MSD is not used. 191 | periods : list, optional 192 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] 193 | fft_sizes : list, optional 194 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] 195 | sample_rate : int, optional 196 | Sampling rate of audio in Hz, by default 44100 197 | bands : list, optional 198 | Bands to run MRD at, by default `BANDS` 199 | """ 200 | super().__init__() 201 | discs = [] 202 | discs += [MPD(p) for p in periods] 203 | discs += [MSD(r, sample_rate=sample_rate) for r in rates] 204 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] 205 | self.discriminators = nn.ModuleList(discs) 206 | 207 | def preprocess(self, y): 208 | # Remove DC offset 209 | y = y - y.mean(dim=-1, keepdims=True) 210 | # Peak normalize the volume of input audio 211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 212 | return y 213 | 214 | def forward(self, x): 215 | x = self.preprocess(x) 216 | fmaps = [d(x) for d in self.discriminators] 217 | return fmaps 218 | 219 | 220 | if __name__ == "__main__": 221 | disc = Discriminator() 222 | x = torch.zeros(1, 1, 44100) 223 | results = disc(x) 224 | for i, result in enumerate(results): 225 | print(f"disc{i}") 226 | for i, r in enumerate(result): 227 | print(r.shape, r.mean(), r.min(), r.max()) 228 | print() 229 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![arXiv](https://img.shields.io/badge/arXiv-2408.17175-brightgreen.svg?style=flat-square)](https://arxiv.org/pdf/2408.17175) 3 | # X-Codec 4 | 5 | Unified Semantic and Acoustic Codec for Audio Language Model. 6 | 7 | # X-Codec-2.0 released! 8 | 9 | 10 | # Paper 11 | 12 | 13 | **Title**: Codec Does Matter: Exploring the Semantic Shortcoming of Codec for Audio Language Model (AAAI 2025) 14 | 15 | **Authors**: Zhen Ye, Peiwen Sun, Jiahe Lei, Hongzhan Lin, Xu Tan, Zheqi Dai, Qiuqiang Kong, Jianyi Chen, Jiahao Pan, Qifeng Liu, Yike Guo*, Wei Xue* 16 | 17 | Overview 18 | 19 | # Experiments on VALL-E 20 | Exp 21 | 22 | 23 | 24 | 27 | 28 | # Highlight 29 | 30 | You can easily apply our approach to enhance any existing acoustic codec: 31 | 32 | For example 33 | 34 | ```python 35 | class Codec(): 36 | def __init__(self): 37 | # Acoustic codec components 38 | self.encoder = Encoder(...) # Acoustic encoder 39 | self.decoder = Decoder(...) # Acoustic decoder 40 | self.quantizer = RVQ(...) # Residual Vector Quantizer (RVQ) 41 | 42 | # Adding the semantic module 43 | self.semantic_model = AutoModel.from_pretrained(...) # e.g., Hubert, WavLM 44 | 45 | # Adding Projector 46 | self.fc_prior = nn.Linear(...) 47 | self.fc_post1 = nn.Linear(...) 48 | self.fc_post2 = nn.Linear(...) 49 | 50 | def forward(self, x, bw): 51 | # Encode the input acoustically and semantically 52 | e_acoustic = self.encoder(x) 53 | e_semantic = self.semantic_model(x) 54 | 55 | # Combine acoustic and semantic features 56 | combined_features = torch.cat([e_acoustic, e_semantic]) 57 | 58 | # Apply prior transformation 59 | transformed_features = self.fc_prior(combined_features) 60 | 61 | # Quantize the unified semantic and acoustic features 62 | quantized, codes, bandwidth, commit_loss = self.quantizer(transformed_features, bw) 63 | 64 | # Post-process the quantized features 65 | quantized_semantic = self.fc_post1(quantized) 66 | quantized_acoustic = self.fc_post2(quantized) 67 | 68 | # Decode the quantized acoustic features 69 | output = self.decoder(quantized_acoustic) 70 | 71 | 72 | 73 | def semantic_loss(self,semantic,quantized_semantic): 74 | return F.mse_loss(semantic,quantized_semantic) 75 | ``` 76 | For more details, please refer to our code. 77 | 78 | # Available models 79 | 80 | X-Codec is part of Hugging Face's Transformers library (see [model documentation](https://huggingface.co/docs/transformers/en/model_doc/xcodec)). 81 | 82 | Below are the Transformers-compatible checkpoints on the Hugging Face Hub 🤗 83 | 84 | | Model checkpoint | Semantic Model | Domain | Training Data | 85 | |--------------------------------------------|-----------------------------------------------------------------------|---------------|-------------------------------| 86 | | [xcodec-hubert-librispeech](https://huggingface.co/hf-audio/xcodec-hubert-librispeech) | [facebook/hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960) | Speech | Librispeech | 87 | | [xcodec-wavlm-mls](https://huggingface.co/hf-audio/xcodec-wavlm-mls) (not mentioned in paper) | [microsoft/wavlm-base-plus](https://huggingface.co/microsoft/wavlm-base-plus)| Speech | MLS English | 88 | | [xcodec-wavlm-more-data](https://huggingface.co/hf-audio/xcodec-wavlm-more-data) (not mentioned in paper) | [microsoft/wavlm-base-plus](https://huggingface.co/microsoft/wavlm-base-plus)| Speech | MLS English + Internal data | 89 | | [xcodec-hubert-general](https://huggingface.co/hf-audio/xcodec-hubert-general) | [ZhenYe234/hubert_base_general_audio](https://huggingface.co/ZhenYe234/hubert_base_general_audio) | General audio | 200k hours internal data | 90 | | [xcodec-hubert-general-balanced](https://huggingface.co/hf-audio/xcodec-hubert-general-balanced) (not mentioned in paper) | [ZhenYe234/hubert_base_general_audio](https://huggingface.co/ZhenYe234/hubert_base_general_audio) | General audio | More balanced data | 91 | 92 | 93 | Below are the original checkpoints. 94 | 95 | | Model name | Hugging Face | Config | Semantic Model | Domain | Training Data | 96 | |---------------------------------------------|--------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------|---------------|-------------------------------| 97 | | xcodec_hubert_librispeech | [🤗](https://huggingface.co/ZhenYe234/xcodec/blob/main/xcodec_speech_hubert_librispeech.pth) | [🤗](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_hubert.yaml) | [🤗 Hubert-base](https://huggingface.co/facebook/hubert-base-ls960) | Speech | Librispeech | 98 | | xcodec_wavlm_mls (not mentioned in paper) | [🤗](https://huggingface.co/ZhenYe234/xcodec/blob/main/xcodec_speech_wavlm_mls.pth) | [🤗](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_wavlm.yaml) | [🤗 Wavlm-base-plus](https://huggingface.co/microsoft/wavlm-base-plus) | Speech | MLS English | 99 | | xcodec_wavlm_more_data (not mentioned in paper) | [🤗](https://huggingface.co/ZhenYe234/xcodec/blob/main/xcodec_speech_wavlm_more_data.pth) | [🤗](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_wavlm.yaml) | [🤗 Wavlm-base-plus](https://huggingface.co/microsoft/wavlm-base-plus) | Speech | MLS English + Internal data | 100 | | xcodec_hubert_general_audio | [🤗](https://huggingface.co/ZhenYe234/xcodec/blob/main/xcodec_hubert_general_audio.pth) | [🤗](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_hubert_general.yaml) | [🤗Hubert-base-general-audio](https://huggingface.co/ZhenYe234/hubert_base_general_audio) | General audio | 200k hours internal data | 101 | | xcodec_hubert_general_audio_more_data (not mentioned in paper) | [🤗](https://huggingface.co/ZhenYe234/xcodec/blob/main/xcodec_hubert_general_audio_v2.pth) | [🤗](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_hubert_general.yaml) | [🤗Hubert-base-general-audio](https://huggingface.co/ZhenYe234/hubert_base_general_audio) | General audio | More balanced data | 102 | 103 | 104 | 105 | 106 | 107 | # Inference 108 | 109 | To run inference, first download the model and config from hugging face. 110 | 111 | ```bash 112 | python inference.py 113 | ``` 114 | 115 | # Training 116 | Prepare the training_file and validation_file in config. The file should list the paths to your audio files: 117 | ```bash 118 | /path/to/your/xxx.wav 119 | /path/to/your/yyy.wav 120 | ... 121 | ``` 122 | Then: 123 | 124 | ```bash 125 | torchrun --nnodes=1 --nproc-per-node=8 main_launch_vqdp.py 126 | ``` 127 | 128 | ## Acknowledgement 129 | I would like to extend a special thanks to authors of Uniaudio and DAC, since our code base is mainly borrowed from [Uniaudio](https://github.com/yangdongchao/UniAudio/tree/main/codec) and [DAC](https://github.com/descriptinc/descript-audio-codec). 130 | 131 | ## Citation 132 | If you find this repo helpful, please consider citing in the following format: 133 | 134 | ```bibtex 135 | @article{ye2024codecdoesmatterexploring, 136 | title={Codec Does Matter: Exploring the Semantic Shortcoming of Codec for Audio Language Model}, 137 | author={Zhen Ye and Peiwen Sun and Jiahe Lei and Hongzhan Lin and Xu Tan and Zheqi Dai and Qiuqiang Kong and Jianyi Chen and Jiahao Pan and Qifeng Liu and Yike Guo and Wei Xue}, 138 | journal={arXiv preprint arXiv:2408.17175}, 139 | year={2024}, 140 | } 141 | ``` 142 | --------------------------------------------------------------------------------