├── requirements.txt ├── .gitignore ├── crybrain_config_utils.py ├── README.md ├── .pre-commit-config.yaml ├── hparams └── ecapa_voxceleb_basic.yaml ├── crybrain.py ├── LICENSE ├── train.ipynb └── evaluate.ipynb /requirements.txt: -------------------------------------------------------------------------------- 1 | speechbrain 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | audio 2 | .ipynb_checkpoints 3 | concatenated_audio_train 4 | *.zip 5 | *.json 6 | *.csv 7 | __pycache__/ 8 | experiments/ 9 | *.sw* 10 | -------------------------------------------------------------------------------- /crybrain_config_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from speechbrain.nnet.schedulers import CyclicLRScheduler, ReduceLROnPlateau 7 | 8 | 9 | def choose_lrsched(lrsched_name, **kwargs): 10 | print(f"lrsched_name: {lrsched_name}") 11 | if lrsched_name == "onplateau": 12 | return ReduceLROnPlateau( 13 | lr_min=kwargs["lr_min"], 14 | factor=kwargs["factor"], 15 | patience=kwargs["patience"], 16 | dont_halve_until_epoch=kwargs["dont_halve_until_epoch"], 17 | ) 18 | elif lrsched_name == "cyclic": 19 | return CyclicLRScheduler( 20 | base_lr=kwargs["base_lr"], 21 | max_lr=kwargs["max_lr"], 22 | step_size=kwargs["step_size"], 23 | mode=kwargs["mode"], 24 | gamma=kwargs["gamma"], 25 | scale_fn=kwargs["scale_fn"], 26 | scale_mode=kwargs["scale_mode"], 27 | ) 28 | 29 | 30 | def set_seed(seed): 31 | """Set seed in every way possible.""" 32 | print(f"setting seeds to {seed}") 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | os.environ["PYTHONHASHSEED"] = str(seed) 37 | if torch.cuda.is_available(): 38 | print(f"setting cuda seeds to {seed}") 39 | torch.cuda.manual_seed(seed) 40 | torch.cuda.manual_seed_all(seed) 41 | 42 | 43 | def test_cuda_seed(): 44 | """Print some random results from various libraries.""" 45 | print(f"python random float: {random.random()}") 46 | print(f"numpy random int: {np.random.randint(100)}") 47 | print(f"torch random tensor (cpu): {torch.FloatTensor(100).uniform_()}") 48 | print(f"torch random tensor (cuda): {torch.cuda.FloatTensor(100).uniform_()}") 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CryCeleb2023 SpeechBrain fine-tuning 2 | 3 | All you need to start fine-tuning [SpeechBrain](https://speechbrain.readthedocs.io/) models using the [Ubenwa CryCeleb dataset](https://huggingface.co/datasets/Ubenwa/CryCeleb2023)! 4 | 5 | This code was used for [CryCeleb2023 HuggingFace challenge](https://huggingface.co/spaces/competitions/CryCeleb2023) 6 | 7 | It reproduces the [official baseline](https://huggingface.co/Ubenwa/ecapa-voxceleb-ft2-cryceleb) model training 8 | 9 | `train.ipynb` 10 | Open In Colab 11 | - main code for data preparation and fine-tuning with various configs 12 | 13 | `evaluate.ipynb` 14 | Open In Colab 15 | - example of scoring with pre-trained and fine-tuned model 16 | 17 | Note that default configurations are optimized for speed and simplicity rather than accuracy 18 | 19 | ## Cite 20 | 21 | ### CryCeleb (accepted to ICASSP 2024!) 22 | 23 | ```bibtex 24 | @article{ubenwa2023cryceleb, 25 | title={CryCeleb: A Speaker Verification Dataset Based on Infant Cry Sounds}, 26 | author={David Budaghyan and Charles C. Onu and Arsenii Gorin and Cem Subakan and Doina Precup}, 27 | year={2023}, 28 | journal={preprint arXiv:2305.00969}, 29 | } 30 | ``` 31 | 32 | ### SpeechBrain 33 | 34 | ```bibtex 35 | @misc{speechbrain, 36 | title={{SpeechBrain}: A General-Purpose Speech Toolkit}, 37 | author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, 38 | year={2021}, 39 | eprint={2106.04624}, 40 | archivePrefix={arXiv}, 41 | primaryClass={eess.AS}, 42 | note={arXiv:2106.04624} 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.3.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | # - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-added-large-files 19 | 20 | # python code formatting 21 | - repo: https://github.com/psf/black 22 | rev: 22.6.0 23 | hooks: 24 | - id: black 25 | args: [--line-length, "120"] 26 | 27 | # python import sorting 28 | - repo: https://github.com/PyCQA/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | args: ["--profile", "black", "--filter-files"] 33 | 34 | # python upgrading syntax to newer version 35 | - repo: https://github.com/asottile/pyupgrade 36 | rev: v2.32.1 37 | hooks: 38 | - id: pyupgrade 39 | args: [--py38-plus] 40 | 41 | # python docstring formatting 42 | - repo: https://github.com/myint/docformatter 43 | rev: v1.4 44 | hooks: 45 | - id: docformatter 46 | args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] 47 | 48 | # python check (PEP8), programming errors and code complexity 49 | - repo: https://github.com/PyCQA/flake8 50 | rev: 4.0.1 51 | hooks: 52 | - id: flake8 53 | args: 54 | [ 55 | "--ignore", 56 | "E501,F401,F841,W504,E203,W503", 57 | "--exclude", 58 | "logs/*,data/*", 59 | ] 60 | 61 | # yaml formatting 62 | - repo: https://github.com/pre-commit/mirrors-prettier 63 | rev: v2.7.1 64 | hooks: 65 | - id: prettier 66 | types: [yaml] 67 | 68 | # jupyter notebook cell output clearing 69 | - repo: https://github.com/kynan/nbstripout 70 | rev: 0.5.0 71 | hooks: 72 | - id: nbstripout 73 | 74 | # jupyter notebook linting 75 | - repo: https://github.com/nbQA-dev/nbQA 76 | rev: 1.4.0 77 | hooks: 78 | - id: nbqa-black 79 | args: ["--line-length=99"] 80 | - id: nbqa-isort 81 | args: ["--profile=black"] 82 | - id: nbqa-flake8 83 | args: 84 | [ 85 | "--extend-ignore=E203,E402,E501,F401,F841", 86 | "--exclude=logs/*,data/*", 87 | ] 88 | 89 | # md formatting 90 | - repo: https://github.com/executablebooks/mdformat 91 | rev: 0.7.14 92 | hooks: 93 | - id: mdformat 94 | args: ["--number"] 95 | additional_dependencies: 96 | - mdformat-gfm 97 | - mdformat-tables 98 | - mdformat_frontmatter 99 | # - mdformat-toc 100 | # - mdformat-black 101 | 102 | # word spelling linter 103 | - repo: https://github.com/codespell-project/codespell 104 | rev: v2.1.0 105 | hooks: 106 | - id: codespell 107 | args: 108 | - --skip=logs/**,data/** 109 | - --skip=notebooks/* 110 | # - --ignore-words-list=abc,def 111 | -------------------------------------------------------------------------------- /hparams/ecapa_voxceleb_basic.yaml: -------------------------------------------------------------------------------- 1 | # ################################ 2 | # Model: Speaker identification with ECAPA for CryCeleb 3 | # Authors: David Budaghyan 4 | # ################################ 5 | 6 | ckpt_interval_minutes: 15 # save checkpoint every N min 7 | 8 | ##### SEED 9 | seed: !PLACEHOLDER 10 | __set_seed: !apply:crybrain_config_utils.set_seed [!ref ] 11 | 12 | # DataLoader 13 | bs: 16 14 | train_dataloader_options: 15 | batch_size: !ref 16 | shuffle: True 17 | val_dataloader_options: 18 | batch_size: 2 19 | shuffle: False 20 | 21 | ##### ESTIMATOR COMPONENTS 22 | # Fbank (feature extractor) 23 | n_mels: 80 24 | left_frames: 0 25 | right_frames: 0 26 | deltas: False 27 | compute_features: !new:speechbrain.lobes.features.Fbank 28 | n_mels: !ref 29 | left_frames: !ref 30 | right_frames: !ref 31 | deltas: !ref 32 | 33 | # ECAPA 34 | emb_dim: 192 35 | embedding_model: !new:speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN 36 | input_size: !ref 37 | channels: [1024, 1024, 1024, 1024, 3072] 38 | kernel_sizes: [5, 3, 3, 3, 1] 39 | dilations: [1, 2, 3, 4, 1] 40 | groups: [1, 1, 1, 1, 1] 41 | attention_channels: 128 42 | lin_neurons: !ref 43 | 44 | # If you do not want to use the pretrained encoder you can simply delete pretrained_encoder field. 45 | pretrained_model_name: spkrec-ecapa-voxceleb 46 | pretrained_embedding_model_path: !ref speechbrain//embedding_model.ckpt 47 | pretrained_embedding_model: !new:speechbrain.utils.parameter_transfer.Pretrainer 48 | collect_in: !ref /ckpts 49 | loadables: 50 | model: !ref 51 | paths: 52 | model: !ref 53 | 54 | # CLASSIFIER 55 | n_classes: !PLACEHOLDER # check-yaml disable 56 | 57 | 58 | classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier 59 | input_size: !ref 60 | out_neurons: !ref 61 | 62 | ##### EPOCH COUNTER 63 | n_epochs: 1000 64 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 65 | limit: !ref 66 | 67 | ##### OPTIMIZER 68 | start_lr: 0.0001 69 | opt_class: !name:torch.optim.Adam 70 | lr: !ref 71 | weight_decay: 0.000002 72 | 73 | ##### LEARNING RATE SCHEDULERS 74 | lrsched_name: cyclic 75 | # one of: 76 | # onplateau 77 | # cyclic 78 | lr_min: 0.0000000001 79 | lr_scheduler: !apply:crybrain_config_utils.choose_lrsched 80 | lrsched_name: !ref 81 | #below are kwargs, only the ones relevant to the type of scheduler will be 82 | #used for initialization in `choose_lrsched` 83 | 84 | #onplateau (ReduceLROnPlateau) 85 | lr_min: !ref 86 | factor: 0.4 87 | patience: 10 88 | dont_halve_until_epoch: 35 89 | #cyclic (CyclicLRScheduler) 90 | base_lr: 0.00000001 91 | max_lr: !ref 92 | step_size: 100 93 | mode: triangular 94 | gamma: 1.0 95 | scale_fn: null 96 | scale_mode: cycle 97 | 98 | sample_rate: 16000 99 | mean_var_norm: !new:speechbrain.processing.features.InputNormalization 100 | norm_type: sentence 101 | std_norm: False 102 | 103 | modules: 104 | compute_features: !ref 105 | embedding_model: !ref 106 | classifier: !ref 107 | mean_var_norm: !ref 108 | 109 | compute_cost: !new:speechbrain.nnet.losses.LogSoftmaxWrapper 110 | loss_fn: !new:speechbrain.nnet.losses.AdditiveAngularMargin 111 | margin: 0.2 112 | scale: 30 113 | 114 | classification_stats: !name:speechbrain.utils.metric_stats.ClassificationStats 115 | ################################################################### 116 | ### OUTPUT PATHS ### 117 | 118 | 119 | experiment_name: 120 | !PLACEHOLDER # must run from the directory which contains "experiments" 121 | 122 | 123 | experiment_dir: !ref ./experiments/ 124 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 125 | save_file: !ref /train_log.txt 126 | 127 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 128 | checkpoints_dir: !ref /ckpts 129 | recoverables: 130 | embedding_model: !ref 131 | classifier: !ref 132 | normalizer: !ref 133 | counter: !ref 134 | lr_scheduler: !ref 135 | -------------------------------------------------------------------------------- /crybrain.py: -------------------------------------------------------------------------------- 1 | """Utils for working with cryceleb2023 in SpeechBrain. 2 | 3 | Author 4 | * David Budaghyan 2023 5 | * Mirco Ravanelli 2020 6 | """ 7 | 8 | 9 | import os 10 | import pickle 11 | import zipfile 12 | 13 | import speechbrain as sb 14 | import torch 15 | from huggingface_hub import hf_hub_download 16 | 17 | 18 | def download_data(dest="data"): 19 | 20 | if os.path.exists(os.path.join(dest, "audio", "train")): 21 | print( 22 | f"It appears that data is already downloaded. \nIf you think it should be re-downloaded, remove {dest}/ directory and re-run" 23 | ) 24 | return 25 | 26 | # download data from Huggingface 27 | for file_name in ["metadata.csv", "audio.zip", "dev_pairs.csv", "test_pairs.csv", "sample_submission.csv"]: 28 | 29 | hf_hub_download( 30 | repo_id="Ubenwa/CryCeleb2023", 31 | filename=file_name, 32 | local_dir=dest, 33 | repo_type="dataset", 34 | ) 35 | 36 | with zipfile.ZipFile(os.path.join(dest, "audio.zip"), "r") as zip_ref: 37 | zip_ref.extractall(dest) 38 | 39 | print("Data downloaded to {dest}/ directory") 40 | 41 | 42 | class CryBrain(sb.core.Brain): 43 | """Class for speaker embedding training".""" 44 | 45 | def compute_forward(self, batch, stage): 46 | """Computation pipeline based on a encoder + speaker classifier. 47 | 48 | Data augmentation and environmental corruption are applied to the input speech. 49 | """ 50 | batch = batch.to(self.device) 51 | wavs, lens = batch.sig 52 | 53 | if stage == sb.Stage.TRAIN and hasattr(self.hparams, "augment_pipeline"): 54 | 55 | # Applying the augmentation pipeline 56 | wavs_aug_tot = [] 57 | wavs_aug_tot.append(wavs) 58 | for count, augment in enumerate(self.hparams.augment_pipeline): 59 | # Apply augment 60 | wavs_aug = augment(wavs, lens) 61 | 62 | # Managing speed change 63 | if wavs_aug.shape[1] > wavs.shape[1]: 64 | wavs_aug = wavs_aug[:, 0 : wavs.shape[1]] 65 | else: 66 | zero_sig = torch.zeros_like(wavs) 67 | zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug 68 | wavs_aug = zero_sig 69 | 70 | if self.hparams.concat_augment: 71 | wavs_aug_tot.append(wavs_aug) 72 | else: 73 | wavs = wavs_aug 74 | wavs_aug_tot[0] = wavs 75 | 76 | wavs = torch.cat(wavs_aug_tot, dim=0) 77 | self.n_augment = len(wavs_aug_tot) 78 | lens = torch.cat([lens] * self.n_augment) 79 | 80 | # Feature extraction and normalization 81 | feats = self.modules.compute_features(wavs) 82 | feats = self.modules.mean_var_norm(feats, lens) 83 | 84 | # Embeddings + speaker classifier 85 | embeddings = self.modules.embedding_model(feats) 86 | classifier_outputs = self.modules.classifier(embeddings) 87 | return classifier_outputs, lens 88 | 89 | def compute_objectives(self, compute_forward_return, batch, stage): 90 | """Computes the loss using speaker-id as label.""" 91 | classifier_outputs, lens = compute_forward_return 92 | uttid = batch.id 93 | labels, _ = batch.baby_id_encoded 94 | 95 | # Concatenate labels (due to data augmentations) 96 | if stage == sb.Stage.TRAIN and hasattr(self.hparams, "augment_pipeline"): 97 | labels = torch.cat([labels] * self.n_augment, dim=0) 98 | uttid = [f"{u}_{i}" for i in range(self.n_augment) for u in uttid] 99 | 100 | loss = self.hparams.compute_cost(classifier_outputs, labels, lens) 101 | 102 | if stage == sb.Stage.TRAIN and hasattr(self.hparams.lr_scheduler, "on_batch_end"): 103 | self.hparams.lr_scheduler.on_batch_end(self.optimizer) 104 | 105 | predictions = [str(pred.item()) for pred in classifier_outputs.squeeze().argmax(1)] 106 | targets = [str(t.item()) for t in labels.squeeze()] 107 | 108 | # append the stats, if val, we also append the log_probs 109 | if stage == sb.Stage.TRAIN: 110 | self.classification_stats.append(ids=uttid, predictions=predictions, targets=targets) 111 | elif stage == sb.Stage.VALID: 112 | self.classification_stats.append(ids=uttid, predictions=predictions, targets=targets) 113 | return loss 114 | 115 | def on_stage_start(self, stage, epoch=None): 116 | """Gets called at the beginning of an epoch.""" 117 | 118 | self.classification_stats = self.hparams.classification_stats() 119 | 120 | def _write_stats(self, stage, epoch): 121 | """Wrties stats to f"{self.experiment_dir}/stats/{epoch}/{stage}.txt 122 | Arguments 123 | --------- 124 | epoch: int 125 | the epoch number 126 | stage: str 127 | "train" or "test" 128 | """ 129 | # do this to extract "train", "val", "test" to put in the file path 130 | stage = stage.__str__().split(".")[-1].lower() 131 | output_dir = os.path.join( 132 | self.hparams.experiment_dir, 133 | "stats", 134 | str(epoch), 135 | ) 136 | # create dir if doesn't exist 137 | if not os.path.exists(output_dir): 138 | os.makedirs(output_dir) 139 | # write classwise stats and confusion stats using speechbrain's classification_stats 140 | classwise_file_path = os.path.join(output_dir, f"classwise_{stage}.txt") 141 | with open(classwise_file_path, "w") as w: 142 | self.classification_stats.write_stats(w) 143 | # logger.info("classwise_statsvwritten to file: %s", classwise_file_path) 144 | 145 | # write instancewise stats 146 | instancewise_stats = [["instance_id", "prediction", "target"]] 147 | instancewise_stats.extend( 148 | [ 149 | [instance_id, prediction, target] 150 | for instance_id, prediction, target in zip( 151 | self.classification_stats.ids, 152 | self.classification_stats.predictions, 153 | self.classification_stats.targets, 154 | ) 155 | ] 156 | ) 157 | instancewise_file_path = os.path.join(output_dir, f"instancewise_{stage}.pkl") 158 | with open(instancewise_file_path, "wb") as pkl_file: 159 | pickle.dump(instancewise_stats, pkl_file) 160 | 161 | def on_stage_end(self, stage, stage_loss, epoch=None): 162 | """Gets called at the end of an epoch.""" 163 | 164 | log_stats = { 165 | "loss": stage_loss, 166 | # .summarize("accuracy") computes many statistics 167 | # but only returns the accuracy. The entire dictionary of stats is written to the stats dir. 168 | "acc": self.classification_stats.summarize("accuracy") * 100, 169 | } 170 | if stage == sb.Stage.TRAIN: 171 | # this is to save the log_stats, so we 172 | # write it along the validation log_stats after validation stage 173 | self.train_log_stats = log_stats 174 | # Perform end-of-iteration things, like annealing, logging, etc. 175 | if stage == sb.Stage.VALID: 176 | if self.hparams.lrsched_name == 'onplateau': 177 | old_lr, new_lr = self.hparams.lr_scheduler([self.optimizer], 178 | current_epoch=epoch, 179 | current_loss=stage_loss) 180 | else: 181 | old_lr, new_lr = self.hparams.lr_scheduler(epoch) 182 | sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr) 183 | # LOGGING 184 | self.hparams.train_logger.log_stats( 185 | {"Epoch": epoch, "lr": old_lr}, 186 | # train_stats={"loss": self.train_loss}, #do this if only keeping loss during training 187 | train_stats=self.train_log_stats, 188 | valid_stats=log_stats, 189 | ) 190 | # Save the current checkpoint and delete previous checkpoints, 191 | self.checkpointer.save_and_keep_only( 192 | name=f"epoch-{epoch}_valacc-{log_stats['acc']:.2f}", meta=log_stats, num_to_keep=4, max_keys=["acc"] 193 | ) 194 | 195 | self._write_stats(stage=stage, epoch=epoch) 196 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "fvP3aM0oMKvT" 7 | }, 8 | "source": [ 9 | "# Fine-tuning ECAPA-TDNN on [CryCeleb2023](https://huggingface.co/spaces/competitions/CryCeleb2023) using [SpeechBrain](https://speechbrain.readthedocs.io)\n", 10 | "\n", 11 | "This notebook should help you get started training your own models for CryCeleb2023 challenge.\n", 12 | "\n", 13 | "Note that it is provides basic example for simplicity and speed.\n", 14 | "\n", 15 | "Author: David Budaghyan (Ubenwa)\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "xLoVtmTDGVby" 22 | }, 23 | "source": [ 24 | "### Imports" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# For Colab - uncomment and run the following to set up the repo\n", 34 | "# !pip install speechbrain\n", 35 | "# !git clone https://github.com/Ubenwa/cryceleb2023.git\n", 36 | "# %cd cryceleb2023" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "executionInfo": { 44 | "elapsed": 4595, 45 | "status": "ok", 46 | "timestamp": 1683904816148, 47 | "user": { 48 | "displayName": "David Budaghyan", 49 | "userId": "04476631309743002593" 50 | }, 51 | "user_tz": 240 52 | }, 53 | "id": "5k7MiTJsK2Ba" 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "%%capture\n", 58 | "%load_ext autoreload\n", 59 | "%autoreload 2\n", 60 | "\n", 61 | "import pathlib\n", 62 | "import random\n", 63 | "\n", 64 | "import numpy as np\n", 65 | "import pandas as pd\n", 66 | "import seaborn as sns\n", 67 | "import speechbrain as sb\n", 68 | "import torch\n", 69 | "from huggingface_hub import hf_hub_download\n", 70 | "from hyperpyyaml import load_hyperpyyaml\n", 71 | "from IPython.display import display\n", 72 | "from speechbrain.dataio.dataio import read_audio, write_audio\n", 73 | "from speechbrain.dataio.dataset import DynamicItemDataset\n", 74 | "from speechbrain.dataio.encoder import CategoricalEncoder\n", 75 | "\n", 76 | "from crybrain import CryBrain, download_data\n", 77 | "\n", 78 | "dataset_path = \"data\"" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "id": "oBhuztrjHQX8" 85 | }, 86 | "source": [ 87 | "### Download data\n", 88 | "\n", 89 | "You need to log in to HuggingFace to be able to download the dataset" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": { 96 | "executionInfo": { 97 | "elapsed": 72, 98 | "status": "ok", 99 | "timestamp": 1683904816150, 100 | "user": { 101 | "displayName": "David Budaghyan", 102 | "userId": "04476631309743002593" 103 | }, 104 | "user_tz": 240 105 | }, 106 | "id": "Du5UrdEgKx7a" 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "from huggingface_hub import notebook_login\n", 111 | "\n", 112 | "notebook_login()" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "download_data(dataset_path)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": { 128 | "colab": { 129 | "base_uri": "https://localhost:8080/", 130 | "height": 513 131 | }, 132 | "executionInfo": { 133 | "elapsed": 73, 134 | "status": "ok", 135 | "timestamp": 1683904900224, 136 | "user": { 137 | "displayName": "David Budaghyan", 138 | "userId": "04476631309743002593" 139 | }, 140 | "user_tz": 240 141 | }, 142 | "id": "JrT1EvuiLjlr", 143 | "outputId": "dbbd3292-ed66-4d58-c10e-299d1bb4b453" 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "# read metadata\n", 148 | "metadata = pd.read_csv(\n", 149 | " f\"{dataset_path}/metadata.csv\", dtype={\"baby_id\": str, \"chronological_index\": str}\n", 150 | ")\n", 151 | "train_metadata = metadata.loc[metadata[\"split\"] == \"train\"].copy()\n", 152 | "display(\n", 153 | " train_metadata.head()\n", 154 | " .style.set_caption(\"train_metadata\")\n", 155 | " .set_table_styles([{\"selector\": \"caption\", \"props\": [(\"font-size\", \"20px\")]}])\n", 156 | ")\n", 157 | "display(train_metadata.describe())" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": { 163 | "id": "nX6q3zCxtaQa" 164 | }, 165 | "source": [ 166 | "### Concatenate cry sounds\n", 167 | "\n", 168 | "We are given short cry sounds for each baby. Here we simply concatenate them. " 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": { 175 | "colab": { 176 | "base_uri": "https://localhost:8080/", 177 | "height": 930 178 | }, 179 | "executionInfo": { 180 | "elapsed": 24464, 181 | "status": "ok", 182 | "timestamp": 1683904924639, 183 | "user": { 184 | "displayName": "David Budaghyan", 185 | "userId": "04476631309743002593" 186 | }, 187 | "user_tz": 240 188 | }, 189 | "id": "bUTe9QqwRqE-", 190 | "outputId": "d4872617-95de-4aa8-db79-bb41fe879408" 191 | }, 192 | "outputs": [], 193 | "source": [ 194 | "# read the segments\n", 195 | "train_metadata[\"cry\"] = train_metadata.apply(\n", 196 | " lambda row: read_audio(f'{dataset_path}/{row[\"file_name\"]}').numpy(), axis=1\n", 197 | ")\n", 198 | "# concatenate all segments for each (baby_id, period) group\n", 199 | "manifest_df = pd.DataFrame(\n", 200 | " train_metadata.groupby([\"baby_id\", \"period\"])[\"cry\"].agg(lambda x: np.concatenate(x.values)),\n", 201 | " columns=[\"cry\"],\n", 202 | ").reset_index()\n", 203 | "# all files have 16000 sampling rate\n", 204 | "manifest_df[\"duration\"] = manifest_df[\"cry\"].apply(len) / 16000\n", 205 | "pathlib.Path(f\"{dataset_path}/concatenated_audio_train\").mkdir(exist_ok=True)\n", 206 | "manifest_df[\"file_path\"] = manifest_df.apply(\n", 207 | " lambda row: f\"{dataset_path}/concatenated_audio_train/{row['baby_id']}_{row['period']}.wav\",\n", 208 | " axis=1,\n", 209 | ")\n", 210 | "manifest_df.apply(\n", 211 | " lambda row: write_audio(\n", 212 | " filepath=f'{row[\"file_path\"]}', audio=torch.tensor(row[\"cry\"]), samplerate=16000\n", 213 | " ),\n", 214 | " axis=1,\n", 215 | ")\n", 216 | "manifest_df = manifest_df.drop(columns=[\"cry\"])\n", 217 | "display(manifest_df)\n", 218 | "ax = sns.histplot(manifest_df, x=\"duration\")\n", 219 | "ax.set_title(\"Histogram of Concatenated Cry Sound Lengths\")" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "During training, we will extract random cuts of 3-5 seconds from concatenated audio" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "executionInfo": { 234 | "elapsed": 114, 235 | "status": "ok", 236 | "timestamp": 1683904924641, 237 | "user": { 238 | "displayName": "David Budaghyan", 239 | "userId": "04476631309743002593" 240 | }, 241 | "user_tz": 240 242 | }, 243 | "id": "8fqkEktdRRXf" 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "def create_cut_length_interval(row, cut_length_interval):\n", 248 | " \"\"\"cut_length_interval is a tuple indicating the range of lengths we want our chunks to be.\n", 249 | " this function computes the valid range of chunk lengths for each audio file\n", 250 | " \"\"\"\n", 251 | " # the lengths are in seconds, convert them to frames\n", 252 | " cut_length_interval = [round(length * 16000) for length in cut_length_interval]\n", 253 | " cry_length = round(row[\"duration\"] * 16000)\n", 254 | " # make the interval valid for the specific sound file\n", 255 | " min_cut_length, max_cut_length = cut_length_interval\n", 256 | " # if min_cut_length is greater than length of cry, don't cut\n", 257 | " if min_cut_length >= cry_length:\n", 258 | " cut_length_interval = (cry_length, cry_length)\n", 259 | " # if max_cut_length is greater than length of cry, take a cut of length between min_cut_length and full length of cry\n", 260 | " elif max_cut_length >= cry_length:\n", 261 | " cut_length_interval = (min_cut_length, cry_length)\n", 262 | " return cut_length_interval\n", 263 | "\n", 264 | "\n", 265 | "cut_length_interval = (3, 5)\n", 266 | "manifest_df[\"cut_length_interval_in_frames\"] = manifest_df.apply(\n", 267 | " lambda row: create_cut_length_interval(row, cut_length_interval=cut_length_interval), axis=1\n", 268 | ")" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": { 274 | "id": "VS6R0uA2tpAJ" 275 | }, 276 | "source": [ 277 | "### Split into train and val\n", 278 | "\n", 279 | "For training a classfier, we can split the data into train/val in any way, as long as val does not contain new classes\n", 280 | "\n", 281 | "One way to split is to split by period: train on birth recordings and validate on discharge" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": { 288 | "colab": { 289 | "base_uri": "https://localhost:8080/", 290 | "height": 493 291 | }, 292 | "executionInfo": { 293 | "elapsed": 115, 294 | "status": "ok", 295 | "timestamp": 1683904924645, 296 | "user": { 297 | "displayName": "David Budaghyan", 298 | "userId": "04476631309743002593" 299 | }, 300 | "user_tz": 240 301 | }, 302 | "id": "KceQZp_pga34", 303 | "outputId": "a1eafd79-d192-4f7f-f447-36cb05b33d87" 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "# we can train on any subset of babies (e.g. to reduce the number of classes, only keep babies with long enough cries, etc)\n", 308 | "def get_babies_with_both_recordings(manifest_df):\n", 309 | " count_of_periods_per_baby = manifest_df.groupby(\"baby_id\")[\"period\"].count()\n", 310 | " baby_ids_with_recording_from_both_periods = count_of_periods_per_baby[\n", 311 | " count_of_periods_per_baby == 2\n", 312 | " ].index\n", 313 | " return baby_ids_with_recording_from_both_periods\n", 314 | "\n", 315 | "\n", 316 | "# def get_babies_with_a_birth_recording(manifest_df):\n", 317 | "# bool_series = manifest_df.groupby('baby_id')['period'].unique().apply(set(['B']).issubset)\n", 318 | "# baby_ids_with_a_recordings_from_birth = bool_series[bool_series].index\n", 319 | "# return baby_ids_with_a_recordings_from_birth\n", 320 | "\n", 321 | "\n", 322 | "def split_by_period(row, included_baby_ids):\n", 323 | " if row[\"baby_id\"] in included_baby_ids:\n", 324 | " if row[\"period\"] == \"B\":\n", 325 | " return \"train\"\n", 326 | " else:\n", 327 | " return \"val\"\n", 328 | " else:\n", 329 | " return \"not_used\"\n", 330 | "\n", 331 | "\n", 332 | "babies_with_both_recordings = get_babies_with_both_recordings(manifest_df)\n", 333 | "manifest_df[\"split\"] = manifest_df.apply(\n", 334 | " lambda row: split_by_period(row, included_baby_ids=babies_with_both_recordings), axis=1\n", 335 | ")\n", 336 | "\n", 337 | "# each instance will be identified with a unique id\n", 338 | "manifest_df[\"id\"] = manifest_df[\"baby_id\"] + \"_\" + manifest_df[\"period\"]\n", 339 | "display(manifest_df)\n", 340 | "display(\n", 341 | " manifest_df[\"split\"]\n", 342 | " .value_counts()\n", 343 | " .rename(\"use_babies_with_both_recordings_and_split_by_period\")\n", 344 | ")\n", 345 | "manifest_df.set_index(\"id\").to_json(\"manifest.json\", orient=\"index\")" 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": { 351 | "id": "N2X-8Zs2-Mhm" 352 | }, 353 | "source": [ 354 | "### Create dynamic datasets\n", 355 | "\n", 356 | "See SpeechBrain documentation to understand details" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "metadata": { 363 | "colab": { 364 | "base_uri": "https://localhost:8080/" 365 | }, 366 | "executionInfo": { 367 | "elapsed": 107, 368 | "status": "ok", 369 | "timestamp": 1683904924648, 370 | "user": { 371 | "displayName": "David Budaghyan", 372 | "userId": "04476631309743002593" 373 | }, 374 | "user_tz": 240 375 | }, 376 | "id": "1NxmeLy_fj99", 377 | "outputId": "6ca4659d-4083-4a79-ccdd-241d6501a7df" 378 | }, 379 | "outputs": [], 380 | "source": [ 381 | "# create a dynamic dataset from the csv, only used to create train and val datasets\n", 382 | "dataset = DynamicItemDataset.from_json(\"manifest.json\")\n", 383 | "baby_id_encoder = CategoricalEncoder()\n", 384 | "datasets = {}\n", 385 | "# create a dataset for each split\n", 386 | "for split in [\"train\", \"val\"]:\n", 387 | " # retrieve the desired slice (train or val) and sort by length to minimize amount of padding\n", 388 | " datasets[split] = dataset.filtered_sorted(\n", 389 | " key_test={\"split\": lambda value: value == split}, sort_key=\"duration\"\n", 390 | " ) # select_n=100\n", 391 | " # create the baby_id_encoded field\n", 392 | " datasets[split].add_dynamic_item(\n", 393 | " baby_id_encoder.encode_label_torch, takes=\"baby_id\", provides=\"baby_id_encoded\"\n", 394 | " )\n", 395 | " # set visible fields\n", 396 | " datasets[split].set_output_keys([\"id\", \"baby_id\", \"baby_id_encoded\", \"sig\"])\n", 397 | "\n", 398 | "\n", 399 | "# create the signal field for the val split (no chunking)\n", 400 | "datasets[\"val\"].add_dynamic_item(sb.dataio.dataio.read_audio, takes=\"file_path\", provides=\"sig\")\n", 401 | "\n", 402 | "# the label encoder will map the baby_ids to target classes 0, 1, 2, ...\n", 403 | "# only use the classes which appear in `train`,\n", 404 | "baby_id_encoder.update_from_didataset(datasets[\"train\"], \"baby_id\")\n", 405 | "\n", 406 | "\n", 407 | "# for reading the train split, we add chunking\n", 408 | "def audio_pipeline(file_path, cut_length_interval_in_frames):\n", 409 | " \"\"\"Load the signal, and pass it and its length to the corruption class.\n", 410 | " This is done on the CPU in the `collate_fn`.\"\"\"\n", 411 | " sig = sb.dataio.dataio.read_audio(file_path)\n", 412 | " if cut_length_interval_in_frames is not None:\n", 413 | " cut_length = random.randint(*cut_length_interval_in_frames)\n", 414 | " # pick the start index of the cut\n", 415 | " left_index = random.randint(0, len(sig) - cut_length)\n", 416 | " # cut the signal\n", 417 | " sig = sig[left_index : left_index + cut_length]\n", 418 | " return sig\n", 419 | "\n", 420 | "\n", 421 | "# create the signal field (with chunking)\n", 422 | "datasets[\"train\"].add_dynamic_item(\n", 423 | " audio_pipeline, takes=[\"file_path\", \"cut_length_interval_in_frames\"], provides=\"sig\"\n", 424 | ")\n", 425 | "\n", 426 | "print(datasets[\"train\"][0])" 427 | ] 428 | }, 429 | { 430 | "cell_type": "markdown", 431 | "metadata": { 432 | "id": "bbwH78P8_TOd" 433 | }, 434 | "source": [ 435 | "### Fine-tune the classifier\n", 436 | "\n", 437 | "Here we use a very basic example that just trains for 5 epochs" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": { 444 | "colab": { 445 | "base_uri": "https://localhost:8080/" 446 | }, 447 | "executionInfo": { 448 | "elapsed": 94, 449 | "status": "ok", 450 | "timestamp": 1683904924651, 451 | "user": { 452 | "displayName": "David Budaghyan", 453 | "userId": "04476631309743002593" 454 | }, 455 | "user_tz": 240 456 | }, 457 | "id": "ixojL5uH5y1V", 458 | "outputId": "de735a57-42da-41ab-8860-1665874edf15" 459 | }, 460 | "outputs": [], 461 | "source": [ 462 | "config_filename = \"hparams/ecapa_voxceleb_basic.yaml\"\n", 463 | "overrides = {\n", 464 | " \"seed\": 3011,\n", 465 | " \"n_classes\": len(baby_id_encoder),\n", 466 | " \"experiment_name\": \"ecapa_voxceleb_ft_basic\",\n", 467 | " \"bs\": 32,\n", 468 | " \"n_epochs\": 5,\n", 469 | "}\n", 470 | "device = \"cuda\"\n", 471 | "run_opts = {\"device\": device}\n", 472 | "###########################################\n", 473 | "# Load hyperparameters file with command-line overrides.\n", 474 | "with open(config_filename) as fin:\n", 475 | " hparams = load_hyperpyyaml(fin, overrides)\n", 476 | "# Create experiment directory\n", 477 | "sb.create_experiment_directory(\n", 478 | " experiment_directory=hparams[\"experiment_dir\"],\n", 479 | " hyperparams_to_save=config_filename,\n", 480 | " overrides=overrides,\n", 481 | ")\n", 482 | "\n", 483 | "# Initialize the Brain object to prepare for training.\n", 484 | "crybrain = CryBrain(\n", 485 | " modules=hparams[\"modules\"],\n", 486 | " opt_class=hparams[\"opt_class\"],\n", 487 | " hparams=hparams,\n", 488 | " run_opts=run_opts,\n", 489 | " checkpointer=hparams[\"checkpointer\"],\n", 490 | ")\n", 491 | "\n", 492 | "# if a pretrained model is specified, load it\n", 493 | "if \"pretrained_embedding_model\" in hparams:\n", 494 | " sb.utils.distributed.run_on_main(hparams[\"pretrained_embedding_model\"].collect_files)\n", 495 | " hparams[\"pretrained_embedding_model\"].load_collected(device=device)\n", 496 | "\n", 497 | "crybrain.fit(\n", 498 | " epoch_counter=crybrain.hparams.epoch_counter,\n", 499 | " train_set=datasets[\"train\"],\n", 500 | " valid_set=datasets[\"val\"],\n", 501 | " train_loader_kwargs=hparams[\"train_dataloader_options\"],\n", 502 | " valid_loader_kwargs=hparams[\"val_dataloader_options\"],\n", 503 | ")" 504 | ] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "metadata": { 509 | "id": "cVDparawc1YF" 510 | }, 511 | "source": [ 512 | "You can now use embedding_model.ckpt from this recipe and use it in evaluate.ipynb to verify pairs of cries and submit your results!" 513 | ] 514 | } 515 | ], 516 | "metadata": { 517 | "accelerator": "GPU", 518 | "colab": { 519 | "authorship_tag": "ABX9TyN+3g6dMDslVfuhfJgsJIst", 520 | "gpuType": "T4", 521 | "provenance": [], 522 | "toc_visible": true 523 | }, 524 | "gpuClass": "standard", 525 | "kernelspec": { 526 | "display_name": "Python 3 (ipykernel)", 527 | "language": "python", 528 | "name": "python3" 529 | }, 530 | "language_info": { 531 | "codemirror_mode": { 532 | "name": "ipython", 533 | "version": 3 534 | }, 535 | "file_extension": ".py", 536 | "mimetype": "text/x-python", 537 | "name": "python", 538 | "nbconvert_exporter": "python", 539 | "pygments_lexer": "ipython3", 540 | "version": "3.8.16" 541 | } 542 | }, 543 | "nbformat": 4, 544 | "nbformat_minor": 1 545 | } 546 | -------------------------------------------------------------------------------- /evaluate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "7gd-5ANmkKSu" 7 | }, 8 | "source": [ 9 | "# Evaluation notebook for [CryCeleb2023 challenge](https://huggingface.co/spaces/competitions/CryCeleb2023)\n", 10 | "\n", 11 | "## This notebook does the following:\n", 12 | "- Download the Cryceleb data from Hugging Face.\n", 13 | "- Download a pretrained SpeechBrain model from Hugging Face.\n", 14 | "- Compute embeddings.\n", 15 | "- Compute similarity scores for pairs of embeddings.\n", 16 | "- Compute the equal error rate of the scores and visualize results.\n", 17 | "- Produces my_solution.csv that can be uploaded to the competition platform." 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "fwKCGXetl_Jd" 24 | }, 25 | "source": [ 26 | "### Imports" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# For Colab - uncomment and run the following to set up the repo\n", 36 | "# !pip install speechbrain\n", 37 | "# !git clone https://github.com/Ubenwa/cryceleb2023.git\n", 38 | "# %cd cryceleb2023" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "id": "ZZqSCpv_lUIa" 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "%%capture\n", 50 | "\n", 51 | "import matplotlib.pyplot as plt\n", 52 | "import numpy as np\n", 53 | "import pandas as pd\n", 54 | "import seaborn as sns\n", 55 | "import speechbrain as sb\n", 56 | "import torch\n", 57 | "from huggingface_hub import hf_hub_download\n", 58 | "from IPython.display import display\n", 59 | "from speechbrain.dataio.dataio import read_audio\n", 60 | "from speechbrain.pretrained import EncoderClassifier, SpeakerRecognition\n", 61 | "from speechbrain.utils.metric_stats import EER\n", 62 | "from tqdm.notebook import tqdm\n", 63 | "\n", 64 | "from crybrain import download_data\n", 65 | "\n", 66 | "dataset_path = \"data\"" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "AUNVrz16mNZH" 73 | }, 74 | "source": [ 75 | "### Data" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "from huggingface_hub import notebook_login\n", 85 | "\n", 86 | "notebook_login()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": { 93 | "colab": { 94 | "base_uri": "https://localhost:8080/", 95 | "height": 177, 96 | "referenced_widgets": [ 97 | "48a9413b4a9f44ab9b133f75757ba1b3", 98 | "bde64b1a9a2a44a68ce376638850faf1", 99 | "c688ebb7d042446db9688c0a8c5686cb", 100 | "31f9d67b4f5d41fc8761bc1f7bcc3a88", 101 | "6faf6b7594064b78b38122e45c69229c", 102 | "8a390c2647714f06982ede802e708857", 103 | "87c274b22fc7460188a51d8d5689ec04", 104 | "69e01654e825416783101f44f800ba37", 105 | "85c75c5f031844f58f61c2217819201c", 106 | "56ac688d29664728a5921d5b6132da75", 107 | "aae703425f7245a8a54dd95870e98345", 108 | "d25726ee8d404d30b151aa04950aec53", 109 | "ee675ef5bdcc4b01b5f5d63737ef4e4a", 110 | "01f99069fb0648639f57a2fcc1625302", 111 | "688abf05d24644e385cee5294d63eaff", 112 | "adf5285fdc5e47cd8bd89acec431c519", 113 | "afe24b4ae3df482f90633ffd2600254c", 114 | "a445688a0d254b0abb42951544c5de29", 115 | "4cae57f7658c46168272be6ab8ed6361", 116 | "e3a7af1279464532b725177f0a0d50c2", 117 | "493f3c83c220430ab95687b724a7ce36", 118 | "3ccf5f36399f4dd9b5a1b39034147011", 119 | "92a1d2e7cfba4139bd30f33c9952ad0b", 120 | "446adc6e4b49402295141543a06d7022", 121 | "0e96610caf1747dc99d14ba7d4717d1e", 122 | "608197a5662e440db4309350a2f19148", 123 | "47b28b83da3048a0880294d898ed25a8", 124 | "38a29275901949afae2e4bf373463c19", 125 | "2976157384ca4771ac36d66d54f4eebc", 126 | "9c1a48b5361d4b4bb2a33fe44f59b4a9", 127 | "bb807e9c793d415f8e0992856b2bda96", 128 | "7b4a03db3e0d4bd5906e20333ce5358f", 129 | "3efc045306d14584a02a06cfa49f8d69", 130 | "3a07bc2787cc43a598d0a2b189e5fc36", 131 | "cb23d7bbcb014798964fde1f90b76496", 132 | "0c76014da6b44182a2c3561c473e1656", 133 | "d49d39ccb6b0439497375a9f37e961cb", 134 | "2885010222bc494e903b0ef93b8d27b2", 135 | "052b0a77e57a45d0be587a7564e18dd9", 136 | "42eadb95424e44399a1916768ae31002", 137 | "1785f38eb7494ba3a28870cc4792363f", 138 | "036f560d958f4e4a8d2367e4a1f5ab2f", 139 | "4a271e56fb17495c8efdf47c0f5b1582", 140 | "a227c472debb467bb952ef172afbda37", 141 | "74575ce9235f449c8cdab90ff650ad0e", 142 | "a8e3cdce46304042882a1121d075668e", 143 | "3ddf4f47a405499e93634ae846855d48", 144 | "351363b8f8d040b19745f5201db56809", 145 | "6ce593721e034de682cef0645cfefc31", 146 | "5d55ed42d44b417fa9aa8ea2bc5128f5", 147 | "39e121daa40b43b280d7bae156847071", 148 | "afb60ef0d33a4a7dba0396602d1c014e", 149 | "36722281b9f243fab567adffc6c6dc3b", 150 | "d74ccc6ec2f24ef291d98cadab769bcc", 151 | "afd5310507ac4800b254af99d9b89ed0" 152 | ] 153 | }, 154 | "id": "zqXn1mT8nRwp", 155 | "outputId": "59280321-e953-44c9-9658-ebc32f9a7941" 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "download_data(dataset_path)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": { 166 | "colab": { 167 | "base_uri": "https://localhost:8080/", 168 | "height": 871 169 | }, 170 | "id": "CenUmoY_mMqw", 171 | "outputId": "1bb9f7b1-fae0-4f58-8a6a-c50c8780379a" 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "# read metadata\n", 176 | "metadata = pd.read_csv(\n", 177 | " f\"{dataset_path}/metadata.csv\", dtype={\"baby_id\": str, \"chronological_index\": str}\n", 178 | ")\n", 179 | "dev_metadata = metadata.loc[metadata[\"split\"] == \"dev\"].copy()\n", 180 | "# read sample submission\n", 181 | "sample_submission = pd.read_csv(\n", 182 | " f\"{dataset_path}/sample_submission.csv\"\n", 183 | ") # scores are unfiorm random\n", 184 | "# read verification pairs\n", 185 | "dev_pairs = pd.read_csv(\n", 186 | " f\"{dataset_path}/dev_pairs.csv\", dtype={\"baby_id_B\": str, \"baby_id_D\": str}\n", 187 | ")\n", 188 | "test_pairs = pd.read_csv(f\"{dataset_path}/test_pairs.csv\")\n", 189 | "\n", 190 | "display(\n", 191 | " metadata.head()\n", 192 | " .style.set_caption(\"metadata\")\n", 193 | " .set_table_styles([{\"selector\": \"caption\", \"props\": [(\"font-size\", \"20px\")]}])\n", 194 | ")\n", 195 | "display(\n", 196 | " dev_pairs.head()\n", 197 | " .style.set_caption(\"dev_pairs\")\n", 198 | " .set_table_styles([{\"selector\": \"caption\", \"props\": [(\"font-size\", \"20px\")]}])\n", 199 | ")\n", 200 | "display(\n", 201 | " test_pairs.head()\n", 202 | " .style.set_caption(\"test_pairs\")\n", 203 | " .set_table_styles([{\"selector\": \"caption\", \"props\": [(\"font-size\", \"20px\")]}])\n", 204 | ")\n", 205 | "display(\n", 206 | " sample_submission.head()\n", 207 | " .style.set_caption(\"sample_submission\")\n", 208 | " .set_table_styles([{\"selector\": \"caption\", \"props\": [(\"font-size\", \"20px\")]}])\n", 209 | ")" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": { 215 | "id": "i7qn0lFdmOlF" 216 | }, 217 | "source": [ 218 | "### Initialize encoder" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": { 224 | "id": "Rtgd7qlfmUWC" 225 | }, 226 | "source": [ 227 | "One way to verify if both pairs come from the same baby is to concatenate all the segments for each pair, compute the embedding of the concatenated cry, and compute the cosine similarity between the embeddings.\n", 228 | "\n", 229 | "Let's load the model" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": { 236 | "colab": { 237 | "base_uri": "https://localhost:8080/", 238 | "height": 81, 239 | "referenced_widgets": [ 240 | "a5a6c8050ce444fe9e5e3835af2c89f7", 241 | "ffe1c871271b4eeb8566f8c3e9535679", 242 | "fe22550070ad45c4a23564099beafadd", 243 | "8657df05011a4f5f8cf0492aed78e819", 244 | "7d7bf74077214e65b3727346d8e4485b", 245 | "d31bd375beac4b308abc9751e88d6762", 246 | "7359bf51f90846999bc7e6fedf1ab4f1", 247 | "635dd9fb9e844c99b60812acd3b8cefc", 248 | "9bcebcb7cb224ff9ba92e3693adcf04e", 249 | "4a8dffe277524f48aa5e2cdb76d22926", 250 | "dab9e9c96c344b78abf24353bbbeebb9", 251 | "785545f1963f4f6d849fb562f31a3c0d", 252 | "2c7f0724170f4e0d9d74677aad6e18a3", 253 | "f876ee90477b41e3972374a5f874647c", 254 | "da7ba01d01c84602a5066475ddba8195", 255 | "cd5315c7c4ed475f837285339fca2fcc", 256 | "0adc32cfd65441f4becabf682041611f", 257 | "d9c373060ef242cba7b669582645ef69", 258 | "732fe933fcb14a1f8b32988dc5ab41a3", 259 | "2b4b55f6867443e0b10ae8caccef9fc2", 260 | "4c950b1378d34aeea6d87f53f6e5484e", 261 | "9760ab6f43f94981986f37a79cf0bdd7" 262 | ] 263 | }, 264 | "id": "CsE0Z6JCmSBv", 265 | "outputId": "6c22b5b6-a996-4312-b33e-78e89de6cc33" 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "!rm -rf spkrec-ecapa-voxceleb\n", 270 | "encoder = SpeakerRecognition.from_hparams(\n", 271 | " source=\"speechbrain/spkrec-ecapa-voxceleb\",\n", 272 | " savedir=\"spkrec-ecapa-voxceleb\",\n", 273 | " run_opts={\"device\": \"cuda\"}, # comment out if no GPU available\n", 274 | ")" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "# you can also plug in your encoder weights if you fine-tuned this model locally\n", 284 | "# !rm spkrec-ecapa-voxceleb/embedding_model.ckpt\n", 285 | "# !cp experiments/ecapa_voxceleb_ft_basic/ckpts/CKPT+epoch-4_valacc-0.57/embedding_model.ckpt spkrec-ecapa-voxceleb\n", 286 | "\n", 287 | "# encoder = SpeakerRecognition.from_hparams(\n", 288 | "# source=\"speechbrain/spkrec-ecapa-voxceleb\",\n", 289 | "# savedir=\"spkrec-ecapa-voxceleb\",\n", 290 | "# run_opts={\"device\": \"cuda\"}, # comment out if no GPU available\n", 291 | "# )" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": { 297 | "id": "vUP0tjNImYT4" 298 | }, 299 | "source": [ 300 | "#### Compute Encodings" 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": { 306 | "id": "mFTLJ74s6I3_" 307 | }, 308 | "source": [ 309 | "Change runtime type to GPU if using Colab" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": { 316 | "colab": { 317 | "base_uri": "https://localhost:8080/", 318 | "height": 159, 319 | "referenced_widgets": [ 320 | "780d053589614f1981ce9c3bc1bf2633", 321 | "d7869f0445404dfa98a36832af202d75", 322 | "36248ef425d84a31a5f2b4bf0935d1ce", 323 | "dcf4695b9f694a55a9fa0691f7c0b236", 324 | "bbb7c274e05c4feba27a1968b0566621", 325 | "790b64fac2d84d10b2fb6a1d1ba30b7f", 326 | "9d7d86758d46432085a7f1eeef6ee589", 327 | "75ef1cc7d24b47fe8dbd1b3e9b8c974d", 328 | "e9b6cee88af34f968c4d09d232b62120", 329 | "366a1031361647ec8aad6b599a2746a1", 330 | "a1ff4f5afe694b4b86e1aa80fa0169e5" 331 | ] 332 | }, 333 | "id": "rgMZ8rrBmWc3", 334 | "outputId": "7894ce23-2972-4554-ccd1-d84f52230bdd" 335 | }, 336 | "outputs": [], 337 | "source": [ 338 | "%%time\n", 339 | "# read the segments\n", 340 | "dev_metadata[\"cry\"] = dev_metadata.apply(\n", 341 | " lambda row: read_audio(f'{dataset_path}/{row[\"file_name\"]}').numpy(), axis=1\n", 342 | ")\n", 343 | "# concatenate all segments for each (baby_id, period) group\n", 344 | "cry_dict = pd.DataFrame(\n", 345 | " dev_metadata.groupby([\"baby_id\", \"period\"])[\"cry\"].agg(lambda x: np.concatenate(x.values)),\n", 346 | " columns=[\"cry\"],\n", 347 | ").to_dict(orient=\"index\")\n", 348 | "# encode the concatenated cries\n", 349 | "for (baby_id, period), d in tqdm(cry_dict.items()):\n", 350 | " d[\"cry_encoded\"] = encoder.encode_batch(torch.tensor(d[\"cry\"]), normalize=False)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": { 356 | "id": "PF4Sa3BnmcLA" 357 | }, 358 | "source": [ 359 | "#### Compute Similarity Between Encodings" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": { 366 | "colab": { 367 | "base_uri": "https://localhost:8080/", 368 | "height": 206 369 | }, 370 | "id": "lgaScJgImcaO", 371 | "outputId": "691c5120-65d9-4eae-8d4f-c8d79a502254" 372 | }, 373 | "outputs": [], 374 | "source": [ 375 | "def compute_cosine_similarity_score(row, cry_dict):\n", 376 | " cos = torch.nn.CosineSimilarity(dim=-1)\n", 377 | " similarity_score = cos(\n", 378 | " cry_dict[(row[\"baby_id_B\"], \"B\")][\"cry_encoded\"],\n", 379 | " cry_dict[(row[\"baby_id_D\"], \"D\")][\"cry_encoded\"],\n", 380 | " )\n", 381 | " return similarity_score.item()\n", 382 | "\n", 383 | "\n", 384 | "dev_pairs[\"score\"] = dev_pairs.apply(\n", 385 | " lambda row: compute_cosine_similarity_score(row=row, cry_dict=cry_dict), axis=1\n", 386 | ")\n", 387 | "display(dev_pairs.head())" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": { 394 | "colab": { 395 | "base_uri": "https://localhost:8080/", 396 | "height": 472 397 | }, 398 | "id": "3tBfFZ1OmeaG", 399 | "outputId": "9eedd68d-69f2-44da-d156-6db7acce7e4d" 400 | }, 401 | "outputs": [], 402 | "source": [ 403 | "def compute_eer_and_plot_verification_scores(pairs_df):\n", 404 | " \"\"\"pairs_df must have 'score' and 'label' columns\"\"\"\n", 405 | " positive_scores = pairs_df.loc[pairs_df[\"label\"] == 1][\"score\"].values\n", 406 | " negative_scores = pairs_df.loc[pairs_df[\"label\"] == 0][\"score\"].values\n", 407 | " eer, threshold = EER(torch.tensor(positive_scores), torch.tensor(negative_scores))\n", 408 | " ax = sns.histplot(pairs_df, x=\"score\", hue=\"label\", stat=\"percent\", common_norm=False)\n", 409 | " ax.set_title(f\"EER={round(eer, 4)} - Thresh={round(threshold, 4)}\")\n", 410 | " plt.axvline(x=[threshold], color=\"red\", ls=\"--\")\n", 411 | " return eer, threshold\n", 412 | "\n", 413 | "\n", 414 | "eer, threshold = compute_eer_and_plot_verification_scores(pairs_df=dev_pairs)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": { 420 | "id": "LEjkMjYN17rf" 421 | }, 422 | "source": [ 423 | "The above plot displays the histogram of scores for +ive (same baby) and -ive (different baby) dev_pairs.\n", 424 | "\n", 425 | "A perfect verifier would attribute a higher score to all +ive pairs than any -ive pair.\\\n", 426 | "Your task is to come up with a scoring system which maximizes the separation between the two distributions, as measured by the EER.\\\n", 427 | "You can change the encoder module, the aggregation of cry segments, the similarity metric, or come up with a completely different process! \\\n", 428 | "\n", 429 | "\n" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": {}, 436 | "outputs": [], 437 | "source": [ 438 | "# same for test set that was hidden during evaluation\n", 439 | "eer, threshold = compute_eer_and_plot_verification_scores(pairs_df=test_pairs)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "markdown", 444 | "metadata": { 445 | "id": "yINubTKWImh1" 446 | }, 447 | "source": [ 448 | "You can also create example submission file for the challenge using code below. \n", 449 | "\n", 450 | "It is no more relevant as we have access to eval labels now and do scoring above." 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": { 457 | "colab": { 458 | "base_uri": "https://localhost:8080/", 459 | "height": 274, 460 | "referenced_widgets": [ 461 | "fe70b8c96fdc4f5695438719d0a538f0", 462 | "60d9353fa67d47b08d6f79fe9edf14ff", 463 | "6d6f9ae639ea4b39943fa5abc84f2aaf", 464 | "10fb5219f51c455bb1039fe19513175d", 465 | "fe030cd60b6a49f9a5e22c53df2c1fbb", 466 | "7203a181918148d1a5e0fadd8d33344d", 467 | "751cbbff056f476aba01b860f14a1b52", 468 | "39c3334153a74fb8b9c470e55dac6e4c", 469 | "c4077ea694104e908743bd647fb228b7", 470 | "31f214f5206f4939a9e7eafff0fdb63c", 471 | "e67f032a44a1490e9eefa81886a86705" 472 | ] 473 | }, 474 | "id": "q7LrpPFVIGir", 475 | "outputId": "43ce30ba-debd-49a0-f0e8-d6b82c7653b7" 476 | }, 477 | "outputs": [], 478 | "source": [ 479 | "%%time\n", 480 | "test_metadata = metadata.loc[metadata[\"split\"] == \"test\"].copy()\n", 481 | "# read the segments\n", 482 | "test_metadata[\"cry\"] = test_metadata.apply(\n", 483 | " lambda row: read_audio(f'{dataset_path}/{row[\"file_name\"]}').numpy(), axis=1\n", 484 | ")\n", 485 | "# concatenate all segments for each (baby_id, period) group\n", 486 | "cry_dict_test = pd.DataFrame(\n", 487 | " test_metadata.groupby([\"baby_id\", \"period\"])[\"cry\"].agg(lambda x: np.concatenate(x.values)),\n", 488 | " columns=[\"cry\"],\n", 489 | ").to_dict(orient=\"index\")\n", 490 | "# encode the concatenated cries\n", 491 | "for (baby_id, period), d in tqdm(cry_dict_test.items()):\n", 492 | " d[\"cry_encoded\"] = encoder.encode_batch(torch.tensor(d[\"cry\"]), normalize=False)\n", 493 | "\n", 494 | "# compute cosine similarity between all pairs\n", 495 | "test_pairs[\"score\"] = test_pairs.apply(\n", 496 | " lambda row: compute_cosine_similarity_score(row=row, cry_dict=cry_dict_test), axis=1\n", 497 | ")\n", 498 | "display(test_pairs.head())" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "metadata": { 505 | "colab": { 506 | "base_uri": "https://localhost:8080/", 507 | "height": 206 508 | }, 509 | "id": "4ZYkMxOrKRQ3", 510 | "outputId": "c58ed78b-1295-4209-c474-384e421cc779" 511 | }, 512 | "outputs": [], 513 | "source": [ 514 | "# submission must match the 'sample_submission.csv' format exactly\n", 515 | "my_submission = test_pairs[[\"id\", \"score\"]]\n", 516 | "my_submission.to_csv(\"my_submission.csv\", index=False)\n", 517 | "display(my_submission.head())" 518 | ] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "metadata": { 523 | "id": "S3dgbBwUfpzU" 524 | }, 525 | "source": [ 526 | "You can now download `my_submission.csv` and submit it to the challenge!" 527 | ] 528 | } 529 | ], 530 | "metadata": { 531 | "accelerator": "GPU", 532 | "colab": { 533 | "provenance": [] 534 | }, 535 | "gpuClass": "standard", 536 | "kernelspec": { 537 | "display_name": "Python 3 (ipykernel)", 538 | "language": "python", 539 | "name": "python3" 540 | }, 541 | "language_info": { 542 | "codemirror_mode": { 543 | "name": "ipython", 544 | "version": 3 545 | }, 546 | "file_extension": ".py", 547 | "mimetype": "text/x-python", 548 | "name": "python", 549 | "nbconvert_exporter": "python", 550 | "pygments_lexer": "ipython3", 551 | "version": "3.11.5" 552 | } 553 | }, 554 | "nbformat": 4, 555 | "nbformat_minor": 1 556 | } 557 | --------------------------------------------------------------------------------