├── .dockerignore ├── .gitignore ├── Dockerfile ├── README.md ├── conv_ssl ├── __init__.py ├── augmentations.py ├── callbacks.py ├── conf │ ├── __init__.py │ ├── config.yaml │ ├── data │ │ └── data.yaml │ ├── events │ │ └── events.yaml │ ├── model │ │ ├── comparative.yaml │ │ ├── discrete.yaml │ │ ├── discrete_20hz.yaml │ │ ├── discrete_50hz.yaml │ │ ├── independent.yaml │ │ └── independent40.yaml │ ├── swb_kfolds │ │ ├── 0_fold_train.txt │ │ ├── 0_fold_val.txt │ │ ├── 10_fold_train.txt │ │ ├── 10_fold_val.txt │ │ ├── 1_fold_train.txt │ │ ├── 1_fold_val.txt │ │ ├── 2_fold_train.txt │ │ ├── 2_fold_val.txt │ │ ├── 3_fold_train.txt │ │ ├── 3_fold_val.txt │ │ ├── 4_fold_train.txt │ │ ├── 4_fold_val.txt │ │ ├── 5_fold_train.txt │ │ ├── 5_fold_val.txt │ │ ├── 6_fold_train.txt │ │ ├── 6_fold_val.txt │ │ ├── 7_fold_train.txt │ │ ├── 7_fold_val.txt │ │ ├── 8_fold_train.txt │ │ ├── 8_fold_val.txt │ │ ├── 9_fold_train.txt │ │ └── 9_fold_val.txt │ ├── trainer │ │ └── trainer.yaml │ └── vap │ │ └── vap.yaml ├── datamodule_disk.py ├── dataset_save_samples_to_disk.py ├── evaluation │ ├── README.md │ ├── __init__.py │ ├── anova.py │ ├── duration.py │ ├── evaluate_paper_models.py │ ├── evaluation.py │ ├── evaluation_augmentation.py │ ├── evaluation_phrases.py │ ├── extract_video_data.py │ ├── forced_alignment.bash │ ├── forced_alignment_duration.bash │ ├── phrase_dataset.py │ ├── phrases.json │ ├── phrases_duration_process.py │ ├── prepare_phrases_for_alignment.py │ ├── tts.py │ ├── update_checkpoints.py │ ├── utils.py │ └── vad.py ├── model.py ├── models │ ├── __init__.py │ ├── autoregressive.py │ ├── cnn.py │ ├── cpc_base_model.py │ ├── encoder.py │ ├── multi_head_attention.py │ ├── transformer.py │ └── transformer_old.py ├── plot_utils.py ├── train.py ├── train_disk.py ├── transforms.py └── utils.py ├── docker ├── Dockerfile_base ├── README.md └── prepare_dataset.py ├── example ├── cpc_48_50hz_15gqq5s5.ckpt ├── student_long_female_en-US-Wavenet-G.TextGrid ├── student_long_female_en-US-Wavenet-G.json ├── student_long_female_en-US-Wavenet-G.wav └── vad_list.json ├── frontend └── main.py ├── pytest.ini ├── requirements.txt ├── run.py ├── scripts ├── README.md ├── model_kfold.bash ├── pitch_eval.bash ├── test_augmentation.bash ├── test_future.py ├── test_models_script.bash ├── test_phrases.bash ├── test_regular.bash ├── train_script.bash └── train_vap_new.bash ├── setup.py ├── tests ├── __init__.py ├── test_ar.py ├── test_cpc_causality.py ├── test_encoder.py └── test_main.py └── visualize_run.py /.dockerignore: -------------------------------------------------------------------------------- 1 | assets/ 2 | artifacts/ 3 | runs/ 4 | wandb/ 5 | dataset_units/ 6 | lightning_logs/ 7 | checkpoints/ 8 | swb_dataset/ 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Custom 132 | assets/ 133 | checkpoints/ 134 | swb_dataset/ 135 | runs/ 136 | *.wav 137 | tmp_data/ 138 | assets_local/ 139 | dataset_units/ 140 | wandb/ 141 | artifacts/ 142 | lightning_logs/ 143 | pyrightconfig.json 144 | *.lock 145 | *.mp4 146 | *.avi 147 | .jekyll* 148 | _site/ 149 | .streamlit/ 150 | .sass* 151 | 152 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Conv_ssl 2 | # Combine this with Dockerfile_base when everything is done and working... 3 | FROM vap_base 4 | 5 | # datasets_turntaking 6 | WORKDIR /dependencies 7 | RUN git clone https://github.com/ErikEkstedt/datasets_turntaking.git 8 | WORKDIR /dependencies/datasets_turntaking 9 | RUN pip install -r requirements.txt 10 | RUN pip install -e . 11 | 12 | # vad_turn_taking 13 | WORKDIR /dependencies 14 | RUN git clone https://github.com/ErikEkstedt/vap_turn_taking.git 15 | WORKDIR /dependencies/vap_turn_taking 16 | RUN pip install -r requirements.txt 17 | RUN pip install -e . 18 | 19 | # conv_ssl 20 | WORKDIR /workspace 21 | COPY . . 22 | RUN pip install -r requirements.txt 23 | RUN pip install -e . 24 | 25 | # Prepare switchboard (so we dont have to download it all the time) 26 | RUN python docker/prepare_dataset.py 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continuous Conversational SSL 2 | 3 | OLD REPO FOR ACCESS TO THE NEWER STEREO-VERSION OF THIS MODEL SEE: 4 | 5 | * [VoiceActivityProjection](https://github.com/ErikEkstedt/VoiceActivityProjection) 6 | 7 | IT SIMPLIFIES INFERENCE AND DEPENDENCY HELL... 8 | 9 | 10 | -------------------------------------------------------------------- 11 | 12 | Model training for 13 | * [Voice Activity Projection: Self-supervised Learning of Turn-taking Events](https://arxiv.org/abs/2205.09812) 14 | * [How Much Does Prosody Help Turn-taking? Investigations using Voice Activity Projection Models]() 15 | 16 | 17 | ## Installation 18 | 19 | * Create conda env: `conda create -n conv_ssl python=3.9` 20 | - source env: `conda source conv_ssl` 21 | * PyTorch: `conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch` 22 | * Dependencies: 23 | * Install requirements: `pip install -r requirements.txt` 24 | * **NOTE:** If you have problems install `pip install cython` manually first and then run the `pip install -r requirements.txt` command (trouble automating the install of the [CPC_audio](https://github.com/facebookresearch/CPC_audio) repo). 25 | * [Optional] Manual installation of [CPC_audio](https://github.com/facebookresearch/CPC_audio) (if the note above does not work) 26 | * `git clone https://github.com/facebookresearch/CPC_audio.git` 27 | * cd to repo and install dependencies (see repository) but probably you'll need 28 | * `pip install cython` 29 | * run: `python setup.py develop` (again see original implementation) 30 | * **VAP**: Voice Activity Projection multi-purpose "head". 31 | * Install [`vap_turn_taking`](https://github.com/ErikEkstedt/vap_turn_taking) 32 | * `git clone https://github.com/ErikEkstedt/vap_turn_taking.git` 33 | * cd to repo, and install dependencies: `pip install -r requirements.txt` 34 | * Install: `pip install -e .` 35 | * **DATASET** 36 | * Install [datasets_turntaking](https://github.com/ErikEkstedt/datasets_turntaking) 37 | * `git clone https://github.com/ErikEkstedt/datasets_turntaking.git` 38 | * cd to repo, and install dependencies: `pip install -r requirements.txt` 39 | * Install repo: `pip install -e .` 40 | * **WARNING:** Requires [Switchboard](https://catalog.ldc.upenn.edu/LDC97S62) and/or [Fisher](https://catalog.ldc.upenn.edu/LDC2004S13) data! 41 | * Install **`conv_ssl`:** 42 | * cd to root directory and run: `pip install -e .` 43 | 44 | ### Train 45 | 46 | ```bash 47 | python conv_ssl/train.py data.datasets=['switchboard','fisher'] +trainer.val_check_interval=0.5 early_stopping.patience=20 48 | ``` 49 | 50 | ### Evaluate 51 | 52 | ```bash 53 | python conv_ssl/evaluation/evaluation.py \ 54 | +checkpoint_path=/full/path/checkpoint.ckpt \ 55 | +savepath=assets/vap_fis \ 56 | data.num_workers=4 \ 57 | data.batch_size=16 58 | ``` 59 | 60 | 61 | ### Run 62 | 63 | The `run.py` script loads a pretrained model and evaluates on a sample (waveform + `text_grid_name.TextGrid` or `vad_list_name.json`). See `examples` folder for format etc. 64 | 65 | * Using defaults: `python run.py` 66 | * Custom run requires a audio file `sample.wav` and **either** a `text_grid_name.TextGrid` or `vad_list_name.json` 67 | ```bash 68 | python run.py \ 69 | -c example/cpc_48_50hz_15gqq5s5.ckpt \ 70 | -w example/student_long_female_en-US-Wavenet-G.wav \ # waveform 71 | -tg example/student_long_female_en-US-Wavenet-G.TextGrid \ # text grid 72 | -v example/vad_list.json \ # vad-list 73 | -o VAP_OUTPUT.json # output file 74 | ``` 75 | 76 | 77 | ### Paper 78 | 79 | The paper investigates the performance over kfold splits (see `conv_ssl/config/swb_kfolds`) over 4 different model architectures ('discrete', 'independent', 'independent-40', 'comparative'). 80 | * Save samples to disk: `conv_ssl/dataset_save_samples_to_disk.py` 81 | * train on samples on disk: `conv_ssl/train_disk.py` 82 | * run `scripts/model_kfold.bash` 83 | * We evaluate (find threshold over validation set + final evaluation on test-set) 84 | - see `conv_ssl/evaluation/evaluate_paper_model.py` 85 | - the ids are the `WandB` ids. 86 | - We save all model scores to disk 87 | * In `conv_ssl/evaluation/anova.py` we compare the scores to extract the final values in the paper. 88 | 89 | ## Experiments 90 | 91 | * Training uses [WandB](https://wandb.ai) by default. 92 | * The event settings used in the paper are included in `conv_ssl/config/event_settings.json`. 93 | - See paper Section 3 94 | 95 | ```python 96 | from conv_ssl.utils import read_json 97 | 98 | event_settings = read_json("conv_ssl/config/event_settings.json") 99 | hs_kwargs = event_settings['hs'] 100 | bc_kwargs = event_settings['bc'] 101 | metric_kwargs = event_settings['metric'] 102 | ``` 103 | 104 | ```json 105 | { 106 | "hs": { 107 | "post_onset_shift": 1, 108 | "pre_offset_shift": 1, 109 | "post_onset_hold": 1, 110 | "pre_offset_hold": 1, 111 | "non_shift_horizon": 2, 112 | "metric_pad": 0.05, 113 | "metric_dur": 0.1, 114 | "metric_pre_label_dur": 0.5, 115 | "metric_onset_dur": 0.2 116 | }, 117 | "bc": { 118 | "max_duration_frames": 1.0, 119 | "pre_silence_frames": 1.0, 120 | "post_silence_frames": 2.0, 121 | "min_duration_frames": 0.2, 122 | "metric_dur_frames": 0.2, 123 | "metric_pre_label_dur": 0.5 124 | }, 125 | "metric": { 126 | "pad": 0.05, 127 | "dur": 0.1, 128 | "pre_label_dur": 0.5, 129 | "onset_dur": 0.2, 130 | "min_context": 3.0 131 | } 132 | } 133 | ``` 134 | 135 | 136 | ## Citation 137 | 138 | ```latex 139 | TBA 140 | ``` 141 | -------------------------------------------------------------------------------- /conv_ssl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ErikEkstedt/conv_ssl/c365345afff3df33c791c6fc9d498bc08617ffb7/conv_ssl/__init__.py -------------------------------------------------------------------------------- /conv_ssl/callbacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import wandb 4 | 5 | from conv_ssl.augmentations import ( 6 | flatten_pitch_batch, 7 | shift_pitch_batch, 8 | low_pass_filter_resample, 9 | IntensityNeutralizer, 10 | ) 11 | 12 | 13 | class SymmetricSpeakersCallback(pl.Callback): 14 | """ 15 | This callback "flips" the speakers such that we get a fair evaluation not dependent on the 16 | biased speaker-order / speaker-activity 17 | 18 | The audio is mono which requires no change. 19 | 20 | The only change we apply is to flip the channels in the VAD-tensor and get the corresponding VAD-history 21 | which is defined as the ratio of speaker 0 (i.e. vad_history_flipped = 1 - vad_history) 22 | """ 23 | 24 | def get_symmetric_batch(self, batch): 25 | """Appends a flipped version of the batch-samples""" 26 | for k, v in batch.items(): 27 | if k == "vad": 28 | flipped = torch.stack((v[..., 1], v[..., 0]), dim=-1) 29 | elif k == "vad_history": 30 | flipped = 1.0 - v 31 | else: 32 | flipped = v 33 | if isinstance(v, torch.Tensor): 34 | batch[k] = torch.cat((v, flipped)) 35 | else: 36 | batch[k] = v + flipped 37 | return batch 38 | 39 | def on_train_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 40 | batch = self.get_symmetric_batch(batch) 41 | 42 | def on_test_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 43 | batch = self.get_symmetric_batch(batch) 44 | 45 | def on_val_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 46 | batch = self.get_symmetric_batch(batch) 47 | 48 | 49 | class FlattenPitchCallback(pl.Callback): 50 | """ """ 51 | 52 | def __init__( 53 | self, 54 | target_f0: int = -1, 55 | statistic: str = "mean", 56 | stats_frame_length: int = 800, 57 | stats_hop_length: int = 320, 58 | sample_rate: int = 16000, 59 | to_mono: bool = True, 60 | ): 61 | super().__init__() 62 | self.statistic = statistic 63 | self.stats_frame_length = stats_frame_length 64 | self.stats_hop_length = stats_hop_length 65 | self.target_f0 = target_f0 66 | self.sample_rate = sample_rate 67 | self.to_mono = to_mono 68 | 69 | def flatten_pitch(self, batch, device): 70 | """Appends a flipped version of the batch-samples""" 71 | flat_waveform = flatten_pitch_batch( 72 | waveform=batch["waveform"].cpu(), 73 | vad=batch["vad"], 74 | target_f0=self.target_f0, 75 | statistic=self.statistic, 76 | stats_frame_length=self.stats_frame_length, 77 | stats_hop_length=self.stats_hop_length, 78 | sample_rate=self.sample_rate, 79 | to_mono=self.to_mono, 80 | ) 81 | batch["waveform"] = flat_waveform.to(device) 82 | return batch 83 | 84 | def on_test_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 85 | batch = self.flatten_pitch(batch, device=pl_module.device) 86 | 87 | def on_val_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 88 | batch = self.flatten_pitch(batch, device=pl_module.device) 89 | 90 | def on_train_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 91 | batch = self.flatten_pitch(batch, device=pl_module.device) 92 | 93 | 94 | class NeutralIntensityCallback(pl.Callback): 95 | """ """ 96 | 97 | def __init__( 98 | self, 99 | vad_hz, 100 | vad_cutoff: float = 0.2, 101 | hop_time: float = 0.01, 102 | f0_min: int = 60, 103 | statistic: str = "mean", 104 | sample_rate: int = 16000, 105 | to_mono: bool = True, 106 | ): 107 | super().__init__() 108 | self.hop_time = hop_time 109 | self.vad_hz = vad_hz 110 | self.f0_min = f0_min 111 | self.vad_cutoff = vad_cutoff 112 | self.statistic = statistic 113 | self.sample_rate = sample_rate 114 | self.to_mono = to_mono 115 | self.neutralizer = IntensityNeutralizer( 116 | hop_time=hop_time, 117 | vad_hz=vad_hz, 118 | f0_min=f0_min, 119 | vad_cutoff=vad_cutoff, 120 | scale_stat=statistic, 121 | sample_rate=sample_rate, 122 | to_mono=to_mono, 123 | ) 124 | 125 | def neutral_batch(self, batch): 126 | batch_size = batch["waveform"].shape[0] 127 | n_frames = batch["vad_history"].shape[1] 128 | 129 | combine = False 130 | 131 | if batch["waveform"].ndim == 3: 132 | combine = True 133 | 134 | new_waveform = [] 135 | for b in range(batch_size): 136 | vad = batch["vad"][b, :n_frames] 137 | if combine: 138 | y_tmp = batch["waveform"][b].mean(0, keepdim=True) 139 | else: 140 | y_tmp = batch["waveform"][b] 141 | y, _ = self.neutralizer(y_tmp, vad=vad) 142 | new_waveform.append(y) 143 | batch["waveform"] = torch.cat(new_waveform) 144 | return batch 145 | 146 | def on_test_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 147 | batch = self.neutral_batch(batch) 148 | 149 | def on_val_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 150 | batch = self.neutral_batch(batch) 151 | 152 | def on_train_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 153 | batch = self.neutral_batch(batch) 154 | 155 | 156 | class LowPassFilterCallback(pl.Callback): 157 | """ 158 | Applies a low-pass filter by downsampling and upsampling the signal based on Nyquist theorem. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | cutoff_freq: int = 300, 164 | sample_rate: int = 16000, 165 | norm: bool = True, 166 | to_mono: bool = True, 167 | ): 168 | super().__init__() 169 | self.cutoff_freq = cutoff_freq 170 | self.sample_rate = sample_rate 171 | self.norm = norm 172 | self.to_mono = to_mono 173 | 174 | def normalize(self, x): 175 | assert x.ndim == 2, f"normalization expects (B, n_samples) got {x.shape}" 176 | xx = x - x.min(-1, keepdim=True).values 177 | xx = 2 * xx / xx.max() 178 | xx = xx - 1.0 179 | return xx 180 | 181 | def low_pass(self, waveform): 182 | waveform = low_pass_filter_resample( 183 | waveform, self.cutoff_freq, self.sample_rate 184 | ) 185 | if self.to_mono: 186 | waveform = waveform.mean(1) 187 | 188 | if self.norm: 189 | waveform = self.normalize(waveform) 190 | 191 | return waveform 192 | 193 | def on_test_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 194 | batch["waveform"] = self.low_pass(batch["waveform"]) 195 | 196 | def on_val_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 197 | batch["waveform"] = self.low_pass(batch["waveform"]) 198 | 199 | def on_train_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 200 | batch["waveform"] = self.low_pass(batch["waveform"]) 201 | 202 | 203 | class ShiftPitchCallback(pl.Callback): 204 | def __init__( 205 | self, factor: float = 0.9, sample_rate: int = 16000, to_mono: bool = True 206 | ): 207 | super().__init__() 208 | self.factor = factor 209 | self.sample_rate = sample_rate 210 | self.to_mono = to_mono 211 | 212 | def shift_pitch(self, batch, device): 213 | flat_waveform = shift_pitch_batch( 214 | waveform=batch["waveform"].cpu(), 215 | factor=self.factor, 216 | vad=batch["vad"], 217 | sample_rate=self.sample_rate, 218 | to_mono=self.to_mono, 219 | ) 220 | batch["waveform"] = flat_waveform.to(device) 221 | return batch 222 | 223 | def on_test_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 224 | batch = self.shift_pitch(batch, device=pl_module.device) 225 | 226 | def on_val_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 227 | batch = self.shift_pitch(batch, device=pl_module.device) 228 | 229 | def on_train_batch_start(self, trainer, pl_module, batch, *args, **kwargs): 230 | batch = self.shift_pitch(batch, device=pl_module.device) 231 | 232 | 233 | class WandbArtifactCallback(pl.Callback): 234 | def upload(self, trainer): 235 | run = trainer.logger.experiment 236 | print(f"Ending run: {run.id}") 237 | artifact = wandb.Artifact(f"{run.id}_model", type="model") 238 | for path, val_loss in trainer.checkpoint_callback.best_k_models.items(): 239 | print(f"Adding artifact: {path}") 240 | artifact.add_file(path) 241 | run.log_artifact(artifact) 242 | 243 | def on_train_end(self, trainer, pl_module): 244 | print("Training End ---------------- Custom Upload") 245 | self.upload(trainer) 246 | 247 | def on_exception(self, trainer, pl_module, exception): 248 | if isinstance(exception, KeyboardInterrupt): 249 | print("Keyboard Interruption ------- Custom Upload") 250 | self.upload(trainer) 251 | 252 | 253 | if __name__ == "__main__": 254 | from os.path import join, basename 255 | from conv_ssl.evaluation.evaluation_phrases import load_model_dset 256 | import sounddevice as sd 257 | 258 | ch_root = "assets/PaperB/checkpoints" 259 | checkpoint = join(ch_root, "cpc_48_50hz_15gqq5s5.ckpt") 260 | checkpoint = join(ch_root, "cpc_48_50hz_15gqq5s5.ckpt") 261 | model, dset = load_model_dset(checkpoint) 262 | checkpoint_name = basename(checkpoint) 263 | 264 | batch = dset.get_sample("student", "long", "female", 0) 265 | batch["waveform"] = batch["waveform"].unsqueeze(1) 266 | waveform = shift_pitch_batch(batch["waveform"].cpu(), factor=0.8) 267 | sd.play(waveform[0].cpu(), samplerate=16000) 268 | 269 | # test Callbacks 270 | batch = dset.get_sample("student", "long", "female", 0) 271 | batch["waveform"] = batch["waveform"].unsqueeze(1) 272 | # augmentation = "flat_f0" 273 | # augmentation = 'only_f0' 274 | augmentation = "shift_f0" 275 | clb = [] 276 | if augmentation == "flat_f0": 277 | clb.append(FlattenPitchCallback()) 278 | elif augmentation == "only_f0": 279 | clb.append(LowPassFilterCallback(cutoff_freq=300)) 280 | elif augmentation == "shift_f0": 281 | clb.append(ShiftPitchCallback(factor=0.9)) 282 | elif augmentation == "flat_intensity": 283 | pass 284 | clb[0].on_test_batch_start(trainer=None, pl_module=model, batch=batch) 285 | print(batch["waveform"]) 286 | sd.play(batch["waveform"][0].cpu(), samplerate=16000) 287 | -------------------------------------------------------------------------------- /conv_ssl/conf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ErikEkstedt/conv_ssl/c365345afff3df33c791c6fc9d498bc08617ffb7/conv_ssl/conf/__init__.py -------------------------------------------------------------------------------- /conv_ssl/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: discrete_50hz 3 | - events: events 4 | 5 | seed: 1 6 | verbose: false 7 | 8 | wandb: 9 | project: 'VAP' 10 | dont_log_model: true 11 | 12 | optimizer: 13 | alpha: 0.6 14 | learning_rate: 3.63e-4 15 | betas: [0.9, 0.999] 16 | weight_decay: 0.001 17 | lr_scheduler_interval: "step" 18 | lr_scheduler_freq: 100 19 | lr_scheduler_tmax: 2500 20 | swa_enable: false 21 | swa_epoch_start: 5 22 | swa_annealing_epochs: 10 23 | train_encoder_epoch: -1 24 | 25 | early_stopping: 26 | patience: 10 27 | monitor: 'val_loss' 28 | mode: 'min' 29 | 30 | checkpoint: 31 | monitor: 'val_loss' 32 | mode: 'min' 33 | 34 | trainer: 35 | gpus: -1 36 | fast_dev_run: 0 37 | deterministic: true 38 | max_epochs: 30 39 | 40 | data: 41 | datasets: ["switchboard", "fisher"] 42 | type: "sliding" 43 | sample_rate: ${model.encoder.sample_rate} 44 | audio_mono: true 45 | audio_duration: 10 46 | audio_normalize: true 47 | audio_overlap: 1 48 | # VAD 49 | vad_hz: ${model.encoder.frame_hz} 50 | vad_horizon: 2 51 | vad_history: ${model.va_cond.history} 52 | vad_history_times: [60, 30, 10, 5] 53 | # Data 54 | train_files: null 55 | val_files: null 56 | test_files: null 57 | batch_size: 16 58 | num_workers: 24 59 | 60 | hydra: 61 | run: 62 | dir: runs 63 | -------------------------------------------------------------------------------- /conv_ssl/conf/data/data.yaml: -------------------------------------------------------------------------------- 1 | datasets: ["switchboard", "fisher"] 2 | type: "sliding" 3 | sample_rate: ${model.encoder.sample_rate} 4 | audio_mono: true 5 | audio_duration: 10 6 | audio_normalize: true 7 | audio_overlap: 1 8 | # VAD 9 | vad_hz: ${model.encoder.frame_hz} 10 | vad_horizon: 2 11 | vad_history: true 12 | vad_history_times: [60, 30, 10, 5] 13 | # Data 14 | train_files: null 15 | val_files: null 16 | test_files: null 17 | batch_size: 16 18 | num_workers: 24 19 | -------------------------------------------------------------------------------- /conv_ssl/conf/events/events.yaml: -------------------------------------------------------------------------------- 1 | metric: 2 | pad: 0.05 3 | dur: 0.1 4 | pre_label_dur: 0.5 5 | onset_dur: 0.2 6 | min_context: 3.0 7 | 8 | SH: 9 | post_onset_shift: 1 10 | pre_offset_shift: 1 11 | post_onset_hold: 1 12 | pre_offset_hold: 1 13 | non_shift_horizon: 2 14 | metric_pad: ${events.metric.pad} 15 | metric_dur: ${events.metric.dur} 16 | metric_pre_label_dur: ${events.metric.pre_label_dur} 17 | metric_onset_dur: ${events.metric.onset_dur} 18 | 19 | BC: 20 | max_duration_frames: 1.0 21 | pre_silence_frames: 1.0 22 | post_silence_frames: 2.0 23 | min_duration_frames: ${events.metric.onset_dur} 24 | metric_dur_frames: ${events.metric.onset_dur} 25 | metric_pre_label_dur: ${events.metric.pre_label_dur} 26 | 27 | threshold: 28 | SL: 0.5 29 | S_pred: 0.3 30 | BC_pred: 0.1 31 | -------------------------------------------------------------------------------- /conv_ssl/conf/model/comparative.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | pretrained: true 3 | output_layer: 1 4 | sample_rate: 16000 5 | name: cpc 6 | frame_hz: 100 7 | 8 | va_cond: 9 | history: true 10 | history_bins: 5 11 | 12 | ar: 13 | type: 'transformer' 14 | dim: 256 15 | num_layers: 4 16 | num_heads: 4 17 | use_pos_emb: 1 18 | abspos: true 19 | sizeSeq: 1024 20 | dff_k: 3 21 | dropout: 0.4 22 | 23 | vap: 24 | bin_times: [.2, .4, .6, .8] 25 | type: 'comparative' 26 | pre_frames: 2 27 | bin_threshold: 0.5 28 | -------------------------------------------------------------------------------- /conv_ssl/conf/model/discrete.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | pretrained: true 3 | output_layer: 1 4 | sample_rate: 16000 5 | name: cpc 6 | frame_hz: 100 7 | freeze: True 8 | 9 | va_cond: 10 | history: true 11 | history_bins: 5 12 | 13 | ar: 14 | type: 'transformer' 15 | dim: 256 16 | num_layers: 4 17 | num_heads: 4 18 | use_pos_emb: 1 19 | abspos: true 20 | sizeSeq: 1024 21 | dff_k: 3 22 | dropout: 0.4 23 | 24 | vap: 25 | bin_times: [.2, .4, .6, .8] 26 | type: 'discrete' 27 | pre_frames: 2 28 | bin_threshold: 0.5 29 | -------------------------------------------------------------------------------- /conv_ssl/conf/model/discrete_20hz.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | pretrained: true 3 | output_layer: 1 4 | sample_rate: 16000 5 | name: cpc 6 | frame_hz: 20 7 | freeze: True 8 | downsample: 9 | kernel: [11] 10 | stride: [5] 11 | dilation: [1] 12 | dim: 256 13 | activation: "GELU" 14 | 15 | va_cond: 16 | history: true 17 | history_bins: 5 18 | 19 | ar: 20 | type: 'gpt' 21 | dim: 256 22 | num_layers: 4 23 | num_heads: 4 24 | dff_k: 3 25 | dropout: 0.4 26 | use_pos_emb: 0 # AliBI 27 | max_context: null # no max context if use_pos_emb=0 28 | abspos: null # deprecated 29 | sizeSeq: null # deprecated 30 | 31 | vap: 32 | bin_times: [.2, .4, .6, .8] 33 | type: 'discrete' 34 | pre_frames: 2 35 | bin_threshold: 0.5 36 | -------------------------------------------------------------------------------- /conv_ssl/conf/model/discrete_50hz.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | pretrained: true 3 | output_layer: 1 4 | sample_rate: 16000 5 | name: cpc 6 | frame_hz: 50 7 | freeze: True 8 | downsample: 9 | kernel: [5] 10 | stride: [2] 11 | dilation: [1] 12 | dim: 256 13 | activation: "GELU" 14 | 15 | va_cond: 16 | history: true 17 | history_bins: 5 18 | 19 | ar: 20 | type: 'gpt' 21 | dim: 256 22 | num_layers: 4 23 | num_heads: 4 24 | dff_k: 3 25 | dropout: 0.4 26 | use_pos_emb: 0 # AliBI 27 | max_context: null # no max context if use_pos_emb=0 28 | abspos: null # deprecated 29 | sizeSeq: null # deprecated 30 | 31 | vap: 32 | bin_times: [.2, .4, .6, .8] 33 | type: 'discrete' 34 | pre_frames: 2 35 | bin_threshold: 0.5 36 | -------------------------------------------------------------------------------- /conv_ssl/conf/model/independent.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | pretrained: true 3 | output_layer: 1 4 | sample_rate: 16000 5 | name: cpc 6 | frame_hz: 100 7 | 8 | va_cond: 9 | history: true 10 | history_bins: 5 11 | 12 | ar: 13 | type: 'transformer' 14 | dim: 256 15 | num_layers: 4 16 | num_heads: 4 17 | use_pos_emb: 1 18 | abspos: true 19 | sizeSeq: 1024 20 | dff_k: 3 21 | dropout: 0.4 22 | 23 | vap: 24 | bin_times: [.2, .4, .6, .8] 25 | type: 'independent' 26 | pre_frames: 2 27 | bin_threshold: 0.5 28 | -------------------------------------------------------------------------------- /conv_ssl/conf/model/independent40.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | pretrained: true 3 | output_layer: 1 4 | sample_rate: 16000 5 | name: cpc 6 | frame_hz: 100 7 | 8 | va_cond: 9 | history: true 10 | history_bins: 5 11 | 12 | ar: 13 | type: 'transformer' 14 | dim: 256 15 | num_layers: 4 16 | num_heads: 4 17 | use_pos_emb: 1 18 | abspos: true 19 | sizeSeq: 1024 20 | dff_k: 3 21 | dropout: 0.4 22 | 23 | vap: 24 | bin_times: [.5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5, .5] 25 | type: 'independent' 26 | pre_frames: 2 27 | bin_threshold: 0.5 28 | -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/0_fold_val.txt: -------------------------------------------------------------------------------- 1 | 2001 2 | 2005 3 | 2006 4 | 2007 5 | 2008 6 | 2009 7 | 2010 8 | 2012 9 | 2013 10 | 2014 11 | 2015 12 | 2017 13 | 2018 14 | 2019 15 | 2020 16 | 2023 17 | 2024 18 | 2025 19 | 2027 20 | 2028 21 | 2032 22 | 2035 23 | 2036 24 | 2038 25 | 2039 26 | 2040 27 | 2041 28 | 2044 29 | 2045 30 | 2050 31 | 2051 32 | 2053 33 | 2055 34 | 2056 35 | 2060 36 | 2061 37 | 2062 38 | 2064 39 | 2065 40 | 2071 41 | 2072 42 | 2073 43 | 2079 44 | 2080 45 | 2082 46 | 2083 47 | 2085 48 | 2086 49 | 2087 50 | 2089 51 | 2090 52 | 2091 53 | 2092 54 | 2093 55 | 2094 56 | 2095 57 | 2096 58 | 2101 59 | 2102 60 | 2104 61 | 2105 62 | 2107 63 | 2108 64 | 2109 65 | 2110 66 | 2111 67 | 2113 68 | 2114 69 | 2116 70 | 2118 71 | 2120 72 | 2121 73 | 2122 74 | 2123 75 | 2124 76 | 2125 77 | 2129 78 | 2130 79 | 2131 80 | 2136 81 | 2137 82 | 2139 83 | 2141 84 | 2145 85 | 2150 86 | 2151 87 | 2152 88 | 2153 89 | 2154 90 | 2155 91 | 2157 92 | 2158 93 | 2160 94 | 2161 95 | 2162 96 | 2163 97 | 2165 98 | 2166 99 | 2167 100 | 2168 101 | 2169 102 | 2171 103 | 2172 104 | 2173 105 | 2174 106 | 2175 107 | 2176 108 | 2177 109 | 2178 110 | 2179 111 | 2180 112 | 2181 113 | 2182 114 | 2184 115 | 2185 116 | 2186 117 | 2187 118 | 2189 119 | 2190 120 | 2191 121 | 2193 122 | 2194 123 | 2195 124 | 2196 125 | 2197 126 | 2198 127 | 2199 128 | 2201 129 | 2202 130 | 2204 131 | 2205 132 | 2206 133 | 2221 134 | 2222 135 | 2223 136 | 2224 137 | 2226 138 | 2227 139 | 2228 140 | 2229 141 | 2230 142 | 2231 143 | 2232 144 | 2233 145 | 2234 146 | 2235 147 | 2236 148 | 2237 149 | 2238 150 | 2239 151 | 2240 152 | 2241 153 | 2242 154 | 2243 155 | 2244 156 | 2245 157 | 2246 158 | 2247 159 | 2248 160 | 2249 161 | 2250 162 | 2251 163 | 2252 164 | 2253 165 | 2254 166 | 2255 167 | 2256 168 | 2257 169 | 2258 170 | 2259 171 | 2261 172 | 2262 173 | 2263 174 | 2265 175 | 2266 176 | 2267 177 | 2268 178 | 2269 179 | 2270 180 | 2271 181 | 2272 182 | 2273 183 | 2274 184 | 2275 185 | 2276 186 | 2278 187 | 2279 188 | 2280 189 | 2281 190 | 2282 191 | 2283 192 | 2284 193 | 2285 194 | 2286 195 | 2287 196 | 2288 197 | 2289 198 | 2290 199 | 2291 200 | 2292 201 | 2293 -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/10_fold_val.txt: -------------------------------------------------------------------------------- 1 | 4238 2 | 4239 3 | 4240 4 | 4241 5 | 4242 6 | 4243 7 | 4244 8 | 4245 9 | 4246 10 | 4247 11 | 4248 12 | 4249 13 | 4250 14 | 4251 15 | 4252 16 | 4253 17 | 4254 18 | 4255 19 | 4256 20 | 4257 21 | 4258 22 | 4259 23 | 4260 24 | 4261 25 | 4262 26 | 4263 27 | 4264 28 | 4266 29 | 4267 30 | 4268 31 | 4269 32 | 4270 33 | 4271 34 | 4272 35 | 4273 36 | 4274 37 | 4275 38 | 4276 39 | 4277 40 | 4278 41 | 4279 42 | 4280 43 | 4281 44 | 4282 45 | 4283 46 | 4284 47 | 4285 48 | 4286 49 | 4287 50 | 4289 51 | 4340 52 | 4341 53 | 4342 54 | 4343 55 | 4344 56 | 4345 57 | 4346 58 | 4347 59 | 4348 60 | 4349 61 | 4350 62 | 4351 63 | 4352 64 | 4353 65 | 4354 66 | 4355 67 | 4356 68 | 4357 69 | 4358 70 | 4359 71 | 4360 72 | 4361 73 | 4362 74 | 4363 75 | 4364 76 | 4365 77 | 4366 78 | 4367 79 | 4368 80 | 4369 81 | 4370 82 | 4371 83 | 4372 84 | 4373 85 | 4374 86 | 4375 87 | 4376 88 | 4377 89 | 4378 90 | 4379 91 | 4380 92 | 4381 93 | 4382 94 | 4383 95 | 4384 96 | 4387 97 | 4400 98 | 4421 99 | 4433 100 | 4440 101 | 4443 102 | 4445 103 | 4448 104 | 4467 105 | 4474 106 | 4483 107 | 4488 108 | 4493 109 | 4495 110 | 4497 111 | 4501 112 | 4502 113 | 4503 114 | 4512 115 | 4519 116 | 4522 117 | 4523 118 | 4526 119 | 4531 120 | 4540 121 | 4548 122 | 4555 123 | 4559 124 | 4565 125 | 4570 126 | 4572 127 | 4576 128 | 4588 129 | 4590 130 | 4594 131 | 4595 132 | 4603 133 | 4605 134 | 4606 135 | 4607 136 | 4608 137 | 4611 138 | 4612 139 | 4615 140 | 4616 141 | 4617 142 | 4618 143 | 4619 144 | 4622 145 | 4624 146 | 4626 147 | 4628 148 | 4629 149 | 4630 150 | 4633 151 | 4290 152 | 4291 153 | 4292 154 | 4293 155 | 4294 156 | 4295 157 | 4296 158 | 4297 159 | 4298 160 | 4299 161 | 4300 162 | 4301 163 | 4302 164 | 4303 165 | 4304 166 | 4305 167 | 4306 168 | 4307 169 | 4308 170 | 4309 171 | 4310 172 | 4311 173 | 4312 174 | 4313 175 | 4314 176 | 4315 177 | 4316 178 | 4317 179 | 4318 180 | 4319 181 | 4320 182 | 4321 183 | 4322 184 | 4323 185 | 4324 186 | 4325 187 | 4326 188 | 4327 189 | 4328 190 | 4329 191 | 4330 192 | 4331 193 | 4332 194 | 4333 195 | 4334 196 | 4335 197 | 4336 198 | 4337 199 | 4338 200 | 4339 -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/1_fold_val.txt: -------------------------------------------------------------------------------- 1 | 2294 2 | 2295 3 | 2296 4 | 2297 5 | 2298 6 | 2300 7 | 2301 8 | 2302 9 | 2303 10 | 2304 11 | 2305 12 | 2306 13 | 2307 14 | 2308 15 | 2309 16 | 2310 17 | 2311 18 | 2312 19 | 2313 20 | 2314 21 | 2315 22 | 2316 23 | 2317 24 | 2318 25 | 2319 26 | 2320 27 | 2323 28 | 2324 29 | 2325 30 | 2326 31 | 2327 32 | 2328 33 | 2329 34 | 2330 35 | 2331 36 | 2332 37 | 2333 38 | 2335 39 | 2336 40 | 2337 41 | 2338 42 | 2339 43 | 2340 44 | 2341 45 | 2342 46 | 2343 47 | 2344 48 | 2345 49 | 2346 50 | 2348 51 | 2349 52 | 2350 53 | 2352 54 | 2353 55 | 2354 56 | 2355 57 | 2356 58 | 2358 59 | 2359 60 | 2360 61 | 2361 62 | 2362 63 | 2363 64 | 2365 65 | 2366 66 | 2367 67 | 2368 68 | 2369 69 | 2370 70 | 2371 71 | 2372 72 | 2373 73 | 2374 74 | 2375 75 | 2376 76 | 2377 77 | 2378 78 | 2379 79 | 2380 80 | 2382 81 | 2383 82 | 2384 83 | 2387 84 | 2388 85 | 2389 86 | 2390 87 | 2392 88 | 2393 89 | 2394 90 | 2395 91 | 2396 92 | 2397 93 | 2398 94 | 2399 95 | 2402 96 | 2403 97 | 2404 98 | 2405 99 | 2406 100 | 2407 101 | 2408 102 | 2409 103 | 2410 104 | 2411 105 | 2413 106 | 2415 107 | 2416 108 | 2417 109 | 2418 110 | 2419 111 | 2421 112 | 2422 113 | 2423 114 | 2424 115 | 2426 116 | 2427 117 | 2428 118 | 2429 119 | 2430 120 | 2431 121 | 2433 122 | 2434 123 | 2435 124 | 2436 125 | 2437 126 | 2438 127 | 2439 128 | 2440 129 | 2441 130 | 2442 131 | 2443 132 | 2444 133 | 2445 134 | 2446 135 | 2448 136 | 2450 137 | 2451 138 | 2452 139 | 2453 140 | 2454 141 | 2455 142 | 2456 143 | 2457 144 | 2458 145 | 2459 146 | 2460 147 | 2461 148 | 2463 149 | 2464 150 | 2465 151 | 2466 152 | 2467 153 | 2468 154 | 2469 155 | 2470 156 | 2471 157 | 2472 158 | 2474 159 | 2476 160 | 2477 161 | 2478 162 | 2479 163 | 2481 164 | 2482 165 | 2483 166 | 2484 167 | 2485 168 | 2486 169 | 2487 170 | 2488 171 | 2489 172 | 2490 173 | 2492 174 | 2493 175 | 2494 176 | 2495 177 | 2496 178 | 2497 179 | 2498 180 | 2499 181 | 2500 182 | 2501 183 | 2502 184 | 2503 185 | 2504 186 | 2505 187 | 2506 188 | 2507 189 | 2508 190 | 2509 191 | 2510 192 | 2511 193 | 2513 194 | 2514 195 | 2515 196 | 2518 197 | 2519 198 | 2520 199 | 2521 200 | 2522 201 | 2523 -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/2_fold_val.txt: -------------------------------------------------------------------------------- 1 | 2524 2 | 2525 3 | 2526 4 | 2527 5 | 2528 6 | 2529 7 | 2530 8 | 2531 9 | 2532 10 | 2533 11 | 2534 12 | 2535 13 | 2536 14 | 2537 15 | 2538 16 | 2539 17 | 2540 18 | 2543 19 | 2545 20 | 2546 21 | 2547 22 | 2548 23 | 2549 24 | 2551 25 | 2552 26 | 2553 27 | 2554 28 | 2556 29 | 2557 30 | 2558 31 | 2559 32 | 2560 33 | 2561 34 | 2562 35 | 2563 36 | 2564 37 | 2565 38 | 2567 39 | 2568 40 | 2570 41 | 2571 42 | 2572 43 | 2573 44 | 2574 45 | 2575 46 | 2576 47 | 2577 48 | 2578 49 | 2579 50 | 2580 51 | 2582 52 | 2583 53 | 2586 54 | 2587 55 | 2588 56 | 2589 57 | 2591 58 | 2592 59 | 2593 60 | 2594 61 | 2595 62 | 2596 63 | 2597 64 | 2598 65 | 2599 66 | 2600 67 | 2602 68 | 2603 69 | 2604 70 | 2605 71 | 2606 72 | 2608 73 | 2609 74 | 2610 75 | 2611 76 | 2612 77 | 2613 78 | 2614 79 | 2615 80 | 2616 81 | 2617 82 | 2618 83 | 2619 84 | 2620 85 | 2621 86 | 2622 87 | 2623 88 | 2624 89 | 2625 90 | 2626 91 | 2627 92 | 2628 93 | 2629 94 | 2630 95 | 2631 96 | 2632 97 | 2633 98 | 2634 99 | 2635 100 | 2636 101 | 2637 102 | 2638 103 | 2641 104 | 2642 105 | 2643 106 | 2644 107 | 2645 108 | 2646 109 | 2647 110 | 2648 111 | 2649 112 | 2650 113 | 2651 114 | 2652 115 | 2653 116 | 2654 117 | 2656 118 | 2657 119 | 2658 120 | 2659 121 | 2661 122 | 2663 123 | 2664 124 | 2665 125 | 2666 126 | 2667 127 | 2668 128 | 2669 129 | 2670 130 | 2671 131 | 2672 132 | 2673 133 | 2674 134 | 2675 135 | 2676 136 | 2678 137 | 2679 138 | 2680 139 | 2681 140 | 2682 141 | 2684 142 | 2685 143 | 2686 144 | 2687 145 | 2688 146 | 2690 147 | 2691 148 | 2692 149 | 2693 150 | 2694 151 | 2695 152 | 2696 153 | 2697 154 | 2698 155 | 2699 156 | 2700 157 | 2702 158 | 2703 159 | 2704 160 | 2705 161 | 2706 162 | 2707 163 | 2708 164 | 2710 165 | 2711 166 | 2712 167 | 2713 168 | 2714 169 | 2715 170 | 2716 171 | 2717 172 | 2718 173 | 2719 174 | 2720 175 | 2721 176 | 2722 177 | 2723 178 | 2724 179 | 2725 180 | 2726 181 | 2727 182 | 2728 183 | 2729 184 | 2730 185 | 2731 186 | 2732 187 | 2733 188 | 2734 189 | 2735 190 | 2736 191 | 2737 192 | 2738 193 | 2739 194 | 2740 195 | 2741 196 | 2742 197 | 2743 198 | 2746 199 | 2747 200 | 2748 201 | 2749 -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/3_fold_val.txt: -------------------------------------------------------------------------------- 1 | 2750 2 | 2751 3 | 2752 4 | 2753 5 | 2754 6 | 2755 7 | 2756 8 | 2757 9 | 2758 10 | 2759 11 | 2760 12 | 2761 13 | 2762 14 | 2763 15 | 2764 16 | 2765 17 | 2766 18 | 2767 19 | 2768 20 | 2769 21 | 2770 22 | 2771 23 | 2772 24 | 2773 25 | 2774 26 | 2775 27 | 2776 28 | 2777 29 | 2778 30 | 2779 31 | 2780 32 | 2781 33 | 2782 34 | 2783 35 | 2784 36 | 2785 37 | 2786 38 | 2787 39 | 2788 40 | 2789 41 | 2790 42 | 2791 43 | 2792 44 | 2793 45 | 2794 46 | 2795 47 | 2796 48 | 2797 49 | 2798 50 | 2799 51 | 2800 52 | 2801 53 | 2802 54 | 2803 55 | 2804 56 | 2805 57 | 2806 58 | 2807 59 | 2808 60 | 2809 61 | 2810 62 | 2811 63 | 2812 64 | 2813 65 | 2814 66 | 2815 67 | 2816 68 | 2817 69 | 2818 70 | 2820 71 | 2821 72 | 2822 73 | 2823 74 | 2824 75 | 2825 76 | 2826 77 | 2827 78 | 2828 79 | 2829 80 | 2830 81 | 2831 82 | 2832 83 | 2833 84 | 2834 85 | 2835 86 | 2836 87 | 2837 88 | 2838 89 | 2839 90 | 2840 91 | 2841 92 | 2842 93 | 2843 94 | 2844 95 | 2845 96 | 2846 97 | 2847 98 | 2848 99 | 2849 100 | 2850 101 | 2851 102 | 2852 103 | 2853 104 | 2854 105 | 2855 106 | 2858 107 | 2859 108 | 2860 109 | 2861 110 | 2862 111 | 2863 112 | 2864 113 | 2865 114 | 2866 115 | 2868 116 | 2869 117 | 2870 118 | 2871 119 | 2872 120 | 2873 121 | 2874 122 | 2875 123 | 2876 124 | 2877 125 | 2878 126 | 2882 127 | 2883 128 | 2884 129 | 2885 130 | 2886 131 | 2887 132 | 2888 133 | 2889 134 | 2890 135 | 2891 136 | 2893 137 | 2895 138 | 2896 139 | 2897 140 | 2898 141 | 2899 142 | 2900 143 | 2901 144 | 2909 145 | 2910 146 | 2911 147 | 2912 148 | 2914 149 | 2915 150 | 2916 151 | 2917 152 | 2918 153 | 2919 154 | 2920 155 | 2921 156 | 2922 157 | 2923 158 | 2925 159 | 2926 160 | 2927 161 | 2928 162 | 2929 163 | 2930 164 | 2931 165 | 2932 166 | 2933 167 | 2934 168 | 2935 169 | 2936 170 | 2937 171 | 2938 172 | 2939 173 | 2940 174 | 2941 175 | 2942 176 | 2943 177 | 2944 178 | 2945 179 | 2950 180 | 2951 181 | 2952 182 | 2954 183 | 2955 184 | 2956 185 | 2957 186 | 2958 187 | 2959 188 | 2961 189 | 2962 190 | 2964 191 | 2965 192 | 2966 193 | 2967 194 | 2968 195 | 2969 196 | 2970 197 | 2973 198 | 2978 199 | 2979 200 | 2980 201 | 2981 -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/4_fold_val.txt: -------------------------------------------------------------------------------- 1 | 2982 2 | 2983 3 | 2985 4 | 2986 5 | 2987 6 | 2989 7 | 2990 8 | 2991 9 | 2992 10 | 2993 11 | 2994 12 | 2995 13 | 2996 14 | 2997 15 | 2998 16 | 3000 17 | 3001 18 | 3002 19 | 3003 20 | 3004 21 | 3005 22 | 3006 23 | 3007 24 | 3008 25 | 3009 26 | 3010 27 | 3011 28 | 3012 29 | 3013 30 | 3014 31 | 3015 32 | 3016 33 | 3017 34 | 3018 35 | 3019 36 | 3020 37 | 3021 38 | 3022 39 | 3023 40 | 3024 41 | 3025 42 | 3026 43 | 3027 44 | 3029 45 | 3030 46 | 3031 47 | 3032 48 | 3033 49 | 3034 50 | 3035 51 | 3036 52 | 3037 53 | 3038 54 | 3039 55 | 3040 56 | 3041 57 | 3042 58 | 3043 59 | 3044 60 | 3047 61 | 3048 62 | 3049 63 | 3050 64 | 3052 65 | 3053 66 | 3054 67 | 3055 68 | 3056 69 | 3057 70 | 3058 71 | 3059 72 | 3060 73 | 3061 74 | 3062 75 | 3064 76 | 3065 77 | 3066 78 | 3067 79 | 3069 80 | 3070 81 | 3071 82 | 3072 83 | 3073 84 | 3074 85 | 3075 86 | 3076 87 | 3077 88 | 3078 89 | 3079 90 | 3080 91 | 3082 92 | 3083 93 | 3084 94 | 3085 95 | 3086 96 | 3087 97 | 3088 98 | 3089 99 | 3090 100 | 3091 101 | 3092 102 | 3093 103 | 3094 104 | 3095 105 | 3096 106 | 3097 107 | 3098 108 | 3099 109 | 3100 110 | 3101 111 | 3102 112 | 3103 113 | 3104 114 | 3105 115 | 3106 116 | 3107 117 | 3108 118 | 3110 119 | 3111 120 | 3112 121 | 3113 122 | 3114 123 | 3115 124 | 3116 125 | 3117 126 | 3118 127 | 3119 128 | 3120 129 | 3121 130 | 3123 131 | 3124 132 | 3125 133 | 3126 134 | 3127 135 | 3128 136 | 3129 137 | 3130 138 | 3131 139 | 3132 140 | 3133 141 | 3134 142 | 3135 143 | 3136 144 | 3137 145 | 3138 146 | 3140 147 | 3141 148 | 3142 149 | 3143 150 | 3144 151 | 3146 152 | 3147 153 | 3148 154 | 3149 155 | 3150 156 | 3151 157 | 3152 158 | 3153 159 | 3154 160 | 3155 161 | 3156 162 | 3157 163 | 3158 164 | 3159 165 | 3160 166 | 3161 167 | 3162 168 | 3163 169 | 3164 170 | 3166 171 | 3167 172 | 3168 173 | 3169 174 | 3170 175 | 3171 176 | 3172 177 | 3173 178 | 3174 179 | 3175 180 | 3176 181 | 3177 182 | 3178 183 | 3179 184 | 3180 185 | 3181 186 | 3182 187 | 3183 188 | 3184 189 | 3185 190 | 3186 191 | 3187 192 | 3188 193 | 3189 194 | 3190 195 | 3191 196 | 3192 197 | 3193 198 | 3194 199 | 3195 200 | 3196 201 | 3197 -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/5_fold_val.txt: -------------------------------------------------------------------------------- 1 | 3198 2 | 3199 3 | 3200 4 | 3201 5 | 3202 6 | 3203 7 | 3204 8 | 3205 9 | 3206 10 | 3207 11 | 3208 12 | 3214 13 | 3215 14 | 3216 15 | 3217 16 | 3218 17 | 3219 18 | 3221 19 | 3222 20 | 3223 21 | 3224 22 | 3225 23 | 3226 24 | 3227 25 | 3228 26 | 3229 27 | 3231 28 | 3232 29 | 3233 30 | 3234 31 | 3235 32 | 3236 33 | 3237 34 | 3238 35 | 3239 36 | 3240 37 | 3241 38 | 3242 39 | 3243 40 | 3244 41 | 3246 42 | 3247 43 | 3248 44 | 3249 45 | 3250 46 | 3251 47 | 3252 48 | 3253 49 | 3254 50 | 3255 51 | 3256 52 | 3257 53 | 3258 54 | 3259 55 | 3260 56 | 3261 57 | 3262 58 | 3263 59 | 3264 60 | 3266 61 | 3267 62 | 3268 63 | 3269 64 | 3270 65 | 3271 66 | 3272 67 | 3273 68 | 3274 69 | 3275 70 | 3276 71 | 3277 72 | 3278 73 | 3279 74 | 3280 75 | 3281 76 | 3282 77 | 3283 78 | 3284 79 | 3286 80 | 3287 81 | 3288 82 | 3289 83 | 3290 84 | 3291 85 | 3292 86 | 3293 87 | 3294 88 | 3295 89 | 3296 90 | 3297 91 | 3298 92 | 3299 93 | 3300 94 | 3301 95 | 3302 96 | 3303 97 | 3304 98 | 3305 99 | 3306 100 | 3307 101 | 3308 102 | 3309 103 | 3310 104 | 3311 105 | 3312 106 | 3313 107 | 3314 108 | 3315 109 | 3316 110 | 3317 111 | 3318 112 | 3319 113 | 3320 114 | 3321 115 | 3322 116 | 3323 117 | 3325 118 | 3326 119 | 3327 120 | 3328 121 | 3329 122 | 3330 123 | 3331 124 | 3332 125 | 3333 126 | 3334 127 | 3335 128 | 3336 129 | 3337 130 | 3339 131 | 3340 132 | 3341 133 | 3342 134 | 3343 135 | 3344 136 | 3345 137 | 3346 138 | 3347 139 | 3348 140 | 3349 141 | 3351 142 | 3352 143 | 3353 144 | 3354 145 | 3355 146 | 3356 147 | 3358 148 | 3359 149 | 3360 150 | 3361 151 | 3362 152 | 3363 153 | 3365 154 | 3366 155 | 3367 156 | 3368 157 | 3369 158 | 3370 159 | 3371 160 | 3372 161 | 3374 162 | 3375 163 | 3377 164 | 3378 165 | 3379 166 | 3380 167 | 3381 168 | 3382 169 | 3383 170 | 3384 171 | 3385 172 | 3386 173 | 3387 174 | 3388 175 | 3389 176 | 3390 177 | 3391 178 | 3392 179 | 3393 180 | 3394 181 | 3395 182 | 3397 183 | 3398 184 | 3399 185 | 3401 186 | 3402 187 | 3403 188 | 3404 189 | 3405 190 | 3406 191 | 3407 192 | 3408 193 | 3409 194 | 3410 195 | 3411 196 | 3412 197 | 3413 198 | 3414 199 | 3415 200 | 3416 -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/6_fold_val.txt: -------------------------------------------------------------------------------- 1 | 3417 2 | 3418 3 | 3419 4 | 3420 5 | 3421 6 | 3422 7 | 3423 8 | 3424 9 | 3425 10 | 3426 11 | 3427 12 | 3428 13 | 3429 14 | 3430 15 | 3431 16 | 3432 17 | 3433 18 | 3434 19 | 3435 20 | 3436 21 | 3437 22 | 3438 23 | 3439 24 | 3440 25 | 3441 26 | 3442 27 | 3444 28 | 3445 29 | 3446 30 | 3447 31 | 3448 32 | 3449 33 | 3450 34 | 3451 35 | 3452 36 | 3453 37 | 3454 38 | 3455 39 | 3456 40 | 3457 41 | 3458 42 | 3459 43 | 3460 44 | 3461 45 | 3462 46 | 3463 47 | 3464 48 | 3465 49 | 3466 50 | 3467 51 | 3468 52 | 3469 53 | 3470 54 | 3471 55 | 3472 56 | 3473 57 | 3474 58 | 3475 59 | 3476 60 | 3477 61 | 3478 62 | 3479 63 | 3480 64 | 3481 65 | 3482 66 | 3483 67 | 3484 68 | 3485 69 | 3486 70 | 3487 71 | 3488 72 | 3489 73 | 3491 74 | 3492 75 | 3493 76 | 3494 77 | 3495 78 | 3496 79 | 3497 80 | 3498 81 | 3499 82 | 3500 83 | 3501 84 | 3502 85 | 3503 86 | 3504 87 | 3505 88 | 3506 89 | 3507 90 | 3508 91 | 3509 92 | 3510 93 | 3511 94 | 3512 95 | 3513 96 | 3514 97 | 3515 98 | 3516 99 | 3517 100 | 3518 101 | 3519 102 | 3520 103 | 3521 104 | 3522 105 | 3523 106 | 3524 107 | 3525 108 | 3526 109 | 3527 110 | 3528 111 | 3529 112 | 3530 113 | 3531 114 | 3532 115 | 3533 116 | 3534 117 | 3535 118 | 3536 119 | 3537 120 | 3538 121 | 3540 122 | 3541 123 | 3542 124 | 3543 125 | 3544 126 | 3545 127 | 3546 128 | 3547 129 | 3548 130 | 3550 131 | 3551 132 | 3552 133 | 3553 134 | 3554 135 | 3555 136 | 3556 137 | 3557 138 | 3558 139 | 3559 140 | 3560 141 | 3561 142 | 3562 143 | 3563 144 | 3564 145 | 3565 146 | 3566 147 | 3567 148 | 3568 149 | 3569 150 | 3570 151 | 3571 152 | 3572 153 | 3573 154 | 3574 155 | 3575 156 | 3576 157 | 3577 158 | 3578 159 | 3579 160 | 3580 161 | 3581 162 | 3582 163 | 3583 164 | 3584 165 | 3585 166 | 3586 167 | 3587 168 | 3588 169 | 3589 170 | 3590 171 | 3591 172 | 3592 173 | 3593 174 | 3594 175 | 3595 176 | 3596 177 | 3597 178 | 3598 179 | 3599 180 | 3600 181 | 3601 182 | 3602 183 | 3603 184 | 3604 185 | 3605 186 | 3606 187 | 3607 188 | 3608 189 | 3609 190 | 3610 191 | 3611 192 | 3612 193 | 3614 194 | 3615 195 | 3616 196 | 3617 197 | 3618 198 | 3619 199 | 3620 200 | 3621 -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/7_fold_val.txt: -------------------------------------------------------------------------------- 1 | 3622 2 | 3623 3 | 3624 4 | 3625 5 | 3626 6 | 3627 7 | 3628 8 | 3630 9 | 3631 10 | 3632 11 | 3633 12 | 3635 13 | 3636 14 | 3637 15 | 3638 16 | 3639 17 | 3640 18 | 3641 19 | 3642 20 | 3643 21 | 3644 22 | 3645 23 | 3646 24 | 3647 25 | 3648 26 | 3649 27 | 3650 28 | 3651 29 | 3652 30 | 3653 31 | 3654 32 | 3655 33 | 3656 34 | 3658 35 | 3659 36 | 3660 37 | 3661 38 | 3662 39 | 3663 40 | 3664 41 | 3665 42 | 3666 43 | 3667 44 | 3668 45 | 3669 46 | 3670 47 | 3671 48 | 3672 49 | 3673 50 | 3674 51 | 3675 52 | 3676 53 | 3677 54 | 3678 55 | 3679 56 | 3680 57 | 3681 58 | 3682 59 | 3683 60 | 3684 61 | 3685 62 | 3686 63 | 3687 64 | 3688 65 | 3689 66 | 3690 67 | 3691 68 | 3692 69 | 3693 70 | 3694 71 | 3695 72 | 3696 73 | 3697 74 | 3698 75 | 3699 76 | 3700 77 | 3701 78 | 3702 79 | 3703 80 | 3704 81 | 3705 82 | 3706 83 | 3707 84 | 3708 85 | 3709 86 | 3710 87 | 3711 88 | 3713 89 | 3714 90 | 3715 91 | 3716 92 | 3717 93 | 3718 94 | 3719 95 | 3720 96 | 3721 97 | 3722 98 | 3723 99 | 3724 100 | 3725 101 | 3726 102 | 3727 103 | 3728 104 | 3729 105 | 3730 106 | 3731 107 | 3732 108 | 3733 109 | 3734 110 | 3735 111 | 3736 112 | 3737 113 | 3738 114 | 3739 115 | 3740 116 | 3741 117 | 3742 118 | 3743 119 | 3744 120 | 3745 121 | 3746 122 | 3747 123 | 3748 124 | 3749 125 | 3750 126 | 3751 127 | 3752 128 | 3753 129 | 3754 130 | 3755 131 | 3756 132 | 3757 133 | 3758 134 | 3759 135 | 3760 136 | 3761 137 | 3762 138 | 3763 139 | 3764 140 | 3765 141 | 3766 142 | 3767 143 | 3768 144 | 3769 145 | 3770 146 | 3771 147 | 3772 148 | 3773 149 | 3774 150 | 3775 151 | 3776 152 | 3777 153 | 3778 154 | 3779 155 | 3780 156 | 3781 157 | 3782 158 | 3783 159 | 3784 160 | 3785 161 | 3786 162 | 3787 163 | 3788 164 | 3789 165 | 3790 166 | 3791 167 | 3792 168 | 3793 169 | 3794 170 | 3795 171 | 3796 172 | 3797 173 | 3798 174 | 3799 175 | 3800 176 | 3801 177 | 3802 178 | 3803 179 | 3804 180 | 3805 181 | 3806 182 | 3807 183 | 3808 184 | 3809 185 | 3810 186 | 3811 187 | 3812 188 | 3813 189 | 3814 190 | 3815 191 | 3816 192 | 3817 193 | 3818 194 | 3819 195 | 3820 196 | 3821 197 | 3822 198 | 3823 199 | 3824 200 | 3825 -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/8_fold_val.txt: -------------------------------------------------------------------------------- 1 | 3826 2 | 3827 3 | 3828 4 | 3829 5 | 3830 6 | 3831 7 | 3832 8 | 3833 9 | 3834 10 | 3836 11 | 3837 12 | 3838 13 | 3839 14 | 3840 15 | 3841 16 | 3842 17 | 3843 18 | 3844 19 | 3845 20 | 3846 21 | 3847 22 | 3848 23 | 3849 24 | 3850 25 | 3851 26 | 3852 27 | 3853 28 | 3854 29 | 3855 30 | 3856 31 | 3857 32 | 3858 33 | 3859 34 | 3860 35 | 3861 36 | 3862 37 | 3863 38 | 3864 39 | 3865 40 | 3866 41 | 3867 42 | 3868 43 | 3869 44 | 3870 45 | 3871 46 | 3872 47 | 3873 48 | 3874 49 | 3875 50 | 3876 51 | 3877 52 | 3878 53 | 3879 54 | 3880 55 | 3881 56 | 3882 57 | 3883 58 | 3884 59 | 3885 60 | 3886 61 | 3887 62 | 3888 63 | 3889 64 | 3890 65 | 3891 66 | 3892 67 | 3893 68 | 3894 69 | 3895 70 | 3896 71 | 3897 72 | 3898 73 | 3899 74 | 3900 75 | 3901 76 | 3902 77 | 3903 78 | 3904 79 | 3905 80 | 3906 81 | 3907 82 | 3908 83 | 3909 84 | 3910 85 | 3911 86 | 3912 87 | 3913 88 | 3914 89 | 3915 90 | 3916 91 | 3917 92 | 3918 93 | 3919 94 | 3920 95 | 3921 96 | 3922 97 | 3923 98 | 3924 99 | 3925 100 | 3926 101 | 3927 102 | 3928 103 | 3929 104 | 3930 105 | 3931 106 | 3932 107 | 3933 108 | 3934 109 | 3935 110 | 3936 111 | 3937 112 | 3938 113 | 3939 114 | 3940 115 | 3941 116 | 3942 117 | 3943 118 | 3944 119 | 3945 120 | 3946 121 | 3947 122 | 3948 123 | 3949 124 | 3950 125 | 3951 126 | 3952 127 | 3953 128 | 3954 129 | 3955 130 | 3956 131 | 3957 132 | 3958 133 | 3959 134 | 3960 135 | 3961 136 | 3962 137 | 3963 138 | 3964 139 | 3965 140 | 3966 141 | 3967 142 | 3969 143 | 3970 144 | 3971 145 | 3972 146 | 3973 147 | 3974 148 | 3975 149 | 3976 150 | 3977 151 | 3979 152 | 3980 153 | 3981 154 | 3982 155 | 3983 156 | 3984 157 | 3985 158 | 3986 159 | 3987 160 | 3988 161 | 3989 162 | 3990 163 | 3991 164 | 3992 165 | 3993 166 | 3994 167 | 3996 168 | 3998 169 | 3999 170 | 4000 171 | 4001 172 | 4002 173 | 4005 174 | 4006 175 | 4007 176 | 4008 177 | 4009 178 | 4010 179 | 4011 180 | 4012 181 | 4013 182 | 4014 183 | 4015 184 | 4016 185 | 4017 186 | 4018 187 | 4019 188 | 4020 189 | 4021 190 | 4022 191 | 4023 192 | 4024 193 | 4025 194 | 4026 195 | 4027 196 | 4028 197 | 4029 198 | 4030 199 | 4032 200 | 4033 -------------------------------------------------------------------------------- /conv_ssl/conf/swb_kfolds/9_fold_val.txt: -------------------------------------------------------------------------------- 1 | 4034 2 | 4035 3 | 4036 4 | 4037 5 | 4038 6 | 4039 7 | 4040 8 | 4041 9 | 4042 10 | 4043 11 | 4044 12 | 4045 13 | 4046 14 | 4047 15 | 4048 16 | 4049 17 | 4050 18 | 4051 19 | 4052 20 | 4053 21 | 4054 22 | 4055 23 | 4056 24 | 4057 25 | 4058 26 | 4059 27 | 4060 28 | 4062 29 | 4063 30 | 4064 31 | 4065 32 | 4066 33 | 4067 34 | 4068 35 | 4069 36 | 4071 37 | 4072 38 | 4073 39 | 4074 40 | 4075 41 | 4076 42 | 4077 43 | 4078 44 | 4079 45 | 4080 46 | 4081 47 | 4082 48 | 4083 49 | 4084 50 | 4085 51 | 4087 52 | 4088 53 | 4089 54 | 4090 55 | 4091 56 | 4092 57 | 4093 58 | 4094 59 | 4095 60 | 4096 61 | 4097 62 | 4098 63 | 4099 64 | 4100 65 | 4101 66 | 4102 67 | 4103 68 | 4104 69 | 4105 70 | 4106 71 | 4107 72 | 4108 73 | 4109 74 | 4110 75 | 4111 76 | 4112 77 | 4113 78 | 4114 79 | 4115 80 | 4116 81 | 4117 82 | 4118 83 | 4119 84 | 4120 85 | 4121 86 | 4122 87 | 4123 88 | 4124 89 | 4125 90 | 4126 91 | 4127 92 | 4128 93 | 4129 94 | 4130 95 | 4131 96 | 4132 97 | 4133 98 | 4134 99 | 4135 100 | 4136 101 | 4137 102 | 4138 103 | 4139 104 | 4140 105 | 4141 106 | 4142 107 | 4143 108 | 4144 109 | 4145 110 | 4146 111 | 4147 112 | 4148 113 | 4149 114 | 4150 115 | 4151 116 | 4152 117 | 4153 118 | 4154 119 | 4155 120 | 4156 121 | 4157 122 | 4158 123 | 4159 124 | 4160 125 | 4161 126 | 4162 127 | 4163 128 | 4164 129 | 4165 130 | 4166 131 | 4167 132 | 4168 133 | 4169 134 | 4170 135 | 4171 136 | 4172 137 | 4173 138 | 4174 139 | 4175 140 | 4176 141 | 4177 142 | 4178 143 | 4179 144 | 4180 145 | 4181 146 | 4182 147 | 4183 148 | 4184 149 | 4185 150 | 4186 151 | 4187 152 | 4188 153 | 4189 154 | 4190 155 | 4191 156 | 4192 157 | 4193 158 | 4194 159 | 4195 160 | 4196 161 | 4197 162 | 4198 163 | 4199 164 | 4200 165 | 4201 166 | 4202 167 | 4203 168 | 4204 169 | 4205 170 | 4206 171 | 4207 172 | 4208 173 | 4209 174 | 4210 175 | 4211 176 | 4212 177 | 4213 178 | 4214 179 | 4215 180 | 4216 181 | 4217 182 | 4218 183 | 4220 184 | 4221 185 | 4222 186 | 4223 187 | 4224 188 | 4225 189 | 4226 190 | 4227 191 | 4228 192 | 4229 193 | 4230 194 | 4231 195 | 4232 196 | 4233 197 | 4234 198 | 4235 199 | 4236 200 | 4237 -------------------------------------------------------------------------------- /conv_ssl/conf/trainer/trainer.yaml: -------------------------------------------------------------------------------- 1 | gpus: -1 2 | fast_dev_run: 0 3 | deterministic: true 4 | max_epochs: 30 5 | -------------------------------------------------------------------------------- /conv_ssl/conf/vap/vap.yaml: -------------------------------------------------------------------------------- 1 | bin_times: [.2, .4, .6, .8] 2 | type: 'discrete' 3 | pre_frames: 2 4 | bin_threshold: 0.5 5 | -------------------------------------------------------------------------------- /conv_ssl/datamodule_disk.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from glob import glob 3 | from os import cpu_count 4 | 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | import pytorch_lightning as pl 8 | 9 | from conv_ssl.utils import read_txt 10 | 11 | 12 | class DiskDataset(Dataset): 13 | def __init__(self, root) -> None: 14 | super().__init__() 15 | self.root = root 16 | self.sample_paths = glob(join(root, "*.pt")) 17 | 18 | def __len__(self): 19 | return len(self.sample_paths) 20 | 21 | def __getitem__(self, idx): 22 | return torch.load(self.sample_paths[idx]) 23 | 24 | 25 | class DiskDM(pl.LightningDataModule): 26 | def __init__(self, root, batch_size=4, num_workers=0): 27 | super().__init__() 28 | self.root = root 29 | self.train_path = join(root, "train") 30 | self.val_path = join(root, "val") 31 | self.test_path = join(root, "test") 32 | 33 | self.batch_size = batch_size 34 | self.num_workers = num_workers 35 | 36 | def setup(self, stage="fit"): 37 | if stage == "test": 38 | self.test_dset = DiskDataset(self.test_path) 39 | else: 40 | self.train_dset = DiskDataset(self.train_path) 41 | self.val_dset = DiskDataset(self.val_path) 42 | 43 | def train_dataloader(self): 44 | return DataLoader( 45 | self.train_dset, 46 | batch_size=self.batch_size, 47 | num_workers=self.num_workers, 48 | pin_memory=True, 49 | shuffle=True, 50 | ) 51 | 52 | def val_dataloader(self): 53 | return DataLoader( 54 | self.val_dset, 55 | batch_size=self.batch_size, 56 | num_workers=self.num_workers, 57 | pin_memory=True, 58 | shuffle=False, 59 | ) 60 | 61 | def test_dataloader(self): 62 | return DataLoader( 63 | self.test_dset, 64 | batch_size=self.batch_size, 65 | num_workers=self.num_workers, 66 | pin_memory=True, 67 | shuffle=False, 68 | ) 69 | 70 | 71 | class DiskDatasetFiles(Dataset): 72 | def __init__(self, sample_paths) -> None: 73 | super().__init__() 74 | self.sample_paths = sample_paths 75 | 76 | def __len__(self): 77 | return len(self.sample_paths) 78 | 79 | def __getitem__(self, idx): 80 | return torch.load(self.sample_paths[idx]) 81 | 82 | 83 | class DiskDMFiles(pl.LightningDataModule): 84 | def __init__( 85 | self, 86 | root, 87 | train_files=None, 88 | val_files=None, 89 | test_files=None, 90 | batch_size=4, 91 | num_workers=0, 92 | ): 93 | super().__init__() 94 | self.root = root 95 | self.train_files = train_files 96 | self.val_files = val_files 97 | self.test_files = test_files 98 | 99 | self.batch_size = batch_size 100 | self.num_workers = num_workers 101 | 102 | def init_paths(self, files): 103 | sample_paths = [] 104 | 105 | sessions = read_txt(files) 106 | for session_number in sessions: 107 | tmp_paths = glob(join(self.root, f"{session_number}*.pt")) 108 | for p in tmp_paths: 109 | sample_paths.append(p) 110 | return sample_paths 111 | 112 | def setup(self, stage="fit"): 113 | if stage == "test": 114 | self.test_paths = self.init_paths(self.test_files) 115 | self.test_dset = DiskDatasetFiles(self.test_paths) 116 | else: 117 | self.train_paths = self.init_paths(self.train_files) 118 | self.val_paths = self.init_paths(self.val_files) 119 | 120 | self.train_dset = DiskDatasetFiles(self.train_paths) 121 | self.val_dset = DiskDatasetFiles(self.val_paths) 122 | 123 | def train_dataloader(self): 124 | return DataLoader( 125 | self.train_dset, 126 | batch_size=self.batch_size, 127 | num_workers=self.num_workers, 128 | pin_memory=True, 129 | shuffle=True, 130 | ) 131 | 132 | def val_dataloader(self): 133 | return DataLoader( 134 | self.val_dset, 135 | batch_size=self.batch_size, 136 | num_workers=self.num_workers, 137 | pin_memory=True, 138 | shuffle=False, 139 | ) 140 | 141 | def test_dataloader(self): 142 | return DataLoader( 143 | self.test_dset, 144 | batch_size=self.batch_size, 145 | num_workers=self.num_workers, 146 | pin_memory=True, 147 | shuffle=False, 148 | ) 149 | 150 | @staticmethod 151 | def add_data_specific_args(parent_parser): 152 | """argparse arguments for SoSIModel (based on yaml-config)""" 153 | parser = parent_parser.add_argument_group("DataModule from disk") 154 | parser.add_argument("--data_root", default="swb_dataset", type=str) 155 | parser.add_argument("--train_files", default=None, type=str) 156 | parser.add_argument("--val_files", default=None, type=str) 157 | parser.add_argument("--test_files", default=None, type=str) 158 | parser.add_argument("--batch_size", default=4, type=int) 159 | parser.add_argument("--num_workers", default=cpu_count(), type=int) 160 | return parent_parser 161 | 162 | 163 | if __name__ == "__main__": 164 | 165 | dm = DiskDM(root="swb_dataset") 166 | dm.setup() 167 | 168 | batch = next(iter(dm.train_dataloader())) 169 | for k, v in batch.items(): 170 | if isinstance(v, torch.Tensor): 171 | print(f"{k}: {tuple(v.shape)}") 172 | else: 173 | print(f"{k}: {v}") 174 | 175 | batch = next(iter(dm.val_dataloader())) 176 | for k, v in batch.items(): 177 | if isinstance(v, torch.Tensor): 178 | print(f"{k}: {tuple(v.shape)}") 179 | else: 180 | print(f"{k}: {v}") 181 | 182 | for batch in dm.train_dataloader(): 183 | pass 184 | 185 | for batch in dm.val_dataloader(): 186 | pass 187 | 188 | conf_root = "/home/erik/projects/conv_ssl/conv_ssl/config/swb_kfolds" 189 | train_files = join(conf_root, "1_fold_train.txt") 190 | val_files = join(conf_root, "1_fold_val.txt") 191 | dm = DiskDMFiles(train_files=train_files, val_files=val_files, root="swb_dataset") 192 | dm.setup() 193 | print("train: ", len(dm.val_dataloader()), len(dm.train_dataloader())) 194 | batch = next(iter(dm.val_dataloader())) 195 | 196 | conf_root = "/home/erik/projects/conv_ssl/conv_ssl/config/swb_kfolds" 197 | train_files = join(conf_root, "3_fold_train.txt") 198 | val_files = join(conf_root, "3_fold_val.txt") 199 | dm = DiskDMFiles(train_files=train_files, val_files=val_files, root="swb_dataset") 200 | dm.setup() 201 | print("train: ", len(dm.val_dataloader()), len(dm.train_dataloader())) 202 | batch = next(iter(dm.val_dataloader())) 203 | -------------------------------------------------------------------------------- /conv_ssl/dataset_save_samples_to_disk.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from os import cpu_count, makedirs 3 | from os.path import join 4 | 5 | import torch 6 | from datasets_turntaking import DialogAudioDM 7 | 8 | from conv_ssl.utils import write_json 9 | 10 | 11 | def save_samples(dm, root, max_batches=-1): 12 | makedirs(root, exist_ok=True) 13 | file_map = {} 14 | for dloader in [dm.val_dataloader(), dm.train_dataloader()]: 15 | for ii, batch in enumerate(dloader): 16 | if max_batches > 0 and ii == max_batches: 17 | break 18 | batch_size = batch["waveform"].shape[0] 19 | for i in range(batch_size): 20 | session = batch["session"][i] 21 | if session not in file_map: 22 | file_map[session] = -1 23 | file_map[session] += 1 24 | n = file_map[session] 25 | sample = { 26 | "waveform": batch["waveform"][i], 27 | "vad": batch["vad"][i], 28 | "vad_history": batch["vad_history"][i], 29 | "dset_name": batch["dset_name"][i], 30 | "session": batch["session"][i], 31 | } 32 | name = f"{session}_{n}.pt" 33 | torch.save(sample, join(root, name)) 34 | write_json(file_map, join(root, "file_map.json")) 35 | 36 | 37 | if __name__ == "__main__": 38 | 39 | parser = ArgumentParser() 40 | parser.add_argument("--dirpath", default="swb_dataset", type=str) 41 | parser.add_argument("--hz", default=100, type=int) 42 | parser.add_argument("--duration", default=10, type=float) 43 | parser.add_argument("--sample_rate", default=16000, type=int) 44 | parser.add_argument("--horizon", default=3, type=float) 45 | parser.add_argument("--batch_size", default=20, type=int) 46 | parser.add_argument("--num_workers", default=cpu_count(), type=int) 47 | parser.add_argument("--max_batches", default=-1, type=int) 48 | args = parser.parse_args() 49 | for k, v in vars(args).items(): 50 | print(f"{k}: {v}") 51 | 52 | data_conf = DialogAudioDM.load_config() 53 | DialogAudioDM.print_dm(data_conf, args) 54 | dm = DialogAudioDM( 55 | datasets=data_conf["dataset"]["datasets"], 56 | type=data_conf["dataset"]["type"], 57 | audio_duration=data_conf["dataset"]["audio_duration"], 58 | audio_normalize=data_conf["dataset"]["audio_normalize"], 59 | audio_overlap=data_conf["dataset"]["audio_overlap"], 60 | sample_rate=data_conf["dataset"]["sample_rate"], 61 | vad_hz=100, 62 | vad_horizon=2, 63 | vad_history=data_conf["dataset"]["vad_history"], 64 | vad_history_times=data_conf["dataset"]["vad_history_times"], 65 | batch_size=args.batch_size, 66 | num_workers=args.num_workers, 67 | ) 68 | dm.prepare_data() 69 | dm.setup() 70 | 71 | save_samples(dm, args.dirpath, max_batches=args.max_batches) 72 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | 4 | 1. Generate Audio: `python conv_ssl/evaluation/tts.py` 5 | 2. Aligner: 6 | * `python conv_ssl/evaluation/prepare_phrases_for_alignment.py` 7 | - puts .txt files with corresponding words suitable for montreal aligner 8 | * `bash conv_ssl/evaluation/forced_alignment.bash` 9 | - align files (conda activate env with mfa montreal-forced-aligner) 10 | 4. Add VAD: `python conv_ssl/evaluation/vad.py` 11 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ErikEkstedt/conv_ssl/c365345afff3df33c791c6fc9d498bc08617ffb7/conv_ssl/evaluation/__init__.py -------------------------------------------------------------------------------- /conv_ssl/evaluation/anova.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from glob import glob 3 | 4 | import torch 5 | import scipy.stats as stats 6 | 7 | from conv_ssl.utils import read_json 8 | 9 | 10 | METRIC_2_STATS = { 11 | "f1_hold_shift": "SH", 12 | "f1_short_long": "SL", 13 | "f1_predict_shift": "S-pred", 14 | "f1_bc_prediction": "BC-pred", 15 | } 16 | 17 | METRIC_NAMES = [ 18 | "f1_hold_shift", 19 | "f1_short_long", 20 | "f1_predict_shift", 21 | "f1_bc_prediction", 22 | ] 23 | 24 | STATS_NAMES = ["SH", "SL", "S-pred", "BC-pred"] 25 | 26 | 27 | def load_all_scores(root="assets/paper_evaluation"): 28 | """ 29 | load scores create by `evaluate_paper_models.py` 30 | 31 | -------------------------------- 32 | root/ 33 | └── discrete/ 34 | ├── kfold_0 35 | ├── ... 36 | └── kfold_11 37 | └── independent/ 38 | ├── kfold_0 39 | ├── ... 40 | └── kfold_11 41 | └── independent_baseline/ 42 | ├── kfold_0 43 | ├── ... 44 | └── kfold_11 45 | -------------------------------- 46 | 47 | """ 48 | all_score = {} 49 | for model_type in ["discrete", "independent", "independent_baseline"]: 50 | # model_type_dir = join(root, model_type) 51 | model_dict = {} 52 | for metric_filepath in glob(join(root, model_type, "**/metric.json")): 53 | r = read_json(metric_filepath) 54 | for tmp_metric, val in r.items(): 55 | if "threshold" in tmp_metric: 56 | continue 57 | if "loss" in tmp_metric: 58 | continue 59 | if isinstance(val, dict): 60 | # add shift/hold f1 61 | continue 62 | tmp_stat = METRIC_2_STATS[tmp_metric] 63 | if tmp_stat in model_dict: 64 | model_dict[tmp_stat].append(val) 65 | else: 66 | model_dict[tmp_stat] = [val] 67 | all_score[model_type] = model_dict 68 | return all_score 69 | 70 | 71 | def anova(all_score): 72 | statistics = {name: {} for name in STATS_NAMES} 73 | averages = {} 74 | for stat_name in STATS_NAMES: 75 | # Get score for the different groups 76 | m_discrete = all_score["discrete"][stat_name] 77 | m_ind = all_score["independent"][stat_name] 78 | m_ind_base = all_score["independent_baseline"][stat_name] 79 | 80 | anova_result = stats.f_oneway(m_discrete, m_ind, m_ind_base) 81 | statistics[stat_name] = anova_result.pvalue 82 | averages[stat_name] = { 83 | "discrete": torch.tensor(m_discrete).mean().item(), 84 | "independent": torch.tensor(m_ind).mean().item(), 85 | "independent_baseline": torch.tensor(m_ind_base).mean().item(), 86 | } 87 | 88 | # ad-hoc test 89 | d_vs_i = stats.ttest_ind(m_discrete, m_ind) 90 | d_vs_ib = stats.ttest_ind(m_discrete, m_ind_base) 91 | statistics[f"{stat_name}_t_test_d_vs_i"] = d_vs_i.pvalue 92 | statistics[f"{stat_name}_t_test_d_vs_ib"] = d_vs_ib.pvalue 93 | 94 | return statistics 95 | 96 | 97 | if __name__ == "__main__": 98 | scores = load_all_scores() 99 | # for stat, vals in scores["discrete"].items(): 100 | # vals = torch.tensor(vals) 101 | # scores["independent"][stat] = ( 102 | # vals - 0.02 * torch.rand_like(vals).abs() 103 | # ).tolist() 104 | # scores["independent_baseline"][stat] = ( 105 | # vals - 0.04 * torch.rand_like(vals).abs() 106 | # ).tolist() 107 | statistics = anova(scores) # pvalues 108 | 109 | for k, v in statistics.items(): 110 | print(f"{k}: {v}") 111 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/duration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | 4 | from os.path import join, basename 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | from parselmouth.praat import call 8 | import parselmouth 9 | import matplotlib.pyplot as plt 10 | import textgrids 11 | 12 | from conv_ssl.augmentations import torch_to_praat_sound, praat_to_torch 13 | from conv_ssl.utils import read_txt 14 | 15 | """ 16 | * https://github.com/Legisign/Praat-textgrids 17 | - clone and install `pip install -e .` 18 | - textgrid 19 | * https://github.com/prosegrinder/python-cmudict 20 | - `pip install cmudict` 21 | - syllables 22 | """ 23 | 24 | 25 | EXAMPLE_TO_TARGET_WORD = { 26 | "student": "student", 27 | "psychology": "psychology", 28 | "first_year": "student", 29 | "basketball": "basketball", 30 | "experiment": "before", 31 | "live": "yourself", 32 | "work": "side", 33 | "bike": "bike", 34 | "drive": "here", 35 | } 36 | 37 | # Phones extracted by cmudict 38 | # last syllable decided by me... hope its correct 39 | LAST_SYLLABLE = { 40 | "student": [["D", "AH0", "N", "T"]], 41 | "psychology": [["JH", "IY0"]], 42 | # "year": ["Y", "IH1", "R"]], 43 | "basketball": [["B", "AO2", "L"]], 44 | # "experiments": ["M", "AH0", "N", "T", "S"], 45 | "before": [["B", "IH0", "F", "AO1", "R"], ["B", "IY2", "F", "AO1", "R"]], 46 | # "live": ["L", "IH1", "V"], 47 | "yourself": [ 48 | ["Y", "ER0", "S", "EH1", "L", "F"], 49 | ["Y", "UH0", "R", "S", "EH1", "L", "F"], 50 | ["Y", "AO1", "R", "S", "EH0", "L", "F"], 51 | ], 52 | # "work": ["W", "ER1", "K"], 53 | "side": [["S", "AY1", "D"]], 54 | "bike": [["B", "AY1", "K"]], 55 | # "drive": ["D", "R", "AY1", "V"], 56 | "here": [["HH", "IY1", "R"]], 57 | } 58 | 59 | 60 | # TODO: must match with new `read_text_grid` 61 | def match_duration( 62 | long_waveform, short_phones, long_phones, sample_rate=16000, eps=1e-5, verbose=False 63 | ): 64 | """ 65 | https://www.fon.hum.uva.nl/praat/manual/Intro_8_2__Manipulation_of_duration.html 66 | https://www.fon.hum.uva.nl/praat/manual/DurationTier.html 67 | 68 | One of the types of objects in Praat. A DurationTier object contains a 69 | number of (time, duration) points, where duration is to be interpreted 70 | as a relative duration (e.g. the duration of a manipulated sound as 71 | compared to the duration of the original). For instance, if 72 | your DurationTier contains two points, one with a duration value of 1.5 73 | at a time of 0.5 seconds and one with a duration value of 0.6 at a time 74 | of 1.1 seconds, this is to be interpreted as a relative duration of 1.5 75 | (i.e. a slowing down) for all original times before 0.5 seconds, a 76 | relative duration of 0.6 (i.e. a speeding up) for all original times 77 | after 1.1 seconds, and a linear interpolation between 0.5 and 1.1 78 | seconds (e.g. a relative duration of 1.2 at 0.7 seconds, and of 0.9 at 0.9 seconds). 79 | 80 | 81 | Match the first phoneme duration of "short" in "long". 82 | Some example get different phonemes/alignment (they are pronounced differently naturally) 83 | and if this occurs (3 times in the data) then we simply stop at the last matching phoneme. 84 | 85 | """ 86 | change = [] 87 | for ps, pl in zip(short_phones["intervals"], long_phones["intervals"]): 88 | pps = ps[-1] 89 | ppl = pl[-1] 90 | if pps != ppl: 91 | # if the phonemes no longer match we break 92 | continue 93 | long_dur = pl[1] - pl[0] 94 | short_dur = ps[1] - ps[0] 95 | ratio = short_dur / long_dur 96 | if ratio == 1: 97 | continue 98 | change.append([pl[0], pl[1], ratio]) 99 | 100 | sound = torch_to_praat_sound(long_waveform, sample_rate=sample_rate) 101 | manipulation = call(sound, "To Manipulation", 0.01, 60, 400) 102 | 103 | # add the last chunk to keep duration as is 104 | dur_tier = call( 105 | manipulation, 106 | "Create DurationTier", 107 | "shorten", 108 | sound.start_time, 109 | sound.end_time, 110 | ) 111 | 112 | # before this point duration should be the same 113 | try: 114 | if change[0][0] > 0: 115 | call(dur_tier, "Add point", change[0][0] - eps, 1.0) 116 | if verbose: 117 | print(f'call(dur_tier, "Add point", {change[0][0]-eps}, 1.)') 118 | except: 119 | print(change) 120 | for ps, pl in zip(short_phones["intervals"], long_phones["intervals"]): 121 | print(ps[-1], pl[-1]) 122 | input() 123 | 124 | for s, e, r in change: 125 | call(dur_tier, "Add point", s, r) 126 | call(dur_tier, "Add point", e, r) 127 | if verbose: 128 | print(f'call(dur_tier, "Add point", {s}, {r})') 129 | print(f'call(dur_tier, "Add point", {e}, {r})') 130 | 131 | # After this point duration should be the same 132 | call(dur_tier, "Add point", change[-1][1] + eps, 1.0) 133 | if verbose: 134 | print(f'call(dur_tier, "Add point", {change[-1][1]+eps}, 1.)') 135 | 136 | call([manipulation, dur_tier], "Replace duration tier") 137 | sound_dur = call(manipulation, "Get resynthesis (overlap-add)") 138 | y = sound_dur.as_array().astype("float32") 139 | return torch.from_numpy(y) 140 | 141 | 142 | def read_text_grid(path): 143 | grid = textgrids.TextGrid(path) 144 | data = {"words": [], "phones": []} 145 | for word_phones, vals in grid.items(): 146 | for w in vals: 147 | if w.text == "": 148 | continue 149 | # what about words spoken multiple times? 150 | # if word_phones == 'words': 151 | # data[word_phones][w.text] = (w.xmin, w.xmax) 152 | data[word_phones].append((w.xmin, w.xmax, w.text)) 153 | return data 154 | 155 | 156 | def get_word_times(target_word, sample): 157 | ret = [] 158 | for start, end, word in sample["words"]: 159 | if word == target_word: 160 | ret.append((start, end, word)) 161 | return ret 162 | 163 | 164 | def find_phones_in_interval(sample, start, end): 165 | phones = [] 166 | for s, e, p in sample["phones"]: 167 | if start <= s < end and start < e <= end: 168 | phones.append((s, e, p)) 169 | return phones 170 | 171 | 172 | def get_last_syllable_duration(sample): 173 | target_word = EXAMPLE_TO_TARGET_WORD[sample["example"]] 174 | syllable_list = LAST_SYLLABLE[target_word] 175 | 176 | # Extract target word from sample 177 | # and take only phones from this word 178 | word_boundaries = get_word_times(target_word, sample) 179 | wstart, wend, _ = word_boundaries[0] # assume only one entry 180 | phones = find_phones_in_interval(sample, wstart, wend) 181 | 182 | start, end = None, None 183 | 184 | iter_again = False # hacky 185 | for syllable in syllable_list: 186 | n_phones = len(syllable) 187 | last_phones = phones[-n_phones:] 188 | start = last_phones[0][0] 189 | end = last_phones[-1][1] 190 | 191 | # Check that phonemes are the same as expected 192 | for s, s_phone in zip(syllable, last_phones): 193 | # assert s == s_phone[-1], f"Not the same phones {syllable} != {last_phones}" 194 | if s != s_phone[-1]: 195 | iter_again = True 196 | break 197 | if not iter_again: 198 | break 199 | 200 | assert start is not None, f"start: {start}, end:{end} not Found" 201 | return start, end, end - start 202 | 203 | 204 | def extract_final_f0_height( 205 | waveform, start, end, sample_rate=16000, hop_time=0.01, f0_min=60, f0_max=400 206 | ): 207 | sound = torch_to_praat_sound(waveform, sample_rate) 208 | pitch = sound.to_pitch( 209 | time_step=hop_time, pitch_floor=f0_min, pitch_ceiling=f0_max 210 | ).selected_array["frequency"] 211 | 212 | # Frame boundaries 213 | s_frame = int(start / hop_time) 214 | e_frame = int(end / hop_time) 215 | 216 | # stats 217 | mean = pitch[pitch > 0].mean() 218 | min_f0 = pitch[pitch > 0].min() 219 | max_f0 = pitch[s_frame : e_frame + 1].max() 220 | return max_f0 / mean, max_f0, min_f0, mean 221 | 222 | 223 | def extract_f0_duration_data(dset): 224 | data = {} 225 | n_skipped = 0 226 | for sample in tqdm(dset, desc="collect f0/dur data"): 227 | example = sample["example"] 228 | short_long = sample["size"] 229 | gender = sample["gender"] 230 | 231 | start, end, dur = get_last_syllable_duration(sample) 232 | r, f0_max, f0_min, f0_mean = extract_final_f0_height( 233 | sample["waveform"], start, end 234 | ) 235 | 236 | # add to data 237 | if example not in data: 238 | data[example] = {} 239 | 240 | if gender not in data[example]: 241 | data[example][gender] = {} 242 | 243 | if short_long not in data[example][gender]: 244 | data[example][gender][short_long] = [] 245 | 246 | data[example][gender][short_long].append([dur, r]) 247 | 248 | print("Skipped: ", n_skipped) 249 | return data 250 | 251 | 252 | def save_all_f0_duration_plots( 253 | data, 254 | text=True, 255 | save=True, 256 | plot=False, 257 | savepath="assets/PaperB/eval_phrases/figs/f0_dur", 258 | ): 259 | Path(savepath).mkdir(parents=True, exist_ok=True) 260 | s = 12 261 | if text: 262 | s = 2 263 | for example, v in data.items(): 264 | fig, ax = plt.subplots(1, 1) 265 | ax.set_title(example) 266 | for gender, vv in v.items(): 267 | alpha = 0.8 if gender == "female" else 0.4 268 | for dur, r in vv["short"]: 269 | ax.scatter(dur, r, s=s, color="g", alpha=alpha) 270 | if text: 271 | ax.text(dur, r, s="S", color="g", fontweight="bold") 272 | for dur, r in vv["long"]: 273 | ax.scatter(dur, r, s=s, color="b", alpha=alpha) 274 | if text: 275 | ax.text(dur, r, s="L", color="b", fontweight="bold") 276 | ax.set_xlabel("Duration") 277 | ax.set_ylabel("Rel. F0") 278 | if save: 279 | fig.savefig(join(savepath, example + ".png")) 280 | if plot: 281 | plt.show() 282 | else: 283 | plt.close("all") 284 | 285 | 286 | if __name__ == "__main__": 287 | from conv_ssl.evaluation.phrase_dataset import PhraseDataset 288 | 289 | dset = PhraseDataset( 290 | "assets/phrases_beta/phrases.json", 291 | vad_hz=50, 292 | sample_rate=16000, 293 | vad_horizon=2.0, 294 | ) 295 | 296 | data = extract_f0_duration_data(dset) 297 | 298 | save_all_f0_duration_plots(data, save=True, plot=True) 299 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/evaluate_paper_models.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from os import makedirs, cpu_count 3 | 4 | from conv_ssl.evaluation.evaluation import evaluate 5 | from conv_ssl.evaluation.utils import get_checkpoint, load_paper_versions 6 | from conv_ssl.utils import everything_deterministic 7 | 8 | model_ids = { 9 | "discrete": { 10 | "0": "1h52tpnn", 11 | "1": "3fhjobk0", 12 | "2": "120k8fdv", 13 | "3": "1vx0omkd", 14 | "4": "sbzhz86n", 15 | "5": "1lyezca0", 16 | "6": "2vtd1u1n", 17 | "7": "2ldfo4rg", 18 | "8": "2ca7uxad", 19 | "9": "2fsy74rf", 20 | "10": "3ik6jod6", 21 | }, 22 | "independent": { 23 | "0": "1t7vvo0c", 24 | "1": "24bn5wi6", 25 | "2": "1u7yzji0", 26 | "3": "s5unjaj7", 27 | "4": "10krujrj", 28 | "5": "2rq33fxr", 29 | "6": "3uqpk8e1", 30 | "7": "3mpxa1iy", 31 | "8": "3ulpo767", 32 | "9": "3d952gec", 33 | "10": "2651d3ln", 34 | }, 35 | "independent_baseline": { 36 | "0": "2mme28tm", 37 | "1": "qo9mf26t", 38 | "2": "2rrdm5ma", 39 | "3": "mrzizwex", 40 | "4": "1lximsk1", 41 | "5": "2wyymo7n", 42 | "6": "1nze8m3l", 43 | "7": "1cdhj9yo", 44 | "8": "kamzjel0", 45 | "9": "15ze1p0y", 46 | "10": "2mvwxxar", 47 | }, 48 | "comparative": { 49 | "0": "2kwhi1zi", 50 | "1": "2izpsu6r", 51 | "2": "23mzxhhd", 52 | "3": "1lvk73tr", 53 | "4": "11jlsatj", 54 | "5": "nxvb62j4", 55 | "6": "1pglrfbn", 56 | "7": "1z9qyfh6", 57 | "8": "1kgiwy2m", 58 | "9": "1eluv8de", 59 | "10": "2530040o", 60 | }, 61 | } 62 | 63 | everything_deterministic() 64 | 65 | savepath = "assets/paper_evaluation" 66 | makedirs(savepath, exist_ok=True) 67 | 68 | for model_type, ids in model_ids.items(): 69 | model_root = join(savepath, model_type) 70 | makedirs(model_root, exist_ok=True) 71 | for kfold, id in ids.items(): 72 | instance_root = join(model_root, f"kfold_{kfold}") 73 | makedirs(instance_root, exist_ok=True) 74 | 75 | # get checkpoint from wandb-ID 76 | # change in repo requires slight change of `state_dict` 77 | checkpoint_path = get_checkpoint(run_path=id) 78 | checkpoint_path = load_paper_versions(checkpoint_path) 79 | 80 | # Threshold and extract score 81 | metrics, prediction, curves = evaluate( 82 | checkpoint_path=checkpoint_path, 83 | savepath=instance_root, 84 | batch_size=16, 85 | num_workers=cpu_count(), 86 | ) 87 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from os.path import join, basename, dirname as dr 3 | from os import makedirs 4 | from omegaconf import DictConfig, OmegaConf 5 | import hydra 6 | 7 | import torch 8 | from pytorch_lightning import Trainer 9 | from pytorch_lightning.loggers import WandbLogger 10 | 11 | from conv_ssl.callbacks import SymmetricSpeakersCallback 12 | from conv_ssl.model import VPModel 13 | from conv_ssl.utils import ( 14 | everything_deterministic, 15 | write_json, 16 | read_json, 17 | tensor_dict_to_json, 18 | ) 19 | from datasets_turntaking import DialogAudioDM 20 | 21 | # ugly path 22 | 23 | SAVEPATH = join(dr(dr(dr(__file__))), "assets/PaperB/eval") 24 | MIN_THRESH = 0.01 # minimum threshold limit for S/L, S-pred, BC-pred 25 | 26 | everything_deterministic() 27 | 28 | """ 29 | python conv_ssl/evaluation/evaluation.py \ 30 | +checkpoint_path=/FULL/PATH/TO/CHECKPOINT/checkpoint.ckpt \ 31 | data.num_workers=4 \ 32 | data.batch_size=10 33 | """ 34 | 35 | 36 | def load_dm(model, cfg_dict, verbose=False): 37 | data_conf = model.conf["data"] 38 | data_conf["audio_mono"] = False 39 | data_conf["datasets"] = cfg_dict["data"].get("datasets", data_conf["datasets"]) 40 | data_conf["batch_size"] = cfg_dict["data"].get( 41 | "batch_size", data_conf["batch_size"] 42 | ) 43 | data_conf["num_workers"] = cfg_dict["data"].get( 44 | "num_workers", data_conf["num_workers"] 45 | ) 46 | if verbose: 47 | print("Num Workers: ", data_conf["num_workers"]) 48 | print("Batch size: ", data_conf["batch_size"]) 49 | print("Mono: ", data_conf["audio_mono"]) 50 | print("datasets: ", data_conf["datasets"]) 51 | dm = DialogAudioDM(**data_conf) 52 | dm.prepare_data() 53 | dm.setup("test") 54 | 55 | 56 | def test(model, dloader, max_batches=None, project="VPModelTest", online=False): 57 | """ 58 | Iterate over the dataloader to extract metrics. 59 | 60 | * Adds SymmetricSpeakersCallback 61 | - each sample is duplicated with channels reversed 62 | * online = True 63 | - upload to wandb 64 | """ 65 | logger = None 66 | if online: 67 | savedir = "runs/" + project 68 | makedirs(savedir, exist_ok=True) 69 | logger = WandbLogger( 70 | save_dir=savedir, 71 | project=project, 72 | name=model.run_name, 73 | log_model=False, 74 | ) 75 | 76 | # Limit batches 77 | if max_batches is not None: 78 | trainer = Trainer( 79 | gpus=-1, 80 | limit_test_batches=max_batches, 81 | deterministic=True, 82 | logger=logger, 83 | callbacks=[SymmetricSpeakersCallback()], 84 | ) 85 | else: 86 | trainer = Trainer( 87 | gpus=-1, 88 | deterministic=True, 89 | logger=logger, 90 | callbacks=[SymmetricSpeakersCallback()], 91 | ) 92 | 93 | result = trainer.test(model, dataloaders=dloader, verbose=False) 94 | return result 95 | 96 | 97 | def get_curves(preds, target, pos_label=1, thresholds=None, EPS=1e-6): 98 | """ 99 | precision = tp / (tp+fp) 100 | recall = tp / (tp+fn) 101 | 102 | """ 103 | 104 | if thresholds is None: 105 | thresholds = torch.linspace(0, 1, steps=101) 106 | 107 | if pos_label == 0: 108 | raise NotImplemented("Have not done this") 109 | 110 | ba, f1 = [], [] 111 | auc0, auc1 = [], [] 112 | prec0, rec0 = [], [] 113 | prec1, rec1 = [], [] 114 | pos_label_idx = torch.where(target == 1) 115 | neg_label_idx = torch.where(target == 0) 116 | 117 | for t in thresholds: 118 | pred_labels = (preds >= t).float() 119 | correct = pred_labels == target 120 | 121 | # POSITIVES 122 | tp = correct[pos_label_idx].sum() 123 | n_p = (target == 1).sum() 124 | fn = n_p - tp 125 | # NEGATIVES 126 | tn = correct[neg_label_idx].sum() 127 | n_n = (target == 0).sum() 128 | fp = n_n - tn 129 | ###################################3 130 | # Balanced Accuracy 131 | ###################################3 132 | # TPR, TNR 133 | tpr = tp / n_p 134 | tnr = tn / n_n 135 | # BA 136 | ba_tmp = (tpr + tnr) / 2 137 | ba.append(ba_tmp) 138 | ###################################3 139 | # F1 140 | ###################################3 141 | precision1 = tp / (tp + fp + EPS) 142 | recall1 = tp / (tp + fn + EPS) 143 | f1_1 = 2 * precision1 * recall1 / (precision1 + recall1 + EPS) 144 | prec1.append(precision1) 145 | rec1.append(recall1) 146 | auc1.append(precision1 * recall1) 147 | 148 | precision0 = tn / (tn + fn + EPS) 149 | recall0 = tn / (tn + fp + EPS) 150 | f1_0 = 2 * precision0 * recall0 / (precision0 + recall0 + EPS) 151 | prec0.append(precision0) 152 | rec0.append(recall0) 153 | auc0.append(precision0 * recall0) 154 | 155 | f1w = (f1_0 * n_n + f1_1 * n_p) / (n_n + n_p) 156 | f1.append(f1w) 157 | 158 | return { 159 | "bacc": torch.stack(ba), 160 | "f1": torch.stack(f1), 161 | "prec1": torch.stack(prec1), 162 | "rec1": torch.stack(rec1), 163 | "prec0": torch.stack(prec0), 164 | "rec0": torch.stack(rec0), 165 | "auc0": torch.stack(auc0), 166 | "auc1": torch.stack(auc1), 167 | "thresholds": thresholds, 168 | } 169 | 170 | 171 | def find_threshold(model, dloader, min_thresh=0.01): 172 | """Find the best threshold using PR-curves""" 173 | 174 | def get_best_thresh(curves, metric, measure, min_thresh): 175 | ts = curves[metric]["thresholds"] 176 | over = min_thresh <= ts 177 | under = ts <= (1 - min_thresh) 178 | w = torch.where(torch.logical_and(over, under)) 179 | values = curves[metric][measure][w] 180 | ts = ts[w] 181 | _, best_idx = values.max(0) 182 | return ts[best_idx] 183 | 184 | # Init metric: 185 | model.test_metric = model.init_metric( 186 | bc_pred_pr_curve=True, 187 | shift_pred_pr_curve=True, 188 | long_short_pr_curve=True, 189 | ) 190 | 191 | # Find Thresholds 192 | _ = test(model, dloader, online=False) 193 | 194 | ############################################ 195 | # Save predictions 196 | predictions = {} 197 | if hasattr(model.test_metric, "long_short_pr"): 198 | predictions["long_short"] = { 199 | "preds": torch.cat(model.test_metric.long_short_pr.preds), 200 | "target": torch.cat(model.test_metric.long_short_pr.target), 201 | } 202 | if hasattr(model.test_metric, "bc_pred_pr"): 203 | predictions["bc_preds"] = { 204 | "preds": torch.cat(model.test_metric.bc_pred_pr.preds), 205 | "target": torch.cat(model.test_metric.bc_pred_pr.target), 206 | } 207 | if hasattr(model.test_metric, "shift_pred_pr"): 208 | predictions["shift_preds"] = { 209 | "preds": torch.cat(model.test_metric.shift_pred_pr.preds), 210 | "target": torch.cat(model.test_metric.shift_pred_pr.target), 211 | } 212 | 213 | ############################################ 214 | # Curves 215 | curves = {} 216 | for metric in ["bc_preds", "long_short", "shift_preds"]: 217 | curves[metric] = get_curves( 218 | preds=predictions[metric]["preds"], target=predictions[metric]["target"] 219 | ) 220 | 221 | ############################################ 222 | # find best thresh 223 | bc_pred_threshold = None 224 | shift_pred_threshold = None 225 | long_short_threshold = None 226 | if "bc_preds" in curves: 227 | bc_pred_threshold = get_best_thresh(curves, "bc_preds", "f1", min_thresh) 228 | if "shift_preds" in curves: 229 | shift_pred_threshold = get_best_thresh(curves, "shift_preds", "f1", min_thresh) 230 | if "long_short" in curves: 231 | long_short_threshold = get_best_thresh(curves, "long_short", "f1", min_thresh) 232 | 233 | thresholds = { 234 | "pred_shift": shift_pred_threshold, 235 | "pred_bc": bc_pred_threshold, 236 | "short_long": long_short_threshold, 237 | } 238 | return thresholds, predictions, curves 239 | 240 | 241 | @hydra.main(config_path="../conf", config_name="config") 242 | def evaluate(cfg: DictConfig) -> None: 243 | """Evaluate model""" 244 | cfg_dict = OmegaConf.to_object(cfg) 245 | cfg_dict = dict(cfg_dict) 246 | 247 | # Load model 248 | model = VPModel.load_from_checkpoint(cfg.checkpoint_path, strict=False) 249 | model = model.eval() 250 | if torch.cuda.is_available(): 251 | model = model.to("cuda") 252 | 253 | savepath = join(SAVEPATH, basename(cfg.checkpoint_path).replace(".ckpt", "")) 254 | savepath += "_" + "_".join(cfg.data.datasets) 255 | Path(savepath).mkdir(exist_ok=True, parents=True) 256 | 257 | # Load data 258 | print("Num Workers: ", cfg.data.num_workers) 259 | print("Batch size: ", cfg.data.batch_size) 260 | print(cfg.data.datasets) 261 | dm = DialogAudioDM( 262 | datasets=cfg.data.datasets, 263 | type=cfg.data.type, 264 | audio_duration=cfg.data.audio_duration, 265 | audio_normalize=cfg.data.audio_normalize, 266 | audio_overlap=cfg.data.audio_overlap, 267 | sample_rate=cfg.data.sample_rate, 268 | vad_hz=model.frame_hz, 269 | vad_horizon=model.VAP.horizon, 270 | vad_history=cfg.data.vad_history, 271 | vad_history_times=cfg.data.vad_history_times, 272 | flip_channels=False, 273 | batch_size=cfg.data.batch_size, 274 | num_workers=cfg.data.num_workers, 275 | ) 276 | dm.prepare_data() 277 | dm.setup(None) 278 | 279 | # Threshold 280 | # Find the best thresholds (S-pred, BC-pred, S/L) on the validation set 281 | threshold_path = cfg.get("thresholds", None) 282 | if threshold_path is None: 283 | print("#" * 60) 284 | print("Finding Thresholds (val-set)...") 285 | print("#" * 60) 286 | thresholds, prediction, curves = find_threshold( 287 | model, dm.val_dataloader(), min_thresh=MIN_THRESH 288 | ) 289 | 290 | th = {k: v.item() for k, v in thresholds.items()} 291 | write_json(th, join(savepath, "thresholds.json")) 292 | torch.save(prediction, join(savepath, "predictions.pt")) 293 | torch.save(curves, join(savepath, "curves.pt")) 294 | print("Saved Thresholds -> ", join(savepath, "thresholds.json")) 295 | print("Saved Curves -> ", join(savepath, "curves.pt")) 296 | else: 297 | print("Loading thresholds: ", threshold_path) 298 | thresholds = read_json(threshold_path) 299 | 300 | # Score 301 | print("#" * 60) 302 | print("Final Score (test-set)...") 303 | print("#" * 60) 304 | model.test_metric = model.init_metric( 305 | threshold_pred_shift=thresholds.get("pred_shift", 0.5), 306 | threshold_short_long=thresholds.get("short_long", 0.5), 307 | threshold_bc_pred=thresholds.get("pred_bc", 0.5), 308 | ) 309 | result = test(model, dm.test_dataloader(), online=False)[0] 310 | metrics = model.test_metric.compute() 311 | 312 | metrics["loss"] = result["test_loss"] 313 | metrics["threshold_pred_shift"] = thresholds["pred_shift"] 314 | metrics["threshold_pred_bc"] = thresholds["pred_bc"] 315 | metrics["threshold_short_long"] = thresholds["short_long"] 316 | 317 | metric_json = tensor_dict_to_json(metrics) 318 | write_json(metric_json, join(savepath, "metric.json")) 319 | print("Saved metrics -> ", join(savepath, "metric.pt")) 320 | 321 | 322 | if __name__ == "__main__": 323 | evaluate() 324 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/evaluation_augmentation.py: -------------------------------------------------------------------------------- 1 | from os.path import join, basename 2 | from pathlib import Path 3 | from os import makedirs 4 | from omegaconf import DictConfig, OmegaConf 5 | import hydra 6 | 7 | import torch 8 | from pytorch_lightning import Trainer 9 | from pytorch_lightning.loggers import WandbLogger 10 | 11 | from conv_ssl.callbacks import SymmetricSpeakersCallback 12 | from conv_ssl.model import VPModel 13 | import conv_ssl.transforms as CT 14 | from conv_ssl.utils import ( 15 | everything_deterministic, 16 | write_json, 17 | read_json, 18 | tensor_dict_to_json, 19 | ) 20 | from datasets_turntaking import DialogAudioDM 21 | 22 | everything_deterministic() 23 | 24 | SAVEPATH = "/home/erik/projects/CCConv/conv_ssl/assets/PaperB/eval" 25 | 26 | 27 | def load_dm(model, cfg_dict, transforms, verbose=False): 28 | data_conf = model.conf["data"] 29 | data_conf["audio_mono"] = False 30 | data_conf["datasets"] = cfg_dict["data"].get("datasets", data_conf["datasets"]) 31 | data_conf["batch_size"] = cfg_dict["data"].get( 32 | "batch_size", data_conf["batch_size"] 33 | ) 34 | data_conf["num_workers"] = cfg_dict["data"].get( 35 | "num_workers", data_conf["num_workers"] 36 | ) 37 | if verbose: 38 | print("datasets: ", data_conf["datasets"]) 39 | print("duration: ", data_conf["audio_duration"]) 40 | print("Num Workers: ", data_conf["num_workers"]) 41 | print("Batch size: ", data_conf["batch_size"]) 42 | print("Mono: ", data_conf["audio_mono"]) 43 | dm = DialogAudioDM(**data_conf, transforms=transforms) 44 | dm.prepare_data() 45 | dm.setup(None) 46 | return dm 47 | 48 | 49 | def get_augmentation_params(model, cfg, augmentation): 50 | name_suffix = "" 51 | aug_params = {} 52 | transforms = None 53 | if augmentation == "flat_f0": 54 | aug_params = { 55 | "target_f0": cfg.get("target_f0", -1), 56 | "statistic": cfg.get("statistic", "mean"), 57 | "stats_frame_length": int(0.05 * model.sample_rate), 58 | "stats_hop_length": int(0.02 * model.sample_rate), 59 | "sample_rate": model.sample_rate, 60 | "to_mono": True, 61 | } 62 | name_suffix = "flat_f0" 63 | transforms = CT.FlatPitch(**aug_params) 64 | elif augmentation == "shift_f0": 65 | aug_params = { 66 | "factor": cfg.get("factor", 0.9), 67 | "sample_rate": model.sample_rate, 68 | "to_mono": True, 69 | } 70 | name_suffix = f"shift_f0_{aug_params['factor']}" 71 | transforms = CT.ShiftPitch(**aug_params) 72 | elif augmentation == "flat_intensity": 73 | aug_params = { 74 | "vad_hz": model.frame_hz, 75 | "vad_cutoff": cfg.get("vad_cutoff", 0.2), 76 | "hop_time": cfg.get("hop_time", 0.01), 77 | "f0_min": cfg.get("f0_min", 60), 78 | "statistic": cfg.get("statistic", "mean"), 79 | "sample_rate": model.sample_rate, 80 | "to_mono": True, 81 | } 82 | name_suffix = "flat_intensity" 83 | transforms = CT.FlatIntensity(**aug_params) 84 | elif augmentation == "only_f0": 85 | aug_params = { 86 | "cutoff_freq": cfg.get("cutoff_freq", 400), 87 | "sample_rate": model.sample_rate, 88 | "norm": True, 89 | "to_mono": True, 90 | } 91 | name_suffix = f"only_f0_{aug_params['cutoff_freq']}" 92 | transforms = CT.LowPass(**aug_params) 93 | return transforms, aug_params, name_suffix 94 | 95 | 96 | def test_augmented( 97 | model, 98 | dloader, 99 | max_batches=None, 100 | project="VAPFlatTest", 101 | online=False, 102 | ): 103 | """ 104 | Iterate over the dataloader to extract metrics. 105 | 106 | Callbacks are done in order! so important to do Flat first... 107 | 108 | * Adds FlattenPitchCallback 109 | - flattens the pitch of each waveform/speaker 110 | * Adds SymmetricSpeakersCallback 111 | - each sample is duplicated with channels reversed 112 | * online = True 113 | - upload to wandb 114 | """ 115 | logger = None 116 | if online: 117 | logger = WandbLogger( 118 | project=project, 119 | name=model.run_name, 120 | log_model=False, 121 | ) 122 | 123 | callbacks = [SymmetricSpeakersCallback()] 124 | # Limit batches 125 | if max_batches is not None: 126 | trainer = Trainer( 127 | gpus=-1, 128 | limit_test_batches=max_batches, 129 | deterministic=True, 130 | logger=logger, 131 | callbacks=callbacks, 132 | ) 133 | else: 134 | trainer = Trainer( 135 | gpus=-1, deterministic=True, logger=logger, callbacks=callbacks 136 | ) 137 | 138 | result = trainer.test(model, dataloaders=dloader, verbose=False) 139 | return result 140 | 141 | 142 | @hydra.main(config_path="../conf", config_name="config") 143 | def evaluate(cfg: DictConfig) -> None: 144 | """Evaluate model""" 145 | cfg_dict = OmegaConf.to_object(cfg) 146 | cfg_dict = dict(cfg_dict) 147 | 148 | assert cfg.get("checkpoint_path", False), "Must provide `+checkpoint_path=/path/to`" 149 | 150 | augmentation = cfg.get("augmentation", None) 151 | assert ( 152 | augmentation is not None 153 | ), f"Please provide `augmentation` by `+augmentation=` and any of ['flat_f0', 'shift_f0', 'flat_intensity', 'only_f0']" 154 | 155 | ################################## 156 | # Load Model 157 | ################################## 158 | model = VPModel.load_from_checkpoint(cfg.checkpoint_path, strict=False) 159 | model = model.eval() 160 | if torch.cuda.is_available(): 161 | model = model.to("cuda") 162 | 163 | # Save directory + Augmentations 164 | savepath = join(SAVEPATH, basename(cfg.checkpoint_path).replace(".ckpt", "")) 165 | savepath += "_" + "_".join(cfg.data.datasets) 166 | Path(savepath).mkdir(exist_ok=True, parents=True) 167 | transforms, aug_params, name_suffix = get_augmentation_params( 168 | model, cfg, augmentation 169 | ) 170 | 171 | assert transforms is not None, "NO transformations" 172 | 173 | for k, v in aug_params.items(): 174 | print(f"{k}: {v}") 175 | print("#" * 60) 176 | 177 | # ################################## 178 | # # Load Data 179 | # ################################## 180 | data_conf = model.conf["data"] 181 | dm = load_dm(model, cfg_dict, transforms=transforms, verbose=True) 182 | 183 | print("Thresholds") 184 | threshold_path = join(savepath, "thresholds.json") 185 | print("Loading thresholds: ", threshold_path) 186 | thresholds = read_json(threshold_path) 187 | for k, v in thresholds.items(): 188 | print(f"{k}: {v}") 189 | 190 | print(f"SAVEPATH: ", savepath) 191 | # input("Press Enter to Continue") 192 | 193 | ################################## 194 | # Test 195 | ################################## 196 | print("#" * 60) 197 | print("Flat Score (test-set)...") 198 | print("#" * 60) 199 | model.test_metric = model.init_metric( 200 | threshold_pred_shift=thresholds["pred_shift"], 201 | threshold_short_long=thresholds["short_long"], 202 | threshold_bc_pred=thresholds["pred_bc"], 203 | ) 204 | result = test_augmented( 205 | model, 206 | dm.test_dataloader(), 207 | online=False, 208 | max_batches=cfg_dict.get("max_batches", None), 209 | )[0] 210 | metrics = model.test_metric.compute() 211 | metrics["loss"] = result["test_loss"] 212 | metrics["threshold_pred_shift"] = thresholds["pred_shift"] 213 | metrics["threshold_pred_bc"] = thresholds["pred_bc"] 214 | metrics["threshold_short_long"] = thresholds["short_long"] 215 | 216 | makedirs(savepath, exist_ok=True) 217 | metric_json = tensor_dict_to_json(metrics) 218 | filepath = join(savepath, f"metric_{name_suffix}.json") 219 | write_json(metric_json, filepath) 220 | print("Saved metrics -> ", filepath) 221 | 222 | 223 | if __name__ == "__main__": 224 | evaluate() 225 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/extract_video_data.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from conv_ssl.model import VPModel 7 | from conv_ssl.utils import everything_deterministic 8 | from datasets_turntaking import DialogAudioDM 9 | 10 | everything_deterministic() 11 | 12 | 13 | def video_data_single( 14 | model, 15 | dset, 16 | idx, 17 | audio_duration=10, 18 | audio_overlap=5, 19 | batch_size=8, 20 | savepath="assets/video", 21 | ): 22 | """ 23 | Extract video data from single dialog (defined by `idx`) 24 | """ 25 | 26 | # Extract all data for video 27 | d = dset.get_dialog_sample(idx) 28 | batches = dset.dialog_to_batch( 29 | d, 30 | audio_duration=audio_duration, 31 | audio_overlap=audio_overlap, 32 | batch_size=batch_size, 33 | ) 34 | 35 | # Combine all data 36 | start_frame = int(audio_overlap * dm.vad_hz) 37 | start_sample = int(audio_overlap * dm.sample_rate) 38 | video_data = {"waveform": [], "va": [], "vh": [], "p": [], "p_bc": [], "logits": []} 39 | losses = [] 40 | for i, batch in enumerate(tqdm(batches)): 41 | loss, out, probs, batch = model.output(batch) 42 | losses.append(loss["total"]) 43 | tmp_batch_size = out["logits_vp"].shape[0] 44 | start_batch = 0 45 | if i == 0: 46 | video_data["waveform"].append(batch["waveform"][0].to("cpu")) 47 | video_data["va"].append(batch["vad"][0].to("cpu")) 48 | video_data["vh"].append(batch["vad_history"][0].to("cpu")) 49 | video_data["p"].append(probs["p"][0].to("cpu")) 50 | video_data["p_bc"].append(probs["bc_prediction"][0].to("cpu")) 51 | video_data["logits"].append(out["logits_vp"][0].to("cpu")) 52 | start_batch = 1 53 | for n in range(start_batch, tmp_batch_size): 54 | video_data["waveform"].append(batch["waveform"][n, start_sample:].to("cpu")) 55 | video_data["va"].append(batch["vad"][n, start_frame:].to("cpu")) 56 | video_data["vh"].append(batch["vad_history"][n, start_frame:].to("cpu")) 57 | video_data["p"].append(probs["p"][n, start_frame:].to("cpu")) 58 | video_data["p_bc"].append(probs["bc_prediction"][n, start_frame:].to("cpu")) 59 | video_data["logits"].append(out["logits_vp"][n, start_frame:].to("cpu")) 60 | 61 | for name, vallist in video_data.items(): 62 | video_data[name] = torch.cat(vallist) 63 | 64 | # Add additional info 65 | video_data["loss"] = torch.stack(losses).mean() 66 | video_data["vap_bins"] = model.VAP.vap_bins.cpu() 67 | video_data["session"] = d["session"][0] 68 | 69 | # save to disk 70 | makedirs(savepath, exist_ok=True) 71 | 72 | filename = join(savepath, f"{d['session'][0]}_video_data.pt") 73 | torch.save(video_data, filename) 74 | print("Saved -> ", filename) 75 | 76 | return video_data 77 | 78 | 79 | if __name__ == "__main__": 80 | from argparse import ArgumentParser 81 | 82 | parser = ArgumentParser() 83 | parser.add_argument("--checkpoint", type=str) 84 | parser.add_argument("--savepath", type=str) 85 | parser.add_argument("--n", type=int, default=3, help="Number of videos") 86 | parser.add_argument("--session", type=str, default=None, help="Specific session") 87 | 88 | args = parser.parse_args() 89 | 90 | # Load model 91 | model = VPModel.load_from_checkpoint(args.checkpoint, strict=False) 92 | model = model.eval() 93 | if torch.cuda.is_available(): 94 | model = model.to("cuda") 95 | 96 | # Load dialog 97 | # Load data 98 | data_conf = DialogAudioDM.load_config() 99 | DialogAudioDM.print_dm(data_conf) 100 | dm = DialogAudioDM( 101 | datasets=data_conf["dataset"]["datasets"], 102 | type=data_conf["dataset"]["type"], 103 | audio_duration=data_conf["dataset"]["audio_duration"], 104 | audio_normalize=data_conf["dataset"]["audio_normalize"], 105 | audio_overlap=data_conf["dataset"]["audio_overlap"], 106 | sample_rate=data_conf["dataset"]["sample_rate"], 107 | vad_hz=model.frame_hz, 108 | vad_horizon=model.VAP.horizon, 109 | vad_history=data_conf["dataset"]["vad_history"], 110 | vad_history_times=data_conf["dataset"]["vad_history_times"], 111 | flip_channels=False, 112 | batch_size=1, 113 | num_workers=0, 114 | ) 115 | dm.prepare_data() 116 | dm.setup(None) 117 | 118 | if args.session is not None: 119 | idx = dm.test_dset.dataset["session"].index(args.session) 120 | video_data = video_data_single( 121 | model, dset=dm.test_dset, idx=idx, savepath=args.savepath 122 | ) 123 | else: 124 | # select 20 first test-videos 125 | for idx in range(args.n): 126 | video_data = video_data_single( 127 | model, dset=dm.test_dset, idx=idx, savepath=args.savepath 128 | ) 129 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/forced_alignment.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 1. Install 4 | # Create a new conda environment (if you want), source environment and run: 5 | # conda install -c conda-forge montreal-forced-aligner 6 | 7 | # 2. Download acoustic-model/dictionary 8 | # mfa model download acoustic english_us_arpa 9 | # mfa model download dictionary english_us_arpa 10 | 11 | # 3. validate corpus and align 12 | corpus="assets/phrases_beta/audio" 13 | alignpath="assets/phrases_beta/alignment" 14 | 15 | # echo "Validating $corpus" 16 | # mfa validate $corpus english_us_arpa english_us_arpa 17 | 18 | echo "Alignment $corpus -> $alignpath" 19 | mfa align $corpus english_us_arpa english_us_arpa $alignpath --clean 20 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/forced_alignment_duration.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 1. Install 4 | # Create a new conda environment (if you want), source environment and run: 5 | # conda install -c conda-forge montreal-forced-aligner 6 | 7 | # 2. Download acoustic-model/dictionary 8 | # mfa model download acoustic english_us_arpa 9 | # mfa model download dictionary english_us_arpa 10 | 11 | # 3. validate corpus and align 12 | corpus="assets/phrases_beta/duration_audio" 13 | alignpath="assets/phrases_beta/duration_alignment" 14 | 15 | 16 | # echo "Validating $corpus" 17 | # mfa validate $corpus english_us_arpa english_us_arpa 18 | 19 | echo "Alignment $corpus -> $alignpath" 20 | mfa align $corpus english_us_arpa english_us_arpa $alignpath --clean 21 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/phrases.json: -------------------------------------------------------------------------------- 1 | { 2 | "student": { 3 | "short": "Are you a student?", 4 | "long": "Are you a student here at this university?" 5 | }, 6 | "psychology": { 7 | "short": "Do you study psychology?", 8 | "long": "Do you study psychology here at this university?" 9 | }, 10 | "first_year": { 11 | "short": "Are you a first year student?", 12 | "long": "Are you a first year student here at this university?" 13 | }, 14 | "basketball": { 15 | "short": "So do you play basketball?", 16 | "long": "So do you play basketball on Thursdays?" 17 | }, 18 | "experiment": { 19 | "short": "Have you participated in any experiments before?", 20 | "long": "Have you participated in any experiments before here at this university?" 21 | }, 22 | "live": { 23 | "short": "Do you live by yourself?", 24 | "long": "Do you live by yourself or with someone else?" 25 | }, 26 | "work": { 27 | "short": "So you work on the side?", 28 | "long": "So you work on the side in a supermarket in addition to your studies?" 29 | }, 30 | "bike": { 31 | "short": "Did you come here by bike?", 32 | "long": "Did you come here by bike this morning?" 33 | }, 34 | "drive": { 35 | "short": "So did you drive here?", 36 | "long": "So did you drive here this morning?" 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/phrases_duration_process.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import parselmouth 3 | import numpy as np 4 | import torchaudio 5 | from os.path import join, basename 6 | from pathlib import Path 7 | from parselmouth.praat import call 8 | import string 9 | from tqdm import tqdm 10 | 11 | from conv_ssl.augmentations import torch_to_praat_sound, praat_to_torch 12 | from conv_ssl.evaluation.duration import read_text_grid 13 | from conv_ssl.evaluation.phrase_dataset import PhraseDataset, words_to_vad 14 | 15 | 16 | ROOT = "assets/phrases_beta" 17 | WAV_ROOT = join(ROOT, "duration_audio") 18 | TG_ROOT = join(ROOT, "duration_alignment") 19 | 20 | 21 | """ 22 | 1. Extract audio and save transcripts -> `python conv_ssl/evaluation/phrases_duration_process.py --preprocess` 23 | 2. Align audio and save .TextGrid -> `bash conv_ssl/evaluation/forced_alignment_duration.bash` 24 | - don't forget to source conda env 25 | """ 26 | 27 | 28 | def calculate_average_phone_duration(dset): 29 | phone_durations = {} 30 | for sample in dset: 31 | for s, e, p in sample["phones"]: 32 | d = e - s 33 | if p in phone_durations: 34 | phone_durations[p].append(d) 35 | else: 36 | phone_durations[p] = [d] 37 | for phone, durations in phone_durations.items(): 38 | phone_durations[phone] = np.mean(durations) 39 | return phone_durations 40 | 41 | 42 | class DurationAvg(object): 43 | """ 44 | This is not a transformation but a class to change the duration 45 | on the `phrase dataset`. 46 | 47 | Requires timing of phones 48 | """ 49 | 50 | def __init__(self, phone_durations, sample_rate=16000): 51 | super().__init__() 52 | self.phone_durations = phone_durations 53 | self.sample_rate = sample_rate 54 | 55 | # praat 56 | self.hop_size = 0.01 57 | self.f0_min = 60 58 | self.f0_max = 400 59 | self.eps = 1e-5 60 | 61 | def __call__(self, sample): 62 | sound = torch_to_praat_sound(sample["waveform"], sample_rate=self.sample_rate) 63 | manipulation = call( 64 | sound, "To Manipulation", self.hop_size, self.f0_min, self.f0_max 65 | ) 66 | 67 | # add the last chunk to keep duration as is 68 | dur_tier = call( 69 | manipulation, 70 | "Create DurationTier", 71 | "shorten", 72 | sound.start_time, 73 | sound.end_time, 74 | ) 75 | 76 | # praat interpolates between start -> point -> end 77 | # so we add a non-changing duration point before the first 78 | # phone 79 | first_phone_start, _, _ = sample["phones"][0] 80 | first_phone_start = max(first_phone_start - self.eps, 0) 81 | call(dur_tier, "Add point", first_phone_start, 1.0) 82 | 83 | for start, end, phone in sample["phones"]: 84 | dur = end - start 85 | base = self.phone_durations[phone] 86 | r = dur / base 87 | 88 | # add boundary parts for current phone 89 | # where the end is slightly before actual end 90 | # (next point will start exactly on start) 91 | call(dur_tier, "Add point", start, r) 92 | call(dur_tier, "Add point", end - self.eps, r) 93 | 94 | # Add a final duration to not change remaining part 95 | # of audio signal 96 | _, end, _ = sample["phones"][-1] 97 | call(dur_tier, "Add point", end, 1.0) 98 | 99 | call([manipulation, dur_tier], "Replace duration tier") 100 | sound_dur = call(manipulation, "Get resynthesis (overlap-add)") 101 | return praat_to_torch(sound_dur) 102 | 103 | 104 | def extract_new_audio(dset): 105 | """ 106 | Load sample -> change duration -> save .wav and transcript .txt 107 | """ 108 | phone_durations = calculate_average_phone_duration(dset) 109 | duration_modifier = DurationAvg(phone_durations) 110 | Path(WAV_ROOT).mkdir(parents=True, exist_ok=True) 111 | for sample in tqdm(dset, desc="Extract avg-duration audio"): 112 | sample["waveform"] = duration_modifier(sample) 113 | wavfile = join(WAV_ROOT, sample["name"] + ".wav") 114 | textfile = join(WAV_ROOT, sample["name"] + ".txt") 115 | torchaudio.save(wavfile, sample["waveform"], sample_rate=dset.sample_rate) 116 | text = sample["text"].replace("-", " ") 117 | text = text.translate(str.maketrans("", "", string.punctuation)).lower() 118 | with open(textfile, "w") as text_file: 119 | text_file.write(text) 120 | print("Extracted avg-duration waveforms to -> ", WAV_ROOT) 121 | 122 | 123 | def raw_sample_to_sample(sample, dset): 124 | """rewritten `dset.get_sample_data()`""" 125 | sample["audio_path"] = join(WAV_ROOT, basename(sample["audio_path"])) 126 | tg_path = join(TG_ROOT, sample["name"] + ".TextGrid") 127 | tg = read_text_grid(tg_path) 128 | vad_list = words_to_vad(tg["words"]) 129 | # Returns: waveform, dataset_name, vad, vad_history 130 | ret = dset._sample_data(sample, vad_list) 131 | ret["example"] = sample["example"] 132 | ret["words"] = tg["words"] 133 | ret["phones"] = tg["phones"] 134 | ret["size"] = sample["size"] 135 | 136 | # print("ret: ", list(ret.keys())) 137 | # for k, v in sample.items(): 138 | # if k in ["vad", "words", "phones", "waveform"]: 139 | # continue 140 | # ret[k] = v 141 | # print("ret: ", list(ret.keys())) 142 | # input() 143 | return ret 144 | 145 | 146 | def _test(): 147 | import time 148 | 149 | dset = PhraseDataset("assets/phrases_beta/phrases.json") 150 | sample = dset.get_sample("student", "long", "female", 3) 151 | dur_sample = raw_sample_to_sample(sample, dset) 152 | print("sample['waveform']: ", tuple(sample["waveform"].shape)) 153 | print("dur_sample['waveform']: ", tuple(dur_sample["waveform"].shape)) 154 | 155 | for w1, w2 in zip(sample["words"], dur_sample["words"]): 156 | d1 = w1[1] - w1[0] 157 | d2 = w2[1] - w2[0] 158 | # print(w1, d1) 159 | # print(w2, d2) 160 | print(d1 - d2) 161 | print("-" * 30) 162 | 163 | sd.play(sample["waveform"][0], samplerate=16000) 164 | time.sleep(2.5) 165 | sd.play(dur_sample["waveform"][0], samplerate=16000) 166 | 167 | 168 | if __name__ == "__main__": 169 | 170 | import sounddevice as sd 171 | 172 | parser = ArgumentParser() 173 | parser.add_argument("--process", action="store_true") 174 | parser.add_argument( 175 | "--phrases", type=str, default="assets/phrases_beta/phrases.json" 176 | ) 177 | 178 | args = parser.parse_args() 179 | dset = PhraseDataset(args.phrases) 180 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/prepare_phrases_for_alignment.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from os.path import join, basename 3 | from glob import glob 4 | import string 5 | 6 | from conv_ssl.utils import read_json 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = ArgumentParser() 11 | parser.add_argument("--data", type=str, default="assets/phrases_beta") 12 | args = parser.parse_args() 13 | 14 | audio_path = join(args.data, "audio") 15 | anno_path = join(args.data, "annotation") 16 | 17 | wav_paths = glob(join(audio_path, "*.wav")) 18 | wav_paths.sort() 19 | 20 | for wav_path in wav_paths: 21 | name = basename(wav_path).replace(".wav", "") 22 | text = read_json(join(anno_path, name + ".json"))["text"] 23 | text = text.replace("-", " ") 24 | text = text.translate(str.maketrans("", "", string.punctuation)).lower() 25 | new_txt_path = join(audio_path, name + ".txt") 26 | with open(new_txt_path, "w") as text_file: 27 | text_file.write(text) 28 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/update_checkpoints.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from os.path import dirname 3 | from os import makedirs 4 | from conv_ssl.evaluation.utils import get_checkpoint, load_paper_versions 5 | 6 | 7 | parser = ArgumentParser() 8 | parser.add_argument("--id", type=str) 9 | parser.add_argument("--savepath", type=str) 10 | 11 | args = parser.parse_args() 12 | ch = get_checkpoint(run_path=args.id) 13 | 14 | makedirs(dirname(args.savepath), exist_ok=True) 15 | new_ch = load_paper_versions(ch, savepath=args.savepath) 16 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | from os.path import basename, dirname, join, exists 2 | import torch 3 | 4 | from conv_ssl.model import VPModel 5 | from datasets_turntaking import DialogAudioDM 6 | 7 | 8 | def run_path_to_project_id(run_path): 9 | id = basename(run_path) # 1xon133f 10 | project = dirname(run_path) # USER_NAME/PROJECT 11 | return project, id 12 | 13 | 14 | def run_path_to_artifact_url(run_path, version="v0"): 15 | """ 16 | run_path: "how_so/ULMProjection/1xon133f" 17 | 18 | artifact_url = "how_so/ULMProjection/model-1xon133f:v1" 19 | """ 20 | project, id = run_path_to_project_id(run_path) 21 | 22 | artifact_path = project + "/" + "model-" + id + ":" + version 23 | return artifact_path 24 | 25 | 26 | def get_checkpoint(run_path, version="v0", artifact_dir="./artifacts"): 27 | """ 28 | On information tab in WandB find 'Run Path' and copy to clipboard 29 | 30 | --------------------------------------------------------- 31 | run_path: how_so/ULMProjection/1tokrds0 32 | --------------------------------------------------------- 33 | project: how_so/ULMProjection 34 | id: 1tokrds0 35 | artifact_url: how_so/ULMProjection/model-1xon133f:v1 36 | checkpoint: ${artifact_dir}/model-3hysqnmt:v1/model.ckpt 37 | --------------------------------------------------------- 38 | """ 39 | import wandb 40 | 41 | # project, id = run_path_to_project_id(run_path) 42 | artifact_url = run_path_to_artifact_url(run_path, version) 43 | model_name = basename(artifact_url) 44 | checkpoint = join(artifact_dir, model_name, "model.ckpt") 45 | 46 | if not exists(checkpoint): 47 | # URL: always '/' 48 | with wandb.init() as run: 49 | artifact = run.use_artifact(artifact_url, type="model") 50 | _ = artifact.download() 51 | return checkpoint 52 | 53 | 54 | def load_metadata(run_path): 55 | import wandb 56 | 57 | if not run_path.startswith("/"): 58 | run_path = "/" + run_path 59 | 60 | api = wandb.Api() 61 | run = api.run(run_path) 62 | return run 63 | 64 | 65 | def load_model(checkpoint_path=None, run_path=None, eval=True, strict=True, **kwargs): 66 | if checkpoint_path is None: 67 | checkpoint_path = get_checkpoint(run_path=run_path, **kwargs) 68 | model = VPModel.load_from_checkpoint(checkpoint_path, strict=strict) 69 | if torch.cuda.is_available(): 70 | model = model.to("cuda") 71 | 72 | if eval: 73 | model = model.eval() 74 | return model 75 | 76 | 77 | def load_dm( 78 | model=None, 79 | vad_hz=100, 80 | horizon=2, 81 | batch_size=4, 82 | num_workers=4, 83 | audio_duration=10, 84 | audio_overlap=1, 85 | ): 86 | data_conf = DialogAudioDM.load_config() 87 | 88 | if model is not None: 89 | horizon = round(sum(model.conf["vad_projection"]["bin_times"]), 2) 90 | vad_hz = model.frame_hz 91 | 92 | dm = DialogAudioDM( 93 | datasets=data_conf["dataset"]["datasets"], 94 | type=data_conf["dataset"]["type"], 95 | # audio_duration=data_conf["dataset"]["audio_duration"], 96 | audio_duration=audio_duration, 97 | audio_normalize=data_conf["dataset"]["audio_normalize"], 98 | audio_overlap=audio_overlap, 99 | sample_rate=data_conf["dataset"]["sample_rate"], 100 | vad_hz=vad_hz, 101 | vad_horizon=horizon, 102 | vad_history=data_conf["dataset"]["vad_history"], 103 | vad_history_times=data_conf["dataset"]["vad_history_times"], 104 | flip_channels=False, # don't flip on evaluation 105 | batch_size=batch_size, 106 | num_workers=num_workers, 107 | ) 108 | dm.prepare_data() 109 | dm.setup(None) 110 | return dm 111 | 112 | 113 | # Temporary 114 | def load_paper_versions(checkpoint_path, savepath=None): 115 | """ 116 | The code was reformatted and simplified and so some paramter names were changed. 117 | 118 | This functions can load the checkpoints (at the paper version) and replace older names 119 | to create a new state_dict appropriate for the new version 120 | 121 | WARNING! 122 | The optimizer state is not changed so will probably be bad to continue training with that optimizer 123 | """ 124 | 125 | print("Old Paper version checkpoint -> new") 126 | 127 | dir = dirname(checkpoint_path) 128 | name = basename(checkpoint_path) 129 | 130 | chpt = torch.load(checkpoint_path) 131 | sd = chpt["state_dict"] 132 | from_to = { 133 | "net.projection_head.weight": "net.vap_head.projection_head.weight", 134 | "net.projection_head.bias": "net.vap_head.projection_head.bias", 135 | } 136 | new_sd = {} 137 | for param, weight in sd.items(): 138 | if param in from_to: 139 | print(param, "->", from_to[param]) 140 | param = from_to[param] 141 | new_sd[param] = weight 142 | chpt["state_dict"] = new_sd 143 | if savepath is None: 144 | new_name = name.replace(".ckpt", "_new.ckpt") 145 | savepath = join(dir, new_name) 146 | torch.save(chpt, savepath) 147 | return savepath 148 | -------------------------------------------------------------------------------- /conv_ssl/evaluation/vad.py: -------------------------------------------------------------------------------- 1 | from datasets_turntaking.utils import load_waveform 2 | from tqdm import tqdm 3 | 4 | from conv_ssl.utils import read_json, write_json 5 | from conv_ssl.evaluation.duration import ( 6 | audio_path_text_grid_path, 7 | read_text_grid, 8 | ) 9 | 10 | try: 11 | from pyannote.audio import Pipeline 12 | except ImportError as e: 13 | print( 14 | """ 15 | Install pyannote 16 | ```bash 17 | conda create -n pyannote python=3.8 18 | conda activate pyannote 19 | conda install pytorch torchaudio -c pytorch 20 | pip install https://github.com/pyannote/pyannote-audio/archive/develop.zip 21 | ``` 22 | """ 23 | ) 24 | 25 | 26 | def text_grid_to_vad_list(sample): 27 | vad_list = [] 28 | for s, e, w in sample["words"]: 29 | vad_list.append((s, e)) 30 | # return as if two channel audio 31 | return [vad_list, []] 32 | 33 | 34 | class VadExtractor: 35 | def __init__(self): 36 | self.vad_pipeline = Pipeline.from_pretrained( 37 | "pyannote/voice-activity-detection" 38 | ) 39 | 40 | def __call__(self, y, sample_rate): 41 | vad_list = [[], []] 42 | for channel in range(y.shape[0]): 43 | audio_in_memory = { 44 | "waveform": y[channel : channel + 1], 45 | "sample_rate": sample_rate, 46 | } 47 | out = self.vad_pipeline(audio_in_memory) 48 | for segment in out.get_timeline(): 49 | vad_list[channel].append( 50 | [round(segment.start, 2), round(segment.end, 2)] 51 | ) 52 | return vad_list 53 | 54 | 55 | def preprocess_vad_with_model(path="assets/phrases_beta/phrases.json"): 56 | try: 57 | from conv_ssl.evaluation.vad import VadExtractor 58 | except ImportError as e: 59 | print( 60 | "PyAnnote not installed. No preprocessing available... (only required once)" 61 | ) 62 | raise e 63 | 64 | vadder = VadExtractor() 65 | sample_rate = 16000 # pyannot vad 66 | data = read_json(path) 67 | for example, short_long_dict in tqdm(data.items()): 68 | for short_long, gender_dict in short_long_dict.items(): 69 | for gender, sample_list in gender_dict.items(): 70 | for sample in sample_list: 71 | waveform, sr = load_waveform( 72 | sample["audio_path"], sample_rate=sample_rate 73 | ) 74 | vad_list = vadder(waveform, sample_rate=sr) 75 | sample["vad"] = vad_list 76 | write_json(data, path) 77 | 78 | 79 | def _test_vadder(): 80 | wav_path = "assets/phrases/audio/basketball_long_female_en-US-Wavenet-C.wav" 81 | y, sr = load_waveform(wav_path, sample_rate=16000) 82 | vadder = VadExtractor() 83 | vad_list = vadder(y, sample_rate=sr) 84 | print(vad_list) 85 | 86 | 87 | def preprocess_vad_from_align(path): 88 | sample_rate = 16000 # pyannot vad 89 | data = read_json(path) 90 | for example, short_long_dict in tqdm(data.items()): 91 | for short_long, gender_dict in short_long_dict.items(): 92 | for gender, sample_list in gender_dict.items(): 93 | for sample in sample_list: 94 | waveform, sr = load_waveform( 95 | sample["audio_path"], sample_rate=sample_rate 96 | ) 97 | tg = read_text_grid(audio_path_text_grid_path(sample["audio_path"])) 98 | sample["vad"] = text_grid_to_vad_list(sample) 99 | write_json(data, path) 100 | 101 | 102 | if __name__ == "__main__": 103 | # preprocess_vad_with_model("assets/phrases_beta/phrases.json") 104 | preprocess_vad_from_align("assets/phrases_beta/phrases.json") 105 | -------------------------------------------------------------------------------- /conv_ssl/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import Encoder 2 | from .autoregressive import AR 3 | -------------------------------------------------------------------------------- /conv_ssl/models/autoregressive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from conv_ssl.models.transformer import GPT 5 | from conv_ssl.models.transformer_old import CausalTransformer 6 | 7 | 8 | class AR(nn.Module): 9 | """Simplified version of original `CPCAR` module""" 10 | 11 | TYPES = ["gru", "sru", "lstm", "transformer", "gpt"] 12 | 13 | def __init__( 14 | self, 15 | input_dim, 16 | dim, 17 | num_layers, 18 | dropout, 19 | ar="LSTM", 20 | transfomer_kwargs=None, 21 | keep_hidden=False, 22 | ): 23 | super().__init__() 24 | self.input_dim = input_dim 25 | self.dim = dim 26 | self.num_layers = num_layers 27 | self.dropout = dropout 28 | 29 | self.ar_type = ar.lower() 30 | self.ar = self._ar(ar, transfomer_kwargs) 31 | self.hidden = None 32 | self.keep_hidden = keep_hidden 33 | 34 | def _ar(self, ar, transfomer_kwargs): 35 | ar = ar.lower() 36 | assert ar in self.TYPES, 'Please choose ["GRU", "LSTM", "transformer", "gpt"]' 37 | 38 | ret = nn.Identity() 39 | if ar == "gru": 40 | ret = nn.GRU( 41 | self.input_dim, self.dim, num_layers=self.num_layers, batch_first=True 42 | ) 43 | elif ar == "lstm": 44 | ret = nn.LSTM( 45 | self.input_dim, self.dim, num_layers=self.num_layers, batch_first=True 46 | ) 47 | elif ar == "sru": 48 | raise NotADirectoryError("SRU not implemented!") 49 | # ret = SRU(self.input_dim, self.dim, num_layers=self.num_layers) 50 | 51 | # TODO: input projection if input_dim != dim 52 | elif ar == "transformer": 53 | ret = CausalTransformer( 54 | dim=self.dim, 55 | dff_k=transfomer_kwargs["dff_k"], 56 | num_layers=self.num_layers, 57 | num_heads=transfomer_kwargs["num_heads"], 58 | dropout=self.dropout, 59 | sizeSeq=transfomer_kwargs["sizeSeq"], 60 | abspos=transfomer_kwargs["abspos"], 61 | ) 62 | # if not transfomer_kwargs["use_pos_emb"] and self.dim != self.input_dim: 63 | if self.dim != self.input_dim: 64 | ret = nn.Sequential( 65 | nn.Linear(self.input_dim, self.dim), nn.LayerNorm(self.dim), ret 66 | ) 67 | elif ar == "gpt": 68 | ret = GPT( 69 | dim=self.dim, 70 | dff_k=transfomer_kwargs["dff_k"], 71 | num_layers=self.num_layers, 72 | num_heads=transfomer_kwargs["num_heads"], 73 | activation="GELU", 74 | dropout=self.dropout, 75 | use_pos_emb=transfomer_kwargs["use_pos_emb"], # False -> Alibi 76 | max_context=transfomer_kwargs["max_context"], 77 | ) 78 | 79 | return ret 80 | 81 | def forward(self, x, attention=False): 82 | ret = {} 83 | if self.ar_type == "transformer": 84 | x = self.ar(x) 85 | ret["z"] = x 86 | elif self.ar_type == "gpt": 87 | x = self.ar(x, attention=attention) 88 | if attention: 89 | x, attn = x 90 | ret["attn"] = attn 91 | ret["z"] = x 92 | else: 93 | x, h = self.ar(x) 94 | ret["z"] = x 95 | ret["h"] = x 96 | if self.keep_hidden: 97 | if isinstance(h, tuple): 98 | self.hidden = tuple(x.detach() for x in h) 99 | else: 100 | self.hidden = h.detach() 101 | return ret 102 | 103 | 104 | def _test_ar(config_name): 105 | from conv_ssl.utils import load_hydra_conf 106 | from omegaconf import OmegaConf 107 | 108 | conf = load_hydra_conf(config_name=config_name) 109 | conf = conf["model"] 110 | print(OmegaConf.to_yaml(conf)) 111 | B = 4 112 | N = 100 113 | D = 256 114 | # Autoregressive 115 | model = AR( 116 | input_dim=D, 117 | dim=conf["ar"]["dim"], 118 | num_layers=conf["ar"]["num_layers"], 119 | dropout=conf["ar"]["dropout"], 120 | ar=conf["ar"]["type"], 121 | transfomer_kwargs=dict( 122 | num_heads=conf["ar"]["num_heads"], 123 | dff_k=conf["ar"]["dff_k"], 124 | use_pos_emb=conf["ar"]["use_pos_emb"], 125 | max_context=conf["ar"].get("max_context", None), 126 | abspos=conf["ar"].get("abspos", None), 127 | sizeSeq=conf["ar"].get("sizeSeq", None), 128 | ), 129 | ) 130 | # print(model) 131 | x = torch.rand((B, N, D)) 132 | print("x: ", x.shape) 133 | o = model(x) 134 | print(o["z"].shape) 135 | 136 | 137 | if __name__ == "__main__": 138 | _test_ar("model/discrete") 139 | _test_ar("model/discrete_20hz") 140 | -------------------------------------------------------------------------------- /conv_ssl/models/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from einops.layers.torch import Rearrange 5 | 6 | 7 | class LayerNorm(nn.Module): 8 | """ 9 | Extending `nn.LayerNorm` by rearranging input dims to normalize over channel dimension in convnets. 10 | 11 | The original `nn.LayerNorm` + 2 einops Rearrange is faster than custom Norm which calculated values directly on channel... 12 | """ 13 | 14 | def __init__(self, dim: int, rearrange_outputs: bool = True) -> None: 15 | super().__init__() 16 | self.ln = nn.LayerNorm(dim) 17 | self.in_rearrange = Rearrange("b d t -> b t d") 18 | if rearrange_outputs: 19 | self.out_rearrange = Rearrange("b t d -> b d t") 20 | else: 21 | self.out_rearrange = nn.Identity() 22 | 23 | def __repr__(self): 24 | return str(self.ln) 25 | 26 | def forward(self, x): 27 | return self.out_rearrange(self.ln(self.in_rearrange(x))) 28 | 29 | 30 | class CConv1d(nn.Conv1d): 31 | """source: https://github.com/pytorch/pytorch/issues/1333""" 32 | 33 | def __init__( 34 | self, 35 | in_channels, 36 | out_channels, 37 | kernel_size, 38 | stride=1, 39 | dilation=1, 40 | groups=1, 41 | padding_value=0, 42 | bias=True, 43 | **kwargs, 44 | ): 45 | super().__init__( 46 | in_channels, 47 | out_channels, 48 | kernel_size=kernel_size, 49 | stride=stride, 50 | dilation=dilation, 51 | groups=groups, 52 | bias=bias, 53 | **kwargs, 54 | ) 55 | 56 | ks = kernel_size if isinstance(kernel_size, int) else kernel_size[0] 57 | pad_dim1_pre = ks - 1 58 | pad_dim1_post = 0 59 | if dilation > 0: 60 | pad_dim1_pre *= dilation 61 | pad = (pad_dim1_pre, pad_dim1_post) 62 | self.pad = nn.ConstantPad1d(padding=pad, value=padding_value) 63 | 64 | def debug_weights(self, type="sum"): 65 | w = 1.0 66 | if type == "mean": 67 | w = 1.0 / self.kernel_size[0] 68 | 69 | elif type == "range": 70 | k = self.kernel_size[0] 71 | w = torch.arange(1, k + 1).float().pow(2) 72 | w = w.repeat(self.out_channels, self.in_channels, 1) 73 | print("w: ", w.shape) 74 | self.weight.data = self.weight.data = w 75 | if self.bias: 76 | self.bias.data = self.bias.data.fill_(0.0) 77 | return None 78 | 79 | self.weight.data = self.weight.data.fill_(w) 80 | if self.bias: 81 | self.bias.data = self.bias.data.fill_(0.0) 82 | 83 | def forward(self, input): 84 | return super().forward(self.pad(input)) 85 | -------------------------------------------------------------------------------- /conv_ssl/models/cpc_base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import argparse 4 | from os.path import join, exists, dirname 5 | from os import makedirs 6 | 7 | from cpc.model import CPCModel as cpcmodel 8 | from cpc.cpc_default_config import get_default_cpc_config 9 | from cpc.feature_loader import getEncoder, getAR, loadArgs 10 | 11 | from conv_ssl.utils import repo_root 12 | 13 | """ 14 | torch.hub downloads to: 15 | 16 | `~/.cache/torch/hub/checkpoints/` 17 | 18 | Explicit checkpoint path saved manually in "assets/" see CHECKPOINTS below. 19 | """ 20 | 21 | 22 | CHECKPOINTS = { 23 | "cpc": join(repo_root(), "assets/checkpoints/cpc/60k_epoch4-d0f474de.pt") 24 | } 25 | NAMES = list(CHECKPOINTS.keys()) 26 | 27 | 28 | def load_CPC(): 29 | """ 30 | Contrast predictive learning model for audio data 31 | pretrained: if True, load a model trained on libri-light 60k 32 | (https://arxiv.org/abs/1912.07875) 33 | **kwargs : see cpc/cpc_default_config to get the list of possible arguments 34 | """ 35 | locArgs = get_default_cpc_config() 36 | if exists(CHECKPOINTS["cpc"]): 37 | checkpoint = torch.load(CHECKPOINTS["cpc"], map_location="cpu") 38 | else: 39 | checkpoint_url = "https://dl.fbaipublicfiles.com/librilight/CPC_checkpoints/60k_epoch4-d0f474de.pt" 40 | checkpoint = torch.hub.load_state_dict_from_url( 41 | checkpoint_url, progress=False, map_location="cpu" 42 | ) 43 | makedirs(dirname(CHECKPOINTS["cpc"])) 44 | torch.save(checkpoint, CHECKPOINTS["cpc"]) 45 | loadArgs(locArgs, argparse.Namespace(**checkpoint["config"])) 46 | encoderNet = getEncoder(locArgs) 47 | arNet = getAR(locArgs) 48 | model = cpcmodel(encoderNet, arNet) 49 | 50 | # always load pretrained 51 | model.load_state_dict(checkpoint["weights"], strict=False) 52 | model.name = "cpc" 53 | return model 54 | -------------------------------------------------------------------------------- /conv_ssl/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import einops 4 | from einops.layers.torch import Rearrange 5 | 6 | from conv_ssl.models.cpc_base_model import load_CPC 7 | from conv_ssl.models.cnn import CConv1d, LayerNorm 8 | 9 | 10 | def get_cnn_layer(dim, kernel, stride, dilation, activation): 11 | layers = [Rearrange("b t d -> b d t")] 12 | for k, s, d in zip(kernel, stride, dilation): 13 | layers.append(CConv1d(dim, dim, kernel_size=k, stride=s, dilation=d)) 14 | layers.append(LayerNorm(dim)) 15 | layers.append(getattr(torch.nn, activation)()) 16 | layers.append(Rearrange("b d t -> b t d")) 17 | return nn.Sequential(*layers) 18 | 19 | 20 | class Encoder(nn.Module): 21 | """ 22 | Encoder: waveform -> h 23 | pretrained: default='cpc' 24 | 25 | A simpler version of the Encoder 26 | check paper (branch) version to see other encoders... 27 | """ 28 | 29 | def __init__(self, conf, freeze=True): 30 | super().__init__() 31 | self.conf = conf 32 | self.name = conf["name"] 33 | self.frame_hz = conf["frame_hz"] 34 | self.sample_rate = conf["sample_rate"] 35 | self.encoder_layer = conf["output_layer"] 36 | self.encoder = load_CPC() 37 | self.output_dim = self.encoder.gEncoder.conv4.out_channels 38 | 39 | if conf.get("downsample", False): 40 | down = conf["downsample"] 41 | self.downsample = get_cnn_layer( 42 | dim=self.output_dim, 43 | kernel=down["kernel"], 44 | stride=down["stride"], 45 | dilation=down["dilation"], 46 | activation=down["activation"], 47 | ) 48 | else: 49 | self.downsample = nn.Identity() 50 | 51 | if freeze: 52 | self.freeze() 53 | 54 | def freeze(self): 55 | for p in self.encoder.parameters(): 56 | p.requires_grad_(False) 57 | print(f"Froze {self.__class__.__name__}!") 58 | 59 | def unfreeze(self): 60 | for p in self.encoder.parameters(): 61 | p.requires_grad_(True) 62 | print(f"Trainable {self.__class__.__name__}!") 63 | 64 | def encode(self, waveform): 65 | if waveform.ndim < 3: 66 | waveform = waveform.unsqueeze(1) # channel dim 67 | 68 | # Backwards using only the encoder encounters: 69 | # --------------------------------------------------- 70 | # RuntimeError: one of the variables needed for gradient computation 71 | # has been modified by an inplace operation: 72 | # [torch.FloatTensor [4, 256, 1000]], which is output 0 of ReluBackward0, is at version 1; 73 | # expected version 0 instead. Hint: enable anomaly detection to find 74 | # the operation that failed to compute its gradient, with 75 | # torch.autograd.set_detect_anomaly(True). 76 | z = self.encoder.gEncoder(waveform) # .permute(0, 2, 1) 77 | z = einops.rearrange(z, "b c n -> b n c") 78 | 79 | # However, if we feed through gAR we do not encounter that problem... 80 | if self.encoder_layer > 0: 81 | z = self.encoder.gAR(z) 82 | return z 83 | 84 | def forward(self, waveform): 85 | z = self.encode(waveform) 86 | z = self.downsample(z) 87 | return {"z": z} 88 | 89 | 90 | def _test_encoder(config_name): 91 | from conv_ssl.utils import load_hydra_conf 92 | 93 | conf = load_hydra_conf(config_name=config_name) 94 | econf = conf["model"]["encoder"] 95 | enc = Encoder(econf, freeze=econf["freeze"]) 96 | x = torch.rand((4, econf["sample_rate"])) 97 | out = enc(x) 98 | z = out["z"] 99 | print("Config: ", config_name) 100 | print("x: ", tuple(x.shape)) 101 | print("z: ", tuple(z.shape)) 102 | 103 | 104 | if __name__ == "__main__": 105 | _test_encoder("model/discrete") 106 | _test_encoder("model/discrete_20hz") 107 | _test_encoder("model/discrete_50hz") 108 | -------------------------------------------------------------------------------- /conv_ssl/models/multi_head_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops.layers.torch import Rearrange 6 | from typing import Optional 7 | 8 | 9 | def prepare_causal_mask(T, device="cpu"): 10 | mask = torch.tril(torch.ones((T, T), device=device)).view(1, 1, T, T) 11 | mask.requires_grad_(False) 12 | return mask 13 | 14 | 15 | def get_slopes(n): 16 | """ 17 | * aLiBi slopes for heads. 18 | * m in Figure 3. 19 | * Source: 20 | - https://github.com/ofirpress/attention_with_linear_biases/blob/5b327adc6d131e28b40ba58906b30bb469483519/fairseq/models/transformer.py#L742 21 | 22 | Comments: 23 | 24 | In the paper, we only train models that have 2^a heads for some a. This function has 25 | some good properties that only occur when the input is a power of 2. 26 | To maintain that even closest_power_of_2 = 2**math.floor(math.log2(n)) 27 | when the number of heads is not a power of 2, we use this workaround. 28 | """ 29 | 30 | def get_slopes_power_of_2(n): 31 | start = 2 ** (-(2 ** -(math.log2(n) - 3))) 32 | ratio = start 33 | return [start * ratio ** i for i in range(n)] 34 | 35 | # In the paper, we only train models that have 2^a heads for some a. This function has 36 | # some good properties that only occur when the input is a power of 2. To maintain that even 37 | # when the number of heads is not a power of 2, we use this workaround. 38 | if math.log2(n).is_integer(): 39 | slopes = get_slopes_power_of_2(n) 40 | else: 41 | closest_power_of_2 = 2 ** math.floor(math.log2(n)) 42 | slopes = ( 43 | get_slopes_power_of_2(closest_power_of_2) 44 | + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] 45 | ) 46 | return slopes 47 | 48 | 49 | def get_relative_bias_matrix(n, num_heads, device="cpu"): 50 | """Relative Bias matrix for aLiBi embeddings""" 51 | return torch.arange(n, device=device).view(1, 1, -1).expand(1, num_heads, -1) 52 | 53 | 54 | class MultiHeadAttention(nn.Module): 55 | """ 56 | A vanilla multi-head masked self-attention layer with a projection at the end. 57 | It is possible to use torch.nn.MultiheadAttention here but I am including an 58 | explicit implementation here to show that there is nothing too scary here. 59 | """ 60 | 61 | def __init__(self, dim: int, num_heads: int, dropout: float, bias: bool = False): 62 | super().__init__() 63 | assert dim % num_heads == 0 64 | self.num_heads = num_heads 65 | self.dim = dim 66 | 67 | # key, query, value projections for all heads 68 | self.key = nn.Linear(dim, dim, bias=bias) 69 | self.query = nn.Linear(dim, dim, bias=bias) 70 | self.value = nn.Linear(dim, dim, bias=bias) 71 | 72 | # head re-shapers 73 | self.unstack_heads = Rearrange("b t (h d) -> b h t d", h=self.num_heads) 74 | self.stack_heads = Rearrange("b h t d -> b t (h d)") 75 | 76 | # regularization 77 | self.attn_drop = nn.Dropout(dropout) 78 | self.resid_drop = nn.Dropout(dropout) 79 | 80 | # output projection 81 | self.proj = nn.Linear(dim, dim, bias=bias) 82 | self.scale = 1.0 / math.sqrt(dim) 83 | 84 | def get_scores(self, q: torch.Tensor, k: torch.Tensor): 85 | """ 86 | Arguments: 87 | q: (B, heads, T, D) 88 | k: (B, heads, T, D) 89 | 90 | Return: 91 | QK: (B, heads, T, T) 92 | """ 93 | return torch.einsum("bhid,bhjd->bhij", q, k) 94 | 95 | def mask_scores(self, qk: torch.Tensor, mask=None): 96 | T = qk.size(-1) 97 | if mask is None: 98 | mask = prepare_causal_mask(T, device=qk.device) 99 | qk = qk.masked_fill(mask == 0, float("-inf")) 100 | return qk 101 | 102 | def forward( 103 | self, 104 | Q: torch.Tensor, 105 | K: torch.Tensor, 106 | V: torch.Tensor, 107 | mask: Optional[torch.Tensor] = None, 108 | ): 109 | # batch size, sequence length, embedding dimensionality (n_embd) 110 | B, T, D = Q.size() 111 | 112 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 113 | k = self.unstack_heads(self.key(K)) # (B, heads, T, D_head) 114 | q = self.unstack_heads(self.query(Q)) # (B, heads, T, D_head) 115 | v = self.unstack_heads(self.value(V)) # (B, heads, T, D_head) 116 | 117 | # QK 118 | att = self.get_scores(q, k) * self.scale # (B, nh, T, T) 119 | att = self.mask_scores(att, mask) 120 | att = F.softmax(att, dim=-1) 121 | 122 | # Softmax, dropout, values 123 | y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 124 | 125 | # re-assemble all head outputs side by side 126 | y = self.stack_heads(y) 127 | 128 | # output projection 129 | y = self.resid_drop(self.proj(y)) 130 | return y, att 131 | 132 | 133 | class MultiHeadAttentionAlibi(MultiHeadAttention): 134 | def __init__(self, dim: int, num_heads: int, dropout: float, bias: bool = False): 135 | super().__init__(dim, num_heads, dropout, bias) 136 | self.m = torch.tensor(get_slopes(num_heads)) 137 | self.m.requires_grad_(False) 138 | self.mask = None 139 | 140 | def get_alibi_mask(self, T: int, device="cpu"): 141 | rel_bias_mat = get_relative_bias_matrix(T, self.num_heads, device) 142 | alibi = rel_bias_mat * self.m.unsqueeze(0).unsqueeze(-1).to(device) 143 | 144 | # Causal mask (standard GPT pask) 145 | # lower triangle = 1 146 | # upper triangle = 0 147 | mask = prepare_causal_mask(T, device) # (1, 1, T, T) 148 | # Repeat to get a mask for each head 149 | mask = mask.repeat(1, self.num_heads, 1, 1) # (1, num_heads, T, T) 150 | # fill "future" information with negative infinity 151 | mask.masked_fill_(mask == 0, float("-inf")) 152 | 153 | # Add causality mask to alibi (1, num_heads, T, T) 154 | alibi = alibi.unsqueeze(-2) + mask 155 | alibi.requires_grad_(False) # this should not be trained 156 | return alibi 157 | 158 | def mask_scores(self, qk: torch.Tensor, mask=None): 159 | T = qk.size(-1) 160 | if mask is None: 161 | if self.mask is None or self.mask.shape[-1] < T: 162 | mask = self.get_alibi_mask(T, device=qk.device) 163 | self.mask = mask 164 | else: 165 | mask = self.mask[..., :T, :T] 166 | 167 | # add aLiBi-mask to qk (see Figure 3.) 168 | # Addition/translation does not effect softmax (over each row) 169 | # mentioned in the original representation 170 | qk = qk + mask.to(qk.device) 171 | return qk 172 | 173 | 174 | def _test_alibi(): 175 | """https://github.com/ofirpress/attention_with_linear_biases""" 176 | 177 | import matplotlib.pyplot as plt 178 | 179 | N = 20 180 | num_heads = 8 181 | mha = MultiHeadAttentionAlibi(dim=256, num_heads=num_heads, dropout=0) 182 | mask = mha.get_alibi_mask(N) 183 | print("mask: ", tuple(mask.shape)) 184 | 185 | fig, ax = plt.subplots(num_heads, 1, sharex=True, sharey=True, figsize=(6, 12)) 186 | for h in range(num_heads): 187 | ax[h].imshow( 188 | mask[0, h], 189 | aspect="auto", 190 | origin="upper", 191 | interpolation="none", 192 | vmin=0, 193 | vmax=10, 194 | cmap="viridis", 195 | ) 196 | # plt.pause(0.1) 197 | plt.show() 198 | 199 | 200 | if __name__ == "__main__": 201 | _test_alibi() 202 | -------------------------------------------------------------------------------- /conv_ssl/models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | from typing import Optional, Tuple 6 | 7 | from conv_ssl.models.multi_head_attention import ( 8 | MultiHeadAttentionAlibi, 9 | MultiHeadAttention, 10 | ) 11 | 12 | 13 | class StaticPositionEmbedding(nn.Module): 14 | def __init__(self, seqlen, dmodel): 15 | super(StaticPositionEmbedding, self).__init__() 16 | pos = torch.arange(0.0, seqlen).unsqueeze(1).repeat(1, dmodel) 17 | dim = torch.arange(0.0, dmodel).unsqueeze(0).repeat(seqlen, 1) 18 | div = torch.exp( 19 | -math.log(10000) * (2 * torch.div(dim, 2, rounding_mode="trunc") / dmodel) 20 | ) 21 | pos *= div 22 | pos[:, 0::2] = torch.sin(pos[:, 0::2]) 23 | pos[:, 1::2] = torch.cos(pos[:, 1::2]) 24 | self.register_buffer("pe", pos.unsqueeze(0)) 25 | 26 | def forward(self, x: torch.Tensor) -> torch.Tensor: 27 | return x + self.pe[:, : x.size(1), :] 28 | 29 | 30 | def ffn_block( 31 | din: int, 32 | dff: int, 33 | activation: str = "GELU", 34 | dropout: float = 0.0, 35 | bias: bool = False, 36 | ) -> nn.Sequential: 37 | return nn.Sequential( 38 | nn.Linear(din, dff, bias=bias), 39 | getattr(nn, activation)(), 40 | nn.Dropout(p=dropout), 41 | nn.Linear(dff, din, bias=bias), 42 | ) 43 | 44 | 45 | class TransformerLayer(nn.Module): 46 | """ 47 | Transformer Layer 48 | 49 | Using pre-layer-normalization: https://arxiv.org/pdf/2002.04745.pdf 50 | """ 51 | 52 | def __init__( 53 | self, 54 | dim: int = 512, 55 | ffn_dim: int = 1536, 56 | num_heads: int = 8, 57 | ffn_activation: str = "GELU", 58 | dropout: float = 0.1, 59 | position_emb: bool = False, 60 | ): 61 | super().__init__() 62 | self.ln_multihead = nn.LayerNorm(dim) 63 | self.ln_ffnetwork = nn.LayerNorm(dim) 64 | self.dropout = nn.Dropout(p=dropout) 65 | 66 | if position_emb: 67 | self.multihead = MultiHeadAttention( 68 | dim=dim, num_heads=num_heads, dropout=dropout 69 | ) 70 | else: 71 | self.multihead = MultiHeadAttentionAlibi( 72 | dim=dim, num_heads=num_heads, dropout=dropout 73 | ) 74 | self.ffnetwork = ffn_block( 75 | dim, ffn_dim, activation=ffn_activation, dropout=dropout 76 | ) 77 | 78 | def post_layer_norm_forward( 79 | self, x: torch.Tensor, mask: Optional[torch.Tensor] = None 80 | ) -> Tuple[torch.Tensor, torch.Tensor]: 81 | """Not used but kept here for reference""" 82 | h, attn = self.multihead(Q=x, K=x, V=x, mask=mask) 83 | h = self.ln_multihead(x + h) 84 | h = self.ln_ffnetwork(h + self.ffnetwork(h)) 85 | return h, attn 86 | 87 | def forward( 88 | self, x: torch.Tensor, mask: Optional[torch.Tensor] = None 89 | ) -> Tuple[torch.Tensor, torch.Tensor]: 90 | h = self.ln_multihead(x) 91 | h, attn = self.multihead(Q=h, K=h, V=h, mask=mask) 92 | h = x + self.dropout(h) 93 | h = x + h 94 | h = h + self.dropout(self.ffnetwork(self.ln_ffnetwork(h))) 95 | return h, attn 96 | 97 | 98 | class GPT(nn.Module): 99 | def __init__( 100 | self, 101 | dim: int, 102 | dff_k: int = 3, 103 | num_layers: int = 4, 104 | num_heads: int = 4, 105 | activation: str = "GELU", 106 | dropout: float = 0.1, 107 | use_pos_emb: bool = False, # False -> Alibi 108 | max_context: int = 1024, 109 | ): 110 | super().__init__() 111 | self.dim = dim 112 | self.dff = int(dim * dff_k) 113 | self.num_layers = num_layers 114 | self.num_heads = num_heads 115 | self.activation = activation 116 | self.dropout = dropout 117 | self.use_pos_emb = use_pos_emb 118 | 119 | if self.use_pos_emb: 120 | self.max_context = max_context 121 | self.pos_emb = StaticPositionEmbedding(max_context, self.dim) 122 | else: 123 | self.pos_emb = nn.Identity() 124 | 125 | layers = [] 126 | for _ in range(self.num_layers): 127 | layers.append( 128 | TransformerLayer( 129 | dim=self.dim, 130 | ffn_dim=self.dff, 131 | num_heads=self.num_heads, 132 | ffn_activation=self.activation, 133 | dropout=self.dropout, 134 | position_emb=self.use_pos_emb, 135 | ) 136 | ) 137 | self.layers = nn.ModuleList(layers) 138 | self.apply(self._init_weights) 139 | 140 | def _init_weights(self, module): 141 | if isinstance(module, (nn.Linear, nn.Embedding)): 142 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 143 | if isinstance(module, nn.Linear) and module.bias is not None: 144 | torch.nn.init.zeros_(module.bias) 145 | elif isinstance(module, nn.LayerNorm): 146 | torch.nn.init.zeros_(module.bias) 147 | torch.nn.init.ones_(module.weight) 148 | 149 | def forward(self, x, attention=False): 150 | all_attention = [] 151 | 152 | x = self.pos_emb(x) 153 | for layer in self.layers: 154 | x, attn = layer(x) 155 | if attention: 156 | all_attention.append(attn) 157 | 158 | if attention: 159 | attn = torch.stack(all_attention, dim=1) 160 | return x, attn 161 | 162 | return x 163 | 164 | 165 | def _test_gpt(): 166 | import matplotlib.pyplot as plt 167 | 168 | model = GPT(dim=256, dff_k=3, num_layers=4, num_heads=8) 169 | x = torch.rand((4, 20, model.dim)) 170 | with torch.no_grad(): 171 | z, attn = model(x, attention=True) 172 | print("z: ", tuple(z.shape)) 173 | print("attn: ", tuple(attn.shape)) 174 | b = 0 175 | fig, ax = plt.subplots( 176 | model.num_heads, model.num_layers, sharex=True, sharey=True, figsize=(12, 12) 177 | ) 178 | for n_layer in range(model.num_layers): 179 | for n_head in range(model.num_heads): 180 | ax[n_head, n_layer].imshow( 181 | attn[b, n_layer, n_head], 182 | aspect="auto", 183 | origin="upper", 184 | interpolation="none", 185 | vmin=0, 186 | vmax=1, 187 | cmap="viridis", 188 | ) 189 | if n_layer == 0: 190 | ax[n_head, n_layer].set_ylabel(f"Head {n_head}") 191 | if n_head == 0: 192 | ax[n_head, n_layer].set_title(f"Layer {n_layer}") 193 | ax[0, 0].set_xticks([]) 194 | ax[0, 0].set_yticks([]) 195 | plt.tight_layout() 196 | plt.show() 197 | 198 | 199 | if __name__ == "__main__": 200 | 201 | _test_gpt() 202 | -------------------------------------------------------------------------------- /conv_ssl/models/transformer_old.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # SOURCE: https://github.com/facebookresearch/CPC_audio 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | 10 | 11 | class ScaledDotProductAttention(nn.Module): 12 | def __init__( 13 | self, 14 | sizeSeq, # Size of the input sequence 15 | dk, # Dimension of the input sequence 16 | dropout, # Dropout parameter 17 | relpos=False, 18 | ): # Do we retrieve positional information ? 19 | super(ScaledDotProductAttention, self).__init__() 20 | 21 | self.drop = nn.Dropout(dropout) 22 | self.softmax = nn.Softmax(dim=2) 23 | self.relpos = relpos 24 | self.sizeSeq = sizeSeq 25 | 26 | if relpos: 27 | self.Krelpos = nn.Parameter(torch.Tensor(dk, sizeSeq)) 28 | self.initmat_(self.Krelpos) 29 | self.register_buffer("z", torch.zeros(1, sizeSeq, 1)) 30 | 31 | # A mask is set so that a node never queries data in the future 32 | mask = torch.tril(torch.ones(sizeSeq, sizeSeq), diagonal=0) 33 | mask = 1 - mask 34 | mask[mask == 1] = -float("inf") 35 | self.register_buffer("mask", mask.unsqueeze(0)) 36 | 37 | def initmat_(self, mat, dim=0): 38 | stdv = 1.0 / math.sqrt(mat.size(dim)) 39 | mat.data.uniform_(-stdv, stdv) 40 | 41 | def forward(self, Q, K, V): 42 | # Input dim : N x sizeSeq x dk 43 | QK = torch.bmm(Q, K.transpose(-2, -1)) 44 | n = Q.shape[1] # used for correct mask 45 | 46 | if self.relpos: 47 | bsz = Q.size(0) 48 | QP = Q.matmul(self.Krelpos) 49 | # This trick with z fills QP's diagonal with zeros 50 | QP = torch.cat((self.z.expand(bsz, -1, -1), QP), 2) 51 | QK += QP.view(bsz, self.sizeSeq + 1, self.sizeSeq)[:, 1:, :] 52 | 53 | A = self.softmax(QK / math.sqrt(K.size(-1)) + self.mask[:, :n, :n]) 54 | return torch.bmm(self.drop(A), V) 55 | 56 | 57 | class MultiHeadAttention(nn.Module): 58 | def __init__( 59 | self, 60 | sizeSeq, # Size of a sequence 61 | dropout, # Dropout parameter 62 | dmodel, # Model's dimension 63 | nheads, # Number of heads in the model 64 | abspos, 65 | ): # Is positional information encoded in the input ? 66 | super(MultiHeadAttention, self).__init__() 67 | self.Wo = nn.Linear(dmodel, dmodel, bias=False) 68 | self.Wk = nn.Linear(dmodel, dmodel, bias=False) 69 | self.Wq = nn.Linear(dmodel, dmodel, bias=False) 70 | self.Wv = nn.Linear(dmodel, dmodel, bias=False) 71 | self.nheads = nheads 72 | self.dk = dmodel // nheads 73 | self.Att = ScaledDotProductAttention( 74 | sizeSeq, self.dk, dropout, relpos=not abspos 75 | ) 76 | 77 | def trans_(self, x): 78 | bsz, bptt, h, dk = x.size(0), x.size(1), self.nheads, self.dk 79 | return ( 80 | x.view(bsz, bptt, h, dk) 81 | .transpose(1, 2) 82 | .contiguous() 83 | .view(bsz * h, bptt, dk) 84 | ) 85 | 86 | def reverse_trans_(self, x): 87 | bsz, bptt, h, dk = x.size(0) // self.nheads, x.size(1), self.nheads, self.dk 88 | return ( 89 | x.view(bsz, h, bptt, dk) 90 | .transpose(1, 2) 91 | .contiguous() 92 | .view(bsz, bptt, h * dk) 93 | ) 94 | 95 | def forward(self, Q, K, V): 96 | q = self.trans_(self.Wq(Q)) 97 | k = self.trans_(self.Wk(K)) 98 | v = self.trans_(self.Wv(V)) 99 | y = self.reverse_trans_(self.Att(q, k, v)) 100 | return self.Wo(y) 101 | 102 | 103 | class FFNetwork(nn.Module): 104 | def __init__(self, din, dout, dff, dropout): 105 | super(FFNetwork, self).__init__() 106 | self.lin1 = nn.Linear(din, dff, bias=True) 107 | self.lin2 = nn.Linear(dff, dout, bias=True) 108 | self.relu = nn.ReLU() 109 | self.drop = nn.Dropout(dropout) 110 | 111 | def forward(self, x): 112 | return self.lin2(self.drop(self.relu(self.lin1(x)))) 113 | 114 | 115 | class TransformerLayer(nn.Module): 116 | def __init__( 117 | self, sizeSeq=32, dmodel=512, dff=2048, dropout=0.1, nheads=8, abspos=False 118 | ): 119 | super(TransformerLayer, self).__init__() 120 | self.multihead = MultiHeadAttention(sizeSeq, dropout, dmodel, nheads, abspos) 121 | self.ln_multihead = nn.LayerNorm(dmodel) 122 | self.ffnetwork = FFNetwork(dmodel, dmodel, dff, dropout) 123 | self.ln_ffnetwork = nn.LayerNorm(dmodel) 124 | 125 | def forward(self, x): 126 | y = self.ln_multihead(x + self.multihead(Q=x, K=x, V=x)) 127 | return self.ln_ffnetwork(y + self.ffnetwork(y)) 128 | 129 | 130 | class StaticPositionEmbedding(nn.Module): 131 | def __init__(self, seqlen, dmodel): 132 | super(StaticPositionEmbedding, self).__init__() 133 | pos = torch.arange(0.0, seqlen).unsqueeze(1).repeat(1, dmodel) 134 | dim = torch.arange(0.0, dmodel).unsqueeze(0).repeat(seqlen, 1) 135 | # div = torch.exp(-math.log(10000) * (2 * (dim // 2) / dmodel)) 136 | div = torch.exp( 137 | -math.log(10000) * (2 * torch.div(dim, 2, rounding_mode="trunc") / dmodel) 138 | ) 139 | pos *= div 140 | pos[:, 0::2] = torch.sin(pos[:, 0::2]) 141 | pos[:, 1::2] = torch.cos(pos[:, 1::2]) 142 | self.register_buffer("pe", pos.unsqueeze(0)) 143 | 144 | def forward(self, x): 145 | return x + self.pe[:, : x.size(1), :] 146 | 147 | 148 | class CausalTransformer(nn.Module): 149 | def __init__( 150 | self, 151 | dim, 152 | dff_k=3, 153 | num_layers=4, 154 | num_heads=4, 155 | dropout=0.1, 156 | sizeSeq=1024, 157 | abspos=True, 158 | use_pos_emb=True, 159 | ): 160 | super().__init__() 161 | self.dim = dim 162 | self.dff = int(dim * dff_k) 163 | self.num_layers = num_layers 164 | self.num_heads = num_heads 165 | self.sizeSeq = sizeSeq 166 | self.abspos = abspos 167 | self.dropout = dropout 168 | 169 | self.use_pos_emb = use_pos_emb 170 | 171 | self.model = self._build_model() 172 | 173 | def _build_model(self): 174 | net = [] 175 | if self.use_pos_emb: 176 | net.append(StaticPositionEmbedding(self.sizeSeq, self.dim)) 177 | for _ in range(self.num_layers): 178 | net.append( 179 | TransformerLayer( 180 | sizeSeq=self.sizeSeq, 181 | dmodel=self.dim, 182 | dff=self.dff, 183 | dropout=self.dropout, 184 | nheads=self.num_heads, 185 | abspos=self.abspos, 186 | ) 187 | ) 188 | return nn.Sequential(*net) 189 | 190 | def forward(self, x): 191 | if x.shape[1] > self.sizeSeq: 192 | raise IndexError( 193 | f"input is longer than maximum sequence length! x: {x.shape} > {self.sizeSeq}" 194 | ) 195 | return self.model(x) 196 | 197 | 198 | def buildTransformerAR( 199 | dimEncoded, # Output dimension of the encoder 200 | nLayers, # Number of transformer layers 201 | sizeSeq, # Expected size of the input sequence 202 | abspos, 203 | ): 204 | layerSequence = [] 205 | if abspos: 206 | layerSequence += [StaticPositionEmbedding(sizeSeq, dimEncoded)] 207 | 208 | for _ in range(nLayers): 209 | layerSequence += [ 210 | TransformerLayer(sizeSeq=sizeSeq, dmodel=dimEncoded, abspos=abspos) 211 | ] 212 | return nn.Sequential(*layerSequence) 213 | 214 | 215 | if __name__ == "__main__": 216 | 217 | model = CausalTransformer(dim=32, num_layers=1, num_heads=2) 218 | x = torch.randint(0, 256, (4, 599, 32)) 219 | y = model(x) 220 | -------------------------------------------------------------------------------- /conv_ssl/train.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig, OmegaConf 2 | from os import makedirs, environ 3 | import hydra 4 | import wandb 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import ( 8 | ModelCheckpoint, 9 | EarlyStopping, 10 | LearningRateMonitor, 11 | StochasticWeightAveraging, 12 | ) 13 | from pytorch_lightning.loggers import WandbLogger 14 | from pytorch_lightning.strategies import DDPStrategy 15 | from conv_ssl.callbacks import WandbArtifactCallback 16 | from conv_ssl.model import VPModel 17 | from conv_ssl.utils import everything_deterministic 18 | from datasets_turntaking import DialogAudioDM 19 | 20 | 21 | everything_deterministic() 22 | 23 | 24 | @hydra.main(config_path="conf", config_name="config") 25 | def train(cfg: DictConfig) -> None: 26 | cfg_dict = OmegaConf.to_object(cfg) 27 | cfg_dict = dict(cfg_dict) 28 | 29 | if "debug" in cfg_dict: 30 | environ["WANDB_MODE"] = "offline" 31 | print("DEBUG -> OFFLINE MODE") 32 | 33 | pl.seed_everything(cfg_dict["seed"]) 34 | local_rank = environ.get("LOCAL_RANK", 0) 35 | 36 | model = VPModel(cfg_dict) 37 | 38 | if cfg_dict["verbose"]: 39 | print("DataModule") 40 | for k, v in cfg_dict["data"].items(): 41 | print(f"{k}: {v}") 42 | print("#" * 60) 43 | 44 | dm = DialogAudioDM(**cfg_dict["data"]) 45 | dm.prepare_data() 46 | 47 | if cfg_dict["trainer"]["fast_dev_run"]: 48 | trainer = pl.Trainer(**cfg_dict["trainer"]) 49 | print(cfg_dict["model"]) 50 | print("-" * 40) 51 | print(dm) 52 | trainer.fit(model, datamodule=dm) 53 | else: 54 | # Callbacks & Logger 55 | logger = WandbLogger( 56 | # save_dir=SA, 57 | project=cfg_dict["wandb"]["project"], 58 | name=model.run_name, 59 | log_model=False, 60 | ) 61 | 62 | if local_rank == 0: 63 | print("#" * 40) 64 | print(f"Early stopping (patience={cfg_dict['early_stopping']['patience']})") 65 | print("#" * 40) 66 | 67 | callbacks = [ 68 | ModelCheckpoint( 69 | mode=cfg_dict["checkpoint"]["mode"], 70 | monitor=cfg_dict["checkpoint"]["monitor"], 71 | ), 72 | EarlyStopping( 73 | monitor=cfg_dict["early_stopping"]["monitor"], 74 | mode=cfg_dict["early_stopping"]["mode"], 75 | patience=cfg_dict["early_stopping"]["patience"], 76 | strict=True, # crash if "monitor" is not found in val metrics 77 | verbose=False, 78 | ), 79 | LearningRateMonitor(), 80 | WandbArtifactCallback(), 81 | ] 82 | 83 | if cfg_dict["optimizer"].get("swa_enable", False): 84 | callbacks.append( 85 | StochasticWeightAveraging( 86 | swa_epoch_start=cfg_dict["optimizer"].get("swa_epoch_start", 5), 87 | annealing_epochs=cfg_dict["optimizer"].get( 88 | "swa_annealing_epochs", 10 89 | ), 90 | ) 91 | ) 92 | 93 | # Find Best Learning Rate 94 | trainer = pl.Trainer(gpus=-1) 95 | lr_finder = trainer.tuner.lr_find(model, dm) 96 | model.learning_rate = lr_finder.suggestion() 97 | print("#" * 40) 98 | print("Initial Learning Rate: ", model.learning_rate) 99 | print("#" * 40) 100 | 101 | # Actual Training 102 | trainer = pl.Trainer( 103 | logger=logger, 104 | callbacks=callbacks, 105 | strategy=DDPStrategy(find_unused_parameters=True), 106 | **cfg_dict["trainer"], 107 | ) 108 | trainer.fit(model, datamodule=dm) 109 | 110 | 111 | def load(): 112 | # from conv_ssl.evaluation.utils import load_model 113 | # checkpoint = "runs/TestHydra/223srezy/checkpoints/epoch=4-step=50.ckpt" 114 | # model = VPModel.load_from_checkpoint(checkpoint) 115 | run = wandb.init() 116 | artifact = run.use_artifact("how_so/TestHydra/3hjsv2z8_model:v0", type="model") 117 | artifact_dir = artifact.download() 118 | checkpoint = artifact_dir + "/model" 119 | print("artifact_dir: ", artifact_dir) 120 | # run_path = "how_so/TestHydra/3hjsv2z8" 121 | # model = load_model(run_path=run_path) 122 | # 123 | # print(model) 124 | 125 | 126 | if __name__ == "__main__": 127 | train() 128 | -------------------------------------------------------------------------------- /conv_ssl/train_disk.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from os import makedirs, environ 3 | 4 | import torch 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.loggers import WandbLogger 7 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 8 | 9 | from datamodule_disk import DiskDMFiles 10 | from conv_ssl.model import VPModel 11 | from conv_ssl.utils import count_parameters, everything_deterministic 12 | 13 | import wandb 14 | 15 | PROJECT = "VPModel" 16 | SAVEDIR = "runs/VPModel" 17 | 18 | 19 | everything_deterministic() 20 | 21 | 22 | class WandbArtifactCallback(pl.Callback): 23 | def upload(self, trainer): 24 | run = trainer.logger.experiment 25 | print(f"Ending run: {run.id}") 26 | artifact = wandb.Artifact(f"{run.id}_model", type="model") 27 | for path, val_loss in trainer.checkpoint_callback.best_k_models.items(): 28 | print(f"Adding artifact: {path}") 29 | artifact.add_file(path) 30 | run.log_artifact(artifact) 31 | 32 | def on_train_end(self, trainer, pl_module): 33 | print("Training End ---------------- Custom Upload") 34 | self.upload(trainer) 35 | 36 | def on_exception(self, trainer, pl_module, exception): 37 | if isinstance(exception, KeyboardInterrupt): 38 | print("Keyboard Interruption ------- Custom Upload") 39 | self.upload(trainer) 40 | 41 | 42 | def train(): 43 | parser = ArgumentParser() 44 | parser = VPModel.add_model_specific_args(parser) 45 | parser = DiskDMFiles.add_data_specific_args(parser) 46 | parser = pl.Trainer.add_argparse_args(parser) 47 | parser.add_argument("--seed", type=int, default=1) 48 | parser.add_argument("--name_info", type=str, default="") 49 | parser.add_argument("--project_info", type=str, default="") 50 | parser.add_argument("--patience", type=int, default=5) 51 | parser.add_argument("--log_gradients", action="store_true") 52 | args = parser.parse_args() 53 | pl.seed_everything(args.seed) 54 | args.deterministic = True 55 | 56 | if args.train_files is None: 57 | raise NotImplementedError('Must provide "--train_files"... Abort') 58 | 59 | if args.val_files is None: 60 | raise NotImplementedError('Must provide "--val_files"... Abort') 61 | 62 | local_rank = environ.get("LOCAL_RANK", 0) 63 | 64 | ######### 65 | # Model # 66 | ######### 67 | conf = VPModel.load_config(path=args.conf, args=args) 68 | model = VPModel(conf) 69 | 70 | # print after callbacks/wandb init 71 | if local_rank == 0: 72 | print("-" * 60) 73 | print(model.summary()) 74 | print(f"Model Name: {model.run_name}") 75 | print("Base: ", args.conf) 76 | print("PARAMETERS: ", count_parameters(model)) 77 | print() 78 | print("-" * 60) 79 | 80 | dm = DiskDMFiles( 81 | args.data_root, 82 | train_files=args.train_files, 83 | val_files=args.val_files, 84 | test_files=args.test_files, 85 | batch_size=args.batch_size, 86 | num_workers=args.num_workers, 87 | ) 88 | 89 | # Callbacks & Logger 90 | logger = None 91 | callbacks = [] 92 | 93 | # this should be handled automatically with pytorch_lightning? 94 | if not args.fast_dev_run: 95 | makedirs(SAVEDIR, exist_ok=True) 96 | logger = WandbLogger( 97 | save_dir=SAVEDIR, 98 | project=PROJECT + args.project_info, 99 | name=model.run_name + args.name_info, 100 | log_model=not args.dont_log_model, 101 | # log_model=True, # True: logs after training finish 102 | ) 103 | 104 | callbacks.append( 105 | ModelCheckpoint( 106 | mode="max", 107 | monitor="val_f1_weighted", 108 | # mode="min", 109 | # monitor="val_loss", 110 | ) 111 | ) 112 | callbacks.append(WandbArtifactCallback()) 113 | verbose = False 114 | if local_rank == 0: 115 | print(f"Early stopping (patience={args.patience})") 116 | verbose = True 117 | 118 | callbacks.append( 119 | EarlyStopping( 120 | monitor="val_f1_weighted", 121 | mode="max", 122 | patience=args.patience, 123 | strict=True, # crash if "monitor" is not found in val metrics 124 | verbose=verbose, 125 | ) 126 | ) 127 | 128 | # Trainer 129 | # args.auto_lr_find = True 130 | trainer = pl.Trainer.from_argparse_args( 131 | args=args, logger=logger, callbacks=callbacks 132 | ) 133 | # auto_finder = trainer.tune(model, dm)["lr_find"] 134 | 135 | trainer.fit(model, datamodule=dm) 136 | 137 | 138 | if __name__ == "__main__": 139 | train() 140 | -------------------------------------------------------------------------------- /conv_ssl/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from conv_ssl.augmentations import ( 5 | flatten_pitch_batch, 6 | shift_pitch_batch, 7 | low_pass_filter_resample, 8 | IntensityNeutralizer, 9 | ) 10 | 11 | 12 | class FlatPitch(nn.Module): 13 | def __init__( 14 | self, 15 | target_f0: int = -1, 16 | statistic: str = "mean", 17 | stats_frame_length: int = 800, 18 | stats_hop_length: int = 320, 19 | sample_rate: int = 16000, 20 | to_mono: bool = True, 21 | ): 22 | super().__init__() 23 | self.statistic = statistic 24 | self.stats_frame_length = stats_frame_length 25 | self.stats_hop_length = stats_hop_length 26 | self.target_f0 = target_f0 27 | self.sample_rate = sample_rate 28 | self.to_mono = to_mono 29 | 30 | def forward(self, waveform, vad): 31 | """Appends a flipped version of the batch-samples""" 32 | w = flatten_pitch_batch( 33 | waveform=waveform, 34 | vad=vad, 35 | target_f0=self.target_f0, 36 | statistic=self.statistic, 37 | stats_frame_length=self.stats_frame_length, 38 | stats_hop_length=self.stats_hop_length, 39 | sample_rate=self.sample_rate, 40 | to_mono=self.to_mono, 41 | ) 42 | return w 43 | 44 | 45 | class ShiftPitch(nn.Module): 46 | def __init__( 47 | self, factor: float = 0.9, sample_rate: int = 16000, to_mono: bool = True 48 | ): 49 | super().__init__() 50 | self.factor = factor 51 | self.sample_rate = sample_rate 52 | self.to_mono = to_mono 53 | 54 | def forward(self, waveform, vad=None): 55 | return shift_pitch_batch( 56 | waveform=waveform, 57 | factor=self.factor, 58 | vad=vad, 59 | sample_rate=self.sample_rate, 60 | to_mono=self.to_mono, 61 | ) 62 | 63 | 64 | class LowPass(nn.Module): 65 | def __init__( 66 | self, 67 | cutoff_freq: int = 400, 68 | sample_rate: int = 16000, 69 | norm: bool = True, 70 | to_mono: bool = True, 71 | ): 72 | super().__init__() 73 | self.cutoff_freq = cutoff_freq 74 | self.sample_rate = sample_rate 75 | self.norm = norm 76 | self.to_mono = to_mono 77 | # self.gain = AT.Vol(gain=10, gain_type="db") 78 | 79 | def normalize(self, x): 80 | assert x.ndim == 2, f"normalization expects (B, n_samples) got {x.shape}" 81 | xx = x - x.min(-1, keepdim=True).values 82 | 83 | xmax = xx.max(-1, keepdim=True).values 84 | xx = xx / xmax 85 | xx = 2 * xx - 1.0 86 | return xx 87 | 88 | def standardize(self, x, eps=1e-5): 89 | assert x.ndim == 2, f"standardization expects (B, n_samples) got {x.shape}" 90 | m = x.mean(-1, keepdim=True) 91 | s = x.std(-1, keepdim=True) 92 | xx = (x - m) / (s + eps) 93 | return xx 94 | 95 | def forward(self, waveform, *args, **kwargs): 96 | waveform = low_pass_filter_resample( 97 | waveform, self.cutoff_freq, self.sample_rate 98 | ) 99 | if self.to_mono: 100 | waveform = waveform.mean(1) 101 | 102 | if self.norm: 103 | # waveform = self.standardize(waveform) 104 | waveform = self.normalize(waveform) 105 | # waveform = self.gain(waveform) 106 | 107 | return waveform 108 | 109 | 110 | class FlatIntensity(nn.Module): 111 | """ """ 112 | 113 | def __init__( 114 | self, 115 | vad_hz, 116 | vad_cutoff: float = 0.2, 117 | hop_time: float = 0.01, 118 | f0_min: int = 60, 119 | statistic: str = "mean", 120 | sample_rate: int = 16000, 121 | to_mono: bool = True, 122 | ): 123 | super().__init__() 124 | self.hop_time = hop_time 125 | self.vad_hz = vad_hz 126 | self.f0_min = f0_min 127 | self.vad_cutoff = vad_cutoff 128 | self.statistic = statistic 129 | self.sample_rate = sample_rate 130 | self.to_mono = to_mono 131 | self.neutralizer = IntensityNeutralizer( 132 | hop_time=hop_time, 133 | vad_hz=vad_hz, 134 | f0_min=f0_min, 135 | vad_cutoff=vad_cutoff, 136 | scale_stat=statistic, 137 | sample_rate=sample_rate, 138 | to_mono=to_mono, 139 | ) 140 | 141 | def forward(self, waveform, vad): 142 | combine = False 143 | if waveform.ndim == 3: 144 | combine = True 145 | if combine: 146 | y_tmp = waveform.mean(1) 147 | else: 148 | y_tmp = waveform 149 | y, _ = self.neutralizer(y_tmp, vad=vad) 150 | return y 151 | -------------------------------------------------------------------------------- /docker/Dockerfile_base: -------------------------------------------------------------------------------- 1 | # vim:ft=dockerfile 2 | FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime 3 | WORKDIR /workspace 4 | RUN apt-get update 5 | RUN apt-get install git g++ sox -y 6 | 7 | # torchaudio does not want to cooperate... this takes time but works 8 | RUN pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 9 | 10 | # CPCAudio 11 | WORKDIR /dependencies 12 | RUN git clone https://github.com/facebookresearch/CPC_audio.git 13 | WORKDIR /dependencies/CPC_audio 14 | RUN git checkout b98a1bdf1fe9ea219816db7a6c28115d404a3510 15 | RUN pip install cython 16 | RUN python setup.py develop 17 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker 2 | 3 | **WARNING: this may not be updated (not using docker for development)** 4 | 5 | 6 | * Requires [Nvidia-Docker]() for gpu support. 7 | * [Nvidia Docker Github](https://github.com/NVIDIA/nvidia-docker) 8 | * [github.io docs](https://nvidia.github.io/nvidia-docker/) 9 | * [Installation Guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker) 10 | * sudo may be required for default setup. That is add `sudo` before each of the commands below. 11 | * Build Base (torchaudio was difficult): `docker build -f docker/Dockerfile_base -t vap_base .` 12 | * Build: `docker build . -t vap` 13 | * Run: `docker run --rm -it --gpus all -v=$(pwd)/assets:/workspace/assets -v=$HOME/projects/data:/root/projects/data vap` 14 | * Used during debug + some training: 15 | * Add current directory (if changing code) 16 | * Run: `docker run --rm -it --gpus all -v=$(pwd):/workspace -v=$HOME/projects/data:/root/projects/data vap` 17 | 18 | 19 | ```bash 20 | # takes time but installs torch/torchaudio etc + CPC model repository 21 | docker build -f docker/Dockerfile_base -t vap_base . 22 | 23 | # using the image above and installs VAP repos 24 | # vap_turn_taking 25 | # datasets_turntaking 26 | # This repo 27 | docker build . -t vap 28 | 29 | # start docker (must include path to audio e.g. `$HOME/projects/data` in our case) 30 | docker run --rm -it --gpus all -v=$(pwd)/assets:/workspace/assets -v=$HOME/projects/data:/root/projects/data vap 31 | ``` 32 | 33 | -------------------------------------------------------------------------------- /docker/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets_turntaking.switchboard import load_switchboard 2 | from datasets_turntaking.fisher import load_fisher 3 | 4 | if __name__ == "__main__": 5 | for split in ["train", "val", "test"]: 6 | dset = load_switchboard(split=split) 7 | 8 | for split in ["train", "val", "test"]: 9 | dset = load_switchboard(split=split) 10 | -------------------------------------------------------------------------------- /example/cpc_48_50hz_15gqq5s5.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ErikEkstedt/conv_ssl/c365345afff3df33c791c6fc9d498bc08617ffb7/example/cpc_48_50hz_15gqq5s5.ckpt -------------------------------------------------------------------------------- /example/student_long_female_en-US-Wavenet-G.TextGrid: -------------------------------------------------------------------------------- 1 | File type = "ooTextFile" 2 | Object class = "TextGrid" 3 | 4 | xmin = 0 5 | xmax = 2.4888125 6 | tiers? 7 | size = 2 8 | item []: 9 | item [1]: 10 | class = "IntervalTier" 11 | name = "words" 12 | xmin = 0 13 | xmax = 2.4888125 14 | intervals: size = 9 15 | intervals [1]: 16 | xmin = 0 17 | xmax = 0.04 18 | text = "are" 19 | intervals [2]: 20 | xmin = 0.04 21 | xmax = 0.28 22 | text = "you" 23 | intervals [3]: 24 | xmin = 0.28 25 | xmax = 0.34 26 | text = "a" 27 | intervals [4]: 28 | xmin = 0.34 29 | xmax = 0.92 30 | text = "student" 31 | intervals [5]: 32 | xmin = 0.92 33 | xmax = 1.15 34 | text = "here" 35 | intervals [6]: 36 | xmin = 1.15 37 | xmax = 1.29 38 | text = "at" 39 | intervals [7]: 40 | xmin = 1.29 41 | xmax = 1.49 42 | text = "this" 43 | intervals [8]: 44 | xmin = 1.49 45 | xmax = 2.12 46 | text = "university" 47 | intervals [9]: 48 | xmin = 2.12 49 | xmax = 2.4888125 50 | text = "" 51 | item [2]: 52 | class = "IntervalTier" 53 | name = "phones" 54 | xmin = 0 55 | xmax = 2.4888125 56 | intervals: size = 30 57 | intervals [1]: 58 | xmin = 0 59 | xmax = 0.04 60 | text = "ER0" 61 | intervals [2]: 62 | xmin = 0.04 63 | xmax = 0.2 64 | text = "Y" 65 | intervals [3]: 66 | xmin = 0.2 67 | xmax = 0.28 68 | text = "UW1" 69 | intervals [4]: 70 | xmin = 0.28 71 | xmax = 0.34 72 | text = "AH0" 73 | intervals [5]: 74 | xmin = 0.34 75 | xmax = 0.44 76 | text = "S" 77 | intervals [6]: 78 | xmin = 0.44 79 | xmax = 0.52 80 | text = "T" 81 | intervals [7]: 82 | xmin = 0.52 83 | xmax = 0.64 84 | text = "UW1" 85 | intervals [8]: 86 | xmin = 0.64 87 | xmax = 0.68 88 | text = "D" 89 | intervals [9]: 90 | xmin = 0.68 91 | xmax = 0.75 92 | text = "AH0" 93 | intervals [10]: 94 | xmin = 0.75 95 | xmax = 0.78 96 | text = "N" 97 | intervals [11]: 98 | xmin = 0.78 99 | xmax = 0.92 100 | text = "T" 101 | intervals [12]: 102 | xmin = 0.92 103 | xmax = 0.93 104 | text = "HH" 105 | intervals [13]: 106 | xmin = 0.93 107 | xmax = 1.08 108 | text = "IY1" 109 | intervals [14]: 110 | xmin = 1.08 111 | xmax = 1.15 112 | text = "R" 113 | intervals [15]: 114 | xmin = 1.15 115 | xmax = 1.22 116 | text = "AE1" 117 | intervals [16]: 118 | xmin = 1.22 119 | xmax = 1.29 120 | text = "T" 121 | intervals [17]: 122 | xmin = 1.29 123 | xmax = 1.3 124 | text = "DH" 125 | intervals [18]: 126 | xmin = 1.3 127 | xmax = 1.41 128 | text = "IH0" 129 | intervals [19]: 130 | xmin = 1.41 131 | xmax = 1.49 132 | text = "S" 133 | intervals [20]: 134 | xmin = 1.49 135 | xmax = 1.54 136 | text = "Y" 137 | intervals [21]: 138 | xmin = 1.54 139 | xmax = 1.57 140 | text = "UW2" 141 | intervals [22]: 142 | xmin = 1.57 143 | xmax = 1.6 144 | text = "N" 145 | intervals [23]: 146 | xmin = 1.6 147 | xmax = 1.66 148 | text = "AH0" 149 | intervals [24]: 150 | xmin = 1.66 151 | xmax = 1.71 152 | text = "V" 153 | intervals [25]: 154 | xmin = 1.71 155 | xmax = 1.8 156 | text = "ER1" 157 | intervals [26]: 158 | xmin = 1.8 159 | xmax = 1.88 160 | text = "S" 161 | intervals [27]: 162 | xmin = 1.88 163 | xmax = 1.91 164 | text = "AH0" 165 | intervals [28]: 166 | xmin = 1.91 167 | xmax = 1.94 168 | text = "T" 169 | intervals [29]: 170 | xmin = 1.94 171 | xmax = 2.12 172 | text = "IY0" 173 | intervals [30]: 174 | xmin = 2.12 175 | xmax = 2.4888125 176 | text = "" 177 | -------------------------------------------------------------------------------- /example/student_long_female_en-US-Wavenet-G.json: -------------------------------------------------------------------------------- 1 | { 2 | "text": "Are you a student here at this university?", 3 | "audio_path": "assets/phrases_beta/audio/student_long_female_en-US-Wavenet-G.wav", 4 | "gender": "female", 5 | "words": ["Are", "you", "a", "student", "here", "at", "this", "university?"], 6 | "starts": [ 7 | 0.009999999776482582, 0.08537499606609344, 0.26220834255218506, 8 | 0.3022083342075348, 0.7422499656677246, 1.016708254814148, 9 | 1.1381666660308838, 1.331708312034607 10 | ], 11 | "size": "long", 12 | "tts": "en-US-Wavenet-G", 13 | "name": "student_long_female_en-US-Wavenet-G" 14 | } 15 | -------------------------------------------------------------------------------- /example/student_long_female_en-US-Wavenet-G.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ErikEkstedt/conv_ssl/c365345afff3df33c791c6fc9d498bc08617ffb7/example/student_long_female_en-US-Wavenet-G.wav -------------------------------------------------------------------------------- /example/vad_list.json: -------------------------------------------------------------------------------- 1 | [ 2 | [ 3 | [0.0, 0.04], 4 | [0.04, 0.28], 5 | [0.28, 0.34], 6 | [0.34, 0.92], 7 | [0.92, 1.15], 8 | [1.15, 1.29], 9 | [1.29, 1.49], 10 | [1.49, 2.12] 11 | ], 12 | [] 13 | ] 14 | -------------------------------------------------------------------------------- /frontend/main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import streamlit as st 3 | import torch 4 | import torchaudio 5 | import json 6 | from scipy.io.wavfile import read 7 | 8 | import textgrids 9 | 10 | from conv_ssl.model import VPModel 11 | from conv_ssl.utils import ( 12 | everything_deterministic, 13 | get_tg_vad_list, 14 | load_waveform, 15 | read_json, 16 | read_txt, 17 | ) 18 | from conv_ssl.evaluation.duration import read_text_grid 19 | from conv_ssl.evaluation.evaluation_phrases import plot_sample 20 | 21 | everything_deterministic() 22 | 23 | 24 | CHECKPOINT = "example/cpc_48_50hz_15gqq5s5.ckpt" 25 | SAMPLE_RATE = 16000 26 | TG_TMP_PATH = "tmp_textgrid.TextGrid" 27 | 28 | 29 | EX_WAV = "example/student_long_female_en-US-Wavenet-G.wav" 30 | EX_TG = "example/student_long_female_en-US-Wavenet-G.TextGrid" 31 | EX_VA = "example/vad_list.json" 32 | 33 | 34 | @st.cache 35 | def load_model(checkpoint=CHECKPOINT): 36 | model = VPModel.load_from_checkpoint(checkpoint) 37 | model = model.eval() 38 | if torch.cuda.is_available(): 39 | _ = model.to("cuda") 40 | return model 41 | 42 | 43 | def load_vad_list(vad_list_data): 44 | vad_list = json.loads(vad_list_data.getvalue().decode("utf-8")) 45 | st.session_state.vad_list = vad_list 46 | 47 | 48 | def load_textgrid(tg_data): 49 | tg_read = tg_data.getvalue().decode("utf-8") 50 | with open(TG_TMP_PATH, "w", encoding="utf-8") as f: 51 | f.write(tg_read) 52 | tg = read_text_grid(TG_TMP_PATH) 53 | vad_list = get_tg_vad_list(tg) 54 | st.session_state.tg = tg 55 | st.session_state.vad_list = vad_list 56 | 57 | 58 | def run_model(): 59 | if "waveform" in st.session_state and "vad_list" in st.session_state: 60 | sample = st.session_state.model.load_sample( 61 | st.session_state.waveform, st.session_state.vad_list 62 | ) 63 | 64 | if "tg" in st.session_state and st.session_state.tg is not None: 65 | if "words" in st.session_state.tg: 66 | sample["words"] = st.session_state.tg["words"] 67 | 68 | if "phones" in st.session_state.tg: 69 | sample["phones"] = st.session_state.tg["phones"] 70 | 71 | loss, out, probs, sample = st.session_state.model.output(sample) 72 | # Save 73 | data = { 74 | "loss": {"vp": loss["vp"].item(), "frames": loss["frames"].tolist()}, 75 | "probs": out["logits_vp"].softmax(-1).tolist(), 76 | "labels": out["va_labels"].tolist(), 77 | "p": probs["p"].tolist(), 78 | "p_bc": probs["bc_prediction"].tolist(), 79 | } 80 | st.session_state.output_data = data 81 | 82 | fig, ax = plot_sample( 83 | probs["p"][0, :, 0], 84 | sample, 85 | sample_rate=st.session_state.model.sample_rate, 86 | frame_hz=st.session_state.model.frame_hz, 87 | ) 88 | st.session_state.fig = fig 89 | 90 | 91 | def sample(): 92 | waveform, _ = load_waveform( 93 | EX_WAV, sample_rate=SAMPLE_RATE, normalize=True, mono=True 94 | ) 95 | tg = read_text_grid(EX_TG) 96 | vad_list = get_tg_vad_list(tg) 97 | sample = st.session_state.model.load_sample(waveform, vad_list) 98 | sample["words"] = tg["words"] 99 | sample["phones"] = tg["phones"] 100 | 101 | loss, out, probs, sample = st.session_state.model.output(sample) 102 | # Save 103 | data = { 104 | "loss": {"vp": loss["vp"].item(), "frames": loss["frames"].tolist()}, 105 | "probs": out["logits_vp"].softmax(-1).tolist(), 106 | "labels": out["va_labels"].tolist(), 107 | "p": probs["p"].tolist(), 108 | "p_bc": probs["bc_prediction"].tolist(), 109 | } 110 | st.session_state.output_data = data 111 | fig, ax = plot_sample( 112 | probs["p"][0, :, 0], 113 | sample, 114 | sample_rate=st.session_state.model.sample_rate, 115 | frame_hz=st.session_state.model.frame_hz, 116 | ) 117 | st.session_state.fig = fig 118 | 119 | 120 | def clear(): 121 | st.session_state.fig = None 122 | 123 | 124 | def check_password(): 125 | """Returns `True` if the user had the correct password.""" 126 | 127 | def password_entered(): 128 | """Checks whether a password entered by the user is correct.""" 129 | if st.session_state["password"] == st.secrets["password"]: 130 | st.session_state["password_correct"] = True 131 | del st.session_state["password"] # don't store password 132 | else: 133 | st.session_state["password_correct"] = False 134 | 135 | if "password_correct" not in st.session_state: 136 | # First run, show input for password. 137 | st.text_input( 138 | "Password", type="password", on_change=password_entered, key="password" 139 | ) 140 | return False 141 | elif not st.session_state["password_correct"]: 142 | # Password not correct, show input + error. 143 | st.text_input( 144 | "Password", type="password", on_change=password_entered, key="password" 145 | ) 146 | st.error("😕 Password incorrect") 147 | return False 148 | else: 149 | # Password correct. 150 | return True 151 | 152 | 153 | if __name__ == "__main__": 154 | 155 | if check_password(): 156 | with st.sidebar: 157 | st.header("Sample") 158 | 159 | with open(EX_WAV, "rb") as f: 160 | st.download_button( 161 | "Download Wav", f, file_name="sample.wav", mime="audio/wav" 162 | ) 163 | 164 | # st.subheader("VA List") 165 | # st.text(read_json(EX_VA), expanded=False) 166 | with open(EX_TG, "rb") as f: 167 | st.download_button( 168 | "Download TextGrid", f, file_name="sample_tg.TextGrid" 169 | ) 170 | 171 | with open(EX_VA, "rb") as f: 172 | st.download_button( 173 | "Download VA list", 174 | f, 175 | file_name="sample_va.json", 176 | mime="application/json", 177 | ) 178 | 179 | st.subheader("Inspect") 180 | with st.expander("TextGrid"): 181 | with open(EX_TG, "r", encoding="utf-8") as f: 182 | st.text(f.read()) 183 | with st.expander("VA List"): 184 | st.json(read_json(EX_VA), expanded=True) 185 | 186 | if "model" not in st.session_state: 187 | st.session_state.model = load_model() 188 | 189 | if "output_data" not in st.session_state: 190 | st.session_state.output_data = None 191 | 192 | with st.container(): 193 | col1, col2, col3 = st.columns(3) 194 | with col1: 195 | audio = st.file_uploader("Audio", type="wav") 196 | if audio is not None: 197 | st.session_state.waveform, _ = load_waveform( 198 | audio, sample_rate=SAMPLE_RATE, normalize=True, mono=True 199 | ) 200 | 201 | with col2: 202 | tg_data = st.file_uploader("TextGrid", type="TextGrid") 203 | if tg_data is not None: 204 | load_textgrid(tg_data) 205 | 206 | with col3: 207 | vad_list_data = st.file_uploader("VA List", type="json") 208 | if vad_list_data is not None: 209 | load_vad_list(vad_list_data) 210 | 211 | st.audio(audio) 212 | with st.container(): 213 | c1, c2, c3, c4 = st.columns(4) 214 | with c1: 215 | st.button(label="Run model", on_click=run_model) 216 | with c2: 217 | st.button(label="Run Sample", on_click=sample) 218 | with c3: 219 | st.button(label="Clear", on_click=clear) 220 | with c4: 221 | st.download_button( 222 | label="Download", 223 | data=json.dumps(st.session_state.output_data), 224 | file_name="vap_output.json", 225 | mime="application/json", 226 | ) 227 | 228 | if "fig" in st.session_state and st.session_state.fig is not None: 229 | st.pyplot(st.session_state.fig) 230 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | encoder: pretrained encoder 4 | cpc: 5 | main: training loop 6 | ar: autoregressive 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # lightning 2 | pytorch-lightning 3 | numpy 4 | scipy 5 | scikit-learn 6 | einops 7 | matplotlib 8 | wandb 9 | omegaconf 10 | hydra-core 11 | pytest 12 | datasets 13 | 14 | # CPC: does not seem to work directly. One must install cython 'pip install cython' 15 | # prior to 'pip install -r requirements.txt' 16 | cython 17 | soundfile 18 | git+https://github.com/facebookresearch/CPC_audio.git 19 | 20 | 21 | # Our Dependencies 22 | # must download and install via 23 | # 'pip install -r requirements.txt' 24 | # 'pip install -e .' 25 | # in the respective repos 26 | # git+https://github.com/ErikEkstedt/vap_turn_taking.git 27 | # git+https://github.com/ErikEkstedt/datasets_turntaking.git 28 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from conv_ssl.model import VPModel 3 | from conv_ssl.utils import ( 4 | everything_deterministic, 5 | get_tg_vad_list, 6 | read_json, 7 | write_json, 8 | ) 9 | from conv_ssl.evaluation.duration import read_text_grid 10 | 11 | everything_deterministic() 12 | 13 | 14 | def get_args(): 15 | parser = ArgumentParser() 16 | parser.add_argument( 17 | "-c", "--checkpoint", type=str, default="example/cpc_48_50hz_15gqq5s5.ckpt" 18 | ) 19 | parser.add_argument( 20 | "-w", 21 | "--wav", 22 | type=str, 23 | default="example/student_long_female_en-US-Wavenet-G.wav", 24 | ) 25 | parser.add_argument( 26 | "-tg", 27 | "--text_grid", 28 | type=str, 29 | default="example/student_long_female_en-US-Wavenet-G.TextGrid", 30 | ) 31 | parser.add_argument( 32 | "-v", 33 | "--voice_activity", 34 | type=str, 35 | default=None, # default="example/student_long_female_en-US-Wavenet-G.json", 36 | ) 37 | parser.add_argument( 38 | "-o", 39 | "--output", 40 | type=str, 41 | default="vap_output.json", 42 | ) 43 | args = parser.parse_args() 44 | 45 | assert ( 46 | args.voice_activity is not None or args.text_grid is not None 47 | ), "Must provide --voice_activity or --text_grid" 48 | return args 49 | 50 | 51 | def serialize_sample(sample): 52 | return {"vad": sample["vad"].tolist(), "waveform": sample["waveform"].tolist()} 53 | 54 | 55 | if __name__ == "__main__": 56 | 57 | args = get_args() 58 | 59 | tg = None 60 | if args.voice_activity is not None: 61 | vad_list = read_json(args.voice_activity) 62 | else: 63 | tg = read_text_grid(args.text_grid) 64 | vad_list = get_tg_vad_list(tg) 65 | 66 | print("Load Model: ", args.checkpoint) 67 | model = VPModel.load_from_checkpoint(args.checkpoint) 68 | model = model.eval() 69 | _ = model.to("cuda") 70 | 71 | print("Wavfile: ", args.wav) 72 | print("VA-list: ", args.voice_activity) 73 | print("TextGrid: ", args.text_grid) 74 | 75 | # get sample and process 76 | sample = model.load_sample(args.wav, vad_list) 77 | loss, out, probs, sample = model.output(sample) 78 | 79 | # Save 80 | data = { 81 | "loss": {"vp": loss["vp"].item(), "frames": loss["frames"].tolist()}, 82 | "probs": out["logits_vp"].softmax(-1).tolist(), 83 | "labels": out["va_labels"].tolist(), 84 | "p": probs["p"].tolist(), 85 | "p_bc": probs["bc_prediction"].tolist(), 86 | "model": { 87 | "sample_rate": model.sample_rate, 88 | "frame_hz": model.frame_hz, 89 | "checkpoint": args.checkpoint, 90 | }, 91 | "va": vad_list, 92 | } 93 | 94 | if tg is not None: 95 | data["words"] = tg["words"] 96 | data["phones"] = tg["phones"] 97 | 98 | write_json(data, args.output) 99 | print("Wrote output -> ", args.output) 100 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Scripts 2 | 3 | 4 | These are scripts used during training and for other random/debugging purposes. These are not really for use but works as a record for doing random things and may not work out of the box. 5 | -------------------------------------------------------------------------------- /scripts/model_kfold.bash: -------------------------------------------------------------------------------- 1 | train="python conv_ssl/train_disk.py --gpus -1 --batch_size 50 --patience 20 --val_check_interval 0.5" 2 | 3 | 4 | # Used in the paper to evaluate models over kfold-splits 5 | 6 | for i in 0 1 2 3 4 5 6 7 8 9 10 7 | do 8 | $train --conf conv_ssl/config/model.yaml \ 9 | --train_files conv_ssl/config/swb_kfolds/${i}_fold_train.txt \ 10 | --val_files conv_ssl/config/swb_kfolds/${i}_fold_val.txt \ 11 | --seed $i \ 12 | --name_info _kfold_${i} 13 | done 14 | 15 | for i in 0 1 2 3 4 5 6 7 8 9 10 16 | do 17 | $train --conf conv_ssl/config/model_independent.yaml \ 18 | --train_files conv_ssl/config/swb_kfolds/${i}_fold_train.txt \ 19 | --val_files conv_ssl/config/swb_kfolds/${i}_fold_val.txt \ 20 | --seed $i \ 21 | --name_info _kfold_${i} 22 | done 23 | 24 | for i in 0 1 2 3 4 5 6 7 8 9 10 25 | do 26 | $train --conf conv_ssl/config/model_independent_baseline.yaml \ 27 | --train_files conv_ssl/config/swb_kfolds/${i}_fold_train.txt \ 28 | --val_files conv_ssl/config/swb_kfolds/${i}_fold_val.txt \ 29 | --seed $i \ 30 | --name_info _kfold_${i} 31 | done 32 | 33 | for i in 0 1 2 3 4 5 6 7 8 9 10 34 | do 35 | $train --conf conv_ssl/config/model_comparative.yaml \ 36 | --train_files conv_ssl/config/swb_kfolds/${i}_fold_train.txt \ 37 | --val_files conv_ssl/config/swb_kfolds/${i}_fold_val.txt \ 38 | --seed $i \ 39 | --name_info _kfold_${i} 40 | done 41 | 42 | # Equal bin-times for discrete 43 | # for i in 0 1 2 3 4 5 6 7 8 9 10 44 | # do 45 | # $train --conf conv_ssl/config/model_equal.yaml \ 46 | # --train_files conv_ssl/config/swb_kfolds/${i}_fold_train.txt \ 47 | # --val_files conv_ssl/config/swb_kfolds/${i}_fold_val.txt \ 48 | # --seed $i \ 49 | # --name_info _kfold_${i} 50 | # done 51 | 52 | # for i in 0 1 2 3 4 5 6 7 8 9 10 53 | # do 54 | # $train --conf conv_ssl/config/model_latent.yaml \ 55 | # --train_files conv_ssl/config/swb_kfolds/${i}_fold_train.txt \ 56 | # --val_files conv_ssl/config/swb_kfolds/${i}_fold_val.txt \ 57 | # --seed $i \ 58 | # --name_info _kfold_${i} 59 | # done 60 | -------------------------------------------------------------------------------- /scripts/pitch_eval.bash: -------------------------------------------------------------------------------- 1 | aug_eval="python conv_ssl/evaluation/evaluation_augmentation.py data.batch_size=4 data.num_workers=4" 2 | swb="data.datasets=['switchboard']" 3 | fisher="data.datasets=['fisher']" 4 | both="data.datasets=['fisher', 'switchboard']" 5 | paperb="/home/erik/projects/CCConv/conv_ssl/assets/PaperB" 6 | 7 | 8 | # Paths 9 | chpath="+checkpoint_path=/home/erik/projects/CCConv/conv_ssl/assets/PaperB/checkpoints/" 10 | thpath="+threshold_path=/home/erik/projects/CCConv/conv_ssl/assets/PaperB/eval/" 11 | 12 | # Model Checkpoints 13 | cpc20=$chpath"cpc_48_20hz_2ucueis8.ckpt" 14 | cpc50=$chpath"cpc_48_50hz_15gqq5s5.ckpt" 15 | cpc100=$chpath"cpc_48_100hz_3mkvq5fk.ckpt" 16 | 17 | # Model Thresholds 18 | # th20_swb 19 | # th20_fisher 20 | # th20_both 21 | 22 | 23 | sp_root="+savepath=PaperB/" 24 | ch_root="+checkpoint_path="$paperb"/checkpoints/" 25 | th_root="+threshold_path="$paperb"/eval/" 26 | # cpc_44_20hz_unfreeze/thresholds.json 27 | 28 | 29 | # savepaths 30 | sp_root="+savepath=PaperB/" 31 | $aug_eval $fisher $ch_root$cpc20u 32 | 33 | 34 | $aug_eval $swb $cpc20 $th20 +augmentation='low_pass' +cutoff_freq=250 35 | $aug_eval $swb $cpc50 $th50 +augmentation='low_pass' +cutoff_freq=250 36 | $aug_eval $swb $cpc100 $th100 +augmentation='low_pass' +cutoff_freq=250 37 | -------------------------------------------------------------------------------- /scripts/test_augmentation.bash: -------------------------------------------------------------------------------- 1 | aug_eval="python conv_ssl/evaluation/evaluation_augmentation.py data.batch_size=4 data.num_workers=4" 2 | swb="data.datasets=['switchboard']" 3 | fisher="data.datasets=['fisher']" 4 | both="data.datasets=['fisher','switchboard']" 5 | 6 | paperb="/home/erik/projects/CCConv/conv_ssl/assets/PaperB" 7 | chpath="+checkpoint_path=/home/erik/projects/CCConv/conv_ssl/assets/PaperB/checkpoints/" 8 | 9 | cpc20=$chpath"cpc_48_20hz_2ucueis8.ckpt" 10 | cpc50=$chpath"cpc_48_50hz_15gqq5s5.ckpt" 11 | cpc100=$chpath"cpc_48_100hz_3mkvq5fk.ckpt" 12 | 13 | # DEBUG 14 | # $aug_eval $swb $cpc50 +augmentation='flat_f0' +max_batches=10 15 | # $aug_eval $swb $cpc50 +augmentation='shift_f0' +max_batches=10 16 | # $aug_eval $swb $cpc50 +augmentation='flat_intensity' +max_batches=10 17 | # $aug_eval $swb $cpc50 +augmentation='only_f0' +max_batches=10 18 | 19 | 20 | 21 | # both last 22 | $aug_eval $both $cpc20 +augmentation='flat_f0' 23 | $aug_eval $both $cpc20 +augmentation='shift_f0' 24 | $aug_eval $both $cpc20 +augmentation='only_f0' 25 | $aug_eval $both $cpc20 +augmentation='flat_intensity' 26 | 27 | $aug_eval $both $cpc50 +augmentation='flat_f0' 28 | $aug_eval $both $cpc50 +augmentation='shift_f0' 29 | $aug_eval $both $cpc50 +augmentation='only_f0' 30 | $aug_eval $both $cpc50 +augmentation='flat_intensity' 31 | 32 | $aug_eval $both $cpc100 +augmentation='flat_f0' 33 | $aug_eval $both $cpc100 +augmentation='shift_f0' 34 | $aug_eval $both $cpc100 +augmentation='only_f0' 35 | $aug_eval $both $cpc100 +augmentation='flat_intensity' 36 | 37 | # # Switchboard 38 | # $aug_eval $swb $cpc20 +augmentation='flat_f0' 39 | # $aug_eval $swb $cpc20 +augmentation='shift_f0' 40 | # $aug_eval $swb $cpc20 +augmentation='flat_intensity' 41 | # $aug_eval $swb $cpc20 +augmentation='only_f0' 42 | # $aug_eval $swb $cpc50 +augmentation='flat_f0' 43 | # $aug_eval $swb $cpc50 +augmentation='shift_f0' 44 | # $aug_eval $swb $cpc50 +augmentation='flat_intensity' 45 | # $aug_eval $swb $cpc50 +augmentation='only_f0' 46 | # $aug_eval $swb $cpc100 +augmentation='flat_f0' 47 | # $aug_eval $swb $cpc100 +augmentation='shift_f0' 48 | # $aug_eval $swb $cpc100 +augmentation='flat_intensity' 49 | # $aug_eval $swb $cpc100 +augmentation='only_f0' 50 | 51 | 52 | # # FISHER 53 | # $aug_eval $fisher $cpc20 +augmentation='flat_f0' 54 | # $aug_eval $fisher $cpc20 +augmentation='shift_f0' 55 | # $aug_eval $fisher $cpc20 +augmentation='flat_intensity' 56 | # $aug_eval $fisher $cpc20 +augmentation='only_f0' 57 | # $aug_eval $fisher $cpc50 +augmentation='flat_f0' 58 | # $aug_eval $fisher $cpc50 +augmentation='shift_f0' 59 | # $aug_eval $fisher $cpc50 +augmentation='flat_intensity' 60 | # $aug_eval $fisher $cpc50 +augmentation='only_f0' 61 | # $aug_eval $fisher $cpc100 +augmentation='flat_f0' 62 | # $aug_eval $fisher $cpc100 +augmentation='shift_f0' 63 | # $aug_eval $fisher $cpc100 +augmentation='flat_intensity' 64 | # $aug_eval $fisher $cpc100 +augmentation='only_f0' 65 | -------------------------------------------------------------------------------- /scripts/test_future.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import torch 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | from conv_ssl.evaluation.evaluation_phrases import load_model_dset 7 | from conv_ssl.utils import everything_deterministic 8 | 9 | everything_deterministic() 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | ch_root = "assets/PaperB/checkpoints" 15 | checkpoint = join(ch_root, "cpc_48_50hz_15gqq5s5.ckpt") 16 | model, dset = load_model_dset(checkpoint) 17 | 18 | example = "student" 19 | d = dset.get_sample(example, "long", "female", 0) 20 | loss, out, probs, batch = model.output(d) 21 | 22 | # cutoff 23 | scp = [a for a in d["words"] if a[-1] == example][0] 24 | end_time = scp[1] 25 | t = d["waveform"].shape[-1] / model.sample_rate # seconds 26 | n_samples = int(end_time * model.sample_rate) 27 | n_frames = int(end_time * model.frame_hz) 28 | tsil = 1 29 | n_sil_samples = int(tsil * model.sample_rate) 30 | n_sil_frames = int(tsil * model.frame_hz) 31 | # Voice Activity History 32 | vah = d["vad_history"][:, :n_frames] 33 | vah_sil = vah[:, -1].repeat(1, n_sil_frames, 1) 34 | vah = torch.cat((vah, vah_sil), dim=1) 35 | # Voice Activity 36 | va = d["vad"][:, :n_frames] # , torch.zeros(1, n_sil_frames, 2)), dim=1) 37 | va_sil = torch.zeros(1, n_sil_frames, 2) 38 | va = torch.cat((va, va_sil), dim=1) 39 | va_horizon = torch.zeros(1, model.horizon_frames, 2) 40 | va = torch.cat((va, va_horizon), dim=1) 41 | 42 | d_short = { 43 | "waveform": torch.cat( 44 | (d["waveform"][:, :n_samples], torch.zeros(1, n_sil_samples)), dim=-1 45 | ), 46 | "vad": va, 47 | "vad_history": vah, 48 | } 49 | 50 | sloss, sout, sprobs, sbatch = model.output(d_short) 51 | 52 | # Compare 53 | n = sprobs["p"].shape[1] 54 | sp = sprobs["p"][0, :, 1] 55 | p = probs["p"][0, :n, 1] 56 | 57 | fig, ax = plt.subplots(4, 1) 58 | ax[0].plot(p, label="original", color="b") 59 | ax[1].plot(sp, label="cutoff", color="r") 60 | ax[2].plot(p, label="original", color="b") 61 | ax[2].plot(sp, label="cutoff", color="r") 62 | ax[3].plot(sp - p, label="diff") 63 | ax[3].set_ylim([-0.1, 0.1]) 64 | for a in ax: 65 | a.legend() 66 | a.vlines(n_frames - 1, ymin=0, ymax=1) 67 | plt.show() 68 | -------------------------------------------------------------------------------- /scripts/test_models_script.bash: -------------------------------------------------------------------------------- 1 | train="python conv_ssl/train.py --gpus -1 --batch_size 10 --fast_dev_run 1" 2 | 3 | $train --conf conv_ssl/config/model.yaml 4 | $train --conf conv_ssl/config/model_independent.yaml 5 | $train --conf conv_ssl/config/model_independent_baseline.yaml 6 | $train --conf conv_ssl/config/model_comparative.yaml 7 | -------------------------------------------------------------------------------- /scripts/test_phrases.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | phrase="python conv_ssl/evaluation/evaluation_phrases.py" 3 | 4 | chpath="--checkpoint=/home/erik/projects/CCConv/conv_ssl/assets/PaperB/checkpoints" 5 | cpc20=$chpath"/cpc_48_20hz_2ucueis8.ckpt" 6 | cpc50=$chpath"/cpc_48_50hz_15gqq5s5.ckpt" 7 | cpc100=$chpath"/cpc_48_100hz_3mkvq5fk.ckpt" 8 | 9 | $phrase $cpc20 10 | $phrase $cpc50 11 | $phrase $cpc100 12 | -------------------------------------------------------------------------------- /scripts/test_regular.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Evaluation of models. 4 | # Using PR-curves and scores to find the best thresholds on validation set 5 | # Use thresholds to get actual TEST-score 6 | 7 | test_="python conv_ssl/evaluation/evaluation.py data.batch_size=10 data.num_workers=4" 8 | swb="+data.datasets=['switchboard']" 9 | fisher="+data.datasets=['fisher']" 10 | both="+data.datasets=['fisher','switchboard']" 11 | chpath="+checkpoint_path=/home/erik/projects/CCConv/conv_ssl/assets/PaperB/checkpoints/" 12 | 13 | # regular 14 | cpc20=$chpath"cpc_48_20hz_2ucueis8.ckpt" 15 | cpc50=$chpath"cpc_48_50hz_15gqq5s5.ckpt" 16 | cpc100=$chpath"cpc_48_100hz_3mkvq5fk.ckpt" 17 | 18 | # 20 Hz 19 | # $test_ $swb $cpc20 20 | # $test_ $fisher $cpc20 21 | # $test_ $both $cpc20 22 | 23 | # 50 Hz 24 | # $test_ $swb $cpc50 25 | # $test_ $fisher $cpc50 26 | # $test_ $both $cpc50 27 | 28 | # 100 Hz 29 | $test_ $swb $cpc100 30 | $test_ $fisher $cpc100 31 | $test_ $both $cpc100 32 | -------------------------------------------------------------------------------- /scripts/train_script.bash: -------------------------------------------------------------------------------- 1 | train="python conv_ssl/train_disk.py --gpus -1 --batch_size 50 --patience 50 --val_check_interval 0.5" 2 | 3 | $train --conf conv_ssl/config/model.yaml 4 | $train --conf conv_ssl/config/model_independent.yaml 5 | $train --conf conv_ssl/config/model_independent_baseline.yaml 6 | $train --conf conv_ssl/config/model_comparative.yaml 7 | -------------------------------------------------------------------------------- /scripts/train_vap_new.bash: -------------------------------------------------------------------------------- 1 | train="python conv_ssl/train_hydra.py +trainer.val_check_interval=0.5" 2 | data="data.datasets=['switchboard','fisher'] data.num_workers=24 data.batch_size=25" 3 | dev="+trainer.limit_train_batches=10 +trainer.limit_val_batches=10" 4 | dur20="data.audio_duration=20" 5 | unfreeze10="optimizer.train_encoder_epoch=10" 6 | unfreeze5="optimizer.train_encoder_epoch=5" 7 | no_hist="data.vad_history=False" 8 | 9 | ################################# 10 | # Train without History 11 | ################################# 12 | # $train $data $dur20 $no_hist $unfreeze model=discrete_50hz model.ar.num_heads=8 13 | # $train $data $dur20 $no_hist $unfreeze model=discrete_20hz model.ar.num_heads=8 14 | 15 | ################################# 16 | # Train Encoder after 10 epochs 17 | ################################# 18 | # $train $data $dur20 $unfreeze model=discrete_50hz model.ar.num_heads=8 19 | # $train $data $dur20 $unfreeze model=discrete_20hz model.ar.num_heads=8 20 | 21 | 22 | $train $data model=discrete model.ar.num_heads=8 23 | 24 | 25 | $train $data $dur20 $unfreeze5 model=discrete_20hz model.ar.num_heads=8 26 | $train $data $dur20 $unfreeze5 model=discrete_50hz model.ar.num_heads=8 27 | $train $data $unfreeze5 model=discrete model.ar.num_heads=8 28 | 29 | # $train $data model=discrete_20hz $dev 30 | # $train $data $unfreeze model=discrete 31 | # $train $data $no_hist $unfreeze model=discrete 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | setup( 6 | name="conv_ssl", 7 | version="0.0.0", 8 | description="Conversational Self-Supervised Learning", 9 | author="erikekst", 10 | author_email="erikekst@kth.se", 11 | url="https://github.com/ErikEkstedt/conv_ssl", 12 | packages=["conv_ssl", "conv_ssl.models", "conv_ssl.evaluation"], 13 | ) 14 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | TEST_ROOT = os.path.realpath(os.path.dirname(__file__)) 4 | PACKAGE_ROOT = os.path.dirname(TEST_ROOT) 5 | ROOT_SEED = 1234 6 | -------------------------------------------------------------------------------- /tests/test_ar.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from conv_ssl.models import AR 4 | from conv_ssl.utils import load_hydra_conf 5 | 6 | 7 | @pytest.mark.ar 8 | @pytest.mark.parametrize( 9 | "config_name", ["model/discrete", "model/discrete_20hz", "model/discrete_50hz"] 10 | ) 11 | def test_autoregressive(config_name): 12 | 13 | D = 256 14 | conf = load_hydra_conf(config_name=config_name)["model"] 15 | model = AR( 16 | input_dim=D, 17 | dim=conf["ar"]["dim"], 18 | num_layers=conf["ar"]["num_layers"], 19 | dropout=conf["ar"]["dropout"], 20 | ar=conf["ar"]["type"], 21 | transfomer_kwargs=dict( 22 | num_heads=conf["ar"]["num_heads"], 23 | dff_k=conf["ar"]["dff_k"], 24 | use_pos_emb=conf["ar"]["use_pos_emb"], 25 | max_context=conf["ar"].get("max_context", None), 26 | abspos=conf["ar"].get("abspos", None), 27 | sizeSeq=conf["ar"].get("sizeSeq", None), 28 | ), 29 | ) 30 | 31 | in_frames = 100 32 | if config_name.endswith("20hz"): 33 | in_frames = 20 34 | elif config_name.endswith("50hz"): 35 | in_frames = 50 36 | 37 | # extract the representation of last layer 38 | x = torch.randn(1, in_frames, D) 39 | 40 | if torch.cuda.is_available(): 41 | x = x.to("cuda") 42 | model.to("cuda") 43 | 44 | z = model(x)["z"] 45 | assert tuple(z.shape) == (1, in_frames, D), "shape mismatch" 46 | -------------------------------------------------------------------------------- /tests/test_cpc_causality.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from conv_ssl.model import VPModel 5 | from conv_ssl.utils import everything_deterministic, to_device, load_hydra_conf 6 | 7 | everything_deterministic() 8 | 9 | BATCH_SIZE = 2 10 | DURATION = 10 11 | # Offset in samples. 12 | # Not strictly safe/deterministic. 13 | # We get gradients up to 312/16000 = 0.0195 seconds -> 20ms into the future 14 | PAD = 312 15 | 16 | 17 | def get_sample_batch(sample_rate, frame_hz, frame_horizon): 18 | n_sample = int(sample_rate * DURATION) 19 | n_frames = int(DURATION * frame_hz) 20 | n_frames_horizon = n_frames + frame_horizon 21 | waveform = torch.randn(BATCH_SIZE, n_sample) 22 | vad = torch.randint(0, 2, (BATCH_SIZE, n_frames_horizon, 2), dtype=torch.float) 23 | vad_history = torch.rand((BATCH_SIZE, n_frames, 5)) 24 | return {"waveform": waveform, "vad": vad, "vad_history": vad_history} 25 | 26 | 27 | @pytest.mark.causality 28 | @pytest.mark.parametrize( 29 | ("output_layer", "config_name"), 30 | [ 31 | (0, "model/discrete"), 32 | (1, "model/discrete"), 33 | (0, "model/discrete_20hz"), 34 | (1, "model/discrete_20hz"), 35 | (0, "model/discrete_50hz"), 36 | (1, "model/discrete_50hz"), 37 | ], 38 | ) 39 | def test_causality(output_layer, config_name): 40 | 41 | conf = load_hydra_conf() 42 | conf["model"] = load_hydra_conf(config_name=config_name)["model"] 43 | conf["model"]["encoder"]["output_layer"] = output_layer 44 | model = VPModel(conf) 45 | 46 | if torch.cuda.is_available(): 47 | model.to("cuda") 48 | 49 | sample_rate = conf["model"]["encoder"]["sample_rate"] 50 | frame_hz = conf["model"]["encoder"]["frame_hz"] 51 | frame_horizon = model.VAP.horizon_frames 52 | 53 | batch = get_sample_batch( 54 | sample_rate=sample_rate, frame_hz=frame_hz, frame_horizon=frame_horizon 55 | ) 56 | for k, v in batch.items(): 57 | if isinstance(v, torch.Tensor): 58 | print(f"{k}: {tuple(v.shape)}") 59 | else: 60 | print(f"{k}: {v}") 61 | n_frames = batch["vad_history"].shape[1] 62 | half_frames = n_frames // 2 63 | half_samples = batch["waveform"].shape[-1] // 2 64 | 65 | batch = to_device(batch, model.device) 66 | ############################################## 67 | batch["waveform"].requires_grad = True 68 | loss, _, batch = model.shared_step(batch, reduction="none") 69 | l = loss["frames"] 70 | # backward 71 | l[:, half_frames].sum().backward() 72 | g = batch["waveform"].grad.abs() 73 | # g[:, half_samples + PAD :].sum() 74 | 75 | assert ( 76 | g[:, half_samples + PAD :].sum() == 0 77 | ), f"Non-Zero gradient after STEP. EncOut: {output_layer}, ULM-layers: {ulm_layers}" 78 | -------------------------------------------------------------------------------- /tests/test_encoder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from conv_ssl.models import Encoder 5 | from conv_ssl.utils import load_hydra_conf 6 | 7 | 8 | @pytest.mark.encoder 9 | @pytest.mark.parametrize("config_name", ["model/discrete", "model/discrete_20hz"]) 10 | def test_cpc_encoder(config_name): 11 | conf = load_hydra_conf(config_name=config_name) 12 | enc_conf = conf["model"]["encoder"] 13 | model = Encoder(enc_conf) 14 | 15 | # extract the representation of last layer 16 | wav_input_16khz = torch.randn(1, enc_conf["sample_rate"]) 17 | 18 | if torch.cuda.is_available(): 19 | wav_input_16khz = wav_input_16khz.to("cuda") 20 | model.to("cuda") 21 | 22 | z = model.encode(wav_input_16khz) 23 | assert tuple(z.shape) == (1, 100, 256), "shape mismatch" 24 | -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytorch_lightning as pl 3 | 4 | from conv_ssl.model import VPModel 5 | from conv_ssl.utils import everything_deterministic, load_hydra_conf 6 | from datasets_turntaking import DialogAudioDM 7 | 8 | everything_deterministic() 9 | 10 | 11 | @pytest.mark.main 12 | @pytest.mark.parametrize( 13 | "config_name", 14 | [ 15 | "model/discrete", 16 | "model/discrete_20hz", 17 | "model/discrete_50hz", 18 | ], 19 | ) 20 | def test_cpc_train(config_name): 21 | conf = load_hydra_conf() 22 | conf["model"] = load_hydra_conf(config_name=config_name)["model"] 23 | model = VPModel(conf) 24 | 25 | conf["data"]["num_workers"] = 0 26 | conf["data"]["batch_size"] = 4 27 | dm = DialogAudioDM(**conf["data"]) 28 | dm.prepare_data() 29 | 30 | trainer = pl.Trainer( 31 | gpus=-1, fast_dev_run=1, strategy="ddp", deterministic=True, log_every_n_steps=1 32 | ) 33 | trainer.fit(model, datamodule=dm) 34 | -------------------------------------------------------------------------------- /visualize_run.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | from conv_ssl.evaluation.evaluation_phrases import plot_sample 6 | from conv_ssl.utils import read_json, load_waveform 7 | 8 | 9 | def get_args(): 10 | parser = ArgumentParser() 11 | parser.add_argument( 12 | "-w", 13 | "--wav", 14 | type=str, 15 | default="example/student_long_female_en-US-Wavenet-G.wav", 16 | ) 17 | parser.add_argument( 18 | "-f", 19 | "--file", 20 | type=str, 21 | default="vap_output.json", 22 | ) 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | if __name__ == "__main__": 28 | args = get_args() 29 | 30 | data = read_json(args.file) 31 | 32 | sample = { 33 | "waveform": load_waveform( 34 | args.wav, 35 | sample_rate=data["model"]["sample_rate"], 36 | normalize=True, 37 | mono=True, 38 | )[0], 39 | "vad": data["va"], 40 | } 41 | 42 | if "phones" in data: 43 | sample["phones"] = data["phones"] 44 | 45 | if "words" in data: 46 | sample["words"] = data["words"] 47 | 48 | prob_next_speaker = torch.tensor(data["p"])[0, :, 0] 49 | 50 | fig, ax = plot_sample( 51 | prob_next_speaker, 52 | sample=sample, 53 | sample_rate=data["model"]["sample_rate"], 54 | frame_hz=data["model"]["frame_hz"], 55 | ) 56 | plt.show() 57 | --------------------------------------------------------------------------------