├── requirements.txt
├── .gitignore
├── crybrain_config_utils.py
├── README.md
├── .pre-commit-config.yaml
├── hparams
└── ecapa_voxceleb_basic.yaml
├── crybrain.py
├── LICENSE
├── train.ipynb
└── evaluate.ipynb
/requirements.txt:
--------------------------------------------------------------------------------
1 | speechbrain
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | audio
2 | .ipynb_checkpoints
3 | concatenated_audio_train
4 | *.zip
5 | *.json
6 | *.csv
7 | __pycache__/
8 | experiments/
9 | *.sw*
10 |
--------------------------------------------------------------------------------
/crybrain_config_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 | from speechbrain.nnet.schedulers import CyclicLRScheduler, ReduceLROnPlateau
7 |
8 |
9 | def choose_lrsched(lrsched_name, **kwargs):
10 | print(f"lrsched_name: {lrsched_name}")
11 | if lrsched_name == "onplateau":
12 | return ReduceLROnPlateau(
13 | lr_min=kwargs["lr_min"],
14 | factor=kwargs["factor"],
15 | patience=kwargs["patience"],
16 | dont_halve_until_epoch=kwargs["dont_halve_until_epoch"],
17 | )
18 | elif lrsched_name == "cyclic":
19 | return CyclicLRScheduler(
20 | base_lr=kwargs["base_lr"],
21 | max_lr=kwargs["max_lr"],
22 | step_size=kwargs["step_size"],
23 | mode=kwargs["mode"],
24 | gamma=kwargs["gamma"],
25 | scale_fn=kwargs["scale_fn"],
26 | scale_mode=kwargs["scale_mode"],
27 | )
28 |
29 |
30 | def set_seed(seed):
31 | """Set seed in every way possible."""
32 | print(f"setting seeds to {seed}")
33 | random.seed(seed)
34 | np.random.seed(seed)
35 | torch.manual_seed(seed)
36 | os.environ["PYTHONHASHSEED"] = str(seed)
37 | if torch.cuda.is_available():
38 | print(f"setting cuda seeds to {seed}")
39 | torch.cuda.manual_seed(seed)
40 | torch.cuda.manual_seed_all(seed)
41 |
42 |
43 | def test_cuda_seed():
44 | """Print some random results from various libraries."""
45 | print(f"python random float: {random.random()}")
46 | print(f"numpy random int: {np.random.randint(100)}")
47 | print(f"torch random tensor (cpu): {torch.FloatTensor(100).uniform_()}")
48 | print(f"torch random tensor (cuda): {torch.cuda.FloatTensor(100).uniform_()}")
49 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## CryCeleb2023 SpeechBrain fine-tuning
2 |
3 | All you need to start fine-tuning [SpeechBrain](https://speechbrain.readthedocs.io/) models using the [Ubenwa CryCeleb dataset](https://huggingface.co/datasets/Ubenwa/CryCeleb2023)!
4 |
5 | This code was used for [CryCeleb2023 HuggingFace challenge](https://huggingface.co/spaces/competitions/CryCeleb2023)
6 |
7 | It reproduces the [official baseline](https://huggingface.co/Ubenwa/ecapa-voxceleb-ft2-cryceleb) model training
8 |
9 | `train.ipynb`
10 |
11 | - main code for data preparation and fine-tuning with various configs
12 |
13 | `evaluate.ipynb`
14 |
15 | - example of scoring with pre-trained and fine-tuned model
16 |
17 | Note that default configurations are optimized for speed and simplicity rather than accuracy
18 |
19 | ## Cite
20 |
21 | ### CryCeleb (accepted to ICASSP 2024!)
22 |
23 | ```bibtex
24 | @article{ubenwa2023cryceleb,
25 | title={CryCeleb: A Speaker Verification Dataset Based on Infant Cry Sounds},
26 | author={David Budaghyan and Charles C. Onu and Arsenii Gorin and Cem Subakan and Doina Precup},
27 | year={2023},
28 | journal={preprint arXiv:2305.00969},
29 | }
30 | ```
31 |
32 | ### SpeechBrain
33 |
34 | ```bibtex
35 | @misc{speechbrain,
36 | title={{SpeechBrain}: A General-Purpose Speech Toolkit},
37 | author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
38 | year={2021},
39 | eprint={2106.04624},
40 | archivePrefix={arXiv},
41 | primaryClass={eess.AS},
42 | note={arXiv:2106.04624}
43 | }
44 | ```
45 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v4.3.0
7 | hooks:
8 | # list of supported hooks: https://pre-commit.com/hooks.html
9 | - id: trailing-whitespace
10 | - id: end-of-file-fixer
11 | - id: check-docstring-first
12 | # - id: check-yaml
13 | - id: debug-statements
14 | - id: detect-private-key
15 | - id: check-executables-have-shebangs
16 | - id: check-toml
17 | - id: check-case-conflict
18 | - id: check-added-large-files
19 |
20 | # python code formatting
21 | - repo: https://github.com/psf/black
22 | rev: 22.6.0
23 | hooks:
24 | - id: black
25 | args: [--line-length, "120"]
26 |
27 | # python import sorting
28 | - repo: https://github.com/PyCQA/isort
29 | rev: 5.12.0
30 | hooks:
31 | - id: isort
32 | args: ["--profile", "black", "--filter-files"]
33 |
34 | # python upgrading syntax to newer version
35 | - repo: https://github.com/asottile/pyupgrade
36 | rev: v2.32.1
37 | hooks:
38 | - id: pyupgrade
39 | args: [--py38-plus]
40 |
41 | # python docstring formatting
42 | - repo: https://github.com/myint/docformatter
43 | rev: v1.4
44 | hooks:
45 | - id: docformatter
46 | args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
47 |
48 | # python check (PEP8), programming errors and code complexity
49 | - repo: https://github.com/PyCQA/flake8
50 | rev: 4.0.1
51 | hooks:
52 | - id: flake8
53 | args:
54 | [
55 | "--ignore",
56 | "E501,F401,F841,W504,E203,W503",
57 | "--exclude",
58 | "logs/*,data/*",
59 | ]
60 |
61 | # yaml formatting
62 | - repo: https://github.com/pre-commit/mirrors-prettier
63 | rev: v2.7.1
64 | hooks:
65 | - id: prettier
66 | types: [yaml]
67 |
68 | # jupyter notebook cell output clearing
69 | - repo: https://github.com/kynan/nbstripout
70 | rev: 0.5.0
71 | hooks:
72 | - id: nbstripout
73 |
74 | # jupyter notebook linting
75 | - repo: https://github.com/nbQA-dev/nbQA
76 | rev: 1.4.0
77 | hooks:
78 | - id: nbqa-black
79 | args: ["--line-length=99"]
80 | - id: nbqa-isort
81 | args: ["--profile=black"]
82 | - id: nbqa-flake8
83 | args:
84 | [
85 | "--extend-ignore=E203,E402,E501,F401,F841",
86 | "--exclude=logs/*,data/*",
87 | ]
88 |
89 | # md formatting
90 | - repo: https://github.com/executablebooks/mdformat
91 | rev: 0.7.14
92 | hooks:
93 | - id: mdformat
94 | args: ["--number"]
95 | additional_dependencies:
96 | - mdformat-gfm
97 | - mdformat-tables
98 | - mdformat_frontmatter
99 | # - mdformat-toc
100 | # - mdformat-black
101 |
102 | # word spelling linter
103 | - repo: https://github.com/codespell-project/codespell
104 | rev: v2.1.0
105 | hooks:
106 | - id: codespell
107 | args:
108 | - --skip=logs/**,data/**
109 | - --skip=notebooks/*
110 | # - --ignore-words-list=abc,def
111 |
--------------------------------------------------------------------------------
/hparams/ecapa_voxceleb_basic.yaml:
--------------------------------------------------------------------------------
1 | # ################################
2 | # Model: Speaker identification with ECAPA for CryCeleb
3 | # Authors: David Budaghyan
4 | # ################################
5 |
6 | ckpt_interval_minutes: 15 # save checkpoint every N min
7 |
8 | ##### SEED
9 | seed: !PLACEHOLDER
10 | __set_seed: !apply:crybrain_config_utils.set_seed [!ref ]
11 |
12 | # DataLoader
13 | bs: 16
14 | train_dataloader_options:
15 | batch_size: !ref
16 | shuffle: True
17 | val_dataloader_options:
18 | batch_size: 2
19 | shuffle: False
20 |
21 | ##### ESTIMATOR COMPONENTS
22 | # Fbank (feature extractor)
23 | n_mels: 80
24 | left_frames: 0
25 | right_frames: 0
26 | deltas: False
27 | compute_features: !new:speechbrain.lobes.features.Fbank
28 | n_mels: !ref
29 | left_frames: !ref
30 | right_frames: !ref
31 | deltas: !ref
32 |
33 | # ECAPA
34 | emb_dim: 192
35 | embedding_model: !new:speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN
36 | input_size: !ref
37 | channels: [1024, 1024, 1024, 1024, 3072]
38 | kernel_sizes: [5, 3, 3, 3, 1]
39 | dilations: [1, 2, 3, 4, 1]
40 | groups: [1, 1, 1, 1, 1]
41 | attention_channels: 128
42 | lin_neurons: !ref
43 |
44 | # If you do not want to use the pretrained encoder you can simply delete pretrained_encoder field.
45 | pretrained_model_name: spkrec-ecapa-voxceleb
46 | pretrained_embedding_model_path: !ref speechbrain//embedding_model.ckpt
47 | pretrained_embedding_model: !new:speechbrain.utils.parameter_transfer.Pretrainer
48 | collect_in: !ref /ckpts
49 | loadables:
50 | model: !ref
51 | paths:
52 | model: !ref
53 |
54 | # CLASSIFIER
55 | n_classes: !PLACEHOLDER # check-yaml disable
56 |
57 |
58 | classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
59 | input_size: !ref
60 | out_neurons: !ref
61 |
62 | ##### EPOCH COUNTER
63 | n_epochs: 1000
64 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
65 | limit: !ref
66 |
67 | ##### OPTIMIZER
68 | start_lr: 0.0001
69 | opt_class: !name:torch.optim.Adam
70 | lr: !ref
71 | weight_decay: 0.000002
72 |
73 | ##### LEARNING RATE SCHEDULERS
74 | lrsched_name: cyclic
75 | # one of:
76 | # onplateau
77 | # cyclic
78 | lr_min: 0.0000000001
79 | lr_scheduler: !apply:crybrain_config_utils.choose_lrsched
80 | lrsched_name: !ref
81 | #below are kwargs, only the ones relevant to the type of scheduler will be
82 | #used for initialization in `choose_lrsched`
83 |
84 | #onplateau (ReduceLROnPlateau)
85 | lr_min: !ref
86 | factor: 0.4
87 | patience: 10
88 | dont_halve_until_epoch: 35
89 | #cyclic (CyclicLRScheduler)
90 | base_lr: 0.00000001
91 | max_lr: !ref
92 | step_size: 100
93 | mode: triangular
94 | gamma: 1.0
95 | scale_fn: null
96 | scale_mode: cycle
97 |
98 | sample_rate: 16000
99 | mean_var_norm: !new:speechbrain.processing.features.InputNormalization
100 | norm_type: sentence
101 | std_norm: False
102 |
103 | modules:
104 | compute_features: !ref
105 | embedding_model: !ref
106 | classifier: !ref
107 | mean_var_norm: !ref
108 |
109 | compute_cost: !new:speechbrain.nnet.losses.LogSoftmaxWrapper
110 | loss_fn: !new:speechbrain.nnet.losses.AdditiveAngularMargin
111 | margin: 0.2
112 | scale: 30
113 |
114 | classification_stats: !name:speechbrain.utils.metric_stats.ClassificationStats
115 | ###################################################################
116 | ### OUTPUT PATHS ###
117 |
118 |
119 | experiment_name:
120 | !PLACEHOLDER # must run from the directory which contains "experiments"
121 |
122 |
123 | experiment_dir: !ref ./experiments/
124 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
125 | save_file: !ref /train_log.txt
126 |
127 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
128 | checkpoints_dir: !ref /ckpts
129 | recoverables:
130 | embedding_model: !ref
131 | classifier: !ref
132 | normalizer: !ref
133 | counter: !ref
134 | lr_scheduler: !ref
135 |
--------------------------------------------------------------------------------
/crybrain.py:
--------------------------------------------------------------------------------
1 | """Utils for working with cryceleb2023 in SpeechBrain.
2 |
3 | Author
4 | * David Budaghyan 2023
5 | * Mirco Ravanelli 2020
6 | """
7 |
8 |
9 | import os
10 | import pickle
11 | import zipfile
12 |
13 | import speechbrain as sb
14 | import torch
15 | from huggingface_hub import hf_hub_download
16 |
17 |
18 | def download_data(dest="data"):
19 |
20 | if os.path.exists(os.path.join(dest, "audio", "train")):
21 | print(
22 | f"It appears that data is already downloaded. \nIf you think it should be re-downloaded, remove {dest}/ directory and re-run"
23 | )
24 | return
25 |
26 | # download data from Huggingface
27 | for file_name in ["metadata.csv", "audio.zip", "dev_pairs.csv", "test_pairs.csv", "sample_submission.csv"]:
28 |
29 | hf_hub_download(
30 | repo_id="Ubenwa/CryCeleb2023",
31 | filename=file_name,
32 | local_dir=dest,
33 | repo_type="dataset",
34 | )
35 |
36 | with zipfile.ZipFile(os.path.join(dest, "audio.zip"), "r") as zip_ref:
37 | zip_ref.extractall(dest)
38 |
39 | print("Data downloaded to {dest}/ directory")
40 |
41 |
42 | class CryBrain(sb.core.Brain):
43 | """Class for speaker embedding training"."""
44 |
45 | def compute_forward(self, batch, stage):
46 | """Computation pipeline based on a encoder + speaker classifier.
47 |
48 | Data augmentation and environmental corruption are applied to the input speech.
49 | """
50 | batch = batch.to(self.device)
51 | wavs, lens = batch.sig
52 |
53 | if stage == sb.Stage.TRAIN and hasattr(self.hparams, "augment_pipeline"):
54 |
55 | # Applying the augmentation pipeline
56 | wavs_aug_tot = []
57 | wavs_aug_tot.append(wavs)
58 | for count, augment in enumerate(self.hparams.augment_pipeline):
59 | # Apply augment
60 | wavs_aug = augment(wavs, lens)
61 |
62 | # Managing speed change
63 | if wavs_aug.shape[1] > wavs.shape[1]:
64 | wavs_aug = wavs_aug[:, 0 : wavs.shape[1]]
65 | else:
66 | zero_sig = torch.zeros_like(wavs)
67 | zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug
68 | wavs_aug = zero_sig
69 |
70 | if self.hparams.concat_augment:
71 | wavs_aug_tot.append(wavs_aug)
72 | else:
73 | wavs = wavs_aug
74 | wavs_aug_tot[0] = wavs
75 |
76 | wavs = torch.cat(wavs_aug_tot, dim=0)
77 | self.n_augment = len(wavs_aug_tot)
78 | lens = torch.cat([lens] * self.n_augment)
79 |
80 | # Feature extraction and normalization
81 | feats = self.modules.compute_features(wavs)
82 | feats = self.modules.mean_var_norm(feats, lens)
83 |
84 | # Embeddings + speaker classifier
85 | embeddings = self.modules.embedding_model(feats)
86 | classifier_outputs = self.modules.classifier(embeddings)
87 | return classifier_outputs, lens
88 |
89 | def compute_objectives(self, compute_forward_return, batch, stage):
90 | """Computes the loss using speaker-id as label."""
91 | classifier_outputs, lens = compute_forward_return
92 | uttid = batch.id
93 | labels, _ = batch.baby_id_encoded
94 |
95 | # Concatenate labels (due to data augmentations)
96 | if stage == sb.Stage.TRAIN and hasattr(self.hparams, "augment_pipeline"):
97 | labels = torch.cat([labels] * self.n_augment, dim=0)
98 | uttid = [f"{u}_{i}" for i in range(self.n_augment) for u in uttid]
99 |
100 | loss = self.hparams.compute_cost(classifier_outputs, labels, lens)
101 |
102 | if stage == sb.Stage.TRAIN and hasattr(self.hparams.lr_scheduler, "on_batch_end"):
103 | self.hparams.lr_scheduler.on_batch_end(self.optimizer)
104 |
105 | predictions = [str(pred.item()) for pred in classifier_outputs.squeeze().argmax(1)]
106 | targets = [str(t.item()) for t in labels.squeeze()]
107 |
108 | # append the stats, if val, we also append the log_probs
109 | if stage == sb.Stage.TRAIN:
110 | self.classification_stats.append(ids=uttid, predictions=predictions, targets=targets)
111 | elif stage == sb.Stage.VALID:
112 | self.classification_stats.append(ids=uttid, predictions=predictions, targets=targets)
113 | return loss
114 |
115 | def on_stage_start(self, stage, epoch=None):
116 | """Gets called at the beginning of an epoch."""
117 |
118 | self.classification_stats = self.hparams.classification_stats()
119 |
120 | def _write_stats(self, stage, epoch):
121 | """Wrties stats to f"{self.experiment_dir}/stats/{epoch}/{stage}.txt
122 | Arguments
123 | ---------
124 | epoch: int
125 | the epoch number
126 | stage: str
127 | "train" or "test"
128 | """
129 | # do this to extract "train", "val", "test" to put in the file path
130 | stage = stage.__str__().split(".")[-1].lower()
131 | output_dir = os.path.join(
132 | self.hparams.experiment_dir,
133 | "stats",
134 | str(epoch),
135 | )
136 | # create dir if doesn't exist
137 | if not os.path.exists(output_dir):
138 | os.makedirs(output_dir)
139 | # write classwise stats and confusion stats using speechbrain's classification_stats
140 | classwise_file_path = os.path.join(output_dir, f"classwise_{stage}.txt")
141 | with open(classwise_file_path, "w") as w:
142 | self.classification_stats.write_stats(w)
143 | # logger.info("classwise_statsvwritten to file: %s", classwise_file_path)
144 |
145 | # write instancewise stats
146 | instancewise_stats = [["instance_id", "prediction", "target"]]
147 | instancewise_stats.extend(
148 | [
149 | [instance_id, prediction, target]
150 | for instance_id, prediction, target in zip(
151 | self.classification_stats.ids,
152 | self.classification_stats.predictions,
153 | self.classification_stats.targets,
154 | )
155 | ]
156 | )
157 | instancewise_file_path = os.path.join(output_dir, f"instancewise_{stage}.pkl")
158 | with open(instancewise_file_path, "wb") as pkl_file:
159 | pickle.dump(instancewise_stats, pkl_file)
160 |
161 | def on_stage_end(self, stage, stage_loss, epoch=None):
162 | """Gets called at the end of an epoch."""
163 |
164 | log_stats = {
165 | "loss": stage_loss,
166 | # .summarize("accuracy") computes many statistics
167 | # but only returns the accuracy. The entire dictionary of stats is written to the stats dir.
168 | "acc": self.classification_stats.summarize("accuracy") * 100,
169 | }
170 | if stage == sb.Stage.TRAIN:
171 | # this is to save the log_stats, so we
172 | # write it along the validation log_stats after validation stage
173 | self.train_log_stats = log_stats
174 | # Perform end-of-iteration things, like annealing, logging, etc.
175 | if stage == sb.Stage.VALID:
176 | if self.hparams.lrsched_name == 'onplateau':
177 | old_lr, new_lr = self.hparams.lr_scheduler([self.optimizer],
178 | current_epoch=epoch,
179 | current_loss=stage_loss)
180 | else:
181 | old_lr, new_lr = self.hparams.lr_scheduler(epoch)
182 | sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
183 | # LOGGING
184 | self.hparams.train_logger.log_stats(
185 | {"Epoch": epoch, "lr": old_lr},
186 | # train_stats={"loss": self.train_loss}, #do this if only keeping loss during training
187 | train_stats=self.train_log_stats,
188 | valid_stats=log_stats,
189 | )
190 | # Save the current checkpoint and delete previous checkpoints,
191 | self.checkpointer.save_and_keep_only(
192 | name=f"epoch-{epoch}_valacc-{log_stats['acc']:.2f}", meta=log_stats, num_to_keep=4, max_keys=["acc"]
193 | )
194 |
195 | self._write_stats(stage=stage, epoch=epoch)
196 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
--------------------------------------------------------------------------------
/train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "fvP3aM0oMKvT"
7 | },
8 | "source": [
9 | "# Fine-tuning ECAPA-TDNN on [CryCeleb2023](https://huggingface.co/spaces/competitions/CryCeleb2023) using [SpeechBrain](https://speechbrain.readthedocs.io)\n",
10 | "\n",
11 | "This notebook should help you get started training your own models for CryCeleb2023 challenge.\n",
12 | "\n",
13 | "Note that it is provides basic example for simplicity and speed.\n",
14 | "\n",
15 | "Author: David Budaghyan (Ubenwa)\n"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "xLoVtmTDGVby"
22 | },
23 | "source": [
24 | "### Imports"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "# For Colab - uncomment and run the following to set up the repo\n",
34 | "# !pip install speechbrain\n",
35 | "# !git clone https://github.com/Ubenwa/cryceleb2023.git\n",
36 | "# %cd cryceleb2023"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {
43 | "executionInfo": {
44 | "elapsed": 4595,
45 | "status": "ok",
46 | "timestamp": 1683904816148,
47 | "user": {
48 | "displayName": "David Budaghyan",
49 | "userId": "04476631309743002593"
50 | },
51 | "user_tz": 240
52 | },
53 | "id": "5k7MiTJsK2Ba"
54 | },
55 | "outputs": [],
56 | "source": [
57 | "%%capture\n",
58 | "%load_ext autoreload\n",
59 | "%autoreload 2\n",
60 | "\n",
61 | "import pathlib\n",
62 | "import random\n",
63 | "\n",
64 | "import numpy as np\n",
65 | "import pandas as pd\n",
66 | "import seaborn as sns\n",
67 | "import speechbrain as sb\n",
68 | "import torch\n",
69 | "from huggingface_hub import hf_hub_download\n",
70 | "from hyperpyyaml import load_hyperpyyaml\n",
71 | "from IPython.display import display\n",
72 | "from speechbrain.dataio.dataio import read_audio, write_audio\n",
73 | "from speechbrain.dataio.dataset import DynamicItemDataset\n",
74 | "from speechbrain.dataio.encoder import CategoricalEncoder\n",
75 | "\n",
76 | "from crybrain import CryBrain, download_data\n",
77 | "\n",
78 | "dataset_path = \"data\""
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {
84 | "id": "oBhuztrjHQX8"
85 | },
86 | "source": [
87 | "### Download data\n",
88 | "\n",
89 | "You need to log in to HuggingFace to be able to download the dataset"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {
96 | "executionInfo": {
97 | "elapsed": 72,
98 | "status": "ok",
99 | "timestamp": 1683904816150,
100 | "user": {
101 | "displayName": "David Budaghyan",
102 | "userId": "04476631309743002593"
103 | },
104 | "user_tz": 240
105 | },
106 | "id": "Du5UrdEgKx7a"
107 | },
108 | "outputs": [],
109 | "source": [
110 | "from huggingface_hub import notebook_login\n",
111 | "\n",
112 | "notebook_login()"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "download_data(dataset_path)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {
128 | "colab": {
129 | "base_uri": "https://localhost:8080/",
130 | "height": 513
131 | },
132 | "executionInfo": {
133 | "elapsed": 73,
134 | "status": "ok",
135 | "timestamp": 1683904900224,
136 | "user": {
137 | "displayName": "David Budaghyan",
138 | "userId": "04476631309743002593"
139 | },
140 | "user_tz": 240
141 | },
142 | "id": "JrT1EvuiLjlr",
143 | "outputId": "dbbd3292-ed66-4d58-c10e-299d1bb4b453"
144 | },
145 | "outputs": [],
146 | "source": [
147 | "# read metadata\n",
148 | "metadata = pd.read_csv(\n",
149 | " f\"{dataset_path}/metadata.csv\", dtype={\"baby_id\": str, \"chronological_index\": str}\n",
150 | ")\n",
151 | "train_metadata = metadata.loc[metadata[\"split\"] == \"train\"].copy()\n",
152 | "display(\n",
153 | " train_metadata.head()\n",
154 | " .style.set_caption(\"train_metadata\")\n",
155 | " .set_table_styles([{\"selector\": \"caption\", \"props\": [(\"font-size\", \"20px\")]}])\n",
156 | ")\n",
157 | "display(train_metadata.describe())"
158 | ]
159 | },
160 | {
161 | "cell_type": "markdown",
162 | "metadata": {
163 | "id": "nX6q3zCxtaQa"
164 | },
165 | "source": [
166 | "### Concatenate cry sounds\n",
167 | "\n",
168 | "We are given short cry sounds for each baby. Here we simply concatenate them. "
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "metadata": {
175 | "colab": {
176 | "base_uri": "https://localhost:8080/",
177 | "height": 930
178 | },
179 | "executionInfo": {
180 | "elapsed": 24464,
181 | "status": "ok",
182 | "timestamp": 1683904924639,
183 | "user": {
184 | "displayName": "David Budaghyan",
185 | "userId": "04476631309743002593"
186 | },
187 | "user_tz": 240
188 | },
189 | "id": "bUTe9QqwRqE-",
190 | "outputId": "d4872617-95de-4aa8-db79-bb41fe879408"
191 | },
192 | "outputs": [],
193 | "source": [
194 | "# read the segments\n",
195 | "train_metadata[\"cry\"] = train_metadata.apply(\n",
196 | " lambda row: read_audio(f'{dataset_path}/{row[\"file_name\"]}').numpy(), axis=1\n",
197 | ")\n",
198 | "# concatenate all segments for each (baby_id, period) group\n",
199 | "manifest_df = pd.DataFrame(\n",
200 | " train_metadata.groupby([\"baby_id\", \"period\"])[\"cry\"].agg(lambda x: np.concatenate(x.values)),\n",
201 | " columns=[\"cry\"],\n",
202 | ").reset_index()\n",
203 | "# all files have 16000 sampling rate\n",
204 | "manifest_df[\"duration\"] = manifest_df[\"cry\"].apply(len) / 16000\n",
205 | "pathlib.Path(f\"{dataset_path}/concatenated_audio_train\").mkdir(exist_ok=True)\n",
206 | "manifest_df[\"file_path\"] = manifest_df.apply(\n",
207 | " lambda row: f\"{dataset_path}/concatenated_audio_train/{row['baby_id']}_{row['period']}.wav\",\n",
208 | " axis=1,\n",
209 | ")\n",
210 | "manifest_df.apply(\n",
211 | " lambda row: write_audio(\n",
212 | " filepath=f'{row[\"file_path\"]}', audio=torch.tensor(row[\"cry\"]), samplerate=16000\n",
213 | " ),\n",
214 | " axis=1,\n",
215 | ")\n",
216 | "manifest_df = manifest_df.drop(columns=[\"cry\"])\n",
217 | "display(manifest_df)\n",
218 | "ax = sns.histplot(manifest_df, x=\"duration\")\n",
219 | "ax.set_title(\"Histogram of Concatenated Cry Sound Lengths\")"
220 | ]
221 | },
222 | {
223 | "cell_type": "markdown",
224 | "metadata": {},
225 | "source": [
226 | "During training, we will extract random cuts of 3-5 seconds from concatenated audio"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": null,
232 | "metadata": {
233 | "executionInfo": {
234 | "elapsed": 114,
235 | "status": "ok",
236 | "timestamp": 1683904924641,
237 | "user": {
238 | "displayName": "David Budaghyan",
239 | "userId": "04476631309743002593"
240 | },
241 | "user_tz": 240
242 | },
243 | "id": "8fqkEktdRRXf"
244 | },
245 | "outputs": [],
246 | "source": [
247 | "def create_cut_length_interval(row, cut_length_interval):\n",
248 | " \"\"\"cut_length_interval is a tuple indicating the range of lengths we want our chunks to be.\n",
249 | " this function computes the valid range of chunk lengths for each audio file\n",
250 | " \"\"\"\n",
251 | " # the lengths are in seconds, convert them to frames\n",
252 | " cut_length_interval = [round(length * 16000) for length in cut_length_interval]\n",
253 | " cry_length = round(row[\"duration\"] * 16000)\n",
254 | " # make the interval valid for the specific sound file\n",
255 | " min_cut_length, max_cut_length = cut_length_interval\n",
256 | " # if min_cut_length is greater than length of cry, don't cut\n",
257 | " if min_cut_length >= cry_length:\n",
258 | " cut_length_interval = (cry_length, cry_length)\n",
259 | " # if max_cut_length is greater than length of cry, take a cut of length between min_cut_length and full length of cry\n",
260 | " elif max_cut_length >= cry_length:\n",
261 | " cut_length_interval = (min_cut_length, cry_length)\n",
262 | " return cut_length_interval\n",
263 | "\n",
264 | "\n",
265 | "cut_length_interval = (3, 5)\n",
266 | "manifest_df[\"cut_length_interval_in_frames\"] = manifest_df.apply(\n",
267 | " lambda row: create_cut_length_interval(row, cut_length_interval=cut_length_interval), axis=1\n",
268 | ")"
269 | ]
270 | },
271 | {
272 | "cell_type": "markdown",
273 | "metadata": {
274 | "id": "VS6R0uA2tpAJ"
275 | },
276 | "source": [
277 | "### Split into train and val\n",
278 | "\n",
279 | "For training a classfier, we can split the data into train/val in any way, as long as val does not contain new classes\n",
280 | "\n",
281 | "One way to split is to split by period: train on birth recordings and validate on discharge"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "metadata": {
288 | "colab": {
289 | "base_uri": "https://localhost:8080/",
290 | "height": 493
291 | },
292 | "executionInfo": {
293 | "elapsed": 115,
294 | "status": "ok",
295 | "timestamp": 1683904924645,
296 | "user": {
297 | "displayName": "David Budaghyan",
298 | "userId": "04476631309743002593"
299 | },
300 | "user_tz": 240
301 | },
302 | "id": "KceQZp_pga34",
303 | "outputId": "a1eafd79-d192-4f7f-f447-36cb05b33d87"
304 | },
305 | "outputs": [],
306 | "source": [
307 | "# we can train on any subset of babies (e.g. to reduce the number of classes, only keep babies with long enough cries, etc)\n",
308 | "def get_babies_with_both_recordings(manifest_df):\n",
309 | " count_of_periods_per_baby = manifest_df.groupby(\"baby_id\")[\"period\"].count()\n",
310 | " baby_ids_with_recording_from_both_periods = count_of_periods_per_baby[\n",
311 | " count_of_periods_per_baby == 2\n",
312 | " ].index\n",
313 | " return baby_ids_with_recording_from_both_periods\n",
314 | "\n",
315 | "\n",
316 | "# def get_babies_with_a_birth_recording(manifest_df):\n",
317 | "# bool_series = manifest_df.groupby('baby_id')['period'].unique().apply(set(['B']).issubset)\n",
318 | "# baby_ids_with_a_recordings_from_birth = bool_series[bool_series].index\n",
319 | "# return baby_ids_with_a_recordings_from_birth\n",
320 | "\n",
321 | "\n",
322 | "def split_by_period(row, included_baby_ids):\n",
323 | " if row[\"baby_id\"] in included_baby_ids:\n",
324 | " if row[\"period\"] == \"B\":\n",
325 | " return \"train\"\n",
326 | " else:\n",
327 | " return \"val\"\n",
328 | " else:\n",
329 | " return \"not_used\"\n",
330 | "\n",
331 | "\n",
332 | "babies_with_both_recordings = get_babies_with_both_recordings(manifest_df)\n",
333 | "manifest_df[\"split\"] = manifest_df.apply(\n",
334 | " lambda row: split_by_period(row, included_baby_ids=babies_with_both_recordings), axis=1\n",
335 | ")\n",
336 | "\n",
337 | "# each instance will be identified with a unique id\n",
338 | "manifest_df[\"id\"] = manifest_df[\"baby_id\"] + \"_\" + manifest_df[\"period\"]\n",
339 | "display(manifest_df)\n",
340 | "display(\n",
341 | " manifest_df[\"split\"]\n",
342 | " .value_counts()\n",
343 | " .rename(\"use_babies_with_both_recordings_and_split_by_period\")\n",
344 | ")\n",
345 | "manifest_df.set_index(\"id\").to_json(\"manifest.json\", orient=\"index\")"
346 | ]
347 | },
348 | {
349 | "cell_type": "markdown",
350 | "metadata": {
351 | "id": "N2X-8Zs2-Mhm"
352 | },
353 | "source": [
354 | "### Create dynamic datasets\n",
355 | "\n",
356 | "See SpeechBrain documentation to understand details"
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": null,
362 | "metadata": {
363 | "colab": {
364 | "base_uri": "https://localhost:8080/"
365 | },
366 | "executionInfo": {
367 | "elapsed": 107,
368 | "status": "ok",
369 | "timestamp": 1683904924648,
370 | "user": {
371 | "displayName": "David Budaghyan",
372 | "userId": "04476631309743002593"
373 | },
374 | "user_tz": 240
375 | },
376 | "id": "1NxmeLy_fj99",
377 | "outputId": "6ca4659d-4083-4a79-ccdd-241d6501a7df"
378 | },
379 | "outputs": [],
380 | "source": [
381 | "# create a dynamic dataset from the csv, only used to create train and val datasets\n",
382 | "dataset = DynamicItemDataset.from_json(\"manifest.json\")\n",
383 | "baby_id_encoder = CategoricalEncoder()\n",
384 | "datasets = {}\n",
385 | "# create a dataset for each split\n",
386 | "for split in [\"train\", \"val\"]:\n",
387 | " # retrieve the desired slice (train or val) and sort by length to minimize amount of padding\n",
388 | " datasets[split] = dataset.filtered_sorted(\n",
389 | " key_test={\"split\": lambda value: value == split}, sort_key=\"duration\"\n",
390 | " ) # select_n=100\n",
391 | " # create the baby_id_encoded field\n",
392 | " datasets[split].add_dynamic_item(\n",
393 | " baby_id_encoder.encode_label_torch, takes=\"baby_id\", provides=\"baby_id_encoded\"\n",
394 | " )\n",
395 | " # set visible fields\n",
396 | " datasets[split].set_output_keys([\"id\", \"baby_id\", \"baby_id_encoded\", \"sig\"])\n",
397 | "\n",
398 | "\n",
399 | "# create the signal field for the val split (no chunking)\n",
400 | "datasets[\"val\"].add_dynamic_item(sb.dataio.dataio.read_audio, takes=\"file_path\", provides=\"sig\")\n",
401 | "\n",
402 | "# the label encoder will map the baby_ids to target classes 0, 1, 2, ...\n",
403 | "# only use the classes which appear in `train`,\n",
404 | "baby_id_encoder.update_from_didataset(datasets[\"train\"], \"baby_id\")\n",
405 | "\n",
406 | "\n",
407 | "# for reading the train split, we add chunking\n",
408 | "def audio_pipeline(file_path, cut_length_interval_in_frames):\n",
409 | " \"\"\"Load the signal, and pass it and its length to the corruption class.\n",
410 | " This is done on the CPU in the `collate_fn`.\"\"\"\n",
411 | " sig = sb.dataio.dataio.read_audio(file_path)\n",
412 | " if cut_length_interval_in_frames is not None:\n",
413 | " cut_length = random.randint(*cut_length_interval_in_frames)\n",
414 | " # pick the start index of the cut\n",
415 | " left_index = random.randint(0, len(sig) - cut_length)\n",
416 | " # cut the signal\n",
417 | " sig = sig[left_index : left_index + cut_length]\n",
418 | " return sig\n",
419 | "\n",
420 | "\n",
421 | "# create the signal field (with chunking)\n",
422 | "datasets[\"train\"].add_dynamic_item(\n",
423 | " audio_pipeline, takes=[\"file_path\", \"cut_length_interval_in_frames\"], provides=\"sig\"\n",
424 | ")\n",
425 | "\n",
426 | "print(datasets[\"train\"][0])"
427 | ]
428 | },
429 | {
430 | "cell_type": "markdown",
431 | "metadata": {
432 | "id": "bbwH78P8_TOd"
433 | },
434 | "source": [
435 | "### Fine-tune the classifier\n",
436 | "\n",
437 | "Here we use a very basic example that just trains for 5 epochs"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": null,
443 | "metadata": {
444 | "colab": {
445 | "base_uri": "https://localhost:8080/"
446 | },
447 | "executionInfo": {
448 | "elapsed": 94,
449 | "status": "ok",
450 | "timestamp": 1683904924651,
451 | "user": {
452 | "displayName": "David Budaghyan",
453 | "userId": "04476631309743002593"
454 | },
455 | "user_tz": 240
456 | },
457 | "id": "ixojL5uH5y1V",
458 | "outputId": "de735a57-42da-41ab-8860-1665874edf15"
459 | },
460 | "outputs": [],
461 | "source": [
462 | "config_filename = \"hparams/ecapa_voxceleb_basic.yaml\"\n",
463 | "overrides = {\n",
464 | " \"seed\": 3011,\n",
465 | " \"n_classes\": len(baby_id_encoder),\n",
466 | " \"experiment_name\": \"ecapa_voxceleb_ft_basic\",\n",
467 | " \"bs\": 32,\n",
468 | " \"n_epochs\": 5,\n",
469 | "}\n",
470 | "device = \"cuda\"\n",
471 | "run_opts = {\"device\": device}\n",
472 | "###########################################\n",
473 | "# Load hyperparameters file with command-line overrides.\n",
474 | "with open(config_filename) as fin:\n",
475 | " hparams = load_hyperpyyaml(fin, overrides)\n",
476 | "# Create experiment directory\n",
477 | "sb.create_experiment_directory(\n",
478 | " experiment_directory=hparams[\"experiment_dir\"],\n",
479 | " hyperparams_to_save=config_filename,\n",
480 | " overrides=overrides,\n",
481 | ")\n",
482 | "\n",
483 | "# Initialize the Brain object to prepare for training.\n",
484 | "crybrain = CryBrain(\n",
485 | " modules=hparams[\"modules\"],\n",
486 | " opt_class=hparams[\"opt_class\"],\n",
487 | " hparams=hparams,\n",
488 | " run_opts=run_opts,\n",
489 | " checkpointer=hparams[\"checkpointer\"],\n",
490 | ")\n",
491 | "\n",
492 | "# if a pretrained model is specified, load it\n",
493 | "if \"pretrained_embedding_model\" in hparams:\n",
494 | " sb.utils.distributed.run_on_main(hparams[\"pretrained_embedding_model\"].collect_files)\n",
495 | " hparams[\"pretrained_embedding_model\"].load_collected(device=device)\n",
496 | "\n",
497 | "crybrain.fit(\n",
498 | " epoch_counter=crybrain.hparams.epoch_counter,\n",
499 | " train_set=datasets[\"train\"],\n",
500 | " valid_set=datasets[\"val\"],\n",
501 | " train_loader_kwargs=hparams[\"train_dataloader_options\"],\n",
502 | " valid_loader_kwargs=hparams[\"val_dataloader_options\"],\n",
503 | ")"
504 | ]
505 | },
506 | {
507 | "cell_type": "markdown",
508 | "metadata": {
509 | "id": "cVDparawc1YF"
510 | },
511 | "source": [
512 | "You can now use embedding_model.ckpt from this recipe and use it in evaluate.ipynb to verify pairs of cries and submit your results!"
513 | ]
514 | }
515 | ],
516 | "metadata": {
517 | "accelerator": "GPU",
518 | "colab": {
519 | "authorship_tag": "ABX9TyN+3g6dMDslVfuhfJgsJIst",
520 | "gpuType": "T4",
521 | "provenance": [],
522 | "toc_visible": true
523 | },
524 | "gpuClass": "standard",
525 | "kernelspec": {
526 | "display_name": "Python 3 (ipykernel)",
527 | "language": "python",
528 | "name": "python3"
529 | },
530 | "language_info": {
531 | "codemirror_mode": {
532 | "name": "ipython",
533 | "version": 3
534 | },
535 | "file_extension": ".py",
536 | "mimetype": "text/x-python",
537 | "name": "python",
538 | "nbconvert_exporter": "python",
539 | "pygments_lexer": "ipython3",
540 | "version": "3.8.16"
541 | }
542 | },
543 | "nbformat": 4,
544 | "nbformat_minor": 1
545 | }
546 |
--------------------------------------------------------------------------------
/evaluate.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "7gd-5ANmkKSu"
7 | },
8 | "source": [
9 | "# Evaluation notebook for [CryCeleb2023 challenge](https://huggingface.co/spaces/competitions/CryCeleb2023)\n",
10 | "\n",
11 | "## This notebook does the following:\n",
12 | "- Download the Cryceleb data from Hugging Face.\n",
13 | "- Download a pretrained SpeechBrain model from Hugging Face.\n",
14 | "- Compute embeddings.\n",
15 | "- Compute similarity scores for pairs of embeddings.\n",
16 | "- Compute the equal error rate of the scores and visualize results.\n",
17 | "- Produces my_solution.csv that can be uploaded to the competition platform."
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "id": "fwKCGXetl_Jd"
24 | },
25 | "source": [
26 | "### Imports"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "# For Colab - uncomment and run the following to set up the repo\n",
36 | "# !pip install speechbrain\n",
37 | "# !git clone https://github.com/Ubenwa/cryceleb2023.git\n",
38 | "# %cd cryceleb2023"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {
45 | "id": "ZZqSCpv_lUIa"
46 | },
47 | "outputs": [],
48 | "source": [
49 | "%%capture\n",
50 | "\n",
51 | "import matplotlib.pyplot as plt\n",
52 | "import numpy as np\n",
53 | "import pandas as pd\n",
54 | "import seaborn as sns\n",
55 | "import speechbrain as sb\n",
56 | "import torch\n",
57 | "from huggingface_hub import hf_hub_download\n",
58 | "from IPython.display import display\n",
59 | "from speechbrain.dataio.dataio import read_audio\n",
60 | "from speechbrain.pretrained import EncoderClassifier, SpeakerRecognition\n",
61 | "from speechbrain.utils.metric_stats import EER\n",
62 | "from tqdm.notebook import tqdm\n",
63 | "\n",
64 | "from crybrain import download_data\n",
65 | "\n",
66 | "dataset_path = \"data\""
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "metadata": {
72 | "id": "AUNVrz16mNZH"
73 | },
74 | "source": [
75 | "### Data"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "from huggingface_hub import notebook_login\n",
85 | "\n",
86 | "notebook_login()"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {
93 | "colab": {
94 | "base_uri": "https://localhost:8080/",
95 | "height": 177,
96 | "referenced_widgets": [
97 | "48a9413b4a9f44ab9b133f75757ba1b3",
98 | "bde64b1a9a2a44a68ce376638850faf1",
99 | "c688ebb7d042446db9688c0a8c5686cb",
100 | "31f9d67b4f5d41fc8761bc1f7bcc3a88",
101 | "6faf6b7594064b78b38122e45c69229c",
102 | "8a390c2647714f06982ede802e708857",
103 | "87c274b22fc7460188a51d8d5689ec04",
104 | "69e01654e825416783101f44f800ba37",
105 | "85c75c5f031844f58f61c2217819201c",
106 | "56ac688d29664728a5921d5b6132da75",
107 | "aae703425f7245a8a54dd95870e98345",
108 | "d25726ee8d404d30b151aa04950aec53",
109 | "ee675ef5bdcc4b01b5f5d63737ef4e4a",
110 | "01f99069fb0648639f57a2fcc1625302",
111 | "688abf05d24644e385cee5294d63eaff",
112 | "adf5285fdc5e47cd8bd89acec431c519",
113 | "afe24b4ae3df482f90633ffd2600254c",
114 | "a445688a0d254b0abb42951544c5de29",
115 | "4cae57f7658c46168272be6ab8ed6361",
116 | "e3a7af1279464532b725177f0a0d50c2",
117 | "493f3c83c220430ab95687b724a7ce36",
118 | "3ccf5f36399f4dd9b5a1b39034147011",
119 | "92a1d2e7cfba4139bd30f33c9952ad0b",
120 | "446adc6e4b49402295141543a06d7022",
121 | "0e96610caf1747dc99d14ba7d4717d1e",
122 | "608197a5662e440db4309350a2f19148",
123 | "47b28b83da3048a0880294d898ed25a8",
124 | "38a29275901949afae2e4bf373463c19",
125 | "2976157384ca4771ac36d66d54f4eebc",
126 | "9c1a48b5361d4b4bb2a33fe44f59b4a9",
127 | "bb807e9c793d415f8e0992856b2bda96",
128 | "7b4a03db3e0d4bd5906e20333ce5358f",
129 | "3efc045306d14584a02a06cfa49f8d69",
130 | "3a07bc2787cc43a598d0a2b189e5fc36",
131 | "cb23d7bbcb014798964fde1f90b76496",
132 | "0c76014da6b44182a2c3561c473e1656",
133 | "d49d39ccb6b0439497375a9f37e961cb",
134 | "2885010222bc494e903b0ef93b8d27b2",
135 | "052b0a77e57a45d0be587a7564e18dd9",
136 | "42eadb95424e44399a1916768ae31002",
137 | "1785f38eb7494ba3a28870cc4792363f",
138 | "036f560d958f4e4a8d2367e4a1f5ab2f",
139 | "4a271e56fb17495c8efdf47c0f5b1582",
140 | "a227c472debb467bb952ef172afbda37",
141 | "74575ce9235f449c8cdab90ff650ad0e",
142 | "a8e3cdce46304042882a1121d075668e",
143 | "3ddf4f47a405499e93634ae846855d48",
144 | "351363b8f8d040b19745f5201db56809",
145 | "6ce593721e034de682cef0645cfefc31",
146 | "5d55ed42d44b417fa9aa8ea2bc5128f5",
147 | "39e121daa40b43b280d7bae156847071",
148 | "afb60ef0d33a4a7dba0396602d1c014e",
149 | "36722281b9f243fab567adffc6c6dc3b",
150 | "d74ccc6ec2f24ef291d98cadab769bcc",
151 | "afd5310507ac4800b254af99d9b89ed0"
152 | ]
153 | },
154 | "id": "zqXn1mT8nRwp",
155 | "outputId": "59280321-e953-44c9-9658-ebc32f9a7941"
156 | },
157 | "outputs": [],
158 | "source": [
159 | "download_data(dataset_path)"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {
166 | "colab": {
167 | "base_uri": "https://localhost:8080/",
168 | "height": 871
169 | },
170 | "id": "CenUmoY_mMqw",
171 | "outputId": "1bb9f7b1-fae0-4f58-8a6a-c50c8780379a"
172 | },
173 | "outputs": [],
174 | "source": [
175 | "# read metadata\n",
176 | "metadata = pd.read_csv(\n",
177 | " f\"{dataset_path}/metadata.csv\", dtype={\"baby_id\": str, \"chronological_index\": str}\n",
178 | ")\n",
179 | "dev_metadata = metadata.loc[metadata[\"split\"] == \"dev\"].copy()\n",
180 | "# read sample submission\n",
181 | "sample_submission = pd.read_csv(\n",
182 | " f\"{dataset_path}/sample_submission.csv\"\n",
183 | ") # scores are unfiorm random\n",
184 | "# read verification pairs\n",
185 | "dev_pairs = pd.read_csv(\n",
186 | " f\"{dataset_path}/dev_pairs.csv\", dtype={\"baby_id_B\": str, \"baby_id_D\": str}\n",
187 | ")\n",
188 | "test_pairs = pd.read_csv(f\"{dataset_path}/test_pairs.csv\")\n",
189 | "\n",
190 | "display(\n",
191 | " metadata.head()\n",
192 | " .style.set_caption(\"metadata\")\n",
193 | " .set_table_styles([{\"selector\": \"caption\", \"props\": [(\"font-size\", \"20px\")]}])\n",
194 | ")\n",
195 | "display(\n",
196 | " dev_pairs.head()\n",
197 | " .style.set_caption(\"dev_pairs\")\n",
198 | " .set_table_styles([{\"selector\": \"caption\", \"props\": [(\"font-size\", \"20px\")]}])\n",
199 | ")\n",
200 | "display(\n",
201 | " test_pairs.head()\n",
202 | " .style.set_caption(\"test_pairs\")\n",
203 | " .set_table_styles([{\"selector\": \"caption\", \"props\": [(\"font-size\", \"20px\")]}])\n",
204 | ")\n",
205 | "display(\n",
206 | " sample_submission.head()\n",
207 | " .style.set_caption(\"sample_submission\")\n",
208 | " .set_table_styles([{\"selector\": \"caption\", \"props\": [(\"font-size\", \"20px\")]}])\n",
209 | ")"
210 | ]
211 | },
212 | {
213 | "cell_type": "markdown",
214 | "metadata": {
215 | "id": "i7qn0lFdmOlF"
216 | },
217 | "source": [
218 | "### Initialize encoder"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "metadata": {
224 | "id": "Rtgd7qlfmUWC"
225 | },
226 | "source": [
227 | "One way to verify if both pairs come from the same baby is to concatenate all the segments for each pair, compute the embedding of the concatenated cry, and compute the cosine similarity between the embeddings.\n",
228 | "\n",
229 | "Let's load the model"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {
236 | "colab": {
237 | "base_uri": "https://localhost:8080/",
238 | "height": 81,
239 | "referenced_widgets": [
240 | "a5a6c8050ce444fe9e5e3835af2c89f7",
241 | "ffe1c871271b4eeb8566f8c3e9535679",
242 | "fe22550070ad45c4a23564099beafadd",
243 | "8657df05011a4f5f8cf0492aed78e819",
244 | "7d7bf74077214e65b3727346d8e4485b",
245 | "d31bd375beac4b308abc9751e88d6762",
246 | "7359bf51f90846999bc7e6fedf1ab4f1",
247 | "635dd9fb9e844c99b60812acd3b8cefc",
248 | "9bcebcb7cb224ff9ba92e3693adcf04e",
249 | "4a8dffe277524f48aa5e2cdb76d22926",
250 | "dab9e9c96c344b78abf24353bbbeebb9",
251 | "785545f1963f4f6d849fb562f31a3c0d",
252 | "2c7f0724170f4e0d9d74677aad6e18a3",
253 | "f876ee90477b41e3972374a5f874647c",
254 | "da7ba01d01c84602a5066475ddba8195",
255 | "cd5315c7c4ed475f837285339fca2fcc",
256 | "0adc32cfd65441f4becabf682041611f",
257 | "d9c373060ef242cba7b669582645ef69",
258 | "732fe933fcb14a1f8b32988dc5ab41a3",
259 | "2b4b55f6867443e0b10ae8caccef9fc2",
260 | "4c950b1378d34aeea6d87f53f6e5484e",
261 | "9760ab6f43f94981986f37a79cf0bdd7"
262 | ]
263 | },
264 | "id": "CsE0Z6JCmSBv",
265 | "outputId": "6c22b5b6-a996-4312-b33e-78e89de6cc33"
266 | },
267 | "outputs": [],
268 | "source": [
269 | "!rm -rf spkrec-ecapa-voxceleb\n",
270 | "encoder = SpeakerRecognition.from_hparams(\n",
271 | " source=\"speechbrain/spkrec-ecapa-voxceleb\",\n",
272 | " savedir=\"spkrec-ecapa-voxceleb\",\n",
273 | " run_opts={\"device\": \"cuda\"}, # comment out if no GPU available\n",
274 | ")"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": null,
280 | "metadata": {},
281 | "outputs": [],
282 | "source": [
283 | "# you can also plug in your encoder weights if you fine-tuned this model locally\n",
284 | "# !rm spkrec-ecapa-voxceleb/embedding_model.ckpt\n",
285 | "# !cp experiments/ecapa_voxceleb_ft_basic/ckpts/CKPT+epoch-4_valacc-0.57/embedding_model.ckpt spkrec-ecapa-voxceleb\n",
286 | "\n",
287 | "# encoder = SpeakerRecognition.from_hparams(\n",
288 | "# source=\"speechbrain/spkrec-ecapa-voxceleb\",\n",
289 | "# savedir=\"spkrec-ecapa-voxceleb\",\n",
290 | "# run_opts={\"device\": \"cuda\"}, # comment out if no GPU available\n",
291 | "# )"
292 | ]
293 | },
294 | {
295 | "cell_type": "markdown",
296 | "metadata": {
297 | "id": "vUP0tjNImYT4"
298 | },
299 | "source": [
300 | "#### Compute Encodings"
301 | ]
302 | },
303 | {
304 | "cell_type": "markdown",
305 | "metadata": {
306 | "id": "mFTLJ74s6I3_"
307 | },
308 | "source": [
309 | "Change runtime type to GPU if using Colab"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": null,
315 | "metadata": {
316 | "colab": {
317 | "base_uri": "https://localhost:8080/",
318 | "height": 159,
319 | "referenced_widgets": [
320 | "780d053589614f1981ce9c3bc1bf2633",
321 | "d7869f0445404dfa98a36832af202d75",
322 | "36248ef425d84a31a5f2b4bf0935d1ce",
323 | "dcf4695b9f694a55a9fa0691f7c0b236",
324 | "bbb7c274e05c4feba27a1968b0566621",
325 | "790b64fac2d84d10b2fb6a1d1ba30b7f",
326 | "9d7d86758d46432085a7f1eeef6ee589",
327 | "75ef1cc7d24b47fe8dbd1b3e9b8c974d",
328 | "e9b6cee88af34f968c4d09d232b62120",
329 | "366a1031361647ec8aad6b599a2746a1",
330 | "a1ff4f5afe694b4b86e1aa80fa0169e5"
331 | ]
332 | },
333 | "id": "rgMZ8rrBmWc3",
334 | "outputId": "7894ce23-2972-4554-ccd1-d84f52230bdd"
335 | },
336 | "outputs": [],
337 | "source": [
338 | "%%time\n",
339 | "# read the segments\n",
340 | "dev_metadata[\"cry\"] = dev_metadata.apply(\n",
341 | " lambda row: read_audio(f'{dataset_path}/{row[\"file_name\"]}').numpy(), axis=1\n",
342 | ")\n",
343 | "# concatenate all segments for each (baby_id, period) group\n",
344 | "cry_dict = pd.DataFrame(\n",
345 | " dev_metadata.groupby([\"baby_id\", \"period\"])[\"cry\"].agg(lambda x: np.concatenate(x.values)),\n",
346 | " columns=[\"cry\"],\n",
347 | ").to_dict(orient=\"index\")\n",
348 | "# encode the concatenated cries\n",
349 | "for (baby_id, period), d in tqdm(cry_dict.items()):\n",
350 | " d[\"cry_encoded\"] = encoder.encode_batch(torch.tensor(d[\"cry\"]), normalize=False)"
351 | ]
352 | },
353 | {
354 | "cell_type": "markdown",
355 | "metadata": {
356 | "id": "PF4Sa3BnmcLA"
357 | },
358 | "source": [
359 | "#### Compute Similarity Between Encodings"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "metadata": {
366 | "colab": {
367 | "base_uri": "https://localhost:8080/",
368 | "height": 206
369 | },
370 | "id": "lgaScJgImcaO",
371 | "outputId": "691c5120-65d9-4eae-8d4f-c8d79a502254"
372 | },
373 | "outputs": [],
374 | "source": [
375 | "def compute_cosine_similarity_score(row, cry_dict):\n",
376 | " cos = torch.nn.CosineSimilarity(dim=-1)\n",
377 | " similarity_score = cos(\n",
378 | " cry_dict[(row[\"baby_id_B\"], \"B\")][\"cry_encoded\"],\n",
379 | " cry_dict[(row[\"baby_id_D\"], \"D\")][\"cry_encoded\"],\n",
380 | " )\n",
381 | " return similarity_score.item()\n",
382 | "\n",
383 | "\n",
384 | "dev_pairs[\"score\"] = dev_pairs.apply(\n",
385 | " lambda row: compute_cosine_similarity_score(row=row, cry_dict=cry_dict), axis=1\n",
386 | ")\n",
387 | "display(dev_pairs.head())"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": null,
393 | "metadata": {
394 | "colab": {
395 | "base_uri": "https://localhost:8080/",
396 | "height": 472
397 | },
398 | "id": "3tBfFZ1OmeaG",
399 | "outputId": "9eedd68d-69f2-44da-d156-6db7acce7e4d"
400 | },
401 | "outputs": [],
402 | "source": [
403 | "def compute_eer_and_plot_verification_scores(pairs_df):\n",
404 | " \"\"\"pairs_df must have 'score' and 'label' columns\"\"\"\n",
405 | " positive_scores = pairs_df.loc[pairs_df[\"label\"] == 1][\"score\"].values\n",
406 | " negative_scores = pairs_df.loc[pairs_df[\"label\"] == 0][\"score\"].values\n",
407 | " eer, threshold = EER(torch.tensor(positive_scores), torch.tensor(negative_scores))\n",
408 | " ax = sns.histplot(pairs_df, x=\"score\", hue=\"label\", stat=\"percent\", common_norm=False)\n",
409 | " ax.set_title(f\"EER={round(eer, 4)} - Thresh={round(threshold, 4)}\")\n",
410 | " plt.axvline(x=[threshold], color=\"red\", ls=\"--\")\n",
411 | " return eer, threshold\n",
412 | "\n",
413 | "\n",
414 | "eer, threshold = compute_eer_and_plot_verification_scores(pairs_df=dev_pairs)"
415 | ]
416 | },
417 | {
418 | "cell_type": "markdown",
419 | "metadata": {
420 | "id": "LEjkMjYN17rf"
421 | },
422 | "source": [
423 | "The above plot displays the histogram of scores for +ive (same baby) and -ive (different baby) dev_pairs.\n",
424 | "\n",
425 | "A perfect verifier would attribute a higher score to all +ive pairs than any -ive pair.\\\n",
426 | "Your task is to come up with a scoring system which maximizes the separation between the two distributions, as measured by the EER.\\\n",
427 | "You can change the encoder module, the aggregation of cry segments, the similarity metric, or come up with a completely different process! \\\n",
428 | "\n",
429 | "\n"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": null,
435 | "metadata": {},
436 | "outputs": [],
437 | "source": [
438 | "# same for test set that was hidden during evaluation\n",
439 | "eer, threshold = compute_eer_and_plot_verification_scores(pairs_df=test_pairs)"
440 | ]
441 | },
442 | {
443 | "cell_type": "markdown",
444 | "metadata": {
445 | "id": "yINubTKWImh1"
446 | },
447 | "source": [
448 | "You can also create example submission file for the challenge using code below. \n",
449 | "\n",
450 | "It is no more relevant as we have access to eval labels now and do scoring above."
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": null,
456 | "metadata": {
457 | "colab": {
458 | "base_uri": "https://localhost:8080/",
459 | "height": 274,
460 | "referenced_widgets": [
461 | "fe70b8c96fdc4f5695438719d0a538f0",
462 | "60d9353fa67d47b08d6f79fe9edf14ff",
463 | "6d6f9ae639ea4b39943fa5abc84f2aaf",
464 | "10fb5219f51c455bb1039fe19513175d",
465 | "fe030cd60b6a49f9a5e22c53df2c1fbb",
466 | "7203a181918148d1a5e0fadd8d33344d",
467 | "751cbbff056f476aba01b860f14a1b52",
468 | "39c3334153a74fb8b9c470e55dac6e4c",
469 | "c4077ea694104e908743bd647fb228b7",
470 | "31f214f5206f4939a9e7eafff0fdb63c",
471 | "e67f032a44a1490e9eefa81886a86705"
472 | ]
473 | },
474 | "id": "q7LrpPFVIGir",
475 | "outputId": "43ce30ba-debd-49a0-f0e8-d6b82c7653b7"
476 | },
477 | "outputs": [],
478 | "source": [
479 | "%%time\n",
480 | "test_metadata = metadata.loc[metadata[\"split\"] == \"test\"].copy()\n",
481 | "# read the segments\n",
482 | "test_metadata[\"cry\"] = test_metadata.apply(\n",
483 | " lambda row: read_audio(f'{dataset_path}/{row[\"file_name\"]}').numpy(), axis=1\n",
484 | ")\n",
485 | "# concatenate all segments for each (baby_id, period) group\n",
486 | "cry_dict_test = pd.DataFrame(\n",
487 | " test_metadata.groupby([\"baby_id\", \"period\"])[\"cry\"].agg(lambda x: np.concatenate(x.values)),\n",
488 | " columns=[\"cry\"],\n",
489 | ").to_dict(orient=\"index\")\n",
490 | "# encode the concatenated cries\n",
491 | "for (baby_id, period), d in tqdm(cry_dict_test.items()):\n",
492 | " d[\"cry_encoded\"] = encoder.encode_batch(torch.tensor(d[\"cry\"]), normalize=False)\n",
493 | "\n",
494 | "# compute cosine similarity between all pairs\n",
495 | "test_pairs[\"score\"] = test_pairs.apply(\n",
496 | " lambda row: compute_cosine_similarity_score(row=row, cry_dict=cry_dict_test), axis=1\n",
497 | ")\n",
498 | "display(test_pairs.head())"
499 | ]
500 | },
501 | {
502 | "cell_type": "code",
503 | "execution_count": null,
504 | "metadata": {
505 | "colab": {
506 | "base_uri": "https://localhost:8080/",
507 | "height": 206
508 | },
509 | "id": "4ZYkMxOrKRQ3",
510 | "outputId": "c58ed78b-1295-4209-c474-384e421cc779"
511 | },
512 | "outputs": [],
513 | "source": [
514 | "# submission must match the 'sample_submission.csv' format exactly\n",
515 | "my_submission = test_pairs[[\"id\", \"score\"]]\n",
516 | "my_submission.to_csv(\"my_submission.csv\", index=False)\n",
517 | "display(my_submission.head())"
518 | ]
519 | },
520 | {
521 | "cell_type": "markdown",
522 | "metadata": {
523 | "id": "S3dgbBwUfpzU"
524 | },
525 | "source": [
526 | "You can now download `my_submission.csv` and submit it to the challenge!"
527 | ]
528 | }
529 | ],
530 | "metadata": {
531 | "accelerator": "GPU",
532 | "colab": {
533 | "provenance": []
534 | },
535 | "gpuClass": "standard",
536 | "kernelspec": {
537 | "display_name": "Python 3 (ipykernel)",
538 | "language": "python",
539 | "name": "python3"
540 | },
541 | "language_info": {
542 | "codemirror_mode": {
543 | "name": "ipython",
544 | "version": 3
545 | },
546 | "file_extension": ".py",
547 | "mimetype": "text/x-python",
548 | "name": "python",
549 | "nbconvert_exporter": "python",
550 | "pygments_lexer": "ipython3",
551 | "version": "3.11.5"
552 | }
553 | },
554 | "nbformat": 4,
555 | "nbformat_minor": 1
556 | }
557 |
--------------------------------------------------------------------------------