├── .gitignore ├── LICENSE ├── README.md ├── audio_segment.py ├── emovdb_prepare_concat_dataset.py ├── emovdb_prepare_concat_gender_dataset.py ├── emovdb_prepare_concat_ib_dataset.py ├── iemocap_prepare_concat_dataset.py ├── iemocap_prepare_concat_gender_dataset.py ├── install_dep.sh ├── requirements.txt ├── segment_eval.py ├── speech_lm ├── UnsupSeg │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── conf │ │ └── config.yaml │ ├── dataloader.py │ ├── main.py │ ├── next_frame_classifier.py │ ├── predict.py │ ├── pretrained_models │ │ ├── buckeye+_pretrained.ckpt │ │ ├── buckeye_pretrained.ckpt │ │ ├── timit+_pretrained.ckpt │ │ └── timit_pretrained.ckpt │ ├── scripts │ │ ├── make_timit.py │ │ └── preprocess_buckeye.py │ ├── solver.py │ └── utils.py ├── __init__.py ├── configs │ ├── README.md │ ├── inference │ │ └── TWIST-1-3B.json │ ├── segmentors │ │ ├── diarization │ │ │ ├── emotion_diarization_config.json │ │ │ └── speaker_diarization_config.json │ │ ├── equal_length │ │ │ ├── equal_length_config_adaptive1.json │ │ │ ├── equal_length_config_adaptive2.json │ │ │ ├── equal_length_config_adaptive3.json │ │ │ ├── equal_length_config_adaptive4.json │ │ │ ├── equal_length_config_constant_10.json │ │ │ ├── equal_length_config_constant_15.json │ │ │ └── equal_length_config_constant_20.json │ │ ├── pmi_adaptive │ │ │ ├── pmi_segmentor_config_adaptive1.json │ │ │ ├── pmi_segmentor_config_adaptive1_350M.json │ │ │ ├── pmi_segmentor_config_adaptive1_7B.json │ │ │ ├── pmi_segmentor_config_adaptive2.json │ │ │ ├── pmi_segmentor_config_adaptive2_350M.json │ │ │ ├── pmi_segmentor_config_adaptive2_7B.json │ │ │ ├── pmi_segmentor_config_adaptive3.json │ │ │ ├── pmi_segmentor_config_adaptive3_350M.json │ │ │ ├── pmi_segmentor_config_adaptive3_7B.json │ │ │ ├── pmi_segmentor_config_adaptive4.json │ │ │ ├── pmi_segmentor_config_adaptive4_350M.json │ │ │ └── pmi_segmentor_config_adaptive4_7B.json │ │ ├── pmi_adaptive_1-5s │ │ │ ├── pmi_segmentor_config_adaptive1_350M.json │ │ │ ├── pmi_segmentor_config_adaptive2_350M.json │ │ │ ├── pmi_segmentor_config_adaptive3_350M.json │ │ │ └── pmi_segmentor_config_adaptive4_350M.json │ │ ├── pmi_adaptive_1s │ │ │ ├── pmi_segmentor_config_adaptive1_350M.json │ │ │ ├── pmi_segmentor_config_adaptive2_350M.json │ │ │ ├── pmi_segmentor_config_adaptive3_350M.json │ │ │ └── pmi_segmentor_config_adaptive4_350M.json │ │ ├── pmi_adaptive_2s │ │ │ ├── pmi_segmentor_config_adaptive1_350M.json │ │ │ ├── pmi_segmentor_config_adaptive2_350M.json │ │ │ ├── pmi_segmentor_config_adaptive3_350M.json │ │ │ └── pmi_segmentor_config_adaptive4_350M.json │ │ ├── pmi_constant │ │ │ ├── pmi_segmentor_config_constant_10_1-3B.json │ │ │ ├── pmi_segmentor_config_constant_10_350M.json │ │ │ ├── pmi_segmentor_config_constant_10_7B.json │ │ │ ├── pmi_segmentor_config_constant_15_1-3B.json │ │ │ ├── pmi_segmentor_config_constant_15_350M.json │ │ │ ├── pmi_segmentor_config_constant_15_7B.json │ │ │ ├── pmi_segmentor_config_constant_20_1-3B.json │ │ │ ├── pmi_segmentor_config_constant_20_350M.json │ │ │ └── pmi_segmentor_config_constant_20_7B.json │ │ ├── pmi_constant_1-5s │ │ │ ├── pmi_segmentor_config_constant_10_350M.json │ │ │ ├── pmi_segmentor_config_constant_15_350M.json │ │ │ └── pmi_segmentor_config_constant_20_350M.json │ │ ├── pmi_constant_1s │ │ │ ├── pmi_segmentor_config_constant_10_350M.json │ │ │ ├── pmi_segmentor_config_constant_15_350M.json │ │ │ └── pmi_segmentor_config_constant_20_350M.json │ │ ├── pmi_constant_2s │ │ │ ├── pmi_segmentor_config_constant_10_350M.json │ │ │ ├── pmi_segmentor_config_constant_15_350M.json │ │ │ └── pmi_segmentor_config_constant_20_350M.json │ │ ├── pmi_threshold │ │ │ ├── pmi_segmentor_config_threshold_0_350M.json │ │ │ ├── pmi_segmentor_config_threshold_10.json │ │ │ ├── pmi_segmentor_config_threshold_10_350M.json │ │ │ ├── pmi_segmentor_config_threshold_10_7B.json │ │ │ ├── pmi_segmentor_config_threshold_12-5.json │ │ │ ├── pmi_segmentor_config_threshold_12-5_350M.json │ │ │ ├── pmi_segmentor_config_threshold_12-5_7B.json │ │ │ ├── pmi_segmentor_config_threshold_12.5.json │ │ │ ├── pmi_segmentor_config_threshold_15.json │ │ │ ├── pmi_segmentor_config_threshold_15_350M.json │ │ │ ├── pmi_segmentor_config_threshold_15_7B.json │ │ │ ├── pmi_segmentor_config_threshold_5.json │ │ │ ├── pmi_segmentor_config_threshold_5_350M.json │ │ │ ├── pmi_segmentor_config_threshold_5_7B.json │ │ │ ├── pmi_segmentor_config_threshold_8.json │ │ │ ├── pmi_segmentor_config_threshold_8_350M.json │ │ │ └── pmi_segmentor_config_threshold_8_7B.json │ │ ├── pmi_threshold_1-5s │ │ │ ├── pmi_segmentor_config_threshold_10_350M.json │ │ │ ├── pmi_segmentor_config_threshold_12-5_350M.json │ │ │ ├── pmi_segmentor_config_threshold_15_350M.json │ │ │ ├── pmi_segmentor_config_threshold_5_350M.json │ │ │ └── pmi_segmentor_config_threshold_8_350M.json │ │ ├── pmi_threshold_1s │ │ │ ├── pmi_segmentor_config_threshold_10_350M.json │ │ │ ├── pmi_segmentor_config_threshold_12-5_350M.json │ │ │ ├── pmi_segmentor_config_threshold_15_350M.json │ │ │ ├── pmi_segmentor_config_threshold_5_350M.json │ │ │ └── pmi_segmentor_config_threshold_8_350M.json │ │ ├── pmi_threshold_2s │ │ │ ├── pmi_segmentor_config_threshold_10_350M.json │ │ │ ├── pmi_segmentor_config_threshold_12-5_350M.json │ │ │ ├── pmi_segmentor_config_threshold_15_350M.json │ │ │ ├── pmi_segmentor_config_threshold_5_350M.json │ │ │ └── pmi_segmentor_config_threshold_8_350M.json │ │ └── unsupseg │ │ │ ├── unsupseg_config_constant_10.json │ │ │ ├── unsupseg_config_constant_15.json │ │ │ └── unsupseg_config_constant_20.json │ └── tokenizers │ │ └── tokenizer_500_vocab.json ├── evaluation │ └── segmentation.py ├── inference.py ├── scorers.py ├── segmentors.py ├── spans_selector.py ├── speech_sentencer.py ├── tokenizers.py └── utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Avishai Elmakies 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Unsupervised Speech Segmentation: A General Approach Using Speech Language Models 2 | #### This repository is the official implementation for the paper "[Unsupervised Speech Segmentation: A General Approach Using Speech Language Models](https://arxiv.org/abs/2501.03711)" 3 | 4 | ## Installing dependecies 5 | 6 | to use the code please download the libraries needed to run it 7 | 8 | you can run the script called ```install_dep.sh```. it will try and install all dependecies in one go. 9 | 10 | The script does the following steps: 11 | 12 | 1. start by installing [torch](https://pytorch.org/) including torchaudio (preferably with cuda; version used is 2.1.1 with cuda118). 13 | `pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118` 14 | 15 | 2. `pip install -r requirements.txt` 16 | 17 | 3. `pip install git+https://github.com/pytorch/fairseq@da8fb630880d529ab47e53381c30ddc8ad235216` 18 | 19 | Note: you might get an error beacuse of omegaconf version. this is ok 20 | 21 | ## Preaparing Datasets 22 | 23 | There are 5 files that help create the syntetic files used in the paper: 24 | 25 | 1. [emovdb_prepare_concat_dataset.py](emovdb_prepare_concat_dataset.py): file used to convert EmoV-DB dataset to the dataset used for the emotion expirement. 26 | 2. [emovdb_prepare_concat_gender_dataset.py](emovdb_prepare_concat_gender_dataset.py): file used to convert EmoV-DB dataset to the dataset used for the gender expirement. 27 | 3. [iemocap_prepare_concat_dataset.py](iemocap_prepare_concat_dataset.py): file used to convert IEMOCAP dataset to the dataset used for the emotion expirement. 28 | 4. [iemocap_prepare_concat_gender_dataset.py](emovdb_prepare_concat_gender_dataset.py): file used to convert IEMOCAP dataset to the dataset used for the gender expirement. 29 | 5. [emovdb_prepare_concat_ib_dataset.py](emovdb_prepare_concat_ib_dataset.py): file used to to convert EmoV-DB dataset to test the inductive bias hypothesis. 30 | 31 | ### runnning the preparation scripts 32 | most files have similar arguments used to run them 33 | 34 | - ```-i/--input_folder```: input folder for the datasets 35 | - ```-o/--output_folder```: output folder for the result 36 | - ```-s/--seed```: seed to use for randomness, None is default 37 | - ```-max/--max_concats```: max number of segments for file,default 10 38 | - ```-min/--min_concats```: min number of segments for file,default 2 39 | - ```--num_files```: number of files per speaker/pair 40 | - ```--sample_rate```: resample rate,default 16000 41 | - ```--remove_silence```: will use vad to remove silence in files 42 | 43 | #### EmoV-DB unique arguments 44 | - ```--remove_emotions```: emotions to not use for dataset,can be multiple, default None 45 | 46 | running EmoV-DB files: 47 | 48 | ```python emovdb_prepare_concat_dataset.py -i -o -s 42 -min 4 -max 30 --num_files 500 --remove_emotions sleepiness``` 49 | 50 | ```python emovdb_prepare_concat_gender_dataset.py -i -o -s 42 -min 4 -max 30 --num_files 250 --remove_emotions sleepiness``` 51 | 52 | ```python emovdb_prepare_concat_ib_dataset.py -i -o -s 42 -min 4 -max 30 --num_files 500 --remove_emotions sleepiness``` 53 | 54 | running IEMOCAP files: 55 | 56 | ```python iemocap_prepare_concat_dataset.py -i -o -s 42 -min 4 -max 30 --num_files 250``` 57 | 58 | ```python iemocap_prepare_concat_gender_dataset.py -i -o -s 42 -min 4 -max 30 --num_files 250``` 59 | 60 | Note: Those are the commands used to create the datasets on a linux computer. results may vary using different operating systems. 61 | 62 | ## Segmentation/Inference 63 | 64 | For inference you will use the file [audio_segment.py](audio_segment.py) 65 | 66 | - ```-c/--seg_config```: path to a config file for the segmentor. examples can be seen in [speech_lm/configs](speech_lm/configs). you can also create/update your own config, please read the instructions in [configs README](speech_lm/configs/README.md) 67 | 68 | - ```-j/--input_json```: path to the input json that will be used for inference. each script shown before creates a data.json that can be used for inference. the format for the input is: 69 | 70 | ``` 71 | 72 | { 73 | "key1": { 74 | "subkey1": { 75 | "wav_path": 76 | } 77 | "subkey2": { 78 | "wav_path": 79 | } 80 | } 81 | "key2": { 82 | "subkey1": { 83 | "wav_path": 84 | } 85 | "subkey2": { 86 | "wav_path": 87 | } 88 | } 89 | ... 90 | } 91 | ``` 92 | 93 | - `-o/--output_folder`: path where to save the output of the scripts 94 | - `-b/--base_path`: path used for models (or to download the models there) 95 | - `-s/--save_params`: flag if you want to save seg_config.json and data.json used for inference. 96 | - `--save_audio`: flag if you want to save the segmentation audio files created. 97 | 98 | this script will create a result.json in the `` folder given. 99 | 100 | format will be: 101 | 102 | ``` 103 | { 104 | "key1": { 105 | "subkey1": { 106 | "segmentation": [ 107 | { 108 | "start": , 109 | "end": 110 | }, 111 | { 112 | "start": , 113 | "end": 114 | }, 115 | ... 116 | ] 117 | } 118 | "subkey2": { 119 | "segmentation": [ 120 | { 121 | "start": , 122 | "end": 123 | }, 124 | { 125 | "start": , 126 | "end": 127 | }, 128 | ... 129 | ] 130 | } 131 | } 132 | ... 133 | } 134 | ``` 135 | if using `--save_audio` is used, each segment will also have a "wav_path" pointing to the segment wav created. 136 | 137 | ## Evaluation 138 | 139 | To evaluate the inference you can use the file [segment_eval.py](segment_eval.py). 140 | 141 | - `-re/--reference_path`: a path to a json file representing the reference (ground truth). format: 142 | ``` 143 | { 144 | "key1": { 145 | "subkey1": { 146 | "segmentation": [ 147 | { 148 | "start": , 149 | "end": 150 | }, 151 | { 152 | "start": , 153 | "end": 154 | }, 155 | ... 156 | ] 157 | } 158 | "subkey2": { 159 | "segmentation": [ 160 | { 161 | "start": , 162 | "end": 163 | }, 164 | { 165 | "start": , 166 | "end": 167 | }, 168 | ... 169 | ] 170 | } 171 | } 172 | ... 173 | } 174 | ``` 175 | - `-hy/--hypothesis_path`: a path to a json file representing the hypothesis (inference). Use same format as reference_path. 176 | 177 | - `-m/metrics`: metrics to use. options are: `["coverage","purity", "cpfmeasure", "recall", "precision","rpfmeasure", "r_value"]`. can give multiple options. can also use `"all"` to use all metrics. 178 | 179 | - `p/print_sub`: print result for each sample. 180 | 181 | - `-o/--output`: optional outputpath used to save results. 182 | 183 | - `-ci/--confidence_interval`: action to get confidence intervals as well. 184 | 185 | this script will either print a json with the results for the metrics given(unless you use `--output` which will make the script save the json into the given file path) 186 | -------------------------------------------------------------------------------- /audio_segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from tqdm import tqdm 4 | from speech_lm.segmentors import Segmentor, SegmentorFactory 5 | import torchaudio 6 | import os 7 | import torch 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser("this script segments auido files into smaller chunks.") 11 | parser.add_argument('-c','--seg_config', type=str, help='Path to the config file for segmentation',required=True) 12 | parser.add_argument('-j','--input_json', type=str, help='Path to the json file that contains the wav paths. The keys are the names of the folders. each key will have multiple sub keys (sub folders). with a wav path as the value.',required=True) 13 | parser.add_argument('-o','--output_folder', type=str, help='folder to save the segmented audio files',required=True) 14 | parser.add_argument('-b','--base_path', type=str, default='../models/', help='base path for models') 15 | parser.add_argument("-s",'--save_params',action='store_true',help="save the params and data json to reproduce the results if needed") 16 | parser.add_argument('-sa','--save_audio',action='store_true',help='bool do decide if the audio segments should be saved or not') 17 | 18 | args = parser.parse_args() 19 | 20 | 21 | input_json = args.input_json 22 | output_folder = args.output_folder 23 | save_audio = args.save_audio 24 | 25 | with open(args.seg_config) as f: 26 | seg_config = json.load(f) 27 | 28 | segmentor:Segmentor = SegmentorFactory.get_segmentor(seg_config,base_path=args.base_path) 29 | segmentor.eval() 30 | if torch.cuda.is_available(): 31 | segmentor.to('cuda') 32 | segmentor.eval() 33 | 34 | with open(input_json) as f: 35 | data = json.load(f) 36 | 37 | os.makedirs(output_folder,exist_ok=True) 38 | 39 | json_segments = {} 40 | for key, value in tqdm(data.items(),total=len(data)): 41 | json_segments[key] = {} 42 | for sub_key, sub_values in tqdm(value.items(),total=len(value)): 43 | if isinstance(sub_values,str): 44 | wav_path = sub_values 45 | elif isinstance(sub_values,dict): 46 | wav_path = sub_values["wav_path"] 47 | try: 48 | segments, sr = segmentor.segment_path(wav_path) 49 | except RuntimeError as e: 50 | print(f"Error in {key}/{sub_key}: {e}") 51 | continue 52 | wavs = [segment.pop("audio") for segment in segments] 53 | if wavs[0].ndim > 1: 54 | wavs = [wav.mean(dim=0) for wav in wavs] 55 | json_segments[key][sub_key] = {} 56 | if save_audio: 57 | outdir = os.path.join(output_folder,key,sub_key) 58 | os.makedirs(outdir,exist_ok=True) 59 | for i, (wav,segment) in enumerate(zip(wavs,segments)): 60 | outpath = os.path.join(outdir,f'{sub_key}_{i}.wav') 61 | torchaudio.save(outpath,wav.unsqueeze(0),sr) 62 | segment["audio_path"] = os.path.abspath(outpath) 63 | json_segments[key][sub_key]["audio_folder"] = os.path.abspath(outdir) 64 | json_segments[key][sub_key]["segmentation"] = segments 65 | 66 | 67 | with open(os.path.join(output_folder,'results.json'),'w') as f: 68 | json.dump(json_segments,f,indent=4) 69 | 70 | if args.save_params: 71 | with open(os.path.join(output_folder,'seg_config.json'),'w') as f: 72 | json.dump(seg_config,f,indent=4) 73 | with open(os.path.join(output_folder,'data.json'),'w') as f: 74 | json.dump(data,f,indent=4) 75 | 76 | if __name__ == '__main__': 77 | main() -------------------------------------------------------------------------------- /emovdb_prepare_concat_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchaudio 3 | import os 4 | import random 5 | import glob 6 | import torch 7 | import json 8 | from utils import load_vad,VAD_THRESHOLD 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser("helper script to prepare the EmoV-DB dataset for our experiments") 12 | parser.add_argument('-i','--input_folder',type=str,help='path to the input folder',required=True) 13 | parser.add_argument('-o','--output_folder',type=str,help='path to the output folder',required=True) 14 | parser.add_argument('-s','--seed',type=int,default=None,help='seed for the random number generator') 15 | parser.add_argument('-max','--max_concats',type=int,default=10,help='maximum number of combinations to create for each speaker') 16 | parser.add_argument('-min','--min_concats',type=int,default=2,help='minimum number of combinations to create for each speaker') 17 | parser.add_argument('--remove_emotions', default=None, nargs='*', help='list of emotions to remove') 18 | parser.add_argument('--num_files', type=int, help='number of files to sample from each speaker',required=True) 19 | parser.add_argument('--sample_rate', type=int,default=16000, help='resample sr') 20 | parser.add_argument('--remove_silence', action='store_true', help='remove silence from the audio files') 21 | 22 | args = parser.parse_args() 23 | input_folder = args.input_folder 24 | output_folder = args.output_folder 25 | seed = args.seed 26 | if seed is not None: 27 | random.seed(seed) 28 | remove_emotions = args.remove_emotions 29 | num_files = args.num_files 30 | 31 | resample = torchaudio.transforms.Resample(orig_freq=44100, new_freq=args.sample_rate) 32 | 33 | vad_model, utils = None,None 34 | get_speech_timestamps, _, _, _,collect_chunks = None,None,None,None,None 35 | if args.remove_silence: 36 | vad_model, utils = load_vad() 37 | get_speech_timestamps, _, _, _,collect_chunks = utils 38 | 39 | json_data = {} 40 | for i in range(1,5): 41 | speaker_data = {} 42 | speaker_input_folder = os.path.join(input_folder,str(i)) 43 | speaker_output_folder = os.path.join(output_folder,str(i)) 44 | os.makedirs(speaker_output_folder,exist_ok=True) 45 | all_files = glob.glob(os.path.join(speaker_input_folder, '*.wav')) 46 | if remove_emotions is not None: 47 | all_files = [f for f in all_files if not any([emotion in f for emotion in remove_emotions])] 48 | for j in range(num_files): 49 | num_seg = random.randint(args.min_concats, args.max_concats) 50 | files = random.sample(all_files, num_seg) 51 | output_wavs = [] 52 | output_texts = [] 53 | output_emotions = [] 54 | start = 0 55 | for file in files: 56 | wav,sr = torchaudio.load(file) 57 | if sr != args.sample_rate: 58 | wav = resample(wav) 59 | if vad_model is not None: 60 | timestamps = get_speech_timestamps(wav, vad_model, sampling_rate=sr,threshold=VAD_THRESHOLD) 61 | if len(timestamps) > 0: 62 | wav = collect_chunks(timestamps,wav.squeeze(0)).unsqueeze(0) 63 | output_wavs.append(wav) 64 | lab_file = file.replace('.wav','.lab') 65 | with open(lab_file) as f: 66 | text = f.read() 67 | output_texts.append(text) 68 | emotion = os.path.basename(lab_file).split('_')[0] 69 | output_emotions.append({"emo":emotion.lower(),"start":start,"end":start+wav.shape[-1]/args.sample_rate}) 70 | start += wav.shape[-1]/args.sample_rate 71 | output_wav = torch.cat(output_wavs,dim=-1) 72 | output_text = '.'.join(output_texts) 73 | output_wav_path = os.path.join(speaker_output_folder,f'{j}.wav') 74 | torchaudio.save(output_wav_path,output_wav,args.sample_rate) 75 | speaker_data[f"{j}"] = {"wav_path":os.path.abspath(output_wav_path),"text":output_text,"segmentation":output_emotions,"duration":output_wav.shape[-1]/args.sample_rate} 76 | json_data[f"speaker_{i}"] = speaker_data 77 | 78 | with open(os.path.join(output_folder,'data.json'),'w') as f: 79 | json.dump(json_data,f,indent=4) 80 | 81 | if __name__ == "__main__": 82 | main() -------------------------------------------------------------------------------- /emovdb_prepare_concat_gender_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchaudio 3 | import os 4 | import random 5 | import glob 6 | import torch 7 | import json 8 | from utils import load_vad,VAD_THRESHOLD 9 | 10 | 11 | def get_speaker_dict(text_files): 12 | emotion_dict = {} 13 | emotions = set() 14 | for text_file in text_files: 15 | with open(text_file) as f: 16 | text = f.read() 17 | emotion = os.path.basename(text_file).split('_')[0].lower() 18 | if emotion not in emotion_dict: 19 | emotion_dict[emotion] = [] 20 | emotions.add(emotion) 21 | wav_file = text_file.replace('.lab','.wav') 22 | emotion_dict[emotion].append((wav_file,text)) 23 | return emotion_dict 24 | 25 | def get_speaker_list(input_folder,remove_emotions=None): 26 | 27 | speaker_list = [] 28 | for i in range(1,5): 29 | speaker_input_folder = os.path.join(input_folder,str(i)) 30 | all_files = glob.glob(os.path.join(speaker_input_folder, '*.lab')) 31 | if remove_emotions is not None: 32 | all_files = [f for f in all_files if not any([emotion in f for emotion in remove_emotions])] 33 | speaker_list.append(get_speaker_dict(all_files)) 34 | return speaker_list 35 | 36 | FEMALE_IDX = [0,1] 37 | MALE_IDX = [2,3] 38 | 39 | def get_rand_speaker(is_male): 40 | return random.choice(MALE_IDX if is_male else FEMALE_IDX) 41 | 42 | def get_wav_file(wav_file,sample_rate, vad_model=None, get_speech_timestamps=None, collect_chunks=None): 43 | wav,sr = torchaudio.load(wav_file) 44 | if sr != sample_rate: 45 | wav = torchaudio.functional.resample(wav,sr,sample_rate) 46 | if vad_model is not None: 47 | timestamps = get_speech_timestamps(wav, vad_model, sampling_rate=sample_rate,threshold=VAD_THRESHOLD) 48 | if len(timestamps) > 0: 49 | wav = collect_chunks(timestamps,wav.squeeze(0)).unsqueeze(0) 50 | return wav 51 | 52 | def get_segment(speaker,start,wav,sample_rate): 53 | return {"speaker":speaker,"start":start,"end":start+wav.shape[-1]/sample_rate},start+wav.shape[-1]/sample_rate 54 | 55 | 56 | def concat_files(first_list,second_list,first_speaker_idx, 57 | second_speaker_idx,sample_rate,vad_model=None, get_speech_timestamps=None, collect_chunks=None): 58 | output_wavs = [] 59 | output_texts = [] 60 | output_segments = [] 61 | start = 0 62 | for (wav1_file,text1),(wav2_file,text2) in zip(first_list,second_list): 63 | wav1 = get_wav_file(wav1_file,sample_rate,vad_model, get_speech_timestamps, collect_chunks) 64 | wav2 = get_wav_file(wav2_file,sample_rate,vad_model, get_speech_timestamps, collect_chunks) 65 | output_wavs.extend([wav1,wav2]) 66 | output_texts.extend([text1,text2]) 67 | segment1,start = get_segment(first_speaker_idx,start,wav1,sample_rate) 68 | segment2,start = get_segment(second_speaker_idx,start,wav2,sample_rate) 69 | output_segments.extend([segment1,segment2]) 70 | 71 | output_wav = torch.cat(output_wavs,dim=-1) 72 | output_text = '.'.join(output_texts) 73 | return output_wav,output_text,output_segments 74 | 75 | 76 | def main(): 77 | parser = argparse.ArgumentParser("helper script to prepare the EmoV-DB dataset for our experiments") 78 | parser.add_argument('-i','--input_folder',type=str,help='path to the input folder',required=True) 79 | parser.add_argument('-o','--output_folder',type=str,help='path to the output folder',required=True) 80 | parser.add_argument('-s','--seed',type=int,default=None,help='seed for the random number generator') 81 | parser.add_argument('-max','--max_concats',type=int,default=10,help='maximum number of combinations to create for each speaker') 82 | parser.add_argument('-min','--min_concats',type=int,default=2,help='minimum number of combinations to create for each speaker') 83 | parser.add_argument('--remove_emotions', default=None, nargs='*', help='list of emotions to remove') 84 | parser.add_argument('--num_files', type=int, help='number of files to sample from each speaker',required=True) 85 | parser.add_argument('--sample_rate', type=int,default=16000, help='resample sr') 86 | parser.add_argument('--remove_silence', action='store_true', help='remove silence from the audio files') 87 | 88 | args = parser.parse_args() 89 | input_folder = args.input_folder 90 | output_folder = args.output_folder 91 | seed = args.seed 92 | if seed is not None: 93 | random.seed(seed) 94 | remove_emotions = args.remove_emotions 95 | num_files = args.num_files 96 | 97 | vad_model, utils = None,None 98 | get_speech_timestamps, _, _, _,collect_chunks = None,None,None,None,None 99 | if args.remove_silence: 100 | vad_model, utils = load_vad() 101 | get_speech_timestamps, _, _, _,collect_chunks = utils 102 | 103 | speakers_list = get_speaker_list(input_folder,remove_emotions) 104 | 105 | json_data = {} 106 | for first_speaker_idx in range(len(speakers_list)): 107 | speaker_data = {} 108 | first_speaker = speakers_list[first_speaker_idx] 109 | male_first = first_speaker_idx in MALE_IDX 110 | output_speaker_folder = os.path.join(output_folder,str(first_speaker_idx+1)) 111 | os.makedirs(output_speaker_folder,exist_ok=True) 112 | for second_speaker_idx in (FEMALE_IDX if male_first else MALE_IDX): 113 | second_speaker = speakers_list[second_speaker_idx] 114 | emotions = list(set(first_speaker.keys()).intersection(second_speaker.keys())) 115 | subkey = f"{'M' if male_first else 'F'}{first_speaker_idx+1}_{'F' if male_first else 'M'}{second_speaker_idx+1}" 116 | for j in range(num_files): 117 | emotion = random.choice(emotions) 118 | first_speaker_emotion_list = first_speaker[emotion] 119 | second_speaker_emotion_list = second_speaker[emotion] 120 | num_segs = random.randint(args.min_concats,args.max_concats) 121 | first_speaker_emotion_list = random.sample(first_speaker_emotion_list,num_segs) 122 | second_speaker_emotion_list = random.sample(second_speaker_emotion_list,num_segs) 123 | output_wav,output_text,output_segments = concat_files(first_speaker_emotion_list,second_speaker_emotion_list, 124 | first_speaker_idx,second_speaker_idx,args.sample_rate, 125 | vad_model, get_speech_timestamps, collect_chunks) 126 | output_wav_file = os.path.join(output_speaker_folder,f"{subkey}_{j}.wav") 127 | torchaudio.save(output_wav_file,output_wav,args.sample_rate) 128 | speaker_data[f"{subkey}_{j}"] = {"wav_path":os.path.abspath(output_wav_file), 129 | "text":output_text, 130 | "segmentation":output_segments, 131 | "emotion":emotion, 132 | "duration":output_wav.shape[-1]/args.sample_rate} 133 | 134 | json_data[f"Speaker{first_speaker_idx+1}"] = speaker_data 135 | 136 | 137 | with open(os.path.join(output_folder,'data.json'),'w') as f: 138 | json.dump(json_data,f,indent=4) 139 | 140 | if __name__ == "__main__": 141 | main() -------------------------------------------------------------------------------- /emovdb_prepare_concat_ib_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchaudio 3 | import os 4 | import random 5 | import glob 6 | import torch 7 | import json 8 | import librosa 9 | import numpy as np 10 | 11 | 12 | def split_files(all_files): 13 | durations = [librosa.get_duration(path=f) for f in all_files] 14 | mean = np.mean(durations) 15 | std = np.std(durations) 16 | 17 | short_files,avg_files,long_files = [],[],[] 18 | 19 | for f,d in zip(all_files,durations): 20 | if mean - d > 1.8*std: 21 | short_files.append(f) 22 | elif d - mean > 1.8*std: 23 | long_files.append(f) 24 | elif abs(d - mean) < std: 25 | avg_files.append(f) 26 | return short_files,avg_files,long_files 27 | 28 | 29 | def sample(short,avg,long,num_sgments): 30 | files = [] 31 | for i in range(num_sgments): # maybe more random? 32 | if i%4 <= 1: 33 | files.append(random.choice(short)) 34 | else: 35 | files.append(random.choice(long)) 36 | return files 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser("helper script to prepare the EmoV-DB dataset for our experiments") 40 | parser.add_argument('-i','--input_folder',type=str,help='path to the input folder',required=True) 41 | parser.add_argument('-o','--output_folder',type=str,help='path to the output folder',required=True) 42 | parser.add_argument('-s','--seed',type=int,default=None,help='seed for the random number generator') 43 | parser.add_argument('-max','--max_concats',type=int,default=10,help='maximum number of combinations to create for each speaker') 44 | parser.add_argument('-min','--min_concats',type=int,default=2,help='minimum number of combinations to create for each speaker') 45 | parser.add_argument('--remove_emotions', default=None, nargs='*', help='list of emotions to remove') 46 | parser.add_argument('--num_files', type=int, help='number of files to sample from each speaker',required=True) 47 | parser.add_argument('--sample_rate', type=int,default=16000, help='resample sr') 48 | 49 | args = parser.parse_args() 50 | input_folder = args.input_folder 51 | output_folder = args.output_folder 52 | seed = args.seed 53 | if seed is not None: 54 | random.seed(seed) 55 | remove_emotions = args.remove_emotions 56 | num_files = args.num_files 57 | 58 | resample = torchaudio.transforms.Resample(orig_freq=44100, new_freq=args.sample_rate) 59 | 60 | json_data = {} 61 | for i in range(1,5): 62 | speaker_data = {} 63 | speaker_input_folder = os.path.join(input_folder,str(i)) 64 | speaker_output_folder = os.path.join(output_folder,str(i)) 65 | os.makedirs(speaker_output_folder,exist_ok=True) 66 | all_files = glob.glob(os.path.join(speaker_input_folder, '*.wav')) 67 | if remove_emotions is not None: 68 | all_files = [f for f in all_files if not any([emotion in f for emotion in remove_emotions])] 69 | short_files,avg_files,long_files = split_files(all_files) 70 | print(f"Speaker {i} has {len(short_files)} short files, {len(avg_files)} average files and {len(long_files)} long files") 71 | for j in range(num_files): 72 | num_seg = random.randint(args.min_concats, args.max_concats) 73 | files = sample(short_files,avg_files,long_files,num_seg) 74 | output_wavs = [] 75 | output_texts = [] 76 | output_emotions = [] 77 | start = 0 78 | for file in files: 79 | wav,sr = torchaudio.load(file) 80 | if sr != args.sample_rate: 81 | wav = resample(wav) 82 | output_wavs.append(wav) 83 | lab_file = file.replace('.wav','.lab') 84 | with open(lab_file) as f: 85 | text = f.read() 86 | output_texts.append(text) 87 | emotion = os.path.basename(lab_file).split('_')[0] 88 | output_emotions.append({"emo":emotion.lower(),"start":start,"end":start+wav.shape[-1]/args.sample_rate}) 89 | start += wav.shape[-1]/args.sample_rate 90 | output_wav = torch.cat(output_wavs,dim=-1) 91 | output_text = '.'.join(output_texts) 92 | output_wav_path = os.path.join(speaker_output_folder,f'{j}.wav') 93 | torchaudio.save(output_wav_path,output_wav,args.sample_rate) 94 | speaker_data[f"{j}"] = {"wav_path":os.path.abspath(output_wav_path),"text":output_text,"segmentation":output_emotions,"duration":output_wav.shape[-1]/args.sample_rate} 95 | json_data[f"speaker_{i}"] = speaker_data 96 | 97 | with open(os.path.join(output_folder,'data.json'),'w') as f: 98 | json.dump(json_data,f,indent=4) 99 | 100 | if __name__ == "__main__": 101 | main() -------------------------------------------------------------------------------- /iemocap_prepare_concat_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchaudio 3 | import os 4 | import random 5 | import glob 6 | import torch 7 | import json 8 | from utils import get_emotion_dict 9 | import re 10 | 11 | pattern = r"(Ses\w+_\w+) \[\d+\.\d+-\d+\.\d+\]: (.+)" 12 | 13 | def get_speakers_dicts(text_files): 14 | session_emotion_dict = {} 15 | for text_file in text_files: 16 | session_emotion_dict.update(get_emotion_dict(text_file)) 17 | male = list(filter(lambda x: '_M' in x[0], session_emotion_dict.items())) 18 | female = list(filter(lambda x: '_F' in x[0], session_emotion_dict.items())) 19 | return male,female 20 | 21 | def get_transcriptions_dict(transcriptions_files): 22 | transcriptions_dict = {} 23 | for file in transcriptions_files: 24 | with open(file) as f: 25 | for line in f: 26 | line = line.strip() 27 | if line == '': 28 | continue 29 | match = re.match(pattern,line) 30 | if match is None: 31 | continue 32 | id = match.group(1) 33 | text = match.group(2) 34 | transcriptions_dict[id] = text 35 | return transcriptions_dict 36 | 37 | def get_concatenation(emotion_files,sentences_wav_path,transcription_dict,resmaple_rate): 38 | output_wavs = [] 39 | output_texts = [] 40 | output_emotions = [] 41 | start = 0 42 | for file,emotion in emotion_files: 43 | if "script" in file: 44 | folder = "_".join(file.split('_')[:3]) 45 | else: 46 | folder = "_".join(file.split('_')[:2]) 47 | wav_file = os.path.join(sentences_wav_path,folder,file+'.wav') 48 | wav,sr = torchaudio.load(wav_file) 49 | if sr != resmaple_rate: 50 | wav = torchaudio.functional.resample(wav,sr,resmaple_rate) 51 | output_wavs.append(wav) 52 | output_texts.append(transcription_dict[file].strip()) 53 | output_emotions.append({"emo":emotion.lower(),"start":start,"end":start+wav.shape[-1]/resmaple_rate}) 54 | start += wav.shape[-1]/resmaple_rate 55 | output_wav = torch.cat(output_wavs,dim=-1) 56 | output_text = ' '.join(output_texts) 57 | return output_wav,output_text,output_emotions 58 | 59 | def main(): 60 | parser = argparse.ArgumentParser("helper script to prepare the iemocap dataset for our experiments") 61 | parser.add_argument('-i','--input_folder',type=str,help='path to the input folder',required=True) 62 | parser.add_argument('-o','--output_folder',type=str,help='path to the output folder',required=True) 63 | parser.add_argument('-s','--seed',type=int,default=None,help='seed for the random number generator') 64 | parser.add_argument('-max','--max_concats',type=int,default=10,help='maximum number of combinations to create for each speaker') 65 | parser.add_argument('-min','--min_concats',type=int,default=2,help='minimum number of combinations to create for each speaker') 66 | parser.add_argument('--num_files', type=int, help='number of files to sample from each speaker',required=True) 67 | parser.add_argument('--sample_rate', type=int,default=16000, help='resample sr') 68 | 69 | args = parser.parse_args() 70 | input_folder = args.input_folder 71 | output_folder = args.output_folder 72 | seed = args.seed 73 | if seed is not None: 74 | random.seed(seed) 75 | num_files = args.num_files 76 | sample_rate = args.sample_rate 77 | 78 | 79 | json_data = {} 80 | for i in range(1,6): 81 | print(f"processing session {i}") 82 | session_data = {} 83 | session_path = os.path.join(input_folder,'Session'+str(i)) 84 | sentences_wav_path = os.path.join(session_path,'sentences/wav') 85 | dialog_emo_path = os.path.join(session_path,'dialog/EmoEvaluation') 86 | text_files = glob.glob(os.path.join(dialog_emo_path, '*.txt')) 87 | transcriptions_files = glob.glob(os.path.join(session_path,'dialog/transcriptions/*.txt')) 88 | transcription_dicts = get_transcriptions_dict(transcriptions_files) 89 | emotion_male,emotion_female = get_speakers_dicts(text_files) 90 | output_session_path = os.path.join(output_folder,'Session'+str(i)) 91 | os.makedirs(output_session_path,exist_ok=True) 92 | for j in range(num_files): 93 | num_seg = random.randint(args.min_concats, args.max_concats) 94 | files = random.sample(emotion_male, num_seg) 95 | wav,text,emotions = get_concatenation(files,sentences_wav_path,transcription_dicts,sample_rate) 96 | output_wav_path = os.path.join(output_session_path,f'Sess{i}_male_{j}.wav') 97 | torchaudio.save(output_wav_path,wav,sample_rate) 98 | session_data[f"Sess{i}_male_{j}"] = {"wav_path":os.path.abspath(output_wav_path),"text":text,"segmentation":emotions,"duration":wav.shape[-1]/sample_rate} 99 | files = random.sample(emotion_female, num_seg) 100 | wav,text,emotions = get_concatenation(files,sentences_wav_path,transcription_dicts,sample_rate) 101 | output_wav_path = os.path.join(output_session_path,f'Sess{i}_female_{j}.wav') 102 | torchaudio.save(output_wav_path,wav,sample_rate) 103 | session_data[f"Sess{i}_female_{j}"] = {"wav_path":os.path.abspath(output_wav_path),"text":text,"segmentation":emotions,"duration":wav.shape[-1]/sample_rate} 104 | json_data[f"Session{i}"] = session_data 105 | 106 | with open(os.path.join(output_folder,'data.json'),'w') as f: 107 | json.dump(json_data,f,indent=4) 108 | 109 | if __name__ == "__main__": 110 | main() -------------------------------------------------------------------------------- /iemocap_prepare_concat_gender_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchaudio 3 | import os 4 | import random 5 | import glob 6 | import torch 7 | import json 8 | from utils import get_emotion_dict,EMOTION_DICT 9 | import re 10 | 11 | pattern = r"(Ses\w+_\w+) \[\d+\.\d+-\d+\.\d+\]: (.+)" 12 | 13 | 14 | EMOTIONS = list(set(EMOTION_DICT.values())) 15 | 16 | def get_transcriptions_dict(transcriptions_files): 17 | transcriptions_dict = {} 18 | for file in transcriptions_files: 19 | with open(file) as f: 20 | for line in f: 21 | line = line.strip() 22 | if line == '': 23 | continue 24 | match = re.match(pattern,line) 25 | if match is None: 26 | continue 27 | id = match.group(1) 28 | text = match.group(2) 29 | transcriptions_dict[id] = text 30 | return transcriptions_dict 31 | 32 | 33 | def get_speaker_dict(text_files): 34 | male_speaker_dict = {} 35 | female_speaker_dict = {} 36 | for text_file in text_files: 37 | emotion_dict = get_emotion_dict(text_file) 38 | for file,emo in emotion_dict.items(): 39 | if '_M' in file: 40 | if emo not in male_speaker_dict: 41 | male_speaker_dict[emo] = [] 42 | male_speaker_dict[emo].append(file) 43 | elif '_F' in file: 44 | if emo not in female_speaker_dict: 45 | female_speaker_dict[emo] = [] 46 | female_speaker_dict[emo].append(file) 47 | return male_speaker_dict,female_speaker_dict 48 | 49 | get_folder = lambda x: "_".join(x.split('_')[:3]) if "script" in x else "_".join(x.split('_')[:2]) 50 | 51 | def get_wav(file,sentences_wav_path,resmaple_rate): 52 | folder = get_folder(file) 53 | wav_file = os.path.join(sentences_wav_path,folder,file+'.wav') 54 | wav,sr = torchaudio.load(wav_file) 55 | if sr != resmaple_rate: 56 | wav = torchaudio.functional.resample(wav,sr,resmaple_rate) 57 | return wav 58 | 59 | def speaker_segment(wav,sample_rate,speaker,start): 60 | return { 61 | "speaker":speaker, 62 | "start":start, 63 | "end":start+wav.shape[-1]/sample_rate 64 | }, start+wav.shape[-1]/sample_rate 65 | 66 | 67 | def concat_files(first_files,second_files,first_speaker,second_speaker, 68 | sentences_wav_path,transcription_dict,resmaple_rate): 69 | output_wavs = [] 70 | output_texts = [] 71 | output_speakers = [] 72 | start = 0 73 | for file1,file2 in zip(first_files,second_files): 74 | wav1 = get_wav(file1,sentences_wav_path,resmaple_rate) 75 | wav2 = get_wav(file2,sentences_wav_path,resmaple_rate) 76 | output_wavs.extend([wav1,wav2]) 77 | output_texts.extend([transcription_dict[file1],transcription_dict[file2]]) 78 | speaker1_seg,start = speaker_segment(wav1,resmaple_rate,first_speaker,start) 79 | speaker2_seg,start = speaker_segment(wav2,resmaple_rate,second_speaker,start) 80 | output_speakers.extend([speaker1_seg,speaker2_seg]) 81 | output_wav = torch.cat(output_wavs,dim=-1) 82 | output_text = ' '.join(output_texts) 83 | return output_wav,output_text,output_speakers 84 | 85 | def get_speakers_dict(input_folder): 86 | """ 87 | get a dict where each key is a speaker, values are another dict with a key for each emotion that will have a list of files 88 | """ 89 | male_speakers_dict = {} 90 | female_speakers_dict = {} 91 | transcription_dict = {} 92 | for i in range(1,6): 93 | session_path = os.path.join(input_folder,'Session'+str(i)) 94 | dialog_emo_path = os.path.join(session_path,'dialog/EmoEvaluation') 95 | text_files = glob.glob(os.path.join(dialog_emo_path, '*.txt')) 96 | male_speaker_dict,female_speaker_dict = get_speaker_dict(text_files) 97 | male_speakers_dict[f"Session{str(i)}_M"] = male_speaker_dict 98 | female_speakers_dict[f"Session{str(i)}_F"] = female_speaker_dict 99 | transcriptions_files = glob.glob(os.path.join(session_path,'dialog/transcriptions/*.txt')) 100 | transcription_dict.update(get_transcriptions_dict(transcriptions_files)) 101 | return male_speakers_dict,female_speakers_dict,transcription_dict 102 | 103 | def main(): 104 | parser = argparse.ArgumentParser("helper script to prepare the iemocap dataset for our experiments") 105 | parser.add_argument('-i','--input_folder',type=str,help='path to the input folder',required=True) 106 | parser.add_argument('-o','--output_folder',type=str,help='path to the output folder',required=True) 107 | parser.add_argument('-s','--seed',type=int,default=None,help='seed for the random number generator') 108 | parser.add_argument('-max','--max_concats',type=int,default=10,help='maximum number of combinations to create for each speaker') 109 | parser.add_argument('-min','--min_concats',type=int,default=2,help='minimum number of combinations to create for each speaker') 110 | parser.add_argument('--num_files', type=int, help='number of files to sample from each speaker',required=True) 111 | parser.add_argument('--sample_rate', type=int,default=16000, help='resample sr') 112 | 113 | args = parser.parse_args() 114 | input_folder = args.input_folder 115 | output_folder = args.output_folder 116 | seed = args.seed 117 | if seed is not None: 118 | random.seed(seed) 119 | num_files = args.num_files 120 | sample_rate = args.sample_rate 121 | 122 | 123 | json_data = {} 124 | male_speakers_dict,female_speakers_dict,transcription_dict = get_speakers_dict(input_folder) 125 | for i in range(1,6): 126 | print(f"processing session {i}") 127 | sentences_wav_path = os.path.join(input_folder,'Session'+str(i),'sentences/wav') 128 | output_session_folder = os.path.join(output_folder,f"Session_{str(i)}") 129 | os.makedirs(output_session_folder,exist_ok=True) 130 | male_speaker = f"Session{str(i)}_M" 131 | female_speaker = f"Session{str(i)}_F" 132 | session_data = {} 133 | for j in range(num_files): 134 | rand_emo = random.choice(EMOTIONS) 135 | num_segments = random.randint(args.min_concats,args.max_concats) 136 | male_emo_list = male_speakers_dict[male_speaker][rand_emo] 137 | female_emo_list = female_speakers_dict[female_speaker][rand_emo] 138 | wav,text,speaker_segments = concat_files(random.sample(male_emo_list,num_segments),random.sample(female_emo_list,num_segments), 139 | male_speaker,female_speaker,sentences_wav_path,transcription_dict,sample_rate) 140 | output_wav_path = os.path.join(output_session_folder,f"MF_{rand_emo}_{j}.wav") 141 | torchaudio.save(output_wav_path,wav,sample_rate) 142 | session_data[f"MF_{rand_emo}_{j}"] = { 143 | "text":text, 144 | "segmentation":speaker_segments, 145 | "wav_path":os.path.abspath(output_wav_path), 146 | "duration":wav.shape[-1]/sample_rate, 147 | "emotion":rand_emo 148 | } 149 | 150 | rand_emo = random.choice(EMOTIONS) 151 | num_segments = random.randint(args.min_concats,args.max_concats) 152 | male_emo_list = male_speakers_dict[male_speaker][rand_emo] 153 | female_emo_list = female_speakers_dict[female_speaker][rand_emo] 154 | wav,text,speaker_segments = concat_files(random.sample(female_emo_list,num_segments),random.sample(male_emo_list,num_segments), 155 | female_speaker,male_speaker,sentences_wav_path,transcription_dict,sample_rate) 156 | output_wav_path = os.path.join(output_session_folder,f"FM_{rand_emo}_{j}.wav") 157 | torchaudio.save(output_wav_path,wav,sample_rate) 158 | session_data[f"FM_{rand_emo}_{j}"] = { 159 | "text":text, 160 | "segmentation":speaker_segments, 161 | "wav_path":os.path.abspath(output_wav_path), 162 | "duration":wav.shape[-1]/sample_rate, 163 | "emotion":rand_emo 164 | } 165 | json_data[f"Session_{str(i)}"] = session_data 166 | 167 | with open(os.path.join(output_folder,'data.json'),'w') as f: 168 | json.dump(json_data,f,indent=4) 169 | 170 | 171 | if __name__ == "__main__": 172 | main() -------------------------------------------------------------------------------- /install_dep.sh: -------------------------------------------------------------------------------- 1 | pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118 --no-cache-dir 2 | pip install -r requirements.txt --no-cache-dir 3 | pip install git+https://github.com/pytorch/fairseq@da8fb630880d529ab47e53381c30ddc8ad235216 --no-cache-dir -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.10.0 2 | numpy==1.22.0 3 | pyannote.audio==3.0.0 4 | pyannote.core==5.0.0 5 | pyannote.database==5.0.1 6 | pyannote.metrics==3.2.1 7 | pyannote.pipeline==3.0.1 8 | scipy==1.11.4 9 | tqdm==4.66.1 10 | -e git+https://github.com/facebookresearch/textlesslib.git@ba33d669d8284b4f7bfe81e7384e83ab799fe384#egg=textless 11 | transformers==4.36.1 12 | wget==3.2 13 | speechbrain==1.0.0 14 | boltons==20.0.0 15 | -------------------------------------------------------------------------------- /segment_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from speech_lm.evaluation.segmentation import SegmentationMetrics, to_annotation, to_timeline, AVAILABLE_METRICS 3 | import json 4 | 5 | 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser("this script evaluates segmented audio files segmentation."\ 10 | "should use a json file you should have a dict where each key contains another dict with a key for segmentation" \ 11 | "with a list of segments. each segment should have a start and end time", "multiple files can be evaluated at once if they are in the same json") 12 | parser.add_argument('-re','--reference_path', type=str, help='Path to the json file that contains the reference segmentation',required=True) 13 | parser.add_argument('-hy','--hypothesis_path', type=str, help='Path to the json file that contains the predicted segmentation',required=True) 14 | parser.add_argument('-m','--metrics', type=str,nargs="+", help='metrics to use for evaluation',choices=AVAILABLE_METRICS + ["all"],required=True) 15 | parser.add_argument('-p','--print_sub', action="store_true", help='print the results') 16 | parser.add_argument('-o','--output', type=str,default=None, help='Path to the output file if None will print to stdout') 17 | parser.add_argument('-ci','--confidence_interval', action="store_true", help='add confidence interval') 18 | 19 | args = parser.parse_args() 20 | reference_path = args.reference_path 21 | with open(reference_path) as f: 22 | reference = json.load(f) 23 | hypothesis_path = args.hypothesis_path 24 | with open(hypothesis_path) as f: 25 | hypothesis = json.load(f) 26 | 27 | print_sub = args.print_sub 28 | 29 | m = args.metrics 30 | if "all" in m: 31 | m = AVAILABLE_METRICS 32 | 33 | metrics = SegmentationMetrics(m) 34 | 35 | for key,referene_values in reference.items(): 36 | if key not in hypothesis: 37 | continue 38 | hypothesis_values = hypothesis[key] 39 | for sub_key,reference_segments in referene_values.items(): 40 | if sub_key not in hypothesis_values: 41 | continue 42 | hypothesis_segments = hypothesis_values[sub_key]["segmentation"] 43 | reference_segments = reference_segments["segmentation"] 44 | reference_annotation = to_annotation(reference_segments) 45 | hypothesis_annotation = to_timeline(hypothesis_segments) 46 | results = metrics(reference_annotation,hypothesis_annotation) 47 | if print_sub: 48 | print(f"Results for {key} {sub_key}:") 49 | for metric in results: 50 | print(f"{metric}: {results[metric]}") 51 | print("") 52 | 53 | results = abs(metrics) 54 | if args.confidence_interval: 55 | results["confidence_interval"] = metrics.confidence_interval() 56 | if args.output: 57 | with open(args.output,"w") as f: 58 | json.dump(results,f,indent=4) 59 | else: 60 | print(results) 61 | 62 | 63 | 64 | 65 | if __name__ == '__main__': 66 | main() -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Felix Kreuk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Contrastive Learning for Unsupervised Phoneme Segmentation (INTERSPEECH 2020) 2 | 3 | ## Paper 4 | [Self-Supervised Contrastive Learning for Unsupervised Phoneme Segmentation](https://arxiv.org/abs/2007.13465). 5 |
6 | Felix Kreuk, Joseph Keshet, Yossi Adi 7 |
8 | INTERSPEECH 2020 9 | 10 | We propose a self-supervised representation learning model for the task of unsupervised phoneme boundary detection. The model is a convolutional neural network that operates directly on the raw waveform. It is optimized to identify spectral changes in the signal using the Noise-Contrastive Estimation principle. At test time, a peak detection algorithm is applied over the model outputs to produce the final boundaries. As such, the proposed model is trained in a fully unsupervised manner with no manual annotations in the form of target boundaries nor phonetic transcriptions. We compare the proposed approach to several unsupervised baselines using both TIMIT and Buckeye corpora. Results suggest that our approach surpasses the baseline models and reaches state-of-the-art performance on both data sets. Furthermore, we experimented with expanding the training set with additional examples from the Librispeech corpus. We evaluated the resulting model on distributions and languages that were not seen during the training phase (English, Hebrew and German) and showed that utilizing additional untranscribed data is beneficial for model performance. 11 | 12 | If you find this paper and implementation useful, please consider citing our work: 13 | ``` 14 | @article{kreuk2020self, 15 | title={Self-Supervised Contrastive Learning for Unsupervised Phoneme Segmentation}, 16 | author={Kreuk, Felix and Keshet, Joseph and Adi, Yossi}, 17 | journal={arXiv preprint arXiv:2007.13465}, 18 | year={2020} 19 | } 20 | ``` 21 | 22 | ## Clone repository 23 | ``` 24 | git clone https://github.com/felixkreuk/UnsupSeg.git 25 | cd UnsupSeg 26 | ``` 27 | 28 | ## Setup environment 29 | ``` 30 | conda create --name unsup_seg --file requirements.txt 31 | conda activate unsup_seg 32 | ``` 33 | 34 | ## Data structure 35 | The training script assumes that the data is structured as follows: 36 | ``` 37 | timit_directory 38 | │ 39 | └───val 40 | │ │ X.wav 41 | │ └─ X.phn 42 | │ 43 | └───test 44 | │ │ Y.wav 45 | │ └─ Y.phn 46 | │ 47 | └───train 48 | │ Z.wav 49 | └─ Z.phn 50 | ``` 51 | 52 | Where `X.wav` is a raw waveform signal, and `X.phn` is its' corresponding phoneme boundaries labeld with the following format: 53 | ``` 54 | 0 9640 h# 55 | 9640 11240 sh 56 | 11240 12783 iy 57 | 12783 14078 hv 58 | 14078 16157 ae 59 | 16157 16880 dcl 60 | ... 61 | ``` 62 | Where the two numbers each line represent the onset of offset of the phoneme (in samples), and the last element represents the phoneme identity. 63 | 64 | ## Usage 65 | ### Data preperation 66 | To convert default TIMIT file formats to the required format in the `Data structure` section, you should first run the script `scripts/make_timit.py`. 67 | ``` 68 | python scripts/make_timit.py --inpath /path/to/original/timit --outpath /path/to/output/timit 69 | ``` 70 | 71 | ### Configuration 72 | Prior to using our code you should configure the paths to the TIMIT/Buckeye/Librispeech datasets under `conf/config.yaml`. 73 | For TIMIT/Buckeye the path should point to a directory with three sub-directories: train, val and test. 74 | For Librispeech the path should point to a directory that contains the `LibriSpeech` directory downloaded by torchaudio (or by you manually). 75 | 76 | For example: 77 | ``` 78 | ... 79 | buckeye_path: /data/datasets/buckeye 80 | timit_path: /data/datasets/timit 81 | libri_path: /data/datasets # under this dir there's a LibriSpeech dir 82 | ... 83 | ``` 84 | 85 | ### Train 86 | To run training with default hyper-parameters, run the following: 87 | ``` 88 | python main.py 89 | ``` 90 | To see further hyper-parameters see `conf/config.yaml`. 91 | More examples: 92 | ``` 93 | python main.py data=timit # train on timit 94 | python main.py data=buckeye # train on buckeye 95 | python main.py data=timit_libri # timit + librispeech 96 | python main.py data=timit_libri libri_percent=0.5 # use only 50% of librispeech 97 | ``` 98 | 99 | ### Test 100 | The following command runs a test epoch on the selected data and reports results in terms of precision, recall, F1 and R-value. 101 | ``` 102 | python main.py ckpt=/absolute/path/to/model.ckpt data=timit # test on timit 103 | python main.py ckpt=/absolute/path/to/model.ckpt data=buckeye # test on buckeye 104 | ``` 105 | 106 | ### Inference on a single wav 107 | The following command runs inference on a single .wav file and outputs the predicted boundaries in seconds. 108 | ``` 109 | python predict.py --ckpt /absolute/path/to/model.ckpt --wav /path/to/audio.wav 110 | python predict.py --ckpt /absolute/path/to/model.ckpt --wav /path/to/audio.wav --prominence 0.05 111 | ``` 112 | The threshold for the peak detection procedure can be adjusted using the `--prominence 0.05` argument. For best results it is advisable to trim silences using a voice activity detector. 113 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishaiElmakies/unsupervised_speech_segmentation_using_slm/b0e4fe17bc2b700de6fce31e49c1cde789bad8fd/speech_lm/UnsupSeg/__init__.py -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/conf/config.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: /data/felix/runs/unsupervised_segmentor/${now:%Y-%m-%d_%H-%M-%S}-${exp_name} 4 | 5 | # DATA 6 | libri_path: /data/felix/datasets/librispeech 7 | buckeye_path: /data/felix/datasets/buckeye_processed_by_speaker_balanced 8 | timit_path: /data/felix/datasets/timit 9 | libri_subset: train-clean-100 10 | libri_percent: 1.0 11 | buckeye_percent: 1.0 12 | val_ratio: 0.1 # ratio of validation set 13 | data: timit 14 | dataloader_n_workers: 10 15 | 16 | # MODEL 17 | cosine_coef: 1.0 # cosine similarity coefficient 18 | z_proj: 64 # size of projection 19 | z_proj_linear: true 20 | z_proj_dropout: 0 21 | z_dim: 256 22 | pred_steps: 1 # number of future prediction steps 23 | pred_offset: 0 # offset of future prediction steps 24 | batch_shuffle: false # if 'false' negative samples will be from the same utterance, if 'true' may be from different utterances 25 | latent_dim: 0 # latent dimension of encoder 26 | n_negatives: 1 # number of negative samples for contrastive loss 27 | 28 | # MISC 29 | gpus: 1 30 | tag: default 31 | exp_name: default 32 | project: unsupervised_segmentor 33 | ckpt: null 34 | dev_run: false # fast debug run 35 | val_check_interval: 0.2 # how often a validation epoch is run 36 | overfit_pct: 1 37 | seed: 100 38 | early_stop_metric: val_max_rval 39 | early_stop_mode: max 40 | 41 | # OPTIMIZATION 42 | optimizer: adam 43 | momentum: 0.9 44 | lr: 0.0002 45 | lr_anneal_gamma: 1.0 46 | lr_anneal_step: 1000 47 | epochs: 200 48 | grad_clip: 0.5 49 | batch_size: 8 50 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader, Dataset 5 | torch.multiprocessing.set_sharing_strategy('file_system') 6 | from tqdm import tqdm 7 | import numpy as np 8 | import os 9 | from os.path import join, basename 10 | from boltons.fileutils import iter_find_files 11 | import soundfile as sf 12 | import librosa 13 | import pickle 14 | from multiprocessing import Pool 15 | import random 16 | import torchaudio 17 | import math 18 | from torchaudio.datasets import LIBRISPEECH 19 | 20 | 21 | def collate_fn_padd(batch): 22 | """collate_fn_padd 23 | Padds batch of variable length 24 | 25 | :param batch: 26 | """ 27 | # get sequence lengths 28 | spects = [t[0] for t in batch] 29 | segs = [t[1] for t in batch] 30 | labels = [t[2] for t in batch] 31 | lengths = [t[3] for t in batch] 32 | fnames = [t[4] for t in batch] 33 | 34 | padded_spects = torch.nn.utils.rnn.pad_sequence(spects, batch_first=True) 35 | lengths = torch.LongTensor(lengths) 36 | 37 | return padded_spects, segs, labels, lengths, fnames 38 | 39 | 40 | def spectral_size(wav_len): 41 | layers = [(10,5,0), (8,4,0), (4,2,0), (4,2,0), (4,2,0)] 42 | for kernel, stride, padding in layers: 43 | wav_len = math.floor((wav_len + 2*padding - 1*(kernel-1) - 1)/stride + 1) 44 | return wav_len 45 | 46 | 47 | def get_subset(dataset, percent): 48 | A_split = int(len(dataset) * percent) 49 | B_split = len(dataset) - A_split 50 | dataset, _ = torch.utils.data.random_split(dataset, [A_split, B_split]) 51 | return dataset 52 | 53 | 54 | class WavPhnDataset(Dataset): 55 | def __init__(self, path): 56 | self.path = path 57 | self.data = list(iter_find_files(self.path, "*.wav")) 58 | super(WavPhnDataset, self).__init__() 59 | 60 | @staticmethod 61 | def get_datasets(path): 62 | raise NotImplementedError 63 | 64 | def process_file(self, wav_path): 65 | phn_path = wav_path.replace("wav", "phn") 66 | 67 | # load audio 68 | audio, sr = torchaudio.load(wav_path) 69 | audio = audio[0] 70 | audio_len = len(audio) 71 | spectral_len = spectral_size(audio_len) 72 | len_ratio = (audio_len / spectral_len) 73 | 74 | # load labels -- segmentation and phonemes 75 | with open(phn_path, "r") as f: 76 | lines = f.readlines() 77 | lines = list(map(lambda line: line.split(" "), lines)) 78 | 79 | # get segment times 80 | times = torch.FloatTensor(list(map(lambda line: int(int(line[1]) / len_ratio), lines)))[:-1] # don't count end time as boundary 81 | 82 | # get phonemes in each segment (for K times there should be K+1 phonemes) 83 | phonemes = list(map(lambda line: line[2].strip(), lines)) 84 | 85 | return audio, times.tolist(), phonemes, wav_path 86 | 87 | def __getitem__(self, idx): 88 | audio, seg, phonemes, fname = self.process_file(self.data[idx]) 89 | return audio, seg, phonemes, spectral_size(len(audio)), fname 90 | 91 | def __len__(self): 92 | return len(self.data) 93 | 94 | 95 | class TrainTestDataset(WavPhnDataset): 96 | def __init__(self, path): 97 | super(TrainTestDataset, self).__init__(path) 98 | 99 | @staticmethod 100 | def get_datasets(path, val_ratio=0.1): 101 | train_dataset = TrainTestDataset(join(path, 'train')) 102 | test_dataset = TrainTestDataset(join(path, 'test')) 103 | 104 | train_len = len(train_dataset) 105 | train_split = int(train_len * (1 - val_ratio)) 106 | val_split = train_len - train_split 107 | train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_split, val_split]) 108 | 109 | train_dataset.path = join(path, 'train') 110 | val_dataset.path = join(path, 'train') 111 | 112 | return train_dataset, val_dataset, test_dataset 113 | 114 | 115 | class TrainValTestDataset(WavPhnDataset): 116 | def __init__(self, paths): 117 | super(TrainValTestDataset, self).__init__(paths) 118 | 119 | @staticmethod 120 | def get_datasets(path, percent=1.0): 121 | train_dataset = TrainValTestDataset(join(path, 'train')) 122 | if percent != 1.0: 123 | train_dataset = get_subset(train_dataset, percent) 124 | train_dataset.path = join(path, 'train') 125 | val_dataset = TrainValTestDataset(join(path, 'val')) 126 | test_dataset = TrainValTestDataset(join(path, 'test')) 127 | 128 | return train_dataset, val_dataset, test_dataset 129 | 130 | 131 | class LibriSpeechDataset(LIBRISPEECH): 132 | def __init__(self, path, subset, percent): 133 | self.libri_dataset = LIBRISPEECH(path, url=subset, download=False) 134 | if percent != 1.0: 135 | self.libri_dataset = get_subset(self.libri_dataset, percent) 136 | self.path = path 137 | 138 | def __getitem__(self, idx): 139 | wav, sr, utt, spk_id, chp_id, utt_id = self.libri_dataset[idx] 140 | wav = wav[0] 141 | return wav, None, None, spectral_size(len(wav)), None 142 | 143 | def __len__(self): 144 | return len(self.libri_dataset) 145 | 146 | 147 | class MixedDataset(Dataset): 148 | def __init__(self, ds1, ds2): 149 | self.ds1 = ds1 150 | self.ds2 = ds2 151 | self.path = f"{ds1.path}+{ds2.path}" 152 | self.ds1_len, self.ds2_len = len(ds1), len(ds2) 153 | 154 | def __len__(self): 155 | return self.ds1_len + self.ds2_len 156 | 157 | def __getitem__(self, idx): 158 | if idx < self.ds1_len: 159 | return self.ds1[idx] 160 | else: 161 | return self.ds2[idx - self.ds1_len] 162 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import socket 4 | from argparse import Namespace 5 | from distutils.dir_util import copy_tree 6 | 7 | import hydra 8 | import numpy as np 9 | import torch 10 | from pytorch_lightning import Trainer 11 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 12 | from torch.backends import cudnn 13 | 14 | from solver import Solver 15 | 16 | torch.autograd.set_detect_anomaly(True) 17 | 18 | 19 | @hydra.main(config_path='conf/config.yaml', strict=False) 20 | def main(cfg): 21 | torch.manual_seed(cfg.seed) 22 | np.random.seed(cfg.seed) 23 | random.seed(cfg.seed) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | print(f"running in: {os.getcwd()}") 28 | cfg.wd = os.getcwd() 29 | cfg.host = socket.gethostname() 30 | cfg.project = "default" if not hasattr(cfg, "project") else cfg.project 31 | cfg = Namespace(**dict(cfg)) 32 | 33 | checkpoint_callback = ModelCheckpoint( 34 | filepath=os.getcwd(), 35 | save_top_k=1, 36 | verbose=True, 37 | monitor=cfg.early_stop_metric, 38 | mode=cfg.early_stop_mode, 39 | prefix='', 40 | ) 41 | 42 | trainer = Trainer( 43 | checkpoint_callback=checkpoint_callback, 44 | early_stop_callback=None, 45 | distributed_backend="dp", 46 | show_progress_bar=True, 47 | num_sanity_val_steps=0, 48 | track_grad_norm=2, 49 | print_nan_grads=True, 50 | gpus=cfg.gpus, 51 | gradient_clip_val=cfg.grad_clip, 52 | val_check_interval=cfg.val_check_interval, 53 | fast_dev_run=cfg.dev_run, 54 | max_epochs=cfg.epochs 55 | ) 56 | 57 | if cfg.ckpt is not None: 58 | ckpt = cfg.ckpt 59 | else: 60 | solver = Solver(cfg) 61 | trainer.fit(solver) 62 | ckpt = solver.get_ckpt_path() 63 | 64 | print(f"running test on ckpt: {ckpt}") 65 | print(f"testing for {cfg.data.upper()}") 66 | solver = Solver.load_from_checkpoint(ckpt) 67 | 68 | # override checkpoint paths with current conf paths 69 | solver.hp.timit_path = cfg.timit_path 70 | solver.hp.buckeye_path = cfg.buckeye_path 71 | solver.hp.libri_path = cfg.libri_path 72 | solver.hp.data = cfg.data 73 | trainer.test(solver) 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/next_frame_classifier.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import hydra 7 | from .utils import LambdaLayer, PrintShapeLayer, length_to_mask 8 | from .dataloader import TrainTestDataset 9 | from collections import defaultdict 10 | 11 | 12 | class NextFrameClassifier(nn.Module): 13 | def __init__(self, hp): 14 | super(NextFrameClassifier, self).__init__() 15 | self.hp = hp 16 | Z_DIM = hp.z_dim 17 | LS = hp.latent_dim if hp.latent_dim != 0 else Z_DIM 18 | 19 | self.enc = nn.Sequential( 20 | nn.Conv1d(1, LS, kernel_size=10, stride=5, padding=0, bias=False), 21 | nn.BatchNorm1d(LS), 22 | nn.LeakyReLU(), 23 | nn.Conv1d(LS, LS, kernel_size=8, stride=4, padding=0, bias=False), 24 | nn.BatchNorm1d(LS), 25 | nn.LeakyReLU(), 26 | nn.Conv1d(LS, LS, kernel_size=4, stride=2, padding=0, bias=False), 27 | nn.BatchNorm1d(LS), 28 | nn.LeakyReLU(), 29 | nn.Conv1d(LS, LS, kernel_size=4, stride=2, padding=0, bias=False), 30 | nn.BatchNorm1d(LS), 31 | nn.LeakyReLU(), 32 | nn.Conv1d(LS, Z_DIM, kernel_size=4, stride=2, padding=0, bias=False), 33 | LambdaLayer(lambda x: x.transpose(1,2)), 34 | ) 35 | print("learning features from raw wav") 36 | 37 | if self.hp.z_proj != 0: 38 | if self.hp.z_proj_linear: 39 | self.enc.add_module( 40 | "z_proj", 41 | nn.Sequential( 42 | nn.Dropout2d(self.hp.z_proj_dropout), 43 | nn.Linear(Z_DIM, self.hp.z_proj), 44 | ) 45 | ) 46 | else: 47 | self.enc.add_module( 48 | "z_proj", 49 | nn.Sequential( 50 | nn.Dropout2d(self.hp.z_proj_dropout), 51 | nn.Linear(Z_DIM, Z_DIM), nn.LeakyReLU(), 52 | nn.Dropout2d(self.hp.z_proj_dropout), 53 | nn.Linear(Z_DIM, self.hp.z_proj), 54 | ) 55 | ) 56 | 57 | # # similarity estimation projections 58 | self.pred_steps = list(range(1 + self.hp.pred_offset, 1 + self.hp.pred_offset + self.hp.pred_steps)) 59 | print(f"prediction steps: {self.pred_steps}") 60 | 61 | def score(self, f, b): 62 | return F.cosine_similarity(f, b, dim=-1) * self.hp.cosine_coef 63 | 64 | def forward(self, spect): 65 | device = spect.device 66 | 67 | # wav => latent z 68 | z = self.enc(spect.unsqueeze(1)) 69 | 70 | preds = defaultdict(list) 71 | for i, t in enumerate(self.pred_steps): # predict for steps 1...t 72 | pos_pred = self.score(z[:, :-t], z[:, t:]) # score for positive frame 73 | preds[t].append(pos_pred) 74 | 75 | for _ in range(self.hp.n_negatives): 76 | if self.training: 77 | time_reorder = torch.randperm(pos_pred.shape[1]) 78 | batch_reorder = torch.arange(pos_pred.shape[0]) 79 | if self.hp.batch_shuffle: 80 | batch_reorder = torch.randperm(pos_pred.shape[0]) 81 | else: 82 | time_reorder = torch.arange(pos_pred.shape[1]) 83 | batch_reorder = torch.arange(pos_pred.shape[0]) 84 | 85 | neg_pred = self.score(z[:, :-t], z[batch_reorder][: , time_reorder]) # score for negative random frame 86 | preds[t].append(neg_pred) 87 | 88 | return preds 89 | 90 | def loss(self, preds, lengths): 91 | loss = 0 92 | for t, t_preds in preds.items(): 93 | mask = length_to_mask(lengths - t) 94 | out = torch.stack(t_preds, dim=-1) 95 | out = F.log_softmax(out, dim=-1) 96 | out = out[...,0] * mask 97 | loss += -out.mean() 98 | return loss 99 | 100 | @hydra.main(config_path='conf/config.yaml', strict=False) 101 | def main(cfg): 102 | ds, _, _ = TrainTestDataset.get_datasets(cfg.timit_path) 103 | spect, seg, phonemes, length, fname = ds[0] 104 | spect = spect.unsqueeze(0) 105 | 106 | model = NextFrameClassifier(cfg) 107 | out = model(spect, length) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dill 3 | from argparse import Namespace 4 | import torch 5 | import torchaudio 6 | from utils import (detect_peaks, max_min_norm, replicate_first_k_frames) 7 | from next_frame_classifier import NextFrameClassifier 8 | 9 | 10 | def main(wav, ckpt, prominence): 11 | print(f"running inference on: {wav}") 12 | print(f"running inferece using ckpt: {ckpt}") 13 | print("\n\n", 90 * "-") 14 | 15 | ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage) 16 | hp = Namespace(**dict(ckpt["hparams"])) 17 | 18 | # load weights and peak detection params 19 | model = NextFrameClassifier(hp) 20 | weights = ckpt["state_dict"] 21 | weights = {k.replace("NFC.", ""): v for k,v in weights.items()} 22 | model.load_state_dict(weights) 23 | peak_detection_params = dill.loads(ckpt['peak_detection_params'])['cpc_1'] 24 | if prominence is not None: 25 | print(f"overriding prominence with {prominence}") 26 | peak_detection_params["prominence"] = prominence 27 | 28 | # load data 29 | audio, sr = torchaudio.load(wav) 30 | assert sr == 16000, "model was trained with audio sampled at 16khz, please downsample." 31 | audio = audio[0] 32 | audio = audio.unsqueeze(0) 33 | 34 | # run inference 35 | preds = model(audio) # get scores 36 | preds = preds[1][0] # get scores of positive pairs 37 | preds = replicate_first_k_frames(preds, k=1, dim=1) # padding 38 | preds = 1 - max_min_norm(preds) # normalize scores (good for visualizations) 39 | preds = detect_peaks(x=preds, 40 | lengths=[preds.shape[1]], 41 | prominence=peak_detection_params["prominence"], 42 | width=peak_detection_params["width"], 43 | distance=peak_detection_params["distance"]) # run peak detection on scores 44 | preds = preds[0] * 160 / sr # transform frame indexes to seconds 45 | 46 | print("predicted boundaries (in seconds):") 47 | print(preds) 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser(description='Unsupervised segmentation inference script') 51 | parser.add_argument('--wav', help='path to wav file') 52 | parser.add_argument('--ckpt', help='path to checkpoint file') 53 | parser.add_argument('--prominence', type=float, default=None, help='prominence for peak detection (default: 0.05)') 54 | args = parser.parse_args() 55 | main(args.wav, args.ckpt, args.prominence) 56 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/pretrained_models/buckeye+_pretrained.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishaiElmakies/unsupervised_speech_segmentation_using_slm/b0e4fe17bc2b700de6fce31e49c1cde789bad8fd/speech_lm/UnsupSeg/pretrained_models/buckeye+_pretrained.ckpt -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/pretrained_models/buckeye_pretrained.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishaiElmakies/unsupervised_speech_segmentation_using_slm/b0e4fe17bc2b700de6fce31e49c1cde789bad8fd/speech_lm/UnsupSeg/pretrained_models/buckeye_pretrained.ckpt -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/pretrained_models/timit+_pretrained.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishaiElmakies/unsupervised_speech_segmentation_using_slm/b0e4fe17bc2b700de6fce31e49c1cde789bad8fd/speech_lm/UnsupSeg/pretrained_models/timit+_pretrained.ckpt -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/pretrained_models/timit_pretrained.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishaiElmakies/unsupervised_speech_segmentation_using_slm/b0e4fe17bc2b700de6fce31e49c1cde789bad8fd/speech_lm/UnsupSeg/pretrained_models/timit_pretrained.ckpt -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/scripts/make_timit.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | import shutil 5 | from tqdm import tqdm 6 | 7 | def main(inpath, outpath): 8 | if not os.path.exists(inpath): 9 | print('Error: input path does not exist!!') 10 | return -1 11 | if not os.path.exists(outpath): 12 | os.makedirs(outpath, exist_ok=True) 13 | 14 | for _f in tqdm(os.listdir(inpath)): 15 | parent_f = os.path.join(inpath, _f) 16 | if os.path.isdir(parent_f): 17 | for _ff in os.listdir(parent_f): 18 | parent_ff = os.path.join(parent_f, _ff) 19 | if os.path.isdir(parent_ff): 20 | for ex in os.listdir(parent_ff): 21 | if ex.endswith('.phn') or ex.endswith('.wav'): 22 | src_name = os.path.join(parent_ff, ex) 23 | tgt_name = os.path.join(outpath, _f+'_'+_ff+'_'+ex) 24 | shutil.copy(src_name, tgt_name) 25 | 26 | parser = argparse.ArgumentParser(description='Make TIMIT dataset ready for unsupervised segmentation.') 27 | parser.add_argument('--inpath', type=str, required=True, help='the path to the base timit dir.') 28 | parser.add_argument('--outpath', type=str, required=True, help='the path to save the new format of the data.') 29 | 30 | args = parser.parse_args() 31 | 32 | main(args.inpath, args.outpath) 33 | 34 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/scripts/preprocess_buckeye.py: -------------------------------------------------------------------------------- 1 | import random 2 | import soundfile as sf 3 | import buckeye 4 | from tqdm import tqdm 5 | import numpy as np 6 | from boltons import fileutils 7 | import os 8 | import os.path as osp 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description=__doc__) 12 | parser.add_argument('--spkr', default=False, action='store_true') 13 | parser.add_argument('--source', default=False) 14 | parser.add_argument('--target', default=False) 15 | parser.add_argument('--min_phonemes', type=int) 16 | parser.add_argument('--max_phonemes', type=int) 17 | args = parser.parse_args() 18 | 19 | 20 | DELIMITER = ['VOCNOISE', 'NOISE', 'SIL'] 21 | FORBIDDEN = ['{B_TRANS}', '{E_TRANS}', '', 'LAUGH', 'UNKNOWN', 'IVER-LAUGH', '', 'IVER'] 22 | MIN_PHONEMES = args.min_phonemes 23 | MAX_PHONEMES = args.max_phonemes 24 | NOISE_EDGES = 0.2 25 | is_delim = lambda x: x.seg in DELIMITER 26 | contain_forbidden = lambda phone_list: not set([p.seg for p in phone_list]).isdisjoint(FORBIDDEN) 27 | path = args.source 28 | output_path = args.target 29 | train_path = osp.join(output_path, "train") 30 | val_path = osp.join(output_path, "val") 31 | test_path = osp.join(output_path, "test") 32 | 33 | wavs = list(fileutils.iter_find_files(path, "*.wav")) 34 | files = [] 35 | segments = [] 36 | file_counter = 0 37 | 38 | os.makedirs(output_path, exist_ok=True) 39 | os.makedirs(train_path, exist_ok=True) 40 | os.makedirs(val_path, exist_ok=True) 41 | os.makedirs(test_path, exist_ok=True) 42 | 43 | for wav in tqdm(wavs): 44 | try: 45 | spkr = osp.basename(wav)[:3] 46 | name = wav.replace(".wav", "") 47 | words = wav.replace("wav", "words") 48 | phones = wav.replace("wav", "phones") 49 | log = wav.replace("wav", "log") 50 | txt = wav.replace("wav", "txt") 51 | track = buckeye.Track(name=name, 52 | words=words, 53 | phones=phones, 54 | log=log, 55 | txt=txt, 56 | wav=wav) 57 | phones = track.phones[1:-1] 58 | delim_locations = np.array([i for i, phone in enumerate(phones) if is_delim(phone)]) 59 | loaded_wav, sr = sf.read(wav) 60 | 61 | # in some files the last segment annotation ends after the 62 | # actual wav file, ignore those files. 63 | if phones[-1].end >= loaded_wav.shape[0] / sr: 64 | print(f"last phone end: {phones[-1].end}") 65 | print(f"len of wav: {loaded_wav.shape[0] / sr}") 66 | print(f"skipping {wav}") 67 | continue 68 | 69 | # iterate over all phone segments inside wav 70 | for start, end in zip(delim_locations[:-1], delim_locations[1:]): 71 | segment = phones[start:end+1] 72 | 73 | # if contains forbidden annotations, or number of segments is 74 | # not in the desired range => ignore 75 | if contain_forbidden(segment) or not (MIN_PHONEMES < end - start < MAX_PHONEMES): 76 | continue 77 | 78 | # make sure that the noise/sil on the edges is less than 79 | # NOISE_EDGES seconds 80 | if segment[0].end - segment[0].beg > NOISE_EDGES: 81 | segment[0]._beg = segment[0].end - NOISE_EDGES 82 | if segment[-1].end - segment[-1].beg > NOISE_EDGES: 83 | segment[-1]._end = segment[-1].beg + NOISE_EDGES 84 | 85 | # get stat and end times 86 | segment_start_time = segment[0].beg 87 | segment_end_time = segment[-1].end 88 | 89 | # trim wav according to start and end 90 | # also, extract from the .phn file the corresponding phonemes 91 | output_wav_file = f"{spkr}_{file_counter}.wav" 92 | output_phn_file = f"{spkr}_{file_counter}.phn" 93 | track.clip_wav(osp.join(output_path, output_wav_file), segment_start_time, segment_end_time) 94 | phn_data = "\n".join([f"{int((p.beg - segment_start_time) * sr)} {int((p.end - segment_start_time) * sr)} {p.seg}" for p in segment]) 95 | with open(osp.join(output_path, output_phn_file), "w") as f: 96 | f.writelines(phn_data) 97 | file_counter += 1 98 | 99 | segments.append(segment) 100 | except UnboundLocalError: 101 | print(f"loading {wav} failed!") 102 | 103 | lens = np.array([len(seg) for seg in segments]) 104 | secs = np.array([seg[-1].end - seg[0].beg for seg in segments]) 105 | print(f"{len(segments)} items") 106 | print(f"avg len: {lens.mean()}") 107 | print(f"min len: {lens.min()}") 108 | print(f"max len: {lens.max()}") 109 | print(f"avg sec: {secs.mean()}") 110 | print(f"min sec: {secs.min()}") 111 | print(f"max sec: {secs.max()}") 112 | print(f"{secs.sum() / (60*60)} hours") 113 | 114 | if args.spkr: 115 | os.chdir(output_path) 116 | test_spkrs = ["s07", "s03", "s31", "s34"] 117 | val_spkrs = ["s40", "s39", "s36", "s25"] 118 | for spkr in test_spkrs: 119 | os.system(f"mv {spkr}* test/") 120 | for spkr in val_spkrs: 121 | os.system(f"mv {spkr}* val/") 122 | os.system(f"mv *.wav train/") 123 | os.system(f"mv *.phn train/") 124 | else: 125 | splits = [0.8, 0.9, 1.0] 126 | wavs = list(fileutils.iter_find_files(output_path, "*.wav")) 127 | random.shuffle(wavs) 128 | for i, wav in enumerate(wavs): 129 | if i < len(wavs) * splits[0]: 130 | target = train_path 131 | elif len(wavs) * splits[0] <= i and i < len(wavs) * splits[1]: 132 | target = val_path 133 | else: 134 | target = test_path 135 | phn = wav.replace("wav", "phn") 136 | os.rename(wav, wav.replace(output_path, target + "/")) 137 | os.rename(phn, phn.replace(output_path, target + "/")) 138 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/solver.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from collections import OrderedDict, defaultdict 3 | 4 | import dill 5 | import wandb 6 | 7 | import pytorch_lightning as pl 8 | from pytorch_lightning import LightningModule 9 | import torch_optimizer as optim_extra 10 | 11 | import torch 12 | from torch import optim 13 | from torch.utils.data import ConcatDataset, DataLoader 14 | import torchaudio 15 | 16 | from dataloader import (LibriSpeechDataset, MixedDataset, TrainTestDataset, 17 | TrainValTestDataset, collate_fn_padd, spectral_size) 18 | from next_frame_classifier import NextFrameClassifier 19 | from utils import (PrecisionRecallMetric, StatsMeter, 20 | detect_peaks, line, max_min_norm, replicate_first_k_frames) 21 | 22 | 23 | class Solver(LightningModule): 24 | def __init__(self, hparams): 25 | super(Solver, self).__init__() 26 | hp = hparams 27 | self.hp = hp 28 | self.hparams = hp 29 | self.peak_detection_params = defaultdict(lambda: { 30 | "prominence": 0.05, 31 | "width": None, 32 | "distance": None 33 | }) 34 | self.pr = defaultdict(lambda: { 35 | "train": PrecisionRecallMetric(), 36 | "val": PrecisionRecallMetric(), 37 | "test": PrecisionRecallMetric() 38 | }) 39 | self.best_rval = defaultdict(lambda: { 40 | "train": (0, 0), 41 | "val": (0, 0), 42 | "test": (0, 0) 43 | }) 44 | self.overall_best_rval = 0 45 | self.stats = defaultdict(lambda: { 46 | "train": StatsMeter(), 47 | "val": StatsMeter(), 48 | "test": StatsMeter() 49 | }) 50 | 51 | wandb.init(project=self.hp.project, name=hp.exp_name, config=vars(hp), tags=[hp.tag]) 52 | self.build_model() 53 | 54 | def prepare_data(self): 55 | # setup training set 56 | if "timit" in self.hp.data: 57 | train, val, test = TrainTestDataset.get_datasets(path=self.hp.timit_path) 58 | elif "buckeye" in self.hp.data: 59 | train, val, test = TrainValTestDataset.get_datasets(path=self.hp.buckeye_path, percent=self.hp.buckeye_percent) 60 | else: 61 | raise Exception("no such training data!") 62 | 63 | if "libri" in self.hp.data: 64 | libri_train = LibriSpeechDataset(path=self.hp.libri_path, 65 | subset=self.hp.libri_subset, 66 | percent=self.hp.libri_percent) 67 | train = ConcatDataset([train, libri_train]) 68 | train.path = "\n\t+".join([dataset.path for dataset in train.datasets]) 69 | print(f"added libri ({len(libri_train)} examples)") 70 | 71 | self.train_dataset = train 72 | self.valid_dataset = val 73 | self.test_dataset = test 74 | 75 | line() 76 | print("DATA:") 77 | print(f"train: {self.train_dataset.path} ({len(self.train_dataset)})") 78 | print(f"valid: {self.valid_dataset.path} ({len(self.valid_dataset)})") 79 | print(f"test: {self.test_dataset.path} ({len(self.test_dataset)})") 80 | line() 81 | 82 | 83 | @pl.data_loader 84 | def train_dataloader(self): 85 | self.train_loader = DataLoader(self.train_dataset, 86 | batch_size=self.hp.batch_size, 87 | shuffle=True, 88 | collate_fn=collate_fn_padd, 89 | num_workers=self.hp.dataloader_n_workers) 90 | return self.train_loader 91 | 92 | @pl.data_loader 93 | def val_dataloader(self): 94 | self.valid_loader = DataLoader(self.valid_dataset, 95 | batch_size=self.hp.batch_size, 96 | shuffle=False, 97 | collate_fn=collate_fn_padd, 98 | num_workers=self.hp.dataloader_n_workers) 99 | return self.valid_loader 100 | 101 | @pl.data_loader 102 | def test_dataloader(self): 103 | self.test_loader = DataLoader(self.test_dataset, 104 | batch_size=self.hp.batch_size, 105 | shuffle=False, 106 | collate_fn=collate_fn_padd, 107 | num_workers=self.hp.dataloader_n_workers) 108 | return self.test_loader 109 | 110 | def build_model(self): 111 | print("MODEL:") 112 | self.NFC = NextFrameClassifier(self.hp) 113 | line() 114 | 115 | def forward(self, data_batch, batch_i, mode): 116 | loss = 0 117 | 118 | # TRAIN 119 | audio, seg, phonemes, length, fname = data_batch 120 | preds = self.NFC(audio) 121 | NFC_loss = self.NFC.loss(preds, length) 122 | self.stats['nfc_loss'][mode].update(NFC_loss.item()) 123 | loss += NFC_loss 124 | 125 | # INFERENCE 126 | if mode == "test" or (mode == "val" and self.hp.early_stop_metric == "val_max_rval"): 127 | positives = 0 128 | for t in self.NFC.pred_steps: 129 | p = preds[t][0] 130 | p = replicate_first_k_frames(p, k=t, dim=1) 131 | positives += p 132 | positives = 1 - max_min_norm(positives) 133 | self.pr[f'cpc_{t}'][mode].update(seg, positives, length) 134 | 135 | loss_key = "loss" if mode == "train" else f"{mode}_loss" 136 | return OrderedDict({ 137 | loss_key: loss 138 | }) 139 | 140 | def generic_eval_end(self, outputs, mode): 141 | metrics = {} 142 | data = self.hp.data 143 | 144 | for k, v in self.stats.items(): 145 | metrics[f"train_{k}"] = self.stats[k]["train"].get_stats() 146 | metrics[f"{mode}_{k}"] = self.stats[k][mode].get_stats() 147 | 148 | epoch = self.current_epoch + 1 149 | metrics['epoch'] = epoch 150 | metrics['current_lr'] = self.opt.param_groups[0]['lr'] 151 | 152 | line() 153 | for pred_type in self.pr.keys(): 154 | if mode == "val": 155 | (precision, recall, f1, rval), (width, prominence, distance) = self.pr[pred_type][mode].get_stats() 156 | if rval > self.best_rval[pred_type][mode][0]: 157 | self.best_rval[pred_type][mode] = rval, self.current_epoch 158 | self.peak_detection_params[pred_type]["width"] = width 159 | self.peak_detection_params[pred_type]["prominence"] = prominence 160 | self.peak_detection_params[pred_type]["distance"] = distance 161 | self.peak_detection_params[pred_type]["epoch"] = self.current_epoch 162 | print(f"saving for test - {pred_type} - {self.peak_detection_params[pred_type]}") 163 | else: 164 | print(f"using pre-defined peak detection values - {pred_type} - {self.peak_detection_params[pred_type]}") 165 | (precision, recall, f1, rval), _ = self.pr[pred_type][mode].get_stats( 166 | width=self.peak_detection_params[pred_type]["width"], 167 | prominence=self.peak_detection_params[pred_type]["prominence"], 168 | distance=self.peak_detection_params[pred_type]["distance"], 169 | ) 170 | # test has only one epoch so set it as best 171 | # this is to get the overall best pred_type later 172 | self.best_rval[pred_type][mode] = rval, self.current_epoch 173 | metrics[f'{data}_{mode}_{pred_type}_f1'] = f1 174 | metrics[f'{data}_{mode}_{pred_type}_precision'] = precision 175 | metrics[f'{data}_{mode}_{pred_type}_recall'] = recall 176 | metrics[f'{data}_{mode}_{pred_type}_rval'] = rval 177 | metrics[f"{data}_{mode}_{pred_type}_max_rval"] = self.best_rval[pred_type][mode][0] 178 | metrics[f"{data}_{mode}_{pred_type}_max_rval_epoch"] = self.best_rval[pred_type][mode][1] 179 | 180 | # get best rval from all rval types and all epochs 181 | best_overall_rval = -float("inf") 182 | for pred_type, rval in self.best_rval.items(): 183 | if rval[mode][0] > best_overall_rval: 184 | best_overall_rval = rval[mode][0] 185 | metrics[f'{mode}_max_rval'] = best_overall_rval 186 | 187 | for k, v in metrics.items(): 188 | print(f"\t{k:<30} -- {v}") 189 | line() 190 | wandb.log(metrics) 191 | 192 | output = OrderedDict({ 193 | 'log': metrics 194 | }) 195 | 196 | return output 197 | 198 | def training_step(self, data_batch, batch_i): 199 | return self.forward(data_batch, batch_i, 'train') 200 | 201 | def validation_step(self, data_batch, batch_i): 202 | return self.forward(data_batch, batch_i, 'val') 203 | 204 | def test_step(self, data_batch, batch_i): 205 | return self.forward(data_batch, batch_i, 'test') 206 | 207 | def validation_end(self, outputs): 208 | return self.generic_eval_end(outputs, 'val') 209 | 210 | def test_end(self, outputs): 211 | return self.generic_eval_end(outputs, 'test') 212 | 213 | def configure_optimizers(self): 214 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 215 | if self.hp.optimizer == "sgd": 216 | self.opt = optim.SGD(parameters, lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4) 217 | elif self.hp.optimizer == "adam": 218 | self.opt = optim.Adam(parameters, lr=self.hparams.lr, weight_decay=5e-4) 219 | elif self.hp.optimizer == "ranger": 220 | self.opt = optim_extra.Ranger(parameters, lr=self.hparams.lr, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95, 0.999), eps=1e-5, weight_decay=0) 221 | else: 222 | raise Exception("unknown optimizer") 223 | print(f"optimizer: {self.opt}") 224 | line() 225 | self.scheduler = optim.lr_scheduler.StepLR(self.opt, 226 | step_size=self.hp.lr_anneal_step, 227 | gamma=self.hp.lr_anneal_gamma) 228 | return [self.opt] 229 | 230 | def on_epoch_end(self): 231 | self.scheduler.step() 232 | 233 | def on_save_checkpoint(self, ckpt): 234 | ckpt['peak_detection_params'] = dill.dumps(self.peak_detection_params) 235 | 236 | def on_load_checkpoint(self, ckpt): 237 | self.peak_detection_params = dill.loads(ckpt['peak_detection_params']) 238 | 239 | def get_ckpt_path(self): 240 | return glob.glob(self.hp.wd + "/*.ckpt")[0] 241 | -------------------------------------------------------------------------------- /speech_lm/UnsupSeg/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import time 6 | from scipy.signal import find_peaks 7 | from tqdm import tqdm 8 | 9 | 10 | def replicate_first_k_frames(x, k, dim): 11 | return torch.cat([x.index_select(dim=dim, index=torch.LongTensor([0] * k).to(x.device)), x], dim=dim) 12 | 13 | 14 | class LambdaLayer(nn.Module): 15 | def __init__(self, lambd): 16 | super(LambdaLayer, self).__init__() 17 | self.lambd = lambd 18 | def forward(self, x): 19 | return self.lambd(x) 20 | 21 | 22 | class PrintShapeLayer(nn.Module): 23 | def __init__(self): 24 | super(PrintShapeLayer, self).__init__() 25 | def forward(self, x): 26 | print(x.shape) 27 | return x 28 | 29 | 30 | def length_to_mask(length, max_len=None, dtype=None): 31 | """length: B. 32 | return B x max_len. 33 | If max_len is None, then max of length will be used. 34 | """ 35 | assert len(length.shape) == 1, 'Length shape should be 1 dimensional.' 36 | max_len = max_len or length.max().item() 37 | mask = torch.arange(max_len, device=length.device, 38 | dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1) 39 | if dtype is not None: 40 | mask = torch.as_tensor(mask, dtype=dtype, device=length.device) 41 | return mask 42 | 43 | 44 | def detect_peaks(x, lengths, prominence=0.1, width=None, distance=None): 45 | """detect peaks of next_frame_classifier 46 | 47 | Arguments: 48 | x {Tensor} -- batch of confidence per time 49 | """ 50 | out = [] 51 | 52 | for xi, li in zip(x, lengths): 53 | if type(xi) == torch.Tensor: 54 | xi = xi.cpu().detach().numpy() 55 | xi = xi[:li] # shorten to actual length 56 | xmin, xmax = xi.min(), xi.max() 57 | xi = (xi - xmin) / (xmax - xmin) 58 | peaks, _ = find_peaks(xi, prominence=prominence, width=width, distance=distance) 59 | 60 | if len(peaks) == 0: 61 | peaks = np.array([len(xi)-1]) 62 | 63 | out.append(peaks) 64 | 65 | return out 66 | 67 | 68 | class PrecisionRecallMetric: 69 | def __init__(self): 70 | self.precision_counter = 0 71 | self.recall_counter = 0 72 | self.pred_counter = 0 73 | self.gt_counter = 0 74 | self.eps = 1e-5 75 | self.data = [] 76 | self.tolerance = 2 77 | self.prominence_range = np.arange(0, 0.15, 0.01) 78 | self.width_range = [None, 1] 79 | self.distance_range = [None, 1] 80 | 81 | def get_metrics(self, precision_counter, recall_counter, pred_counter, gt_counter): 82 | EPS = 1e-7 83 | 84 | precision = precision_counter / (pred_counter + self.eps) 85 | recall = recall_counter / (gt_counter + self.eps) 86 | f1 = 2 * (precision * recall) / (precision + recall + self.eps) 87 | 88 | os = recall / (precision + EPS) - 1 89 | r1 = np.sqrt((1 - recall) ** 2 + os ** 2) 90 | r2 = (-os + recall - 1) / (np.sqrt(2)) 91 | rval = 1 - (np.abs(r1) + np.abs(r2)) / 2 92 | 93 | return precision, recall, f1, rval 94 | 95 | def zero(self): 96 | self.data = [] 97 | 98 | def update(self, seg, pos_pred, length): 99 | for seg_i, pos_pred_i, length_i in zip(seg, pos_pred, length): 100 | self.data.append((seg_i, pos_pred_i.cpu().detach().numpy(), length_i.item())) 101 | 102 | def get_stats(self, width=None, prominence=None, distance=None): 103 | print(f"calculating metrics using {len(self.data)} entries") 104 | max_rval = -float("inf") 105 | best_params = None 106 | segs = list(map(lambda x: x[0], self.data)) 107 | length = list(map(lambda x: x[2], self.data)) 108 | yhats = list(map(lambda x: x[1], self.data)) 109 | 110 | width_range = self.width_range 111 | distance_range = self.distance_range 112 | prominence_range = self.prominence_range 113 | 114 | # when testing, we would override the search with specific values from validation 115 | if prominence is not None: 116 | width_range = [width] 117 | distance_range = [distance] 118 | prominence_range = [prominence] 119 | 120 | for width in width_range: 121 | for prominence in prominence_range: 122 | for distance in distance_range: 123 | precision_counter = 0 124 | recall_counter = 0 125 | pred_counter = 0 126 | gt_counter = 0 127 | peaks = detect_peaks(yhats, 128 | length, 129 | prominence=prominence, 130 | width=width, 131 | distance=distance) 132 | 133 | for (y, yhat) in zip(segs, peaks): 134 | for yhat_i in yhat: 135 | min_dist = np.abs(y - yhat_i).min() 136 | precision_counter += (min_dist <= self.tolerance) 137 | for y_i in y: 138 | min_dist = np.abs(yhat - y_i).min() 139 | recall_counter += (min_dist <= self.tolerance) 140 | pred_counter += len(yhat) 141 | gt_counter += len(y) 142 | 143 | p, r, f1, rval = self.get_metrics(precision_counter, 144 | recall_counter, 145 | pred_counter, 146 | gt_counter) 147 | if rval > max_rval: 148 | max_rval = rval 149 | best_params = width, prominence, distance 150 | out = (p, r, f1, rval) 151 | self.zero() 152 | print(f"best peak detection params: {best_params} (width, prominence, distance)") 153 | return out, best_params 154 | 155 | 156 | class StatsMeter: 157 | def __init__(self): 158 | self.data = [] 159 | 160 | def update(self, item): 161 | if type(item) == list: 162 | self.data.extend(item) 163 | else: 164 | self.data.append(item) 165 | 166 | def get_stats(self): 167 | data = np.array(self.data) 168 | mean = data.mean() 169 | self.zero() 170 | return mean 171 | 172 | def zero(self): 173 | self.data.clear() 174 | assert len(self.data) == 0, "StatsMeter didn't clear" 175 | 176 | 177 | class Timer: 178 | def __init__(self, msg): 179 | self.msg = msg 180 | self.start_time = None 181 | 182 | def __enter__(self): 183 | self.start_time = time.time() 184 | print(f"{self.msg} -- started") 185 | 186 | def __exit__(self, exc_type, exc_value, exc_tb): 187 | print(f"{self.msg} -- done in {(time.time() - self.start_time)} secs") 188 | 189 | 190 | def max_min_norm(x): 191 | x -= x.min(-1, keepdim=True)[0] 192 | x /= x.max(-1, keepdim=True)[0] 193 | return x 194 | 195 | 196 | def line(): 197 | print(90 * "-") 198 | -------------------------------------------------------------------------------- /speech_lm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avishaiElmakies/unsupervised_speech_segmentation_using_slm/b0e4fe17bc2b700de6fce31e49c1cde789bad8fd/speech_lm/__init__.py -------------------------------------------------------------------------------- /speech_lm/configs/README.md: -------------------------------------------------------------------------------- 1 | ## Segmentor Configs Explanation 2 | 3 | This readme will be used to explain config files format. 4 | 5 | We use to config files to have an easy and flexible way to set choose how the segmentor will work and make it flexiable. 6 | 7 | As the paper states, most approaches used in the paper use 3 parts: 8 | 1. sentencer 9 | 2. scorer 10 | 3. span selector 11 | 12 | We create a segmentor mixing and matching different configs for each of those parts. 13 | 14 | this gives us a format for a segmentor: 15 | 16 | ``` 17 | { 18 | "sentencer": { 19 | ... 20 | }, 21 | "scorer": { 22 | ... 23 | }, 24 | "sselector":{ 25 | ... 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": 29 | } 30 | 31 | ``` 32 | 33 | if we don't need some part in the pipeline we can omit it (this of course depends on the `type` of the segmentor, e.g. equal length segmentor doesn't need a scorer,unsupseg doesn't need any of those) 34 | 35 | ### Segmentor Config 36 | 37 | this is the full config given to [audio_segment.py](../../audio_segment.py). 38 | 39 | The required keys/arguments needed for the segmentor are: 40 | - `type`: type of segmentor to use. currently available (i) `"speech_pmi"`(ii) `'equal_length'` (iii) `'next_frame'` 41 | - `default_sample_rate`: default sample rate used. 42 | 43 | ### Sentencer Config 44 | 45 | This is a config used for the sentencer in the pipeline. 46 | 47 | required keys/arguments: 48 | - `type`: type of sentencer. currently the only type available is `"constant"` 49 | 50 | for more information and arguments for the sentencer look at [speech_sentencer.py](../speech_sentencer.py) 51 | 52 | ### Scorer Config 53 | 54 | This is a config used for the scorer in the pipeline. 55 | 56 | required keys/arguments: 57 | 58 | - `type`: type of scorer to use. currently the only one available is `pmi`. 59 | for `speech_pmi` you also need an `inference_model` config with option `model_name:"TWIST-(350M/1-3B/7B)"`, `model_type:slm` and a tokenizer that needs 5 arguments 60 | 61 | this gives us the following config 62 | 63 | ``` 64 | "scorer": { 65 | "type": "pmi", 66 | 67 | "inference_model": { 68 | "model_name": "TWIST-7B", 69 | "model_type": "slm", 70 | "tokenizer": { 71 | "dense_model_name": "", 72 | "quantizer_model_name": "", 73 | "encoder_vocab_size": , 74 | "deduplicate": True/False, 75 | "need_f0": True/False 76 | } 77 | } 78 | }, 79 | ``` 80 | 81 | #### tokenizer config 82 | 83 | the tokenizer uses textless. it has 5 arguments needed 84 | 85 | - `dense_model_name`: model used for embedding 86 | - `quantizer_model_name` : quantizer used to convert embeddings into tokens 87 | - `encoder_vocab_size`: vocab size 88 | - `deduplicate`: deduplicate the tokens 89 | - `need_f0`: get f0 as well from the tokenizer 90 | 91 | for more information look at [textless-lib](https://github.com/facebookresearch/textlesslib) 92 | 93 | 94 | all expirements used in the paper use the same tokenizer config 95 | 96 | ``` 97 | "tokenizer": { 98 | "dense_model_name": "mhubert-base-25hz", 99 | "quantizer_model_name" : "kmeans", 100 | "encoder_vocab_size": 500, 101 | "deduplicate": true, 102 | "need_f0": false 103 | } 104 | ``` 105 | 106 | 107 | ### Span Selector Config 108 | 109 | This is a config used for the span selector in the pipeline. we use the name sselector 110 | 111 | required keys/arguments: 112 | - `type`: type of span selector. currently available (i) constant (ii) adaptive (iii) threshold 113 | 114 | each span selector has it's own keys/arguments. I recommend looking at the file [span_selector.py](../spans_selector.py) for more info 115 | 116 | 117 | ### All configs files used in the paper are available and can be directly used using the file [audio_segment.py](../../audio_segment.py) 118 | 119 | 120 | -------------------------------------------------------------------------------- /speech_lm/configs/inference/TWIST-1-3B.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name" : "TWIST-1.3B", 3 | "model_type": "slm", 4 | "mean_nll": true, 5 | "tokenizer": { 6 | "dense_model_name": "mhubert-base-25hz", 7 | "quantizer_model_name" : "kmeans", 8 | "encoder_vocab_size": 500, 9 | "deduplicate": true, 10 | "need_f0": false 11 | } 12 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/diarization/emotion_diarization_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "type":"emotion_diarization", 3 | "source":"speechbrain/emotion-diarization-wavlm-large" 4 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/diarization/speaker_diarization_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "type":"speaker_diarization", 3 | "pipeline_name":"pyannote/speaker-diarization-3.0" 4 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/equal_length/equal_length_config_adaptive1.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "sselector":{ 9 | "type": "adaptive", 10 | "base_segments": 4, 11 | "len_offset": 20, 12 | "sentences_for_segment": 5, 13 | "descending": false 14 | }, 15 | "default_sample_rate": 16000, 16 | "type":"equal_length" 17 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/equal_length/equal_length_config_adaptive2.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "sselector":{ 9 | "type": "adaptive", 10 | "base_segments": 4, 11 | "len_offset": 20, 12 | "sentences_for_segment": 10, 13 | "descending": false 14 | }, 15 | "default_sample_rate": 16000, 16 | "type":"equal_length" 17 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/equal_length/equal_length_config_adaptive3.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "sselector":{ 9 | "type": "adaptive", 10 | "base_segments": 4, 11 | "len_offset": 20, 12 | "sentences_for_segment": 15, 13 | "descending": false 14 | }, 15 | "default_sample_rate": 16000, 16 | "type":"equal_length" 17 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/equal_length/equal_length_config_adaptive4.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "sselector":{ 9 | "type": "adaptive", 10 | "base_segments": 4, 11 | "len_offset": 20, 12 | "sentences_for_segment": 20, 13 | "descending": false 14 | }, 15 | "default_sample_rate": 16000, 16 | "type":"equal_length" 17 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/equal_length/equal_length_config_constant_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "sselector":{ 9 | "type": "constant", 10 | "num_segments":10, 11 | "descending": false 12 | }, 13 | "default_sample_rate": 16000, 14 | "type":"equal_length" 15 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/equal_length/equal_length_config_constant_15.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "sselector":{ 9 | "type": "constant", 10 | "num_segments":15, 11 | "descending": false 12 | }, 13 | "default_sample_rate": 16000, 14 | "type":"equal_length" 15 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/equal_length/equal_length_config_constant_20.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "sselector":{ 9 | "type": "constant", 10 | "num_segments":20, 11 | "descending": false 12 | }, 13 | "default_sample_rate": 16000, 14 | "type":"equal_length" 15 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive1.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 5, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive1_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 5, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive1_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "adaptive", 25 | "base_segments": 4, 26 | "len_offset": 20, 27 | "sentences_for_segment": 5, 28 | "descending": false 29 | }, 30 | "default_sample_rate": 16000, 31 | "type": "speech_pmi" 32 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive2.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type" : "pmi", 10 | "inference_model": { 11 | "model_name" : "TWIST-350M", 12 | "model_type": "slm", 13 | "mean_nll": false, 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name" : "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector":{ 24 | "type": "adaptive", 25 | "base_segments": 4, 26 | "len_offset": 20, 27 | "sentences_for_segment": 10, 28 | "descending": false 29 | }, 30 | "default_sample_rate": 16000, 31 | "type":"speech_pmi" 32 | } 33 | -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive2_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type" : "pmi", 10 | "inference_model": { 11 | "model_name" : "TWIST-350M", 12 | "model_type": "slm", 13 | "mean_nll": false, 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name" : "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector":{ 24 | "type": "adaptive", 25 | "base_segments": 4, 26 | "len_offset": 20, 27 | "sentences_for_segment": 10, 28 | "descending": false 29 | }, 30 | "default_sample_rate": 16000, 31 | "type":"speech_pmi" 32 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive2_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "adaptive", 25 | "base_segments": 4, 26 | "len_offset": 20, 27 | "sentences_for_segment": 10, 28 | "descending": false 29 | }, 30 | "default_sample_rate": 16000, 31 | "type": "speech_pmi" 32 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive3.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 15, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive3_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 15, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive3_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "adaptive", 25 | "base_segments": 4, 26 | "len_offset": 20, 27 | "sentences_for_segment": 15, 28 | "descending": false 29 | }, 30 | "default_sample_rate": 16000, 31 | "type": "speech_pmi" 32 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive4.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 20, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive4_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 20, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive/pmi_segmentor_config_adaptive4_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "adaptive", 25 | "base_segments": 4, 26 | "len_offset": 20, 27 | "sentences_for_segment": 20, 28 | "descending": false 29 | }, 30 | "default_sample_rate": 16000, 31 | "type": "speech_pmi" 32 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_1-5s/pmi_segmentor_config_adaptive1_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 5, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_1-5s/pmi_segmentor_config_adaptive2_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 10, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_1-5s/pmi_segmentor_config_adaptive3_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 15, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_1-5s/pmi_segmentor_config_adaptive4_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 20, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_1s/pmi_segmentor_config_adaptive1_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 5, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_1s/pmi_segmentor_config_adaptive2_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 10, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_1s/pmi_segmentor_config_adaptive3_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 15, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_1s/pmi_segmentor_config_adaptive4_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 20, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_2s/pmi_segmentor_config_adaptive1_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 5, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_2s/pmi_segmentor_config_adaptive2_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 10, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_2s/pmi_segmentor_config_adaptive3_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 15, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_adaptive_2s/pmi_segmentor_config_adaptive4_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "adaptive", 24 | "base_segments": 4, 25 | "len_offset": 20, 26 | "sentences_for_segment": 20, 27 | "descending": false 28 | }, 29 | "default_sample_rate": 16000, 30 | "type": "speech_pmi" 31 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant/pmi_segmentor_config_constant_10_1-3B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 10, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant/pmi_segmentor_config_constant_10_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 10, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant/pmi_segmentor_config_constant_10_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "constant", 25 | "num_segments": 10, 26 | "descending": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type": "speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant/pmi_segmentor_config_constant_15_1-3B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 15, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant/pmi_segmentor_config_constant_15_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 15, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant/pmi_segmentor_config_constant_15_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "constant", 25 | "num_segments": 15, 26 | "descending": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type": "speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant/pmi_segmentor_config_constant_20_1-3B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 20, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant/pmi_segmentor_config_constant_20_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 20, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant/pmi_segmentor_config_constant_20_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "constant", 25 | "num_segments": 20, 26 | "descending": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type": "speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant_1-5s/pmi_segmentor_config_constant_10_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 10, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant_1-5s/pmi_segmentor_config_constant_15_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 15, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant_1-5s/pmi_segmentor_config_constant_20_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 20, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant_1s/pmi_segmentor_config_constant_10_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type" : "speech_pmi", 10 | "model_type" : "TWIST-350M", 11 | "tokenizer": { 12 | "dense_model_name": "mhubert-base-25hz", 13 | "quantizer_model_name" : "kmeans", 14 | "encoder_vocab_size": 500, 15 | "deduplicate": true, 16 | "need_f0": false 17 | } 18 | }, 19 | "sselector":{ 20 | "type": "constant", 21 | "num_segments":10, 22 | "descending": false 23 | }, 24 | "default_sample_rate": 16000, 25 | "type":"speech_pmi" 26 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant_1s/pmi_segmentor_config_constant_15_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type" : "speech_pmi", 10 | "model_type" : "TWIST-350M", 11 | "tokenizer": { 12 | "dense_model_name": "mhubert-base-25hz", 13 | "quantizer_model_name" : "kmeans", 14 | "encoder_vocab_size": 500, 15 | "deduplicate": true, 16 | "need_f0": false 17 | } 18 | }, 19 | "sselector":{ 20 | "type": "constant", 21 | "num_segments":15, 22 | "descending": false 23 | }, 24 | "default_sample_rate": 16000, 25 | "type":"speech_pmi" 26 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant_1s/pmi_segmentor_config_constant_20_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type" : "speech_pmi", 10 | "model_type" : "TWIST-350M", 11 | "tokenizer": { 12 | "dense_model_name": "mhubert-base-25hz", 13 | "quantizer_model_name" : "kmeans", 14 | "encoder_vocab_size": 500, 15 | "deduplicate": true, 16 | "need_f0": false 17 | } 18 | }, 19 | "sselector":{ 20 | "type": "constant", 21 | "num_segments":20, 22 | "descending": false 23 | }, 24 | "default_sample_rate": 16000, 25 | "type":"speech_pmi" 26 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant_2s/pmi_segmentor_config_constant_10_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 10, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant_2s/pmi_segmentor_config_constant_15_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 15, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_constant_2s/pmi_segmentor_config_constant_20_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "constant", 24 | "num_segments": 20, 25 | "descending": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_0_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type" : "pmi", 10 | "inference_model": { 11 | "model_name" : "TWIST-350M", 12 | "model_type": "slm", 13 | "mean_nll": false, 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name" : "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector":{ 24 | "type": "threshold", 25 | "threshold": 0, 26 | "larger_than": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type":"speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -10, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_10_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type" : "pmi", 10 | "inference_model": { 11 | "model_name" : "TWIST-350M", 12 | "model_type": "slm", 13 | "mean_nll": false, 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name" : "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector":{ 24 | "type": "threshold", 25 | "threshold": -10, 26 | "larger_than": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type":"speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_10_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "threshold", 25 | "threshold": -10, 26 | "larger_than": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type": "speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_12-5.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -12.5, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_12-5_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -12.5, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_12-5_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "threshold", 25 | "threshold": -12.5, 26 | "larger_than": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type": "speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_12.5.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -12.5, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_15.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -15, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_15_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -15, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_15_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "threshold", 25 | "threshold": -15, 26 | "larger_than": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type": "speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_5.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -5, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_5_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type" : "length", 4 | "length" : 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type" : "pmi", 10 | "inference_model": { 11 | "model_name" : "TWIST-350M", 12 | "model_type": "slm", 13 | "mean_nll": false, 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name" : "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector":{ 24 | "type": "threshold", 25 | "threshold": -5, 26 | "larger_than": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type":"speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_5_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "threshold", 25 | "threshold": -5, 26 | "larger_than": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type": "speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_8.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-1.3B", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -8, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_8_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -8, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold/pmi_segmentor_config_threshold_8_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 0.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "batch_size": 16, 11 | "inference_model": { 12 | "model_name": "TWIST-7B", 13 | "model_type": "slm", 14 | "tokenizer": { 15 | "dense_model_name": "mhubert-base-25hz", 16 | "quantizer_model_name": "kmeans", 17 | "encoder_vocab_size": 500, 18 | "deduplicate": true, 19 | "need_f0": false 20 | } 21 | } 22 | }, 23 | "sselector": { 24 | "type": "threshold", 25 | "threshold": -8, 26 | "larger_than": false 27 | }, 28 | "default_sample_rate": 16000, 29 | "type": "speech_pmi" 30 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_1-5s/pmi_segmentor_config_threshold_10_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -10, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_1-5s/pmi_segmentor_config_threshold_12-5_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -12.5, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_1-5s/pmi_segmentor_config_threshold_15_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -15, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_1-5s/pmi_segmentor_config_threshold_5_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -5, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_1-5s/pmi_segmentor_config_threshold_8_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1.5, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -8, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_1s/pmi_segmentor_config_threshold_10_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -10, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_1s/pmi_segmentor_config_threshold_12-5_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -12.5, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_1s/pmi_segmentor_config_threshold_15_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -15, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_1s/pmi_segmentor_config_threshold_5_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -5, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_1s/pmi_segmentor_config_threshold_8_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 1, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -8, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_2s/pmi_segmentor_config_threshold_10_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -10, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_2s/pmi_segmentor_config_threshold_12-5_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -12.5, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_2s/pmi_segmentor_config_threshold_15_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -15, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_2s/pmi_segmentor_config_threshold_5_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -5, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/pmi_threshold_2s/pmi_segmentor_config_threshold_8_350M.json: -------------------------------------------------------------------------------- 1 | { 2 | "sentencer": { 3 | "type": "length", 4 | "length": 2, 5 | "min_length": 0.25, 6 | "drop_last": false 7 | }, 8 | "scorer": { 9 | "type": "pmi", 10 | "inference_model": { 11 | "model_name": "TWIST-350M", 12 | "model_type": "slm", 13 | "tokenizer": { 14 | "dense_model_name": "mhubert-base-25hz", 15 | "quantizer_model_name": "kmeans", 16 | "encoder_vocab_size": 500, 17 | "deduplicate": true, 18 | "need_f0": false 19 | } 20 | } 21 | }, 22 | "sselector": { 23 | "type": "threshold", 24 | "threshold": -8, 25 | "larger_than": false 26 | }, 27 | "default_sample_rate": 16000, 28 | "type": "speech_pmi" 29 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/unsupseg/unsupseg_config_constant_10.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_segments":10, 3 | "default_sample_rate": 16000, 4 | "type":"next_frame", 5 | "model_ckpt": "/cs/labs/oabend/avishai.elma/src/speech_lm/UnsupSeg/pretrained_models/buckeye+_pretrained.ckpt" 6 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/unsupseg/unsupseg_config_constant_15.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_segments":15, 3 | "default_sample_rate": 16000, 4 | "type":"next_frame", 5 | "model_ckpt": "/cs/labs/oabend/avishai.elma/src/speech_lm/UnsupSeg/pretrained_models/buckeye+_pretrained.ckpt" 6 | } -------------------------------------------------------------------------------- /speech_lm/configs/segmentors/unsupseg/unsupseg_config_constant_20.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_segments":20, 3 | "default_sample_rate": 16000, 4 | "type":"next_frame", 5 | "model_ckpt": "/cs/labs/oabend/avishai.elma/src/speech_lm/UnsupSeg/pretrained_models/buckeye+_pretrained.ckpt" 6 | } -------------------------------------------------------------------------------- /speech_lm/configs/tokenizers/tokenizer_500_vocab.json: -------------------------------------------------------------------------------- 1 | { 2 | "dense_model_name": "mhubert-base-25hz", 3 | "quantizer_model_name" : "kmeans", 4 | "encoder_vocab_size": 500, 5 | "deduplicate": true, 6 | "need_f0": false 7 | } -------------------------------------------------------------------------------- /speech_lm/evaluation/segmentation.py: -------------------------------------------------------------------------------- 1 | from typing import Any,Mapping,List 2 | import torch 3 | from pyannote.metrics.segmentation import (SegmentationPurityCoverageFMeasure, 4 | SegmentationCoverage, 5 | SegmentationPurity,SegmentationRecall,SegmentationPrecision) 6 | from pyannote.metrics.base import f_measure 7 | from pyannote.core import Segment, Timeline 8 | from math import sqrt 9 | import scipy.stats 10 | 11 | COVERAGE = "coverage" 12 | PURITY = "purity" 13 | CPFMEASURE = "cpfmeasure" 14 | RECALL = "recall" 15 | PRECISION = "precision" 16 | RPFMEASURE = "rpfmeasure" 17 | R_VALUE = "r_value" 18 | 19 | AVAILABLE_METRICS = [COVERAGE,PURITY, CPFMEASURE, RECALL, PRECISION,RPFMEASURE, R_VALUE] 20 | 21 | EPS = 1e-10 22 | 23 | TOL = 0.5 24 | 25 | def r_value(precision,recall): 26 | """ 27 | calc rvalue from precision and recall 28 | """ 29 | os = recall / (precision + EPS) - 1 30 | r1 = sqrt((1 - recall) ** 2 + os ** 2) 31 | r2 = (-os + recall - 1) / (sqrt(2)) 32 | rval = 1 - (abs(r1) + abs(r2)) / 2 33 | return rval 34 | 35 | def get_metric(name: str): 36 | """ 37 | get metric by name 38 | """ 39 | if name == COVERAGE: 40 | return SegmentationCoverage() 41 | elif name == PURITY: 42 | return SegmentationPurity() 43 | elif name == CPFMEASURE: 44 | return SegmentationPurityCoverageFMeasure() 45 | elif name == RECALL: 46 | return SegmentationRecall(tolerance=TOL) 47 | elif name == PRECISION: 48 | return SegmentationPrecision(tolerance=TOL) 49 | elif name == RPFMEASURE: 50 | return SegmentationPrecisionRecallFMeasure(tolerance=TOL) 51 | elif name == R_VALUE: 52 | return SegmentationPrecisionRecallRValue(tolerance=TOL) 53 | else: 54 | raise ValueError(f"Metric {name} not available") 55 | 56 | 57 | def to_timeline(segments: List[Mapping[str,float]]): 58 | """ 59 | convert list of segments to timeline 60 | """ 61 | return Timeline([Segment(seg["start"], seg["end"]) for seg in segments]) 62 | 63 | def to_annotation(segments: List[Mapping[str,float]]): 64 | """ 65 | convert list of segments to annotation 66 | """ 67 | return to_timeline(segments).to_annotation() 68 | 69 | 70 | class SegmentationPrecisionRecallRValue: 71 | 72 | """ 73 | Segmentation precision recall r value 74 | """ 75 | 76 | def __init__(self, tolerance: float = 0.5) -> None: 77 | self.tolerance = tolerance 78 | self.precision = SegmentationPrecision(tolerance=self.tolerance) 79 | self.recall = SegmentationRecall(tolerance=self.tolerance) 80 | self.results_ = [] 81 | 82 | def __call__(self, reference,hypothesis,detailed=False) -> Any: 83 | """ 84 | get precision recall r value for single sample 85 | """ 86 | precision = self.precision(reference,hypothesis) 87 | recall = self.recall(reference,hypothesis) 88 | rval = r_value(precision,recall) 89 | self.results_.append(rval) 90 | if detailed: 91 | return {"precision": precision, "recall": recall, "rval": rval} 92 | return rval 93 | 94 | def __abs__(self): 95 | presicion = abs(self.precision) 96 | recall = abs(self.recall) 97 | return r_value(presicion,recall) 98 | 99 | def confidence_interval(self, alpha=0.9): 100 | """ 101 | get confidence interval for segmentation 102 | """ 103 | return scipy.stats.bayes_mvs(self.results_, alpha=alpha)[0] 104 | 105 | 106 | class SegmentationPrecisionRecallFMeasure: 107 | 108 | """ 109 | Segmentation precision recall f measure 110 | """ 111 | 112 | def __init__(self, tolerance: float = 0.5, beta=1) -> None: 113 | self.tolerance = tolerance 114 | self.beta = beta 115 | self.precision = SegmentationPrecision(tolerance=self.tolerance) 116 | self.recall = SegmentationRecall(tolerance=self.tolerance) 117 | self.results_ = [] 118 | 119 | def __call__(self, reference,hypothesis,detailed=False) -> Any: 120 | """ 121 | get precision recall f measure for single sample 122 | """ 123 | results = {} 124 | results["precision"] = self.precision(reference,hypothesis,detailed=detailed) 125 | results["recall"] = self.recall(reference,hypothesis,detailed=detailed) 126 | results["f_measure"] = f_measure(results["precision"],results["recall"],beta=self.beta) 127 | self.results_.append(results["f_measure"]) 128 | if detailed: 129 | return results 130 | return results["f_measure"] 131 | 132 | def __abs__(self): 133 | presicion = abs(self.precision) 134 | recall = abs(self.recall) 135 | return f_measure(presicion,recall,beta=self.beta) 136 | 137 | def confidence_interval(self, alpha=0.9): 138 | """ 139 | get confidence interval for segmentation 140 | """ 141 | return scipy.stats.bayes_mvs(self.results_, alpha=alpha)[0] 142 | 143 | 144 | class SegmentationMetrics: 145 | 146 | """ 147 | Segmentation metrics, class handle all metrics at once 148 | """ 149 | 150 | def __init__(self, metrics) -> None: 151 | if not all(metric in AVAILABLE_METRICS for metric in metrics): 152 | raise ValueError(f"Metrics should be one of {AVAILABLE_METRICS}") 153 | if len(metrics) == 0: 154 | raise ValueError("At least one metric should be provided") 155 | if (COVERAGE in metrics or PURITY in metrics) and CPFMEASURE in metrics: 156 | Warning("using CPFMEASURE with COVERAGE or PURITY is not recommended") 157 | self.metrics = {metric: get_metric(metric) for metric in metrics} 158 | 159 | 160 | def __call__(self, reference, hypothesis,detailed=False) -> Any: 161 | """ 162 | get metric for all metrics 163 | """ 164 | results = {} 165 | for metric in self.metrics: 166 | if not detailed: 167 | results[metric] = self.metrics[metric](reference,hypothesis) 168 | else: 169 | results.update(self.metrics[metric](reference,hypothesis,detailed=detailed)) 170 | return results 171 | 172 | def __abs__(self): 173 | results = {} 174 | for metric in self.metrics: 175 | results[metric] = abs(self.metrics[metric]) 176 | return results 177 | 178 | def confidence_interval(self, alpha=0.9): 179 | """ 180 | get confidence interval for all metrics 181 | """ 182 | results = {} 183 | for metric in self.metrics: 184 | results[metric] = self.metrics[metric].confidence_interval(alpha=alpha) 185 | return results 186 | 187 | 188 | -------------------------------------------------------------------------------- /speech_lm/inference.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | from typing import List, Mapping 4 | from .tokenizers import SpeechTokenizer 5 | from .utils import build_speech_lm,nll 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | class InferenceModel(ABC): 9 | 10 | @abstractmethod 11 | def log_likelihood(self, wavs: List[torch.Tensor]) -> torch.Tensor: 12 | ... 13 | 14 | @abstractmethod 15 | def to(self, device): 16 | ... 17 | 18 | 19 | class InferenceModelFactory: 20 | 21 | @staticmethod 22 | def get_model(config: Mapping,base_path="./") -> InferenceModel: 23 | if config["model_type"] == "slm": 24 | return SLMInferenceModel(config,base_path=base_path) 25 | 26 | raise ValueError(f"Model type {config['model_type']} not supported") 27 | 28 | 29 | class SLMInferenceModel(InferenceModel): 30 | 31 | def __init__(self, config,base_path="./"): 32 | tokenizer_config = config['tokenizer'] 33 | self.tokenizer = SpeechTokenizer(tokenizer_config) 34 | self.speech_lm = build_speech_lm(config["model_name"], base_path=base_path) 35 | self.mean_nll = config.get("mean_nll",False) 36 | self.offset = self.speech_lm.config.offset 37 | self.padding_value = self.speech_lm.config.pad_token_id 38 | 39 | def log_likelihood(self, wavs: List[torch.Tensor]) -> torch.Tensor: 40 | sentece_tokens = self.tokenizer(wavs,self.offset) 41 | x = pad_sequence(sentece_tokens,batch_first=True,padding_value=self.padding_value) 42 | logits = self.speech_lm(input_ids=x).logits 43 | shifted_x = x[..., 1:] 44 | shifted_logits = logits[..., :-1, :] 45 | 46 | # Create a mask that is True where the tokens are not padding tokens 47 | mask = (shifted_x != self.padding_value) 48 | 49 | # Convert the losses to likelihoods 50 | 51 | return -nll(shifted_logits, shifted_x, mask,self.mean_nll) 52 | 53 | 54 | def to(self, device): 55 | self.tokenizer.to(device) 56 | self.speech_lm.to(device) 57 | return self 58 | -------------------------------------------------------------------------------- /speech_lm/scorers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | from typing import List, Mapping 4 | from torch import FloatTensor 5 | from .inference import InferenceModelFactory 6 | 7 | class Scorer(ABC, torch.nn.Module): 8 | 9 | @abstractmethod 10 | def score_consecutive(self, x: List[FloatTensor]) -> FloatTensor: 11 | """ 12 | gives a score to a list of consecutive sentences 13 | returns a tensor of size len(x)-1 with the score of each consecutive pair 14 | :param x: 15 | :return: score 16 | """ 17 | pass 18 | 19 | @abstractmethod 20 | def to(self, device): 21 | """ 22 | moves scorer to device 23 | :param device: device 24 | """ 25 | pass 26 | 27 | 28 | class ScorerFactory: 29 | 30 | @staticmethod 31 | def get_scorer(scorer_config:Mapping,base_path="../models") -> Scorer: 32 | """ 33 | factory method to get a scorer from a config 34 | :param scorer_config: scorer config 35 | :return: scorer 36 | """ 37 | scorer_type = scorer_config['type'] 38 | if scorer_type == 'pmi': 39 | return SpeechPMIscorer(scorer_config,base_path=base_path) 40 | else: 41 | raise NotImplementedError(f'No such scorer type: {scorer_type}') 42 | 43 | 44 | class SpeechPMIscorer(Scorer): 45 | 46 | def __init__(self, config,base_path="../models"): 47 | super().__init__() 48 | inferance_config = config['inference_model'] 49 | self.inference_model = InferenceModelFactory.get_model(inferance_config,base_path=base_path) 50 | self.batch_size = config.get('batch_size',-1) 51 | self.device = "cpu" 52 | 53 | 54 | def _batch_log_likelihood(self, sentences: List[torch.Tensor]) -> FloatTensor: 55 | """ 56 | return likelihood for a batch of audio sentences 57 | :param sentences: batch of sentences 58 | :return: likelihoods 59 | """ 60 | likelihoods = [] 61 | for i in range(len(sentences),self.batch_size): 62 | likelihoods.append(self.inference_model.log_likelihood(sentences[i:i+self.batch_size])) 63 | return torch.cat(likelihoods) 64 | 65 | @torch.no_grad() 66 | def score_consecutive(self, sentences: List[torch.Tensor]) -> FloatTensor: 67 | """ 68 | gives a score to a list of consecutive sentences 69 | returns a tensor of size len(sentences)-1 with the score of each consecutive pair 70 | sentences: List of sentences 71 | return: score 72 | """ 73 | assert len(sentences) > 1, "need at least two sentences to score" 74 | concat = [torch.cat((x, y), dim=-1) for x, y in zip(sentences[:-1], sentences[1:])] 75 | if self.batch_size > 0: 76 | sentences_log_likelihoods = self._batch_log_likelihood(sentences) 77 | concat_log_likelihoods = self._batch_log_likelihood(concat) 78 | else: 79 | sentences_log_likelihoods = self.inference_model.log_likelihood(sentences) 80 | concat_log_likelihoods = self.inference_model.log_likelihood(concat) 81 | log_numerator = concat_log_likelihoods 82 | log_denominator_x1 = sentences_log_likelihoods[:-1] 83 | log_denominator_x2 = sentences_log_likelihoods[1:] 84 | return log_numerator - (log_denominator_x1 + log_denominator_x2) 85 | 86 | def to(self, device): 87 | """ 88 | moves scorer to device 89 | """ 90 | self.inference_model.to(device) 91 | self.device = device -------------------------------------------------------------------------------- /speech_lm/spans_selector.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from torch import FloatTensor 3 | from typing import List,Mapping 4 | import torch 5 | 6 | 7 | class SpansSelector(ABC): 8 | """ 9 | SpansSelector is an abstract class that gets the sentences and scores and decides the spans for the segmentations 10 | """ 11 | 12 | @abstractmethod 13 | def decide(self,sentences: List[FloatTensor], scores: FloatTensor) -> int: 14 | """ 15 | decides the number of segments to use 16 | :param sentences: list of sentences 17 | :param scores: scores neighbor sentences 18 | :return: number of segments 19 | """ 20 | pass 21 | 22 | @abstractmethod 23 | def get_spans(self,sentences: List[FloatTensor], scores: FloatTensor) -> List[int]: 24 | 25 | """ 26 | gets the spans for the segmentations 27 | :param sentences: list of sentences 28 | :param scores: scores neighbor sentences 29 | :return: list of spans 30 | """ 31 | pass 32 | 33 | class SpansSelectorFactory: 34 | 35 | @classmethod 36 | def get_span_selector(cls, sselector_config:Mapping)->SpansSelector: 37 | """ 38 | factrory method to get a decider from a config 39 | """ 40 | sselector_type = sselector_config['type'] 41 | if sselector_type == 'constant': 42 | return ConstantSpansSelector(sselector_config) 43 | elif sselector_type == 'adaptive': 44 | return AdaptiveSpansSelector(sselector_config) 45 | elif sselector_type == 'threshold': 46 | return ThresholdSpansSelector(sselector_config) 47 | else: 48 | raise NotImplementedError(f'No such decider type: {sselector_type}') 49 | 50 | 51 | class ConstantSpansSelector(SpansSelector): 52 | """ 53 | ConstantSpansSelector is a SpansSelector that always returns the same number of segments. 54 | """ 55 | 56 | def __init__(self, config: Mapping): 57 | """ 58 | :param config: config for the decider. needs to contain a num_segments int. 59 | """ 60 | self.num_segments = config["num_segments"] 61 | self.descending = config.get("descending", False) 62 | 63 | def decide(self, _: List[FloatTensor], __: FloatTensor) -> int: 64 | """ 65 | decides the number of segments to use 66 | :param sentences: list of sentences 67 | :return: number of segments 68 | """ 69 | return self.num_segments 70 | 71 | def get_spans(self, sentences: List[FloatTensor], scores: FloatTensor) -> List[int]: 72 | """ 73 | gets the spans for the segmentations 74 | :param sentences: list of sentences 75 | :param scores: scores neighbor sentences 76 | :return: list of spans 77 | """ 78 | scores, indicies = torch.sort(scores, descending=self.descending) 79 | top_indicies = indicies[:self.num_segments - 1] 80 | argsort = torch.argsort(top_indicies) 81 | spans = [0] + (top_indicies[argsort] + 1).detach().cpu().tolist() 82 | spans.append(len(sentences)) 83 | return spans 84 | 85 | class AdaptiveSpansSelector(SpansSelector): 86 | """ 87 | AdaptiveSpansSelector is a SpansSelector that returns an Adaptive number of segments based on the number of sentences. 88 | """ 89 | 90 | def __init__(self, config: Mapping): 91 | """ 92 | :param config: config for the decider. needs to contain a base int, len_offset and sentences_for_segment int. 93 | """ 94 | 95 | self.base_segments = config["base_segments"] 96 | self.len_offset = config["len_offset"] 97 | self.sentences_for_segment = config["sentences_for_segment"] 98 | self.descending = config.get("descending", False) 99 | 100 | def decide(self, sentences: List[FloatTensor], _: FloatTensor) -> int: 101 | """ 102 | decides the number of segments to use 103 | :param sentences: list of sentences 104 | :return: number of segments 105 | """ 106 | return max(0,len(sentences) - self.len_offset) // self.sentences_for_segment + self.base_segments 107 | 108 | def get_spans(self, sentences: List[FloatTensor], scores: FloatTensor) -> List[int]: 109 | num_segments = self.decide(sentences, scores) 110 | scores, indicies = torch.sort(scores, descending=self.descending) 111 | top_indicies = indicies[:num_segments - 1] 112 | argsort = torch.argsort(top_indicies) 113 | spans = [0] + (top_indicies[argsort] + 1).detach().cpu().tolist() 114 | spans.append(len(sentences)) 115 | return spans 116 | 117 | 118 | class ThresholdSpansSelector(SpansSelector): 119 | """ 120 | ThresholdSpansSelector is a SpansSelector that returns an Adaptive number of segments based on the number of sentences. 121 | """ 122 | 123 | def __init__(self, config: Mapping): 124 | """ 125 | :param config: config for the decider. needs to contain a base int, len_offset and sentences_for_segment int. 126 | """ 127 | self.threshold = config["threshold"] 128 | self.larger_than = config.get("larger_than", False) 129 | 130 | def decide(self, _: List[FloatTensor], scores: FloatTensor) -> int: 131 | """ 132 | decides the number of segments to use 133 | :param sentences: list of sentences 134 | :return: number of segments 135 | """ 136 | return self.threshold_items(scores).sum().item() 137 | 138 | def threshold_items(self, scores: FloatTensor) -> FloatTensor: 139 | if self.larger_than: 140 | return scores > self.threshold 141 | else: 142 | return scores < self.threshold 143 | 144 | def get_spans(self, sentences: List[FloatTensor], scores: FloatTensor) -> List[int]: 145 | indicies = self.threshold_items(scores).nonzero().squeeze(1) 146 | spans = [0] + (indicies + 1).detach().cpu().tolist() 147 | spans.append(len(sentences)) 148 | return spans -------------------------------------------------------------------------------- /speech_lm/speech_sentencer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from torch import Tensor 3 | import torchaudio 4 | from typing import List, Tuple,Mapping,Any 5 | import torch 6 | 7 | 8 | class SpeechSentencer(ABC): 9 | """ 10 | SpeechSentencer is an abstract class that creates "sentences" from audio. 11 | for an audio file. it will create x1, x2, x3, ... xn "sentences" for the 12 | audio file. The sentences are not necessarily sentences in the traditional 13 | sense, but they are segments of the audio file that are related in some 14 | way. 15 | """ 16 | 17 | @abstractmethod 18 | def sentence(self, audio: Tensor, sr: int) -> List[Mapping[str,Any]]: 19 | """ 20 | Sentence audio 21 | :param sr: sampling rate of the audio 22 | :param audio: audio to sentence 23 | :return: list of audio sentences 24 | """ 25 | pass 26 | 27 | @abstractmethod 28 | def sentence_path(self, audio_file: str, resample_rate:int = -1, device="cpu") -> Tuple[List[Mapping[str,Any]], int]: 29 | """ 30 | Sentence audio file 31 | :param audio_file: audio file to sentence 32 | :param resample_rate: resample rate. if -1, no resampling will be done if different from audio sr will resample. 33 | :return: list of audio sentences 34 | """ 35 | pass 36 | 37 | class SpeechSentencerFactory: 38 | 39 | @staticmethod 40 | def get_sentencer(sentencer_config:Mapping)->SpeechSentencer: 41 | """ 42 | factrory method to get a sentencer from a config 43 | """ 44 | sentencer_type = sentencer_config['type'] 45 | if sentencer_type == 'length': 46 | return LengthSpeechSentencer(sentencer_config['length'],sentencer_config["min_length"],sentencer_config['drop_last']) 47 | else: 48 | raise NotImplementedError(f'No such sentencer type: {sentencer_type}') 49 | 50 | 51 | class LengthSpeechSentencer(SpeechSentencer): 52 | """ 53 | LengthSpeechSentencer is a SpeechSentencer that uses the length of the 54 | audio to sentence it. It will divide the audio into x1, x2, x3, ... xn 55 | "sentences" of equal audio length. 56 | """ 57 | 58 | def __init__(self, length: int,min_length=0.05,drop_last=False): 59 | """ 60 | :param length: length of each sentence in seconds 61 | :param min_length: check if the last sentence is shorter than min length. 62 | if he is combine the last two sentences 63 | :param drop_last: drop last sentece 64 | """ 65 | self.length = length 66 | self.drop_last = drop_last 67 | self.min_length = min_length 68 | 69 | 70 | def sentence(self, audio: Tensor, sr: int) -> List[Mapping[str,Any]]: 71 | """ 72 | :param audio: audio to sentence 73 | :param sr: sampling rate of the audio 74 | :return: list of audio sentences with their start and end times 75 | """ 76 | l = int(self.length * sr) 77 | sentences = [ 78 | {"audio":audio[..., i:i + l], "start":i/sr, "end":min((i + l)/sr,audio.shape[-1]/sr)} 79 | for i in range(0, audio.shape[-1], l) 80 | ] 81 | if self.drop_last: 82 | sentences.pop() 83 | if sentences[-1]["audio"].shape[-1] < int(self.min_length * sr): 84 | last = sentences.pop() 85 | prev_last = sentences.pop() 86 | sentences.append( 87 | {"audio":torch.concat((prev_last["audio"],last["audio"]),dim=-1),"start":prev_last["start"],"end":last["end"]} 88 | ) 89 | return sentences 90 | 91 | 92 | def sentence_path(self, audio_path: str, resample_rate:int = -1,device="cpu") -> Tuple[List[Mapping[str,Any]], int]: 93 | """ 94 | Sentence audio file 95 | :param audio_path: audio file path to sentence 96 | :return: tuple of list of audio sentences and sampling rate 97 | """ 98 | audio, sr = torchaudio.load(audio_path) 99 | audio = audio.to(device) 100 | if resample_rate != -1 and sr != resample_rate: 101 | audio = torchaudio.functional.resample(audio,sr,resample_rate) 102 | sr = resample_rate 103 | return self.sentence(audio, sr), sr 104 | -------------------------------------------------------------------------------- /speech_lm/tokenizers.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List,Union,Mapping 2 | from torch import FloatTensor,LongTensor 3 | import torch 4 | from .utils import get_gslm_speech_encoder 5 | import json 6 | 7 | UNITS = "units" 8 | 9 | class SpeechTokenizer(torch.nn.Module): 10 | 11 | 12 | def __init__(self, config:Mapping) -> None: 13 | """ 14 | init method for SpeechTokenizer 15 | config needs to contain a dense_model_name, a quantizer_model_name, an encoder_vocab_size, and a deduplicate flag 16 | """ 17 | super().__init__() 18 | self.encoder = get_gslm_speech_encoder(config['dense_model_name'], config['quantizer_model_name'], config['encoder_vocab_size'], config['deduplicate'],need_f0 = config['need_f0']) 19 | 20 | @classmethod 21 | def from_pretrained(cls, path_to_config:str) -> 'SpeechTokenizer': 22 | """ 23 | a class method to create a SpeechTokenizer from a pretrained config path 24 | """ 25 | with open(path_to_config, 'r') as f: 26 | config = json.load(f) 27 | return cls(config) 28 | 29 | def forward(self, x:Union[List[FloatTensor],FloatTensor], offset:int) -> List[LongTensor]: 30 | """ 31 | tokenizes a list of audio tensors 32 | x: a list of audio tensors (or a single audio tensor) 33 | offset: the offset to add to the tokens 34 | """ 35 | if isinstance(x, FloatTensor): 36 | x = [x] 37 | offset = torch.tensor(offset,device=x[0].device) 38 | return [self.encoder(x_i)[UNITS].long() + offset for x_i in x] 39 | 40 | def to(self, device): 41 | """ 42 | moves tokenizer to device 43 | """ 44 | self.encoder.to(device) 45 | return self -------------------------------------------------------------------------------- /speech_lm/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import wget 4 | import torch 5 | from transformers import AutoModelForCausalLM 6 | from textless.data.speech_encoder import SpeechEncoder 7 | from torch import FloatTensor, LongTensor 8 | from torch.nn.functional import cross_entropy 9 | 10 | 11 | """ 12 | This file contains utils for speech_lm 13 | """ 14 | 15 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | ROOT_URL = 'https://dl.fbaipublicfiles.com/textless_nlp/twist/lms/' 18 | 19 | 20 | def get_gslm_speech_encoder(dense_model_name, quantizer_model_name, vocab_size, 21 | deduplicate,need_f0,f0_func="yaapt",f0_normalizer=None,f0_quantizer=None,chunk_alignment=False): 22 | """ 23 | get speech encoder using textless library 24 | :param dense_model_name: dense model name 25 | :param quantizer_model_name: quantizer model name 26 | :param vocab_size: vocab size 27 | :param deduplicate: deduplicate 28 | :param need_f0: need f0 29 | """ 30 | return SpeechEncoder.by_name( 31 | dense_model_name=dense_model_name, 32 | quantizer_model_name=quantizer_model_name, 33 | vocab_size=vocab_size, 34 | deduplicate=deduplicate, 35 | need_f0=need_f0, 36 | f0_normalizer=f0_normalizer, 37 | f0_quantizer=f0_quantizer, 38 | f0_func=f0_func, 39 | chunk_alignment=chunk_alignment 40 | ) 41 | 42 | def unzip_file(zip_path, extract_path): 43 | """ 44 | unzip file 45 | :param zip_path: path to zip file 46 | :param extract_path: path to extract to 47 | """ 48 | with zipfile.ZipFile(zip_path, 'r') as zip_ref: 49 | zip_ref.extractall(extract_path) 50 | print(f"File extracted to {extract_path}") 51 | 52 | 53 | def maybe_download_speech_lm(name, base_path): 54 | """ 55 | downloads speech lm 56 | :param name: name of model 57 | :param base_path: base path to download to 58 | """ 59 | if not os.path.exists(base_path): 60 | os.mkdir(base_path) 61 | 62 | ckpt_dir = os.path.join(base_path, name) 63 | if not os.path.exists(ckpt_dir): 64 | url = ROOT_URL + name + '.zip' 65 | zip_path = ckpt_dir + '.zip' 66 | print(f"Downloading from {url}") 67 | filename = wget.download(url, zip_path) 68 | unzip_file(filename, ckpt_dir) 69 | 70 | return os.path.abspath(ckpt_dir) 71 | 72 | def build_speech_lm(model_type, base_path='./'): 73 | """ 74 | builds speech lm 75 | retruns model 76 | """ 77 | ckpt_dir = maybe_download_speech_lm(model_type, base_path) 78 | 79 | lm_model = AutoModelForCausalLM.from_pretrained(ckpt_dir) 80 | lm_model.eval() 81 | 82 | return lm_model 83 | 84 | def nll(logits:FloatTensor, target:LongTensor, mask:LongTensor,mean_nll:bool = False)->FloatTensor: 85 | """ 86 | calculate the negative log likelihood of the logits given the target 87 | :param logits: logits 88 | :param target: target 89 | :param mask: mask 90 | :return: nll 91 | """ 92 | # Calculate the cross-entropy loss for each sequence 93 | losses = cross_entropy( 94 | logits.contiguous().view(-1, logits.size(-1)), 95 | target.long().contiguous().view(-1), reduction='none') 96 | 97 | # Reshape the losses to match the original sequences 98 | losses = losses.view(*target.size()) 99 | 100 | # Use the mask to ignore the losses of the padding tokens 101 | masked_losses = losses * mask 102 | 103 | # Sum the losses to get the total loss for each sequence 104 | ll = masked_losses.sum(dim=-1) 105 | if mean_nll: 106 | return ll / mask.sum(dim=-1) 107 | return ll -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | from librosa import get_duration 5 | 6 | VAD_SR = 16000 7 | VAD_THRESHOLD = 0.4 8 | 9 | 10 | def load_vad(): 11 | vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', 12 | model='silero_vad', 13 | force_reload=False, trust_repo=True) 14 | return vad_model, utils 15 | 16 | 17 | def split_audio(audio, sr, seconds): 18 | seg_length = int(seconds * sr) 19 | chunks = torch.split(audio, seg_length, dim=-1) 20 | return chunks 21 | 22 | 23 | def get_folder_duration(path, folder): 24 | total_duration = 0 25 | folder_path = os.path.join(path, folder) 26 | for root, _, files in os.walk(folder_path): 27 | for file in files: 28 | if not file.endswith(".wav"): 29 | continue 30 | file_path = os.path.join(root, file) 31 | total_duration += get_duration(path=file_path) 32 | return total_duration 33 | 34 | def load_utterInfo(inputFile): 35 | """ 36 | Load utterInfo from original IEMOCAP database 37 | """ 38 | # this regx allow to create a list with: 39 | # [START_TIME - END_TIME] TURN_NAME EMOTION [V, A, D] 40 | # [V, A, D] means [Valence, Arousal, Dominance] 41 | pattern = re.compile( 42 | "[\[]*[0-9]*[.][0-9]*[ -]*[0-9]*[.][0-9]*[\]][\t][a-z0-9_]*[\t][a-z]{3}[\t][\[][0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[\]]", 43 | re.IGNORECASE, 44 | ) # noqa 45 | with open(inputFile, "r") as myfile: 46 | data = myfile.read().replace("\n", " ") 47 | result = pattern.findall(data) 48 | out = [] 49 | for i in result: 50 | a = i.replace("[", "") 51 | b = a.replace(" - ", "\t") 52 | c = b.replace("]", "") 53 | x = c.replace(", ", "\t") 54 | out.append(x.split("\t")) 55 | return out 56 | 57 | 58 | EMOTION_DICT = { 59 | "hap": "happy", 60 | "exc": "happy", 61 | "sad": "sad", 62 | "ang": "angry", 63 | "neu": "neutral", 64 | } 65 | 66 | def get_emotion_dict(text_file): 67 | """ 68 | Get emotion dict from original IEMOCAP database 69 | """ 70 | emotion_dict = {} 71 | with open(text_file) as f: 72 | utterance = load_utterInfo(text_file) 73 | for line in utterance: 74 | id = line[2] 75 | emo = line[3] 76 | if emo in EMOTION_DICT: 77 | emotion_dict[id] = EMOTION_DICT[emo] 78 | return emotion_dict --------------------------------------------------------------------------------