├── .gitignore
├── README.md
├── encodecmae
├── Dockerfile
├── __init__.py
├── configs
│ ├── base
│ │ ├── create_st_dataset.gin
│ │ ├── self-train.gin
│ │ └── unsupervised_pretrain.gin
│ ├── datasets
│ │ ├── audioset-unbalanced-24k.gin
│ │ ├── fma-large-24k.gin
│ │ └── librilight-6k-24k.gin
│ ├── features
│ │ ├── mel.gin
│ │ ├── spec.gin
│ │ ├── st.gin
│ │ └── wav_only.gin
│ ├── imports
│ └── models
│ │ └── encodecmae.gin
├── docker-compose.yml
├── heareval_model
│ ├── __init__.py
│ └── encodecmae.py
├── hub.py
├── models
│ ├── __init__.py
│ ├── encoders
│ │ └── __init__.py
│ ├── heads
│ │ └── __init__.py
│ ├── losses
│ │ └── __init__.py
│ ├── mae.py
│ ├── masks
│ │ └── __init__.py
│ ├── targets
│ │ └── __init__.py
│ └── transformers
│ │ └── __init__.py
├── scripts
│ └── run_pretraining.sh
└── tasks
│ ├── __init__.py
│ ├── data
│ └── __init__.py
│ ├── features
│ └── __init__.py
│ └── utils
│ └── __init__.py
├── requirements.txt
├── setup.cfg
├── setup.py
└── start_docker.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
EnCodecMAE: Leveraging neural codecs for universal audio representation learning
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | This is EnCodecMAE, an audio feature extractor pretrained with masked language modelling to predict discrete targets generated by EnCodec, a neural audio codec.
14 | For more details about the architecture and pretraining procedure, read the [paper](https://arxiv.org/abs/2309.07391).
15 |
16 | ## Updates:
17 | - 2024/5/23 Updated paper in arxiv. New models with better performance across all downstream tasks are available for feature extraction. Code for older version is [here](https://github.com/habla-liaa/encodecmae/tree/v.1.0.0)
18 | - 2024/2/29 [New code](https://github.com/mrpep/encodecmae-to-wav) to go from encodecmae to the waveform domain, with pretrained generative audio models from [this paper](https://arxiv.org/abs/2402.09318).
19 | - 2024/2/14 [Leveraging Pre-Trained Autoencoders for Interpretable Prototype Learning of Music Audio](https://arxiv.org/abs/2402.09318) was accepted to ICASSP 2024 XAI Workshop.
20 | - 2023/10/23 [Prompting for audio generation](https://mrpep.github.io/myblog/posts/audio-lm/).
21 |
22 | ## Usage
23 |
24 | ### Feature extraction using pretrained models
25 |
26 | #### Try our example [Colab notebook](https://colab.research.google.com/drive/123Zn6h0DRVcjsLFp8Xl4j0PZlZ-7VsK2?usp=sharing) or
27 |
28 | #### 1) Clone the [EnCodecMAE library](https://github.com/habla-liaa/encodecmae):
29 | ```
30 | git clone https://github.com/habla-liaa/encodecmae.git
31 | ```
32 |
33 | #### 2) Install it:
34 |
35 | ```
36 | cd encodecmae
37 | pip install -e .
38 | ```
39 |
40 | #### 3) Extract embeddings in Python:
41 |
42 | ``` python
43 | from encodecmae import load_model
44 |
45 | model = load_model('mel256-ec-base_st', device='cuda:0')
46 | features = model.extract_features_from_file('gsc/bed/00176480_nohash_0.wav')
47 | ```
48 |
49 | ### Pretrain your models
50 |
51 | #### 1) Install docker and docker-compose in your system. You'll also need to install nvidia-container toolkit to access GPUs from a docker container.
52 | #### 2) Execute the start_docker.sh script
53 |
54 | First, docker-compose.yml has to be modified. In the volumes section, change the routes to the ones in your system. You'll need a folder called datasets with the following subfolders:
55 | - audioset_24k/unbalanced_train
56 | - fma_large_24k
57 | - librilight_med_24k
58 |
59 | All the audio files need to be converted to a 24kHz sampling rate.
60 |
61 | You might also modify the device_ids if you have a different number of gpus.
62 |
63 | Then, run:
64 | ```
65 | chmod +x start_docker.sh
66 | ./start_docker.sh
67 | ```
68 | This will build the encodecmae image, start a container using docker compose, and attach to it.
69 |
70 | #### 3) Install the encodecmae package inside the container
71 | ```
72 | cd workspace/encodecmae
73 | pip install -e .
74 | ```
75 |
76 | #### 4) Run the training script
77 | ```
78 | chmod +x scripts/run_pretraining.sh
79 | scripts/run_pretraining.sh
80 | ```
81 |
82 | The training script uses my own library for executing pipelines configured with gin: ginpipe. By modifying the config files (with .gin extension), you can control aspects of the training and the model configuration. I plan to explain my approach to ML pipelines, and how to use gin and ginpipe in a future blog article. Stay tuned!
83 |
--------------------------------------------------------------------------------
/encodecmae/Dockerfile:
--------------------------------------------------------------------------------
1 | from nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
2 |
3 | ENV TZ=America/Argentina/Buenos_Aires
4 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
5 |
6 | COPY requirements.txt .
7 | RUN apt-get --fix-missing update && apt-get update && apt-get install -y \
8 | python3.9 \
9 | python3-pip \
10 | git \
11 | sox
12 | RUN pip3 install -r requirements.txt
13 |
14 |
--------------------------------------------------------------------------------
/encodecmae/__init__.py:
--------------------------------------------------------------------------------
1 | from .hub import load_model
--------------------------------------------------------------------------------
/encodecmae/configs/base/create_st_dataset.gin:
--------------------------------------------------------------------------------
1 | DEVICE='cuda:0'
2 | TRAIN_BATCH_SIZE=32
3 | TRAIN_DATALOADER_NUM_WORKERS=8
4 | MAX_AUDIO_DURATION=4
5 | FILTER_AUDIO_LENGTH=10000
6 | TEACHER_LAYER=-1
7 | POSTNORM_LAST_ACTIVATION=True
8 |
9 | $keys_not_saved+=['datasets','dataloaders']
10 |
11 | execute_pipeline:
12 | tasks = [@encodecmae.tasks.utils.set_seed,
13 | @encodecmae.tasks.load_model,
14 | @encodecmae.tasks.data.load_dataset,
15 | @encodecmae.tasks.data.get_dataloaders,
16 | @encodecmae.tasks.create_st_dataset]
17 |
18 | encodecmae.tasks.load_model:
19 | model = %TEACHER_MODEL
20 |
21 | encodecmae.tasks.data.get_dataloaders.dataset_cls={'train': @train/encodecmae.tasks.data.DictDataset}
22 | encodecmae.tasks.data.get_dataloaders.dataloader_cls={'train': @train/torch.utils.data.DataLoader}
23 |
24 | train/torch.utils.data.DataLoader:
25 | shuffle=True
26 | batch_size=%TRAIN_BATCH_SIZE
27 | num_workers=%TRAIN_DATALOADER_NUM_WORKERS
28 | collate_fn=@tasks.data.dynamic_pad_batch
29 |
30 | encodecmae.tasks.data.DictDataset.index_mapper=@encodecmae.tasks.data.compensate_lengths
31 | encodecmae.tasks.data.compensate_lengths.chunk_length=%MAX_AUDIO_DURATION #This will sample long audios multiple times during one epoch (duration//compensate_framing times)
32 |
33 | encodecmae.tasks.data.load_dataset.postprocessors=[@encodecmae.tasks.data.remove_long_audios]
34 | encodecmae.tasks.data.remove_long_audios.limit=%FILTER_AUDIO_LENGTH
35 |
36 | encodecmae.tasks.create_st_dataset:
37 | device=%DEVICE
38 | layer=%TEACHER_LAYER
39 | postnorm_last_activation=%POSTNORM_LAST_ACTIVATION
--------------------------------------------------------------------------------
/encodecmae/configs/base/self-train.gin:
--------------------------------------------------------------------------------
1 | SEED=42
2 | TRAIN_BATCH_SIZE=128
3 | VAL_BATCH_SIZE=8
4 | TRAIN_DATALOADER_NUM_WORKERS=8
5 | VAL_DATALOADER_NUM_WORKERS=4
6 | MAX_AUDIO_DURATION=4
7 | MAX_LR=0.0001
8 | GRAD_ACC=1
9 | TOTAL_PRETRAIN_STEPS=150000
10 | CHECKPOINT_INTERVAL=50000
11 | SELFTRAIN_CHECKPOINT='last'
12 | VAL_SET_SIZE=200
13 | FILTER_AUDIO_LENGTH=10000 #Some files might be too long according to mediainfo so they are discarded.
14 | DEVICE=[0]
15 | PRECISION=16
16 | NUM_TOTAL_TARGETS=1
17 |
18 | $keys_not_saved=['datasets','dataloaders']
19 |
20 | execute_pipeline:
21 | tasks = [@encodecmae.tasks.utils.set_seed,
22 | @encodecmae.tasks.data.load_dataset,
23 | @encodecmae.tasks.data.get_dataloaders,
24 | @encodecmae.tasks.fit_model]
25 | execution_order = 'sequential'
26 |
27 | train/torch.utils.data.DataLoader:
28 | shuffle=True
29 | batch_size=%TRAIN_BATCH_SIZE
30 | num_workers=%TRAIN_DATALOADER_NUM_WORKERS
31 | collate_fn=@encodecmae.tasks.data.dynamic_pad_batch
32 |
33 | val/torch.utils.data.DataLoader:
34 | shuffle=False
35 | batch_size=%VAL_BATCH_SIZE
36 | num_workers=%VAL_DATALOADER_NUM_WORKERS
37 | collate_fn=@encodecmae.tasks.data.dynamic_pad_batch
38 |
39 | encodecmae.tasks.data.get_dataloaders.split_function=@encodecmae.tasks.data.dataset_random_split
40 | encodecmae.tasks.data.get_dataloaders.dataset_cls={'train': @train/encodecmae.tasks.data.DictDataset, 'validation': @val/encodecmae.tasks.data.DictDataset}
41 | encodecmae.tasks.data.get_dataloaders.dataloader_cls={'train': @train/torch.utils.data.DataLoader, 'validation': @val/torch.utils.data.DataLoader}
42 |
43 | encodecmae.tasks.data.load_dataset.reader_fns+=[@encodecmae.tasks.data.read_st_dataset]
44 | encodecmae.tasks.data.read_st_dataset.dataset_path=%ST_DATASET_DIR
45 |
46 | encodecmae.tasks.fit_model:
47 | trainer_cls=@pl.Trainer
48 | from_checkpoint=%ST_CHECKPOINT
49 | checkpoint_folder='pretrain_checkpoints'
50 |
51 | pl.Trainer:
52 | logger=@pl.loggers.CSVLogger()
53 | devices=%DEVICE
54 | callbacks=[@pl.callbacks.ModelCheckpoint(), @pl.callbacks.LearningRateMonitor()]
55 | max_steps=%TOTAL_PRETRAIN_STEPS
56 | accelerator='gpu'
57 | accumulate_grad_batches=%GRAD_ACC
58 | num_sanity_val_steps=1
59 | val_check_interval=%CHECKPOINT_INTERVAL
60 | precision=%PRECISION
61 | check_val_every_n_epoch=None
62 |
63 | pl.callbacks.ModelCheckpoint:
64 | dirpath=%OUTPUT_DIR
65 | every_n_train_steps=%CHECKPOINT_INTERVAL
66 | save_top_k=-1 #Keep all the checkpoints
67 |
68 | pl.loggers.CSVLogger:
69 | save_dir=%OUTPUT_DIR
70 | name='pretrain_logs'
71 |
72 | encodecmae.tasks.data.dataset_random_split:
73 | proportions={'train':-1,'validation':%VAL_SET_SIZE}
--------------------------------------------------------------------------------
/encodecmae/configs/base/unsupervised_pretrain.gin:
--------------------------------------------------------------------------------
1 | SEED=42
2 | TRAIN_BATCH_SIZE=128
3 | VAL_BATCH_SIZE=8
4 | TRAIN_DATALOADER_NUM_WORKERS=8
5 | VAL_DATALOADER_NUM_WORKERS=4
6 | MAX_AUDIO_DURATION=4
7 | MAX_LR=0.0001
8 | GRAD_ACC=1
9 | TOTAL_PRETRAIN_STEPS=500000
10 | CHECKPOINT_INTERVAL=50000
11 | INITIAL_CHECKPOINT='last'
12 | VAL_SET_SIZE=200
13 | FILTER_AUDIO_LENGTH=10000 #Some files might be too long according to mediainfo so they are discarded.
14 | DEVICE=[0]
15 | PRECISION=16
16 |
17 | $keys_not_saved=['datasets','dataloaders']
18 |
19 | execute_pipeline:
20 | tasks = [@tasks.utils.set_seed,
21 | @tasks.data.load_dataset,
22 | @tasks.data.get_dataloaders,
23 | @tasks.fit_model]
24 | execution_order = 'sequential'
25 |
26 | tasks.utils.set_seed.seed=%SEED
27 | train/torch.utils.data.DataLoader:
28 | shuffle=True
29 | batch_size=%TRAIN_BATCH_SIZE
30 | num_workers=%TRAIN_DATALOADER_NUM_WORKERS
31 | collate_fn=@tasks.data.dynamic_pad_batch
32 |
33 | val/torch.utils.data.DataLoader:
34 | shuffle=False
35 | batch_size=%VAL_BATCH_SIZE
36 | num_workers=%VAL_DATALOADER_NUM_WORKERS
37 | collate_fn=@tasks.data.dynamic_pad_batch
38 |
39 | tasks.fit_model:
40 | trainer_cls=@pl.Trainer
41 | from_checkpoint=%INITIAL_CHECKPOINT
42 | checkpoint_folder='pretrain_checkpoints'
43 |
44 | pl.Trainer:
45 | logger=@pl.loggers.CSVLogger()
46 | devices=%DEVICE
47 | callbacks=[@pl.callbacks.ModelCheckpoint(), @pl.callbacks.LearningRateMonitor()]
48 | max_steps=%TOTAL_PRETRAIN_STEPS
49 | accelerator='gpu'
50 | accumulate_grad_batches=%GRAD_ACC
51 | num_sanity_val_steps=1
52 | val_check_interval=%CHECKPOINT_INTERVAL
53 | precision=%PRECISION
54 | check_val_every_n_epoch=None
55 |
56 | pl.callbacks.ModelCheckpoint:
57 | dirpath=%OUTPUT_DIR
58 | every_n_train_steps=%CHECKPOINT_INTERVAL
59 | save_top_k=-1 #Keep all the checkpoints
60 |
61 | pl.loggers.CSVLogger:
62 | save_dir=%OUTPUT_DIR
63 | name='pretrain_logs'
64 |
65 | tasks.data.get_dataloaders.split_function=@tasks.data.dataset_random_split
66 | tasks.data.get_dataloaders.dataset_cls={'train': @train/tasks.data.DictDataset, 'validation': @val/tasks.data.DictDataset}
67 | tasks.data.get_dataloaders.dataloader_cls={'train': @train/torch.utils.data.DataLoader, 'validation': @val/torch.utils.data.DataLoader}
68 |
69 | tasks.data.dataset_random_split:
70 | proportions={'train':-1,'validation':%VAL_SET_SIZE}
71 |
72 | tasks.data.DictDataset.index_mapper=@tasks.data.compensate_lengths
73 | tasks.data.compensate_lengths.chunk_length=%MAX_AUDIO_DURATION #This will sample long audios multiple times during one epoch (duration//compensate_framing times)
74 |
75 | tasks.data.load_dataset.postprocessors=[@tasks.data.remove_long_audios]
76 | tasks.data.remove_long_audios.limit=%FILTER_AUDIO_LENGTH
77 |
--------------------------------------------------------------------------------
/encodecmae/configs/datasets/audioset-unbalanced-24k.gin:
--------------------------------------------------------------------------------
1 | load_dataset.reader_fns+=[@audioset_unbalanced_24k/read_audiodir]
2 | audioset_unbalanced_24k/read_audiodir.dataset_path='/workspace/datasets/audioset_24k/unbalanced_train'
3 | audioset_unbalanced_24k/read_audiodir.dataset='audioset'
--------------------------------------------------------------------------------
/encodecmae/configs/datasets/fma-large-24k.gin:
--------------------------------------------------------------------------------
1 | load_dataset.reader_fns+=[@fma_large_24k/read_audiodir]
2 | fma_large_24k/read_audiodir.dataset_path='/workspace/datasets/fma_large_24k'
3 | fma_large_24k/read_audiodir.dataset='fma'
4 |
--------------------------------------------------------------------------------
/encodecmae/configs/datasets/librilight-6k-24k.gin:
--------------------------------------------------------------------------------
1 | load_dataset.reader_fns+=[@librilight_6k_24k/read_audiodir]
2 | librilight_6k_24k/read_audiodir.dataset_path='/workspace/datasets/librilight_med_24k'
3 | librilight_6k_24k/read_audiodir.dataset='librilight'
--------------------------------------------------------------------------------
/encodecmae/configs/features/mel.gin:
--------------------------------------------------------------------------------
1 | NUM_MEL_BINS=256
2 |
3 | encodecmae.tasks.data.DictDataset:
4 | out_cols=['wav','wav_features']
5 | preprocessor=@encodecmae.tasks.features.SequentialProcessor
6 | #Processor:
7 | encodecmae.tasks.features.SequentialProcessor:
8 | processors=[@encodecmae.tasks.features.ReadAudioProcessor, @encodecmae.tasks.features.MelspectrogramProcessor]
9 | encodecmae.tasks.features.ReadAudioProcessor:
10 | key_in = 'filename'
11 | key_out = 'wav'
12 | max_length = %MAX_AUDIO_DURATION
13 | encodecmae.tasks.features.MelspectrogramProcessor:
14 | key_in = 'wav'
15 | key_out = 'wav_features'
16 | sample_frequency=24000
17 | frame_shift=13.28
18 | frame_length=26.56
19 | htk_compat=True
20 | use_energy=False
21 | window_type='hanning'
22 | num_mel_bins=%NUM_MEL_BINS
23 | dither=0.0
24 | norm_stats=[-6.12, 4.82]
25 |
26 | #Wav encoder:
27 | encodecmae.models.encoders.WavEncoder:
28 | encoder = @torch.nn.Identity
29 | post_net = @wav_encoder_proj/torch.nn.Linear
30 | hop_length = 320
31 | fs = 24000
32 | key_in = 'wav_features'
33 | key_out = 'wav_features'
34 | wav_encoder_proj/torch.nn.Linear:
35 | in_features = %NUM_MEL_BINS
36 | out_features = %MODEL_DIM
37 |
38 | #Target:
39 | encodecmae.models.targets.EncodecQuantizer:
40 | key_in = 'wav'
41 | use_encodec_encoder = True
42 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/encodecmae/configs/features/spec.gin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/habla-liaa/encodecmae/2dd6759c595bf20c9d836914fef5cde9a6386843/encodecmae/configs/features/spec.gin
--------------------------------------------------------------------------------
/encodecmae/configs/features/st.gin:
--------------------------------------------------------------------------------
1 | NUM_TOTAL_TARGETS=1
2 |
3 | encodecmae.tasks.data.DictDataset.out_cols+=['targets']
4 |
5 | encodecmae.tasks.features.SequentialProcessor.processors+=[@encodecmae.tasks.features.LoadNumpyProcessor]
6 | encodecmae.tasks.features.ReadAudioProcessor.key_in = 'filename_audio'
7 | encodecmae.tasks.features.LoadNumpyProcessor:
8 | key_in = 'filename_targets'
9 | key_out = 'targets'
10 |
11 | encodecmae.models.EncodecMAE.target_encoder=@encodecmae.models.targets.IdentityTarget
12 | encodecmae.models.targets.IdentityTarget:
13 | key_in = 'targets'
14 | key_out = 'targets'
15 | encodecmae.models.losses.EnCodecMAEClassificationLoss.quantizer_weights=[1.]
--------------------------------------------------------------------------------
/encodecmae/configs/features/wav_only.gin:
--------------------------------------------------------------------------------
1 | encodecmae.tasks.data.DictDataset:
2 | out_cols=['wav']
3 | preprocessor=@encodecmae.tasks.features.SequentialProcessor
4 |
5 | encodecmae.tasks.features.SequentialProcessor:
6 | processors=[@encodecmae.tasks.features.ReadAudioProcessor]
7 |
8 | encodecmae.tasks.features.ReadAudioProcessor:
9 | key_in = 'filename'
10 | key_out = 'wav'
11 | max_length = %MAX_AUDIO_DURATION
--------------------------------------------------------------------------------
/encodecmae/configs/imports:
--------------------------------------------------------------------------------
1 | encodecmae.models.encoders: encodecmae.models.encoders
2 | encodecmae.models.heads: encodecmae.models.heads
3 | encodecmae.models.losses: encodecmae.models.losses
4 | encodecmae.models.masks: encodecmae.models.masks
5 | encodecmae.models.targets: encodecmae.models.targets
6 | encodecmae.models.transformers: encodecmae.models.transformers
7 | encodecmae.models: encodecmae.models
8 | encodecmae.tasks: encodecmae.tasks
9 | encodecmae.tasks.utils: encodecmae.tasks.utils
10 | encodecmae.tasks.data: encodecmae.tasks.data
11 | encodecmae.tasks.features: encodecmae.tasks.features
12 | pytorch_lightning: pl
13 | pytorch_lightning.Trainer: pl.Trainer
14 | pytorch_lightning.loggers: pl.loggers
15 | pytorch_lightning.callbacks: pl.callbacks
16 | torch.utils.data: torch.utils.data
17 | torch.optim: torch.optim
18 | torchmetrics: torchmetrics
19 | torch.nn: torch.nn
--------------------------------------------------------------------------------
/encodecmae/configs/models/encodecmae.gin:
--------------------------------------------------------------------------------
1 | NUM_ENCODEC_TARGETS=8
2 | NUM_TOTAL_TARGETS=8
3 | NUM_TARGET_TOKENS=1024
4 | MASK_GAP_SIZE=15
5 | MASK_PROP=0.5
6 | MODEL_DIM=768
7 | NUM_ENCODER_LAYERS=10
8 | NUM_ENCODER_HEADS=12
9 | NUM_DECODER_LAYERS=2
10 | NUM_DECODER_HEADS=12
11 | MASKED_LOSS_WEIGHT=0.9
12 | WAV_FEATURE_DIM=128
13 | QUANTIZER_WEIGHTS=[0.22407463, 0.1759858 , 0.14499009, 0.12150037, 0.10315603, 0.08831368, 0.07608274, 0.06589669]
14 | RETURN_ONLY_LAST_Q=False
15 |
16 | #Global settings:
17 | encodecmae.tasks.fit_model.model_cls=@encodecmae.models.EncodecMAE
18 | encodecmae.models.EncodecMAE:
19 | wav_encoder = @encodecmae.models.encoders.WavEncoder
20 | target_encoder = @encodecmae.models.targets.EncodecQuantizer
21 | masker = @encodecmae.models.masks.PatchoutMask
22 | visible_encoder = @encoder/encodecmae.models.transformers.TransformerEncoder
23 | decoder = @decoder/encodecmae.models.transformers.TransformerEncoder
24 | head = @encodecmae.models.heads.FrameLevelClassificationHead
25 | loss = @encodecmae.models.losses.EnCodecMAEClassificationLoss
26 | optimizer=@torch.optim.AdamW
27 |
28 | #Wav encoder:
29 | encodecmae.models.encoders.WavEncoder:
30 | encoder = @encodecmae.models.encoders.EncodecEncoder
31 | post_net = @wav_encoder_proj/torch.nn.Linear
32 | wav_encoder_proj/torch.nn.Linear:
33 | in_features = %WAV_FEATURE_DIM
34 | out_features = %MODEL_DIM
35 |
36 | #Masking:
37 | encodecmae.models.masks.PatchoutMask:
38 | masker = @encodecmae.models.masks.TimeGapMask
39 | positional_encoder = @encodecmae.models.transformers.SinusoidalPositionalEmbeddings
40 | encodecmae.models.masks.TimeGapMask:
41 | gap_size = %MASK_GAP_SIZE
42 | p_mask = %MASK_PROP
43 |
44 | #Visible encoder:
45 | encoder/encodecmae.models.transformers.TransformerEncoder:
46 | model_dim=%MODEL_DIM
47 | num_layers=%NUM_ENCODER_LAYERS
48 | attention_layer=@encoder/encodecmae.models.transformers.MultiHeadAttention
49 | compile=False
50 | key_in='visible_tokens'
51 | key_padding_mask='visible_padding_mask'
52 | key_out='decoder_in'
53 | key_transformer_in=None
54 | key_transformer_out='visible_embeddings'
55 | post_net=@decoder_proj/torch.nn.Linear
56 | encoder/encodecmae.models.transformers.MultiHeadAttention:
57 | model_dim=%MODEL_DIM
58 | num_heads=%NUM_ENCODER_HEADS
59 |
60 | #Decoder:
61 | decoder_proj/torch.nn.Linear:
62 | in_features=%MODEL_DIM
63 | out_features=%MODEL_DIM
64 | decoder/encodecmae.models.transformers.TransformerEncoder:
65 | model_dim=%MODEL_DIM
66 | num_layers=%NUM_DECODER_LAYERS
67 | attention_layer=@decoder/encodecmae.models.transformers.MultiHeadAttention
68 | compile=False
69 | key_in='decoder_in'
70 | key_padding_mask='feature_padding_mask'
71 | key_out='decoder_out'
72 | positional_encoder = @encodecmae.models.transformers.SinusoidalPositionalEmbeddings
73 | decoder/encodecmae.models.transformers.MultiHeadAttention:
74 | model_dim=%MODEL_DIM
75 | num_heads=%NUM_DECODER_HEADS
76 |
77 | encodecmae.models.transformers.SinusoidalPositionalEmbeddings.embedding_dim = %MODEL_DIM
78 |
79 | #Head:
80 | encodecmae.models.heads.FrameLevelClassificationHead:
81 | model_dim=%MODEL_DIM
82 | num_tokens=%NUM_TARGET_TOKENS
83 | num_streams=%NUM_TOTAL_TARGETS
84 | #Target:
85 | encodecmae.models.targets.EncodecQuantizer:
86 | n = %NUM_ENCODEC_TARGETS
87 | key_in = 'wav_features_encoder_out'
88 | return_only_last = %RETURN_ONLY_LAST_Q
89 | #Loss:
90 | encodecmae.models.losses.EnCodecMAEClassificationLoss:
91 | masked_weight=%MASKED_LOSS_WEIGHT
92 | quantizer_weights=%QUANTIZER_WEIGHTS
93 | #Optimizer:
94 | torch.optim.AdamW:
95 | lr=%MAX_LR
96 | betas=(0.9,0.95)
97 | weight_decay=0.05
--------------------------------------------------------------------------------
/encodecmae/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3.9"
2 | services:
3 | encodecmae:
4 | image: encodecmae
5 | container_name: encodecmae-train
6 | volumes:
7 | - /home/lpepino/encodecmae:/workspace/encodecmae
8 | - /mnt/ssd4T/datasets:/workspace/datasets
9 | ipc: host
10 | stdin_open: true
11 | tty: true
12 | deploy:
13 | resources:
14 | reservations:
15 | devices:
16 | - driver: nvidia
17 | device_ids: ['0','1']
18 | capabilities: [gpu]
19 |
--------------------------------------------------------------------------------
/encodecmae/heareval_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/habla-liaa/encodecmae/2dd6759c595bf20c9d836914fef5cde9a6386843/encodecmae/heareval_model/__init__.py
--------------------------------------------------------------------------------
/encodecmae/heareval_model/encodecmae.py:
--------------------------------------------------------------------------------
1 | from encodecmae import load_model as load_encodecmae_model
2 | from torch import Tensor
3 | import torch
4 | import sys
5 | from huggingface_hub import hf_hub_download
6 |
7 | def load_model(file_path, huggingface_ckpt=None, layer=-1):
8 | model = load_encodecmae_model(file_path)
9 | if huggingface_ckpt is not None:
10 | print('Loading checkpoint from {}'.format(huggingface_ckpt))
11 | repo_id = '/'.join(huggingface_ckpt.split('/')[:2])
12 | filename = '/'.join(huggingface_ckpt.split('/')[2:])
13 | ckpt_file = hf_hub_download(repo_id=repo_id,filename=filename)
14 | ckpt = torch.load(ckpt_file, map_location='cpu')
15 | model.load_state_dict(ckpt['state_dict'], strict=False)
16 |
17 | model.sample_rate = 24000
18 | model.embedding_rate=75
19 | model.visible_encoder.compile=False
20 | model.head = None
21 | model.extraction_layer = layer
22 |
23 | del model.optimizer
24 | return model
25 |
26 | def get_timestamp_embeddings(
27 | audio: Tensor,
28 | model: torch.nn.Module,
29 | hop_size: float = 13,
30 | ) -> Tensor:
31 |
32 | with torch.no_grad():
33 | model_device = next(model.parameters()).device
34 | embeddings = model.extract_features_from_array(audio, layer=model.extraction_layer)
35 | embeddings = torch.from_numpy(embeddings).to(model_device)
36 | if (model.extraction_layer == 'all') and embeddings.ndim==3:
37 | embeddings = embeddings.unsqueeze(1)
38 | timestamps = torch.arange(0,embeddings.shape[-2])/model.embedding_rate + (0.5/model.embedding_rate)
39 | timestamps = torch.tile(timestamps[None,:],[embeddings.shape[-3],1])
40 | return embeddings, timestamps.to(model_device, dtype=torch.float32)
41 |
42 | def get_scene_embeddings(
43 | audio: Tensor,
44 | model: torch.nn.Module,
45 | ) -> Tensor:
46 |
47 | y, t = get_timestamp_embeddings(audio, model)
48 | out = torch.mean(y,axis=-2)
49 |
50 | return out
51 |
--------------------------------------------------------------------------------
/encodecmae/hub.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import HfFileSystem, hf_hub_download
2 | from .models import EncodecMAE
3 | from ginpipe.core import gin_configure_externals
4 | import gin
5 | import torch
6 | from pathlib import Path
7 | import copy
8 |
9 | def traverse_dir(dir, result, fs):
10 | for x in fs.ls(dir, refresh=True):
11 | if x['name'].endswith('.pt'):
12 | result.append('/'.join(x['name'].split('/')[2:]))
13 | else:
14 | if x['type'] == 'dir':
15 | traverse_dir(x['name'], result, fs)
16 |
17 | def get_available_models():
18 | fs = HfFileSystem()
19 | available_models = []
20 | traverse_dir('lpepino/encodecmae-v2', available_models, fs)
21 | #available_models = [x['name'].split('/')[-1] for x in fs.ls('lpepino/encodecmae-pretrained/upstreams', refresh=True)]
22 | return [x.split('.')[0] for x in available_models]
23 |
24 | @gin.configurable
25 | def get_model(model, processor):
26 | return model, processor
27 |
28 | def load_model(model, mode='eval',device='cuda:0'):
29 | #Get model files
30 | config_str = gin.config_str()
31 | registers = copy.deepcopy(gin.config._REGISTRY)
32 | gin.clear_config()
33 | registry_clear_keys = [k for k,v in gin.config._REGISTRY.items() if k not in ['gin.macro', 'gin.constant', 'gin.singleton', 'ginpipe.core.execute_pipeline', 'encodecmae.hub.get_model']]
34 | for k in registry_clear_keys:
35 | gin.config._REGISTRY.pop(k)
36 | available_models = get_available_models()
37 | if model in available_models:
38 | model_file = hf_hub_download(repo_id='lpepino/encodecmae-v2', filename='{}.pt'.format(model))
39 | else:
40 | raise Exception("Available models are: {}".format(available_models))
41 |
42 | model_state = torch.load(model_file, map_location='cpu')
43 |
44 | flag = {'module_list_str': model_state['imports']}
45 | gin_configure_externals(flag)
46 | gin.parse_config(model_state['config'])
47 | model, processor = get_model()
48 | model = model()
49 | processor = processor()
50 | model.load_state_dict(model_state['state_dict'])
51 | gin.clear_config()
52 | gin.config._REGISTRY.clear()
53 | gin.config._REGISTRY = registers
54 | gin.parse_config(config_str)
55 | if mode=='eval':
56 | model.eval()
57 | model.to(device)
58 | model.processor = processor
59 |
60 | #To avoid dynamic batch problems:
61 | if hasattr(model.visible_encoder, 'compile'):
62 | model.visible_encoder.compile=False
63 | return model
64 |
--------------------------------------------------------------------------------
/encodecmae/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .mae import EncodecMAE
--------------------------------------------------------------------------------
/encodecmae/models/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from encodec import EncodecModel
3 | import encodec
4 | import torch
5 |
6 | class WavEncoder(torch.nn.Module):
7 | def __init__(self, encoder, pre_net=None, post_net=None,
8 | key_in='wav', key_out='wav_features', key_lens='wav_lens',
9 | make_padding_mask=True, hop_length=None,fs=None):
10 | super().__init__()
11 | self.encoder = encoder()
12 | self.pre_net = pre_net() if pre_net is not None else None
13 | self.post_net = post_net() if post_net is not None else None
14 | self.key_in = key_in
15 | self.key_out = key_out
16 | self.key_wav_lens = key_lens
17 | self.make_padding_mask = make_padding_mask
18 | if hop_length is None:
19 | self.hop_length=self.encoder.hop_length
20 | else:
21 | self.hop_length=hop_length
22 | if fs is None:
23 | self.fs = self.encoder.fs
24 | else:
25 | self.fs = fs
26 |
27 | def forward(self,x):
28 | #Encode wav
29 | y = x[self.key_in]
30 | if self.pre_net is not None:
31 | y = self.pre_net(y)
32 | x[key_out+'_pre_net_output'] = y
33 | y = self.encoder(y)
34 | if self.post_net is not None:
35 | x[self.key_out+'_encoder_out'] = y
36 | y = self.post_net(y)
37 | x[self.key_out] = y
38 | #Make padding masks
39 | if self.make_padding_mask:
40 | x['features_len'] = (x['wav_lens']//self.hop_length).to(y.device)
41 | x['feature_padding_mask'] = x['features_len'][:,None] <= torch.arange(0,y.shape[1],device=y.device)[None,:]
42 |
43 | class EncodecEncoder(torch.nn.Module):
44 | def __init__(self, frozen: bool = True, scale: float = 1.0, pretrained: bool = True) -> None:
45 | """Initialize Encodec Encoder model.
46 |
47 | Args:
48 | frozen (bool, optional): Whether the model is frozen or not. Defaults to True.
49 | scale (float, optional): Scaling factor. Defaults to 1.0.
50 | pretrained (bool, optional): Whether to load a pretrained checkpoint or train from scratch. Defaults to True.
51 | """
52 | super().__init__()
53 | self.model = EncodecModel.encodec_model_24khz(pretrained=pretrained).encoder
54 | self.hop_length = self.model.hop_length
55 | self.frozen = frozen
56 | if self.frozen:
57 | self.model.eval()
58 | else:
59 | self.model.train()
60 | self.out_dim = 128
61 | self.fs = 24000
62 | self.scale = scale
63 |
64 | def forward(self, x: torch.Tensor) -> torch.Tensor:
65 | """Forward pass.
66 |
67 | Args:
68 | x (torch.Tensor): Input tensor corresponding to waveforms with shape [B, T].
69 |
70 | Returns:
71 | torch.Tensor: Output from EnCodec encoder with shape [B, T, D]
72 | """
73 | x = x.unsqueeze(1)
74 | if self.frozen:
75 | with torch.no_grad():
76 | y = self.model(x)
77 | else:
78 | y = self.model(x)
79 | y = torch.permute(y,(0,2,1))*self.scale
80 | return y
81 |
82 | class SEANetEncoder(torch.nn.Module):
83 | def __init__(self, frozen=False, lstm=False, pretrained=True):
84 | super().__init__()
85 | self.model = encodec.modules.SEANetEncoder(lstm=lstm)
86 | if pretrained:
87 | ckpt_path = 'https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th'
88 | sd = torch.hub.load_state_dict_from_url(ckpt_path, map_location='cpu', check_hash=True)
89 | sd = {k.replace('encoder.',''):v for k,v in sd.items()}
90 | if not lstm:
91 | sd = {k.replace('.15','.14'): v for k,v in sd.items()}
92 | self.model.load_state_dict(sd, strict=False)
93 | self.hop_length = self.model.hop_length
94 | self.frozen = frozen
95 | self.fs = 24000
96 | if self.frozen:
97 | self.model.eval()
98 | else:
99 | self.model.train()
100 | self.out_dim = self.model.dimension
101 |
102 | def forward(self, x):
103 | x = x.unsqueeze(1)
104 | if self.frozen:
105 | with torch.no_grad():
106 | y = self.model(x)
107 | else:
108 | y = self.model(x)
109 | y = torch.permute(y,(0,2,1))
110 | return y
--------------------------------------------------------------------------------
/encodecmae/models/heads/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | class FrameLevelClassificationHead(nn.Module):
5 | def __init__(self, model_dim: int, num_tokens: int, num_streams: int = 8, key_in: str = 'decoder_out', key_out: str = 'predicted_tokens') -> None:
6 | """Initialize FrameLevelClassificationHead.
7 |
8 | Args:
9 | model_dim (int): Dimensionality of the input.
10 | num_tokens (int): Number of tokens.
11 | num_streams (int, optional): Number of streams. Defaults to 8.
12 | key_in (str, optional): Key for input data. Defaults to 'decoder_out'.
13 | key_out (str, optional): Key for output data. Defaults to 'predicted_tokens'.
14 | """
15 | super().__init__()
16 | self.layer = nn.Linear(model_dim,num_tokens*num_streams)
17 | self.num_tokens = num_tokens
18 | self.num_streams = num_streams
19 | self.ar_sampling = False
20 | self.key_in=key_in
21 | self.key_out=key_out
22 |
23 | def forward(self, x: dict) -> dict:
24 | """Forward pass.
25 |
26 | Args:
27 | x (dict): Input data dictionary.
28 |
29 | Returns:
30 | dict: Output data dictionary.
31 | """
32 | xin = x[self.key_in]
33 | probs = self.layer(xin).view(xin.shape[0],xin.shape[1],self.num_streams,self.num_tokens)
34 | x[self.key_out] = probs
35 | return x
36 |
37 | class SegmentLevelClassificationHead(nn.Module):
38 | def __init__(self, model_dim: int, num_classes: int, num_streams: int) -> None:
39 | """Initialize SegmentLevelClassificationHead.
40 |
41 | Args:
42 | model_dim (int): Dimensionality of the input.
43 | num_classes (int): Number of classes.
44 | num_streams (int): Number of streams.
45 | """
46 | super().__init__()
47 | self.layer = nn.Linear(model_dim,num_classes*num_streams)
48 |
49 | def forward(self, x: torch.Tensor) -> torch.Tensor:
50 | """Forward pass.
51 |
52 | Args:
53 | x (torch.Tensor): Input tensor.
54 |
55 | Returns:
56 | torch.Tensor: Output tensor.
57 | """
58 | probs = self.layer(torch.mean(x,axis=1))
59 | return probs
--------------------------------------------------------------------------------
/encodecmae/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datetime import datetime
3 |
4 | class EnCodecMAEClassificationLoss(torch.nn.Module):
5 | def __init__(self, masked_weight=0.9, quantizer_weights=None):
6 | super().__init__()
7 | self.masked_weight = masked_weight
8 | self.quantizer_weights = torch.tensor(quantizer_weights)
9 |
10 | def forward(self, x):
11 | all_losses = {}
12 | #Sometimes because of padding/different frame rates, there might be a small difference in length between input and target representations.
13 | #We fix it by cropping the longest vector:
14 | if x['predicted_tokens'].shape[1] < x['targets'].shape[-1]:
15 | x['targets'] = x['targets'][:,:,:x['predicted_tokens'].shape[1]]
16 | elif x['predicted_tokens'].shape[1] > x['targets'].shape[-1]:
17 | x['predicted_tokens'] = x['predicted_tokens'][:,:x['targets'].shape[-1],:,:]
18 | x['non_visible_mask'] = x['non_visible_mask'][:,:x['targets'].shape[-1]]
19 | x['visible_mask'] = x['visible_mask'][:,:x['targets'].shape[-1]]
20 | else:
21 | pass
22 | loss = torch.nn.functional.cross_entropy(torch.permute(x['predicted_tokens'],(0,3,1,2)), torch.permute(x['targets'],(1,2,0)).to(dtype=torch.long), reduction='none')
23 | if 'quantizer_loss_weights' not in x:
24 | x['quantizer_loss_weights'] = torch.tensor([[1.0]], device=loss.device, dtype=loss.dtype)
25 | loss = loss*(x['quantizer_loss_weights'].unsqueeze(1))*(self.quantizer_weights.to(x['predicted_tokens'].device).unsqueeze(0).unsqueeze(0))
26 | masked_loss = loss*(x['non_visible_mask'].unsqueeze(-1))
27 | unmasked_loss = loss*(x['visible_mask'].unsqueeze(-1))
28 |
29 | num_non_visible = max(1,torch.sum(x['non_visible_mask']))
30 | num_visible = max(1,torch.sum(x['visible_mask']))
31 | for q in range(masked_loss.shape[-1]):
32 | all_losses[f'masked_loss_q{q}'] = torch.sum(masked_loss[:,:,q])/num_non_visible
33 | all_losses[f'unmasked_loss_q{q}'] = torch.sum(unmasked_loss[:,:,q])/num_visible
34 |
35 | masked_loss = torch.sum(masked_loss)/num_non_visible
36 | unmasked_loss = torch.sum(unmasked_loss)/num_visible
37 | all_losses['masked_loss'] = masked_loss
38 | all_losses['unmasked_loss'] = unmasked_loss
39 | loss = self.masked_weight*masked_loss + (1-self.masked_weight)*unmasked_loss
40 | all_losses['loss'] = loss
41 | all_losses['time'] = int(datetime.now().strftime('%y%m%d%H%M%S'))
42 |
43 | return all_losses
44 |
--------------------------------------------------------------------------------
/encodecmae/models/mae.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | import torch.nn.functional as F
4 | from datetime import datetime
5 | import gin
6 | import librosa
7 |
8 | import pytorch_lightning as pl
9 | import torch
10 | import torch.nn.functional as F
11 | from datetime import datetime
12 | import gin
13 | import librosa
14 |
15 | import numpy as np
16 |
17 | class EncodecMAE(pl.LightningModule):
18 | """EncodecMAE.
19 |
20 | Args:
21 | wav_encoder (torch.nn.Module): Module that takes a batch of waveforms (BxT) and returns a representation (BxTxDf).
22 | target_encoder (torch.nn.Module): Module that generates targets from unmasked inputs.
23 | visible_encoder (torch.nn.Module): Module that encodes the visible tokens returning visible embeddings (BxTvxD)
24 | decoder (torch.nn.Module): Module that takes expanded input (with mask tokens) and generate embeddings for the whole sequence (BxTxD)
25 | masker (torch.nn.Module): Module that takes wav_encoder outputs and mask them. Has 2 methods: mask() that masks the embeddings, and unmask() that expands visible embeddings to original length.
26 | head (torch.nn.Module): Module that takes decoder output and generates predictions of the targets.
27 | loss (torch.nn.Module): Module that receives prediction and targets and returns model losses.
28 | optimizer (torch.optim.Optimizer): Torch optimizer to use during training.
29 | lr_scheduler (torch.lr_scheduler.LRScheduler): Torch learning rate scheduler used during training.
30 | """
31 | def __init__(self,
32 | wav_encoder: torch.nn.Module,
33 | target_encoder: torch.nn.Module,
34 | visible_encoder: torch.nn.Module,
35 | decoder: torch.nn.Module,
36 | masker: torch.nn.Module,
37 | head: torch.nn.Module,
38 | loss: torch.nn.Module = None,
39 | optimizer: torch.optim.Optimizer = None,
40 | lr_scheduler = None
41 | ):
42 | super().__init__()
43 | self.wav_encoder = wav_encoder()
44 | self.target_encoder = target_encoder()
45 | self.masker = masker()
46 | self.visible_encoder = visible_encoder()
47 | self.head = head()
48 | self.decoder = decoder()
49 | self.optimizer = optimizer
50 | self.lr_scheduler = lr_scheduler
51 | self.loss = loss() if loss is not None else None
52 |
53 | self.apply(self._init_weights)
54 |
55 | def forward(self, x):
56 | self.wav_encoder(x)
57 | self.masker.mask(x)
58 | self.visible_encoder(x)
59 | self.masker.unmask(x)
60 | self.decoder(x)
61 | if hasattr(self.target_encoder,'requires_model') and (self.target_encoder.requires_model):
62 | self.target_encoder(x, self)
63 | else:
64 | self.target_encoder(x)
65 | self.head(x)
66 |
67 | return x
68 |
69 | def forward_finetune(self,x):
70 | self.encode_wav(x)
71 | self.mask(x)
72 | self.encode_visible(x)
73 | self.predict_tokens(x, key_in='visible_embeddings')
74 |
75 | return x
76 |
77 | def extract_activations(self,x,
78 | detach=True,
79 | postnorm_last_activation=True,
80 | extract_decoder=False,
81 | residual_branch=0):
82 | if detach:
83 | with torch.no_grad():
84 | self.wav_encoder(x)
85 | self.masker.mask(x,ignore_mask=True)
86 | self.visible_encoder(x, return_activations=True, padding_mask=x['visible_padding_mask'],postnorm_last_activation=postnorm_last_activation,residual_branch=residual_branch)
87 | if extract_decoder:
88 | self.masker.unmask(x)
89 | self.decoder(x, return_activations=True,postnorm_last_activation=postnorm_last_activation,residual_branch=residual_branch)
90 | return x
91 |
92 | def training_step(self,x, batch_idx):
93 | x = self(x)
94 | losses = self.loss(x)
95 | self.log_results(x,losses,'train')
96 |
97 | return losses['loss']
98 |
99 | def validation_step(self,x, batch_idx):
100 | x = self(x)
101 | losses = self.loss(x)
102 | self.log_results(x,losses,'val')
103 |
104 | def log_results(self,x,losses,prefix):
105 | self.log_dict({'{}_{}'.format(prefix,k): v for k,v in losses.items()})
106 |
107 | def set_optimizer_state(self, state):
108 | self.opt_state = state
109 |
110 | def configure_optimizers(self):
111 | opt = self.optimizer(self.trainer.model.parameters())
112 | if self.lr_scheduler is not None:
113 | if self.lr_scheduler.__name__ == 'SequentialLR':
114 | binds = gin.get_bindings('torch.optim.lr_scheduler.SequentialLR')
115 | lr_scheduler = self.lr_scheduler(opt, schedulers=[s(opt) for s in binds['schedulers']])
116 | else:
117 | lr_scheduler = self.lr_scheduler(opt) if self.lr_scheduler is not None else None
118 | else:
119 | lr_scheduler = None
120 | del self.optimizer
121 | del self.lr_scheduler
122 | opt_config = {'optimizer': opt}
123 | if lr_scheduler is not None:
124 | opt_config['lr_scheduler'] = {'scheduler': lr_scheduler,
125 | 'interval': 'step',
126 | 'frequency': 1}
127 | return opt_config
128 |
129 | def _init_weights(self, m):
130 | if isinstance(m, torch.nn.Linear):
131 | # we use xavier_uniform following official JAX ViT:
132 | torch.nn.init.xavier_uniform_(m.weight)
133 | if isinstance(m, torch.nn.Linear) and m.bias is not None:
134 | torch.nn.init.constant_(m.bias, 0)
135 | elif isinstance(m, torch.nn.LayerNorm):
136 | torch.nn.init.constant_(m.bias, 0)
137 | torch.nn.init.constant_(m.weight, 1.0)
138 |
139 | def extract_features_from_file(self,
140 | filename,
141 | chunk_size=None,
142 | overlap=0,
143 | start=None,
144 | end=None,
145 | layer=-1,
146 | extract_decoder=False,
147 | postnorm_last_activation=True,
148 | residual_branch=0):
149 | fs = self.wav_encoder.fs
150 | start = start/fs if start is not None else None
151 | end = end/fs if end is not None else None
152 | duration = end - start if (end is not None and start is not None) else None
153 | x, fs = librosa.load(filename, sr=fs, offset=start, duration=duration)
154 | if chunk_size is None:
155 | chunk_size=int(fs*self.processor._processors[0].max_length)
156 | features = self.extract_features_from_array(x, chunk_size=chunk_size, layer=layer, overlap=overlap, extract_decoder=extract_decoder, postnorm_last_activation=postnorm_last_activation)
157 | if (features.ndim == 3) and (features.shape[0]==1):
158 | return features[0]
159 | else:
160 | return features
161 |
162 | def apply_processors(self, xin, from_idx=1):
163 | def to_torch(x):
164 | if isinstance(x, np.ndarray):
165 | return torch.from_numpy(x)
166 | else:
167 | return torch.tensor(x)
168 |
169 | batch_processed = []
170 | for k in range(xin['wav'].shape[0]):
171 | xin_i = {f:v[k] for f,v in xin.items()}
172 | for p in self.processor._processors[1:]:
173 | xin_i = p(xin_i)
174 | batch_processed.append(xin_i)
175 | xin = {k: torch.stack([to_torch(b[k]) for b in batch_processed]).to(self.device, self.dtype) for k in batch_processed[0].keys()}
176 | return xin
177 |
178 | def extract_features_from_array(self,
179 | audio,
180 | wav_lens=None,
181 | chunk_size=None,
182 | overlap=0,
183 | return_type='numpy',
184 | layer=-1,
185 | min_length=2048,
186 | extract_decoder=False,
187 | postnorm_last_activation=True,
188 | residual_branch=0):
189 | if chunk_size is None:
190 | fs = self.wav_encoder.fs
191 | chunk_size=int(fs*self.processor._processors[0].max_length)
192 |
193 | hop_size = chunk_size*(1-overlap)
194 | if not isinstance(audio, torch.Tensor):
195 | audio = torch.tensor(audio, device=self.device, dtype=torch.float32)
196 |
197 | if audio.ndim == 1:
198 | batch_size = 1
199 | audio = audio[None,:]
200 | else:
201 | batch_size = audio.shape[0]
202 |
203 | if wav_lens is None:
204 | wav_lens = [audio.shape[1]]*batch_size
205 |
206 | with torch.no_grad():
207 | acts = []
208 | for mi, i in enumerate(range(0,audio.shape[-1],hop_size)):
209 | if audio[:,i:i+chunk_size].shape[1]>min_length:
210 | wav_lens_i = [min(max(wav_lens[i] - i,0),chunk_size) for i in range(batch_size)]
211 | audio_i = audio[:,i:i+chunk_size]
212 | xin = {'wav': audio_i.cpu().numpy(), 'wav_lens': np.array(wav_lens_i)}
213 | xin = self.apply_processors(xin)
214 | if extract_decoder:
215 | self.decoder.key_transformer_out='decoder_transformer'
216 | out_i = self.extract_activations(xin, extract_decoder=extract_decoder, postnorm_last_activation=postnorm_last_activation, residual_branch=residual_branch)
217 | activations = torch.stack(out_i['visible_embeddings_activations']).squeeze(axis=1)
218 | if extract_decoder:
219 | decoder_acts = torch.stack(out_i['decoder_transformer_activations']).squeeze(axis=1)
220 | activations = torch.cat([activations, decoder_acts],axis=0)
221 | if layer != 'all':
222 | activations = activations[layer]
223 | if activations.ndim == 2:
224 | activations = activations.unsqueeze(0)
225 | acts.append(activations)
226 | if acts[0].ndim == 3:
227 | xi = torch.cat(acts,axis=1)
228 | elif acts[0].ndim == 4:
229 | xi = torch.cat(acts,axis=2)
230 | else:
231 | raise Exception('Wrong shape of activations')
232 | if return_type == 'numpy':
233 | return xi.detach().cpu().numpy()
234 | elif return_type == 'torch':
235 | return xi.detach()
236 | else:
237 | raise Exception('Unrecognized return type')
238 |
--------------------------------------------------------------------------------
/encodecmae/models/masks/__init__.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import numpy as np
4 |
5 | class PatchoutMask(torch.nn.Module):
6 | def __init__(self, masker, positional_encoder):
7 | super().__init__()
8 | self.positional_encoder = positional_encoder()
9 | self.masker = masker()
10 |
11 | def mask(self, x, ignore_mask=False):
12 | if self.training and not ignore_mask:
13 | to_mask = self.positional_encoder(x['wav_features'])
14 | x['non_visible_mask'] = self.masker(x)
15 | x['visible_mask'] = ~x['non_visible_mask']
16 |
17 | visibles = []
18 | for i in range(to_mask.shape[0]):
19 | visibles.append(to_mask[i,(x['visible_mask'][i] & ~x['feature_padding_mask'][i])])
20 | visible_lens = [xi.shape[0] for xi in visibles]
21 | maxlen = max(visible_lens)
22 | visible = torch.stack([torch.nn.functional.pad(xi,(0,0,0,maxlen-xi.shape[0])) for xi in visibles])
23 | visible_padding_mask = (torch.arange(0,visible.shape[1],device=visible.device).unsqueeze(0)) >= torch.tensor(visible_lens, device=visible.device).unsqueeze(1)
24 | x['visible_tokens'] = visible
25 | x['visible_lens'] = visible_lens
26 | x['visible_padding_mask'] = visible_padding_mask
27 | else:
28 | x['visible_tokens'] = self.positional_encoder(x['wav_features'])
29 | x['visible_mask'] = torch.ones(x['wav_features'].shape[:2]).to(dtype=torch.bool, device=x['visible_tokens'].device)
30 | x['non_visible_mask'] = torch.zeros((x['wav_features'].shape[0],x['wav_features'].shape[1])).to(dtype=torch.bool, device=x['visible_tokens'].device)
31 | x['visible_padding_mask'] = x['feature_padding_mask']
32 | x['visible_lens'] = x['features_len']
33 | return x
34 |
35 | def unmask(self, x, key_in='decoder_in', key_out='decoder_in'):
36 | mask = x['non_visible_mask']
37 | xin = x[key_in]
38 | feature_padding_mask = x['feature_padding_mask']
39 | visible_padding_mask = x['visible_padding_mask']
40 | unmasked = torch.zeros((mask.shape[0],mask.shape[1],xin.shape[-1]), dtype=xin.dtype, device=xin.device)
41 | unmasked[(~mask & ~feature_padding_mask)] = xin[~visible_padding_mask]
42 | x[key_out] = unmasked
43 | return x
44 |
45 | class TimeGapMask(torch.nn.Module):
46 | def __init__(self, p_mask, gap_size):
47 | super().__init__()
48 | self.p_mask = p_mask
49 | self.gap_size = gap_size
50 |
51 | def forward(self, x):
52 | x_lengths = x['features_len']
53 | mask = np.zeros((len(x_lengths),max(x_lengths)), dtype=bool)
54 | ml = (x_lengths*self.p_mask).to(dtype=torch.int64)
55 | for i in range(len(x_lengths)):
56 | n_masked_i = 0
57 | start_idxs = []
58 | while n_masked_i < ml[i]:
59 | start_idx = random.randint(0,max(0,x_lengths[i]-self.gap_size))
60 | start_idxs.append(start_idx)
61 | mask[i,start_idx:min(start_idx+self.gap_size,x_lengths[i])]=1
62 | n_masked_i = mask[i].sum()
63 | if n_masked_i > ml[i]:
64 | valid_start_idxs = [idx for idx in start_idxs if idx < mask.shape[1]-self.gap_size]
65 | mask[i,valid_start_idxs[0]:valid_start_idxs[0]+n_masked_i-ml[i]]=0
66 | return torch.from_numpy(mask).to(x['wav_features'].device)
--------------------------------------------------------------------------------
/encodecmae/models/targets/__init__.py:
--------------------------------------------------------------------------------
1 | from encodec import EncodecModel
2 | import torch
3 | import numpy as np
4 | import random
5 | from sklearn.cluster import MiniBatchKMeans
6 | import copy
7 |
8 | class EncodecQuantizer(torch.nn.Module):
9 | def __init__(self, n=8, frozen=True, scale=1.0, key_in='wav_features', key_out='targets', use_encodec_encoder=False, return_only_last=False):
10 | super().__init__()
11 | model = EncodecModel.encodec_model_24khz()
12 | self.model = model.quantizer
13 | self.scale = scale
14 | #Modify state dict to scale quantizer weights
15 | sd = self.model.state_dict()
16 | sd = {k: v*scale if 'embed' in k else v for k,v in sd.items()}
17 | self.model.load_state_dict(sd)
18 | self.frozen = frozen
19 | self.bandwidth = (10*75*n)/1000
20 | if self.frozen:
21 | self.model.eval()
22 | self.key_in = key_in
23 | self.key_out = key_out
24 | self.return_only_last = return_only_last
25 | if use_encodec_encoder:
26 | self.encodec_encoder = model.encoder
27 | self.encodec_encoder.eval()
28 | else:
29 | self.encodec_encoder = None
30 |
31 | def forward(self, xin):
32 | x = xin[self.key_in]
33 | with torch.no_grad():
34 | if self.encodec_encoder is not None:
35 | x = x.unsqueeze(1)
36 | x = self.encodec_encoder(x)
37 | x = torch.transpose(x,1,2)
38 | x = torch.transpose(x,1,2)
39 | result = self.model(x,sample_rate=75,bandwidth=self.bandwidth)
40 | y = result.codes
41 | if self.return_only_last:
42 | y = y[-1,:,:].unsqueeze(0)
43 | xin[self.key_out] = y
44 | return xin
45 |
46 | class IdentityTarget(torch.nn.Module):
47 | def __init__(self, key_in='targets', key_out='targets'):
48 | super().__init__()
49 | self.key_in = key_in
50 | self.key_out = key_out
51 | self.requires_model=False
52 |
53 | def forward(self, x):
54 | xin = x[self.key_in]
55 | if xin.ndim==2:
56 | xin = xin.unsqueeze(0)
57 | x[self.key_out] = xin
--------------------------------------------------------------------------------
/encodecmae/models/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 | import gin
5 |
6 | ###Code from TIMM:
7 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
8 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
9 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
10 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
11 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
12 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
13 | 'survival rate' as the argument.
14 | """
15 | if drop_prob == 0. or not training:
16 | return x
17 | keep_prob = 1 - drop_prob
18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20 | if keep_prob > 0.0 and scale_by_keep:
21 | random_tensor.div_(keep_prob)
22 | return x * random_tensor
23 |
24 | class DropPath(nn.Module):
25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
26 | """
27 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
28 | super(DropPath, self).__init__()
29 | self.drop_prob = drop_prob
30 | self.scale_by_keep = scale_by_keep
31 |
32 | def forward(self, x):
33 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
34 |
35 | def extra_repr(self):
36 | return f'drop_prob={round(self.drop_prob,3):0.3f}'
37 |
38 | class LayerScale(nn.Module):
39 | def __init__(self, dim, init_values=1e-5, inplace=False):
40 | super().__init__()
41 | self.inplace = inplace
42 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
43 |
44 | def forward(self, x):
45 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
46 |
47 | class Transpose(nn.Module):
48 | def __init__(self, axis0, axis1):
49 | super().__init__()
50 | self.axis0=axis0
51 | self.axis1=axis1
52 |
53 | def forward(self,x):
54 | return torch.transpose(x,self.axis0,self.axis1)
55 |
56 | class SinusoidalPositionalEmbeddings(nn.Module):
57 | def __init__(self, embedding_dim,learn_scale=True,scale=1.0):
58 | super().__init__()
59 | self.embedding_dim = embedding_dim
60 | self.scale = torch.nn.Parameter(data=torch.tensor(scale),requires_grad=learn_scale)
61 |
62 | def forward(self,x):
63 | pe = torch.zeros_like(x)
64 | position = torch.arange(0, pe.shape[1]).unsqueeze(1).unsqueeze(0)
65 | div_term = torch.exp((torch.arange(0, pe.shape[2], 2, dtype=torch.float) *
66 | -(math.log(10000.0) / pe.shape[2])))
67 | pe[:,:, 0::2] = torch.sin(position.float() * div_term)
68 | pe[:,:, 1::2] = torch.cos(position.float() * div_term)
69 | return x+self.scale*pe
70 |
71 | ### Transformer implementation with ResiDual normalization. See 'ResiDual: Transformer with Dual Residual Connections' - Xie et al.
72 | class TransformerLayer(nn.Module):
73 | def __init__(self, model_dim, attention_layer, ff_layer, norm_layer,
74 | norm_type='ResiDual', cross_attention_layer=None, drop_path=0, init_values=None):
75 | super().__init__()
76 | self.att_layer = attention_layer
77 | self.ff_layer = ff_layer
78 | self.norm1 = norm_layer(model_dim)
79 | self.norm2 = norm_layer(model_dim)
80 | self.norm_type = norm_type
81 | self.xatt_layer = cross_attention_layer
82 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
83 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
84 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
85 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
86 | if cross_attention_layer is not None:
87 | self.ls3 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
88 | self.norm3 = norm_layer(model_dim)
89 | self.drop_path3 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
90 |
91 | def forward(self, x,
92 | key_mask=None, att_mask=None,
93 | att_bias=None, k_mem=None, v_mem=None,
94 | key_mem_mask=None, mem_att_mask=None, mem_att_bias=None):
95 |
96 | if self.norm_type=='prenorm':
97 | xnorm = self.norm1(x)
98 | att_out, att_matrix = self.att_layer(xnorm, xnorm, xnorm, key_mask=key_mask, att_mask=att_mask, att_bias=att_bias)
99 | x = x + self.drop_path1(self.ls1(att_out))
100 | if self.xatt_layer is not None:
101 | xnorm = self.norm3(x)
102 | xatt_out, xatt_matrix = self.xatt_layer(xnorm, k_mem, v_mem, key_mask=key_mem_mask, att_mask=mem_att_mask, att_bias=mem_att_bias)
103 | x = x + self.drop_path3(self.ls3(xatt_out))
104 | else:
105 | xatt_matrix = None
106 | x = x + self.drop_path2(self.ls2(self.ff_layer(self.norm2(x))))
107 |
108 | elif self.norm_type=='postnorm':
109 | att_out, att_matrix = self.att_layer(x,x,x,key_mask=key_mask, att_mask=att_mask, att_bias=att_bias)
110 | x = self.norm1(x+self.drop_path1(self.ls1(att_out)))
111 | if self.xatt_layer is not None:
112 | xatt_out, xatt_matrix = self.xatt_layer(x, k_mem, v_mem, key_mask=key_mem_mask, att_mask=mem_att_mask, att_bias=mem_att_bias)
113 | x = self.norm3(x + self.drop_path3(self.ls3(xatt_out)))
114 | else:
115 | xatt_matrix = None
116 | x = self.norm2(x+self.drop_path2(self.ls2(self.ff_layer(x))))
117 |
118 | elif self.norm_type=='ResiDual':
119 | #Here 2 IO are managed. x[0]: ln_out x[1]: unnormalized
120 | #See https://arxiv.org/pdf/2304.14802.pdf
121 |
122 | if not isinstance(x,tuple):
123 | x = (x,x)
124 | att_out, att_matrix = self.att_layer(x[0],x[0],x[0],key_mask=key_mask, att_mask=att_mask, att_bias=att_bias)
125 | x0 = self.norm1(x[0]+self.drop_path1(self.ls1(att_out)))
126 | x1 = x[1]+self.drop_path1(self.ls1(att_out))
127 | if self.xatt_layer is not None:
128 | xatt_out, xatt_matrix = self.xatt_layer(x0, k_mem, v_mem, key_mask=key_mem_mask, att_mask=mem_att_mask, att_bias=mem_att_bias)
129 | x0 = self.norm3(x0+self.drop_path3(self.ls3(xatt_out)))
130 | x1 = x1+self.drop_path1(self.ls3(xatt_out))
131 | else:
132 | xatt_matrix=None
133 | ff_out = self.ff_layer(x0)
134 | x0 = self.norm2(x0+self.drop_path2(self.ls2(ff_out)))
135 | x1 = x1+self.drop_path2(self.ls2(ff_out))
136 | x=(x0,x1)
137 |
138 | return x, att_matrix, xatt_matrix
139 |
140 | class TransformerEncoder(nn.Module):
141 | def __init__(self, model_dim,
142 | num_layers,
143 | attention_layer,
144 | ff_layer=None,
145 | ff_dim=4096,
146 | norm_layer=torch.nn.LayerNorm,
147 | norm_type='ResiDual',
148 | cross_attention_layer=None,
149 | drop_path=0,
150 | init_values=None,
151 | positional_encoder=None,
152 | compile=True,
153 | return_activations=False,
154 | key_in=None,
155 | key_padding_mask=None,
156 | key_out=None,
157 | key_transformer_in=None,
158 | key_transformer_out=None,
159 | pre_net=None,
160 | post_net=None):
161 |
162 | super().__init__()
163 | if ff_layer is None:
164 | ff_layer = torch.nn.Sequential(torch.nn.Linear(model_dim, ff_dim),
165 | torch.nn.ReLU(),
166 | torch.nn.Linear(ff_dim, model_dim))
167 | else:
168 | ff_layer = ff_layer()
169 | if cross_attention_layer is not None:
170 | cross_attention_layer = cross_attention_layer()
171 | self.encoder_layers = torch.nn.ModuleList([TransformerLayer(model_dim,
172 | attention_layer(),
173 | ff_layer,
174 | norm_layer,
175 | norm_type,
176 | cross_attention_layer,
177 | drop_path,
178 | init_values) for i in range(num_layers)])
179 | if positional_encoder is not None:
180 | self.positional_encoder = positional_encoder(model_dim)
181 | else:
182 | self.positional_encoder = positional_encoder
183 | if norm_type in ['ResiDual','prenorm']:
184 | self.final_norm = norm_layer(model_dim)
185 | elif norm_type == 'postnorm':
186 | self.final_norm = torch.nn.Identity()
187 | self.model_dim = model_dim
188 | self.compile = compile
189 | #self.compile = False
190 | self.norm_type = norm_type
191 | self.return_activations = return_activations
192 | if not self.return_activations:
193 | self.return_activations = []
194 | elif self.return_activations == 'all':
195 | self.return_activations = [i for i in range(num_layers)]
196 | self.key_in = key_in
197 | self.key_padding_mask = key_padding_mask
198 | self.key_out = key_out
199 | self.key_transformer_in = key_transformer_in
200 | self.key_transformer_out = key_transformer_out
201 | self.pre_net = pre_net() if pre_net is not None else None
202 | self.post_net = post_net() if post_net is not None else None
203 | self.residual_branch = 0
204 |
205 | def run_through_encoder(self,x):
206 | x, layers, padding_mask = x
207 | for l in layers:
208 | x,_,_=l(x,key_mask=padding_mask)
209 | return x
210 |
211 | def extract_activations(self,x):
212 | activations = []
213 | x, layers, padding_mask = x
214 | for l in layers:
215 | x,_,_=l(x,key_mask=padding_mask)
216 | activations.append(x[self.residual_branch])
217 | return activations, x
218 |
219 | @torch.compile
220 | def compiled_run_through_encoder(self,x):
221 | return self.run_through_encoder(x)
222 |
223 | @torch.compile
224 | def compiled_get_activations(self,x):
225 | return self.extract_activations(x)
226 |
227 | def forward(self,xin,
228 | padding_mask=None,
229 | return_activations=False,
230 | postnorm_last_activation=True,
231 | residual_branch=0):
232 |
233 | self.residual_branch = residual_branch
234 | if isinstance(xin, dict):
235 | padding_mask = xin[self.key_padding_mask]
236 | x = xin[self.key_in]
237 | else:
238 | x = xin
239 |
240 | if self.pre_net is not None:
241 | x = self.pre_net(x)
242 |
243 | if self.positional_encoder is not None:
244 | x = self.positional_encoder(x)
245 |
246 | if isinstance(xin, dict) and (self.key_transformer_in is not None):
247 | xin[self.key_transformer_in] = x
248 |
249 | if self.compile:
250 | if return_activations:
251 | acts, x = self.compiled_get_activations((x, self.encoder_layers, padding_mask))
252 | xin[self.key_transformer_out+'_activations'] = acts
253 | else:
254 | x = self.compiled_run_through_encoder((x,self.encoder_layers,padding_mask))
255 | else:
256 | if return_activations:
257 | acts, x = self.extract_activations((x, self.encoder_layers, padding_mask))
258 | xin[self.key_transformer_out+'_activations'] = acts
259 | else:
260 | x = self.run_through_encoder((x,self.encoder_layers,padding_mask))
261 |
262 | if self.norm_type == 'ResiDual':
263 | x = x[0] + self.final_norm(x[1])
264 | if return_activations and postnorm_last_activation:
265 | xin[self.key_transformer_out+'_activations'][-1] = x
266 | elif self.norm_type == 'prenorm':
267 | x = self.final_norm(x)
268 | else:
269 | pass
270 |
271 | if isinstance(xin, dict) and (self.key_transformer_out is not None):
272 | xin[self.key_transformer_out] = x
273 |
274 | if self.post_net is not None:
275 | x = self.post_net(x)
276 |
277 | if isinstance(xin,dict):
278 | xin[self.key_out] = x
279 | else:
280 | xin = x
281 | return xin
282 |
283 | class MultiHeadAttention(nn.Module):
284 | def __init__(self,
285 | model_dim,
286 | num_heads,
287 | qk_proj_dim=None,
288 | v_proj_dim=None,
289 | q_bias=False,
290 | k_bias=False,
291 | v_bias=False,
292 | out_bias=False,
293 | att_drop=0,
294 | proj_drop=0,
295 | att_scale=None,
296 | mask_value=-1e9):
297 |
298 | super().__init__()
299 | if qk_proj_dim is None:
300 | qk_proj_dim = model_dim//num_heads
301 | if v_proj_dim is None:
302 | v_proj_dim = qk_proj_dim
303 |
304 | self.qk_proj_dim = qk_proj_dim
305 | self.num_heads = num_heads
306 | self.v_proj_dim = v_proj_dim
307 |
308 | self.wq = nn.Linear(model_dim, qk_proj_dim*num_heads, bias=q_bias)
309 | self.wk = nn.Linear(model_dim, qk_proj_dim*num_heads, bias=k_bias)
310 | self.wv = nn.Linear(model_dim, v_proj_dim*num_heads, bias=v_bias)
311 |
312 | if att_scale is None:
313 | self.att_scale = qk_proj_dim ** -0.5
314 | else:
315 | self.att_scale = att_scale
316 | self.mask_value = mask_value
317 |
318 | self.att_drop = nn.Dropout(att_drop)
319 | self.wo = nn.Linear(v_proj_dim*num_heads, model_dim, bias=out_bias)
320 | self.proj_drop = nn.Dropout(proj_drop)
321 |
322 | def forward(self, q, k, v, key_mask=None, att_mask=None, att_bias=None):
323 | N,Tq,C = q.shape
324 | N,Tk,C = k.shape
325 | Q = self.wq(q) #NxTxD.H
326 | K = self.wk(k) #NxTxD.H
327 | V = self.wv(v) #NxTxD.H
328 |
329 | Q = Q.view(N,Tq,self.qk_proj_dim,self.num_heads).permute(0,3,1,2).reshape(N*self.num_heads,Tq,self.qk_proj_dim) #N.HxTqxDq
330 | K = K.view(N,Tk,self.qk_proj_dim,self.num_heads).permute(0,3,1,2).reshape(N*self.num_heads,Tk,self.qk_proj_dim) #N.HxTkxDk
331 | V = V.view(N,Tk,self.v_proj_dim,self.num_heads).permute(0,3,1,2).reshape(N*self.num_heads,Tk,self.v_proj_dim) #N.HxTkxDv
332 |
333 | kv_mask = torch.zeros((N*self.num_heads,Tq,Tk), dtype=torch.bool, device=Q.device)
334 | if key_mask is not None:
335 | key_mask = torch.tile(key_mask[:,None,None,:],(1,self.num_heads,Tq,1)).reshape(N*self.num_heads,Tq,Tk).to(dtype=torch.bool)
336 | kv_mask += key_mask
337 | if att_mask is not None:
338 | kv_mask += att_mask
339 |
340 | att = Q @ K.transpose(-2, -1) * self.att_scale #N.HxTqxTk
341 | if att_bias is not None:
342 | att += att_bias
343 | att += kv_mask*self.mask_value
344 | att = att.softmax(dim=-1) #N.HxTqxTk
345 | att = self.att_drop(att) #N.HxTqxTk
346 |
347 | x = att @ V #N.HxTqxDv
348 | x = x.view(N,self.num_heads,Tq,self.v_proj_dim).permute(0,2,3,1).reshape(N,Tq,self.v_proj_dim*self.num_heads) #NxTqxDvxH
349 | O = self.wo(x)
350 | O = self.proj_drop(O)
351 |
352 | return O, att
353 |
--------------------------------------------------------------------------------
/encodecmae/scripts/run_pretraining.sh:
--------------------------------------------------------------------------------
1 | #-----------------------------------------First Iteration Models-------------------------------------------------------
2 | #EC-EC Small model
3 | ginpipe configs/base/unsupervised_pretrain.gin \
4 | configs/models/encodecmae.gin \
5 | configs/datasets/audioset-unbalanced-24k.gin \
6 | configs/datasets/fma-large-24k.gin \
7 | configs/datasets/librilight-6k-24k.gin \
8 | configs/features/wav_only.gin \
9 | --module_list configs/imports \
10 | --project_name ec-ec-small_model \
11 | --experiment_name upstream_model \
12 | --mods NUM_ENCODER_LAYERS=5
13 |
14 | #EC-EC Base model
15 | ginpipe configs/base/unsupervised_pretrain.gin \
16 | configs/models/encodecmae.gin \
17 | configs/datasets/audioset-unbalanced-24k.gin \
18 | configs/datasets/fma-large-24k.gin \
19 | configs/datasets/librilight-6k-24k.gin \
20 | configs/features/wav_only.gin \
21 | --module_list configs/imports \
22 | --project_name ec-ec-base_model \
23 | --experiment_name upstream_model
24 |
25 | #EC-EC Large model
26 | ginpipe configs/base/unsupervised_pretrain.gin \
27 | configs/models/encodecmae.gin \
28 | configs/datasets/audioset-unbalanced-24k.gin \
29 | configs/datasets/fma-large-24k.gin \
30 | configs/datasets/librilight-6k-24k.gin \
31 | configs/features/wav_only.gin \
32 | --module_list configs/imports \
33 | --project_name ec-ec-large_model \
34 | --experiment_name upstream_model \
35 | --mods NUM_ENCODER_LAYERS=20 DEVICE="[0,1]" MODEL_DIM=1024 TRAIN_BATCH_SIZE=64 "pl.Trainer.strategy='ddp_find_unused_parameters_true'"
36 |
37 | #Mel256-EC Small model
38 | ginpipe configs/base/unsupervised_pretrain.gin \
39 | configs/models/encodecmae.gin \
40 | configs/datasets/audioset-unbalanced-24k.gin \
41 | configs/datasets/fma-large-24k.gin \
42 | configs/datasets/librilight-6k-24k.gin \
43 | configs/features/mel.gin \
44 | --module_list configs/imports \
45 | --project_name mel256-ec-small_model \
46 | --experiment_name upstream_model \
47 | --mods NUM_ENCODER_LAYERS=5
48 |
49 | #Mel256-EC Base model
50 | ginpipe configs/base/unsupervised_pretrain.gin \
51 | configs/models/encodecmae.gin \
52 | configs/datasets/audioset-unbalanced-24k.gin \
53 | configs/datasets/fma-large-24k.gin \
54 | configs/datasets/librilight-6k-24k.gin \
55 | configs/features/mel.gin \
56 | --module_list configs/imports \
57 | --project_name mel256-ec-base_model \
58 | --experiment_name upstream_model
59 |
60 | #Mel256-EC Base model (Audioset)
61 | ginpipe configs/base/unsupervised_pretrain.gin \
62 | configs/models/encodecmae.gin \
63 | configs/datasets/audioset-unbalanced-24k.gin \
64 | configs/features/mel.gin \
65 | --module_list configs/imports \
66 | --project_name mel256-ec-base_model-as \
67 | --experiment_name upstream_model
68 |
69 | #Mel256-EC Base model (Librilight)
70 | ginpipe configs/base/unsupervised_pretrain.gin \
71 | configs/models/encodecmae.gin \
72 | configs/datasets/librilight-6k-24k.gin \
73 | configs/features/mel.gin \
74 | --module_list configs/imports \
75 | --project_name mel256-ec-base_model-ll \
76 | --experiment_name upstream_model
77 |
78 | #Mel256-EC Base model (FMA)
79 | ginpipe configs/base/unsupervised_pretrain.gin \
80 | configs/models/encodecmae.gin \
81 | configs/datasets/fma-large-24k.gin \
82 | configs/features/mel.gin \
83 | --module_list configs/imports \
84 | --project_name mel256-ec-base_model-fma \
85 | --experiment_name upstream_model
86 |
87 | #Mel256-EC Large model
88 | ginpipe configs/base/unsupervised_pretrain.gin \
89 | configs/models/encodecmae.gin \
90 | configs/datasets/audioset-unbalanced-24k.gin \
91 | configs/datasets/fma-large-24k.gin \
92 | configs/datasets/librilight-6k-24k.gin \
93 | configs/features/mel.gin \
94 | --module_list configs/imports \
95 | --project_name mel256-ec-large_model \
96 | --experiment_name upstream_model \
97 | --mods NUM_ENCODER_LAYERS=20 DEVICE="[0,1]" MODEL_DIM=1024 TRAIN_BATCH_SIZE=64 "pl.Trainer.strategy='ddp_find_unused_parameters_true'"
98 |
99 | #Mel256-EC Large model - Audioset
100 | ginpipe configs/base/unsupervised_pretrain.gin \
101 | configs/models/encodecmae.gin \
102 | configs/datasets/audioset-unbalanced-24k.gin \
103 | configs/features/mel.gin \
104 | --module_list configs/imports \
105 | --project_name mel256-ec-large_model-as \
106 | --experiment_name upstream_model \
107 | --mods NUM_ENCODER_LAYERS=20 DEVICE="[0,1]" MODEL_DIM=1024 TRAIN_BATCH_SIZE=64 "pl.Trainer.strategy='ddp_find_unused_parameters_true'"
108 |
109 | #-----------------------------------------------ST Models--------------------------------------------------------------
110 |
111 | #Mel256->EC Base - NoPN ST
112 | ginpipe configs/base/create_st_dataset.gin \
113 | configs/datasets/audioset-unbalanced-24k.gin \
114 | configs/datasets/fma-large-24k.gin \
115 | configs/datasets/librilight-6k-24k.gin \
116 | configs/features/mel.gin \
117 | --module_list configs/imports \
118 | --project_name mel256-ec-base_model \
119 | --experiment_name self_training_dataset_no-pn \
120 | --mods "TEACHER_MODEL='mel256-ec-base'" \
121 | "POSTNORM_LAST_ACTIVATION=False" \
122 | "DEVICE='cuda:0'"
123 |
124 | ginpipe configs/base/self-train.gin \
125 | configs/models/encodecmae.gin \
126 | configs/features/mel.gin \
127 | configs/features/st.gin \
128 | --module_list configs/imports \
129 | --project_name mel256-ec-base_model \
130 | --experiment_name st_model-no-pn \
131 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-base.pt'" \
132 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-base_model/self_training_dataset_no-pn'" \
133 | "DEVICE=[0]"
134 |
135 | #Mel256->EC Base - NoPN ST (AS)
136 | ginpipe configs/base/create_st_dataset.gin \
137 | configs/datasets/audioset-unbalanced-24k.gin \
138 | configs/features/mel.gin \
139 | --module_list configs/imports \
140 | --project_name mel256-ec-base_model-as \
141 | --experiment_name self_training_dataset_no-pn \
142 | --mods "TEACHER_MODEL='mel256-ec-base-as'" \
143 | "POSTNORM_LAST_ACTIVATION=False" \
144 | "DEVICE='cuda:0'"
145 |
146 | ginpipe configs/base/self-train.gin \
147 | configs/models/encodecmae.gin \
148 | configs/features/mel.gin \
149 | configs/features/st.gin \
150 | --module_list configs/imports \
151 | --project_name mel256-ec-base_model-as \
152 | --experiment_name st_model-no-pn \
153 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-base-as.pt'" \
154 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-base_model-as/self_training_dataset_no-pn'" \
155 | "DEVICE=[0]" \
156 | "pl.Trainer.check_val_every_n_epoch=None"
157 |
158 | #Mel256->EC Base - NoPN ST (LL)
159 | ginpipe configs/base/create_st_dataset.gin \
160 | configs/datasets/librilight-6k-24k.gin \
161 | configs/features/mel.gin \
162 | --module_list configs/imports \
163 | --project_name mel256-ec-base_model-ll \
164 | --experiment_name self_training_dataset_no-pn \
165 | --mods "TEACHER_MODEL='mel256-ec-base-ll'" \
166 | "POSTNORM_LAST_ACTIVATION=False" \
167 | "DEVICE='cuda:1'"
168 |
169 | ginpipe configs/base/self-train.gin \
170 | configs/models/encodecmae.gin \
171 | configs/features/mel.gin \
172 | configs/features/st.gin \
173 | --module_list configs/imports \
174 | --project_name mel256-ec-base_model-ll \
175 | --experiment_name st_model-no-pn \
176 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-base-ll.pt'" \
177 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-base_model-ll/self_training_dataset_no-pn'" \
178 | "DEVICE=[1]" \
179 | "pl.Trainer.check_val_every_n_epoch=None"
180 |
181 | #Mel256->EC Base - NoPN ST (FMA)
182 | ginpipe configs/base/create_st_dataset.gin \
183 | configs/datasets/fma-large-24k.gin \
184 | configs/features/mel.gin \
185 | --module_list configs/imports \
186 | --project_name mel256-ec-base_model-fma \
187 | --experiment_name self_training_dataset_no-pn \
188 | --mods "TEACHER_MODEL='mel256-ec-base-fma'" \
189 | "POSTNORM_LAST_ACTIVATION=False" \
190 | "DEVICE='cuda:1'"
191 |
192 | ginpipe configs/base/self-train.gin \
193 | configs/models/encodecmae.gin \
194 | configs/features/mel.gin \
195 | configs/features/st.gin \
196 | --module_list configs/imports \
197 | --project_name mel256-ec-base_model-fma \
198 | --experiment_name st_model-no-pn \
199 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-base-fma.pt'" \
200 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-base_model-fma/self_training_dataset_no-pn'" \
201 | "DEVICE=[1]" \
202 | "pl.Trainer.check_val_every_n_epoch=None"
203 |
204 | #Mel256-EC Large - NoPN ST
205 | ginpipe configs/base/create_st_dataset.gin \
206 | configs/datasets/audioset-unbalanced-24k.gin \
207 | configs/datasets/fma-large-24k.gin \
208 | configs/datasets/librilight-6k-24k.gin \
209 | configs/features/mel.gin \
210 | --module_list configs/imports \
211 | --project_name mel256-ec-large_model \
212 | --experiment_name self_training_dataset_no-pn \
213 | --mods "TEACHER_MODEL='mel256-ec-base_st-nopn'" \
214 | "POSTNORM_LAST_ACTIVATION=False" \
215 | "DEVICE='cuda:0'"
216 |
217 | ginpipe configs/base/self-train.gin \
218 | configs/models/encodecmae.gin \
219 | configs/features/mel.gin \
220 | configs/features/st.gin \
221 | --module_list configs/imports \
222 | --project_name mel256-ec-large_model \
223 | --experiment_name st_model-no-pn \
224 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-large.pt'" \
225 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-large_model/self_training_dataset_no-pn'" \
226 | "DEVICE=[0,1]" \
227 | NUM_ENCODER_LAYERS=20 \
228 | MODEL_DIM=1024 \
229 | TRAIN_BATCH_SIZE=64 \
230 | "pl.Trainer.strategy='ddp_find_unused_parameters_true'"
231 |
232 | #Mel256-EC Large - NoPN ST AS
233 | ginpipe configs/base/create_st_dataset.gin \
234 | configs/datasets/audioset-unbalanced-24k.gin \
235 | configs/features/mel.gin \
236 | --module_list configs/imports \
237 | --project_name mel256-ec-large_model-as \
238 | --experiment_name self_training_dataset_no-pn \
239 | --mods "TEACHER_MODEL='mel256-ec-base_st-as-nopn'" \
240 | "POSTNORM_LAST_ACTIVATION=False" \
241 | "DEVICE='cuda:1'"
242 |
243 | ginpipe configs/base/self-train.gin \
244 | configs/models/encodecmae.gin \
245 | configs/features/mel.gin \
246 | configs/features/st.gin \
247 | --module_list configs/imports \
248 | --project_name mel256-ec-large_model-as \
249 | --experiment_name st_model-no-pn \
250 | --mods "ST_CHECKPOINT='huggingface:lpepino/encodecmae-pretrained/upstreams/mel256-ec-large-as.pt'" \
251 | "ST_DATASET_DIR='/workspace/encodecmae/encodecmae/experiments/mel256-ec-large_model-as/self_training_dataset_no-pn'" \
252 | "DEVICE=[0,1]" \
253 | NUM_ENCODER_LAYERS=20 \
254 | MODEL_DIM=1024 \
255 | TRAIN_BATCH_SIZE=64 \
256 | "pl.Trainer.strategy='ddp_find_unused_parameters_true'"
257 |
--------------------------------------------------------------------------------
/encodecmae/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pathlib import Path
3 | from loguru import logger
4 | import torchinfo
5 | import inspect
6 | from tqdm import tqdm
7 | import joblib
8 | import numpy as np
9 | from encodecmae.hub import load_model as load_encodecmae_model
10 | from huggingface_hub import hf_hub_download
11 | from threadpoolctl import threadpool_limits
12 |
13 | def load_model(state, model):
14 | model = load_encodecmae_model(model)
15 | state['model'] = model
16 | return state
17 |
18 | def create_st_dataset(state, layer=-1, kmeans_samples=10000, device='cuda:0', postnorm_last_activation=True,kmeans_jobs=8):
19 | from sklearn.cluster import KMeans
20 |
21 | model = state['model'].to(device)
22 | model.visible_encoder.compile=False
23 | #First learn kmeans:
24 | if ('tokenizer' in state):
25 | cluster_model = state['tokenizer']
26 | elif Path(state['output_dir'],'tokenizer.pkl').exists():
27 | cluster_model = joblib.load(Path(state['output_dir'],'tokenizer.pkl'))
28 | else:
29 | kmeans_sample_idxs = np.random.choice(np.arange(0,len(state['datasets']['train'])),size=kmeans_samples,replace=False)
30 | kmeans_dataset = []
31 | for i in tqdm(kmeans_sample_idxs):
32 | x = state['datasets']['train'][i]
33 | x = state['dataloaders']['train'].collate_fn([x])
34 | x = {k: v.to(model.device) for k,v in x.items()}
35 | out = model.extract_activations(x, postnorm_last_activation=postnorm_last_activation)
36 | kmeans_dataset.append(out['visible_embeddings_activations'][layer].cpu().numpy())
37 | kmeans_dataset = np.concatenate(kmeans_dataset,axis=1)[0]
38 | n_tokens = model.head.num_tokens
39 | with threadpool_limits(user_api="blas", limits=kmeans_jobs):
40 | cluster_model = KMeans(n_clusters=n_tokens)
41 | kmeans_sample_idxs = np.random.choice(np.arange(0,kmeans_dataset.shape[0]),size=kmeans_samples,replace=False)
42 | kmeans_dataset = kmeans_dataset[kmeans_sample_idxs]
43 | cluster_model.fit(kmeans_dataset)
44 | state['tokenizer']=cluster_model
45 | joblib.dump(cluster_model, Path(state['output_dir'],'tokenizer.pkl'))
46 | state['datasets']['train']._out_cols+=['start','stop','filename']
47 | state['dataloaders']['train'] = torch.utils.data.DataLoader(state['dataloaders']['train'].dataset, shuffle=False, num_workers=4, collate_fn=state['dataloaders']['train'].collate_fn, batch_size=state['dataloaders']['train'].batch_size)
48 | #Find last saved index before starting this loop:
49 | data_out_path = Path(state['output_dir'],'self_training_dataset')
50 | if not data_out_path.exists():
51 | data_out_path.mkdir(parents=True)
52 | for batch_idx, x in enumerate(tqdm(state['dataloaders']['train'])):
53 | #out = model.extract_activations({'wav': x['wav'].to(device=model.device, dtype=model.dtype),
54 | # 'wav_lens': x['wav_lens'].to(device=model.device)}, postnorm_last_activation=postnorm_last_activation)
55 | x = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k,v in x.items()}
56 | out = model.extract_activations(x, postnorm_last_activation=postnorm_last_activation)
57 | out = out['visible_embeddings_activations'][layer].cpu().numpy()
58 | for i in range(out.shape[0]):
59 | indexs = cluster_model.predict(out[i])
60 | filename = '{}_{}_{}.npy'.format(Path(x['filename'][i]).stem,x['start'][i],x['stop'][i])
61 | file_out = Path(state['output_dir'],'self_training_dataset',filename[:3],filename)
62 | file_out.parent.mkdir(parents=True, exist_ok=True)
63 | try:
64 | np.save(file_out,indexs)
65 | except:
66 | print('Failed {}'.format(file_out))
67 | with open(Path(state['output_dir'],'metadata_selftrain_dataset.csv'),'a') as f:
68 | f.write('{},{},{},{}\n'.format(str(x['filename'][i]),x['start'][i],x['stop'][i],str(file_out.resolve())))
69 | state['model'] = state['model'].to('cpu')
70 | return state
71 |
72 | def fit_model(state, trainer_cls=None, model_cls=None, from_checkpoint=None,
73 | cpu_threads=8,dataloaders_key='dataloaders',
74 | checkpoint_folder='checkpoints',
75 | key_out='model',
76 | cache_model=True,
77 | keep_optimizer_state=False,
78 | non_strict_checkpoint=True):
79 |
80 | if from_checkpoint.startswith('huggingface:'):
81 | hf_path = from_checkpoint.split('huggingface:')[-1]
82 | hf_repo = '/'.join(hf_path.split('/')[:2])
83 | path_in_repo = '/'.join(hf_path.split('/')[2:])
84 | from_checkpoint = hf_hub_download(repo_id=hf_repo, filename=path_in_repo)
85 |
86 | if not ((key_out in state) and (cache_model)):
87 | torch.set_num_threads(cpu_threads)
88 | torch.set_float32_matmul_precision('medium')
89 | kwargs = {}
90 | if 'state' in inspect.signature(model_cls.__init__).parameters.keys():
91 | kwargs['state'] = state
92 |
93 | if model_cls is None:
94 | if Path(from_checkpoint).stem == 'state':
95 | model = joblib.load(from_checkpoint)['model']
96 | from_checkpoint=None
97 | else:
98 | model = model_cls(**kwargs)
99 | trainer = trainer_cls()
100 | trainer.checkpoint_callback.dirpath = trainer.checkpoint_callback.dirpath + '/{}'.format(checkpoint_folder)
101 | base_dir = trainer.checkpoint_callback.dirpath
102 |
103 | #Find last checkpoint
104 | if from_checkpoint == 'last':
105 | ckpts = list(Path(base_dir).glob('*.ckpt'))
106 | if 'last' in [x.stem for x in ckpts]:
107 | from_checkpoint=Path(base_dir, 'last.ckpt')
108 | else:
109 | ckpt_epoch = [int(c.stem.split('epoch=')[-1].split('-')[0]) for c in ckpts]
110 | if len(ckpt_epoch) > 0:
111 | last_epoch = max(ckpt_epoch)
112 | from_checkpoint = ckpts[ckpt_epoch.index(last_epoch)]
113 | else:
114 | logger.info('No checkpoints found in {}. Training from scratch.'.format(base_dir))
115 | from_checkpoint = None
116 |
117 | logger.info(torchinfo.summary(model))
118 |
119 | if (from_checkpoint is not None) and (non_strict_checkpoint):
120 | ckpt_data = torch.load(from_checkpoint)
121 | model_sd = model.state_dict()
122 | ckpt_sd = {}
123 | for k,v in ckpt_data['state_dict'].items():
124 | if (k in model_sd) and (v.shape == model_sd[k].shape):
125 | ckpt_sd[k] = v
126 | else:
127 | print("Couldn't load {} from checkpoint".format(k))
128 |
129 | model.load_state_dict(ckpt_sd, strict=False)
130 | from_checkpoint=None
131 |
132 | trainer.fit(model,
133 | state[dataloaders_key]['train'],
134 | state[dataloaders_key]['validation'],
135 | ckpt_path=from_checkpoint)
136 | trainer.save_checkpoint(Path(base_dir,'last.ckpt'))
137 | state[key_out+'_checkpoint_dir'] = trainer.checkpoint_callback.dirpath
138 | best_model_path = model.trainer.checkpoint_callback.best_model_path
139 | if (best_model_path is not None) and (best_model_path != ''):
140 | model.load_state_dict(torch.load(best_model_path)['state_dict'])
141 | state[key_out] = model
142 | else:
143 | logger.info('Model is already in state. Skipping task.')
144 |
145 | return state
146 |
--------------------------------------------------------------------------------
/encodecmae/tasks/data/__init__.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from loguru import logger
3 | from pathlib import Path
4 | import numpy as np
5 | from tqdm import tqdm
6 | import soundfile as sf
7 | from torch.utils.data import Dataset
8 | import copy
9 | import torch
10 | import re
11 | import joblib
12 |
13 | from typing import Dict, List, Callable, Any, Optional, Union, Type
14 |
15 | def compensate_lengths(df: pd.DataFrame, chunk_length: Optional[float] = None) -> List[int]:
16 | """
17 | Compensates for varying lengths of elements in a DataFrame.
18 |
19 | Args:
20 | df (pandas.DataFrame): The DataFrame containing elements with varying lengths.
21 | chunk_length (float, optional): The length of each chunk in seconds.
22 | If provided, elements will be divided into chunks of approximately equal length.
23 | Defaults to None.
24 |
25 | Returns:
26 | list: A list of indices corresponding to elements in the DataFrame, accounting for varying lengths.
27 |
28 | Note:
29 | If chunk_length is not provided, each element is represented by a single index.
30 | If chunk_length is provided, elements are divided into chunks, and each chunk is represented by its element's index.
31 | """
32 | if chunk_length is not None:
33 | map_idx = []
34 | for i, (idx, row) in enumerate(df.iterrows()):
35 | map_idx.extend([i]*int(max(1,row['duration']//chunk_length)))
36 | return map_idx
37 | else:
38 | return list(range(len(df)))
39 |
40 | def dataset_random_split(df: pd.DataFrame,
41 | proportions: Dict[str, float] = {}) -> Dict[str, pd.DataFrame]:
42 | """
43 | Splits a DataFrame into partitions randomly based on given proportions.
44 |
45 | Args:
46 | df (pandas.DataFrame): The DataFrame to be split.
47 | proportions (dict, optional): Dictionary containing proportions of split for each partition.
48 | If value is greater than 1, it's treated as the number of samples to include in the partition.
49 | If value is between 0 and 1, it's treated as the proportion of the DataFrame to include in the partition.
50 | If -1 is provided for any partition, remaining samples will be assigned to this partition.
51 | Defaults to an empty dictionary.
52 |
53 | Returns:
54 | dict: Dictionary containing partitions as DataFrames.
55 |
56 | Raises:
57 | Exception: If -1 is used in more than one entry in the proportions dictionary.
58 | """
59 | idxs = df.index
60 | prop_type = [v for k,v in proportions.items() if v>1]
61 | if len(prop_type)>0:
62 | prop_type = 'n'
63 | else:
64 | prop_type = 'prop'
65 | remainder_k = [k for k,v in proportions.items() if v==-1]
66 | if len(remainder_k) > 1:
67 | raise Exception("-1 can't be used in more than one entry")
68 | elif len(remainder_k) == 1:
69 | remainder_k = remainder_k[0]
70 | else:
71 | remainder_k = None
72 | partitions = {}
73 | for k,v in proportions.items():
74 | if k != remainder_k:
75 | if prop_type == 'prop':
76 | v = int(len(df)*v)
77 | sampled_idxs = np.random.choice(idxs, v, replace=False)
78 | idxs = [i for i in idxs if i not in sampled_idxs]
79 | partitions[k] = df.loc[sampled_idxs]
80 | if remainder_k is not None:
81 | partitions[remainder_k] = df.loc[idxs]
82 | return partitions
83 |
84 | class DictDataset(Dataset):
85 | """
86 | Dataset class to handle data stored in a dictionary-like format.
87 |
88 | Args:
89 | metadata (pandas.DataFrame): DataFrame containing metadata of the dataset.
90 | state (dict): Dictionary containing additional state information.
91 | out_cols (list): List of columns to be included in the output.
92 | preprocessor (optional): Callable to apply to a dataframe row before returning the item. Defaults to None.
93 | index_mapper (callable, optional): A function to map indices of metadata. Defaults to None.
94 | state_keys (list, optional): List of keys from the state dictionary to be included in the dataset's state. Defaults to None.
95 | """
96 | def __init__(self,
97 | metadata: pd.DataFrame,
98 | state: Dict[str, Any],
99 | out_cols: List[str],
100 | preprocessor: Callable[[Any, Dict[str, Any]], Any] = None,
101 | index_mapper: Optional[Callable[[pd.DataFrame], List[int]]] = None,
102 | state_keys: Optional[List[str]] = None):
103 |
104 | self._metadata = metadata
105 | self._out_cols = out_cols
106 | self._state = {}
107 | self._state['metadata'] = metadata
108 | if 'classID' in state['dataset_metadata'].columns:
109 | if isinstance(state['dataset_metadata'].iloc[0]['classID'], np.ndarray):
110 | self._state['num_classes'] = len(state['dataset_metadata'].iloc[0]['classID'])
111 | else:
112 | self._state['num_classes'] = state['dataset_metadata']['classID'].max() + 1
113 | if state_keys is not None:
114 | for k in state_keys:
115 | if k in state:
116 | self._state[k] = state[k]
117 | self._preprocessor = preprocessor()
118 | if index_mapper is not None:
119 | self._idx_map = index_mapper(self._metadata)
120 | else:
121 | self._idx_map = list(range(len(self._metadata)))
122 |
123 | def __getitem__(self, idx):
124 | row = copy.deepcopy(self._metadata.iloc[self._idx_map[idx]])
125 | if self._preprocessor is not None:
126 | row = self._preprocessor(row)
127 | out = {k: row[k] for k in self._out_cols}
128 | return out
129 |
130 | def __len__(self):
131 | return len(self._idx_map)
132 |
133 | def dynamic_pad_batch(x: Union[list, Dict[str, Any]]) -> Dict[str, torch.Tensor]:
134 | """
135 | Dynamically pads a batch of sequences with variable lengths and converts them to PyTorch tensors.
136 |
137 | Args:
138 | x (Union[list, dict]): List or dictionary containing sequences to be padded.
139 |
140 | Returns:
141 | dict: Dictionary containing padded sequences converted to PyTorch tensors.
142 | """
143 | def not_discarded(x):
144 | if x is None:
145 | return False
146 | else:
147 | return not any([xi is None for xi in x.values()])
148 |
149 | def get_len(x):
150 | if x.ndim == 0:
151 | return 1
152 | else:
153 | return x.shape[0]
154 |
155 | def pad_to_len(x, max_len):
156 | if x.ndim == 0:
157 | return x
158 | else:
159 | pad_spec = ((0,max_len-x.shape[0]),) + ((0,0),)*(x.ndim - 1)
160 | return np.pad(x,pad_spec)
161 |
162 | def to_torch(x):
163 | if isinstance(x, torch.Tensor):
164 | return x
165 | else:
166 | if x.dtype in [np.float64, np.float32, np.float16,
167 | np.complex64, np.complex128,
168 | np.int64, np.int32, np.int16, np.int8,
169 | np.uint8, np.bool]:
170 |
171 | return torch.from_numpy(x)
172 | else:
173 | return x
174 |
175 | x_ = x
176 | x = [xi for xi in x if not_discarded(xi)]
177 |
178 | batch = {k: [np.array(xi[k]) for xi in x] for k in x[0]}
179 | batch_lens = {k: [get_len(x) for x in batch[k]] for k in batch.keys()}
180 | batch_max_lens = {k: max(v) for k,v in batch_lens.items()}
181 | batch = {k: np.stack([pad_to_len(x, batch_max_lens[k]) for x in batch[k]]) for k in batch.keys()}
182 | batch_lens = {k+'_lens': np.array(v) for k,v in batch_lens.items()}
183 | batch.update(batch_lens)
184 | batch = {k: to_torch(v) for k,v in batch.items()}
185 |
186 | return batch
187 |
188 | def get_dataloaders(state: Dict[str, Any],
189 | split_function: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
190 | dataset_cls: Optional[Type[Any]] = None,
191 | dataloader_cls: Optional[Type[Any]] = None,
192 | dataset_key_in: str = 'dataset_metadata',
193 | dataset_key_out: str = 'datasets',
194 | partitions_key_out: str = 'partitions',
195 | dataloaders_key_out: str = 'dataloaders') -> Dict[str, Any]:
196 | """
197 | Constructs dataloaders from the given state and configurations.
198 |
199 | Args:
200 | state (dict): The dictionary representing the current state.
201 | split_function (callable, optional): A function to split the dataset. Defaults to None.
202 | dataset_cls (type, optional): The class to instantiate datasets. Defaults to None.
203 | dataloader_cls (type, optional): The class to instantiate dataloaders. Defaults to None.
204 | dataset_key_in (str, optional): The key in the state to get the dataset metadata. Defaults to 'dataset_metadata'.
205 | dataset_key_out (str, optional): The key to use for storing datasets in the state. Defaults to 'datasets'.
206 | partitions_key_out (str, optional): The key to use for storing dataset partitions in the state. Defaults to 'partitions'.
207 | dataloaders_key_out (str, optional): The key to use for storing dataloaders in the state. Defaults to 'dataloaders'.
208 |
209 | Returns:
210 | dict: The updated state dictionary with datasets, partitions, and dataloaders.
211 |
212 | Raises:
213 | TypeError: If split_function, dataset_cls, or dataloader_cls is not callable.
214 | """
215 |
216 | if split_function is not None:
217 | partitions = split_function(state[dataset_key_in])
218 | else:
219 | partitions = {'train': state[dataset_key_in]}
220 |
221 | datasets = {k: dataset_cls[k](v, state) for k,v in partitions.items() if k in dataset_cls}
222 | dataloaders = {k: dataloader_cls[k](v) for k,v in datasets.items() if k in dataloader_cls}
223 |
224 | state[partitions_key_out] = partitions
225 | state[dataset_key_out] = datasets
226 | state[dataloaders_key_out] = dataloaders
227 |
228 | return state
229 |
230 | def load_dataset(state: Dict[str, Any],
231 | reader_fns: List[Callable[[], pd.DataFrame]],
232 | cache: bool = True,
233 | postprocessors: List[Callable[[pd.DataFrame], pd.DataFrame]] = [],
234 | key_out: str = 'dataset_metadata',
235 | rename: List[Dict[str, Any]] = None) -> Dict[str, Any]:
236 | """
237 | Loads dataset into the state.
238 |
239 | Args:
240 | state (dict): The dictionary representing the current state.
241 | reader_fns (list): List of functions that return DataFrames when called.
242 | cache (bool, optional): Whether to cache the dataset in the state. Defaults to True.
243 | postprocessors (list, optional): List of functions to apply to the resulting dataframe. Each function should accept and return a DataFrame. Defaults to [].
244 | key_out (str, optional): The key to use for storing the dataset in the state. Defaults to 'dataset_metadata'.
245 | rename (list, optional): List of dictionaries specifying renaming rules. Each dictionary should contain keys 'column', 'value', and 'new_value' for renaming. Defaults to None.
246 |
247 | Returns:
248 | dict: The updated state dictionary with the loaded dataset.
249 |
250 | Raises:
251 | TypeError: If reader_fns is not a list.
252 | """
253 |
254 | if not (cache and key_out in state):
255 | if not isinstance(reader_fns, list):
256 | raise TypeError("reader_fns must be a list of reader functions.")
257 | elif len(reader_fns) == 0:
258 | raise Exception("reader_fns is empty. Supply at least one reader function")
259 | dfs = [fn() for fn in reader_fns]
260 | df = pd.concat(dfs).reset_index()
261 | state[key_out] = df
262 | else:
263 | logger.info('Caching dataset metadata from state')
264 |
265 | for f in postprocessors:
266 | state[key_out] = f(state[key_out])
267 |
268 | if rename is not None:
269 | for r in rename:
270 | state[key_out][r['column']] = state[key_out][r['column']].apply(lambda x: r['new_value'] if x == r['value'] else x)
271 |
272 | return state
273 |
274 | def read_audiodir(dataset_path: List[str],
275 | subsample: Optional[int] = None,
276 | dataset: Optional[str] = None,
277 | regex_groups: Optional[str] = None,
278 | filter_list: Optional[str] = None,
279 | partition_lists: Optional[Dict[str, Optional[str]]] = None,
280 | filter_mode: str = 'include') -> pd.DataFrame:
281 | """
282 | Reads audio files from directories and generates metadata DataFrame.
283 |
284 | Args:
285 | dataset_path (list): List of paths to directories containing audio files.
286 | subsample (int, optional): Number of files to subsample. Defaults to None.
287 | dataset (str, optional): Name of the dataset. Defaults to None.
288 | regex_groups (str, optional): Regular expression to extract metadata from filenames. Defaults to None.
289 | filter_list (str, optional): Path to a file containing a list of filenames to filter. Defaults to None.
290 | partition_lists (dict, optional): Dictionary mapping partitions to filenames. Defaults to None.
291 | filter_mode (str, optional): Filtering mode, either 'include' or 'discard'. Defaults to 'include'.
292 |
293 | Returns:
294 | pandas.DataFrame: Metadata DataFrame containing information about audio files.
295 |
296 | Raises:
297 | Exception: If an unrecognized filter mode is provided.
298 | """
299 |
300 | if not isinstance(dataset_path, list):
301 | dataset_path = [dataset_path]
302 |
303 | all_files = []
304 | for p in dataset_path:
305 | all_files_i = list(Path(p).rglob('*.wav')) + list(Path(p).rglob('*.flac'))
306 | all_files.extend(all_files_i)
307 |
308 | if filter_list is not None:
309 | with open(filter_list, 'r') as f:
310 | keep_values = set(f.read().splitlines())
311 | n_slashes = len(next(iter(keep_values)).split('/')) - 1
312 | stem_to_f = {'/'.join(v.parts[-n_slashes-1:]): v for v in all_files}
313 | if filter_mode == 'include':
314 | all_files = [stem_to_f[k] for k in keep_values]
315 | elif filter_mode == 'discard':
316 | all_files = [v for k,v in stem_to_f.items() if k not in keep_values]
317 | else:
318 | raise Exception("Unrecognized filter_mode {}".format(filter_mode))
319 |
320 | if subsample is not None:
321 | subsample_idx = np.random.choice(np.arange(len(all_files)),size=subsample,replace=False)
322 | all_files = np.array(all_files)[subsample_idx]
323 |
324 | rows = []
325 | for f in tqdm(all_files):
326 | try:
327 | finfo = sf.info(f)
328 | metadata = {'filename': str(f.resolve()),
329 | 'sr': finfo.samplerate,
330 | 'channels': finfo.channels,
331 | 'frames': finfo.frames,
332 | 'duration': finfo.duration,
333 | 'rel_path': str(f.relative_to(p))}
334 | if regex_groups is not None:
335 | regex_data = re.match(regex_groups,str(f.relative_to(dataset_path[0]))).groupdict()
336 | metadata.update(regex_data)
337 | rows.append(metadata)
338 | except Exception as e:
339 | print(f'Failed reading {f}. {e}')
340 | df = pd.DataFrame(rows)
341 |
342 | if dataset is not None:
343 | df['dataset'] = dataset
344 |
345 | # df['rel_path'] = df['filename'].apply(lambda x: str(Path(x).relative_to(Path(dataset_path[0]))))
346 |
347 | if partition_lists is not None:
348 | remainder = None
349 | map_to_partitions={}
350 | for k,v in partition_lists.items():
351 | if v is not None:
352 | list_path = Path(dataset_path[0],v)
353 | with open(list_path,'r') as f:
354 | list_files = f.read().splitlines()
355 | for l in list_files:
356 | map_to_partitions[str(l)] = k
357 | else:
358 | remainder = k
359 |
360 | df['partition'] = df['rel_path'].apply(lambda x: map_to_partitions[x] if x in map_to_partitions else remainder)
361 | df = df.drop('rel_path', axis=1)
362 |
363 | return df
364 |
365 | def read_st_dataset(dataset_path):
366 | df = pd.read_csv(Path(dataset_path, 'metadata_selftrain_dataset.csv'), names=['start','stop','filename'])
367 | df = df.reset_index()
368 | df = df.rename({'index':'filename_audio','filename':'filename_targets'},axis=1)
369 | return df
370 |
371 | def remove_long_audios(df: pd.DataFrame, limit: int = 10000) -> pd.DataFrame:
372 | """
373 | Removes rows from a DataFrame where the duration of audio files exceeds a specified limit.
374 |
375 | Args:
376 | df (pandas.DataFrame): The DataFrame containing audio metadata.
377 | limit (int, optional): The maximum duration in milliseconds. Defaults to 10000.
378 |
379 | Returns:
380 | pandas.DataFrame: DataFrame with rows removed where the duration exceeds the limit.
381 | """
382 | df = df.loc[df['duration'] None:
13 | pass
14 |
15 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]:
16 | """Process input and return output.
17 |
18 | Args:
19 | x (Dict[str, Any]): Input data.
20 |
21 | Returns:
22 | Dict[str, Any]: Processed data.
23 | """
24 | raise NotImplementedError
25 |
26 | class SequentialProcessor(Processor):
27 | """Sequential processor that applies a list of processors sequentially."""
28 |
29 | def __init__(self, processors: List[Processor]) -> None:
30 | """Initialize SequentialProcessor.
31 |
32 | Args:
33 | processors (List[Processor]): List of processors.
34 | """
35 | self._processors = [p() for p in processors]
36 |
37 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]:
38 | """Process input by applying each processor sequentially.
39 |
40 | Args:
41 | x (Dict[str, Any]): Input data.
42 |
43 | Returns:
44 | Dict[str, Any]: Processed data.
45 | """
46 | for p in self._processors:
47 | x = p(x)
48 | return x
49 |
50 | class ReadAudioProcessor(Processor):
51 | """Processor to read audio files."""
52 |
53 | def __init__(self, key_in: str, key_out: str, max_length: Union[float, None] = None, mono: bool = True) -> None:
54 | """Initialize ReadAudioProcessor.
55 |
56 | Args:
57 | key_in (str): Key for input audio.
58 | key_out (str): Key for output audio.
59 | max_length (Union[float, None], optional): Maximum length of audio in seconds. Defaults to None.
60 | mono (bool, optional): Whether to convert stereo audio to mono. Defaults to True.
61 | """
62 | super().__init__()
63 | self.key_in, self.key_out, self.max_length, self.mono = key_in, key_out, max_length, mono
64 |
65 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]:
66 | """Read audio file and process it.
67 |
68 | Args:
69 | x (Dict[str, Any]): Input data.
70 |
71 | Returns:
72 | Dict[str, Any]: Processed data.
73 | """
74 | try:
75 | if self.max_length is not None:
76 | audio_info = sf.info(x[self.key_in])
77 | desired_frames = int(self.max_length*audio_info.samplerate)
78 | total_frames = audio_info.frames
79 | if total_frames > desired_frames:
80 | start = random.randint(0,total_frames - desired_frames)
81 | stop = start + desired_frames
82 | else:
83 | start = 0
84 | stop = None
85 | else:
86 | start = 0
87 | stop = None
88 | if 'start' in x:
89 | start = x['start']
90 | if 'stop' in x:
91 | stop = x['stop']
92 | x['start'] = start
93 | x['stop'] = stop
94 | wav, fs = sf.read(x[self.key_in], start=start, stop=stop, dtype=np.float32)
95 | if (wav.ndim == 2) and self.mono:
96 | wav = np.mean(wav,axis=-1)
97 | except Exception as e:
98 | logger.warning('Failed reading {}'.format(x[self.key_in]))
99 | wav = None
100 | x[self.key_out] = wav
101 | return x
102 |
103 | class LoadNumpyProcessor(Processor):
104 | """Processor to load numpy arrays."""
105 |
106 | def __init__(self, key_in: str, key_out: str) -> None:
107 | """Initialize LoadNumpyProcessor.
108 |
109 | Args:
110 | key_in (str): Key for input numpy array file.
111 | key_out (str): Key for output numpy array.
112 | """
113 | super().__init__()
114 | self.key_in = key_in
115 | self.key_out = key_out
116 |
117 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]:
118 | """Load numpy array and process it.
119 |
120 | Args:
121 | x (Dict[str, Any]): Input data.
122 |
123 | Returns:
124 | Dict[str, Any]: Processed data.
125 | """
126 | x[self.key_out] = np.load(x[self.key_in])
127 | return x
128 |
129 | class MelspectrogramProcessor(Processor):
130 | """Processor to calculate melspectrograms from waveforms.
131 | Internally, torchaudio.compliance.kaldi.fbank is used and the same kwargs are accepted.
132 | Additionally, norm_stats can be supplied with [mean,std] to normalize the resulting melspectrogram.
133 | key_in and key_out are strings indicating the key containing the waveform
134 | and the key where the result will be stored."""
135 |
136 | def __init__(self, key_in='wav', key_out='wav_features',
137 | frame_length=25,
138 | frame_shift=10,
139 | high_freq=0,
140 | htk_compat=False,
141 | low_freq=20,
142 | num_mel_bins=23,
143 | sample_frequency=16000,
144 | window_type='povey',
145 | dither=0.0,
146 | use_energy=False, norm_stats=[0,1]):
147 | super().__init__()
148 | self.mel_kwargs = dict(frame_length=frame_length,
149 | frame_shift=frame_shift, high_freq=high_freq,
150 | htk_compat=htk_compat, low_freq=low_freq, num_mel_bins=num_mel_bins,
151 | sample_frequency=sample_frequency, window_type=window_type, use_energy=use_energy)
152 | self.norm_stats = norm_stats
153 | self.key_in, self.key_out = key_in, key_out
154 |
155 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]:
156 | """Calculate melspectrogram and normalize it.
157 |
158 | Args:
159 | x (Dict[str, Any]): Input data dictionary containing waveform.
160 |
161 | Returns:
162 | Dict[str, Any]: Output data dictionary containing melspectrogram.
163 |
164 | """
165 | if x[self.key_in] is not None:
166 | mel = torchaudio.compliance.kaldi.fbank(torch.from_numpy(x[self.key_in]).unsqueeze(0), **self.mel_kwargs).numpy()
167 | mel = (mel-self.norm_stats[0])/self.norm_stats[1]
168 | else:
169 | mel = None
170 | x[self.key_out] = mel
171 | return x
172 |
173 | class SpectrogramProcessor(Processor):
174 | """Processor to calculate spectrograms from waveforms.
175 | Internally, torchaudio.compliance.kaldi.spectrogram is used and the same kwargs are accepted.
176 | Additionally, norm_stats can be supplied with [mean,std] to normalize the resulting spectrogram.
177 | key_in and key_out are strings indicating the key containing the waveform
178 | and the key where the result will be stored."""
179 | def __init__(self, key_in='wav', key_out='wav_features',
180 | frame_length=25,
181 | frame_shift=10,
182 | sample_frequency=16000,
183 | window_type='povey',
184 | dither=0.0,
185 | norm_stats=[0,1]):
186 | self.spec_kwargs = dict(frame_length=frame_length,
187 | frame_shift=frame_shift,
188 | sample_frequency=sample_frequency,
189 | window_type=window_type,
190 | dither=dither)
191 | self.norm_stats = norm_stats
192 | self.key_in, self.key_out = key_in, key_out
193 |
194 | def __call__(self, x: Dict[str, Any]) -> Dict[str, Any]:
195 | """Calculate spectrogram and normalize it.
196 |
197 | Args:
198 | x (Dict[str, Any]): Input data dictionary containing waveform.
199 |
200 | Returns:
201 | Dict[str, Any]: Output data dictionary containing spectrogram.
202 |
203 | """
204 | spec = torchaudio.compliance.kaldi.spectrogram(torch.from_numpy(x[self.key_in]).unsqueeze(0), **self.spec_kwargs).numpy()
205 | spec = (spec-self.norm_stats[0])/self.norm_stats[1]
206 | x[self.key_out] = spec
207 | return x
208 |
--------------------------------------------------------------------------------
/encodecmae/tasks/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 |
5 | def set_seed(state, seed=42):
6 | random.seed(seed)
7 | np.random.seed(seed)
8 | torch.manual_seed(seed)
9 | state.seed = seed
10 | return state
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | huggingface_hub
2 | encodec==0.1.1
3 | ginpipe==0.0.2
4 | librosa>=0.10.1
5 | numpy>=1.24.1
6 | pytorch_lightning>=2.0.0
7 | torch>=2.0.0
8 | torchinfo
9 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = encodecmae
3 | version = 0.0.1
4 | author = Leonardo Pepino
5 | author_email = lpepino@dc.uba.ar
6 | description = EnCodecMAE
7 | long_description = file: README.md
8 | long_description_content_type = text/markdown
9 | classifiers =
10 | Programming Language :: Python :: 3
11 | Operating System :: OS Independent
12 | [options]
13 | packages = encodecmae
14 | python_requires = >=3.6
15 | install_requires =
16 | huggingface_hub
17 | encodec==0.1.1
18 | torch>=2.0.0
19 | pytorch_lightning>=2.0.0
20 | ginpipe
21 | librosa
22 | torchinfo
23 | pandas
24 |
25 |
26 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | setup()
--------------------------------------------------------------------------------
/start_docker.sh:
--------------------------------------------------------------------------------
1 | cd encodecmae
2 | docker build -t encodecmae:latest .
3 | docker compose up -d
4 | docker attach encodecmae-train
--------------------------------------------------------------------------------