├── .gitignore ├── README.md ├── encodecmae ├── Dockerfile ├── __init__.py ├── configs │ ├── base │ │ ├── create_st_dataset.gin │ │ ├── self-train.gin │ │ └── unsupervised_pretrain.gin │ ├── datasets │ │ ├── audioset-unbalanced-24k.gin │ │ ├── fma-large-24k.gin │ │ └── librilight-6k-24k.gin │ ├── features │ │ ├── mel.gin │ │ ├── spec.gin │ │ ├── st.gin │ │ └── wav_only.gin │ ├── imports │ └── models │ │ └── encodecmae.gin ├── docker-compose.yml ├── heareval_model │ ├── __init__.py │ └── encodecmae.py ├── hub.py ├── models │ ├── __init__.py │ ├── encoders │ │ └── __init__.py │ ├── heads │ │ └── __init__.py │ ├── losses │ │ └── __init__.py │ ├── mae.py │ ├── masks │ │ └── __init__.py │ ├── targets │ │ └── __init__.py │ └── transformers │ │ └── __init__.py ├── scripts │ └── run_pretraining.sh └── tasks │ ├── __init__.py │ ├── data │ └── __init__.py │ ├── features │ └── __init__.py │ └── utils │ └── __init__.py ├── requirements.txt ├── setup.cfg ├── setup.py └── start_docker.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

EnCodecMAE: Leveraging neural codecs for universal audio representation learning

2 | 3 |

4 | 5 | read the paper 6 | 7 | 8 | run in colab 9 | 10 | Cite - Bibtex 11 |

12 | 13 | This is EnCodecMAE, an audio feature extractor pretrained with masked language modelling to predict discrete targets generated by EnCodec, a neural audio codec. 14 | For more details about the architecture and pretraining procedure, read the [paper](https://arxiv.org/abs/2309.07391). 15 | 16 | ## Updates: 17 | - 2024/5/23 Updated paper in arxiv. New models with better performance across all downstream tasks are available for feature extraction. Code for older version is [here](https://github.com/habla-liaa/encodecmae/tree/v.1.0.0) 18 | - 2024/2/29 [New code](https://github.com/mrpep/encodecmae-to-wav) to go from encodecmae to the waveform domain, with pretrained generative audio models from [this paper](https://arxiv.org/abs/2402.09318). 19 | - 2024/2/14 [Leveraging Pre-Trained Autoencoders for Interpretable Prototype Learning of Music Audio](https://arxiv.org/abs/2402.09318) was accepted to ICASSP 2024 XAI Workshop. 20 | - 2023/10/23 [Prompting for audio generation](https://mrpep.github.io/myblog/posts/audio-lm/). 21 | 22 | ## Usage 23 | 24 | ### Feature extraction using pretrained models 25 | 26 | #### Try our example [Colab notebook](https://colab.research.google.com/drive/123Zn6h0DRVcjsLFp8Xl4j0PZlZ-7VsK2?usp=sharing) or 27 | 28 | #### 1) Clone the [EnCodecMAE library](https://github.com/habla-liaa/encodecmae): 29 | ``` 30 | git clone https://github.com/habla-liaa/encodecmae.git 31 | ``` 32 | 33 | #### 2) Install it: 34 | 35 | ``` 36 | cd encodecmae 37 | pip install -e . 38 | ``` 39 | 40 | #### 3) Extract embeddings in Python: 41 | 42 | ``` python 43 | from encodecmae import load_model 44 | 45 | model = load_model('mel256-ec-base_st', device='cuda:0') 46 | features = model.extract_features_from_file('gsc/bed/00176480_nohash_0.wav') 47 | ``` 48 | 49 | ### Pretrain your models 50 | 51 | #### 1) Install docker and docker-compose in your system. You'll also need to install nvidia-container toolkit to access GPUs from a docker container. 52 | #### 2) Execute the start_docker.sh script 53 | 54 | First, docker-compose.yml has to be modified. In the volumes section, change the routes to the ones in your system. You'll need a folder called datasets with the following subfolders: 55 | - audioset_24k/unbalanced_train 56 | - fma_large_24k 57 | - librilight_med_24k 58 | 59 | All the audio files need to be converted to a 24kHz sampling rate. 60 | 61 | You might also modify the device_ids if you have a different number of gpus. 62 | 63 | Then, run: 64 | ``` 65 | chmod +x start_docker.sh 66 | ./start_docker.sh 67 | ``` 68 | This will build the encodecmae image, start a container using docker compose, and attach to it. 69 | 70 | #### 3) Install the encodecmae package inside the container 71 | ``` 72 | cd workspace/encodecmae 73 | pip install -e . 74 | ``` 75 | 76 | #### 4) Run the training script 77 | ``` 78 | chmod +x scripts/run_pretraining.sh 79 | scripts/run_pretraining.sh 80 | ``` 81 | 82 | The training script uses my own library for executing pipelines configured with gin: ginpipe. By modifying the config files (with .gin extension), you can control aspects of the training and the model configuration. I plan to explain my approach to ML pipelines, and how to use gin and ginpipe in a future blog article. Stay tuned! 83 | -------------------------------------------------------------------------------- /encodecmae/Dockerfile: -------------------------------------------------------------------------------- 1 | from nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04 2 | 3 | ENV TZ=America/Argentina/Buenos_Aires 4 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 5 | 6 | COPY requirements.txt . 7 | RUN apt-get --fix-missing update && apt-get update && apt-get install -y \ 8 | python3.9 \ 9 | python3-pip \ 10 | git \ 11 | sox 12 | RUN pip3 install -r requirements.txt 13 | 14 | -------------------------------------------------------------------------------- /encodecmae/__init__.py: -------------------------------------------------------------------------------- 1 | from .hub import load_model -------------------------------------------------------------------------------- /encodecmae/configs/base/create_st_dataset.gin: -------------------------------------------------------------------------------- 1 | DEVICE='cuda:0' 2 | TRAIN_BATCH_SIZE=32 3 | TRAIN_DATALOADER_NUM_WORKERS=8 4 | MAX_AUDIO_DURATION=4 5 | FILTER_AUDIO_LENGTH=10000 6 | TEACHER_LAYER=-1 7 | POSTNORM_LAST_ACTIVATION=True 8 | 9 | $keys_not_saved+=['datasets','dataloaders'] 10 | 11 | execute_pipeline: 12 | tasks = [@encodecmae.tasks.utils.set_seed, 13 | @encodecmae.tasks.load_model, 14 | @encodecmae.tasks.data.load_dataset, 15 | @encodecmae.tasks.data.get_dataloaders, 16 | @encodecmae.tasks.create_st_dataset] 17 | 18 | encodecmae.tasks.load_model: 19 | model = %TEACHER_MODEL 20 | 21 | encodecmae.tasks.data.get_dataloaders.dataset_cls={'train': @train/encodecmae.tasks.data.DictDataset} 22 | encodecmae.tasks.data.get_dataloaders.dataloader_cls={'train': @train/torch.utils.data.DataLoader} 23 | 24 | train/torch.utils.data.DataLoader: 25 | shuffle=True 26 | batch_size=%TRAIN_BATCH_SIZE 27 | num_workers=%TRAIN_DATALOADER_NUM_WORKERS 28 | collate_fn=@tasks.data.dynamic_pad_batch 29 | 30 | encodecmae.tasks.data.DictDataset.index_mapper=@encodecmae.tasks.data.compensate_lengths 31 | encodecmae.tasks.data.compensate_lengths.chunk_length=%MAX_AUDIO_DURATION #This will sample long audios multiple times during one epoch (duration//compensate_framing times) 32 | 33 | encodecmae.tasks.data.load_dataset.postprocessors=[@encodecmae.tasks.data.remove_long_audios] 34 | encodecmae.tasks.data.remove_long_audios.limit=%FILTER_AUDIO_LENGTH 35 | 36 | encodecmae.tasks.create_st_dataset: 37 | device=%DEVICE 38 | layer=%TEACHER_LAYER 39 | postnorm_last_activation=%POSTNORM_LAST_ACTIVATION -------------------------------------------------------------------------------- /encodecmae/configs/base/self-train.gin: -------------------------------------------------------------------------------- 1 | SEED=42 2 | TRAIN_BATCH_SIZE=128 3 | VAL_BATCH_SIZE=8 4 | TRAIN_DATALOADER_NUM_WORKERS=8 5 | VAL_DATALOADER_NUM_WORKERS=4 6 | MAX_AUDIO_DURATION=4 7 | MAX_LR=0.0001 8 | GRAD_ACC=1 9 | TOTAL_PRETRAIN_STEPS=150000 10 | CHECKPOINT_INTERVAL=50000 11 | SELFTRAIN_CHECKPOINT='last' 12 | VAL_SET_SIZE=200 13 | FILTER_AUDIO_LENGTH=10000 #Some files might be too long according to mediainfo so they are discarded. 14 | DEVICE=[0] 15 | PRECISION=16 16 | NUM_TOTAL_TARGETS=1 17 | 18 | $keys_not_saved=['datasets','dataloaders'] 19 | 20 | execute_pipeline: 21 | tasks = [@encodecmae.tasks.utils.set_seed, 22 | @encodecmae.tasks.data.load_dataset, 23 | @encodecmae.tasks.data.get_dataloaders, 24 | @encodecmae.tasks.fit_model] 25 | execution_order = 'sequential' 26 | 27 | train/torch.utils.data.DataLoader: 28 | shuffle=True 29 | batch_size=%TRAIN_BATCH_SIZE 30 | num_workers=%TRAIN_DATALOADER_NUM_WORKERS 31 | collate_fn=@encodecmae.tasks.data.dynamic_pad_batch 32 | 33 | val/torch.utils.data.DataLoader: 34 | shuffle=False 35 | batch_size=%VAL_BATCH_SIZE 36 | num_workers=%VAL_DATALOADER_NUM_WORKERS 37 | collate_fn=@encodecmae.tasks.data.dynamic_pad_batch 38 | 39 | encodecmae.tasks.data.get_dataloaders.split_function=@encodecmae.tasks.data.dataset_random_split 40 | encodecmae.tasks.data.get_dataloaders.dataset_cls={'train': @train/encodecmae.tasks.data.DictDataset, 'validation': @val/encodecmae.tasks.data.DictDataset} 41 | encodecmae.tasks.data.get_dataloaders.dataloader_cls={'train': @train/torch.utils.data.DataLoader, 'validation': @val/torch.utils.data.DataLoader} 42 | 43 | encodecmae.tasks.data.load_dataset.reader_fns+=[@encodecmae.tasks.data.read_st_dataset] 44 | encodecmae.tasks.data.read_st_dataset.dataset_path=%ST_DATASET_DIR 45 | 46 | encodecmae.tasks.fit_model: 47 | trainer_cls=@pl.Trainer 48 | from_checkpoint=%ST_CHECKPOINT 49 | checkpoint_folder='pretrain_checkpoints' 50 | 51 | pl.Trainer: 52 | logger=@pl.loggers.CSVLogger() 53 | devices=%DEVICE 54 | callbacks=[@pl.callbacks.ModelCheckpoint(), @pl.callbacks.LearningRateMonitor()] 55 | max_steps=%TOTAL_PRETRAIN_STEPS 56 | accelerator='gpu' 57 | accumulate_grad_batches=%GRAD_ACC 58 | num_sanity_val_steps=1 59 | val_check_interval=%CHECKPOINT_INTERVAL 60 | precision=%PRECISION 61 | check_val_every_n_epoch=None 62 | 63 | pl.callbacks.ModelCheckpoint: 64 | dirpath=%OUTPUT_DIR 65 | every_n_train_steps=%CHECKPOINT_INTERVAL 66 | save_top_k=-1 #Keep all the checkpoints 67 | 68 | pl.loggers.CSVLogger: 69 | save_dir=%OUTPUT_DIR 70 | name='pretrain_logs' 71 | 72 | encodecmae.tasks.data.dataset_random_split: 73 | proportions={'train':-1,'validation':%VAL_SET_SIZE} -------------------------------------------------------------------------------- /encodecmae/configs/base/unsupervised_pretrain.gin: -------------------------------------------------------------------------------- 1 | SEED=42 2 | TRAIN_BATCH_SIZE=128 3 | VAL_BATCH_SIZE=8 4 | TRAIN_DATALOADER_NUM_WORKERS=8 5 | VAL_DATALOADER_NUM_WORKERS=4 6 | MAX_AUDIO_DURATION=4 7 | MAX_LR=0.0001 8 | GRAD_ACC=1 9 | TOTAL_PRETRAIN_STEPS=500000 10 | CHECKPOINT_INTERVAL=50000 11 | INITIAL_CHECKPOINT='last' 12 | VAL_SET_SIZE=200 13 | FILTER_AUDIO_LENGTH=10000 #Some files might be too long according to mediainfo so they are discarded. 14 | DEVICE=[0] 15 | PRECISION=16 16 | 17 | $keys_not_saved=['datasets','dataloaders'] 18 | 19 | execute_pipeline: 20 | tasks = [@tasks.utils.set_seed, 21 | @tasks.data.load_dataset, 22 | @tasks.data.get_dataloaders, 23 | @tasks.fit_model] 24 | execution_order = 'sequential' 25 | 26 | tasks.utils.set_seed.seed=%SEED 27 | train/torch.utils.data.DataLoader: 28 | shuffle=True 29 | batch_size=%TRAIN_BATCH_SIZE 30 | num_workers=%TRAIN_DATALOADER_NUM_WORKERS 31 | collate_fn=@tasks.data.dynamic_pad_batch 32 | 33 | val/torch.utils.data.DataLoader: 34 | shuffle=False 35 | batch_size=%VAL_BATCH_SIZE 36 | num_workers=%VAL_DATALOADER_NUM_WORKERS 37 | collate_fn=@tasks.data.dynamic_pad_batch 38 | 39 | tasks.fit_model: 40 | trainer_cls=@pl.Trainer 41 | from_checkpoint=%INITIAL_CHECKPOINT 42 | checkpoint_folder='pretrain_checkpoints' 43 | 44 | pl.Trainer: 45 | logger=@pl.loggers.CSVLogger() 46 | devices=%DEVICE 47 | callbacks=[@pl.callbacks.ModelCheckpoint(), @pl.callbacks.LearningRateMonitor()] 48 | max_steps=%TOTAL_PRETRAIN_STEPS 49 | accelerator='gpu' 50 | accumulate_grad_batches=%GRAD_ACC 51 | num_sanity_val_steps=1 52 | val_check_interval=%CHECKPOINT_INTERVAL 53 | precision=%PRECISION 54 | check_val_every_n_epoch=None 55 | 56 | pl.callbacks.ModelCheckpoint: 57 | dirpath=%OUTPUT_DIR 58 | every_n_train_steps=%CHECKPOINT_INTERVAL 59 | save_top_k=-1 #Keep all the checkpoints 60 | 61 | pl.loggers.CSVLogger: 62 | save_dir=%OUTPUT_DIR 63 | name='pretrain_logs' 64 | 65 | tasks.data.get_dataloaders.split_function=@tasks.data.dataset_random_split 66 | tasks.data.get_dataloaders.dataset_cls={'train': @train/tasks.data.DictDataset, 'validation': @val/tasks.data.DictDataset} 67 | tasks.data.get_dataloaders.dataloader_cls={'train': @train/torch.utils.data.DataLoader, 'validation': @val/torch.utils.data.DataLoader} 68 | 69 | tasks.data.dataset_random_split: 70 | proportions={'train':-1,'validation':%VAL_SET_SIZE} 71 | 72 | tasks.data.DictDataset.index_mapper=@tasks.data.compensate_lengths 73 | tasks.data.compensate_lengths.chunk_length=%MAX_AUDIO_DURATION #This will sample long audios multiple times during one epoch (duration//compensate_framing times) 74 | 75 | tasks.data.load_dataset.postprocessors=[@tasks.data.remove_long_audios] 76 | tasks.data.remove_long_audios.limit=%FILTER_AUDIO_LENGTH 77 | -------------------------------------------------------------------------------- /encodecmae/configs/datasets/audioset-unbalanced-24k.gin: -------------------------------------------------------------------------------- 1 | load_dataset.reader_fns+=[@audioset_unbalanced_24k/read_audiodir] 2 | audioset_unbalanced_24k/read_audiodir.dataset_path='/workspace/datasets/audioset_24k/unbalanced_train' 3 | audioset_unbalanced_24k/read_audiodir.dataset='audioset' -------------------------------------------------------------------------------- /encodecmae/configs/datasets/fma-large-24k.gin: -------------------------------------------------------------------------------- 1 | load_dataset.reader_fns+=[@fma_large_24k/read_audiodir] 2 | fma_large_24k/read_audiodir.dataset_path='/workspace/datasets/fma_large_24k' 3 | fma_large_24k/read_audiodir.dataset='fma' 4 | -------------------------------------------------------------------------------- /encodecmae/configs/datasets/librilight-6k-24k.gin: -------------------------------------------------------------------------------- 1 | load_dataset.reader_fns+=[@librilight_6k_24k/read_audiodir] 2 | librilight_6k_24k/read_audiodir.dataset_path='/workspace/datasets/librilight_med_24k' 3 | librilight_6k_24k/read_audiodir.dataset='librilight' -------------------------------------------------------------------------------- /encodecmae/configs/features/mel.gin: -------------------------------------------------------------------------------- 1 | NUM_MEL_BINS=256 2 | 3 | encodecmae.tasks.data.DictDataset: 4 | out_cols=['wav','wav_features'] 5 | preprocessor=@encodecmae.tasks.features.SequentialProcessor 6 | #Processor: 7 | encodecmae.tasks.features.SequentialProcessor: 8 | processors=[@encodecmae.tasks.features.ReadAudioProcessor, @encodecmae.tasks.features.MelspectrogramProcessor] 9 | encodecmae.tasks.features.ReadAudioProcessor: 10 | key_in = 'filename' 11 | key_out = 'wav' 12 | max_length = %MAX_AUDIO_DURATION 13 | encodecmae.tasks.features.MelspectrogramProcessor: 14 | key_in = 'wav' 15 | key_out = 'wav_features' 16 | sample_frequency=24000 17 | frame_shift=13.28 18 | frame_length=26.56 19 | htk_compat=True 20 | use_energy=False 21 | window_type='hanning' 22 | num_mel_bins=%NUM_MEL_BINS 23 | dither=0.0 24 | norm_stats=[-6.12, 4.82] 25 | 26 | #Wav encoder: 27 | encodecmae.models.encoders.WavEncoder: 28 | encoder = @torch.nn.Identity 29 | post_net = @wav_encoder_proj/torch.nn.Linear 30 | hop_length = 320 31 | fs = 24000 32 | key_in = 'wav_features' 33 | key_out = 'wav_features' 34 | wav_encoder_proj/torch.nn.Linear: 35 | in_features = %NUM_MEL_BINS 36 | out_features = %MODEL_DIM 37 | 38 | #Target: 39 | encodecmae.models.targets.EncodecQuantizer: 40 | key_in = 'wav' 41 | use_encodec_encoder = True 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /encodecmae/configs/features/spec.gin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/habla-liaa/encodecmae/2dd6759c595bf20c9d836914fef5cde9a6386843/encodecmae/configs/features/spec.gin -------------------------------------------------------------------------------- /encodecmae/configs/features/st.gin: -------------------------------------------------------------------------------- 1 | NUM_TOTAL_TARGETS=1 2 | 3 | encodecmae.tasks.data.DictDataset.out_cols+=['targets'] 4 | 5 | encodecmae.tasks.features.SequentialProcessor.processors+=[@encodecmae.tasks.features.LoadNumpyProcessor] 6 | encodecmae.tasks.features.ReadAudioProcessor.key_in = 'filename_audio' 7 | encodecmae.tasks.features.LoadNumpyProcessor: 8 | key_in = 'filename_targets' 9 | key_out = 'targets' 10 | 11 | encodecmae.models.EncodecMAE.target_encoder=@encodecmae.models.targets.IdentityTarget 12 | encodecmae.models.targets.IdentityTarget: 13 | key_in = 'targets' 14 | key_out = 'targets' 15 | encodecmae.models.losses.EnCodecMAEClassificationLoss.quantizer_weights=[1.] -------------------------------------------------------------------------------- /encodecmae/configs/features/wav_only.gin: -------------------------------------------------------------------------------- 1 | encodecmae.tasks.data.DictDataset: 2 | out_cols=['wav'] 3 | preprocessor=@encodecmae.tasks.features.SequentialProcessor 4 | 5 | encodecmae.tasks.features.SequentialProcessor: 6 | processors=[@encodecmae.tasks.features.ReadAudioProcessor] 7 | 8 | encodecmae.tasks.features.ReadAudioProcessor: 9 | key_in = 'filename' 10 | key_out = 'wav' 11 | max_length = %MAX_AUDIO_DURATION -------------------------------------------------------------------------------- /encodecmae/configs/imports: -------------------------------------------------------------------------------- 1 | encodecmae.models.encoders: encodecmae.models.encoders 2 | encodecmae.models.heads: encodecmae.models.heads 3 | encodecmae.models.losses: encodecmae.models.losses 4 | encodecmae.models.masks: encodecmae.models.masks 5 | encodecmae.models.targets: encodecmae.models.targets 6 | encodecmae.models.transformers: encodecmae.models.transformers 7 | encodecmae.models: encodecmae.models 8 | encodecmae.tasks: encodecmae.tasks 9 | encodecmae.tasks.utils: encodecmae.tasks.utils 10 | encodecmae.tasks.data: encodecmae.tasks.data 11 | encodecmae.tasks.features: encodecmae.tasks.features 12 | pytorch_lightning: pl 13 | pytorch_lightning.Trainer: pl.Trainer 14 | pytorch_lightning.loggers: pl.loggers 15 | pytorch_lightning.callbacks: pl.callbacks 16 | torch.utils.data: torch.utils.data 17 | torch.optim: torch.optim 18 | torchmetrics: torchmetrics 19 | torch.nn: torch.nn -------------------------------------------------------------------------------- /encodecmae/configs/models/encodecmae.gin: -------------------------------------------------------------------------------- 1 | NUM_ENCODEC_TARGETS=8 2 | NUM_TOTAL_TARGETS=8 3 | NUM_TARGET_TOKENS=1024 4 | MASK_GAP_SIZE=15 5 | MASK_PROP=0.5 6 | MODEL_DIM=768 7 | NUM_ENCODER_LAYERS=10 8 | NUM_ENCODER_HEADS=12 9 | NUM_DECODER_LAYERS=2 10 | NUM_DECODER_HEADS=12 11 | MASKED_LOSS_WEIGHT=0.9 12 | WAV_FEATURE_DIM=128 13 | QUANTIZER_WEIGHTS=[0.22407463, 0.1759858 , 0.14499009, 0.12150037, 0.10315603, 0.08831368, 0.07608274, 0.06589669] 14 | RETURN_ONLY_LAST_Q=False 15 | 16 | #Global settings: 17 | encodecmae.tasks.fit_model.model_cls=@encodecmae.models.EncodecMAE 18 | encodecmae.models.EncodecMAE: 19 | wav_encoder = @encodecmae.models.encoders.WavEncoder 20 | target_encoder = @encodecmae.models.targets.EncodecQuantizer 21 | masker = @encodecmae.models.masks.PatchoutMask 22 | visible_encoder = @encoder/encodecmae.models.transformers.TransformerEncoder 23 | decoder = @decoder/encodecmae.models.transformers.TransformerEncoder 24 | head = @encodecmae.models.heads.FrameLevelClassificationHead 25 | loss = @encodecmae.models.losses.EnCodecMAEClassificationLoss 26 | optimizer=@torch.optim.AdamW 27 | 28 | #Wav encoder: 29 | encodecmae.models.encoders.WavEncoder: 30 | encoder = @encodecmae.models.encoders.EncodecEncoder 31 | post_net = @wav_encoder_proj/torch.nn.Linear 32 | wav_encoder_proj/torch.nn.Linear: 33 | in_features = %WAV_FEATURE_DIM 34 | out_features = %MODEL_DIM 35 | 36 | #Masking: 37 | encodecmae.models.masks.PatchoutMask: 38 | masker = @encodecmae.models.masks.TimeGapMask 39 | positional_encoder = @encodecmae.models.transformers.SinusoidalPositionalEmbeddings 40 | encodecmae.models.masks.TimeGapMask: 41 | gap_size = %MASK_GAP_SIZE 42 | p_mask = %MASK_PROP 43 | 44 | #Visible encoder: 45 | encoder/encodecmae.models.transformers.TransformerEncoder: 46 | model_dim=%MODEL_DIM 47 | num_layers=%NUM_ENCODER_LAYERS 48 | attention_layer=@encoder/encodecmae.models.transformers.MultiHeadAttention 49 | compile=False 50 | key_in='visible_tokens' 51 | key_padding_mask='visible_padding_mask' 52 | key_out='decoder_in' 53 | key_transformer_in=None 54 | key_transformer_out='visible_embeddings' 55 | post_net=@decoder_proj/torch.nn.Linear 56 | encoder/encodecmae.models.transformers.MultiHeadAttention: 57 | model_dim=%MODEL_DIM 58 | num_heads=%NUM_ENCODER_HEADS 59 | 60 | #Decoder: 61 | decoder_proj/torch.nn.Linear: 62 | in_features=%MODEL_DIM 63 | out_features=%MODEL_DIM 64 | decoder/encodecmae.models.transformers.TransformerEncoder: 65 | model_dim=%MODEL_DIM 66 | num_layers=%NUM_DECODER_LAYERS 67 | attention_layer=@decoder/encodecmae.models.transformers.MultiHeadAttention 68 | compile=False 69 | key_in='decoder_in' 70 | key_padding_mask='feature_padding_mask' 71 | key_out='decoder_out' 72 | positional_encoder = @encodecmae.models.transformers.SinusoidalPositionalEmbeddings 73 | decoder/encodecmae.models.transformers.MultiHeadAttention: 74 | model_dim=%MODEL_DIM 75 | num_heads=%NUM_DECODER_HEADS 76 | 77 | encodecmae.models.transformers.SinusoidalPositionalEmbeddings.embedding_dim = %MODEL_DIM 78 | 79 | #Head: 80 | encodecmae.models.heads.FrameLevelClassificationHead: 81 | model_dim=%MODEL_DIM 82 | num_tokens=%NUM_TARGET_TOKENS 83 | num_streams=%NUM_TOTAL_TARGETS 84 | #Target: 85 | encodecmae.models.targets.EncodecQuantizer: 86 | n = %NUM_ENCODEC_TARGETS 87 | key_in = 'wav_features_encoder_out' 88 | return_only_last = %RETURN_ONLY_LAST_Q 89 | #Loss: 90 | encodecmae.models.losses.EnCodecMAEClassificationLoss: 91 | masked_weight=%MASKED_LOSS_WEIGHT 92 | quantizer_weights=%QUANTIZER_WEIGHTS 93 | #Optimizer: 94 | torch.optim.AdamW: 95 | lr=%MAX_LR 96 | betas=(0.9,0.95) 97 | weight_decay=0.05 -------------------------------------------------------------------------------- /encodecmae/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | services: 3 | encodecmae: 4 | image: encodecmae 5 | container_name: encodecmae-train 6 | volumes: 7 | - /home/lpepino/encodecmae:/workspace/encodecmae 8 | - /mnt/ssd4T/datasets:/workspace/datasets 9 | ipc: host 10 | stdin_open: true 11 | tty: true 12 | deploy: 13 | resources: 14 | reservations: 15 | devices: 16 | - driver: nvidia 17 | device_ids: ['0','1'] 18 | capabilities: [gpu] 19 | -------------------------------------------------------------------------------- /encodecmae/heareval_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/habla-liaa/encodecmae/2dd6759c595bf20c9d836914fef5cde9a6386843/encodecmae/heareval_model/__init__.py -------------------------------------------------------------------------------- /encodecmae/heareval_model/encodecmae.py: -------------------------------------------------------------------------------- 1 | from encodecmae import load_model as load_encodecmae_model 2 | from torch import Tensor 3 | import torch 4 | import sys 5 | from huggingface_hub import hf_hub_download 6 | 7 | def load_model(file_path, huggingface_ckpt=None, layer=-1): 8 | model = load_encodecmae_model(file_path) 9 | if huggingface_ckpt is not None: 10 | print('Loading checkpoint from {}'.format(huggingface_ckpt)) 11 | repo_id = '/'.join(huggingface_ckpt.split('/')[:2]) 12 | filename = '/'.join(huggingface_ckpt.split('/')[2:]) 13 | ckpt_file = hf_hub_download(repo_id=repo_id,filename=filename) 14 | ckpt = torch.load(ckpt_file, map_location='cpu') 15 | model.load_state_dict(ckpt['state_dict'], strict=False) 16 | 17 | model.sample_rate = 24000 18 | model.embedding_rate=75 19 | model.visible_encoder.compile=False 20 | model.head = None 21 | model.extraction_layer = layer 22 | 23 | del model.optimizer 24 | return model 25 | 26 | def get_timestamp_embeddings( 27 | audio: Tensor, 28 | model: torch.nn.Module, 29 | hop_size: float = 13, 30 | ) -> Tensor: 31 | 32 | with torch.no_grad(): 33 | model_device = next(model.parameters()).device 34 | embeddings = model.extract_features_from_array(audio, layer=model.extraction_layer) 35 | embeddings = torch.from_numpy(embeddings).to(model_device) 36 | if (model.extraction_layer == 'all') and embeddings.ndim==3: 37 | embeddings = embeddings.unsqueeze(1) 38 | timestamps = torch.arange(0,embeddings.shape[-2])/model.embedding_rate + (0.5/model.embedding_rate) 39 | timestamps = torch.tile(timestamps[None,:],[embeddings.shape[-3],1]) 40 | return embeddings, timestamps.to(model_device, dtype=torch.float32) 41 | 42 | def get_scene_embeddings( 43 | audio: Tensor, 44 | model: torch.nn.Module, 45 | ) -> Tensor: 46 | 47 | y, t = get_timestamp_embeddings(audio, model) 48 | out = torch.mean(y,axis=-2) 49 | 50 | return out 51 | -------------------------------------------------------------------------------- /encodecmae/hub.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import HfFileSystem, hf_hub_download 2 | from .models import EncodecMAE 3 | from ginpipe.core import gin_configure_externals 4 | import gin 5 | import torch 6 | from pathlib import Path 7 | import copy 8 | 9 | def traverse_dir(dir, result, fs): 10 | for x in fs.ls(dir, refresh=True): 11 | if x['name'].endswith('.pt'): 12 | result.append('/'.join(x['name'].split('/')[2:])) 13 | else: 14 | if x['type'] == 'dir': 15 | traverse_dir(x['name'], result, fs) 16 | 17 | def get_available_models(): 18 | fs = HfFileSystem() 19 | available_models = [] 20 | traverse_dir('lpepino/encodecmae-v2', available_models, fs) 21 | #available_models = [x['name'].split('/')[-1] for x in fs.ls('lpepino/encodecmae-pretrained/upstreams', refresh=True)] 22 | return [x.split('.')[0] for x in available_models] 23 | 24 | @gin.configurable 25 | def get_model(model, processor): 26 | return model, processor 27 | 28 | def load_model(model, mode='eval',device='cuda:0'): 29 | #Get model files 30 | config_str = gin.config_str() 31 | registers = copy.deepcopy(gin.config._REGISTRY) 32 | gin.clear_config() 33 | registry_clear_keys = [k for k,v in gin.config._REGISTRY.items() if k not in ['gin.macro', 'gin.constant', 'gin.singleton', 'ginpipe.core.execute_pipeline', 'encodecmae.hub.get_model']] 34 | for k in registry_clear_keys: 35 | gin.config._REGISTRY.pop(k) 36 | available_models = get_available_models() 37 | if model in available_models: 38 | model_file = hf_hub_download(repo_id='lpepino/encodecmae-v2', filename='{}.pt'.format(model)) 39 | else: 40 | raise Exception("Available models are: {}".format(available_models)) 41 | 42 | model_state = torch.load(model_file, map_location='cpu') 43 | 44 | flag = {'module_list_str': model_state['imports']} 45 | gin_configure_externals(flag) 46 | gin.parse_config(model_state['config']) 47 | model, processor = get_model() 48 | model = model() 49 | processor = processor() 50 | model.load_state_dict(model_state['state_dict']) 51 | gin.clear_config() 52 | gin.config._REGISTRY.clear() 53 | gin.config._REGISTRY = registers 54 | gin.parse_config(config_str) 55 | if mode=='eval': 56 | model.eval() 57 | model.to(device) 58 | model.processor = processor 59 | 60 | #To avoid dynamic batch problems: 61 | if hasattr(model.visible_encoder, 'compile'): 62 | model.visible_encoder.compile=False 63 | return model 64 | -------------------------------------------------------------------------------- /encodecmae/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mae import EncodecMAE -------------------------------------------------------------------------------- /encodecmae/models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from encodec import EncodecModel 3 | import encodec 4 | import torch 5 | 6 | class WavEncoder(torch.nn.Module): 7 | def __init__(self, encoder, pre_net=None, post_net=None, 8 | key_in='wav', key_out='wav_features', key_lens='wav_lens', 9 | make_padding_mask=True, hop_length=None,fs=None): 10 | super().__init__() 11 | self.encoder = encoder() 12 | self.pre_net = pre_net() if pre_net is not None else None 13 | self.post_net = post_net() if post_net is not None else None 14 | self.key_in = key_in 15 | self.key_out = key_out 16 | self.key_wav_lens = key_lens 17 | self.make_padding_mask = make_padding_mask 18 | if hop_length is None: 19 | self.hop_length=self.encoder.hop_length 20 | else: 21 | self.hop_length=hop_length 22 | if fs is None: 23 | self.fs = self.encoder.fs 24 | else: 25 | self.fs = fs 26 | 27 | def forward(self,x): 28 | #Encode wav 29 | y = x[self.key_in] 30 | if self.pre_net is not None: 31 | y = self.pre_net(y) 32 | x[key_out+'_pre_net_output'] = y 33 | y = self.encoder(y) 34 | if self.post_net is not None: 35 | x[self.key_out+'_encoder_out'] = y 36 | y = self.post_net(y) 37 | x[self.key_out] = y 38 | #Make padding masks 39 | if self.make_padding_mask: 40 | x['features_len'] = (x['wav_lens']//self.hop_length).to(y.device) 41 | x['feature_padding_mask'] = x['features_len'][:,None] <= torch.arange(0,y.shape[1],device=y.device)[None,:] 42 | 43 | class EncodecEncoder(torch.nn.Module): 44 | def __init__(self, frozen: bool = True, scale: float = 1.0, pretrained: bool = True) -> None: 45 | """Initialize Encodec Encoder model. 46 | 47 | Args: 48 | frozen (bool, optional): Whether the model is frozen or not. Defaults to True. 49 | scale (float, optional): Scaling factor. Defaults to 1.0. 50 | pretrained (bool, optional): Whether to load a pretrained checkpoint or train from scratch. Defaults to True. 51 | """ 52 | super().__init__() 53 | self.model = EncodecModel.encodec_model_24khz(pretrained=pretrained).encoder 54 | self.hop_length = self.model.hop_length 55 | self.frozen = frozen 56 | if self.frozen: 57 | self.model.eval() 58 | else: 59 | self.model.train() 60 | self.out_dim = 128 61 | self.fs = 24000 62 | self.scale = scale 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | """Forward pass. 66 | 67 | Args: 68 | x (torch.Tensor): Input tensor corresponding to waveforms with shape [B, T]. 69 | 70 | Returns: 71 | torch.Tensor: Output from EnCodec encoder with shape [B, T, D] 72 | """ 73 | x = x.unsqueeze(1) 74 | if self.frozen: 75 | with torch.no_grad(): 76 | y = self.model(x) 77 | else: 78 | y = self.model(x) 79 | y = torch.permute(y,(0,2,1))*self.scale 80 | return y 81 | 82 | class SEANetEncoder(torch.nn.Module): 83 | def __init__(self, frozen=False, lstm=False, pretrained=True): 84 | super().__init__() 85 | self.model = encodec.modules.SEANetEncoder(lstm=lstm) 86 | if pretrained: 87 | ckpt_path = 'https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th' 88 | sd = torch.hub.load_state_dict_from_url(ckpt_path, map_location='cpu', check_hash=True) 89 | sd = {k.replace('encoder.',''):v for k,v in sd.items()} 90 | if not lstm: 91 | sd = {k.replace('.15','.14'): v for k,v in sd.items()} 92 | self.model.load_state_dict(sd, strict=False) 93 | self.hop_length = self.model.hop_length 94 | self.frozen = frozen 95 | self.fs = 24000 96 | if self.frozen: 97 | self.model.eval() 98 | else: 99 | self.model.train() 100 | self.out_dim = self.model.dimension 101 | 102 | def forward(self, x): 103 | x = x.unsqueeze(1) 104 | if self.frozen: 105 | with torch.no_grad(): 106 | y = self.model(x) 107 | else: 108 | y = self.model(x) 109 | y = torch.permute(y,(0,2,1)) 110 | return y -------------------------------------------------------------------------------- /encodecmae/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class FrameLevelClassificationHead(nn.Module): 5 | def __init__(self, model_dim: int, num_tokens: int, num_streams: int = 8, key_in: str = 'decoder_out', key_out: str = 'predicted_tokens') -> None: 6 | """Initialize FrameLevelClassificationHead. 7 | 8 | Args: 9 | model_dim (int): Dimensionality of the input. 10 | num_tokens (int): Number of tokens. 11 | num_streams (int, optional): Number of streams. Defaults to 8. 12 | key_in (str, optional): Key for input data. Defaults to 'decoder_out'. 13 | key_out (str, optional): Key for output data. Defaults to 'predicted_tokens'. 14 | """ 15 | super().__init__() 16 | self.layer = nn.Linear(model_dim,num_tokens*num_streams) 17 | self.num_tokens = num_tokens 18 | self.num_streams = num_streams 19 | self.ar_sampling = False 20 | self.key_in=key_in 21 | self.key_out=key_out 22 | 23 | def forward(self, x: dict) -> dict: 24 | """Forward pass. 25 | 26 | Args: 27 | x (dict): Input data dictionary. 28 | 29 | Returns: 30 | dict: Output data dictionary. 31 | """ 32 | xin = x[self.key_in] 33 | probs = self.layer(xin).view(xin.shape[0],xin.shape[1],self.num_streams,self.num_tokens) 34 | x[self.key_out] = probs 35 | return x 36 | 37 | class SegmentLevelClassificationHead(nn.Module): 38 | def __init__(self, model_dim: int, num_classes: int, num_streams: int) -> None: 39 | """Initialize SegmentLevelClassificationHead. 40 | 41 | Args: 42 | model_dim (int): Dimensionality of the input. 43 | num_classes (int): Number of classes. 44 | num_streams (int): Number of streams. 45 | """ 46 | super().__init__() 47 | self.layer = nn.Linear(model_dim,num_classes*num_streams) 48 | 49 | def forward(self, x: torch.Tensor) -> torch.Tensor: 50 | """Forward pass. 51 | 52 | Args: 53 | x (torch.Tensor): Input tensor. 54 | 55 | Returns: 56 | torch.Tensor: Output tensor. 57 | """ 58 | probs = self.layer(torch.mean(x,axis=1)) 59 | return probs -------------------------------------------------------------------------------- /encodecmae/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | 4 | class EnCodecMAEClassificationLoss(torch.nn.Module): 5 | def __init__(self, masked_weight=0.9, quantizer_weights=None): 6 | super().__init__() 7 | self.masked_weight = masked_weight 8 | self.quantizer_weights = torch.tensor(quantizer_weights) 9 | 10 | def forward(self, x): 11 | all_losses = {} 12 | #Sometimes because of padding/different frame rates, there might be a small difference in length between input and target representations. 13 | #We fix it by cropping the longest vector: 14 | if x['predicted_tokens'].shape[1] < x['targets'].shape[-1]: 15 | x['targets'] = x['targets'][:,:,:x['predicted_tokens'].shape[1]] 16 | elif x['predicted_tokens'].shape[1] > x['targets'].shape[-1]: 17 | x['predicted_tokens'] = x['predicted_tokens'][:,:x['targets'].shape[-1],:,:] 18 | x['non_visible_mask'] = x['non_visible_mask'][:,:x['targets'].shape[-1]] 19 | x['visible_mask'] = x['visible_mask'][:,:x['targets'].shape[-1]] 20 | else: 21 | pass 22 | loss = torch.nn.functional.cross_entropy(torch.permute(x['predicted_tokens'],(0,3,1,2)), torch.permute(x['targets'],(1,2,0)).to(dtype=torch.long), reduction='none') 23 | if 'quantizer_loss_weights' not in x: 24 | x['quantizer_loss_weights'] = torch.tensor([[1.0]], device=loss.device, dtype=loss.dtype) 25 | loss = loss*(x['quantizer_loss_weights'].unsqueeze(1))*(self.quantizer_weights.to(x['predicted_tokens'].device).unsqueeze(0).unsqueeze(0)) 26 | masked_loss = loss*(x['non_visible_mask'].unsqueeze(-1)) 27 | unmasked_loss = loss*(x['visible_mask'].unsqueeze(-1)) 28 | 29 | num_non_visible = max(1,torch.sum(x['non_visible_mask'])) 30 | num_visible = max(1,torch.sum(x['visible_mask'])) 31 | for q in range(masked_loss.shape[-1]): 32 | all_losses[f'masked_loss_q{q}'] = torch.sum(masked_loss[:,:,q])/num_non_visible 33 | all_losses[f'unmasked_loss_q{q}'] = torch.sum(unmasked_loss[:,:,q])/num_visible 34 | 35 | masked_loss = torch.sum(masked_loss)/num_non_visible 36 | unmasked_loss = torch.sum(unmasked_loss)/num_visible 37 | all_losses['masked_loss'] = masked_loss 38 | all_losses['unmasked_loss'] = unmasked_loss 39 | loss = self.masked_weight*masked_loss + (1-self.masked_weight)*unmasked_loss 40 | all_losses['loss'] = loss 41 | all_losses['time'] = int(datetime.now().strftime('%y%m%d%H%M%S')) 42 | 43 | return all_losses 44 | -------------------------------------------------------------------------------- /encodecmae/models/mae.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn.functional as F 4 | from datetime import datetime 5 | import gin 6 | import librosa 7 | 8 | import pytorch_lightning as pl 9 | import torch 10 | import torch.nn.functional as F 11 | from datetime import datetime 12 | import gin 13 | import librosa 14 | 15 | import numpy as np 16 | 17 | class EncodecMAE(pl.LightningModule): 18 | """EncodecMAE. 19 | 20 | Args: 21 | wav_encoder (torch.nn.Module): Module that takes a batch of waveforms (BxT) and returns a representation (BxTxDf). 22 | target_encoder (torch.nn.Module): Module that generates targets from unmasked inputs. 23 | visible_encoder (torch.nn.Module): Module that encodes the visible tokens returning visible embeddings (BxTvxD) 24 | decoder (torch.nn.Module): Module that takes expanded input (with mask tokens) and generate embeddings for the whole sequence (BxTxD) 25 | masker (torch.nn.Module): Module that takes wav_encoder outputs and mask them. Has 2 methods: mask() that masks the embeddings, and unmask() that expands visible embeddings to original length. 26 | head (torch.nn.Module): Module that takes decoder output and generates predictions of the targets. 27 | loss (torch.nn.Module): Module that receives prediction and targets and returns model losses. 28 | optimizer (torch.optim.Optimizer): Torch optimizer to use during training. 29 | lr_scheduler (torch.lr_scheduler.LRScheduler): Torch learning rate scheduler used during training. 30 | """ 31 | def __init__(self, 32 | wav_encoder: torch.nn.Module, 33 | target_encoder: torch.nn.Module, 34 | visible_encoder: torch.nn.Module, 35 | decoder: torch.nn.Module, 36 | masker: torch.nn.Module, 37 | head: torch.nn.Module, 38 | loss: torch.nn.Module = None, 39 | optimizer: torch.optim.Optimizer = None, 40 | lr_scheduler = None 41 | ): 42 | super().__init__() 43 | self.wav_encoder = wav_encoder() 44 | self.target_encoder = target_encoder() 45 | self.masker = masker() 46 | self.visible_encoder = visible_encoder() 47 | self.head = head() 48 | self.decoder = decoder() 49 | self.optimizer = optimizer 50 | self.lr_scheduler = lr_scheduler 51 | self.loss = loss() if loss is not None else None 52 | 53 | self.apply(self._init_weights) 54 | 55 | def forward(self, x): 56 | self.wav_encoder(x) 57 | self.masker.mask(x) 58 | self.visible_encoder(x) 59 | self.masker.unmask(x) 60 | self.decoder(x) 61 | if hasattr(self.target_encoder,'requires_model') and (self.target_encoder.requires_model): 62 | self.target_encoder(x, self) 63 | else: 64 | self.target_encoder(x) 65 | self.head(x) 66 | 67 | return x 68 | 69 | def forward_finetune(self,x): 70 | self.encode_wav(x) 71 | self.mask(x) 72 | self.encode_visible(x) 73 | self.predict_tokens(x, key_in='visible_embeddings') 74 | 75 | return x 76 | 77 | def extract_activations(self,x, 78 | detach=True, 79 | postnorm_last_activation=True, 80 | extract_decoder=False, 81 | residual_branch=0): 82 | if detach: 83 | with torch.no_grad(): 84 | self.wav_encoder(x) 85 | self.masker.mask(x,ignore_mask=True) 86 | self.visible_encoder(x, return_activations=True, padding_mask=x['visible_padding_mask'],postnorm_last_activation=postnorm_last_activation,residual_branch=residual_branch) 87 | if extract_decoder: 88 | self.masker.unmask(x) 89 | self.decoder(x, return_activations=True,postnorm_last_activation=postnorm_last_activation,residual_branch=residual_branch) 90 | return x 91 | 92 | def training_step(self,x, batch_idx): 93 | x = self(x) 94 | losses = self.loss(x) 95 | self.log_results(x,losses,'train') 96 | 97 | return losses['loss'] 98 | 99 | def validation_step(self,x, batch_idx): 100 | x = self(x) 101 | losses = self.loss(x) 102 | self.log_results(x,losses,'val') 103 | 104 | def log_results(self,x,losses,prefix): 105 | self.log_dict({'{}_{}'.format(prefix,k): v for k,v in losses.items()}) 106 | 107 | def set_optimizer_state(self, state): 108 | self.opt_state = state 109 | 110 | def configure_optimizers(self): 111 | opt = self.optimizer(self.trainer.model.parameters()) 112 | if self.lr_scheduler is not None: 113 | if self.lr_scheduler.__name__ == 'SequentialLR': 114 | binds = gin.get_bindings('torch.optim.lr_scheduler.SequentialLR') 115 | lr_scheduler = self.lr_scheduler(opt, schedulers=[s(opt) for s in binds['schedulers']]) 116 | else: 117 | lr_scheduler = self.lr_scheduler(opt) if self.lr_scheduler is not None else None 118 | else: 119 | lr_scheduler = None 120 | del self.optimizer 121 | del self.lr_scheduler 122 | opt_config = {'optimizer': opt} 123 | if lr_scheduler is not None: 124 | opt_config['lr_scheduler'] = {'scheduler': lr_scheduler, 125 | 'interval': 'step', 126 | 'frequency': 1} 127 | return opt_config 128 | 129 | def _init_weights(self, m): 130 | if isinstance(m, torch.nn.Linear): 131 | # we use xavier_uniform following official JAX ViT: 132 | torch.nn.init.xavier_uniform_(m.weight) 133 | if isinstance(m, torch.nn.Linear) and m.bias is not None: 134 | torch.nn.init.constant_(m.bias, 0) 135 | elif isinstance(m, torch.nn.LayerNorm): 136 | torch.nn.init.constant_(m.bias, 0) 137 | torch.nn.init.constant_(m.weight, 1.0) 138 | 139 | def extract_features_from_file(self, 140 | filename, 141 | chunk_size=None, 142 | overlap=0, 143 | start=None, 144 | end=None, 145 | layer=-1, 146 | extract_decoder=False, 147 | postnorm_last_activation=True, 148 | residual_branch=0): 149 | fs = self.wav_encoder.fs 150 | start = start/fs if start is not None else None 151 | end = end/fs if end is not None else None 152 | duration = end - start if (end is not None and start is not None) else None 153 | x, fs = librosa.load(filename, sr=fs, offset=start, duration=duration) 154 | if chunk_size is None: 155 | chunk_size=int(fs*self.processor._processors[0].max_length) 156 | features = self.extract_features_from_array(x, chunk_size=chunk_size, layer=layer, overlap=overlap, extract_decoder=extract_decoder, postnorm_last_activation=postnorm_last_activation) 157 | if (features.ndim == 3) and (features.shape[0]==1): 158 | return features[0] 159 | else: 160 | return features 161 | 162 | def apply_processors(self, xin, from_idx=1): 163 | def to_torch(x): 164 | if isinstance(x, np.ndarray): 165 | return torch.from_numpy(x) 166 | else: 167 | return torch.tensor(x) 168 | 169 | batch_processed = [] 170 | for k in range(xin['wav'].shape[0]): 171 | xin_i = {f:v[k] for f,v in xin.items()} 172 | for p in self.processor._processors[1:]: 173 | xin_i = p(xin_i) 174 | batch_processed.append(xin_i) 175 | xin = {k: torch.stack([to_torch(b[k]) for b in batch_processed]).to(self.device, self.dtype) for k in batch_processed[0].keys()} 176 | return xin 177 | 178 | def extract_features_from_array(self, 179 | audio, 180 | wav_lens=None, 181 | chunk_size=None, 182 | overlap=0, 183 | return_type='numpy', 184 | layer=-1, 185 | min_length=2048, 186 | extract_decoder=False, 187 | postnorm_last_activation=True, 188 | residual_branch=0): 189 | if chunk_size is None: 190 | fs = self.wav_encoder.fs 191 | chunk_size=int(fs*self.processor._processors[0].max_length) 192 | 193 | hop_size = chunk_size*(1-overlap) 194 | if not isinstance(audio, torch.Tensor): 195 | audio = torch.tensor(audio, device=self.device, dtype=torch.float32) 196 | 197 | if audio.ndim == 1: 198 | batch_size = 1 199 | audio = audio[None,:] 200 | else: 201 | batch_size = audio.shape[0] 202 | 203 | if wav_lens is None: 204 | wav_lens = [audio.shape[1]]*batch_size 205 | 206 | with torch.no_grad(): 207 | acts = [] 208 | for mi, i in enumerate(range(0,audio.shape[-1],hop_size)): 209 | if audio[:,i:i+chunk_size].shape[1]>min_length: 210 | wav_lens_i = [min(max(wav_lens[i] - i,0),chunk_size) for i in range(batch_size)] 211 | audio_i = audio[:,i:i+chunk_size] 212 | xin = {'wav': audio_i.cpu().numpy(), 'wav_lens': np.array(wav_lens_i)} 213 | xin = self.apply_processors(xin) 214 | if extract_decoder: 215 | self.decoder.key_transformer_out='decoder_transformer' 216 | out_i = self.extract_activations(xin, extract_decoder=extract_decoder, postnorm_last_activation=postnorm_last_activation, residual_branch=residual_branch) 217 | activations = torch.stack(out_i['visible_embeddings_activations']).squeeze(axis=1) 218 | if extract_decoder: 219 | decoder_acts = torch.stack(out_i['decoder_transformer_activations']).squeeze(axis=1) 220 | activations = torch.cat([activations, decoder_acts],axis=0) 221 | if layer != 'all': 222 | activations = activations[layer] 223 | if activations.ndim == 2: 224 | activations = activations.unsqueeze(0) 225 | acts.append(activations) 226 | if acts[0].ndim == 3: 227 | xi = torch.cat(acts,axis=1) 228 | elif acts[0].ndim == 4: 229 | xi = torch.cat(acts,axis=2) 230 | else: 231 | raise Exception('Wrong shape of activations') 232 | if return_type == 'numpy': 233 | return xi.detach().cpu().numpy() 234 | elif return_type == 'torch': 235 | return xi.detach() 236 | else: 237 | raise Exception('Unrecognized return type') 238 | -------------------------------------------------------------------------------- /encodecmae/models/masks/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | 5 | class PatchoutMask(torch.nn.Module): 6 | def __init__(self, masker, positional_encoder): 7 | super().__init__() 8 | self.positional_encoder = positional_encoder() 9 | self.masker = masker() 10 | 11 | def mask(self, x, ignore_mask=False): 12 | if self.training and not ignore_mask: 13 | to_mask = self.positional_encoder(x['wav_features']) 14 | x['non_visible_mask'] = self.masker(x) 15 | x['visible_mask'] = ~x['non_visible_mask'] 16 | 17 | visibles = [] 18 | for i in range(to_mask.shape[0]): 19 | visibles.append(to_mask[i,(x['visible_mask'][i] & ~x['feature_padding_mask'][i])]) 20 | visible_lens = [xi.shape[0] for xi in visibles] 21 | maxlen = max(visible_lens) 22 | visible = torch.stack([torch.nn.functional.pad(xi,(0,0,0,maxlen-xi.shape[0])) for xi in visibles]) 23 | visible_padding_mask = (torch.arange(0,visible.shape[1],device=visible.device).unsqueeze(0)) >= torch.tensor(visible_lens, device=visible.device).unsqueeze(1) 24 | x['visible_tokens'] = visible 25 | x['visible_lens'] = visible_lens 26 | x['visible_padding_mask'] = visible_padding_mask 27 | else: 28 | x['visible_tokens'] = self.positional_encoder(x['wav_features']) 29 | x['visible_mask'] = torch.ones(x['wav_features'].shape[:2]).to(dtype=torch.bool, device=x['visible_tokens'].device) 30 | x['non_visible_mask'] = torch.zeros((x['wav_features'].shape[0],x['wav_features'].shape[1])).to(dtype=torch.bool, device=x['visible_tokens'].device) 31 | x['visible_padding_mask'] = x['feature_padding_mask'] 32 | x['visible_lens'] = x['features_len'] 33 | return x 34 | 35 | def unmask(self, x, key_in='decoder_in', key_out='decoder_in'): 36 | mask = x['non_visible_mask'] 37 | xin = x[key_in] 38 | feature_padding_mask = x['feature_padding_mask'] 39 | visible_padding_mask = x['visible_padding_mask'] 40 | unmasked = torch.zeros((mask.shape[0],mask.shape[1],xin.shape[-1]), dtype=xin.dtype, device=xin.device) 41 | unmasked[(~mask & ~feature_padding_mask)] = xin[~visible_padding_mask] 42 | x[key_out] = unmasked 43 | return x 44 | 45 | class TimeGapMask(torch.nn.Module): 46 | def __init__(self, p_mask, gap_size): 47 | super().__init__() 48 | self.p_mask = p_mask 49 | self.gap_size = gap_size 50 | 51 | def forward(self, x): 52 | x_lengths = x['features_len'] 53 | mask = np.zeros((len(x_lengths),max(x_lengths)), dtype=bool) 54 | ml = (x_lengths*self.p_mask).to(dtype=torch.int64) 55 | for i in range(len(x_lengths)): 56 | n_masked_i = 0 57 | start_idxs = [] 58 | while n_masked_i < ml[i]: 59 | start_idx = random.randint(0,max(0,x_lengths[i]-self.gap_size)) 60 | start_idxs.append(start_idx) 61 | mask[i,start_idx:min(start_idx+self.gap_size,x_lengths[i])]=1 62 | n_masked_i = mask[i].sum() 63 | if n_masked_i > ml[i]: 64 | valid_start_idxs = [idx for idx in start_idxs if idx < mask.shape[1]-self.gap_size] 65 | mask[i,valid_start_idxs[0]:valid_start_idxs[0]+n_masked_i-ml[i]]=0 66 | return torch.from_numpy(mask).to(x['wav_features'].device) -------------------------------------------------------------------------------- /encodecmae/models/targets/__init__.py: -------------------------------------------------------------------------------- 1 | from encodec import EncodecModel 2 | import torch 3 | import numpy as np 4 | import random 5 | from sklearn.cluster import MiniBatchKMeans 6 | import copy 7 | 8 | class EncodecQuantizer(torch.nn.Module): 9 | def __init__(self, n=8, frozen=True, scale=1.0, key_in='wav_features', key_out='targets', use_encodec_encoder=False, return_only_last=False): 10 | super().__init__() 11 | model = EncodecModel.encodec_model_24khz() 12 | self.model = model.quantizer 13 | self.scale = scale 14 | #Modify state dict to scale quantizer weights 15 | sd = self.model.state_dict() 16 | sd = {k: v*scale if 'embed' in k else v for k,v in sd.items()} 17 | self.model.load_state_dict(sd) 18 | self.frozen = frozen 19 | self.bandwidth = (10*75*n)/1000 20 | if self.frozen: 21 | self.model.eval() 22 | self.key_in = key_in 23 | self.key_out = key_out 24 | self.return_only_last = return_only_last 25 | if use_encodec_encoder: 26 | self.encodec_encoder = model.encoder 27 | self.encodec_encoder.eval() 28 | else: 29 | self.encodec_encoder = None 30 | 31 | def forward(self, xin): 32 | x = xin[self.key_in] 33 | with torch.no_grad(): 34 | if self.encodec_encoder is not None: 35 | x = x.unsqueeze(1) 36 | x = self.encodec_encoder(x) 37 | x = torch.transpose(x,1,2) 38 | x = torch.transpose(x,1,2) 39 | result = self.model(x,sample_rate=75,bandwidth=self.bandwidth) 40 | y = result.codes 41 | if self.return_only_last: 42 | y = y[-1,:,:].unsqueeze(0) 43 | xin[self.key_out] = y 44 | return xin 45 | 46 | class IdentityTarget(torch.nn.Module): 47 | def __init__(self, key_in='targets', key_out='targets'): 48 | super().__init__() 49 | self.key_in = key_in 50 | self.key_out = key_out 51 | self.requires_model=False 52 | 53 | def forward(self, x): 54 | xin = x[self.key_in] 55 | if xin.ndim==2: 56 | xin = xin.unsqueeze(0) 57 | x[self.key_out] = xin -------------------------------------------------------------------------------- /encodecmae/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | import gin 5 | 6 | ###Code from TIMM: 7 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 8 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 9 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 10 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 11 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 12 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 13 | 'survival rate' as the argument. 14 | """ 15 | if drop_prob == 0. or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0 and scale_by_keep: 21 | random_tensor.div_(keep_prob) 22 | return x * random_tensor 23 | 24 | class DropPath(nn.Module): 25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 26 | """ 27 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 28 | super(DropPath, self).__init__() 29 | self.drop_prob = drop_prob 30 | self.scale_by_keep = scale_by_keep 31 | 32 | def forward(self, x): 33 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 34 | 35 | def extra_repr(self): 36 | return f'drop_prob={round(self.drop_prob,3):0.3f}' 37 | 38 | class LayerScale(nn.Module): 39 | def __init__(self, dim, init_values=1e-5, inplace=False): 40 | super().__init__() 41 | self.inplace = inplace 42 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 43 | 44 | def forward(self, x): 45 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 46 | 47 | class Transpose(nn.Module): 48 | def __init__(self, axis0, axis1): 49 | super().__init__() 50 | self.axis0=axis0 51 | self.axis1=axis1 52 | 53 | def forward(self,x): 54 | return torch.transpose(x,self.axis0,self.axis1) 55 | 56 | class SinusoidalPositionalEmbeddings(nn.Module): 57 | def __init__(self, embedding_dim,learn_scale=True,scale=1.0): 58 | super().__init__() 59 | self.embedding_dim = embedding_dim 60 | self.scale = torch.nn.Parameter(data=torch.tensor(scale),requires_grad=learn_scale) 61 | 62 | def forward(self,x): 63 | pe = torch.zeros_like(x) 64 | position = torch.arange(0, pe.shape[1]).unsqueeze(1).unsqueeze(0) 65 | div_term = torch.exp((torch.arange(0, pe.shape[2], 2, dtype=torch.float) * 66 | -(math.log(10000.0) / pe.shape[2]))) 67 | pe[:,:, 0::2] = torch.sin(position.float() * div_term) 68 | pe[:,:, 1::2] = torch.cos(position.float() * div_term) 69 | return x+self.scale*pe 70 | 71 | ### Transformer implementation with ResiDual normalization. See 'ResiDual: Transformer with Dual Residual Connections' - Xie et al. 72 | class TransformerLayer(nn.Module): 73 | def __init__(self, model_dim, attention_layer, ff_layer, norm_layer, 74 | norm_type='ResiDual', cross_attention_layer=None, drop_path=0, init_values=None): 75 | super().__init__() 76 | self.att_layer = attention_layer 77 | self.ff_layer = ff_layer 78 | self.norm1 = norm_layer(model_dim) 79 | self.norm2 = norm_layer(model_dim) 80 | self.norm_type = norm_type 81 | self.xatt_layer = cross_attention_layer 82 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 83 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 84 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 85 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 86 | if cross_attention_layer is not None: 87 | self.ls3 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 88 | self.norm3 = norm_layer(model_dim) 89 | self.drop_path3 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 90 | 91 | def forward(self, x, 92 | key_mask=None, att_mask=None, 93 | att_bias=None, k_mem=None, v_mem=None, 94 | key_mem_mask=None, mem_att_mask=None, mem_att_bias=None): 95 | 96 | if self.norm_type=='prenorm': 97 | xnorm = self.norm1(x) 98 | att_out, att_matrix = self.att_layer(xnorm, xnorm, xnorm, key_mask=key_mask, att_mask=att_mask, att_bias=att_bias) 99 | x = x + self.drop_path1(self.ls1(att_out)) 100 | if self.xatt_layer is not None: 101 | xnorm = self.norm3(x) 102 | xatt_out, xatt_matrix = self.xatt_layer(xnorm, k_mem, v_mem, key_mask=key_mem_mask, att_mask=mem_att_mask, att_bias=mem_att_bias) 103 | x = x + self.drop_path3(self.ls3(xatt_out)) 104 | else: 105 | xatt_matrix = None 106 | x = x + self.drop_path2(self.ls2(self.ff_layer(self.norm2(x)))) 107 | 108 | elif self.norm_type=='postnorm': 109 | att_out, att_matrix = self.att_layer(x,x,x,key_mask=key_mask, att_mask=att_mask, att_bias=att_bias) 110 | x = self.norm1(x+self.drop_path1(self.ls1(att_out))) 111 | if self.xatt_layer is not None: 112 | xatt_out, xatt_matrix = self.xatt_layer(x, k_mem, v_mem, key_mask=key_mem_mask, att_mask=mem_att_mask, att_bias=mem_att_bias) 113 | x = self.norm3(x + self.drop_path3(self.ls3(xatt_out))) 114 | else: 115 | xatt_matrix = None 116 | x = self.norm2(x+self.drop_path2(self.ls2(self.ff_layer(x)))) 117 | 118 | elif self.norm_type=='ResiDual': 119 | #Here 2 IO are managed. x[0]: ln_out x[1]: unnormalized 120 | #See https://arxiv.org/pdf/2304.14802.pdf 121 | 122 | if not isinstance(x,tuple): 123 | x = (x,x) 124 | att_out, att_matrix = self.att_layer(x[0],x[0],x[0],key_mask=key_mask, att_mask=att_mask, att_bias=att_bias) 125 | x0 = self.norm1(x[0]+self.drop_path1(self.ls1(att_out))) 126 | x1 = x[1]+self.drop_path1(self.ls1(att_out)) 127 | if self.xatt_layer is not None: 128 | xatt_out, xatt_matrix = self.xatt_layer(x0, k_mem, v_mem, key_mask=key_mem_mask, att_mask=mem_att_mask, att_bias=mem_att_bias) 129 | x0 = self.norm3(x0+self.drop_path3(self.ls3(xatt_out))) 130 | x1 = x1+self.drop_path1(self.ls3(xatt_out)) 131 | else: 132 | xatt_matrix=None 133 | ff_out = self.ff_layer(x0) 134 | x0 = self.norm2(x0+self.drop_path2(self.ls2(ff_out))) 135 | x1 = x1+self.drop_path2(self.ls2(ff_out)) 136 | x=(x0,x1) 137 | 138 | return x, att_matrix, xatt_matrix 139 | 140 | class TransformerEncoder(nn.Module): 141 | def __init__(self, model_dim, 142 | num_layers, 143 | attention_layer, 144 | ff_layer=None, 145 | ff_dim=4096, 146 | norm_layer=torch.nn.LayerNorm, 147 | norm_type='ResiDual', 148 | cross_attention_layer=None, 149 | drop_path=0, 150 | init_values=None, 151 | positional_encoder=None, 152 | compile=True, 153 | return_activations=False, 154 | key_in=None, 155 | key_padding_mask=None, 156 | key_out=None, 157 | key_transformer_in=None, 158 | key_transformer_out=None, 159 | pre_net=None, 160 | post_net=None): 161 | 162 | super().__init__() 163 | if ff_layer is None: 164 | ff_layer = torch.nn.Sequential(torch.nn.Linear(model_dim, ff_dim), 165 | torch.nn.ReLU(), 166 | torch.nn.Linear(ff_dim, model_dim)) 167 | else: 168 | ff_layer = ff_layer() 169 | if cross_attention_layer is not None: 170 | cross_attention_layer = cross_attention_layer() 171 | self.encoder_layers = torch.nn.ModuleList([TransformerLayer(model_dim, 172 | attention_layer(), 173 | ff_layer, 174 | norm_layer, 175 | norm_type, 176 | cross_attention_layer, 177 | drop_path, 178 | init_values) for i in range(num_layers)]) 179 | if positional_encoder is not None: 180 | self.positional_encoder = positional_encoder(model_dim) 181 | else: 182 | self.positional_encoder = positional_encoder 183 | if norm_type in ['ResiDual','prenorm']: 184 | self.final_norm = norm_layer(model_dim) 185 | elif norm_type == 'postnorm': 186 | self.final_norm = torch.nn.Identity() 187 | self.model_dim = model_dim 188 | self.compile = compile 189 | #self.compile = False 190 | self.norm_type = norm_type 191 | self.return_activations = return_activations 192 | if not self.return_activations: 193 | self.return_activations = [] 194 | elif self.return_activations == 'all': 195 | self.return_activations = [i for i in range(num_layers)] 196 | self.key_in = key_in 197 | self.key_padding_mask = key_padding_mask 198 | self.key_out = key_out 199 | self.key_transformer_in = key_transformer_in 200 | self.key_transformer_out = key_transformer_out 201 | self.pre_net = pre_net() if pre_net is not None else None 202 | self.post_net = post_net() if post_net is not None else None 203 | self.residual_branch = 0 204 | 205 | def run_through_encoder(self,x): 206 | x, layers, padding_mask = x 207 | for l in layers: 208 | x,_,_=l(x,key_mask=padding_mask) 209 | return x 210 | 211 | def extract_activations(self,x): 212 | activations = [] 213 | x, layers, padding_mask = x 214 | for l in layers: 215 | x,_,_=l(x,key_mask=padding_mask) 216 | activations.append(x[self.residual_branch]) 217 | return activations, x 218 | 219 | @torch.compile 220 | def compiled_run_through_encoder(self,x): 221 | return self.run_through_encoder(x) 222 | 223 | @torch.compile 224 | def compiled_get_activations(self,x): 225 | return self.extract_activations(x) 226 | 227 | def forward(self,xin, 228 | padding_mask=None, 229 | return_activations=False, 230 | postnorm_last_activation=True, 231 | residual_branch=0): 232 | 233 | self.residual_branch = residual_branch 234 | if isinstance(xin, dict): 235 | padding_mask = xin[self.key_padding_mask] 236 | x = xin[self.key_in] 237 | else: 238 | x = xin 239 | 240 | if self.pre_net is not None: 241 | x = self.pre_net(x) 242 | 243 | if self.positional_encoder is not None: 244 | x = self.positional_encoder(x) 245 | 246 | if isinstance(xin, dict) and (self.key_transformer_in is not None): 247 | xin[self.key_transformer_in] = x 248 | 249 | if self.compile: 250 | if return_activations: 251 | acts, x = self.compiled_get_activations((x, self.encoder_layers, padding_mask)) 252 | xin[self.key_transformer_out+'_activations'] = acts 253 | else: 254 | x = self.compiled_run_through_encoder((x,self.encoder_layers,padding_mask)) 255 | else: 256 | if return_activations: 257 | acts, x = self.extract_activations((x, self.encoder_layers, padding_mask)) 258 | xin[self.key_transformer_out+'_activations'] = acts 259 | else: 260 | x = self.run_through_encoder((x,self.encoder_layers,padding_mask)) 261 | 262 | if self.norm_type == 'ResiDual': 263 | x = x[0] + self.final_norm(x[1]) 264 | if return_activations and postnorm_last_activation: 265 | xin[self.key_transformer_out+'_activations'][-1] = x 266 | elif self.norm_type == 'prenorm': 267 | x = self.final_norm(x) 268 | else: 269 | pass 270 | 271 | if isinstance(xin, dict) and (self.key_transformer_out is not None): 272 | xin[self.key_transformer_out] = x 273 | 274 | if self.post_net is not None: 275 | x = self.post_net(x) 276 | 277 | if isinstance(xin,dict): 278 | xin[self.key_out] = x 279 | else: 280 | xin = x 281 | return xin 282 | 283 | class MultiHeadAttention(nn.Module): 284 | def __init__(self, 285 | model_dim, 286 | num_heads, 287 | qk_proj_dim=None, 288 | v_proj_dim=None, 289 | q_bias=False, 290 | k_bias=False, 291 | v_bias=False, 292 | out_bias=False, 293 | att_drop=0, 294 | proj_drop=0, 295 | att_scale=None, 296 | mask_value=-1e9): 297 | 298 | super().__init__() 299 | if qk_proj_dim is None: 300 | qk_proj_dim = model_dim//num_heads 301 | if v_proj_dim is None: 302 | v_proj_dim = qk_proj_dim 303 | 304 | self.qk_proj_dim = qk_proj_dim 305 | self.num_heads = num_heads 306 | self.v_proj_dim = v_proj_dim 307 | 308 | self.wq = nn.Linear(model_dim, qk_proj_dim*num_heads, bias=q_bias) 309 | self.wk = nn.Linear(model_dim, qk_proj_dim*num_heads, bias=k_bias) 310 | self.wv = nn.Linear(model_dim, v_proj_dim*num_heads, bias=v_bias) 311 | 312 | if att_scale is None: 313 | self.att_scale = qk_proj_dim ** -0.5 314 | else: 315 | self.att_scale = att_scale 316 | self.mask_value = mask_value 317 | 318 | self.att_drop = nn.Dropout(att_drop) 319 | self.wo = nn.Linear(v_proj_dim*num_heads, model_dim, bias=out_bias) 320 | self.proj_drop = nn.Dropout(proj_drop) 321 | 322 | def forward(self, q, k, v, key_mask=None, att_mask=None, att_bias=None): 323 | N,Tq,C = q.shape 324 | N,Tk,C = k.shape 325 | Q = self.wq(q) #NxTxD.H 326 | K = self.wk(k) #NxTxD.H 327 | V = self.wv(v) #NxTxD.H 328 | 329 | Q = Q.view(N,Tq,self.qk_proj_dim,self.num_heads).permute(0,3,1,2).reshape(N*self.num_heads,Tq,self.qk_proj_dim) #N.HxTqxDq 330 | K = K.view(N,Tk,self.qk_proj_dim,self.num_heads).permute(0,3,1,2).reshape(N*self.num_heads,Tk,self.qk_proj_dim) #N.HxTkxDk 331 | V = V.view(N,Tk,self.v_proj_dim,self.num_heads).permute(0,3,1,2).reshape(N*self.num_heads,Tk,self.v_proj_dim) #N.HxTkxDv 332 | 333 | kv_mask = torch.zeros((N*self.num_heads,Tq,Tk), dtype=torch.bool, device=Q.device) 334 | if key_mask is not None: 335 | key_mask = torch.tile(key_mask[:,None,None,:],(1,self.num_heads,Tq,1)).reshape(N*self.num_heads,Tq,Tk).to(dtype=torch.bool) 336 | kv_mask += key_mask 337 | if att_mask is not None: 338 | kv_mask += att_mask 339 | 340 | att = Q @ K.transpose(-2, -1) * self.att_scale #N.HxTqxTk 341 | if att_bias is not None: 342 | att += att_bias 343 | att += kv_mask*self.mask_value 344 | att = att.softmax(dim=-1) #N.HxTqxTk 345 | att = self.att_drop(att) #N.HxTqxTk 346 | 347 | x = att @ V #N.HxTqxDv 348 | x = x.view(N,self.num_heads,Tq,self.v_proj_dim).permute(0,2,3,1).reshape(N,Tq,self.v_proj_dim*self.num_heads) #NxTqxDvxH 349 | O = self.wo(x) 350 | O = self.proj_drop(O) 351 | 352 | return O, att 353 | -------------------------------------------------------------------------------- /encodecmae/scripts/run_pretraining.sh: -------------------------------------------------------------------------------- 1 | #-----------------------------------------First Iteration Models------------------------------------------------------- 2 | #EC-EC Small model 3 | ginpipe configs/base/unsupervised_pretrain.gin \ 4 | configs/models/encodecmae.gin \ 5 | configs/datasets/audioset-unbalanced-24k.gin \ 6 | configs/datasets/fma-large-24k.gin \ 7 | configs/datasets/librilight-6k-24k.gin \ 8 | configs/features/wav_only.gin \ 9 | --module_list configs/imports \ 10 | --project_name ec-ec-small_model \ 11 | --experiment_name upstream_model \ 12 | --mods NUM_ENCODER_LAYERS=5 13 | 14 | #EC-EC Base model 15 | ginpipe configs/base/unsupervised_pretrain.gin \ 16 | configs/models/encodecmae.gin \ 17 | configs/datasets/audioset-unbalanced-24k.gin \ 18 | configs/datasets/fma-large-24k.gin \ 19 | configs/datasets/librilight-6k-24k.gin \ 20 | configs/features/wav_only.gin \ 21 | --module_list configs/imports \ 22 | --project_name ec-ec-base_model \ 23 | --experiment_name upstream_model 24 | 25 | #EC-EC Large model 26 | ginpipe configs/base/unsupervised_pretrain.gin \ 27 | configs/models/encodecmae.gin \ 28 | configs/datasets/audioset-unbalanced-24k.gin \ 29 | configs/datasets/fma-large-24k.gin \ 30 | configs/datasets/librilight-6k-24k.gin \ 31 | configs/features/wav_only.gin \ 32 | --module_list configs/imports \ 33 | --project_name ec-ec-large_model \ 34 | --experiment_name upstream_model \ 35 | --mods NUM_ENCODER_LAYERS=20 DEVICE="[0,1]" MODEL_DIM=1024 TRAIN_BATCH_SIZE=64 "pl.Trainer.strategy='ddp_find_unused_parameters_true'" 36 | 37 | #Mel256-EC Small model 38 | ginpipe configs/base/unsupervised_pretrain.gin \ 39 | configs/models/encodecmae.gin \ 40 | configs/datasets/audioset-unbalanced-24k.gin \ 41 | configs/datasets/fma-large-24k.gin \ 42 | configs/datasets/librilight-6k-24k.gin \ 43 | configs/features/mel.gin \ 44 | --module_list configs/imports \ 45 | --project_name mel256-ec-small_model \ 46 | --experiment_name upstream_model \ 47 | --mods NUM_ENCODER_LAYERS=5 48 | 49 | #Mel256-EC Base model 50 | ginpipe configs/base/unsupervised_pretrain.gin \ 51 | configs/models/encodecmae.gin \ 52 | configs/datasets/audioset-unbalanced-24k.gin \ 53 | configs/datasets/fma-large-24k.gin \ 54 | configs/datasets/librilight-6k-24k.gin \ 55 | configs/features/mel.gin \ 56 | --module_list configs/imports \ 57 | --project_name mel256-ec-base_model \ 58 | --experiment_name upstream_model 59 | 60 | #Mel256-EC Base model (Audioset) 61 | ginpipe configs/base/unsupervised_pretrain.gin \ 62 | configs/models/encodecmae.gin \ 63 | configs/datasets/audioset-unbalanced-24k.gin \ 64 | configs/features/mel.gin \ 65 | --module_list configs/imports \ 66 | --project_name mel256-ec-base_model-as \ 67 | --experiment_name upstream_model 68 | 69 | #Mel256-EC Base model (Librilight) 70 | ginpipe configs/base/unsupervised_pretrain.gin \ 71 | configs/models/encodecmae.gin \ 72 | configs/datasets/librilight-6k-24k.gin \ 73 | configs/features/mel.gin \ 74 | --module_list configs/imports \ 75 | --project_name mel256-ec-base_model-ll \ 76 | --experiment_name upstream_model 77 | 78 | #Mel256-EC Base model (FMA) 79 | ginpipe configs/base/unsupervised_pretrain.gin \ 80 | configs/models/encodecmae.gin \ 81 | configs/datasets/fma-large-24k.gin \ 82 | configs/features/mel.gin \ 83 | --module_list configs/imports \ 84 | --project_name mel256-ec-base_model-fma \ 85 | --experiment_name upstream_model 86 | 87 | #Mel256-EC Large model 88 | ginpipe configs/base/unsupervised_pretrain.gin \ 89 | configs/models/encodecmae.gin \ 90 | configs/datasets/audioset-unbalanced-24k.gin \ 91 | configs/datasets/fma-large-24k.gin \ 92 | configs/datasets/librilight-6k-24k.gin \ 93 | configs/features/mel.gin \ 94 | --module_list configs/imports \ 95 | --project_name mel256-ec-large_model \ 96 | --experiment_name upstream_model \ 97 | --mods NUM_ENCODER_LAYERS=20 DEVICE="[0,1]" MODEL_DIM=1024 TRAIN_BATCH_SIZE=64 "pl.Trainer.strategy='ddp_find_unused_parameters_true'" 98 | 99 | #Mel256-EC Large model - Audioset 100 | ginpipe configs/base/unsupervised_pretrain.gin \ 101 | configs/models/encodecmae.gin \ 102 | configs/datasets/audioset-unbalanced-24k.gin \ 103 | configs/features/mel.gin \ 104 | --module_list configs/imports \ 105 | --project_name mel256-ec-large_model-as \ 106 | --experiment_name upstream_model \ 107 | --mods NUM_ENCODER_LAYERS=20 DEVICE="[0,1]" MODEL_DIM=1024 TRAIN_BATCH_SIZE=64 "pl.Trainer.strategy='ddp_find_unused_parameters_true'" 108 | 109 | #-----------------------------------------------ST Models-------------------------------------------------------------- 110 | 111 | #Mel256->EC Base - NoPN ST 112 | ginpipe configs/base/create_st_dataset.gin \ 113 | configs/datasets/audioset-unbalanced-24k.gin \ 114 | configs/datasets/fma-large-24k.gin \ 115 | configs/datasets/librilight-6k-24k.gin \ 116 | configs/features/mel.gin \ 117 | --module_list configs/imports \ 118 | --project_name mel256-ec-base_model \ 119 | --experiment_name self_training_dataset_no-pn \ 120 | --mods "TEACHER_MODEL='mel256-ec-base'" \ 121 | "POSTNORM_LAST_ACTIVATION=False" \ 122 | "DEVICE='cuda:0'" 123 | 124 | ginpipe configs/base/self-train.gin \ 125 | configs/models/encodecmae.gin \ 126 | configs/features/mel.gin \ 127 | configs/features/st.gin \ 128 | --module_list configs/imports \ 129 | --project_name mel256-ec-base_model \ 130 | --experiment_name st_model-no-pn \ 131 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-base.pt'" \ 132 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-base_model/self_training_dataset_no-pn'" \ 133 | "DEVICE=[0]" 134 | 135 | #Mel256->EC Base - NoPN ST (AS) 136 | ginpipe configs/base/create_st_dataset.gin \ 137 | configs/datasets/audioset-unbalanced-24k.gin \ 138 | configs/features/mel.gin \ 139 | --module_list configs/imports \ 140 | --project_name mel256-ec-base_model-as \ 141 | --experiment_name self_training_dataset_no-pn \ 142 | --mods "TEACHER_MODEL='mel256-ec-base-as'" \ 143 | "POSTNORM_LAST_ACTIVATION=False" \ 144 | "DEVICE='cuda:0'" 145 | 146 | ginpipe configs/base/self-train.gin \ 147 | configs/models/encodecmae.gin \ 148 | configs/features/mel.gin \ 149 | configs/features/st.gin \ 150 | --module_list configs/imports \ 151 | --project_name mel256-ec-base_model-as \ 152 | --experiment_name st_model-no-pn \ 153 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-base-as.pt'" \ 154 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-base_model-as/self_training_dataset_no-pn'" \ 155 | "DEVICE=[0]" \ 156 | "pl.Trainer.check_val_every_n_epoch=None" 157 | 158 | #Mel256->EC Base - NoPN ST (LL) 159 | ginpipe configs/base/create_st_dataset.gin \ 160 | configs/datasets/librilight-6k-24k.gin \ 161 | configs/features/mel.gin \ 162 | --module_list configs/imports \ 163 | --project_name mel256-ec-base_model-ll \ 164 | --experiment_name self_training_dataset_no-pn \ 165 | --mods "TEACHER_MODEL='mel256-ec-base-ll'" \ 166 | "POSTNORM_LAST_ACTIVATION=False" \ 167 | "DEVICE='cuda:1'" 168 | 169 | ginpipe configs/base/self-train.gin \ 170 | configs/models/encodecmae.gin \ 171 | configs/features/mel.gin \ 172 | configs/features/st.gin \ 173 | --module_list configs/imports \ 174 | --project_name mel256-ec-base_model-ll \ 175 | --experiment_name st_model-no-pn \ 176 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-base-ll.pt'" \ 177 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-base_model-ll/self_training_dataset_no-pn'" \ 178 | "DEVICE=[1]" \ 179 | "pl.Trainer.check_val_every_n_epoch=None" 180 | 181 | #Mel256->EC Base - NoPN ST (FMA) 182 | ginpipe configs/base/create_st_dataset.gin \ 183 | configs/datasets/fma-large-24k.gin \ 184 | configs/features/mel.gin \ 185 | --module_list configs/imports \ 186 | --project_name mel256-ec-base_model-fma \ 187 | --experiment_name self_training_dataset_no-pn \ 188 | --mods "TEACHER_MODEL='mel256-ec-base-fma'" \ 189 | "POSTNORM_LAST_ACTIVATION=False" \ 190 | "DEVICE='cuda:1'" 191 | 192 | ginpipe configs/base/self-train.gin \ 193 | configs/models/encodecmae.gin \ 194 | configs/features/mel.gin \ 195 | configs/features/st.gin \ 196 | --module_list configs/imports \ 197 | --project_name mel256-ec-base_model-fma \ 198 | --experiment_name st_model-no-pn \ 199 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-base-fma.pt'" \ 200 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-base_model-fma/self_training_dataset_no-pn'" \ 201 | "DEVICE=[1]" \ 202 | "pl.Trainer.check_val_every_n_epoch=None" 203 | 204 | #Mel256-EC Large - NoPN ST 205 | ginpipe configs/base/create_st_dataset.gin \ 206 | configs/datasets/audioset-unbalanced-24k.gin \ 207 | configs/datasets/fma-large-24k.gin \ 208 | configs/datasets/librilight-6k-24k.gin \ 209 | configs/features/mel.gin \ 210 | --module_list configs/imports \ 211 | --project_name mel256-ec-large_model \ 212 | --experiment_name self_training_dataset_no-pn \ 213 | --mods "TEACHER_MODEL='mel256-ec-base_st-nopn'" \ 214 | "POSTNORM_LAST_ACTIVATION=False" \ 215 | "DEVICE='cuda:0'" 216 | 217 | ginpipe configs/base/self-train.gin \ 218 | configs/models/encodecmae.gin \ 219 | configs/features/mel.gin \ 220 | configs/features/st.gin \ 221 | --module_list configs/imports \ 222 | --project_name mel256-ec-large_model \ 223 | --experiment_name st_model-no-pn \ 224 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-large.pt'" \ 225 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-large_model/self_training_dataset_no-pn'" \ 226 | "DEVICE=[0,1]" \ 227 | NUM_ENCODER_LAYERS=20 \ 228 | MODEL_DIM=1024 \ 229 | TRAIN_BATCH_SIZE=64 \ 230 | "pl.Trainer.strategy='ddp_find_unused_parameters_true'" 231 | 232 | #Mel256-EC Large - NoPN ST AS 233 | ginpipe configs/base/create_st_dataset.gin \ 234 | configs/datasets/audioset-unbalanced-24k.gin \ 235 | configs/features/mel.gin \ 236 | --module_list configs/imports \ 237 | --project_name mel256-ec-large_model-as \ 238 | --experiment_name self_training_dataset_no-pn \ 239 | --mods "TEACHER_MODEL='mel256-ec-base_st-as-nopn'" \ 240 | "POSTNORM_LAST_ACTIVATION=False" \ 241 | "DEVICE='cuda:1'" 242 | 243 | ginpipe configs/base/self-train.gin \ 244 | configs/models/encodecmae.gin \ 245 | configs/features/mel.gin \ 246 | configs/features/st.gin \ 247 | --module_list configs/imports \ 248 | --project_name mel256-ec-large_model-as \ 249 | --experiment_name st_model-no-pn \ 250 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-large-as.pt'" \ 251 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-large_model-as/self_training_dataset_no-pn'" \ 252 | "DEVICE=[0,1]" \ 253 | NUM_ENCODER_LAYERS=20 \ 254 | MODEL_DIM=1024 \ 255 | TRAIN_BATCH_SIZE=64 \ 256 | "pl.Trainer.strategy='ddp_find_unused_parameters_true'" 257 | -------------------------------------------------------------------------------- /encodecmae/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from loguru import logger 4 | import torchinfo 5 | import inspect 6 | from tqdm import tqdm 7 | import joblib 8 | import numpy as np 9 | from encodecmae.hub import load_model as load_encodecmae_model 10 | from huggingface_hub import hf_hub_download 11 | from threadpoolctl import threadpool_limits 12 | 13 | def load_model(state, model): 14 | model = load_encodecmae_model(model) 15 | state['model'] = model 16 | return state 17 | 18 | def create_st_dataset(state, layer=-1, kmeans_samples=10000, device='cuda:0', postnorm_last_activation=True,kmeans_jobs=8): 19 | from sklearn.cluster import KMeans 20 | 21 | model = state['model'].to(device) 22 | model.visible_encoder.compile=False 23 | #First learn kmeans: 24 | if ('tokenizer' in state): 25 | cluster_model = state['tokenizer'] 26 | elif Path(state['output_dir'],'tokenizer.pkl').exists(): 27 | cluster_model = joblib.load(Path(state['output_dir'],'tokenizer.pkl')) 28 | else: 29 | kmeans_sample_idxs = np.random.choice(np.arange(0,len(state['datasets']['train'])),size=kmeans_samples,replace=False) 30 | kmeans_dataset = [] 31 | for i in tqdm(kmeans_sample_idxs): 32 | x = state['datasets']['train'][i] 33 | x = state['dataloaders']['train'].collate_fn([x]) 34 | x = {k: v.to(model.device) for k,v in x.items()} 35 | out = model.extract_activations(x, postnorm_last_activation=postnorm_last_activation) 36 | kmeans_dataset.append(out['visible_embeddings_activations'][layer].cpu().numpy()) 37 | kmeans_dataset = np.concatenate(kmeans_dataset,axis=1)[0] 38 | n_tokens = model.head.num_tokens 39 | with threadpool_limits(user_api="blas", limits=kmeans_jobs): 40 | cluster_model = KMeans(n_clusters=n_tokens) 41 | kmeans_sample_idxs = np.random.choice(np.arange(0,kmeans_dataset.shape[0]),size=kmeans_samples,replace=False) 42 | kmeans_dataset = kmeans_dataset[kmeans_sample_idxs] 43 | cluster_model.fit(kmeans_dataset) 44 | state['tokenizer']=cluster_model 45 | joblib.dump(cluster_model, Path(state['output_dir'],'tokenizer.pkl')) 46 | state['datasets']['train']._out_cols+=['start','stop','filename'] 47 | state['dataloaders']['train'] = torch.utils.data.DataLoader(state['dataloaders']['train'].dataset, shuffle=False, num_workers=4, collate_fn=state['dataloaders']['train'].collate_fn, batch_size=state['dataloaders']['train'].batch_size) 48 | #Find last saved index before starting this loop: 49 | data_out_path = Path(state['output_dir'],'self_training_dataset') 50 | if not data_out_path.exists(): 51 | data_out_path.mkdir(parents=True) 52 | for batch_idx, x in enumerate(tqdm(state['dataloaders']['train'])): 53 | #out = model.extract_activations({'wav': x['wav'].to(device=model.device, dtype=model.dtype), 54 | # 'wav_lens': x['wav_lens'].to(device=model.device)}, postnorm_last_activation=postnorm_last_activation) 55 | x = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k,v in x.items()} 56 | out = model.extract_activations(x, postnorm_last_activation=postnorm_last_activation) 57 | out = out['visible_embeddings_activations'][layer].cpu().numpy() 58 | for i in range(out.shape[0]): 59 | indexs = cluster_model.predict(out[i]) 60 | filename = '{}_{}_{}.npy'.format(Path(x['filename'][i]).stem,x['start'][i],x['stop'][i]) 61 | file_out = Path(state['output_dir'],'self_training_dataset',filename[:3],filename) 62 | file_out.parent.mkdir(parents=True, exist_ok=True) 63 | try: 64 | np.save(file_out,indexs) 65 | except: 66 | print('Failed {}'.format(file_out)) 67 | with open(Path(state['output_dir'],'metadata_selftrain_dataset.csv'),'a') as f: 68 | f.write('{},{},{},{}\n'.format(str(x['filename'][i]),x['start'][i],x['stop'][i],str(file_out.resolve()))) 69 | state['model'] = state['model'].to('cpu') 70 | return state 71 | 72 | def fit_model(state, trainer_cls=None, model_cls=None, from_checkpoint=None, 73 | cpu_threads=8,dataloaders_key='dataloaders', 74 | checkpoint_folder='checkpoints', 75 | key_out='model', 76 | cache_model=True, 77 | keep_optimizer_state=False, 78 | non_strict_checkpoint=True): 79 | 80 | if from_checkpoint.startswith('huggingface:'): 81 | hf_path = from_checkpoint.split('huggingface:')[-1] 82 | hf_repo = '/'.join(hf_path.split('/')[:2]) 83 | path_in_repo = '/'.join(hf_path.split('/')[2:]) 84 | from_checkpoint = hf_hub_download(repo_id=hf_repo, filename=path_in_repo) 85 | 86 | if not ((key_out in state) and (cache_model)): 87 | torch.set_num_threads(cpu_threads) 88 | torch.set_float32_matmul_precision('medium') 89 | kwargs = {} 90 | if 'state' in inspect.signature(model_cls.__init__).parameters.keys(): 91 | kwargs['state'] = state 92 | 93 | if model_cls is None: 94 | if Path(from_checkpoint).stem == 'state': 95 | model = joblib.load(from_checkpoint)['model'] 96 | from_checkpoint=None 97 | else: 98 | model = model_cls(**kwargs) 99 | trainer = trainer_cls() 100 | trainer.checkpoint_callback.dirpath = trainer.checkpoint_callback.dirpath + '/{}'.format(checkpoint_folder) 101 | base_dir = trainer.checkpoint_callback.dirpath 102 | 103 | #Find last checkpoint 104 | if from_checkpoint == 'last': 105 | ckpts = list(Path(base_dir).glob('*.ckpt')) 106 | if 'last' in [x.stem for x in ckpts]: 107 | from_checkpoint=Path(base_dir, 'last.ckpt') 108 | else: 109 | ckpt_epoch = [int(c.stem.split('epoch=')[-1].split('-')[0]) for c in ckpts] 110 | if len(ckpt_epoch) > 0: 111 | last_epoch = max(ckpt_epoch) 112 | from_checkpoint = ckpts[ckpt_epoch.index(last_epoch)] 113 | else: 114 | logger.info('No checkpoints found in {}. Training from scratch.'.format(base_dir)) 115 | from_checkpoint = None 116 | 117 | logger.info(torchinfo.summary(model)) 118 | 119 | if (from_checkpoint is not None) and (non_strict_checkpoint): 120 | ckpt_data = torch.load(from_checkpoint) 121 | model_sd = model.state_dict() 122 | ckpt_sd = {} 123 | for k,v in ckpt_data['state_dict'].items(): 124 | if (k in model_sd) and (v.shape == model_sd[k].shape): 125 | ckpt_sd[k] = v 126 | else: 127 | print("Couldn't load {} from checkpoint".format(k)) 128 | 129 | model.load_state_dict(ckpt_sd, strict=False) 130 | from_checkpoint=None 131 | 132 | trainer.fit(model, 133 | state[dataloaders_key]['train'], 134 | state[dataloaders_key]['validation'], 135 | ckpt_path=from_checkpoint) 136 | trainer.save_checkpoint(Path(base_dir,'last.ckpt')) 137 | state[key_out+'_checkpoint_dir'] = trainer.checkpoint_callback.dirpath 138 | best_model_path = model.trainer.checkpoint_callback.best_model_path 139 | if (best_model_path is not None) and (best_model_path != ''): 140 | model.load_state_dict(torch.load(best_model_path)['state_dict']) 141 | state[key_out] = model 142 | else: 143 | logger.info('Model is already in state. Skipping task.') 144 | 145 | return state 146 | -------------------------------------------------------------------------------- /encodecmae/tasks/data/__init__.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from loguru import logger 3 | from pathlib import Path 4 | import numpy as np 5 | from tqdm import tqdm 6 | import soundfile as sf 7 | from torch.utils.data import Dataset 8 | import copy 9 | import torch 10 | import re 11 | import joblib 12 | 13 | from typing import Dict, List, Callable, Any, Optional, Union, Type 14 | 15 | def compensate_lengths(df: pd.DataFrame, chunk_length: Optional[float] = None) -> List[int]: 16 | """ 17 | Compensates for varying lengths of elements in a DataFrame. 18 | 19 | Args: 20 | df (pandas.DataFrame): The DataFrame containing elements with varying lengths. 21 | chunk_length (float, optional): The length of each chunk in seconds. 22 | If provided, elements will be divided into chunks of approximately equal length. 23 | Defaults to None. 24 | 25 | Returns: 26 | list: A list of indices corresponding to elements in the DataFrame, accounting for varying lengths. 27 | 28 | Note: 29 | If chunk_length is not provided, each element is represented by a single index. 30 | If chunk_length is provided, elements are divided into chunks, and each chunk is represented by its element's index. 31 | """ 32 | if chunk_length is not None: 33 | map_idx = [] 34 | for i, (idx, row) in enumerate(df.iterrows()): 35 | map_idx.extend([i]*int(max(1,row['duration']//chunk_length))) 36 | return map_idx 37 | else: 38 | return list(range(len(df))) 39 | 40 | def dataset_random_split(df: pd.DataFrame, 41 | proportions: Dict[str, float] = {}) -> Dict[str, pd.DataFrame]: 42 | """ 43 | Splits a DataFrame into partitions randomly based on given proportions. 44 | 45 | Args: 46 | df (pandas.DataFrame): The DataFrame to be split. 47 | proportions (dict, optional): Dictionary containing proportions of split for each partition. 48 | If value is greater than 1, it's treated as the number of samples to include in the partition. 49 | If value is between 0 and 1, it's treated as the proportion of the DataFrame to include in the partition. 50 | If -1 is provided for any partition, remaining samples will be assigned to this partition. 51 | Defaults to an empty dictionary. 52 | 53 | Returns: 54 | dict: Dictionary containing partitions as DataFrames. 55 | 56 | Raises: 57 | Exception: If -1 is used in more than one entry in the proportions dictionary. 58 | """ 59 | idxs = df.index 60 | prop_type = [v for k,v in proportions.items() if v>1] 61 | if len(prop_type)>0: 62 | prop_type = 'n' 63 | else: 64 | prop_type = 'prop' 65 | remainder_k = [k for k,v in proportions.items() if v==-1] 66 | if len(remainder_k) > 1: 67 | raise Exception("-1 can't be used in more than one entry") 68 | elif len(remainder_k) == 1: 69 | remainder_k = remainder_k[0] 70 | else: 71 | remainder_k = None 72 | partitions = {} 73 | for k,v in proportions.items(): 74 | if k != remainder_k: 75 | if prop_type == 'prop': 76 | v = int(len(df)*v) 77 | sampled_idxs = np.random.choice(idxs, v, replace=False) 78 | idxs = [i for i in idxs if i not in sampled_idxs] 79 | partitions[k] = df.loc[sampled_idxs] 80 | if remainder_k is not None: 81 | partitions[remainder_k] = df.loc[idxs] 82 | return partitions 83 | 84 | class DictDataset(Dataset): 85 | """ 86 | Dataset class to handle data stored in a dictionary-like format. 87 | 88 | Args: 89 | metadata (pandas.DataFrame): DataFrame containing metadata of the dataset. 90 | state (dict): Dictionary containing additional state information. 91 | out_cols (list): List of columns to be included in the output. 92 | preprocessor (optional): Callable to apply to a dataframe row before returning the item. Defaults to None. 93 | index_mapper (callable, optional): A function to map indices of metadata. Defaults to None. 94 | state_keys (list, optional): List of keys from the state dictionary to be included in the dataset's state. Defaults to None. 95 | """ 96 | def __init__(self, 97 | metadata: pd.DataFrame, 98 | state: Dict[str, Any], 99 | out_cols: List[str], 100 | preprocessor: Callable[[Any, Dict[str, Any]], Any] = None, 101 | index_mapper: Optional[Callable[[pd.DataFrame], List[int]]] = None, 102 | state_keys: Optional[List[str]] = None): 103 | 104 | self._metadata = metadata 105 | self._out_cols = out_cols 106 | self._state = {} 107 | self._state['metadata'] = metadata 108 | if 'classID' in state['dataset_metadata'].columns: 109 | if isinstance(state['dataset_metadata'].iloc[0]['classID'], np.ndarray): 110 | self._state['num_classes'] = len(state['dataset_metadata'].iloc[0]['classID']) 111 | else: 112 | self._state['num_classes'] = state['dataset_metadata']['classID'].max() + 1 113 | if state_keys is not None: 114 | for k in state_keys: 115 | if k in state: 116 | self._state[k] = state[k] 117 | self._preprocessor = preprocessor() 118 | if index_mapper is not None: 119 | self._idx_map = index_mapper(self._metadata) 120 | else: 121 | self._idx_map = list(range(len(self._metadata))) 122 | 123 | def __getitem__(self, idx): 124 | row = copy.deepcopy(self._metadata.iloc[self._idx_map[idx]]) 125 | if self._preprocessor is not None: 126 | row = self._preprocessor(row) 127 | out = {k: row[k] for k in self._out_cols} 128 | return out 129 | 130 | def __len__(self): 131 | return len(self._idx_map) 132 | 133 | def dynamic_pad_batch(x: Union[list, Dict[str, Any]]) -> Dict[str, torch.Tensor]: 134 | """ 135 | Dynamically pads a batch of sequences with variable lengths and converts them to PyTorch tensors. 136 | 137 | Args: 138 | x (Union[list, dict]): List or dictionary containing sequences to be padded. 139 | 140 | Returns: 141 | dict: Dictionary containing padded sequences converted to PyTorch tensors. 142 | """ 143 | def not_discarded(x): 144 | if x is None: 145 | return False 146 | else: 147 | return not any([xi is None for xi in x.values()]) 148 | 149 | def get_len(x): 150 | if x.ndim == 0: 151 | return 1 152 | else: 153 | return x.shape[0] 154 | 155 | def pad_to_len(x, max_len): 156 | if x.ndim == 0: 157 | return x 158 | else: 159 | pad_spec = ((0,max_len-x.shape[0]),) + ((0,0),)*(x.ndim - 1) 160 | return np.pad(x,pad_spec) 161 | 162 | def to_torch(x): 163 | if isinstance(x, torch.Tensor): 164 | return x 165 | else: 166 | if x.dtype in [np.float64, np.float32, np.float16, 167 | np.complex64, np.complex128, 168 | np.int64, np.int32, np.int16, np.int8, 169 | np.uint8, np.bool]: 170 | 171 | return torch.from_numpy(x) 172 | else: 173 | return x 174 | 175 | x_ = x 176 | x = [xi for xi in x if not_discarded(xi)] 177 | 178 | batch = {k: [np.array(xi[k]) for xi in x] for k in x[0]} 179 | batch_lens = {k: [get_len(x) for x in batch[k]] for k in batch.keys()} 180 | batch_max_lens = {k: max(v) for k,v in batch_lens.items()} 181 | batch = {k: np.stack([pad_to_len(x, batch_max_lens[k]) for x in batch[k]]) for k in batch.keys()} 182 | batch_lens = {k+'_lens': np.array(v) for k,v in batch_lens.items()} 183 | batch.update(batch_lens) 184 | batch = {k: to_torch(v) for k,v in batch.items()} 185 | 186 | return batch 187 | 188 | def get_dataloaders(state: Dict[str, Any], 189 | split_function: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, 190 | dataset_cls: Optional[Type[Any]] = None, 191 | dataloader_cls: Optional[Type[Any]] = None, 192 | dataset_key_in: str = 'dataset_metadata', 193 | dataset_key_out: str = 'datasets', 194 | partitions_key_out: str = 'partitions', 195 | dataloaders_key_out: str = 'dataloaders') -> Dict[str, Any]: 196 | """ 197 | Constructs dataloaders from the given state and configurations. 198 | 199 | Args: 200 | state (dict): The dictionary representing the current state. 201 | split_function (callable, optional): A function to split the dataset. Defaults to None. 202 | dataset_cls (type, optional): The class to instantiate datasets. Defaults to None. 203 | dataloader_cls (type, optional): The class to instantiate dataloaders. Defaults to None. 204 | dataset_key_in (str, optional): The key in the state to get the dataset metadata. Defaults to 'dataset_metadata'. 205 | dataset_key_out (str, optional): The key to use for storing datasets in the state. Defaults to 'datasets'. 206 | partitions_key_out (str, optional): The key to use for storing dataset partitions in the state. Defaults to 'partitions'. 207 | dataloaders_key_out (str, optional): The key to use for storing dataloaders in the state. Defaults to 'dataloaders'. 208 | 209 | Returns: 210 | dict: The updated state dictionary with datasets, partitions, and dataloaders. 211 | 212 | Raises: 213 | TypeError: If split_function, dataset_cls, or dataloader_cls is not callable. 214 | """ 215 | 216 | if split_function is not None: 217 | partitions = split_function(state[dataset_key_in]) 218 | else: 219 | partitions = {'train': state[dataset_key_in]} 220 | 221 | datasets = {k: dataset_cls[k](v, state) for k,v in partitions.items() if k in dataset_cls} 222 | dataloaders = {k: dataloader_cls[k](v) for k,v in datasets.items() if k in dataloader_cls} 223 | 224 | state[partitions_key_out] = partitions 225 | state[dataset_key_out] = datasets 226 | state[dataloaders_key_out] = dataloaders 227 | 228 | return state 229 | 230 | def load_dataset(state: Dict[str, Any], 231 | reader_fns: List[Callable[[], pd.DataFrame]], 232 | cache: bool = True, 233 | postprocessors: List[Callable[[pd.DataFrame], pd.DataFrame]] = [], 234 | key_out: str = 'dataset_metadata', 235 | rename: List[Dict[str, Any]] = None) -> Dict[str, Any]: 236 | """ 237 | Loads dataset into the state. 238 | 239 | Args: 240 | state (dict): The dictionary representing the current state. 241 | reader_fns (list): List of functions that return DataFrames when called. 242 | cache (bool, optional): Whether to cache the dataset in the state. Defaults to True. 243 | postprocessors (list, optional): List of functions to apply to the resulting dataframe. Each function should accept and return a DataFrame. Defaults to []. 244 | key_out (str, optional): The key to use for storing the dataset in the state. Defaults to 'dataset_metadata'. 245 | rename (list, optional): List of dictionaries specifying renaming rules. Each dictionary should contain keys 'column', 'value', and 'new_value' for renaming. Defaults to None. 246 | 247 | Returns: 248 | dict: The updated state dictionary with the loaded dataset. 249 | 250 | Raises: 251 | TypeError: If reader_fns is not a list. 252 | """ 253 | 254 | if not (cache and key_out in state): 255 | if not isinstance(reader_fns, list): 256 | raise TypeError("reader_fns must be a list of reader functions.") 257 | elif len(reader_fns) == 0: 258 | raise Exception("reader_fns is empty. Supply at least one reader function") 259 | dfs = [fn() for fn in reader_fns] 260 | df = pd.concat(dfs).reset_index() 261 | state[key_out] = df 262 | else: 263 | logger.info('Caching dataset metadata from state') 264 | 265 | for f in postprocessors: 266 | state[key_out] = f(state[key_out]) 267 | 268 | if rename is not None: 269 | for r in rename: 270 | state[key_out][r['column']] = state[key_out][r['column']].apply(lambda x: r['new_value'] if x == r['value'] else x) 271 | 272 | return state 273 | 274 | def read_audiodir(dataset_path: List[str], 275 | subsample: Optional[int] = None, 276 | dataset: Optional[str] = None, 277 | regex_groups: Optional[str] = None, 278 | filter_list: Optional[str] = None, 279 | partition_lists: Optional[Dict[str, Optional[str]]] = None, 280 | filter_mode: str = 'include') -> pd.DataFrame: 281 | """ 282 | Reads audio files from directories and generates metadata DataFrame. 283 | 284 | Args: 285 | dataset_path (list): List of paths to directories containing audio files. 286 | subsample (int, optional): Number of files to subsample. Defaults to None. 287 | dataset (str, optional): Name of the dataset. Defaults to None. 288 | regex_groups (str, optional): Regular expression to extract metadata from filenames. Defaults to None. 289 | filter_list (str, optional): Path to a file containing a list of filenames to filter. Defaults to None. 290 | partition_lists (dict, optional): Dictionary mapping partitions to filenames. Defaults to None. 291 | filter_mode (str, optional): Filtering mode, either 'include' or 'discard'. Defaults to 'include'. 292 | 293 | Returns: 294 | pandas.DataFrame: Metadata DataFrame containing information about audio files. 295 | 296 | Raises: 297 | Exception: If an unrecognized filter mode is provided. 298 | """ 299 | 300 | if not isinstance(dataset_path, list): 301 | dataset_path = [dataset_path] 302 | 303 | all_files = [] 304 | for p in dataset_path: 305 | all_files_i = list(Path(p).rglob('*.wav')) + list(Path(p).rglob('*.flac')) 306 | all_files.extend(all_files_i) 307 | 308 | if filter_list is not None: 309 | with open(filter_list, 'r') as f: 310 | keep_values = set(f.read().splitlines()) 311 | n_slashes = len(next(iter(keep_values)).split('/')) - 1 312 | stem_to_f = {'/'.join(v.parts[-n_slashes-1:]): v for v in all_files} 313 | if filter_mode == 'include': 314 | all_files = [stem_to_f[k] for k in keep_values] 315 | elif filter_mode == 'discard': 316 | all_files = [v for k,v in stem_to_f.items() if k not in keep_values] 317 | else: 318 | raise Exception("Unrecognized filter_mode {}".format(filter_mode)) 319 | 320 | if subsample is not None: 321 | subsample_idx = np.random.choice(np.arange(len(all_files)),size=subsample,replace=False) 322 | all_files = np.array(all_files)[subsample_idx] 323 | 324 | rows = [] 325 | for f in tqdm(all_files): 326 | try: 327 | finfo = sf.info(f) 328 | metadata = {'filename': str(f.resolve()), 329 | 'sr': finfo.samplerate, 330 | 'channels': finfo.channels, 331 | 'frames': finfo.frames, 332 | 'duration': finfo.duration, 333 | 'rel_path': str(f.relative_to(p))} 334 | if regex_groups is not None: 335 | regex_data = re.match(regex_groups,str(f.relative_to(dataset_path[0]))).groupdict() 336 | metadata.update(regex_data) 337 | rows.append(metadata) 338 | except Exception as e: 339 | print(f'Failed reading {f}. {e}') 340 | df = pd.DataFrame(rows) 341 | 342 | if dataset is not None: 343 | df['dataset'] = dataset 344 | 345 | # df['rel_path'] = df['filename'].apply(lambda x: str(Path(x).relative_to(Path(dataset_path[0])))) 346 | 347 | if partition_lists is not None: 348 | remainder = None 349 | map_to_partitions={} 350 | for k,v in partition_lists.items(): 351 | if v is not None: 352 | list_path = Path(dataset_path[0],v) 353 | with open(list_path,'r') as f: 354 | list_files = f.read().splitlines() 355 | for l in list_files: 356 | map_to_partitions[str(l)] = k 357 | else: 358 | remainder = k 359 | 360 | df['partition'] = df['rel_path'].apply(lambda x: map_to_partitions[x] if x in map_to_partitions else remainder) 361 | df = df.drop('rel_path', axis=1) 362 | 363 | return df 364 | 365 | def read_st_dataset(dataset_path): 366 | df = pd.read_csv(Path(dataset_path, 'metadata_selftrain_dataset.csv'), names=['start','stop','filename']) 367 | df = df.reset_index() 368 | df = df.rename({'index':'filename_audio','filename':'filename_targets'},axis=1) 369 | return df 370 | 371 | def remove_long_audios(df: pd.DataFrame, limit: int = 10000) -> pd.DataFrame: 372 | """ 373 | Removes rows from a DataFrame where the duration of audio files exceeds a specified limit. 374 | 375 | Args: 376 | df (pandas.DataFrame): The DataFrame containing audio metadata. 377 | limit (int, optional): The maximum duration in milliseconds. Defaults to 10000. 378 | 379 | Returns: 380 | pandas.DataFrame: DataFrame with rows removed where the duration exceeds the limit. 381 | """ 382 | df = df.loc[df['duration'] None: 13 | pass 14 | 15 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]: 16 | """Process input and return output. 17 | 18 | Args: 19 | x (Dict[str, Any]): Input data. 20 | 21 | Returns: 22 | Dict[str, Any]: Processed data. 23 | """ 24 | raise NotImplementedError 25 | 26 | class SequentialProcessor(Processor): 27 | """Sequential processor that applies a list of processors sequentially.""" 28 | 29 | def __init__(self, processors: List[Processor]) -> None: 30 | """Initialize SequentialProcessor. 31 | 32 | Args: 33 | processors (List[Processor]): List of processors. 34 | """ 35 | self._processors = [p() for p in processors] 36 | 37 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]: 38 | """Process input by applying each processor sequentially. 39 | 40 | Args: 41 | x (Dict[str, Any]): Input data. 42 | 43 | Returns: 44 | Dict[str, Any]: Processed data. 45 | """ 46 | for p in self._processors: 47 | x = p(x) 48 | return x 49 | 50 | class ReadAudioProcessor(Processor): 51 | """Processor to read audio files.""" 52 | 53 | def __init__(self, key_in: str, key_out: str, max_length: Union[float, None] = None, mono: bool = True) -> None: 54 | """Initialize ReadAudioProcessor. 55 | 56 | Args: 57 | key_in (str): Key for input audio. 58 | key_out (str): Key for output audio. 59 | max_length (Union[float, None], optional): Maximum length of audio in seconds. Defaults to None. 60 | mono (bool, optional): Whether to convert stereo audio to mono. Defaults to True. 61 | """ 62 | super().__init__() 63 | self.key_in, self.key_out, self.max_length, self.mono = key_in, key_out, max_length, mono 64 | 65 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]: 66 | """Read audio file and process it. 67 | 68 | Args: 69 | x (Dict[str, Any]): Input data. 70 | 71 | Returns: 72 | Dict[str, Any]: Processed data. 73 | """ 74 | try: 75 | if self.max_length is not None: 76 | audio_info = sf.info(x[self.key_in]) 77 | desired_frames = int(self.max_length*audio_info.samplerate) 78 | total_frames = audio_info.frames 79 | if total_frames > desired_frames: 80 | start = random.randint(0,total_frames - desired_frames) 81 | stop = start + desired_frames 82 | else: 83 | start = 0 84 | stop = None 85 | else: 86 | start = 0 87 | stop = None 88 | if 'start' in x: 89 | start = x['start'] 90 | if 'stop' in x: 91 | stop = x['stop'] 92 | x['start'] = start 93 | x['stop'] = stop 94 | wav, fs = sf.read(x[self.key_in], start=start, stop=stop, dtype=np.float32) 95 | if (wav.ndim == 2) and self.mono: 96 | wav = np.mean(wav,axis=-1) 97 | except Exception as e: 98 | logger.warning('Failed reading {}'.format(x[self.key_in])) 99 | wav = None 100 | x[self.key_out] = wav 101 | return x 102 | 103 | class LoadNumpyProcessor(Processor): 104 | """Processor to load numpy arrays.""" 105 | 106 | def __init__(self, key_in: str, key_out: str) -> None: 107 | """Initialize LoadNumpyProcessor. 108 | 109 | Args: 110 | key_in (str): Key for input numpy array file. 111 | key_out (str): Key for output numpy array. 112 | """ 113 | super().__init__() 114 | self.key_in = key_in 115 | self.key_out = key_out 116 | 117 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]: 118 | """Load numpy array and process it. 119 | 120 | Args: 121 | x (Dict[str, Any]): Input data. 122 | 123 | Returns: 124 | Dict[str, Any]: Processed data. 125 | """ 126 | x[self.key_out] = np.load(x[self.key_in]) 127 | return x 128 | 129 | class MelspectrogramProcessor(Processor): 130 | """Processor to calculate melspectrograms from waveforms. 131 | Internally, torchaudio.compliance.kaldi.fbank is used and the same kwargs are accepted. 132 | Additionally, norm_stats can be supplied with [mean,std] to normalize the resulting melspectrogram. 133 | key_in and key_out are strings indicating the key containing the waveform 134 | and the key where the result will be stored.""" 135 | 136 | def __init__(self, key_in='wav', key_out='wav_features', 137 | frame_length=25, 138 | frame_shift=10, 139 | high_freq=0, 140 | htk_compat=False, 141 | low_freq=20, 142 | num_mel_bins=23, 143 | sample_frequency=16000, 144 | window_type='povey', 145 | dither=0.0, 146 | use_energy=False, norm_stats=[0,1]): 147 | super().__init__() 148 | self.mel_kwargs = dict(frame_length=frame_length, 149 | frame_shift=frame_shift, high_freq=high_freq, 150 | htk_compat=htk_compat, low_freq=low_freq, num_mel_bins=num_mel_bins, 151 | sample_frequency=sample_frequency, window_type=window_type, use_energy=use_energy) 152 | self.norm_stats = norm_stats 153 | self.key_in, self.key_out = key_in, key_out 154 | 155 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]: 156 | """Calculate melspectrogram and normalize it. 157 | 158 | Args: 159 | x (Dict[str, Any]): Input data dictionary containing waveform. 160 | 161 | Returns: 162 | Dict[str, Any]: Output data dictionary containing melspectrogram. 163 | 164 | """ 165 | if x[self.key_in] is not None: 166 | mel = torchaudio.compliance.kaldi.fbank(torch.from_numpy(x[self.key_in]).unsqueeze(0), **self.mel_kwargs).numpy() 167 | mel = (mel-self.norm_stats[0])/self.norm_stats[1] 168 | else: 169 | mel = None 170 | x[self.key_out] = mel 171 | return x 172 | 173 | class SpectrogramProcessor(Processor): 174 | """Processor to calculate spectrograms from waveforms. 175 | Internally, torchaudio.compliance.kaldi.spectrogram is used and the same kwargs are accepted. 176 | Additionally, norm_stats can be supplied with [mean,std] to normalize the resulting spectrogram. 177 | key_in and key_out are strings indicating the key containing the waveform 178 | and the key where the result will be stored.""" 179 | def __init__(self, key_in='wav', key_out='wav_features', 180 | frame_length=25, 181 | frame_shift=10, 182 | sample_frequency=16000, 183 | window_type='povey', 184 | dither=0.0, 185 | norm_stats=[0,1]): 186 | self.spec_kwargs = dict(frame_length=frame_length, 187 | frame_shift=frame_shift, 188 | sample_frequency=sample_frequency, 189 | window_type=window_type, 190 | dither=dither) 191 | self.norm_stats = norm_stats 192 | self.key_in, self.key_out = key_in, key_out 193 | 194 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]: 195 | """Calculate spectrogram and normalize it. 196 | 197 | Args: 198 | x (Dict[str, Any]): Input data dictionary containing waveform. 199 | 200 | Returns: 201 | Dict[str, Any]: Output data dictionary containing spectrogram. 202 | 203 | """ 204 | spec = torchaudio.compliance.kaldi.spectrogram(torch.from_numpy(x[self.key_in]).unsqueeze(0), **self.spec_kwargs).numpy() 205 | spec = (spec-self.norm_stats[0])/self.norm_stats[1] 206 | x[self.key_out] = spec 207 | return x 208 | -------------------------------------------------------------------------------- /encodecmae/tasks/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | def set_seed(state, seed=42): 6 | random.seed(seed) 7 | np.random.seed(seed) 8 | torch.manual_seed(seed) 9 | state.seed = seed 10 | return state -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface_hub 2 | encodec==0.1.1 3 | ginpipe==0.0.2 4 | librosa>=0.10.1 5 | numpy>=1.24.1 6 | pytorch_lightning>=2.0.0 7 | torch>=2.0.0 8 | torchinfo 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = encodecmae 3 | version = 0.0.1 4 | author = Leonardo Pepino 5 | author_email = lpepino@dc.uba.ar 6 | description = EnCodecMAE 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | classifiers = 10 | Programming Language :: Python :: 3 11 | Operating System :: OS Independent 12 | [options] 13 | packages = encodecmae 14 | python_requires = >=3.6 15 | install_requires = 16 | huggingface_hub 17 | encodec==0.1.1 18 | torch>=2.0.0 19 | pytorch_lightning>=2.0.0 20 | ginpipe 21 | librosa 22 | torchinfo 23 | pandas 24 | 25 | 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | setup() -------------------------------------------------------------------------------- /start_docker.sh: -------------------------------------------------------------------------------- 1 | cd encodecmae 2 | docker build -t encodecmae:latest . 3 | docker compose up -d 4 | docker attach encodecmae-train --------------------------------------------------------------------------------