├── .github └── workflows │ └── importtest.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── doc ├── Makefile └── images │ ├── Makefile │ ├── example_id.svg │ ├── example_id.tex │ └── room.svg ├── scripts ├── cmd.sh ├── get_nnet3_model.bash ├── get_tri3_model.bash ├── path.sh ├── prepare_nnet3_model_training.bash └── run_ivector_common.sh ├── setup.py ├── sms_wsj ├── __init__.py ├── database │ ├── __init__.py │ ├── create_intermediate_json.py │ ├── create_json_for_written_files.py │ ├── create_rirs.py │ ├── database.py │ ├── dynamic_mixing.py │ ├── utils.py │ ├── write_files.py │ └── wsj │ │ ├── __init__.py │ │ ├── create_json.py │ │ └── write_wav.py ├── examples │ ├── metric_target_comparison.py │ └── reference_systems.py ├── io.py ├── kaldi │ ├── __init__.py │ ├── get_kaldi_wer.py │ └── utils.py ├── reverb │ ├── __init__.py │ ├── reverb_utils.py │ ├── rotation.py │ └── scenario.py ├── train_baseline_asr.py └── visualization.py └── tests ├── database.py └── test_import.py /.github/workflows/importtest.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Import-Test 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ['3.10', '3.11'] 19 | steps: 20 | - uses: actions/checkout@v3 21 | - uses: mpi4py/setup-mpi@v1 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade cython scipy numpy 29 | python -m pip install --upgrade setuptools 30 | python -m pip install flake8 31 | # https://github.com/pypa/pip/issues/12030#issuecomment-1546344047 32 | python -m pip install wheel 33 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 34 | python -m pip list 35 | python -m pip install '.[all]' 36 | - name: Test with pytest 37 | run: | 38 | python -m pytest tests/test_import.py # other tests require the database, e.g., sms-wsj files 39 | 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Department of Communications Engineering University of Paderborn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL = /bin/bash 2 | 3 | # Uses environment variable WSJ_DIR if it is defined, otherwise falls back to default /net/fastdb/wsj 4 | WSJ_DIR ?= /net/fastdb/wsj 5 | WSJ0_DIR ?= $(WSJ_DIR) 6 | WSJ1_DIR ?= $(WSJ_DIR) 7 | SMS_WSJ_DIR ?= cache 8 | RIR_DIR = $(SMS_WSJ_DIR)/rirs 9 | JSON_DIR ?= $(SMS_WSJ_DIR) 10 | WSJ_8K_ZEROMEAN_DIR ?= $(SMS_WSJ_DIR)/wsj_8k_zeromean 11 | WRITE_ALL = True # If True the reverberated data will be calculated on the fly and not saved to SMS_WSJ_DIR 12 | num_jobs = $(shell nproc --all) 13 | DEBUG = False # If True, create less entries in sms_wsj.json just for debugging. 14 | # Example for call on the paderborn parallel computing center 15 | # ccsalloc --res=rset=1:mem=2G:ncpus=8 -t 4h make all --num_jobs=8 16 | 17 | export OMP_NUM_THREADS = 1 18 | export MKL_NUM_THREADS = 1 19 | 20 | all: sms_wsj sms_wsj.json 21 | 22 | wsj_8k_zeromean: $(WSJ_8K_ZEROMEAN_DIR) 23 | $(WSJ_8K_ZEROMEAN_DIR): $(WSJ_DIR) 24 | @echo creating $(WSJ_8K_ZEROMEAN_DIR) 25 | @echo using $(num_jobs) parallel jobs 26 | mpiexec -np ${num_jobs} python -m sms_wsj.database.wsj.write_wav \ 27 | with dst_dir=$(WSJ_8K_ZEROMEAN_DIR) wsj0_root=$(WSJ0_DIR) wsj1_root=$(WSJ1_DIR) sample_rate=8000 28 | 29 | wsj_8k_zeromean.json: $(JSON_DIR)/wsj_8k_zeromean.json 30 | $(JSON_DIR)/wsj_8k_zeromean.json: $(WSJ_8K_ZEROMEAN_DIR) | $(JSON_DIR) 31 | @echo creating $(JSON_DIR)/wsj_8k_zeromean.json 32 | python -m sms_wsj.database.wsj.create_json \ 33 | with json_path=$(JSON_DIR)/wsj_8k_zeromean.json database_dir=$(WSJ_8K_ZEROMEAN_DIR) as_wav=True 34 | 35 | intermediate_sms_wsj.json: $(JSON_DIR)/intermediate_sms_wsj.json 36 | $(JSON_DIR)/intermediate_sms_wsj.json: $(JSON_DIR)/wsj_8k_zeromean.json | $(JSON_DIR) $(RIR_DIR) 37 | @echo creating $(JSON_DIR)/intermediate_sms_wsj.json 38 | python -m sms_wsj.database.create_intermediate_json \ 39 | with json_path=$(JSON_DIR)/intermediate_sms_wsj.json rir_dir=$(RIR_DIR) wsj_json_path=$(JSON_DIR)/wsj_8k_zeromean.json debug=$(DEBUG) 40 | 41 | sms_wsj: $(SMS_WSJ_DIR)/observation 42 | $(SMS_WSJ_DIR)/observation: $(JSON_DIR)/intermediate_sms_wsj.json | $(SMS_WSJ_DIR) 43 | @echo creating $(SMS_WSJ_DIR) files 44 | @echo using $(num_jobs) parallel jobs 45 | mpiexec -np ${num_jobs} python -m sms_wsj.database.write_files \ 46 | with dst_dir=$(SMS_WSJ_DIR) json_path=$(JSON_DIR)/intermediate_sms_wsj.json write_all=$(WRITE_ALL) debug=$(DEBUG) 47 | 48 | clean_write_files: 49 | rm -rf $(SMS_WSJ_DIR)/{observation,noise,early,tail,speech_source} 50 | 51 | sms_wsj.json: $(JSON_DIR)/sms_wsj.json | $(SMS_WSJ_DIR)/observation 52 | $(JSON_DIR)/sms_wsj.json: $(JSON_DIR)/intermediate_sms_wsj.json | $(SMS_WSJ_DIR) 53 | @echo creating $(JSON_DIR)/sms_wsj.json 54 | @echo This amends the sms_wsj.json with the new paths. 55 | python -m sms_wsj.database.create_json_for_written_files \ 56 | with db_dir=$(SMS_WSJ_DIR) intermed_json_path=$(JSON_DIR)/intermediate_sms_wsj.json write_all=$(WRITE_ALL) json_path=$(JSON_DIR)/sms_wsj.json debug=$(DEBUG) 57 | 58 | # The room impuls responses can be downloaded, so that they do not have to be created 59 | # however if you want to recreate them use "make rirs RIR_DIR=/path/to/storage/" 60 | rirs: 61 | @echo creating $(RIR_DIR) 62 | pip install git+https://github.com/boeddeker/rirgen 63 | mpiexec -np ${num_jobs} python -m sms_wsj.database.create_rirs \ 64 | with database_path=$(RIR_DIR) 65 | 66 | # To manually download and extract the rirs, execute the following after downloading all files from https://zenodo.org/record/3517889 67 | # cat $(RIR_DIR)/sms_wsj.tar.gz.* > $(RIR_DIR)/sms_wsj.tar.gz 68 | # tar -C $(RIR_DIR)/ -zxvf $(RIR_DIR)/sms_wsj.tar.gz 69 | $(RIR_DIR): 70 | @echo "RIR directory does not exist, starting download, to recreate the RIRs use 'make rirs'." 71 | mkdir -p $(RIR_DIR) 72 | echo $(RIR_DIR) 73 | wget -qO- https://zenodo.org/record/3517889/files/sms_wsj.tar.gz.parta{a,b,c,d,e} \ 74 | | tar -C $(RIR_DIR)/ -zx --checkpoint=10000 --checkpoint-action=echo="%u/5530000 %c" 75 | 76 | $(JSON_DIR): 77 | @echo "JSON_DIR is wrongly set or directory does not exist." 78 | @echo "Please specify an existing JSON_DIR directory using the variable, JSON_DIR =" $(JSON_DIR) 79 | exit 1 80 | 81 | $(WSJ_DIR): 82 | @echo "WSJ directory does not exist." 83 | @echo "Please specify an existing WSJ directory using the WSJ_DIR variable, WSJ_DIR =" $(WSJ_DIR) 84 | exit 1 85 | 86 | cache: 87 | mkdir cache 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SMS-WSJ: A database for in-depth analysis of multi-channel source separation algorithms 2 | 3 | ![Example ID](doc/images/room.svg) 4 | 5 | This repository includes the scripts required to create the SMS-WSJ database, a spatial clustering baseline for separation, 6 | and a baseline ASR system using Kaldi (http://github.com/kaldi-asr/kaldi). 7 | 8 | ## Why does this database exist? 9 | 10 | In multi-speaker ASR the [WSJ0-2MIX database](https://www.merl.com/demos/deep-clustering) and the spatialized version thereof are widely used. 11 | Observing that research in multi-speaker ASR is often hard to compare because some researchers pretrain on WSJ, while others train only on WSJ0-2MIX or create other sub-lists of WSJ we decided to use a fixed file list which is suitable for training an ASR system without additional audio data. 12 | Punctuation pronunciation utterances are filtered to further facilitate end-to-end ASR experiments. 13 | 14 | Further, we argue that the tooling around [WSJ0-2MIX database](https://www.merl.com/demos/deep-clustering) and the spatialized version thereof is very limited. 15 | Therefore, we provide a spatial clustering baseline and a Kaldi ASR baseline. 16 | Researchers can now easily improve parts of the pipeline while ensuring that they can fairly compare with baseline results reported in the associated Arxiv paper. 17 | 18 | ## How can I cite this work? Where are baseline results? 19 | The associated paper can be found here: https://arxiv.org/abs/1910.13934 20 | If you are using this code please cite the paper as follows: 21 | 22 | ``` 23 | @Article{SmsWsj19, 24 | author = {Drude, Lukas and Heitkaemper, Jens and Boeddeker, Christoph and Haeb-Umbach, Reinhold}, 25 | title = {{SMS-WSJ}: Database, performance measures, and baseline recipe for multi-channel source separation and recognition}, 26 | journal = {arXiv preprint arXiv:1910.13934}, 27 | year = {2019}, 28 | } 29 | ``` 30 | 31 | ## Installation 32 | 33 | Does not work with Windows. 34 | 35 | Clone this repository and install the package: 36 | ```bash 37 | $ git clone https://github.com/fgnt/sms_wsj.git 38 | $ cd sms_wsj 39 | $ pip install --user -e ./ 40 | ``` 41 | 42 | Set your KALDI_ROOT environment variable: 43 | ```bash 44 | $ export KALDI_ROOT=/path/to/kaldi 45 | ``` 46 | We assume that the Kaldi WSJ baseline has been created with the `run.sh` script. 47 | This is important to be able to use the Kaldi language model. 48 | To build the ASR baseline the structures created during the first stage of 49 | the `run.sh` script are required. 50 | The ASR baseline uses the language models created during the same stage. 51 | Afterwards you can create the database: 52 | ```bash 53 | $ make WSJ_DIR=/path/to/wsj SMS_WSJ_DIR=/path/to/write/db/to 54 | ``` 55 | When you have one folder for WSJ0 and one for WSJ1, you can create the database with: 56 | ``` 57 | $ make WSJ0_DIR=/path/to/wsj0 WSJ1_DIR=/path/to/wsj1 SMS_WSJ_DIR=/path/to/write/db/to # 58 | ``` 59 | If desired the number of parallel jobs may be specified using the additonal 60 | input num_jobs. Per default `nproc --all` parallel jobs are used. 61 | 62 | 63 | The RIRs are downloaded by default, to generate them yourself see [here](#q-i-want-to-generate-the-rirs-myself-how-can-i-do-that). 64 | 65 | 66 | Use the following command to train the baseline ASR model: 67 | ```bash 68 | $ python -m sms_wsj.train_baseline_asr with egs_path=$KALDI_ROOT/egs/ json_path=/path/to/sms_wsj.json 69 | ``` 70 | The script has been tested with the KALDI Git hash "7637de77e0a77bf280bef9bf484e4f37c4eb9475" 71 | 72 | 73 | ## Properties 74 | 75 | - Simulated 76 | - Two speaker mixtures 77 | - 33561 train, 982 dev and 1332 test mixtures 78 | - Longest speaker utterance determines mixture length: 79 | - ASR on both speakers possible 80 | - WSJ based: As clean utterances WSJ0 and WSJ1 are used. 81 | - Sample rate: 8 kHz 82 | - Reverberated 83 | - RIR generator: [Habets](https://github.com/ehabets/RIR-Generator). We use 84 | [this](https://github.com/boeddeker/rirgen) python port. 85 | - Random room with 6 microphones, see first image in this README. 86 | - T60: 200-500 ms 87 | - Time of Flight (ToF) compensation jointly over all channels without RIR 88 | truncation 89 | - A ToF compensation allows to use the source signal as target 90 | for signal level metrics like BSSEval SDR and PESQ, but it also allows 91 | the use of ASR alignments for an ASR training. 92 | - We do not remove the samples in the RIR before the estimated ToF, 93 | because that would imply, that we assume there is an error in the RIR 94 | generator. 95 | - Early-Late split 96 | - We propose a split of the RIR in the early and late part. In this way 97 | the early RIR convolved with the speech source can be used as target 98 | for NN losses (e.g. negative SDR). Note: This is not a target signal 99 | for metrics, because it is far away from a unique definition (Don't 100 | modify the target signal if you want to judge your system). 101 | - Proposed default: 50 ms (Motivated by the REVERB challange) 102 | - Noise 103 | - 20-30 dB Additive white Gaussian noise (AWGN) 104 | - We decided to use just simple noise, because we don't know, how to 105 | simulate realistic multichannel noise. 106 | (e.g. a point noise source is unrealistic) 107 | - We used low volume noise, because it is just AWGN noise. 108 | - **Each unique utterance exactly equally often** 109 | - While the utterances that are used to create are randomly chosen, 110 | we used a sampling algorithm, that guarantees, that each utterance is 111 | equally often used. This ensures, that the "Word" distribution is exactly 112 | the same as the distribution of WSJ0 and WSJ1. 113 | - Many other mixture databases just sample randomly the utterances and 114 | don't ensure that each utterance appears equally often. 115 | - Randomization approach can be generalized to more speakers. 116 | - Random and deterministic 117 | - Exclude verbalized punctuation 118 | 119 | ## How to use this database? 120 | 121 | Once you installed this repository and created the sms_wsj database, 122 | there are a few ways, how you can use this database: 123 | 124 | - Manually read the files from the filesystem (Recommended, when you don't work with python or don't want to use the provided code) 125 | - Manually read the json (Not recommended) 126 | - Use some helper functions from us to: 127 | - Load all desired files from the disk (Recommended for local file systems) 128 | - Load only original WSJ utterances and the RIRs and generate the examples on the fly with the help of the json. (Recommended for remote file systems, we mainly use this) 129 | - This requires some CPU time. It can be done in a backgroud threadpool, e.g. `lazy_dataset.Dataset.prefetch`, for NN experiments, where the CPU often idles, while the GPU is working. 130 | - This allows dynamic mixing of the examples, e.g. creating a nearly infinitely large training dataset. 131 | 132 | On the file system you will find files like `...///_<1_wsj_id>_<2_wsj_id>[_].wav` (e.g. `.../observation/train_si284/0_4axc0218_01kc020f.wav`). 133 | Here an explanation, how the path and file names are generated: 134 | - ``: 135 | - Possible values: `observation`, `speech_source`, `early`, `tail` or `noise` 136 | - `observation` = `early` + `tail` + `noise` 137 | - `speech_source`: The padded signal from WSJ (i.e. `original_source`). 138 | - `early`/`tail`: `speech_source` convolved with inital/late part of `rirs` 139 | - Note: `speech_image` must be calculated as `early` + `tail` 140 | - `speech_image` = `early` + `tail` = `speech_source` convolved with `rirs` 141 | - Note: The WSJ files are mirrored to `wsj_8k_zeromean` and converted to `wav` files and downsamples. Because we simply mirror, the easiest way to find the `original_source` is to use the json. 142 | - ``: 143 | - Possible values: `train_si284`/`cv_dev93`/`test_eval92` 144 | - The original WSJ dataset name. 145 | - ``: 146 | - A running index for the generated room impulse responses (RIR). 147 | - `<1_wsj_id>`, `<2_wsj_id>`: 148 | - The WSJ utterance IDs that are used to generate the mixture. 149 | - ``: 150 | - Possible values: `0` and `1` 151 | - An index, which WSJ utterance/speaker is present in the wav file. 152 | - Omitted for the observation. 153 | 154 | The database creation generates a json file. This file contains all information about the database. 155 | The pattern is as follows: 156 | ```python 157 | { 158 | "datasets": { 159 | dataset_name: { # "train_si284", "cv_dev93" or "test_eval92" 160 | example_id: { # __ 161 | "room_dimensions": [[...], [...], [...]]], 162 | "sound_decay_time": ..., 163 | "source_position": [[..., ...], [..., ...], [..., ...]] 164 | "sensor_position": [[..., ..., ..., ..., ..., ...], [..., ..., ..., ..., ..., ...], [..., ..., ..., ..., ..., ...]], 165 | "example_id": "...", 166 | "num_speakers": 2, 167 | "speaker_id": ["...", "..."], 168 | "gender": ["...", "..."], # "male" or "female" 169 | "kaldi_transcription": ["...", "..."], 170 | "log_weights": [..., ...], # weights of utterances before the are added 171 | "num_samples": { 172 | "original_source": [..., ...], 173 | "observation": ..., 174 | }, 175 | "offset": [..., ...] # Offset of utterance start in samples 176 | "snr": ..., 177 | "audio_path": { 178 | "original_source": ["...", "..."], 179 | "speech_source": ["...", "..."], 180 | "rir": ["...", "..."], 181 | "speech_reverberation_early": ["...", "..."], 182 | "speech_reverberation_tail": ["...", "..."], 183 | "noise_image": "...", 184 | "observation": "...", 185 | } 186 | } 187 | } 188 | } 189 | } 190 | ``` 191 | This file can be used to get all details for each example. 192 | To read it with python, we have some helper functions: 193 | 194 | 195 | ```python 196 | from sms_wsj.database import SmsWsj, AudioReader 197 | db = SmsWsj(json_path='.../sms_wsj.json') 198 | ds = db.get_dataset('train_si284') # "train_si284", "cv_dev93" or "test_eval92" 199 | ds = ds.map(AudioReader(( 200 | 'observation', 201 | 'speech_source', 202 | # 'original_source', 203 | # 'speech_reverberation_early', 204 | # 'speech_reverberation_tail', 205 | 'speech_image', 206 | # 'noise_image', 207 | # 'rir', 208 | ))) 209 | ``` 210 | 211 | Now you can access the examples with the dataset instance. 212 | You can iterate over the dataset (e.g. `for example in ds: ...`) or access examples by their ID, e.g. `ds['0_4axc0218_01kc020f']` 213 | (Access with an index (e.g. `ds[42]`) only works, when the dataset is not shuffled.). 214 | The audio files, that are requested from the `AudioReader` will be loaded on demand and will be available under the key `audio_data` in the `example`. 215 | 216 | If you want to reduce the IO, you can use the `scenario_map_fn`: 217 | ```python 218 | from sms_wsj.database import SmsWsj, AudioReader, scenario_map_fn 219 | db = SmsWsj(json_path='.../sms_wsj.json') 220 | ds = db.get_dataset('cv_dev93') # "train_si284", "cv_dev93" or "test_eval92" 221 | ds = ds.map(AudioReader(( 222 | 'original_source', 223 | 'rir', 224 | ))) 225 | ds = ds.map(scenario_map_fn) # Calculates all signals from `original_source` and `RIR` 226 | ``` 227 | This will avoid the reading of the multi channel signals. 228 | Since the `scenario_map_fn` calculates the convolutions, it can be usefull to use the `prefetch`, so the convolution is done in the backgroud 229 | (Note: The [GIL](https://en.wikipedia.org/wiki/Global_interpreter_lock) will be released, so a ThreadPool is enough.). 230 | 231 | The last option, that we provide, is dynamic mixing. 232 | With each iteration over the dataset, you will get a different utterance. 233 | The `rir` will always be the same, but the utterances will differ (The simulation of the RIR is too expensive to do it on demand): 234 | ```python 235 | from sms_wsj.database import SmsWsj, AudioReader, scenario_map_fn 236 | from sms_wsj.database.dynamic_mixing import SMSWSJRandomDataset 237 | db = SmsWsj(json_path='.../sms_wsj.json') 238 | ds = db.get_dataset('train_si284') 239 | ds = SMSWSJRandomDataset(ds) 240 | ds = ds.map(AudioReader(( 241 | 'original_source', 242 | 'rir', 243 | ))) 244 | ds = ds.map(scenario_map_fn) # Calculates all signals from `original_source` and `RIR` 245 | ``` 246 | 247 | Once you have a `Dataset` instance, you can perfrom shuffeling, batching (with a collate function) and prefetching with a thead/process pool: 248 | 249 | ```python 250 | ds = ds.shuffle(reshuffle=True) 251 | ds = ds.batch(batch_size) # Create a list from `batch_size` consecutive examples 252 | ds = ds.map(my_collate_fn) # e.g. sort the batch, pad/cut examples, move outer list to batch axis, ... 253 | ds = ds.prefetch(4, 8) # Use a ThreadPool with 4 threads to prefetch examples 254 | ``` 255 | 256 | ## FAQ 257 | ### Q: How large is the disc capacity required for the database? 258 | A: The total disc usage is 442.1 GiB. 259 | 260 | directory | disc usage 261 | :------------------|--------------: 262 | tail | 120.1 GiB 263 | early | 120.1 GiB 264 | observation | 60.0 GiB 265 | noise | 60.0 GiB 266 | rirs | 52.6 GiB 267 | wsj_8k_zeromean | 29.2 GiB 268 | sms_wsj.json | 139,7 MiB 269 | wsj_8k.json | 31,6 MiB 270 | 271 | ### Q: How many hours takes the database creation? 272 | A: Using 32 cores the database creation without recalculating the RIRs takes around 4 hours. 273 | 274 | ### Q: What does the example ID `0_4k6c0303_4k4c0319` mean? 275 | A: The example ID is a unique identifier for an example (sometime also known as utterance ID). 276 | The example ID is a composition of the sperakers, the utterances and an scenario counter: 277 | 278 | ![Example ID](doc/images/example_id.svg) 279 | ======= 280 | 281 | ### Q: What to do if Kaldi uses Python 3 instead of Python 2? 282 | The Python code in this repository requires Python 3.6. However, Kaldi runs 283 | on Python 2.7. To solve this mismatch Kaldi has to be forced to switch the 284 | Python version using the `path.sh`. Therefore, add the follwing line to 285 | the `${KALDI_ROOT}/tools/envh.sh` file: 286 | ``` 287 | export PATH=path/to/your/python2/bin/:${PATH} 288 | ``` 289 | 290 | ### Q: I want to generate the RIRs myself. How can I do that? 291 | To generate the RIRs you can run the following command: 292 | ```bash 293 | $ mpiexec -np $(nproc --all) python -m sms_wsj.database.create_rirs with database_path=cache/rirs 294 | ``` 295 | The expected runtime will be around `1900/(ncpus - 1)` hours. 296 | When you have access to an HPC system, you can replace `mpiexec -np $(nproc --all)` with an HPC command. 297 | It is enough, when each job has access to 2GB RAM. 298 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | default: images/example_id.svg 2 | 3 | # Does not work for this image 4 | #images/example_id.svg: images/example_id.pdf 5 | # inkscape -l $@ $^ 6 | 7 | images/example_id.svg: images/example_id.tex 8 | cd images && pdflatex -shell-escape example_id.tex 9 | 10 | -------------------------------------------------------------------------------- /doc/images/Makefile: -------------------------------------------------------------------------------- 1 | all: example_id.svg room.svg 2 | 3 | example_id.svg: example_id.tex 4 | pdflatex --shell-escape example_id.tex 5 | 6 | room.svg: room.tex 7 | pdflatex --shell-escape room.tex 8 | 9 | clean: 10 | rm -f example_id.fls 11 | rm -f example_id.log 12 | rm -f example_id.svg 13 | rm -f example_id.pdf 14 | rm -f example_id.blg 15 | rm -f example_id.fdb_latexmk 16 | rm -f example_id.aux 17 | rm -f example_id.bbl 18 | -------------------------------------------------------------------------------- /doc/images/example_id.tex: -------------------------------------------------------------------------------- 1 | %\documentclass[tikz,border=5,convert={outfile=\jobname.svg}]{standalone} 2 | \documentclass[tikz,border=5,convert=pdf2svg]{standalone} 3 | 4 | % Use TikZ to draw images 5 | \usepackage{pgfplots} % Import der Plots aus Matlab 6 | \usepackage{tikz} 7 | \pgfplotsset{compat=1.9} 8 | 9 | \usetikzlibrary{arrows} 10 | \usetikzlibrary{patterns} 11 | \usetikzlibrary{backgrounds} 12 | \usetikzlibrary{fit} 13 | \usetikzlibrary{positioning} 14 | \usetikzlibrary{shapes.geometric} 15 | \usetikzlibrary{calc} 16 | \usetikzlibrary{shapes.multipart} 17 | \usetikzlibrary{shapes.misc} 18 | 19 | % \usepackage[decimalsymbol=comma, expproduct=times]{siunitx} 20 | \usepackage[expproduct=times]{siunitx} 21 | 22 | \begin{document} 23 | \begin{tikzpicture}[x=5em, y=3em, scale=1] 24 | 25 | \tikzstyle{part}=[inner sep=0, anchor=west] 26 | 27 | % Example id: 0_4k6c0303_4k4c0319 28 | 29 | 30 | \node[part] (scenario) {\Huge 0}; 31 | \node[part] (underscore2) at (scenario.east) {\Huge \vphantom{A}\_}; 32 | \node[part] (spk1) at (underscore2.east) {\Huge 4k6}; 33 | \node[part] (spk1utt) at (spk1.east) {\Huge c0303}; 34 | \node[part] (underscore1) at (spk1utt.east) {\Huge \vphantom{A}\_}; 35 | \node[part] (spk2) at (underscore1.east) {\Huge 4k4}; 36 | \node[part] (spk2utt) at (spk2.east) {\Huge c0319}; 37 | 38 | 39 | \tikzstyle{brace1}=[ 40 | thick, 41 | decoration={ 42 | brace, 43 | mirror, 44 | raise=0.5em, 45 | amplitude=3, 46 | }, 47 | decorate, 48 | ] 49 | \tikzstyle{brace1text}=[pos=0.5,anchor=north,yshift=-0.5em-0.2em, scale=1] 50 | \tikzstyle{brace1textUp}=[pos=0.5,anchor=south,yshift=0.5em+0.2em, scale=1] 51 | 52 | \tikzstyle{brace2}=[ 53 | thick, 54 | decoration={ 55 | brace, 56 | mirror, 57 | raise=2.5em, 58 | amplitude=3, 59 | }, 60 | decorate, 61 | ] 62 | \tikzstyle{brace2text}=[pos=0.5,anchor=north,yshift=-2.5em-0.2em, scale=1] 63 | 64 | \draw [brace1] (spk1.south west) -- (spk1.south east) node [brace1text] {Speaker 0}; 65 | \draw [brace1] (spk2.south west) -- (spk2.south east) node [brace1text] {Speaker 1}; 66 | \draw [brace1] (scenario.north east) -- (scenario.north west) node [brace1textUp] {Scenario counter}; 67 | 68 | \draw [brace2] (spk1.south west) -- (spk1utt.south east) node [brace2text] {WSJ utterance 0}; 69 | \draw [brace2] (spk2.south west) -- (spk2utt.south east) node [brace2text] {WSJ utterance 1}; 70 | 71 | 72 | \end{tikzpicture} 73 | \end{document} 74 | -------------------------------------------------------------------------------- /scripts/cmd.sh: -------------------------------------------------------------------------------- 1 | # you can change cmd.sh depending on what type of queue you are using. 2 | # If you have no queueing system and want to run on a local machine, you 3 | # can change all instances 'queue.pl' to run.pl (but be careful and run 4 | # commands one by one: most recipes will exhaust the memory on your 5 | # machine). queue.pl works with GridEngine (qsub). slurm.pl works 6 | # with slurm. Different queues are configured differently, with different 7 | # queue names and different ways of specifying things like memory; 8 | # to account for these differences you can create and edit the file 9 | # conf/queue.conf to match your queue's configuration. Search for 10 | # conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, 11 | # or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. 12 | 13 | export train_cmd="run.pl --mem 2G" 14 | export decode_cmd="run.pl --mem 4G" 15 | -------------------------------------------------------------------------------- /scripts/get_nnet3_model.bash: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # This script is an adjusted version of kaldis wsj run_tdnn_1g.sh script. 4 | # We reduced the number tdnn-f layer, batch size and training jobs to 5 | # fit smaller gpus. 6 | # Training and evaluation on wsj_8k leads to the following WER: 7 | # this script tdnn1g_sp 8 | # WER dev93 (tgpr) 7.40 6.68 9 | # WER eval92 (tgpr) 5.59 4.54 10 | # Training and evaluation on sms_wsj with a single speaker leads to the following WER: 11 | # WER cv_dev93 (tgpr) 12.20 12 | # WER test_eval92 (tgpr) 8.93 13 | 14 | # Exit on error: https://stackoverflow.com/a/1379904/911441 15 | set -e 16 | 17 | dest_dir= 18 | nj=16 19 | dataset=sms 20 | train_set=train_si284 21 | cv_sets=cv_dev93 22 | 23 | gmm=tri4b 24 | gmm_data_type=wsj_8k 25 | ali_data_type=sms_early 26 | stage=5 27 | 28 | num_threads_ubm=32 29 | nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. 30 | 31 | # Options which are not passed through to run_ivector_common.sh 32 | affix=1a #affix for TDNN+LSTM directory e.g. "1a" or "1b", in case we change the configuration. 33 | common_egs_dir= 34 | reporting_email= 35 | 36 | # LSTM/chain options 37 | train_stage=-10 38 | xent_regularize=0.1 39 | dropout_schedule='0,0@0.20,0.5@0.50,0' 40 | 41 | # training chunk-options 42 | chunk_width=140,100,160 43 | # we don't need extra left/right context for TDNN systems. 44 | chunk_left_context=0 45 | chunk_right_context=0 46 | 47 | # training options 48 | srand=0 49 | remove_egs=true 50 | 51 | . ./cmd.sh 52 | . ./path.sh 53 | . ${KALDI_ROOT}/egs/wsj/s5/utils/parse_options.sh 54 | 55 | green='\033[0;32m' 56 | NC='\033[0m' # No Color 57 | trap 'echo -e "${green}$ $BASH_COMMAND ${NC}"' DEBUG 58 | 59 | if ! cuda-compiled; then 60 | cat < $dir/configs/network.xconfig 96 | input dim=100 name=ivector 97 | input dim=40 name=input 98 | 99 | # please note that it is important to have input layer with the name=input 100 | # as the layer immediately preceding the fixed-affine-layer to enable 101 | # the use of short notation for the descriptor 102 | fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat 103 | 104 | # the first splicing is moved before the lda layer, so no splicing here 105 | relu-batchnorm-dropout-layer name=tdnn1 $tdnn_opts dim=512 106 | tdnnf-layer name=tdnnf2 $tdnnf_opts dim=512 bottleneck-dim=128 time-stride=1 107 | tdnnf-layer name=tdnnf3 $tdnnf_opts dim=512 bottleneck-dim=128 time-stride=1 108 | tdnnf-layer name=tdnnf4 $tdnnf_opts dim=512 bottleneck-dim=128 time-stride=1 109 | tdnnf-layer name=tdnnf5 $tdnnf_opts dim=512 bottleneck-dim=128 time-stride=0 110 | tdnnf-layer name=tdnnf6 $tdnnf_opts dim=512 bottleneck-dim=128 time-stride=3 111 | tdnnf-layer name=tdnnf7 $tdnnf_opts dim=512 bottleneck-dim=128 time-stride=3 112 | tdnnf-layer name=tdnnf8 $tdnnf_opts dim=512 bottleneck-dim=128 time-stride=3 113 | linear-component name=prefinal-l dim=192 $linear_opts 114 | 115 | 116 | prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1024 small-dim=192 117 | output-layer name=output include-log-softmax=false dim=$num_targets $output_opts 118 | 119 | prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1024 small-dim=192 120 | output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts 121 | EOF 122 | steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ 123 | fi 124 | 125 | ################################################################################ 126 | # Train network 127 | ############################################################################# 128 | if [ $stage -le 18 ]; then 129 | 130 | steps/nnet3/chain/train.py --stage=$train_stage \ 131 | --cmd="$decode_cmd" \ 132 | --feat.online-ivector-dir=$train_ivector_dir \ 133 | --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ 134 | --chain.xent-regularize $xent_regularize \ 135 | --chain.leaky-hmm-coefficient=0.1 \ 136 | --chain.l2-regularize=0.0 \ 137 | --chain.apply-deriv-weights=false \ 138 | --chain.lm-opts="--num-extra-lm-states=2000" \ 139 | --trainer.dropout-schedule $dropout_schedule \ 140 | --trainer.add-option="--optimization.memory-compression-level=2" \ 141 | --trainer.srand=$srand \ 142 | --trainer.max-param-change=2.0 \ 143 | --trainer.num-epochs=10 \ 144 | --trainer.frames-per-iter=5000000 \ 145 | --trainer.optimization.num-jobs-initial=1 \ 146 | --trainer.optimization.num-jobs-final=1 \ 147 | --trainer.optimization.initial-effective-lrate=0.0005 \ 148 | --trainer.optimization.final-effective-lrate=0.00005 \ 149 | --trainer.num-chunk-per-minibatch=128,64 \ 150 | --trainer.optimization.momentum=0.0 \ 151 | --egs.chunk-width=$chunk_width \ 152 | --egs.chunk-left-context=0 \ 153 | --egs.chunk-right-context=0 \ 154 | --egs.dir="$common_egs_dir" \ 155 | --egs.opts="--frames-overlap-per-eg 0" \ 156 | --cleanup.remove-egs=$remove_egs \ 157 | --use-gpu=true \ 158 | --reporting.email="$reporting_email" \ 159 | --feat-dir=$train_data_dir \ 160 | --tree-dir=$tree_dir \ 161 | --lat-dir=$lat_dir \ 162 | --dir=$dir || exit 1; 163 | fi 164 | 165 | ################################################################################ 166 | # Decode on the dev set with lm rescoring 167 | ############################################################################# 168 | if [ $stage -le 19 ]; then 169 | # The reason we are using data/lang here, instead of $lang, is just to 170 | # emphasize that it's not actually important to give mkgraph.sh the 171 | # lang directory with the matched topology (since it gets the 172 | # topology file from the model). So you could give it a different 173 | # lang directory, one that contained a wordlist and LM of your choice, 174 | # as long as phones.txt was compatible. 175 | 176 | utils/lang/check_phones_compatible.sh \ 177 | data/lang_test_tgpr/phones.txt $lang/phones.txt 178 | utils/mkgraph.sh \ 179 | --self-loop-scale 1.0 data/lang_test_tgpr \ 180 | $tree_dir $tree_dir/graph_tgpr || exit 1; 181 | fi 182 | if [ $stage -le 20 ]; then 183 | frames_per_chunk=$(echo $chunk_width | cut -d, -f1) 184 | rm $dir/.error 2>/dev/null || true 185 | 186 | for data in $cv_sets; do 187 | data_affix=$(echo $data | sed s/test_//) 188 | nspk=$(wc -l &2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 4 | . $KALDI_ROOT/tools/config/common_path.sh 5 | export LC_ALL=C 6 | -------------------------------------------------------------------------------- /scripts/prepare_nnet3_model_training.bash: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Exit on error: https://stackoverflow.com/a/1379904/911441 3 | set -e 4 | 5 | dest_dir= 6 | nj=16 7 | dataset=sms 8 | train_set=train_si284 9 | cv_sets=cv_dev93 10 | train_cmd=run.pl 11 | gmm=tri4b 12 | gmm_data_type=wsj_8k 13 | ali_data_type=sms_early 14 | stage=5 15 | 16 | num_threads_ubm=32 17 | nnet3_affix= # affix for exp dirs, e.g. it was _cleaned in tedlium. 18 | 19 | 20 | 21 | . ./cmd.sh 22 | . ./path.sh 23 | . ${KALDI_ROOT}/egs/wsj/s5/utils/parse_options.sh 24 | 25 | cd ${dest_dir} 26 | 27 | green='\033[0;32m' 28 | NC='\033[0m' # No Color 29 | trap 'echo -e "${green}$ $BASH_COMMAND ${NC}"' DEBUG 30 | 31 | 32 | ################################################################################ 33 | # Extract MFCC features 34 | ############################################################################# 35 | 36 | # Now make MFCC features. 37 | # mfccdir should be some place with a largish disk where you 38 | # want to store MFCC features. 39 | export mfccdir=mfcc 40 | if [ $stage -le 6 ]; then 41 | for x in ${train_set} $cv_sets ; do 42 | steps/make_mfcc.sh --nj $nj --cmd "$train_cmd" \ 43 | data/$dataset/$x exp/$dataset/make_mfcc/$x $mfccdir 44 | steps/compute_cmvn_stats.sh data/$dataset/$x exp/$dataset/make_mfcc/$x $mfccdir 45 | utils/fix_data_dir.sh data/$dataset/$x 46 | done 47 | fi 48 | ################################################################################ 49 | # cleanup data 50 | ############################################################################# 51 | # ToDo: should we use kaldi clean up? 52 | #steps/cleanup/clean_and_segment_data.sh --nj 64 --cmd run.pl \ 53 | # --segmentation-opts "--min-segment-length 0.3 --min-new-segment-length 0.6" \ 54 | # data/${dataset}/$train_set data/lang exp/$gmm exp/{dataset}/tri4b_cleaned \ 55 | # data/${dataset}/${train_set}_cleaned 56 | 57 | ################################################################################ 58 | # Estimate ivectors 59 | ############################################################################# 60 | # The iVector-extraction and feature-dumping parts are the same as the standard 61 | # nnet3 setup, and you can skip them by setting "--stage 11" if you have already 62 | # run those things. 63 | local_sms/run_ivector_common.sh \ 64 | --stage $stage --nj $nj \ 65 | --test_sets $cv_sets \ 66 | --train-set $train_set --dataset $dataset \ 67 | --gmm $gmm --gmm_data_type $gmm_data_type\ 68 | --ali_data_type $ali_data_type \ 69 | --num-threads-ubm $num_threads_ubm \ 70 | --nnet3-affix "$nnet3_affix" || exit 1; 71 | 72 | 73 | 74 | gmm_dir=exp/$gmm_data_type/${gmm} 75 | ali_dir=exp/$ali_data_type/${gmm}_ali_${train_set}_sp 76 | lat_dir=exp/$dataset/chain${nnet3_affix}/${gmm}_${train_set}_sp_lats 77 | dir=exp/$dataset/chain${nnet3_affix}/tdnn${affix}_sp 78 | train_data_dir=data/$dataset/${train_set}_sp_hires 79 | train_ivector_dir=exp/$dataset/nnet3${nnet3_affix}/ivectors_${train_set}_sp_hires 80 | lores_train_data_dir=data/$ali_data_type/${train_set}_sp 81 | 82 | 83 | ################################################################################ 84 | # get tree, lats etc 85 | ############################################################################# 86 | # note: you don't necessarily have to change the treedir name 87 | # each time you do a new experiment-- only if you change the 88 | # configuration in a way that affects the tree. 89 | tree_dir=exp/$dataset/chain${nnet3_affix}/tree_a_sp 90 | # the 'lang' directory is created by this script. 91 | # If you create such a directory with a non-standard topology 92 | # you should probably name it differently. 93 | lang=data/lang_chain 94 | 95 | for f in $train_data_dir/feats.scp $train_ivector_dir/ivector_online.scp \ 96 | $lores_train_data_dir/feats.scp $gmm_dir/final.mdl \ 97 | $ali_dir/ali.1.gz $gmm_dir/final.mdl; do 98 | [ ! -f $f ] && echo "$0: expected file $f to exist" && exit 1 99 | done 100 | 101 | 102 | if [ $stage -le 14 ]; then 103 | echo "$0: creating lang directory $lang with chain-type topology" 104 | # Create a version of the lang/ directory that has one state per phone in the 105 | # topo file. [note, it really has two states.. the first one is only repeated 106 | # once, the second one has zero or more repeats.] 107 | if [ -d $lang ]; then 108 | if [ $lang/L.fst -nt data/lang/L.fst ]; then 109 | echo "$0: $lang already exists, not overwriting it; continuing" 110 | else 111 | echo "$0: $lang already exists and seems to be older than data/lang..." 112 | echo " ... not sure what to do. Exiting." 113 | exit 1; 114 | fi 115 | else 116 | cp -r data/lang $lang 117 | silphonelist=$(cat $lang/phones/silence.csl) || exit 1; 118 | nonsilphonelist=$(cat $lang/phones/nonsilence.csl) || exit 1; 119 | # Use our special topology... note that later on may have to tune this 120 | # topology. 121 | steps/nnet3/chain/gen_topo.py $nonsilphonelist $silphonelist >$lang/topo 122 | fi 123 | fi 124 | 125 | if [ $stage -le 15 ]; then 126 | # Get the alignments as lattices (gives the chain training more freedom). 127 | # use the same num-jobs as the alignments 128 | steps/align_fmllr_lats.sh --nj $nj --cmd "$train_cmd" ${lores_train_data_dir} \ 129 | data/lang $gmm_dir $lat_dir 130 | rm $lat_dir/fsts.*.gz # save space 131 | fi 132 | 133 | if [ $stage -le 16 ]; then 134 | # Build a tree using our new topology. We know we have alignments for the 135 | # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use 136 | # those. The num-leaves is always somewhat less than the num-leaves from 137 | # the GMM baseline. 138 | if [ -f $tree_dir/final.mdl ]; then 139 | echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." 140 | exit 1; 141 | fi 142 | steps/nnet3/chain/build_tree.sh \ 143 | --frame-subsampling-factor 3 \ 144 | --context-opts "--context-width=2 --central-position=1" \ 145 | --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ 146 | $lang $ali_dir $tree_dir 147 | fi -------------------------------------------------------------------------------- /scripts/run_ivector_common.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e -o pipefail 4 | 5 | # This is a adjusted copy of the common ivector estimation skript provided 6 | # by kaldi. Most changes are regarding the expected file structure and stages 7 | 8 | # This script is called from scripts like local_sms/get_nnet3_model.bash 9 | # (and may eventually be called by more scripts). It 10 | # contains the common feature preparation and iVector-related parts of the 11 | # script. See those scripts for examples of usage. 12 | 13 | 14 | stage=7 15 | nj=30 16 | train_set=train_si284 # you might set this to e.g. train. 17 | test_sets="cv_dev93 test_eval92" 18 | gmm=tri4b # This specifies a GMM-dir from the features of the type you're training the system on; 19 | gmm_data_type=wsj_8k 20 | ali_data_type=sms_early 21 | dataset=sms 22 | num_threads_ubm=32 23 | nnet3_affix= # affix for exp/nnet3 directory to put iVector stuff in (e.g. 24 | # in the tedlium recip it's _cleaned). 25 | 26 | . ./cmd.sh 27 | . ./path.sh 28 | . utils/parse_options.sh 29 | 30 | gmm_dir=exp/$gmm_data_type/${gmm} 31 | ali_dir=exp/$ali_data_type/${gmm}_ali_${train_set}_sp 32 | 33 | for f in data/$dataset/${train_set}/feats.scp ${gmm_dir}/final.mdl; do 34 | if [ ! -f $f ]; then 35 | echo "$0: expected file $f to exist" 36 | exit 1 37 | fi 38 | done 39 | 40 | 41 | 42 | if [ $stage -le 7 ] && [ -f data/$dataset/${train_set}_sp_hires/feats.scp ]; then 43 | echo "$0: data/${train_set}_sp_hires/feats.scp already exists." 44 | echo " ... Please either remove it, or rerun this script with stage > 2." 45 | exit 1 46 | fi 47 | 48 | 49 | if [ $stage -le 7 ]; then 50 | echo "$0: preparing directory for speed-perturbed data" 51 | utils/data/perturb_data_dir_speed_3way.sh data/$dataset/${train_set} data/$dataset/${train_set}_sp 52 | fi 53 | 54 | if [ $stage -le 8 ]; then 55 | echo "$0: creating high-resolution MFCC features" 56 | 57 | 58 | mfccdir=data/$dataset/${train_set}_sp_hires/data 59 | 60 | for datadir in ${train_set}_sp ${test_sets}; do 61 | utils/copy_data_dir.sh data/$dataset/$datadir data/$dataset/${datadir}_hires 62 | done 63 | 64 | # do volume-perturbation on the training data prior to extracting hires 65 | # features; this helps make trained nnets more invariant to test data volume. 66 | utils/data/perturb_data_dir_volume.sh data/$dataset/${train_set}_sp_hires 67 | 68 | for datadir in ${train_set}_sp ${test_sets}; do 69 | steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ 70 | --cmd "$train_cmd" data/$dataset/${datadir}_hires 71 | steps/compute_cmvn_stats.sh data/$dataset/${datadir}_hires 72 | utils/fix_data_dir.sh data/$dataset/${datadir}_hires 73 | done 74 | fi 75 | 76 | if [ $stage -le 8 ]; then 77 | echo "$0: computing a subset of data to train the diagonal UBM." 78 | 79 | mkdir -p exp/$dataset/nnet3${nnet3_affix}/diag_ubm 80 | temp_data_root=exp/$dataset/nnet3${nnet3_affix}/diag_ubm 81 | 82 | # train a diagonal UBM using a subset of about a quarter of the data 83 | num_utts_total=$(wc -l =1.3.0 80 | ], 81 | 82 | # Installation problems in a clean, new environment: 83 | # 1. `cython` and `scipy` must be installed manually before using 84 | # `pip install` 85 | # 2. `pyzmq` has to be installed manually, otherwise `pymatbridge` will 86 | # complain 87 | 88 | # List additional groups of dependencies here (e.g. development 89 | # dependencies). You can install these using the following syntax, 90 | # for example: 91 | # $ pip install -e .[dev,test] 92 | extras_require={ 93 | 'all': [ 94 | 'pytest', 95 | 'IPython', 96 | 'matplotlib', 97 | 'scipy', 98 | 'pandas', 99 | 'einops', 100 | 'nara_wpe', 101 | 'paderbox', 102 | 'pb_bss @ git+https://github.com/fgnt/pb_bss', # Used for the table, that is shown in the publication 103 | ] 104 | }, 105 | ) 106 | -------------------------------------------------------------------------------- /sms_wsj/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | git_root = Path(__file__).parent.parent.resolve().expanduser() 3 | -------------------------------------------------------------------------------- /sms_wsj/database/__init__.py: -------------------------------------------------------------------------------- 1 | from .database import SmsWsj, AudioReader 2 | from .utils import scenario_map_fn 3 | -------------------------------------------------------------------------------- /sms_wsj/database/create_intermediate_json.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import functools 4 | from collections import defaultdict 5 | from copy import copy 6 | from pathlib import Path 7 | import warnings 8 | 9 | import numpy as np 10 | import sacred 11 | import soundfile 12 | from lazy_dataset.database import JsonDatabase 13 | 14 | from sms_wsj.database.create_rirs import get_rng 15 | 16 | ex = sacred.Experiment('Create intermediate SMS-WSJ json') 17 | 18 | PUNCTUATION_SYMBOLS = set(''' 19 | &ERSAND 20 | ,COMMA 21 | ;SEMI-COLON 22 | :COLON 23 | !EXCLAMATION-POINT 24 | ...ELLIPSIS 25 | -HYPHEN 26 | .PERIOD 27 | .DOT 28 | ?QUESTION-MARK 29 | 30 | .DECIMAL 31 | .PERCENT 32 | /SLASH 33 | 34 | 'SINGLE-QUOTE 35 | "DOUBLE-QUOTE 36 | "QUOTE 37 | "UNQUOTE 38 | "END-OF-QUOTE 39 | "END-QUOTE 40 | "CLOSE-QUOTE 41 | "IN-QUOTES 42 | 43 | (PAREN 44 | (PARENTHESES 45 | (IN-PARENTHESIS 46 | (BRACE 47 | (LEFT-PAREN 48 | (PARENTHETICALLY 49 | (BEGIN-PARENS 50 | )CLOSE-PAREN 51 | )CLOSE_PAREN 52 | )END-THE-PAREN 53 | )END-OF-PAREN 54 | )END-PARENS 55 | )CLOSE-BRACE 56 | )RIGHT-PAREN 57 | )UN-PARENTHESES 58 | )PAREN 59 | 60 | {LEFT-BRACE 61 | }RIGHT-BRACE 62 | '''.split()) 63 | 64 | 65 | DEBUG_EXAMPLE_LIMIT = 10 66 | 67 | 68 | def filter_punctuation_pronunciation(example): 69 | transcription = example['kaldi_transcription'].split() 70 | return len(PUNCTUATION_SYMBOLS.intersection(transcription)) == 0 71 | 72 | 73 | def test_example_composition(a, b, speaker_ids): 74 | """ 75 | 76 | Args: 77 | a: List of permutation example indices 78 | b: List of permutation example indices 79 | speaker_ids: Speaker id corresponding to an index 80 | 81 | Returns: 82 | 83 | >>> speaker_ids = np.array(['Alice', 'Bob', 'Carol', 'Carol']) 84 | >>> test_example_composition([0, 1, 2, 3], [2, 3, 1, 0], speaker_ids) 85 | >>> test_example_composition([0, 1, 2, 3], [0, 1, 2, 3], speaker_ids) 86 | Traceback (most recent call last): 87 | ... 88 | AssertionError: ('utterance duplicate', [0, 1, 2, 3], [0, 1, 2, 3]) 89 | >>> test_example_composition([0, 1, 2, 3], [1, 0, 3, 2], speaker_ids) 90 | Traceback (most recent call last): 91 | ... 92 | AssertionError: ('speaker duplicate', 2) 93 | >>> test_example_composition([0, 1, 2, 3], [2, 3, 0, 1], speaker_ids) 94 | Traceback (most recent call last): 95 | ... 96 | AssertionError: ('duplicate pair', 2) 97 | 98 | 99 | 100 | """ 101 | # Ensure that a speaker is not mixed with itself 102 | # This also ensures that an utterance is not mixed with itself 103 | assert np.all(speaker_ids[a] != speaker_ids[b]), ('speaker duplicate', len(a) - np.sum(speaker_ids[a] != speaker_ids[b])) 104 | 105 | # Ensure that any pair of utterances does not appear more than once 106 | tmp = [tuple(sorted(ab)) for ab in zip(a, b)] 107 | assert len(set(tuple(tmp))) == len(a), ('duplicate pair', len(a) - len(set(tuple(tmp)))) 108 | 109 | 110 | def extend_composition_example_greedy(rng, speaker_ids, example_compositions=None, tries=500): 111 | """ 112 | 113 | Args: 114 | rng: 115 | speaker_ids: Speaker id corresponding to an index 116 | example_compositions: 117 | tries: 118 | 119 | Returns: 120 | 121 | >>> rng = np.random.RandomState(0) 122 | >>> speaker_ids = np.array(['Alice', 'Bob', 'Carol', 'Dave', 'Eve']) 123 | >>> comp = extend_composition_example_greedy(rng, speaker_ids) 124 | >>> comp 125 | array([[2], 126 | [0], 127 | [1], 128 | [3], 129 | [4]]) 130 | >>> comp = extend_composition_example_greedy(rng, speaker_ids, comp) 131 | >>> comp 132 | array([[2, 3], 133 | [0, 4], 134 | [1, 2], 135 | [3, 0], 136 | [4, 1]]) 137 | >>> comp = extend_composition_example_greedy(rng, speaker_ids, comp) 138 | >>> comp 139 | array([[2, 3, 1], 140 | [0, 4, 2], 141 | [1, 2, 3], 142 | [3, 0, 4], 143 | [4, 1, 0]]) 144 | >>> speaker_ids[comp] 145 | array([['Carol', 'Dave', 'Bob'], 146 | ['Alice', 'Eve', 'Carol'], 147 | ['Bob', 'Carol', 'Dave'], 148 | ['Dave', 'Alice', 'Eve'], 149 | ['Eve', 'Bob', 'Alice']], dtype='= 0, excess_samples 249 | example["offset"].append(rng.randint(0, excess_samples + 1)) 250 | 251 | example['audio_path']['original_source'] = [ 252 | exa['audio_path']['observation'] for exa in source_examples 253 | ] 254 | # example['audio_path']['rir']: Already defined in rir_example. 255 | return example 256 | 257 | 258 | def combine_rirs_and_sources( 259 | rir_dataset, 260 | source_dataset, 261 | num_speakers, 262 | dataset_name, 263 | ): 264 | # The keys of rir_dataset are integers. Sort the rirs based on this 265 | # integer. 266 | rir_dataset = rir_dataset.sort(sort_fn=functools.partial(sorted, key=int)) 267 | 268 | assert len(rir_dataset) % len(source_dataset) == 0, (len(rir_dataset), len(source_dataset)) 269 | repetitions = len(rir_dataset) // len(source_dataset) 270 | 271 | source_dataset = source_dataset.sort() 272 | source_dataset = list(source_dataset.tile(repetitions)) 273 | 274 | speaker_ids = [example['speaker_id'] for example in source_dataset] 275 | 276 | rng = get_rng(dataset_name, 'example_compositions') 277 | 278 | composition_examples = None 279 | for _ in range(num_speakers): 280 | composition_examples = extend_composition_example_greedy( 281 | rng, speaker_ids, example_compositions=composition_examples, 282 | ) 283 | 284 | ex_dict = dict() 285 | assert len(rir_dataset) == len(composition_examples), (len(rir_dataset), len(composition_examples)) 286 | for rir_example, composition_example in zip( 287 | rir_dataset, composition_examples 288 | ): 289 | source_examples = [source_dataset[i] for i in composition_example] 290 | 291 | example = get_randomized_example( 292 | rir_example, 293 | source_examples, 294 | rng, 295 | dataset_name, 296 | ) 297 | ex_dict[example['example_id']] = example 298 | 299 | return ex_dict 300 | 301 | 302 | @ex.config 303 | def config(): 304 | rir_dir = None 305 | json_path = None 306 | wsj_json_path = None 307 | if rir_dir is None and 'RIR_DIR' in os.environ: 308 | rir_dir = os.environ['RIR_DIR'] 309 | assert rir_dir is not None, 'You have to specify the rir dir' 310 | if wsj_json_path is None and 'WSJ_JSON' in os.environ: 311 | wsj_json_path = os.environ['WSJ_JSON'] 312 | assert wsj_json_path is not None, 'You have to specify a wsj_json_path' 313 | if json_path is None and 'SMS_WSJ_JSON' in os.environ: 314 | json_path = os.environ['SMS_WSJ_JSON'] 315 | assert json_path is not None, 'You have to specify a path for the new json' 316 | 317 | num_speakers = 2 318 | debug = False # If `True`, only creates a few examples per dataset. 319 | 320 | 321 | @ex.automain 322 | def main( 323 | json_path: Path, 324 | rir_dir: Path, 325 | wsj_json_path: Path, 326 | num_speakers: int, 327 | debug: bool, 328 | ): 329 | wsj_json_path = Path(wsj_json_path).expanduser().resolve() 330 | json_path = Path(json_path).expanduser().resolve() 331 | rir_dir = Path(rir_dir).expanduser().resolve() 332 | assert wsj_json_path.is_file(), json_path 333 | assert rir_dir.exists(), rir_dir 334 | 335 | # ToDo: What was the motivation for defining this "setup"? 336 | setup = dict( 337 | train_si284=dict(source_dataset_name="train_si284"), 338 | cv_dev93=dict(source_dataset_name="cv_dev93"), 339 | test_eval92=dict(source_dataset_name="test_eval92"), 340 | ) 341 | 342 | rir_db = JsonDatabase(rir_dir / "scenarios.json") 343 | 344 | source_db = JsonDatabase(wsj_json_path) 345 | 346 | target_db = dict() 347 | target_db['datasets'] = defaultdict(dict) 348 | 349 | for dataset_name in setup.keys(): 350 | source_dataset_name = setup[dataset_name]["source_dataset_name"] 351 | source_dataset = source_db.get_dataset(source_dataset_name) 352 | print(f'length of source {dataset_name}: {len(source_dataset)}') 353 | source_dataset = source_dataset.filter( 354 | filter_fn=filter_punctuation_pronunciation, lazy=False 355 | ) 356 | print( 357 | f'length of source {dataset_name}: {len(source_dataset)} ' 358 | '(after punctuation filter)' 359 | ) 360 | 361 | def add_rir_path(rir_ex): 362 | assert 'audio_path' not in rir_ex, rir_ex 363 | example_id = rir_ex['example_id'] 364 | rir_ex['audio_path'] = {'rir': [ 365 | str(rir_dir / dataset_name / example_id / f"h_{k}.wav") 366 | for k in range(num_speakers) 367 | ]} 368 | return rir_ex 369 | 370 | rir_dataset = rir_db.get_dataset(dataset_name).map(add_rir_path) 371 | 372 | assert len(rir_dataset) % len(source_dataset) == 0, ( 373 | f'To avoid a bias towards certain utterance the len ' 374 | f'rir_dataset ({len(rir_dataset)}) should be an integer ' 375 | f'multiple of len source_dataset ({len(source_dataset)}).' 376 | ) 377 | 378 | print(f'length of rir {dataset_name}: {len(rir_dataset)}') 379 | 380 | probe_path = rir_dir / dataset_name / "0" 381 | available_speaker_positions = len(list(probe_path.glob('h_*.wav'))) 382 | assert num_speakers <= available_speaker_positions, ( 383 | f'Requested {num_speakers} num_speakers, while found only ' 384 | f'{available_speaker_positions} rirs in {probe_path}.' 385 | ) 386 | 387 | info = soundfile.info(str(rir_dir / dataset_name / "0" / "h_0.wav")) 388 | sample_rate_rir = info.samplerate 389 | 390 | ex_wsj = source_dataset.random_choice(1)[0] 391 | info = soundfile.SoundFile(ex_wsj['audio_path']['observation']) 392 | sample_rate_wsj = info.samplerate 393 | assert sample_rate_rir == sample_rate_wsj, ( 394 | sample_rate_rir, sample_rate_wsj 395 | ) 396 | 397 | if debug: 398 | rir_dataset = rir_dataset[:DEBUG_EXAMPLE_LIMIT] 399 | # Use step_size to avoid that only one speaker is in 400 | # source_iterator. 401 | step_size = len(source_dataset) // DEBUG_EXAMPLE_LIMIT 402 | source_dataset = source_dataset[::step_size] 403 | 404 | ex_dict = combine_rirs_and_sources( 405 | rir_dataset=rir_dataset, 406 | source_dataset=source_dataset, 407 | num_speakers=num_speakers, 408 | dataset_name=dataset_name, 409 | ) 410 | 411 | target_db['datasets'][dataset_name] = ex_dict 412 | 413 | json_path.parent.mkdir(exist_ok=True, parents=True) 414 | with json_path.open('w') as f: 415 | json.dump(target_db, f, indent=2, ensure_ascii=False) 416 | print(f'{json_path} written.') 417 | -------------------------------------------------------------------------------- /sms_wsj/database/create_json_for_written_files.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script writes a new json which includes the files 3 | written to disk with sms_wsj.database.write_files.py 4 | 5 | Additionally, the script allows to update the paths 6 | in case of a change in the database location by using 7 | the old sms_wsj.json as intermediate json. 8 | However, this script does not change the speaker 9 | and utterance combination, log weights, etc. which are 10 | specified in the intermediate json. 11 | 12 | """ 13 | 14 | from sms_wsj.database.write_files import check_files, KEY_MAPPER 15 | from sms_wsj.database.utils import _example_id_to_rng 16 | import json 17 | import sacred 18 | from pathlib import Path 19 | from lazy_dataset.database import JsonDatabase 20 | 21 | ex = sacred.Experiment('Write SMS-WSJ json after wav files are written') 22 | 23 | 24 | def create_json(db_dir, intermediate_json_path, write_all, snr_range=(20, 30)): 25 | db = JsonDatabase(intermediate_json_path) 26 | json_dict = dict(datasets=dict()) 27 | database_dict = db.data['datasets'] 28 | 29 | if write_all: 30 | key_mapper = KEY_MAPPER 31 | else: 32 | key_mapper = {'observation': 'observation'} 33 | 34 | for dataset_name, dataset in database_dict.items(): 35 | dataset_dict = dict() 36 | for ex_id, ex in dataset.items(): 37 | for key, data_type in key_mapper.items(): 38 | current_path = db_dir / data_type / dataset_name 39 | if key in ['observation', 'noise_image']: 40 | ex['audio_path'][key] = str(current_path / f'{ex_id}.wav') 41 | else: 42 | ex['audio_path'][key] = [ 43 | str(current_path / f'{ex_id}_{k}.wav') 44 | for k in range(len(ex['speaker_id'])) 45 | ] 46 | 47 | if 'original_source' not in ex['audio_path']: 48 | # legacy code 49 | ex['audio_path']['original_source'] = ex['audio_path']['speech_source'] 50 | 51 | ex['audio_path']['original_source'] = [ 52 | # .../sms_wsj/cache/wsj_8k_zeromean/13-11.1/wsj1/si_tr_s/4ax/4axc0218.wav 53 | str(db_dir.joinpath(*Path(rir).parts[-6:])) 54 | for rir in ex['audio_path']['original_source'] 55 | ] 56 | 57 | rng = _example_id_to_rng(ex_id) 58 | snr = rng.uniform(*snr_range) 59 | if 'dataset' in ex: 60 | del ex['dataset'] 61 | ex["snr"] = snr 62 | dataset_dict[ex_id] = ex 63 | json_dict['datasets'][dataset_name] = dataset_dict 64 | return json_dict 65 | 66 | 67 | @ex.config 68 | def config(): 69 | db_dir = None 70 | intermed_json_path = None 71 | 72 | # If `False`, expects only observation to exist, 73 | # else expect all intermediate signals. 74 | write_all = True 75 | 76 | # Default behavior is to overwrite an existing `sms_wsj.json`. You may 77 | # specify a different path here to change where the JSON is written to. 78 | json_path = None 79 | 80 | snr_range = (20, 30) 81 | 82 | assert db_dir is not None, 'You have to specify a database dir' 83 | assert intermed_json_path is not None, 'You have to specify a path' \ 84 | ' to the original sms_wsj.json' 85 | 86 | debug = False 87 | 88 | 89 | @ex.automain 90 | def main(db_dir, intermed_json_path , write_all, json_path, snr_range): 91 | intermed_json_path = Path(intermed_json_path).expanduser().resolve() 92 | db_dir = Path(db_dir).expanduser().resolve() 93 | if json_path is not None: 94 | json_path = Path(json_path).expanduser().resolve() 95 | else: 96 | json_path = intermed_json_path 97 | print(f'Creating a new json and saving it to {json_path}') 98 | num_wav_files = len(check_files(db_dir)) 99 | message = f'Not all wav files seem to exists, you have {num_wav_files},' \ 100 | f' please check your db directory: {db_dir}' 101 | if write_all: 102 | assert num_wav_files in [(2 * speakers + 2) * 35875 for speakers in [2, 3, 4]], message 103 | else: 104 | assert num_wav_files == 35875, message 105 | updated_json = create_json(db_dir, intermed_json_path , write_all, 106 | snr_range=snr_range) 107 | json_path.parent.mkdir(exist_ok=True, parents=True) 108 | with json_path.open('w') as f: 109 | json.dump(updated_json, f, indent=4, ensure_ascii=False) 110 | print(f'{json_path} written.') 111 | -------------------------------------------------------------------------------- /sms_wsj/database/create_rirs.py: -------------------------------------------------------------------------------- 1 | """Call instructions: 2 | 3 | # When you do not have MPI: 4 | python -m sms_wsj.database.create_rirs with database_path=/Users/lukas/Downloads/temp_wsj_bss debug=True 5 | python -m sms_wsj.database.create_rirs with database_path=temp_wsj_bss debug=True 6 | 7 | # When you have MPI: 8 | mpiexec -np 3 python -m sms_wsj.database.create_rirs with database_path=temp_wsj_bss debug=True 9 | 10 | """ 11 | import hashlib 12 | import json 13 | from collections import defaultdict 14 | from pathlib import Path 15 | 16 | import numpy as np 17 | import soundfile 18 | from sacred import Experiment 19 | from sms_wsj.reverb.reverb_utils import generate_rir 20 | from sms_wsj.reverb.scenario import generate_random_source_positions 21 | from sms_wsj.reverb.scenario import generate_sensor_positions 22 | from sms_wsj.reverb.scenario import sample_from_random_box 23 | 24 | import dlp_mpi 25 | 26 | experiment = Experiment(Path(__file__).stem) 27 | 28 | 29 | @experiment.config 30 | def config(): 31 | debug = False 32 | database_path = "" 33 | 34 | # Either set it to zero or above 0.15 s. Otherwise, RIR contains NaN. 35 | sound_decay_time_range = dict(low=0.2, high=0.5) 36 | 37 | geometry = dict( 38 | number_of_sources=4, 39 | number_of_sensors=6, 40 | sensor_shape="circular", 41 | center=[[4.], [3.], [1.5]], # m 42 | scale=0.1, # m 43 | room=[[8], [6], [3]], # m 44 | random_box=[[0.4], [0.4], [0.4]], # m 45 | ) 46 | 47 | datasets = dict( 48 | train_si284=dict( 49 | count=33561, # 33561 unique non-pp utterances 50 | ), 51 | cv_dev93=dict( 52 | count=491 * 2, # 491 unique non-pp utterances 53 | ), 54 | test_eval92=dict( 55 | count=333 * 4, # 333 unique non-pp utterances 56 | ), 57 | ) 58 | 59 | sample_rate = 8000 60 | filter_length = 2 ** 13 # 1.024 seconds when sample_rate == 8000 61 | 62 | 63 | def get_rng(dataset, example_id): 64 | string = f"{dataset}_{example_id}" 65 | seed = ( 66 | int(hashlib.sha256(string.encode("utf-8")).hexdigest(), 67 | 16) % 2 ** 32 68 | ) 69 | return np.random.RandomState(seed=seed) 70 | 71 | 72 | @experiment.command 73 | def scenarios( 74 | database_path, 75 | datasets, 76 | geometry, 77 | sound_decay_time_range, 78 | debug, 79 | ): 80 | if not dlp_mpi.IS_MASTER: 81 | # It is enough, when one process generates the scenarios.json 82 | return 83 | 84 | assert len(database_path) > 0, "Database path can not be empty." 85 | database_path = Path(database_path).expanduser().resolve() 86 | scenario_json = Path(database_path) / "scenarios.json" 87 | 88 | print(f'from: random') 89 | print(f'to: {database_path}') 90 | 91 | database = defaultdict(lambda: defaultdict(dict)) 92 | for dataset, dataset_config in datasets.items(): 93 | for example_id in range(dataset_config["count"]): 94 | if debug and example_id >= 2: 95 | break 96 | 97 | example_id = str(example_id) 98 | rng = get_rng(dataset, example_id) 99 | room_dimensions = sample_from_random_box( 100 | geometry["room"], geometry["random_box"], rng=rng 101 | ) 102 | center = sample_from_random_box( 103 | geometry["center"], geometry["random_box"], rng=rng 104 | ) 105 | source_positions = generate_random_source_positions( 106 | center=center, 107 | sources=geometry["number_of_sources"], 108 | rng=rng, 109 | ) 110 | sensor_positions = generate_sensor_positions( 111 | shape=geometry["sensor_shape"], 112 | center=center, 113 | scale=geometry["scale"], 114 | number_of_sensors=geometry["number_of_sensors"], 115 | rotate_x=rng.uniform(0, 0.01 * 2 * np.pi), 116 | rotate_y=rng.uniform(0, 0.01 * 2 * np.pi), 117 | rotate_z=rng.uniform(0, 2 * np.pi), 118 | ) 119 | sound_decay_time = rng.uniform(**sound_decay_time_range) 120 | database['datasets'][dataset][example_id] = { 121 | 'room_dimensions': room_dimensions, 122 | 'sound_decay_time': sound_decay_time, 123 | 'source_position': source_positions, 124 | 'sensor_position': sensor_positions, 125 | } 126 | # Use round to make it easier to read the numbers. 127 | # The rounding at this position is allowed, because all values are 128 | # independent of each other. 129 | # This introduces a small jitter on all positions. 130 | database['datasets'][dataset][example_id] = { 131 | k: np.round(v, decimals=3) 132 | for k, v in database['datasets'][dataset][example_id].items() 133 | } 134 | database['datasets'][dataset][example_id].update({ 135 | k: v.tolist() 136 | for k, v in database['datasets'][dataset][example_id].items() 137 | if isinstance(v, np.ndarray) 138 | }) 139 | database['datasets'][dataset][example_id]['example_id'] = \ 140 | example_id 141 | 142 | scenario_json.parent.mkdir(exist_ok=False, parents=True) 143 | with scenario_json.open('w') as f: 144 | json.dump(database, f, indent=2, ensure_ascii=False) 145 | 146 | 147 | @experiment.command 148 | def rirs( 149 | database_path, 150 | datasets, 151 | sample_rate, 152 | filter_length, 153 | ): 154 | database_path = Path(database_path) 155 | 156 | if dlp_mpi.IS_MASTER: 157 | scenario_json = database_path / "scenarios.json" 158 | with scenario_json.open() as f: 159 | database = json.load(f) 160 | for dataset in datasets: 161 | dataset_path = database_path / dataset 162 | dataset_path.mkdir(parents=True, exist_ok=True) 163 | else: 164 | database = None 165 | database = dlp_mpi.bcast(database) 166 | 167 | for dataset_name, dataset in database['datasets'].items(): 168 | 169 | for _example_id, example in dlp_mpi.split_managed( 170 | list(sorted(dataset.items())), 171 | progress_bar=True, 172 | is_indexable=True, 173 | ): 174 | h = generate_rir( 175 | room_dimensions=example['room_dimensions'], 176 | source_positions=example['source_position'], 177 | sensor_positions=example['sensor_position'], 178 | sound_decay_time=example['sound_decay_time'], 179 | sample_rate=sample_rate, 180 | filter_length=filter_length, 181 | sensor_orientations=None, 182 | sensor_directivity=None, 183 | sound_velocity=343 184 | ) 185 | assert not np.any( 186 | np.isnan(h) 187 | ), f"{np.sum(np.isnan(h))} values of {h.size} are NaN." 188 | 189 | K, D, T = h.shape 190 | directory = database_path / dataset_name / _example_id 191 | directory.mkdir(parents=False, exist_ok=False) 192 | 193 | for k in range(K): 194 | # Although storing as np.float64 does not allow every reader 195 | # to access the files, it does not require normalization and 196 | # we are unsure how much precision is needed for RIRs. 197 | with soundfile.SoundFile( 198 | str(directory / f"h_{k}.wav"), subtype='DOUBLE', 199 | samplerate=sample_rate, mode='w', channels=h.shape[1] 200 | ) as f: 201 | f.write(h[k, :, :].T) 202 | 203 | dlp_mpi.barrier() 204 | 205 | print(f'RANK={dlp_mpi.RANK}, SIZE={dlp_mpi.SIZE}:' 206 | f' Finished {dataset_name}.') 207 | 208 | 209 | @experiment.automain 210 | def main(): 211 | scenarios() 212 | rirs() 213 | -------------------------------------------------------------------------------- /sms_wsj/database/database.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dataclasses 3 | 4 | import numpy as np 5 | 6 | import lazy_dataset.database 7 | 8 | 9 | class SmsWsj(lazy_dataset.database.JsonDatabase): 10 | """ 11 | >>> from pprint import pprint 12 | >>> db = SmsWsj() 13 | >>> db.get_dataset() 14 | Traceback (most recent call last): 15 | ... 16 | TypeError: Missing dataset_name, use e.g.: ('train_si284', 'cv_dev93', 'test_eval92') 17 | >>> db.get_dataset('train_si284') 18 | DictDataset(name='train_si284', len=33561) 19 | MapDataset(_pickle.loads) 20 | >>> db.get_dataset('cv_dev93') 21 | DictDataset(name='cv_dev93', len=982) 22 | MapDataset(_pickle.loads) 23 | >>> db.get_dataset('test_eval92') 24 | DictDataset(name='test_eval92', len=1332) 25 | MapDataset(_pickle.loads) 26 | >>> db.get_dataset(['train_si284', 'cv_dev93', 'test_eval92']) 27 | DictDataset(name='train_si284', len=33561) 28 | MapDataset(_pickle.loads) 29 | DictDataset(name='cv_dev93', len=982) 30 | MapDataset(_pickle.loads) 31 | DictDataset(name='test_eval92', len=1332) 32 | MapDataset(_pickle.loads) 33 | ConcatenateDataset() 34 | >>> ds = db.get_dataset('cv_dev93') 35 | >>> pprint(ds[0], width=79-4) # doctest: +ELLIPSIS 36 | {'audio_path': {'noise_image': ..., 37 | 'observation': ..., 38 | 'rir': [...,...], 39 | 'speech_reverberation_early': [...,...], 40 | 'speech_reverberation_tail': [...,...], 41 | 'speech_source': [...,...]}, 42 | 'dataset': 'cv_dev93', 43 | 'example_id': '0_4k6c0303_4k4c0319', 44 | 'gender': ['male', 'female'], 45 | 'kaldi_transcription': [...,...], 46 | 'log_weights': [1.2027951449295022, -1.2027951449295022], 47 | 'num_samples': {'observation': 93389, 'original_source': [31633, 93389]}, 48 | 'num_speakers': 2, 49 | 'offset': [52476, 0], 50 | 'room_dimensions': [[8.169], [5.905], [3.073]], 51 | 'sensor_position': [[4.015, 3.973, 4.03, 4.129, 4.172, 4.115], 52 | [3.265, 3.175, 3.093, 3.102, 3.192, 3.274], 53 | [1.55, 1.556, 1.563, 1.563, 1.558, 1.551]], 54 | 'snr': 23.287502642941252, 55 | 'sound_decay_time': 0.387, 56 | 'source_id': ['4k6c0303', '4k4c0319'], 57 | 'source_position': [[3.312, 3.0], [1.921, 2.379], [1.557, 1.557]], 58 | 'speaker_id': ['4k6', '4k4']} 59 | """ 60 | 61 | @classmethod 62 | def default_json_path(cls): 63 | try: 64 | return os.environ['SMS_WSJ_JSON'] 65 | except KeyError as e: 66 | name = cls.__name__ 67 | raise ValueError( 68 | f'To instantiate the {name} database,\n' 69 | f'you have to provide the path to the json that\n' 70 | f'describes the database.\n' 71 | f'This can be done with\n' 72 | f'\t>>> `{name}()`\n' 73 | f'or setting the environment variable\n' 74 | f'\t$ export SMS_WSJ_JSON=\n' 75 | f'and drop the argument is python\n' 76 | f'\t>>> `{name}()`' 77 | ) from e 78 | 79 | def __init__(self, json_path=None): 80 | if json_path is None: 81 | json_path = self.default_json_path() 82 | 83 | super().__init__(json_path) 84 | 85 | 86 | class AudioReader: 87 | """ 88 | Reads the audio data of an example. 89 | The paths are in `example['audio_path']` and will be written to 90 | `example['audio_data']`. 91 | This reader is usually used as a mapping in a dataset: 92 | 93 | >>> from IPython.lib.pretty import pprint 94 | >>> np.set_string_function(lambda a: f'array(shape={a.shape}, dtype={a.dtype})') 95 | 96 | >>> db = SmsWsj() 97 | >>> ds = db.get_dataset('cv_dev93') 98 | >>> ds = ds.map(AudioReader()) 99 | >>> example = ds[0] 100 | >>> pprint(example['audio_data']) 101 | {'observation': array(shape=(6, 103650), dtype=float64), 102 | 'speech_source': array(shape=(2, 103650), dtype=float64), 103 | 'speech_reverberation_early': array(shape=(2, 6, 103650), dtype=float64), 104 | 'speech_reverberation_tail': array(shape=(2, 6, 103650), dtype=float64), 105 | 'speech_image': array(shape=(2, 6, 103650), dtype=float64), 106 | 'noise_image': array(shape=(6, 103650), dtype=float64)} 107 | """ 108 | all_keys = ( 109 | 'observation', 110 | 'speech_source', 111 | 'original_source', 112 | 'speech_reverberation_early', 113 | 'speech_reverberation_tail', 114 | 'speech_image', 115 | 'noise_image', 116 | 'rir', 117 | ) 118 | 119 | def __init__( 120 | self, 121 | keys=( 122 | 'observation', 123 | 'speech_source', 124 | 'original_source', 125 | 'speech_reverberation_early', 126 | 'speech_reverberation_tail', 127 | 'speech_image', 128 | 'noise_image', 129 | # 'rir', 130 | ), 131 | sync_speech_source: bool = True, # legacy 132 | ): 133 | keys = list(keys) 134 | 135 | if 'speech_source' in keys: 136 | if 'original_source' not in keys: 137 | keys.append('original_source') 138 | keys.remove('speech_source') 139 | self.speech_source = True 140 | else: 141 | self.speech_source = False 142 | 143 | if 'speech_image' in keys: 144 | if 'speech_reverberation_early' not in keys: 145 | keys.append('speech_reverberation_early') 146 | if 'speech_reverberation_tail' not in keys: 147 | keys.append('speech_reverberation_tail') 148 | self.speech_image = True 149 | keys.remove('speech_image') 150 | else: 151 | self.speech_image = False 152 | 153 | self.keys = tuple(keys) 154 | self.sync_speech_source = sync_speech_source 155 | 156 | @classmethod 157 | def _rec_audio_read(cls, file): 158 | import soundfile 159 | 160 | if isinstance(file, (tuple, list)): 161 | return np.array([cls._rec_audio_read(f) for f in file]) 162 | elif isinstance(file, (dict)): 163 | return {k: cls._rec_audio_read(v) for k, v in file.items()} 164 | else: 165 | data, sample_rate = soundfile.read(file) 166 | return data.T 167 | 168 | def __call__(self, example): 169 | data = {} 170 | path = example['audio_path'] 171 | 172 | for k in self.keys: 173 | if k == 'original_source' and k not in path: 174 | # legacy code 175 | path[k] = path['speech_source'] 176 | data[k] = self._rec_audio_read(path[k]) 177 | 178 | if self.speech_source: 179 | if self.sync_speech_source: 180 | from sms_wsj.database.utils import synchronize_speech_source 181 | data['speech_source'] = synchronize_speech_source( 182 | data['original_source'], 183 | example['offset'], 184 | T=example['num_samples']['observation'], 185 | ) 186 | else: 187 | # legacy code 188 | data['speech_source'] = data['original_source'] 189 | 190 | if self.speech_image: 191 | data['speech_image'] = ( 192 | data['speech_reverberation_early'] 193 | + data['speech_reverberation_tail'] 194 | ) 195 | 196 | example['audio_data'] = data 197 | return example 198 | -------------------------------------------------------------------------------- /sms_wsj/database/dynamic_mixing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import lazy_dataset 4 | from sms_wsj.database.create_intermediate_json import combine_rirs_and_sources 5 | 6 | 7 | def filter_duplicates(l): 8 | """ 9 | >>> filter_duplicates([{'a': 1}, {'b': 1}, {'a': 1}]) 10 | [{'a': 1}, {'b': 1}] 11 | """ 12 | def make_hashable(o): 13 | try: 14 | hash(o) 15 | return o 16 | except TypeError: 17 | return helper[type(o)](o) 18 | 19 | helper = { 20 | set: lambda o: tuple([make_hashable(e) for e in o]), 21 | tuple: lambda o: tuple([make_hashable(e) for e in o]), 22 | list: lambda o: tuple([make_hashable(e) for e in o]), 23 | dict: lambda o: frozenset( 24 | [(make_hashable(k), make_hashable(v)) for k, v in 25 | o.items()]), 26 | } 27 | 28 | l = list(l) 29 | 30 | return list({ 31 | hashable: entry 32 | for hashable, entry in zip(make_hashable(l), l) 33 | }.values()) 34 | 35 | 36 | def split_rirs_and_sources( 37 | ds 38 | ): 39 | """ 40 | Split a dataset in the rir and source dataset. 41 | These datasets can be used to recreate a dataset with combine_rirs_and_sources. 42 | The `dataset_name` argument can be used to either get the same dataset 43 | or a random new dataset with different utterance pairs. 44 | 45 | db = SmsWsj(...) 46 | ds = db.get_dataset('train_si284') 47 | rir_ds, source_ds = split_rirs_and_sources(ds) 48 | ds_new = lazy_dataset.new(combine_rirs_and_sources(rir_ds, source_ds, 2, f'train_si284_rng{np.random.randint(0, 2**32)}')) 49 | 50 | Note: The new dataset have to use the `scenario_map_fn`, because the 51 | observation and the intermediate signal have to be calculated on 52 | demand. 53 | 54 | >>> import os, re 55 | >>> from paderbox.utils.pretty import pprint 56 | >>> from pprint import pprint, pformat 57 | >>> from sms_wsj.database.database import SmsWsj 58 | >>> def print_ex(ex): 59 | ... print(re.sub(r"'[^']+(?=/(?:wsj_8k_zeromean|speech_source|early|tail|rirs|noise|observation)/[^']+.wav')", r"'...", pformat(ex))) 60 | >>> db = SmsWsj(os.environ.get('NT_DATABASE_JSONS_DIR') + '/sms_wsj.json') 61 | >>> ds = db.get_dataset('cv_dev93') 62 | >>> print_ex(ds[0]) 63 | {'audio_path': {'noise_image': '.../noise/cv_dev93/0_4k6c0303_4k4c0319.wav', 64 | 'observation': '.../observation/cv_dev93/0_4k6c0303_4k4c0319.wav', 65 | 'original_source': ['.../wsj_8k_zeromean/13-16.1/wsj1/si_dt_20/4k6/4k6c0303.wav', 66 | '.../wsj_8k_zeromean/13-16.1/wsj1/si_dt_20/4k4/4k4c0319.wav'], 67 | 'rir': ['.../rirs/cv_dev93/0/h_0.wav', 68 | '.../rirs/cv_dev93/0/h_1.wav'], 69 | 'speech_reverberation_early': ['.../early/cv_dev93/0_4k6c0303_4k4c0319_0.wav', 70 | '.../early/cv_dev93/0_4k6c0303_4k4c0319_1.wav'], 71 | 'speech_reverberation_tail': ['.../tail/cv_dev93/0_4k6c0303_4k4c0319_0.wav', 72 | '.../tail/cv_dev93/0_4k6c0303_4k4c0319_1.wav'], 73 | 'speech_source': ['.../speech_source/cv_dev93/0_4k6c0303_4k4c0319_0.wav', 74 | '.../speech_source/cv_dev93/0_4k6c0303_4k4c0319_1.wav']}, 75 | 'dataset': 'cv_dev93', 76 | 'example_id': '0_4k6c0303_4k4c0319', 77 | 'gender': ['male', 'female'], 78 | 'kaldi_transcription': ['IN ADDITION TO DEFORESTATION EXAMPLES ARE', 79 | 'THE PROFIT HAS BEEN PLOWED BACK INTO THE BANK WHICH ' 80 | 'HAS PURSUED ITS MISSION TO REBUILD A DECAYING ' 81 | 'NEIGHBORHOOD WITH A SINGULAR FOCUS'], 82 | 'log_weights': [0.9885484337248203, -0.9885484337248203], 83 | 'num_samples': {'observation': 93389, 'original_source': [31633, 93389]}, 84 | 'num_speakers': 2, 85 | 'offset': [52476, 0], 86 | 'room_dimensions': [[8.169], [5.905], [3.073]], 87 | 'sensor_position': [[4.015, 3.973, 4.03, 4.129, 4.172, 4.115], 88 | [3.265, 3.175, 3.093, 3.102, 3.192, 3.274], 89 | [1.55, 1.556, 1.563, 1.563, 1.558, 1.551]], 90 | 'snr': 23.287502642941252, 91 | 'sound_decay_time': 0.387, 92 | 'source_id': ['4k6c0303', '4k4c0319'], 93 | 'source_position': [[3.312, 3.0], [1.921, 2.379], [1.557, 1.557]], 94 | 'speaker_id': ['4k6', '4k4']} 95 | >>> rir_ds, source_ds = split_rirs_and_sources(ds) 96 | >>> print_ex(rir_ds[0]) 97 | {'audio_path': {'rir': ['.../rirs/cv_dev93/0/h_0.wav', 98 | '.../rirs/cv_dev93/0/h_1.wav']}, 99 | 'example_id': '0', 100 | 'room_dimensions': [[8.169], [5.905], [3.073]], 101 | 'sensor_position': [[4.015, 3.973, 4.03, 4.129, 4.172, 4.115], 102 | [3.265, 3.175, 3.093, 3.102, 3.192, 3.274], 103 | [1.55, 1.556, 1.563, 1.563, 1.558, 1.551]], 104 | 'sound_decay_time': 0.387, 105 | 'source_position': [[3.312, 3.0], [1.921, 2.379], [1.557, 1.557]]} 106 | >>> print_ex(source_ds[0]) 107 | {'audio_path': {'observation': '.../wsj_8k_zeromean/13-16.1/wsj1/si_dt_20/4k6/4k6c0303.wav'}, 108 | 'example_id': '4k6c0303', 109 | 'gender': 'male', 110 | 'kaldi_transcription': 'IN ADDITION TO DEFORESTATION EXAMPLES ARE', 111 | 'num_samples': 31633, 112 | 'speaker_id': '4k6'} 113 | >>> ds_new = lazy_dataset.new(combine_rirs_and_sources(rir_ds, source_ds, 2, 'cv_dev93')) 114 | 115 | >>> print_ex(ds_new[0]) 116 | {'audio_path': {'original_source': ['.../wsj_8k_zeromean/13-16.1/wsj1/si_dt_20/4k6/4k6c0303.wav', 117 | '.../wsj_8k_zeromean/13-16.1/wsj1/si_dt_20/4k4/4k4c0319.wav'], 118 | 'rir': ['.../rirs/cv_dev93/0/h_0.wav', 119 | '.../rirs/cv_dev93/0/h_1.wav']}, 120 | 'dataset': 'cv_dev93', 121 | 'example_id': '0_4k6c0303_4k4c0319', 122 | 'gender': ['male', 'female'], 123 | 'kaldi_transcription': ['IN ADDITION TO DEFORESTATION EXAMPLES ARE', 124 | 'THE PROFIT HAS BEEN PLOWED BACK INTO THE BANK WHICH ' 125 | 'HAS PURSUED ITS MISSION TO REBUILD A DECAYING ' 126 | 'NEIGHBORHOOD WITH A SINGULAR FOCUS'], 127 | 'log_weights': [0.9885484337248203, -0.9885484337248203], 128 | 'num_samples': {'observation': 93389, 'original_source': [31633, 93389]}, 129 | 'num_speakers': 2, 130 | 'offset': [52476, 0], 131 | 'room_dimensions': [[8.169], [5.905], [3.073]], 132 | 'sensor_position': [[4.015, 3.973, 4.03, 4.129, 4.172, 4.115], 133 | [3.265, 3.175, 3.093, 3.102, 3.192, 3.274], 134 | [1.55, 1.556, 1.563, 1.563, 1.558, 1.551]], 135 | 'sound_decay_time': 0.387, 136 | 'source_id': ['4k6c0303', '4k4c0319'], 137 | 'source_position': [[3.312, 3.0], [1.921, 2.379], [1.557, 1.557]], 138 | 'speaker_id': ['4k6', '4k4']} 139 | >>> ex_new, ex = ds_new[0], ds[0] 140 | >>> del ex['snr'] 141 | >>> for k in ['speech_reverberation_early', 'speech_source', 'noise_image', 142 | ... 'observation', 'speech_reverberation_tail']: 143 | ... del ex['audio_path'][k] 144 | >>> set.symmetric_difference(set(ex_new.keys()), set(ex.keys())) 145 | set() 146 | >>> assert ex_new == ex 147 | >>> from paderbox.utils.nested import flatten 148 | >>> ex_new, ex = flatten(ex_new), flatten(ex) 149 | >>> set.symmetric_difference(set(ex_new.keys()), set(ex.keys())) 150 | set() 151 | >>> for k in sorted(set(ex_new.keys()) | set(ex.keys())): 152 | ... assert ex_new[k] == ex[k], (k, ex_new[k], ex[k]) 153 | 154 | >>> ds_new = lazy_dataset.new(combine_rirs_and_sources(rir_ds, source_ds, 2, 'cv_dev93_rng1')) 155 | >>> print_ex(ds_new[0]) 156 | {'audio_path': {'original_source': ['.../wsj_8k_zeromean/13-16.1/wsj1/si_dt_20/4k2/4k2c031a.wav', 157 | '.../wsj_8k_zeromean/13-16.1/wsj1/si_dt_20/4k3/4k3c0316.wav'], 158 | 'rir': ['.../rirs/cv_dev93/0/h_0.wav', 159 | '.../rirs/cv_dev93/0/h_1.wav']}, 160 | 'dataset': 'cv_dev93_rng1', 161 | 'example_id': '0_4k2c031a_4k3c0316', 162 | 'gender': ['female', 'female'], 163 | 'kaldi_transcription': ['THE S. E. C. IS MANEUVERING TO E- CURB WHAT FUNDS ' 164 | 'CAN SAY IN NEWSLETTERS JUST AS HOLDERS ARE DEMANDING ' 165 | 'MORE INFORMATION', 166 | 'IN MISSOURI STATE PARTY LEADERS ARE ACTIVELY ' 167 | 'COURTING DEMOCRATS WHO VOTED FOR MR. REAGAN'], 168 | 'log_weights': [1.7891449809156743, -1.7891449809156748], 169 | 'num_samples': {'observation': 59008, 'original_source': [58326, 59008]}, 170 | 'num_speakers': 2, 171 | 'offset': [659, 0], 172 | 'room_dimensions': [[8.169], [5.905], [3.073]], 173 | 'sensor_position': [[4.015, 3.973, 4.03, 4.129, 4.172, 4.115], 174 | [3.265, 3.175, 3.093, 3.102, 3.192, 3.274], 175 | [1.55, 1.556, 1.563, 1.563, 1.558, 1.551]], 176 | 'sound_decay_time': 0.387, 177 | 'source_id': ['4k2c031a', '4k3c0316'], 178 | 'source_position': [[3.312, 3.0], [1.921, 2.379], [1.557, 1.557]], 179 | 'speaker_id': ['4k2', '4k3']} 180 | """ 181 | def get_sources(ex): 182 | num_speakers = len(ex['speaker_id']) 183 | 184 | for spk_idx in range(num_speakers): 185 | example_id = ex['source_id'][spk_idx] 186 | yield ( 187 | example_id, 188 | { 189 | 'audio_path': { 190 | 'observation': (ex['audio_path'].get('original_source') or 191 | ex['audio_path']['speech_source'])[ 192 | spk_idx], 193 | }, 194 | 'example_id': example_id, 195 | **{ 196 | k: ex[k][spk_idx] 197 | for k in ['gender', 'kaldi_transcription', 'speaker_id'] 198 | }, 199 | 'num_samples': (ex['num_samples'].get('original_source') or 200 | ex['num_samples']['speech_source'])[spk_idx], 201 | }, 202 | ) 203 | 204 | source_ds = lazy_dataset.new(dict(list(ds.map(get_sources).unbatch()))) 205 | 206 | def get_rir_ex(ex): 207 | example_id = ex['example_id'].split('_')[0] 208 | return ( 209 | example_id, 210 | { 211 | 'audio_path': {'rir': ex['audio_path']['rir']}, 212 | 'example_id': example_id, 213 | **{ 214 | k: ex[k] 215 | for k in 216 | ['room_dimensions', 'sound_decay_time', 'source_position', 217 | 'sensor_position'] 218 | }, 219 | } 220 | ) 221 | 222 | rir_ds = lazy_dataset.new(dict(list(ds.map(get_rir_ex)))) 223 | assert len(rir_ds) == len(ds), (len(rir_ds), len(ds)) 224 | 225 | return rir_ds, source_ds 226 | 227 | 228 | class SMSWSJRandomDataset(lazy_dataset.Dataset): 229 | """ 230 | >>> import os, re 231 | >>> from paderbox.utils.pretty import pprint 232 | >>> from paderbox.utils.timer import Timer 233 | >>> from pprint import pprint, pformat 234 | >>> from sms_wsj.database.database import SmsWsj 235 | >>> def print_ex(ex): 236 | ... print(re.sub(r"'[^']+(?=/(?:wsj_8k_zeromean|early|tail|rirs|noise|observation)/[^']+.wav')", r"'...", pformat(ex))) 237 | >>> db = SmsWsj(os.environ.get('NT_DATABASE_JSONS_DIR') + '/sms_wsj.json') 238 | >>> ds = db.get_dataset('train_si284') 239 | >>> ds = SMSWSJRandomDataset(ds) 240 | >>> with Timer() as t: 241 | ... ds = ds.copy(freeze=True) 242 | >>> print(t) 243 | : 16.2 s 244 | >>> print_ex(ds[0]) 245 | {'audio_path': {'original_source': ['.../wsj_8k_zeromean/13-6.1/wsj1/si_tr_s/498/498c040z.wav', 246 | '.../wsj_8k_zeromean/13-3.1/wsj1/si_tr_s/478/478c040q.wav'], 247 | 'rir': ['.../rirs/train_si284/0/h_0.wav', 248 | '.../rirs/train_si284/0/h_1.wav']}, 249 | 'dataset': 'cv_dev93_rng4193246114', 250 | 'example_id': '0_498c040z_478c040q', 251 | 'gender': ['female', 'female'], 252 | 'kaldi_transcription': ['THUS LESLEY ASKS CAN THERE BE ANY DOUBT THAT JESUS ' 253 | 'IS ALSO ON THE SIDE OF THE A. N. C.', 254 | " AFTER ALL HE ISN'T THE ONE WHO HAS TO RISK " 255 | 'GETTING HIT OVER THE HEAD WITH A METAL PIPE'], 256 | 'log_weights': [-0.5361178511234126, 0.5361178511234126], 257 | 'num_samples': {'observation': 52086, 'original_source': [52086, 42571]}, 258 | 'num_speakers': 2, 259 | 'offset': [0, 2547], 260 | 'room_dimensions': [[7.875], [5.839], [3.088]], 261 | 'sensor_position': [[3.974, 3.923, 3.823, 3.774, 3.825, 3.925], 262 | [2.979, 3.065, 3.063, 2.976, 2.89, 2.891], 263 | [1.418, 1.421, 1.426, 1.427, 1.424, 1.42]], 264 | 'sound_decay_time': 0.413, 265 | 'source_id': ['498c040z', '478c040q'], 266 | 'source_position': [[3.81, 5.333], [1.919, 3.777], [1.423, 1.423]], 267 | 'speaker_id': ['498', '478']} 268 | 269 | 270 | """ 271 | def __init__(self, dataset, num_speakers=2, rng=np.random): 272 | self.dataset = dataset 273 | dataset_name = set(dataset.map(lambda ex: ex['dataset'])) 274 | assert len(dataset_name) == 1, dataset_name 275 | self.dataset_name = dataset_name.pop() 276 | self.num_speakers = num_speakers 277 | self.rng = rng 278 | self.rir_ds, self.source_ds = split_rirs_and_sources(dataset) 279 | 280 | def get_new_dataset(self): 281 | return lazy_dataset.new(combine_rirs_and_sources( 282 | self.rir_ds, self.source_ds, self.num_speakers, 283 | f'{self.dataset_name}_rng{self.rng.randint(0, 2**32)}')) 284 | 285 | def copy(self, freeze: bool = False) -> 'lazy_dataset.Dataset': 286 | if freeze: 287 | return self.get_new_dataset() 288 | else: 289 | return SMSWSJRandomDataset( 290 | self.dataset, self.num_speakers, self.rng, 291 | ) 292 | 293 | def __iter__(self): 294 | return iter(self.get_new_dataset()) 295 | 296 | def __len__(self): 297 | return len(self.dataset) 298 | 299 | # ToDo: Implement getitem for str. 300 | # 301 | -------------------------------------------------------------------------------- /sms_wsj/database/utils.py: -------------------------------------------------------------------------------- 1 | 2 | from hashlib import md5 3 | 4 | import numpy as np 5 | from scipy.signal import fftconvolve 6 | from sms_wsj.reverb.reverb_utils import get_rir_start_sample 7 | 8 | __all__ = [ 9 | 'scenario_map_fn', 10 | ] 11 | 12 | 13 | def _example_id_to_rng(example_id): 14 | """ 15 | >>> _example_id_to_rng('example_id').get_state()[1][0] 16 | 2915193065 17 | """ 18 | hash_value = md5(example_id.encode()) 19 | hash_value = int(hash_value.hexdigest(), 16) 20 | hash_value -= 1 # legacy operation 21 | hash_value = hash_value % 2 ** 32 22 | return np.random.RandomState(hash_value) 23 | 24 | 25 | def extract_piece(x, offset, target_length): 26 | """ 27 | >>> extract_piece(np.arange(4), -1, 5) 28 | array([1, 2, 3, 0, 0]) 29 | 30 | >>> extract_piece(np.arange(6), -1, 5) 31 | array([1, 2, 3, 4, 5]) 32 | 33 | >>> extract_piece(np.arange(2), -2, 5) 34 | array([0, 0, 0, 0, 0]) 35 | 36 | >>> extract_piece(np.arange(2), 1, 5) 37 | array([0, 0, 1, 0, 0]) 38 | 39 | >>> extract_piece(np.arange(4), 1, 5) 40 | array([0, 0, 1, 2, 3]) 41 | 42 | >>> extract_piece(np.arange(2), 5, 5) 43 | array([0, 0, 0, 0, 0]) 44 | 45 | 46 | Args: 47 | x: 48 | offset: 49 | If negative, cut left side. 50 | If positive: pad left side. 51 | target_length: 52 | 53 | Returns: 54 | 55 | """ 56 | def pad_axis(array, pad_width, axis=-1): 57 | array = np.asarray(array) 58 | 59 | npad = np.zeros([array.ndim, 2], dtype=int) 60 | npad[axis, :] = pad_width 61 | return np.pad(array, pad_width=npad, mode='constant') 62 | 63 | if offset < 0: 64 | x = x[..., -offset:] 65 | else: 66 | x = pad_axis(x, (offset, 0), axis=-1) 67 | 68 | if x.shape[-1] < target_length: 69 | x = pad_axis(x, (0, target_length - x.shape[-1]), axis=-1) 70 | else: 71 | x = x[..., :target_length] 72 | 73 | return x 74 | 75 | 76 | def get_white_noise_for_signal( 77 | time_signal, 78 | *, 79 | snr, 80 | rng_state: np.random.RandomState = np.random 81 | ): 82 | """ 83 | Args: 84 | time_signal: 85 | snr: SNR or single speaker SNR. 86 | rng_state: A random number generator object or np.random 87 | """ 88 | noise_signal = rng_state.normal(size=time_signal.shape) 89 | 90 | power_time_signal = np.mean(time_signal ** 2, keepdims=True) 91 | power_noise_signal = np.mean(noise_signal ** 2, keepdims=True) 92 | current_snr = 10 * np.log10(power_time_signal / power_noise_signal) 93 | 94 | factor = 10 ** (-(snr - current_snr) / 20) 95 | 96 | noise_signal *= factor 97 | return noise_signal 98 | 99 | 100 | def synchronize_speech_source(original_source, offset, T): 101 | """ 102 | >>> from sms_wsj.database.database import SmsWsj, AudioReader 103 | >>> ds = SmsWsj().get_dataset('cv_dev93') 104 | >>> example = ds[0] 105 | >>> original_source = AudioReader._rec_audio_read( 106 | ... example['audio_path']['original_source']) 107 | >>> [s.shape for s in original_source] 108 | [(103650,), (93411,)] 109 | >>> synchronize_speech_source( 110 | ... original_source, 111 | ... example['offset'], 112 | ... T=example['num_samples']['observation'], 113 | ... ).shape 114 | (2, 103650) 115 | """ 116 | return np.array([ 117 | extract_piece(x_, offset_, T) 118 | for x_, offset_ in zip( 119 | original_source, 120 | offset, 121 | ) 122 | ]) 123 | 124 | 125 | def scenario_map_fn( 126 | example, 127 | *, 128 | snr_range: tuple = (20, 30), 129 | 130 | sync_speech_source=True, 131 | 132 | add_speech_image=True, 133 | add_speech_reverberation_early=True, 134 | add_speech_reverberation_tail=True, 135 | add_noise_image=True, 136 | 137 | early_rir_samples: int = int(8000 * 0.05), # 50 milli seconds 138 | channel_slice: [None, slice, tuple, list] = None, 139 | 140 | details=False, 141 | ): 142 | """ 143 | This will care for convolution with RIR and also generate noise. 144 | The random noise generator is fixed based on example ID. It will 145 | therefore generate the same SNR and same noise sequence the next time 146 | you use this DB. 147 | 148 | Args: 149 | example: Example dictionary. 150 | snr_range: required for noise generation 151 | sync_speech_source: Legacy option. The new convention is, that 152 | original_source is the unpadded signal, while speech_source is the 153 | padded signal. 154 | pad and/or cut the source signal to match the length of the 155 | observations. Considers the offset. 156 | add_speech_image: 157 | The speech_image is always computed, but when it is not needed, 158 | this option can reduce the memory consumption. 159 | add_speech_reverberation_early: 160 | Calculate the speech_reverberation_early signal, i.e., the speech 161 | source (padded original source) convolved with the early part of 162 | the RIR. 163 | add_speech_reverberation_tail: 164 | Calculate the speech_reverberation_tail signal, i.e., the speech 165 | source (padded original source) convolved with the tail part of 166 | the RIR. 167 | add_noise_image: 168 | If True, add the noise_image the returned example. 169 | This option has no effect to the computation time or the peak 170 | memory consumption. 171 | early_rir_samples: 172 | The number of samples that we count as the early RIR, default 50ms. 173 | The remaining part of the RIR we call tail. 174 | Note: The length of the early RIR is the time of flight plus this 175 | value. 176 | channel_slice: e.g. None (All channels), [4] (Single channel), ... 177 | Warning: Use this only for training. It will change the scale of 178 | the data and the added white noise. 179 | For the scale the standard deviation is estimated and the generated 180 | noise shape changes, hence also the values. 181 | With this option you can select the interested channels. 182 | All RIRs are used to estimate the time of flight, but only the 183 | interested channels are convolved with the original/speech source. 184 | This reduces computation time and memory consumption. 185 | 186 | Returns: 187 | 188 | """ 189 | h = example['audio_data']['rir'] # Shape (speaker, channel, sample) 190 | 191 | # Estimate start sample first, to make it independent of channel_mode 192 | # Calculate one rir_start_sample (i.e. time of flight) for each speaker. 193 | rir_start_sample = np.array([get_rir_start_sample(h_k) for h_k in h]) 194 | 195 | if channel_slice is not None: 196 | assert h.ndim == 3, h.shape 197 | h = h[:, channel_slice, :] 198 | assert h.ndim == 3, h.shape 199 | 200 | _, D, rir_length = h.shape 201 | 202 | # TODO: SAMPLE_RATE not defined 203 | # rir_stop_sample = rir_start_sample + int(SAMPLE_RATE * 0.05) 204 | # Use 50 milliseconds as early rir part, excluding the propagation delay 205 | # (i.e. "rir_start_sample") 206 | assert isinstance(early_rir_samples, int), (type(early_rir_samples), early_rir_samples) 207 | rir_stop_sample = rir_start_sample + early_rir_samples 208 | 209 | log_weights = example['log_weights'] 210 | 211 | # The two sources have to be cut to same length 212 | K = example['num_speakers'] 213 | T = example['num_samples']['observation'] 214 | if 'original_source' not in example['audio_data']: 215 | # legacy code 216 | example['audio_data']['original_source'] = example['audio_data']['speech_source'] 217 | if 'original_source' not in example['num_samples']: 218 | # legacy code 219 | example['num_samples']['original_source'] = example['num_samples']['speech_source'] 220 | s = example['audio_data']['original_source'] 221 | 222 | def get_convolved_signals(h): 223 | assert len(s) == h.shape[0], (len(s), h.shape) 224 | x = [fftconvolve(s_[..., None, :], h_, axes=-1) 225 | for s_, h_ in zip(s, h)] 226 | 227 | assert len(x) == len(example['num_samples']['original_source']), (len(x), len(example['num_samples']['original_source'])) 228 | for x_, T_ in zip(x, example['num_samples']['original_source']): 229 | assert x_.shape == (D, T_ + rir_length - 1), ( 230 | x_.shape, D, T_ + rir_length - 1) 231 | 232 | # This is Jahn's heuristic to be able to still use WSJ alignments. 233 | offset = [ 234 | offset_ - rir_start_sample_ 235 | for offset_, rir_start_sample_ in zip( 236 | example['offset'], rir_start_sample) 237 | ] 238 | 239 | assert len(x) == len(offset) 240 | x = [extract_piece(x_, offset_, T) for x_, offset_ in zip(x, offset)] 241 | x = np.stack(x, axis=0) 242 | assert x.shape == (K, D, T), x.shape 243 | return x 244 | 245 | x = get_convolved_signals(h) 246 | 247 | # Note: scale depends on channel mode 248 | std = np.maximum( 249 | np.std(x, axis=(-2, -1), keepdims=True), 250 | np.finfo(x.dtype).tiny, 251 | ) 252 | 253 | # Rescale such that invasive SIR is as close as possible to `log_weights`. 254 | scale = (10 ** (np.asarray(log_weights)[:, None, None] / 20)) / std 255 | # divide by 71 to ensure that all values are between -1 and 1 256 | scale /= 71 257 | 258 | x *= scale 259 | if add_speech_image: 260 | example['audio_data']['speech_image'] = x 261 | 262 | clean_mix = np.sum(x, axis=0) 263 | del x # Reduce memory consumption for the case of `not add_speech_image` 264 | 265 | if add_speech_reverberation_early: 266 | h_early = h.copy() 267 | # Replace this with advanced indexing 268 | for i in range(h_early.shape[0]): 269 | h_early[i, ..., rir_stop_sample[i]:] = 0 270 | x_early = get_convolved_signals(h_early) 271 | x_early *= scale 272 | example['audio_data']['speech_reverberation_early'] = x_early 273 | 274 | if details: 275 | example['audio_data']['rir_early'] = h_early 276 | 277 | if add_speech_reverberation_tail: 278 | h_tail = h.copy() 279 | for i in range(h_tail.shape[0]): 280 | h_tail[i, ..., :rir_stop_sample[i]] = 0 281 | x_tail = get_convolved_signals(h_tail) 282 | x_tail *= scale 283 | example['audio_data']['speech_reverberation_tail'] = x_tail 284 | 285 | if details: 286 | example['audio_data']['rir_tail'] = h_tail 287 | 288 | if sync_speech_source: 289 | example['audio_data']['speech_source'] = synchronize_speech_source( 290 | example['audio_data']['original_source'], 291 | offset=example['offset'], 292 | T=T, 293 | ) 294 | else: 295 | # legacy code 296 | example['audio_data']['speech_source'] = \ 297 | example['audio_data']['original_source'] 298 | 299 | rng = _example_id_to_rng(example['example_id']) 300 | snr = rng.uniform(*snr_range) 301 | example["snr"] = snr 302 | 303 | rng = _example_id_to_rng(example['example_id']) 304 | 305 | n = get_white_noise_for_signal(clean_mix, snr=snr, rng_state=rng) 306 | if add_noise_image: 307 | example['audio_data']['noise_image'] = n 308 | 309 | observation = clean_mix 310 | observation += n # Inplace to reduce memory consumption 311 | example['audio_data']['observation'] = observation 312 | 313 | return example 314 | 315 | 316 | def get_valid_mird_rirs(mird_dir, rng=np.random): 317 | import scipy.io 318 | 319 | def minus_with_wrap(angle1, angle2): 320 | return np.angle(np.exp(1j * (angle1 - angle2))) 321 | 322 | K = 2 323 | t60 = rng.choice(['0.160', '0.360', '0.610']) 324 | spacing = rng.choice(['3-3-3-8-3-3-3', '4-4-4-8-4-4-4', '8-8-8-8-8-8-8']) 325 | distance = rng.choice(['1', '2'], size=2, replace=True) 326 | 327 | angular_distance_ok = False 328 | while not angular_distance_ok: 329 | angle_degree = rng.choice([ 330 | '000', 331 | '015', '030', '045', '060', '075', '090', 332 | '270', '285', '300', '315', '330', '345' 333 | ], size=2, replace=False) 334 | angular_distance = np.abs(minus_with_wrap( 335 | float(angle_degree[1]) / 180 * np.pi, 336 | float(angle_degree[0]) / 180 * np.pi, 337 | ) / np.pi * 180) 338 | if angular_distance > 37.5: 339 | angular_distance_ok = True 340 | 341 | rirs = np.stack([ 342 | scipy.io.loadmat(str( 343 | mird_dir / 344 | f'Impulse_response_Acoustic_Lab_Bar-Ilan_University_(' 345 | f'Reverberation_{t60}s)_{spacing}_{distance[k]}m_' 346 | f'{angle_degree[k]}.mat' 347 | ))['impulse_response'].T 348 | for k in range(K) 349 | ]) 350 | 351 | return scipy.signal.resample_poly(rirs, up=1, down=6, axis=-1) 352 | -------------------------------------------------------------------------------- /sms_wsj/database/write_files.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example calls: 3 | python -m sms_wsj.database.wsj.write_wav with dst_dir=/destination/dir json-path=/path/to/sms_wsj.json write_all=True --new_json_path=/path/to/new_sms_wsj.json write_all=True 4 | 5 | mpiexec -np 20 python -m sms_wsj.database.wsj.write_wav with dst_dir=/destination/dir json-path=/path/to/sms_wsj.json write_all=True --new_json_path=/path/to/new_sms_wsj.json write_all=True 6 | 7 | """ 8 | 9 | from functools import partial 10 | from pathlib import Path 11 | 12 | import json 13 | import sacred 14 | import numpy as np 15 | import soundfile 16 | from lazy_dataset.database import JsonDatabase 17 | from sms_wsj.database.utils import scenario_map_fn, _example_id_to_rng 18 | import dlp_mpi 19 | 20 | 21 | ex = sacred.Experiment('Write SMS-WSJ files') 22 | 23 | KEY_MAPPER = { 24 | 'speech_reverberation_early': 'early', 25 | 'speech_reverberation_tail': 'tail', 26 | 'noise_image': 'noise', 27 | 'observation': 'observation', 28 | 'speech_source': 'speech_source', 29 | } 30 | 31 | 32 | def check_files(dst_dir): 33 | return [ 34 | p for p in dst_dir.rglob("*.wav") if any( 35 | [ 36 | ( 37 | p.match(str(dst_dir / f'{data_type}/**/*.wav')) or 38 | p.match(str(dst_dir / f'{data_type}/*.wav')) 39 | ) for data_type in KEY_MAPPER.values() 40 | ] 41 | ) 42 | ] 43 | 44 | 45 | def audio_read(example): 46 | """ 47 | :param example: example dict 48 | :return: example dict with audio_data added 49 | """ 50 | keys = list(example['audio_path'].keys()) 51 | if 'wsj_source' in keys: 52 | audio_keys = ['rir', 'wsj_source'] 53 | else: 54 | # legacy code 55 | audio_keys = ['rir', 'original_source'] 56 | example['audio_data'] = dict() 57 | for audio_key in audio_keys: 58 | assert audio_key in keys, ( 59 | f'Trying to read {audio_key} but only {keys} are available' 60 | ) 61 | audio_data = list() 62 | for wav_file in example['audio_path'][audio_key]: 63 | 64 | with soundfile.SoundFile(wav_file, mode='r') as f: 65 | audio_data.append(f.read().T) 66 | example['audio_data'][audio_key] = np.array(audio_data) 67 | return example 68 | 69 | 70 | def write_wavs(dst_dir, json_path, write_all=False, snr_range=(20, 30)): 71 | db = JsonDatabase(json_path) 72 | if write_all: 73 | if dlp_mpi.IS_MASTER: 74 | [(dst_dir / data_type).mkdir(exist_ok=False) 75 | for data_type in KEY_MAPPER.values()] 76 | map_fn = partial( 77 | scenario_map_fn, 78 | snr_range=snr_range, 79 | sync_speech_source=True, 80 | add_speech_reverberation_early=True, 81 | add_speech_reverberation_tail=True 82 | ) 83 | else: 84 | if dlp_mpi.IS_MASTER: 85 | (dst_dir / 'observation').mkdir(exist_ok=False) 86 | map_fn = partial( 87 | scenario_map_fn, 88 | snr_range=snr_range, 89 | sync_speech_source=True, 90 | add_speech_reverberation_early=False, 91 | add_speech_reverberation_tail=False 92 | ) 93 | for dataset in ['train_si284', 'cv_dev93', 'test_eval92']: 94 | if dlp_mpi.IS_MASTER: 95 | [ 96 | (dst_dir / data_type / dataset).mkdir(exist_ok=False) 97 | for data_type in KEY_MAPPER.values() 98 | ] 99 | ds = db.get_dataset(dataset).map(audio_read).map(map_fn) 100 | for example in dlp_mpi.split_managed( 101 | ds, 102 | is_indexable=True, 103 | allow_single_worker=True, 104 | progress_bar=True, 105 | ): 106 | audio_dict = example['audio_data'] 107 | example_id = example['example_id'] 108 | if not write_all: 109 | del audio_dict['speech_reverberation_early'] 110 | del audio_dict['speech_reverberation_tail'] 111 | del audio_dict['noise_image'] 112 | 113 | def get_abs_max(a): 114 | if isinstance(a, np.ndarray): 115 | if a.dtype == np.object: 116 | return np.max(list(map(get_abs_max, a))) 117 | else: 118 | return np.max(np.abs(a)) 119 | elif isinstance(a, (tuple, list)): 120 | return np.max(list(map(get_abs_max, a))) 121 | elif isinstance(a, dict): 122 | return np.max(list(map(get_abs_max, a.values()))) 123 | else: 124 | raise TypeError(a) 125 | 126 | assert get_abs_max(audio_dict), ( 127 | example_id, { 128 | k: get_abs_max(v) for k, v in audio_dict.items() 129 | } 130 | ) 131 | for key, value in audio_dict.items(): 132 | if key not in KEY_MAPPER: 133 | continue 134 | path = dst_dir / KEY_MAPPER[key] / dataset 135 | if key in ['observation', 'noise_image']: 136 | value = value[None] 137 | for idx, signal in enumerate(value): 138 | appendix = f'_{idx}' if len(value) > 1 else '' 139 | filename = example_id + appendix + '.wav' 140 | audio_path = str(path / filename) 141 | with soundfile.SoundFile( 142 | audio_path, subtype='FLOAT', mode='w', 143 | samplerate=8000, 144 | channels=1 if signal.ndim == 1 else signal.shape[0] 145 | ) as f: 146 | f.write(signal.T) 147 | 148 | dlp_mpi.barrier() 149 | 150 | if dlp_mpi.IS_MASTER: 151 | created_files = check_files(dst_dir) 152 | print(f"Written {len(created_files)} wav files.") 153 | if write_all: 154 | # TODO Less, if you do a test run. 155 | num_speakers = 2 # todo infer num_speakers from json 156 | # 2 files for: early, tail, speech_source 157 | # 1 file for: observation, noise 158 | expect = (3 * num_speakers + 2) * 35875 159 | assert len(created_files) == expect, ( 160 | len(created_files), expect 161 | ) 162 | else: 163 | assert len(created_files) == 35875, len(created_files) 164 | 165 | 166 | @ex.config 167 | def config(): 168 | dst_dir = None 169 | json_path = None 170 | 171 | # If `False`, only write observation, else write all intermediate signals. 172 | write_all = True 173 | 174 | snr_range = (20, 30) 175 | 176 | assert dst_dir is not None, 'You have to specify a destination dir' 177 | assert json_path is not None, 'You have to specify a path to sms_wsj.json' 178 | 179 | debug = False 180 | 181 | 182 | @ex.automain 183 | def main(dst_dir, json_path, write_all, snr_range): 184 | json_path = Path(json_path).expanduser().resolve() 185 | dst_dir = Path(dst_dir).expanduser().resolve() 186 | if dlp_mpi.IS_MASTER: 187 | assert json_path.exists(), json_path 188 | dst_dir.mkdir(exist_ok=True, parents=True) 189 | if not any([ 190 | (dst_dir / data_type).exists() 191 | for data_type in KEY_MAPPER.keys() 192 | ]): 193 | write_files = True 194 | else: 195 | write_files = False 196 | num_wav_files = len(check_files(dst_dir)) 197 | # TODO Less, if you do a test run. 198 | if write_all and num_wav_files == (2 * 2 + 2) * 35875: 199 | print('Wav files seem to exist. They are not overwritten.') 200 | elif ( 201 | not write_all and num_wav_files == 35875 202 | and (dst_dir / 'observation').exists() 203 | ): 204 | print('Wav files seem to exist. They are not overwritten.') 205 | else: 206 | raise ValueError( 207 | 'Not all wav files exist. ' 208 | 'However, the directory structure already exists.' 209 | ) 210 | else: 211 | write_files = None 212 | write_files = dlp_mpi.COMM.bcast(write_files, root=dlp_mpi.MASTER) 213 | if write_files: 214 | write_wavs(dst_dir, json_path, write_all=write_all, snr_range=snr_range) 215 | -------------------------------------------------------------------------------- /sms_wsj/database/wsj/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/sms_wsj/f6b94bb987d620fc8bbe005a92bd042e09a68d9b/sms_wsj/database/wsj/__init__.py -------------------------------------------------------------------------------- /sms_wsj/database/wsj/create_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example calls: 3 | python -m sms_wsj.database.wsj.write_wav --database_dir-dir /destination/dir --json-path /path/to/sms_wsj.json 4 | 5 | 6 | """ 7 | import json 8 | import os 9 | import re 10 | import tempfile 11 | from pathlib import Path 12 | 13 | import sacred 14 | import sh 15 | import soundfile as sf 16 | 17 | ex = sacred.Experiment('Create wsj json') 18 | kaldi_root = Path(os.environ['KALDI_ROOT']) 19 | kaldi_wsj_egs_dir = kaldi_root / 'egs' / 'wsj' / 's5' 20 | kaldi_wsj_data_dir = kaldi_wsj_egs_dir / 'data' / 'local' / 'data' 21 | kaldi_wsj_tools = kaldi_wsj_egs_dir / 'local' 22 | 23 | 24 | def create_official_datasets( 25 | official_sets, official_names, wsj_root, as_wav, genders, transcript 26 | ): 27 | 28 | _examples = dict() 29 | 30 | for idx, set_list in enumerate(official_sets): 31 | set_name = official_names[idx] 32 | _examples[set_name] = dict() 33 | for ods in set_list: 34 | set_path = wsj_root / ods 35 | if set_path.match('*.ndx'): 36 | _example = read_ndx( 37 | set_path, wsj_root, as_wav, genders, transcript 38 | ) 39 | else: 40 | if as_wav: 41 | wav_files = list(set_path.glob('*/*.wav')) 42 | else: 43 | wav_files = list(set_path.glob('*/*.wv1')) 44 | _example = process_example_paths( 45 | wav_files, genders, transcript 46 | ) 47 | _examples[set_name].update(_example) 48 | 49 | return _examples 50 | 51 | 52 | def read_nsamples(audio_path): 53 | 54 | if audio_path.suffix == '.wv1': 55 | f = open(audio_path, 'rb') 56 | header = f.read(1024).decode("utf-8") # nist header is a multiple of 57 | # 1024 bytes 58 | nsamples = int(re.search("sample_count -i (.+?)\n", header).group(1)) 59 | else: 60 | info = sf.info(str(audio_path), verbose=True) 61 | nsamples = info.frames 62 | return nsamples 63 | 64 | 65 | def read_ndx(ndx_file: Path, wsj_root, as_wav, 66 | genders, transcript): 67 | assert ndx_file.match('*.ndx') 68 | 69 | with open(ndx_file) as fid: 70 | if ndx_file.match('*/si_et_20.ndx') or \ 71 | ndx_file.match('*/si_et_05.ndx'): 72 | lines = [line.rstrip() + ".wv1" for line in fid 73 | if not line.startswith(";")] 74 | else: 75 | lines = [line.rstrip() for line in fid 76 | if line.lower().rstrip().endswith(".wv1") 77 | ] 78 | 79 | fixed_paths = list() 80 | 81 | for line in lines: 82 | disk, wav_path = line.split(':') 83 | disk = '{}-{}.{}'.format(*disk.split('_')) 84 | 85 | # wrong disk-ids for test_eval93 and test_eval93_5k 86 | disk = disk.replace('13-32.1', '13-33.1') 87 | wav_path = wav_path.lstrip(' /') # remove leading whitespace and 88 | # slash 89 | audio_path = wsj_root / disk / wav_path 90 | if as_wav: 91 | audio_path = audio_path.with_suffix('.wav') 92 | if "11-2.1/wsj0/si_tr_s/401" in str(audio_path): 93 | continue # skip 401 subdirectory in train sets 94 | fixed_paths.append(audio_path) 95 | 96 | _examples = process_example_paths(fixed_paths, genders, transcript) 97 | 98 | return _examples 99 | 100 | 101 | def process_example_paths(example_paths, genders, transcript): 102 | """ 103 | Creates an entry in keys.EXAMPLE for every example in `example_paths` 104 | 105 | :param example_paths: List of Paths to example .wv files 106 | :type: List 107 | :param genders: Mapping from speaker id to gender 108 | :type: dict 109 | :param transcript: Mapping from raw example id to dirty, clean and kaldi 110 | transcription 111 | :type: dict 112 | 113 | :return _examples: Partial entries in keys.EXAMPLE for examples in 114 | `set_name` 115 | :type: dict 116 | """ 117 | _examples = dict() 118 | 119 | for path in example_paths: 120 | 121 | wav_file = path.parts[-1] 122 | example_id = wav_file.split('.')[0] 123 | 124 | speaker_id = example_id[0:3] 125 | nsamples = read_nsamples(path) 126 | 127 | gender = genders[speaker_id] 128 | 129 | example = { 130 | 'example_id': example_id, 131 | 'audio_path': { 132 | 'observation': str(path) 133 | }, 134 | 'num_samples': { 135 | 'observation': nsamples 136 | }, 137 | 'speaker_id': speaker_id, 138 | 'gender': gender, 139 | 'transcription': transcript['clean word'][example_id], 140 | 'kaldi_transcription': transcript['kaldi'][example_id] 141 | } 142 | 143 | _examples[example_id] = example 144 | 145 | return _examples 146 | 147 | 148 | def get_transcriptions(root: Path, wsj_root: Path): 149 | word = dict() 150 | 151 | dot_files = list(root.rglob('*.dot')) 152 | ptx_files = list(root.rglob('*.ptx')) 153 | ptx_files = [ptx_file for ptx_file in ptx_files if Path( 154 | str(ptx_file).replace('.ptx', '.dot')) not in dot_files] 155 | 156 | for file_path in dot_files + ptx_files: 157 | with open(file_path) as fid: 158 | matches = re.findall("^(.+)\s+\((\S+)\)$", fid.read(), flags=re.M) 159 | word.update({utt_id: trans for trans, utt_id in matches}) 160 | 161 | kaldi = dict() 162 | files = list(kaldi_wsj_data_dir.glob('*.txt')) 163 | for file in files: 164 | with open(file) as fid: 165 | matches = re.findall("^(\S+) (.+)$", fid.read(), flags=re.M) 166 | kaldi.update({utt_id: trans for utt_id, trans in matches}) 167 | 168 | data_dict = dict() 169 | data_dict["word"] = word 170 | data_dict["clean word"] = normalize_transcription(word, wsj_root) 171 | data_dict["kaldi"] = kaldi 172 | return data_dict 173 | 174 | 175 | def normalize_transcription(transcriptions, wsj_root: Path): 176 | """ Passes the dirty transcription dict to a Kaldi Perl script for cleanup. 177 | 178 | We use the original Perl file, to make sure, that the cleanup is done 179 | exactly as it is done by Kaldi. 180 | 181 | :param transcriptions: Dirty transcription dictionary 182 | :param wsj_root: Path to WSJ database 183 | 184 | :return result: Clean transcription dictionary 185 | """ 186 | assert len(transcriptions) > 0, 'No transcriptions to clean up.' 187 | with tempfile.TemporaryDirectory() as temporary_directory: 188 | temporary_directory = Path(temporary_directory).absolute() 189 | with open(temporary_directory / 'dirty.txt', 'w') as f: 190 | for key, value in transcriptions.items(): 191 | f.write('{} {}\n'.format(key, value)) 192 | result = sh.perl( 193 | sh.cat(str(temporary_directory / 'dirty.txt')), 194 | kaldi_wsj_tools / 'normalize_transcript.pl', 195 | '' 196 | ) 197 | result = [line.split(maxsplit=1) for line in result.strip().split('\n')] 198 | result = {k: v for k, v in result} 199 | return result 200 | 201 | 202 | def get_gender_mapping(wsj_root: Path): 203 | spkrinfo_wsj = list(wsj_root.glob('**/wsj?/doc/**/*spkrinfo.txt')) 204 | spkrinfo_kaldi = list(kaldi_wsj_data_dir.glob('**/*spkrinfo.txt')) 205 | spkrinfo = spkrinfo_wsj + spkrinfo_kaldi 206 | 207 | if len(spkrinfo) == 0: 208 | raise RuntimeError( 209 | f'Could not find "{wsj_root}/**/wsj?/doc/**/*spkrinfo.txt" and' 210 | f'"{kaldi_wsj_data_dir}/**/*spkrinfo.txt".' 211 | ) 212 | if len(spkrinfo_wsj) == 0: 213 | raise RuntimeError( 214 | f'Could not find "{wsj_root}/**/wsj?/doc/**/*spkrinfo.txt".' 215 | ) 216 | if len(spkrinfo_kaldi) == 0: 217 | raise RuntimeError( 218 | f'Could not find "{kaldi_wsj_data_dir}/**/*spkrinfo.txt". ' 219 | f'Did you forget to run the data preparation for WSJ in Kaldi?' 220 | ) 221 | 222 | _spkr_gender_mapping = dict() 223 | 224 | for path in spkrinfo: 225 | with open(path, 'r') as fid: 226 | for line in fid: 227 | if not (line.startswith(';') or line.startswith('---')): 228 | line = line.split() 229 | _spkr_gender_mapping[line[0].lower()] = 'male' \ 230 | if line[1] == 'M' else 'female' 231 | 232 | if len(_spkr_gender_mapping) == 0 or '01i' not in _spkr_gender_mapping: 233 | raise RuntimeError( 234 | f'Could not read the gender information from "{spkrinfo}".' 235 | ) 236 | 237 | return _spkr_gender_mapping 238 | 239 | 240 | @ex.config 241 | def config(): 242 | database_dir = None 243 | json_path = None 244 | wsj_json = None 245 | as_wav = True 246 | assert database_dir is not None, 'You have to specify the database dir' 247 | assert json_path is not None, 'You have to specify a path for the new json' 248 | 249 | 250 | @ex.automain 251 | def create_database(database_dir, json_path, as_wav): 252 | database_dir = Path(database_dir).expanduser().resolve() 253 | json_path = Path(json_path).expanduser().resolve() 254 | if json_path.exists(): 255 | raise FileExistsError(json_path) 256 | assert database_dir.exists(), database_dir 257 | 258 | database_dir = Path(database_dir) 259 | 260 | train_sets = [ 261 | ["11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx"], 262 | ["13-34.1/wsj1/doc/indices/si_tr_s.ndx", 263 | "11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx"] 264 | ] 265 | train_set_names = [ 266 | "train_si84", # 7138 examples 267 | "train_si284" # 37416 examples 268 | ] 269 | 270 | test_sets = [ 271 | ["11-13.1/wsj0/doc/indices/test/nvp/si_et_20.ndx"], 272 | ["11-13.1/wsj0/doc/indices/test/nvp/si_et_05.ndx"], 273 | ["13-32.1/wsj1/doc/indices/wsj1/eval/h1_p0.ndx"], 274 | ["13-32.1/wsj1/doc/indices/wsj1/eval/h2_p0.ndx"] 275 | ] 276 | 277 | test_set_names = [ 278 | "test_eval92", # 333 examples 279 | "test_eval92_5k", # 330 examples 280 | "test_eval93", # 213 examples 281 | "test_eval93_5k" # 215 examples 282 | ] 283 | 284 | dev_sets = [ 285 | ["13-34.1/wsj1/doc/indices/h1_p0.ndx"], 286 | ["13-34.1/wsj1/doc/indices/h2_p0.ndx"], 287 | ] 288 | dev_set_names = [ 289 | "cv_dev93", # 503 examples 290 | "cv_dev93_5k", # 513 examples 291 | ] 292 | 293 | transcriptions = get_transcriptions(database_dir, database_dir) 294 | gender_mapping = get_gender_mapping(database_dir) 295 | 296 | examples = dict() 297 | 298 | examples_tr = create_official_datasets( 299 | train_sets, 300 | train_set_names, 301 | database_dir, 302 | as_wav, 303 | gender_mapping, 304 | transcriptions 305 | ) 306 | examples.update(examples_tr) 307 | 308 | examples_dt = create_official_datasets( 309 | dev_sets, 310 | dev_set_names, 311 | database_dir, 312 | as_wav, gender_mapping, 313 | transcriptions 314 | ) 315 | examples.update(examples_dt) 316 | 317 | examples_et = create_official_datasets( 318 | test_sets, 319 | test_set_names, 320 | database_dir, 321 | as_wav, 322 | gender_mapping, 323 | transcriptions 324 | ) 325 | examples.update(examples_et) 326 | 327 | database = { 328 | 'datasets': examples, 329 | } 330 | json_path.parent.mkdir(exist_ok=True, parents=True) 331 | with json_path.open('w') as f: 332 | json.dump(database, f, indent=4, ensure_ascii=False) 333 | print(f'{json_path} written') 334 | -------------------------------------------------------------------------------- /sms_wsj/database/wsj/write_wav.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example calls: 3 | python -m sms_wsj.database.wsj.write_wav with dst_dir=/DEST/DIR wsj_root=/WSJ/ROOT/DIR --sample_rate=8000 4 | 5 | mpiexec -np 20 python -m sms_wsj.database.wsj.write_wav with dst_dir=/DEST/DIR wsj_root=/WSJ/ROOT/DIR sample_rate=8000 6 | 7 | """ 8 | 9 | import os 10 | import io 11 | import fnmatch 12 | import shutil 13 | import subprocess 14 | from pathlib import Path 15 | 16 | import numpy as np 17 | import sacred 18 | import soundfile 19 | import warnings 20 | 21 | import dlp_mpi 22 | 23 | ex = sacred.Experiment('Write WSJ waves') 24 | 25 | kaldi_root = Path(os.environ['KALDI_ROOT']) 26 | 27 | 28 | def read_nist_wsj(path, expected_sample_rate=16000): 29 | """ 30 | Converts a nist/sphere file of wsj and reads it with soundfile. 31 | 32 | :param path: file path to audio file. 33 | :param audioread_function: Function to use to read the resulting audio file 34 | :return: 35 | """ 36 | cmd = [ 37 | kaldi_root / 'tools' / 'sph2pipe_v2.5' / 'sph2pipe', 38 | '-f', 'wav', 39 | str(path), 40 | ] 41 | 42 | completed_process = subprocess.run( 43 | cmd, universal_newlines=False, shell=False, stdout=subprocess.PIPE, 44 | stderr=subprocess.PIPE, check=True, env=None, cwd=None 45 | ) 46 | 47 | signal, sample_rate = soundfile.read(io.BytesIO(completed_process.stdout)) 48 | assert sample_rate == expected_sample_rate, (sample_rate, expected_sample_rate) 49 | return signal 50 | 51 | 52 | def resample_with_sox(x, rate_in, rate_out): 53 | """ 54 | 55 | Args: 56 | x: 57 | rate_in: 58 | rate_out: 59 | 60 | Returns: 61 | Resampled signal. 62 | 63 | >>> from paderbox.utils.pretty import pprint 64 | >>> x = np.ones(10000) 65 | >>> pprint(x.flags) 66 | C_CONTIGUOUS : True 67 | F_CONTIGUOUS : True 68 | OWNDATA : True 69 | WRITEABLE : True 70 | ALIGNED : True 71 | WRITEBACKIFCOPY : False 72 | UPDATEIFCOPY : False 73 | >>> y = resample_with_sox(x, 16000, 8000) 74 | >>> pprint(y) 75 | array(shape=(5000,), dtype=float32) 76 | >>> pprint(y.flags) 77 | C_CONTIGUOUS : True 78 | F_CONTIGUOUS : True 79 | OWNDATA : False 80 | WRITEABLE : False 81 | ALIGNED : True 82 | WRITEBACKIFCOPY : False 83 | UPDATEIFCOPY : False 84 | >>> y = resample_with_sox(x, 16000, 16000) 85 | >>> pprint(y) 86 | array(shape=(10000,), dtype=float64) 87 | >>> pprint(y.flags) 88 | C_CONTIGUOUS : True 89 | F_CONTIGUOUS : True 90 | OWNDATA : False 91 | WRITEABLE : False 92 | ALIGNED : True 93 | WRITEBACKIFCOPY : False 94 | UPDATEIFCOPY : False 95 | """ 96 | if rate_in == rate_out: 97 | # Mirror readonly output property from np.frombuffer 98 | x = x.view() 99 | x.flags.writeable = False 100 | return x 101 | 102 | x = x.astype(np.float32) 103 | command = ( 104 | f'sox -N -V1 -t f32 -r {rate_in} -c 1 - -t f32 -r {rate_out} -c 1 -' 105 | ).split() 106 | process = subprocess.run( 107 | command, 108 | shell=False, 109 | stdout=subprocess.PIPE, 110 | stderr=subprocess.PIPE, 111 | input=x.tobytes(order="f") 112 | ) 113 | return np.frombuffer(process.stdout, dtype=np.float32) 114 | 115 | 116 | @ex.config 117 | def config(): 118 | dst_dir = None 119 | wsj_root = None 120 | wsj0_root = wsj_root 121 | wsj1_root = wsj_root 122 | sample_rate = 16000 123 | assert dst_dir is not None, 'You have to specify a destination dir' 124 | assert wsj0_root is not None, 'You have to specify a wsj0_root' 125 | assert wsj1_root is not None, 'You have to specify a wsj1_root' 126 | 127 | 128 | @ex.automain 129 | def write_wavs(dst_dir: Path, wsj0_root: Path, wsj1_root: Path, sample_rate): 130 | wsj0_root = Path(wsj0_root).expanduser().resolve() 131 | wsj1_root = Path(wsj1_root).expanduser().resolve() 132 | dst_dir = Path(dst_dir).expanduser().resolve() 133 | assert wsj0_root.exists(), wsj0_root 134 | assert wsj1_root.exists(), wsj1_root 135 | 136 | assert not dst_dir == wsj0_root, (wsj0_root, dst_dir) 137 | assert not dst_dir == wsj1_root, (wsj1_root, dst_dir) 138 | # Expect, that the dst_dir does not exist to make sure to not overwrite. 139 | if dlp_mpi.IS_MASTER: 140 | dst_dir.mkdir(parents=True, exist_ok=False) 141 | 142 | if dlp_mpi.IS_MASTER: 143 | # Search for CD numbers, e.g. "13-34.1" 144 | # CD stands for compact disk. 145 | cds_0 = list(wsj0_root.rglob("*-*.*")) 146 | cds_1 = list(wsj1_root.rglob("*-*.*")) 147 | cds = set(cds_0 + cds_1) 148 | 149 | expected_number_of_files = { 150 | 'pl': 3, 'ndx': 106, 'ptx': 3547, 'dot': 3585, 'txt': 256 151 | } 152 | number_of_written_files = dict() 153 | for suffix in expected_number_of_files.keys(): 154 | files_0 = list(wsj0_root.rglob(f"*.{suffix}")) 155 | files_1 = list(wsj1_root.rglob(f"*.{suffix}")) 156 | files = set(files_0 + files_1) 157 | # Filter files that do not have a folder that matches "*-*.*". 158 | files = { 159 | file 160 | for file in files 161 | if any([fnmatch.fnmatch(part, "*-*.*") for part in file.parts]) 162 | } 163 | 164 | # the readme.txt file in the parent directory is not copied 165 | print(f"About to write ca. {len(files)} {suffix} files.") 166 | for cd in cds: 167 | cd_files = list(cd.rglob(f"*.{suffix}")) 168 | for file in cd_files: 169 | target = dst_dir / file.relative_to(cd.parent) 170 | target.parent.mkdir(parents=True, exist_ok=True) 171 | if not target.is_file(): 172 | shutil.copy(file, target.parent) 173 | number_of_written_files[suffix] = len( 174 | list(dst_dir.rglob(f"*.{suffix}")) 175 | ) 176 | print( 177 | f"Writing {number_of_written_files[suffix]} {suffix} files." 178 | ) 179 | print( 180 | f'Expected {expected_number_of_files[suffix]} {suffix} files.' 181 | ) 182 | 183 | for suffix in expected_number_of_files.keys(): 184 | message = ( 185 | f'Expected that ' 186 | f'{expected_number_of_files[suffix]} ' 187 | f'files with the {suffix} are written. ' 188 | f'But only {number_of_written_files} are written. ' 189 | ) 190 | if ( 191 | number_of_written_files[suffix] 192 | != expected_number_of_files[suffix] 193 | ): 194 | warnings.warn(message) 195 | 196 | if suffix == 'pl' and number_of_written_files[suffix] == 1: 197 | raise RuntimeError( 198 | 'Found only one pl file although we expected three. ' 199 | 'A typical reason is having only WSJ0. ' 200 | 'Please make sure you have WSJ0+1 = WSJ COMPLETE.' 201 | ) 202 | 203 | if dlp_mpi.IS_MASTER: 204 | # Ignore .wv2 files since they are not referenced in our database 205 | # anyway 206 | wsj_nist_files = [(cd, nist_file) for cd in cds 207 | for nist_file in cd.rglob("*.wv1")] 208 | print(f"About to write {len(wsj_nist_files)} wav files.") 209 | else: 210 | wsj_nist_files = None 211 | 212 | wsj_nist_files = dlp_mpi.bcast(wsj_nist_files) 213 | 214 | for nist_file_tuple in dlp_mpi.split_managed(wsj_nist_files): 215 | cd, nist_file = nist_file_tuple 216 | assert isinstance(nist_file, Path), nist_file 217 | signal = read_nist_wsj(nist_file, expected_sample_rate=16000) 218 | file = nist_file.with_suffix('.wav') 219 | target = dst_dir / file.relative_to(cd.parent) 220 | assert not target == nist_file, (nist_file, target) 221 | target.parent.mkdir(parents=True, exist_ok=True) 222 | signal = resample_with_sox(signal, rate_in=16000, rate_out=sample_rate) 223 | # normalization to mean 0: 224 | signal = signal - np.mean(signal) 225 | # normalization: 226 | # Correction, because the allowed values are in the range [-1, 1). 227 | # => "1" is not a vaild value 228 | correction = (2 ** 15 - 1) / (2 ** 15) 229 | signal = signal * (correction / np.amax(np.abs(signal))) 230 | with soundfile.SoundFile( 231 | str(target), samplerate=sample_rate, channels=1, 232 | subtype='FLOAT', mode='w', 233 | ) as f: 234 | f.write(signal.T) 235 | 236 | dlp_mpi.barrier() 237 | if dlp_mpi.IS_MASTER: 238 | created_files = list(set(list(dst_dir.rglob("*.wav")))) 239 | print(f"Written {len(created_files)} wav files.") 240 | assert len(wsj_nist_files) == len(created_files), (len(wsj_nist_files), len(created_files)) 241 | -------------------------------------------------------------------------------- /sms_wsj/examples/metric_target_comparison.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | ### Description how to start this script 4 | 5 | # Define how you want to parallelize. 6 | # ccsalloc is an HPC scheduler, and this command requests 100 workers and each has 2GB memory 7 | run="ccsalloc --group=hpc-prf-nt1 --res=rset=100:mem=2g:ncpus=1 --tracefile=ompi.%reqid.trace -t 2h ompi -- " 8 | # To run on the machine, where you are logged in, use mpiexec: 9 | run="mpiexec -np 16 " 10 | 11 | # To start the experiment you can then execute the following commands. 12 | # They will generate two files in your current working directory. 13 | ${run} python -m sms_wsj.examples.metric_target_comparison with dataset=cv_dev93 out=oracle_experiment_dev.json # takes approx 30 min with 16 workers 14 | ${run} python -m sms_wsj.examples.metric_target_comparison with dataset=test_eval92 out=oracle_experiment_eval.json # takes approx 40 min with 16 workers 15 | 16 | # Display the summary again from the json: 17 | python -m sms_wsj.examples.metric_target_comparison summary with out=oracle_experiment_dev.json 18 | python -m sms_wsj.examples.metric_target_comparison summary with out=oracle_experiment_eval.json 19 | 20 | """ 21 | 22 | import json 23 | 24 | from IPython.lib.pretty import pprint 25 | import numpy as np 26 | import pandas as pd 27 | import sacred 28 | 29 | from pb_bss.evaluation.wrapper import OutputMetrics 30 | import dlp_mpi 31 | 32 | from sms_wsj.io import dump_json 33 | from sms_wsj.database import SmsWsj, AudioReader 34 | 35 | experiment = sacred.Experiment('Oracle Experiment') 36 | 37 | 38 | @experiment.config 39 | def config(): 40 | dataset = 'cv_dev93' # or 'test_eval92' 41 | out = None # json file to write the detailed results 42 | 43 | 44 | @experiment.capture 45 | def get_dataset(dataset): 46 | """ 47 | >>> np.set_string_function(lambda a: f'array(shape={a.shape}, dtype={a.dtype})') 48 | >>> pprint(get_dataset('cv_dev93')[0]) # doctest: +ELLIPSIS 49 | {'audio_path': ..., 50 | ..., 51 | 'example_id': '4k0c0301_4k6c030t_0', 52 | ..., 53 | 'kaldi_transcription': ..., 54 | ..., 55 | 'audio_data': {'speech_source': array(shape=(2, 103650), dtype=float64), 56 | 'rir': array(shape=(2, 6, 8192), dtype=float64), 57 | 'speech_image': array(shape=(2, 6, 103650), dtype=float64), 58 | 'speech_reverberation_early': array(shape=(2, 6, 103650), dtype=float64), 59 | 'speech_reverberation_tail': array(shape=(2, 6, 103650), dtype=float64), 60 | 'noise_image': array(shape=(1, 1), dtype=float64), 61 | 'observation': array(shape=(6, 103650), dtype=float64)}, 62 | 'snr': 29.749852569493584} 63 | """ 64 | db = SmsWsj() 65 | ds = db.get_dataset(dataset) 66 | ds = ds.map(AudioReader()) 67 | return ds 68 | 69 | 70 | def get_scores(ex, prediction, source): 71 | """ 72 | Calculate the scores, where the prediction/estimated signal is tested 73 | against the source/desired signal. 74 | This function is for oracle test to figure out, which metric can work with 75 | source signal. 76 | 77 | Example: 78 | SI-SDR does not work, when the desired signal is the signal befor the 79 | room impulse response and give strange results, when the channel is 80 | changed. 81 | 82 | >>> pprint(get_scores(get_dataset('cv_dev93')[0], 'image_0', 'early_0')) 83 | {'pesq': array([2.861]), 84 | 'stoi': array([0.97151566]), 85 | 'mir_eval_sxr_sdr': array([13.39136665]), 86 | 'si_sdr': array([10.81039897])} 87 | >>> pprint(get_scores(get_dataset('cv_dev93')[0], 'image_0', 'source')) 88 | {'pesq': array([2.234]), 89 | 'stoi': array([0.8005423]), 90 | 'mir_eval_sxr_sdr': array([12.11446204]), 91 | 'si_sdr': array([-20.05244551])} 92 | >>> pprint(get_scores(get_dataset('cv_dev93')[0], 'image_0', 'image_1')) 93 | {'pesq': array([3.608]), 94 | 'stoi': array([0.92216845]), 95 | 'mir_eval_sxr_sdr': array([9.55425598]), 96 | 'si_sdr': array([-0.16858895])} 97 | """ 98 | def get_signal(ex, name): 99 | assert isinstance(ex, dict), ex 100 | assert 'audio_data' in ex, ex 101 | assert isinstance(ex['audio_data'], dict), ex 102 | if name == 'source': 103 | return ex['audio_data']['speech_source'][:] 104 | elif name == 'early_0': 105 | return ex['audio_data']['speech_reverberation_early'][:, 0] 106 | elif name == 'early_1': 107 | return ex['audio_data']['speech_reverberation_early'][:, 1] 108 | elif name == 'image_0': 109 | return ex['audio_data']['speech_image'][:, 0] 110 | elif name == 'image_1': 111 | return ex['audio_data']['speech_image'][:, 1] 112 | elif name == 'image_0_noise': 113 | return ex['audio_data']['speech_image'][:, 0] + \ 114 | ex['audio_data']['noise_image'][0] 115 | elif name == 'image_1_noise': 116 | return ex['audio_data']['speech_image'][:, 1] + \ 117 | ex['audio_data']['noise_image'][0] 118 | else: 119 | raise ValueError(name) 120 | 121 | speech_prediction = get_signal(ex, prediction) 122 | speech_source = get_signal(ex, source) 123 | 124 | metric = OutputMetrics( 125 | speech_prediction=speech_prediction, 126 | speech_source=speech_source, 127 | sample_rate=8000, 128 | enable_si_sdr=True, 129 | ) 130 | 131 | result = metric.as_dict() 132 | del result['mir_eval_sxr_selection'] 133 | del result['mir_eval_sxr_sar'] 134 | del result['mir_eval_sxr_sir'] 135 | 136 | return result 137 | 138 | 139 | @experiment.command 140 | def summary(out): 141 | if dlp_mpi.IS_MASTER: 142 | if isinstance(out, str): 143 | assert out.endswith('.json'), out 144 | with open(out, 'r') as fd: 145 | data = json.load(fd) 146 | else: 147 | data = out 148 | 149 | df = pd.DataFrame(data) 150 | 151 | def force_order(df, key): 152 | # https://stackoverflow.com/a/28686885/5766934 153 | df[key] = df[key].astype("category") 154 | df[key].cat.set_categories(pd.unique(df['source']), inplace=True) 155 | 156 | force_order(df, 'source') 157 | force_order(df, 'prediction') 158 | 159 | with pd.option_context( 160 | "display.precision", 2, 161 | 'display.width', 200, 162 | ): 163 | # print(pd.pivot_table( 164 | # df, 165 | # index=['prediction', 'source'], 166 | # columns='score_name', 167 | # values='value', 168 | # aggfunc=np.mean, # average over examples 169 | # )) 170 | print() 171 | print(pd.pivot_table( 172 | df.query('score_name == "mir_eval_sxr_sdr"'), 173 | index=['prediction'], 174 | columns=['score_name', 'source'], 175 | values='value', 176 | aggfunc=np.mean, # average over examples 177 | )) 178 | print() 179 | print(pd.pivot_table( 180 | df.query('score_name == "si_sdr"'), 181 | index=['prediction'], 182 | columns=['score_name', 'source'], 183 | values='value', 184 | aggfunc=np.mean, # average over examples 185 | )) 186 | 187 | 188 | @experiment.automain 189 | def main(_run, out): 190 | if dlp_mpi.IS_MASTER: 191 | from sacred.commands import print_config 192 | print_config(_run) 193 | 194 | ds = get_dataset() 195 | 196 | data = [] 197 | 198 | for ex in dlp_mpi.split_managed(ds.sort(), allow_single_worker=True): 199 | for prediction in [ 200 | 'source', 201 | 'early_0', 202 | 'early_1', 203 | 'image_0', 204 | 'image_1', 205 | 'image_0_noise', 206 | 'image_1_noise', 207 | ]: 208 | for source in [ 209 | 'source', 210 | 'early_0', 211 | 'early_1', 212 | 'image_0', 213 | 'image_1', 214 | 'image_0_noise', 215 | 'image_1_noise', 216 | ]: 217 | scores = get_scores(ex, prediction=prediction, source=source) 218 | for score_name, score_value in scores.items(): 219 | data.append(dict( 220 | score_name=score_name, 221 | prediction=prediction, 222 | source=source, 223 | example_id=ex['example_id'], 224 | value=score_value, 225 | )) 226 | 227 | data = dlp_mpi.gather(data) 228 | 229 | if dlp_mpi.IS_MASTER: 230 | data = [ 231 | entry 232 | for worker_data in data 233 | for entry in worker_data 234 | ] 235 | 236 | if out is not None: 237 | assert isinstance(out, str), out 238 | assert out.endswith('.json'), out 239 | print(f'Write details to {out}.') 240 | dump_json(data, out) 241 | 242 | summary(data) 243 | -------------------------------------------------------------------------------- /sms_wsj/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import numpy as np 4 | import soundfile 5 | 6 | 7 | class NumpyEncoder(json.JSONEncoder): 8 | # https://stackoverflow.com/a/47626762/5766934 9 | def default(self, obj): 10 | if isinstance(obj, np.ndarray): 11 | return obj.tolist() 12 | return json.JSONEncoder.default(self, obj) 13 | 14 | 15 | def dump_json(obj, file, indent=2): 16 | with open(file, 'w') as fd: 17 | json.dump(obj, fd, cls=NumpyEncoder, indent=indent) 18 | 19 | 20 | def dump_audio(obj, file, samplerate=8000, mkdir=True, normalize=True): 21 | if normalize: 22 | # Correction, because the allowed values are in the range [-1, 1). 23 | # => "1" is not a vaild value 24 | correction = (2**15 - 1) / (2**15) 25 | obj = obj * (correction / np.amax(np.abs(obj))) 26 | 27 | if isinstance(file, Path): 28 | file = str(file) 29 | try: 30 | soundfile.write( 31 | file=file, 32 | data=obj, 33 | samplerate=samplerate, 34 | ) 35 | except RuntimeError: 36 | if mkdir: 37 | # Assume mkdir is rarely nessesary, hence first try write 38 | Path(file).parent.mkdir( 39 | parents=True, 40 | exist_ok=True, # Allow concurrent mkdir 41 | ) 42 | soundfile.write( 43 | file=file, 44 | data=obj, 45 | samplerate=samplerate, 46 | ) 47 | else: 48 | raise 49 | -------------------------------------------------------------------------------- /sms_wsj/kaldi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/sms_wsj/f6b94bb987d620fc8bbe005a92bd042e09a68d9b/sms_wsj/kaldi/__init__.py -------------------------------------------------------------------------------- /sms_wsj/kaldi/get_kaldi_wer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example call on local machine: 3 | Automatically takes all available cores: 4 | $ python -m sms_wsj.kaldi.get_kaldi_wer -F /EXP/DIR with kaldi_data_dir=/KALDI/DATA/DIR model_egs_dir=/MODEL/EGS/DIR 5 | 6 | Uses the json_path to create a kaldi data dir 7 | $ python -m sms_wsj.kaldi.get_kaldi_wer -F /EXP/DIR with json_path=/JSON/PATH model_egs_dir=/MODEL/EGS/DIR 8 | 9 | Evaluates audio data in audio_dir, expects audio_dir/dataset to exist and audio files of format {exampled_id}_0.wav the format can be changed using the variable id_to_file_name 10 | $ python -m sms_wsj.kaldi.get_kaldi_wer -F /EXP/DIR with audio_dir=/AUDIO/DIR json_path=/JSON/PATH model_egs_dir=/MODEL/EGS/DIR 11 | 12 | The follwoing command is used to directly decode a dataset. Expects kaldi_data_dir/dataset to exist. 13 | $ python -m sms_wsj.kaldi.get_kaldi_wer -F /EXP/DIR decode with kaldi_data_dir=/KALDI/DATA/DIR model_egs_dir=/MODEL/EGS/DIR dataset=test_eval92 14 | 15 | 16 | Example call on pc2 (HPC system in paderborn): 17 | $ ccsalloc --group=hpc-prf-nt1 --res=rset=64:mem=2G:ncpus=1 -t 2h ompi -- python -m sms_wsj.kaldi.get_kaldi_wer -F /EXP/DIR decode with kaldi_data_dir=/KALDI/DATA/DIR model_egs_dir=/MODEL/EGS/DIR dataset=test_eval92 18 | 19 | """ 20 | 21 | import os 22 | from pathlib import Path 23 | from shutil import copytree 24 | 25 | import sacred 26 | from lazy_dataset.database import JsonDatabase 27 | from sms_wsj.kaldi.utils import calculate_mfccs, calculate_ivectors 28 | from sms_wsj.kaldi.utils import create_data_dir_from_audio_dir 29 | from sms_wsj.kaldi.utils import create_kaldi_dir, create_data_dir 30 | from sms_wsj.kaldi.utils import run_process 31 | 32 | ex = sacred.Experiment('Kaldi array') 33 | kaldi_root = Path(os.environ['KALDI_ROOT']) 34 | 35 | 36 | @ex.command 37 | def create_dir( 38 | audio_dir: Path, dataset_names=None, base_dir=None, json_path=None, db=None, 39 | data_type='sms_enh', id_to_file_name='{id}_{spk}.wav', target_speaker=0, 40 | sample_rate=8000, 41 | ): 42 | """ 43 | 44 | Args: 45 | audio_dir: path to audio_dir 46 | dataset_names: datasets to create a data_dir for 47 | base_dir: directory in which all information is copied or generated 48 | json_path: path to wsj_bss.json file 49 | db: JsonDatabase object 50 | data_type: name of data type to evaluate 51 | id_to_file_name: template to get the wav file name from the example_id 52 | target_speaker: index of speaker to decode 53 | sample_rate: 54 | 55 | Returns: 56 | 57 | """ 58 | if base_dir is None: 59 | assert len(ex.current_run.observers) == 1, ( 60 | 'FileObserver` missing. Add a `FileObserver` with `-F foo/bar/`.' 61 | ) 62 | base_dir = Path( 63 | ex.current_run.observers[0].basedir).expanduser().resolve() 64 | 65 | audio_dir = Path(audio_dir).expanduser().resolve() 66 | create_data_dir_from_audio_dir( 67 | audio_dir, base_dir, id_to_file_name=id_to_file_name, db=db, 68 | json_path=json_path, dataset_names=dataset_names, data_type=data_type, 69 | target_speaker=target_speaker, sample_rate=sample_rate 70 | ) 71 | 72 | 73 | @ex.command 74 | def decode(model_egs_dir, dataset_dir, base_dir=None, model_data_type='sms', 75 | model_dir='chain/tdnn1a_sp', ivector_dir=True, extractor_dir=None, 76 | data_type='sms_enh', hires=True, num_jobs=8, kaldi_cmd='run.pl'): 77 | """ 78 | 79 | Args: 80 | model_egs_dir: path to the egs dir the model was trained in 81 | dataset_dir: kaldi egs dir for the decoding 82 | e.g.: "/data/cv_dev_93" 83 | base_dir: directory in which all information is copied or generated 84 | model_data_type: data type on which the model was trained 85 | model_dir: name of model or Path to model_dir 86 | ivector_dir: directory or name for the ivectors (may be None or False) 87 | extractor_dir: directory of the ivector extractor (maybe None) 88 | data_type: name of data type to evaluate 89 | hires: flag for using high resolution mfcc features (True / False) 90 | num_jobs: number of parallel jobs 91 | kaldi_cmd: kaldi cmd for example run.pl, ssh.pl queue.pl 92 | 93 | Returns: 94 | 95 | """ 96 | 97 | if base_dir is None: 98 | assert len(ex.current_run.observers) == 1, ( 99 | 'FileObserver` missing. Add a `FileObserver` with `-F foo/bar/`.' 100 | ) 101 | base_dir = Path(ex.current_run.observers[0].basedir) 102 | base_dir = base_dir.expanduser().resolve() 103 | dataset_dir = dataset_dir.expanduser().resolve() 104 | assert dataset_dir.exists(), dataset_dir 105 | copytree(dataset_dir, base_dir / 'data' / dataset_dir.name, 106 | symlinks=True) 107 | dataset_dir = base_dir / 'data' / dataset_dir.name 108 | run_process([ 109 | f'utils/fix_data_dir.sh', f'{dataset_dir}'], 110 | cwd=str(base_dir), stdout=None, stderr=None) 111 | else: 112 | base_dir = base_dir.expanduser().resolve() 113 | model_egs_dir = Path(model_egs_dir).expanduser().resolve() 114 | if isinstance(model_dir, str): 115 | model_dir = model_egs_dir / 'exp' / model_data_type / model_dir 116 | 117 | assert model_dir.exists(), f'{model_dir} does not exist' 118 | 119 | os.environ['PATH'] = f'{base_dir}/utils:{os.environ["PATH"]}' 120 | decode_dir = base_dir / 'exp' / model_data_type / model_dir.name 121 | if not decode_dir.exists(): 122 | decode_dir.mkdir(parents=True) 123 | [os.symlink(str(file), str(decode_dir / file.name)) 124 | for file in model_dir.glob('*') if file.is_file()] 125 | assert (decode_dir / 'final.mdl').exists(), ( 126 | f'final.mdl not in decode_dir: {decode_dir}, ' 127 | f'maybe using worn model_egs_dir: {model_egs_dir}?' 128 | ) 129 | decode_name = f'decode_{data_type}_{dataset_dir.name}' 130 | (decode_dir / decode_name).mkdir(exist_ok=False) 131 | if not base_dir == model_egs_dir and not (base_dir / 'steps').exists(): 132 | create_kaldi_dir(base_dir, model_egs_dir, exist_ok=True) 133 | if kaldi_cmd == 'ssh.pl': 134 | CCS_NODEFILE = Path(os.environ['CCS_NODEFILE']) 135 | (base_dir / '.queue').mkdir() 136 | (base_dir / '.queue' / 'machines').write_text( 137 | CCS_NODEFILE.read_text()) 138 | elif kaldi_cmd == 'run.pl': 139 | pass 140 | else: 141 | raise ValueError(kaldi_cmd) 142 | config = 'mfcc_hires.conf' if hires else 'mfcc.conf' 143 | calculate_mfccs(base_dir, dataset_dir, num_jobs=num_jobs, 144 | config=config, recalc=True, kaldi_cmd=kaldi_cmd) 145 | ivector_dir = calculate_ivectors( 146 | ivector_dir, base_dir, dataset_dir, extractor_dir, model_egs_dir, 147 | model_data_type, data_type, num_jobs, kaldi_cmd 148 | ) 149 | run_process([ 150 | 'steps/nnet3/decode.sh', '--acwt', '1.0', 151 | '--post-decode-acwt', '10.0', 152 | '--extra-left-context', '0', '--extra-right-context', '0', 153 | '--extra-left-context-initial', '0', '--extra-right-context-final', 154 | '0', '--frames-per-chunk', '140', '--nj', str(num_jobs), 155 | '--cmd', f'{kaldi_cmd}', '--online-ivector-dir', 156 | str(ivector_dir), f'{model_dir.parent}/tree_a_sp/graph_tgpr', 157 | str(dataset_dir), str(decode_dir / decode_name)], 158 | cwd=str(base_dir), 159 | stdout=None, stderr=None 160 | ) 161 | print((decode_dir / decode_name / 'scoring_kaldi' / 'best_wer' 162 | ).read_text()) 163 | 164 | 165 | @ex.config 166 | def default(): 167 | """ 168 | If audio_dir and json_path are defined, the wavs in audio_dir will be 169 | decoded. If necessary a mapping from example_id to the wav names can be 170 | specified using id_to_file_name. 171 | If kaldi_data_dir it will be used as data_dir for decoding. In this case 172 | audio_dir hast to be None 173 | If neither audio_dir nor kaldi_data_dir is defined, but json_path is not 174 | None. kaldi data_dirs for all data_type and dataset_names are created and 175 | decoded. 176 | 177 | model_egs_dir: egs directory of the trained model with data and exp dir 178 | num_jobs: if not specified takes the the number of cores as default 179 | 180 | """ 181 | model_egs_dir = None 182 | 183 | # Only one of these two variables has to be defined 184 | audio_dir = None 185 | kaldi_data_dir = None 186 | 187 | json_path = None 188 | 189 | model_data_type = 'sms_single_speaker' 190 | dataset_names = ['test_eval92', 'cv_dev93'] 191 | 192 | # This is only used when decode is called directly 193 | if isinstance(dataset_names, str): 194 | dataset_dir = f'{kaldi_data_dir}/{dataset_names}' 195 | else: 196 | dataset_dir = None 197 | if kaldi_data_dir is None and audio_dir is None and dataset_dir is None: 198 | data_type = ['wsj_8k', 'sms_early', 'sms_image', 199 | 'sms_single_speaker', 'sms'] 200 | else: 201 | data_type = 'sms_enh' 202 | 203 | # only used with audio_dir 204 | id_to_file_name = '{id}_{spk}.wav' 205 | # id_to_file_name = '{}_{}.wav' is another possible default, but only 206 | # if the first {} represents the example id and the second the speaker id 207 | target_speaker = [0, 1] 208 | 209 | ref_channels = 0 210 | 211 | if ref_channels > 0 and isinstance(data_type, (list, tuple)): 212 | assert 'wsj_8k' not in data_type, data_type 213 | else: 214 | assert not data_type == 'wsj_8k', (ref_channels, data_type) 215 | 216 | # am specific values which usually do not have to be changed 217 | ivector_dir = True 218 | extractor_dir = 'nnet3/extractor' 219 | model_dir = 'chain/tdnn1a_sp' 220 | hires = True 221 | kaldi_cmd = 'run.pl' 222 | sample_rate = 8000 223 | 224 | # only used for the paderborn parallel computing center 225 | if 'CCS_NODEFILE' in os.environ: 226 | num_jobs = len(list( 227 | Path(os.environ['CCS_NODEFILE']).read_text().strip().splitlines() 228 | )) 229 | else: 230 | # WSJ dev has only 8 speaker and Kaldi fails, when num_jobs is higher. 231 | num_jobs = min(8, os.cpu_count()) 232 | 233 | 234 | def check_config_element(element): 235 | if element is not None and not isinstance(element, bool): 236 | element_path = element 237 | if Path(element_path).exists(): 238 | element_path = Path(element_path) 239 | elif isinstance(element, bool): 240 | element_path = element 241 | else: 242 | element_path = None 243 | return element_path 244 | 245 | 246 | @ex.automain 247 | def run(_config, _run, audio_dir, kaldi_data_dir, json_path): 248 | assert Path(kaldi_root).exists(), kaldi_root 249 | 250 | assert len(ex.current_run.observers) == 1, ( 251 | 'FileObserver` missing. Add a `FileObserver` with `-F foo/bar/`.' 252 | ) 253 | base_dir = Path(ex.current_run.observers[0].basedir) 254 | base_dir = base_dir.expanduser().resolve() 255 | if audio_dir is not None: 256 | audio_dir = Path(audio_dir).expanduser().resolve() 257 | assert audio_dir.exists(), audio_dir 258 | json_path = Path(json_path).expanduser().resolve() 259 | assert json_path.exists(), json_path 260 | db = JsonDatabase(json_path) 261 | elif kaldi_data_dir is not None: 262 | kaldi_data_dir = Path(kaldi_data_dir).expanduser().resolve() 263 | assert kaldi_data_dir.exists(), kaldi_data_dir 264 | assert json_path is None, json_path 265 | elif json_path is not None: 266 | json_path = Path(json_path).expanduser().resolve() 267 | assert json_path.exists(), json_path 268 | db = JsonDatabase(json_path) 269 | else: 270 | raise ValueError('Either json_path, audio_dir or kaldi_data_dir has' 271 | 'to be defined.') 272 | if _config['model_egs_dir'] is None: 273 | model_egs_dir = kaldi_root / 'egs' / 'sms_wsj' / 's5' 274 | else: 275 | model_egs_dir = Path(_config['model_egs_dir']).expanduser().resolve() 276 | assert model_egs_dir.exists(), model_egs_dir 277 | 278 | dataset_names = _config['dataset_names'] 279 | if not isinstance(dataset_names, (tuple, list)): 280 | dataset_names = [dataset_names] 281 | data_type = _config['data_type'] 282 | if not isinstance(data_type, (tuple, list)): 283 | data_type = [data_type] 284 | 285 | kaldi_cmd = _config['kaldi_cmd'] 286 | if not base_dir == model_egs_dir and not (base_dir / 'steps').exists(): 287 | create_kaldi_dir(base_dir, model_egs_dir, exist_ok=True, 288 | sample_rate=_config['sample_rate']) 289 | if kaldi_cmd == 'ssh.pl': 290 | CCS_NODEFILE = Path(os.environ['CCS_NODEFILE']) 291 | (base_dir / '.queue').mkdir() 292 | (base_dir / '.queue' / 'machines').write_text( 293 | CCS_NODEFILE.read_text()) 294 | elif kaldi_cmd == 'run.pl': 295 | pass 296 | else: 297 | raise ValueError(kaldi_cmd) 298 | 299 | for d_type in data_type: 300 | for dset in dataset_names: 301 | dataset_dir = base_dir / 'data' / d_type / dset 302 | if audio_dir is not None: 303 | assert len(data_type) == 1, data_type 304 | create_dir( 305 | audio_dir, base_dir=base_dir, db=db, dataset_names=dset 306 | ) 307 | elif kaldi_data_dir is None: 308 | create_data_dir( 309 | base_dir, db=db, data_type=d_type, dataset_names=dset, 310 | ref_channels=_config['ref_channels'], 311 | target_speaker=_config['target_speaker'], 312 | sample_rate=_config['sample_rate'], 313 | ) 314 | else: 315 | assert len(data_type) == 1, ( 316 | 'when using a predefined kaldi_data_dir not more then one ' 317 | 'data_type should be defined. Better use the decode' 318 | 'command directly' 319 | ) 320 | copytree(kaldi_data_dir / dset, dataset_dir, symlinks=True) 321 | run_process([ 322 | f'utils/fix_data_dir.sh', f'{dataset_dir}'], 323 | cwd=str(base_dir), stdout=None, stderr=None) 324 | 325 | decode( 326 | base_dir=base_dir, 327 | model_egs_dir=model_egs_dir, 328 | dataset_dir=dataset_dir, 329 | model_dir=check_config_element(_config['model_dir']), 330 | ivector_dir=check_config_element(_config['ivector_dir']), 331 | extractor_dir=check_config_element(_config['extractor_dir']), 332 | data_type=d_type 333 | ) 334 | -------------------------------------------------------------------------------- /sms_wsj/kaldi/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import stat 4 | import subprocess 5 | from collections import defaultdict 6 | from functools import partial 7 | from pathlib import Path 8 | 9 | from lazy_dataset.database import JsonDatabase 10 | from sms_wsj import git_root 11 | 12 | DB2AudioKeyMapper = dict( 13 | wsj_8k='original_source', 14 | sms_early='speech_reverberation_early', 15 | sms='observation', 16 | noise='noise_image', 17 | sms_late='speech_reverberation_tail' 18 | ) 19 | 20 | kaldi_root = Path(os.environ['KALDI_ROOT']) 21 | 22 | REQUIRED_FILES = [] 23 | REQUIRED_DIRS = ['data/lang', 'data/local', 24 | 'local', 'steps', 'utils'] 25 | DIRS_WITH_CHANGEABLE_FILES = ['conf', 'data/lang_test_tgpr', 26 | 'data/lang_test_tg'] 27 | 28 | 29 | def create_kaldi_dir(egs_path, org_dir=None, exist_ok=False, sample_rate=8000): 30 | """ 31 | 32 | Args: 33 | egs_path: 34 | org_dir: 35 | An egs folder (e.g. $KALDI_ROOT/egs/wsj/s5). This folder is used as 36 | reference to create the new eps folder. 37 | e.g. 38 | - make symlinks to the 'local', 'steps', 'utils', 'data/lang' and 39 | 'data/local' folder 40 | - copy 'conf', 'data/lang_test_tgpr' and 'data/lang_test_tg' to 41 | the new folder 42 | 43 | 44 | 45 | Returns: 46 | 47 | """ 48 | print(f'Create {egs_path} directory') 49 | (egs_path / 'data').mkdir(exist_ok=exist_ok, parents=True) 50 | if org_dir is None: 51 | org_dir = (egs_path / '..' / '..' / 'wsj' / 's5').resolve() 52 | for file in REQUIRED_FILES: 53 | os.symlink(org_dir / file, egs_path / file) 54 | for dirs in REQUIRED_DIRS: 55 | os.symlink(org_dir / dirs, egs_path / dirs) 56 | for dirs in DIRS_WITH_CHANGEABLE_FILES: 57 | shutil.copytree(org_dir / dirs, egs_path / dirs) 58 | for script in (git_root / 'scripts').glob('*'): 59 | if script.name in ['path.sh', 'cmd.sh']: 60 | new_script_path = egs_path / script.name 61 | else: 62 | (egs_path / 'local_sms').mkdir(exist_ok=True) 63 | new_script_path = egs_path / 'local_sms' / script.name 64 | 65 | shutil.copyfile(script, new_script_path) 66 | if script.name == 'path.sh': 67 | with new_script_path.open('r+') as f: 68 | content = f.read() 69 | f.seek(0, 0) 70 | f.write(f'export KALDI_ROOT={kaldi_root}' + '\n' + content) 71 | # make script executable 72 | st = os.stat(new_script_path) 73 | os.chmod(new_script_path, st.st_mode | stat.S_IEXEC) 74 | 75 | if sample_rate != 16000: 76 | for file in ['mfcc.conf', 'mfcc_hires.conf']: 77 | with (egs_path / 'conf' / file).open('a') as fd: 78 | fd.writelines(f"--sample-frequency={sample_rate}\n") 79 | 80 | 81 | def _get_wav_command_for_json(example, ref_ch, spk, audio_key): 82 | if isinstance(audio_key, (list, tuple)): 83 | mix_command = 'sox -m -v 1 ' + ' -v 1 '.join( 84 | [str(example['audio_path'][audio][spk]) 85 | if isinstance(example['audio_path'][audio], (list, tuple)) 86 | else str(example['audio_path'][audio]) for audio in audio_key] 87 | ) 88 | wav_command = f'{mix_command} -t wav - | sox -t wav -' \ 89 | f' -t wav -b 16 - remix {ref_ch + 1} |' 90 | else: 91 | if isinstance(example['audio_path'][audio_key], (list, tuple)): 92 | wav = example['audio_path'][audio_key][spk] 93 | else: 94 | wav = example['audio_path'][audio_key] 95 | wav_command = f'sox {wav} -t wav -b 16 - remix {ref_ch + 1} |' 96 | return wav_command 97 | 98 | 99 | def _get_wav_command_for_audio_dir( 100 | example, ref_ch, spk, audio_dir, id_to_file_name_fn): 101 | dataset_name = example['dataset'] 102 | ex_id = example['example_id'] 103 | try: 104 | audio_path = audio_dir / dataset_name / id_to_file_name_fn(ex_id, spk) 105 | assert audio_path.exists(), audio_path 106 | except AssertionError: 107 | audio_path = audio_dir / id_to_file_name_fn(ex_id, spk) 108 | assert audio_path.exists(), audio_path 109 | 110 | wav_command = f'sox {audio_path} -t wav -b 16 - remix {ref_ch + 1} |' 111 | return wav_command 112 | 113 | 114 | def create_data_dir( 115 | kaldi_dir, db=None, json_path=None, dataset_names=None, 116 | data_type='wsj_8k', target_speaker=0, ref_channels=0, 117 | sample_rate=8000): 118 | """ 119 | Wrapper calling _create_data_dir for data_dirs from json or db object 120 | """ 121 | if data_type == 'sms_single_speaker': 122 | audio_key = [DB2AudioKeyMapper[data] 123 | for data in ['sms_early', 'sms_late', 'noise']] 124 | elif data_type == 'sms_image': 125 | audio_key = [DB2AudioKeyMapper[data] 126 | for data in ['sms_early', 'sms_late']] 127 | else: 128 | audio_key = DB2AudioKeyMapper[data_type] 129 | get_wav_command_fn = partial( 130 | _get_wav_command_for_json, audio_key=audio_key 131 | ) 132 | _create_data_dir( 133 | get_wav_command_fn, kaldi_dir=kaldi_dir, db=db, json_path=json_path, 134 | dataset_names=dataset_names, data_type=data_type, 135 | target_speaker=target_speaker, ref_channels=ref_channels, 136 | sample_rate=sample_rate 137 | ) 138 | 139 | 140 | def create_data_dir_from_audio_dir( 141 | audio_dir, kaldi_dir, id_to_file_name='{id}_{spk}.wav', db=None, 142 | json_path=None, dataset_names=None, data_type='wsj_8k', 143 | target_speaker=0, ref_channels=0, sample_rate=8000, 144 | ): 145 | """ 146 | Wrapper calling _create_data_dir for data_dirs from audio_dir 147 | """ 148 | if isinstance(id_to_file_name, str): 149 | 150 | if '{}' in id_to_file_name: 151 | id_to_file_name_fn = lambda _id, spk: id_to_file_name.format(_id, spk) 152 | else: 153 | id_to_file_name_fn = lambda _id, spk: id_to_file_name.format( 154 | id=_id, spk=spk) 155 | else: 156 | id_to_file_name_fn = id_to_file_name 157 | assert callable(id_to_file_name_fn), id_to_file_name_fn 158 | if isinstance(target_speaker, (list, tuple)) and len(target_speaker) > 1: 159 | assert id_to_file_name_fn('id1', 'spk1') != id_to_file_name_fn( 160 | 'id1', 'spk2'), (id_to_file_name_fn('id1', 'spk1'), 161 | id_to_file_name_fn('id1', 'spk2')) 162 | assert id_to_file_name_fn('id1', 'spk1') != id_to_file_name_fn( 163 | 'id2', 'spk1'), (id_to_file_name_fn('id1', 'spk1'), 164 | id_to_file_name_fn('id2', 'spk1')) 165 | 166 | get_wav_command_fn = partial( 167 | _get_wav_command_for_audio_dir, audio_dir=audio_dir, 168 | id_to_file_name_fn=id_to_file_name_fn 169 | ) 170 | _create_data_dir( 171 | get_wav_command_fn, kaldi_dir=kaldi_dir, db=db, json_path=json_path, 172 | dataset_names=dataset_names, data_type=data_type, 173 | target_speaker=target_speaker, ref_channels=ref_channels, 174 | sample_rate=sample_rate 175 | ) 176 | 177 | 178 | def _create_data_dir( 179 | get_wav_command_fn, kaldi_dir, db=None, json_path=None, 180 | dataset_names=None, data_type='wsj_8k', target_speaker=0, 181 | ref_channels=0, sample_rate=8000, 182 | ): 183 | """ 184 | 185 | Args: 186 | get_wav_command_fn: 187 | kaldi_dir: 188 | db: 189 | json_path: 190 | dataset_names: 191 | data_type: 192 | target_speaker: 193 | ref_channels: 194 | 195 | Returns: 196 | 197 | """ 198 | 199 | assert not (db is None and json_path is None), (db, json_path) 200 | if db is None: 201 | db = JsonDatabase(json_path) 202 | 203 | kaldi_dir = Path(kaldi_dir).expanduser().resolve() 204 | 205 | data_dir = kaldi_dir / 'data' / data_type 206 | data_dir.mkdir(exist_ok=True, parents=True) 207 | 208 | if not isinstance(ref_channels, (list, tuple)): 209 | ref_channels = [ref_channels] 210 | example_id_to_wav = dict() 211 | example_id_to_speaker = dict() 212 | example_id_to_trans = dict() 213 | example_id_to_duration = dict() 214 | speaker_to_gender = defaultdict(lambda: defaultdict(list)) 215 | dataset_to_example_id = defaultdict(list) 216 | 217 | if dataset_names is None: 218 | dataset_names = ('train_si284', 'cv_dev93', 'test_eval92') 219 | elif isinstance(dataset_names, str): 220 | dataset_names = [dataset_names] 221 | if not isinstance(target_speaker, (list, tuple)): 222 | target_speaker = [target_speaker] 223 | assert not any([ 224 | (data_dir / dataset_name).exists() for dataset_name in dataset_names 225 | ]), ( 226 | 'One of the following directories already exists: ' 227 | f'{[data_dir / ds_name for ds_name in dataset_names]}\n' 228 | 'Delete them if you want to restart this stage' 229 | ) 230 | 231 | print( 232 | 'Create data dir for ' 233 | f'{", ".join([f"{data_type}/{ds_name}" for ds_name in dataset_names])} ' 234 | 'data' 235 | ) 236 | 237 | dataset = db.get_dataset(dataset_names) 238 | for example in dataset: 239 | for ref_ch in ref_channels: 240 | org_example_id = example['example_id'] 241 | dataset_name = example['dataset'] 242 | for t_spk in target_speaker: 243 | speaker_id = example['speaker_id'][t_spk] 244 | example_id = speaker_id + '_' + org_example_id 245 | example_id += f'_c{ref_ch}' if len(ref_channels) > 1 else '' 246 | example_id_to_wav[example_id] = get_wav_command_fn( 247 | example, ref_ch=ref_ch, spk=t_spk) 248 | try: 249 | transcription = example['kaldi_transcription'][t_spk] 250 | except KeyError: 251 | transcription = example['transcription'][t_spk] 252 | example_id_to_trans[example_id] = transcription 253 | 254 | example_id_to_speaker[example_id] = speaker_id 255 | gender = example['gender'][t_spk] 256 | speaker_to_gender[dataset_name][speaker_id] = gender 257 | if isinstance(example['num_samples'], dict): 258 | num_samples = example['num_samples']['observation'] 259 | else: 260 | num_samples = example['num_samples'] 261 | example_id_to_duration[ 262 | example_id] = f"{num_samples / sample_rate:.2f}" 263 | dataset_to_example_id[dataset_name].append(example_id) 264 | 265 | assert len(example_id_to_speaker) > 0, dataset 266 | for dataset_name in dataset_names: 267 | path = data_dir / dataset_name 268 | path.mkdir(exist_ok=False, parents=False) 269 | for name, dictionary in ( 270 | ("utt2spk", example_id_to_speaker), 271 | ("text", example_id_to_trans), 272 | ("utt2dur", example_id_to_duration), 273 | ("wav.scp", example_id_to_wav) 274 | ): 275 | dictionary = {key: value for key, value in dictionary.items() 276 | if key in dataset_to_example_id[dataset_name]} 277 | 278 | assert len(dictionary) > 0, (dataset_name, name) 279 | if name == 'utt2dur': 280 | dump_keyed_lines(dictionary, path / 'reco2dur') 281 | dump_keyed_lines(dictionary, path / name) 282 | dictionary = speaker_to_gender[dataset_name] 283 | assert len(dictionary) > 0, (dataset_name, name) 284 | dump_keyed_lines(dictionary, path / 'spk2gender') 285 | run_process([ 286 | f'utils/fix_data_dir.sh', f'{path}'], 287 | cwd=str(kaldi_dir), stdout=None, stderr=None 288 | ) 289 | 290 | 291 | def calculate_mfccs(base_dir, dataset, num_jobs=20, config='mfcc.conf', 292 | recalc=False, kaldi_cmd='run.pl'): 293 | """ 294 | 295 | :param base_dir: kaldi egs directory with steps and utils dir 296 | :param dataset: name of folder in data 297 | :param num_jobs: number of parallel jobs 298 | :param config: mfcc config 299 | :param recalc: recalc feats if already calculated 300 | :param kaldi_cmd: 301 | :return: 302 | """ 303 | base_dir = base_dir.expanduser().resolve() 304 | 305 | if isinstance(dataset, str): 306 | dataset = base_dir / 'data' / dataset 307 | assert dataset.exists(), dataset 308 | if not (dataset / 'feats.scp').exists() or recalc: 309 | run_process([ 310 | 'steps/make_mfcc.sh', '--nj', str(num_jobs), 311 | '--mfcc-config', f'{base_dir}/conf/{config}', 312 | '--cmd', f'{kaldi_cmd}', f'{dataset}', 313 | f'{dataset}/make_mfcc', f'{dataset}/mfcc'], 314 | cwd=str(base_dir), stdout=None, stderr=None 315 | ) 316 | 317 | if not (dataset / 'cmvn.scp').exists() or recalc: 318 | run_process([ 319 | f'steps/compute_cmvn_stats.sh', 320 | f'{dataset}', f'{dataset}/make_mfcc', f'{dataset}/mfcc'], 321 | cwd=str(base_dir), stdout=None, stderr=None 322 | ) 323 | run_process([ 324 | f'utils/fix_data_dir.sh', f'{dataset}'], 325 | cwd=str(base_dir), stdout=None, stderr=None 326 | ) 327 | 328 | 329 | def calculate_ivectors(ivector_dir, dest_dir, dataset_dir, extractor_dir=None, 330 | org_dir=None, model_data_type='sms', 331 | data_type='sms', num_jobs=8, kaldi_cmd='run.pl'): 332 | """ 333 | 334 | Args: 335 | ivector_dir: ivector directory may be a string, bool or Path 336 | dest_dir: kaldi egs directory with steps and utils dir 337 | dataset_dir: kaldi data dir 338 | extractor_dir: directory of the ivector extractor (may be None) 339 | org_dir: kaldi egs directory used if extractor_dir is only a string 340 | model_data_type: dataset specifier for the extractor data type 341 | data_type: dataset specifier for the input data 342 | num_jobs: number of parallel jobs 343 | kaldi_cmd: 344 | 345 | Returns: 346 | 347 | """ 348 | 349 | dest_dir = dest_dir.expanduser().resolve() 350 | 351 | if isinstance(ivector_dir, str): 352 | ivector_dir = dest_dir / 'exp' / model_data_type / 'nnet3' / \ 353 | ivector_dir 354 | elif ivector_dir is True: 355 | ivector_dir = dest_dir / 'exp' / model_data_type / 'nnet3' / ( 356 | f'ivectors_{data_type}_{dataset_dir.name}') 357 | elif isinstance(ivector_dir, Path): 358 | ivector_dir = ivector_dir 359 | else: 360 | raise ValueError(f'ivector_dir {ivector_dir} has to be either' 361 | f' a Path, a string or bolean') 362 | if not ivector_dir.exists(): 363 | if extractor_dir is None: 364 | extractor_dir = org_dir / f'exp/{model_data_type}/' \ 365 | f'nnet3/extractor' 366 | else: 367 | if isinstance(extractor_dir, str): 368 | extractor_dir = org_dir / f'exp/{model_data_type}/' \ 369 | f'{extractor_dir}' 370 | assert extractor_dir.exists(), extractor_dir 371 | print(f'Directory {ivector_dir} not found, estimating ivectors') 372 | run_process([ 373 | 'steps/online/nnet2/extract_ivectors_online.sh', 374 | '--cmd', f'{kaldi_cmd}', '--nj', f'{num_jobs}', f'{dataset_dir}', 375 | f'{extractor_dir}', str(ivector_dir)], 376 | cwd=str(dest_dir), 377 | stdout=None, stderr=None 378 | ) 379 | return ivector_dir 380 | 381 | 382 | def get_alignments(egs_dir, num_jobs, kaldi_cmd='run.pl', 383 | gmm_data_type=None, data_type='sms_early', 384 | dataset_names=None): 385 | if dataset_names is None: 386 | dataset_names = ('train_si284', 'cv_dev93') 387 | if gmm_data_type is None: 388 | gmm_data_type = data_type 389 | 390 | for dataset in dataset_names: 391 | dataset_dir = egs_dir / 'data' / data_type / dataset 392 | if not (dataset_dir / 'feats.scp').exists(): 393 | calculate_mfccs(egs_dir, dataset_dir, num_jobs=num_jobs, 394 | kaldi_cmd=kaldi_cmd) 395 | run_process([ 396 | f'{egs_dir}/steps/align_fmllr.sh', 397 | '--cmd', kaldi_cmd, 398 | '--nj', str(num_jobs), 399 | f'{dataset_dir}', 400 | f'{egs_dir}/data/lang', 401 | f'{egs_dir}/exp/{gmm_data_type}/tri4b', 402 | f'{egs_dir}/exp/{data_type}/tri4b_ali_{dataset}' 403 | ], 404 | cwd=str(egs_dir) 405 | ) 406 | 407 | 408 | def run_process(cmd, cwd=None, stderr=None, stdout=None): 409 | if isinstance(cmd, str): 410 | shell = True 411 | else: 412 | shell = False 413 | subprocess.run( 414 | cmd, universal_newlines=True, shell=shell, stdout=stdout, 415 | stderr=stderr, check=True, env=None, cwd=cwd 416 | ) 417 | 418 | 419 | def dump_keyed_lines(data_dict: dict, file: Path): 420 | """ 421 | Used to write Kaldi files 422 | 423 | """ 424 | file = Path(file) 425 | file = Path(file).expanduser().resolve() 426 | if file.name in ['utt2dur', 'spk2gender']: 427 | kaldi_type = file.name 428 | else: 429 | kaldi_type = None 430 | items = data_dict.items() 431 | # text_file = Path(text_file) 432 | data = [] 433 | for k, text in items: 434 | if isinstance(text, list): 435 | text = ' '.join(map(str, text)) 436 | if kaldi_type == 'utt2dur': 437 | text_number = float(text) 438 | assert 0. < text_number < 1000., ( 439 | f'Strange duration: {k}: {text_number} s' 440 | ) 441 | elif kaldi_type == 'spk2gender': 442 | text = dict(male='m', female='f', m='m', f='f')[text] 443 | else: 444 | pass 445 | data.append(f'{k} {text}') 446 | 447 | file.write_text('\n'.join(data) + '\n') 448 | 449 | 450 | def pc2_environ(kaldi_dir): 451 | CCS_NODEFILE = Path(os.environ['CCS_NODEFILE']) 452 | if (kaldi_dir / '.queue').exists(): 453 | print('Deleting already existing .queue directory') 454 | shutil.rmtree(kaldi_dir / '.queue') 455 | (kaldi_dir / '.queue').mkdir() 456 | (kaldi_dir / '.queue' / 'machines').write_text(CCS_NODEFILE.read_text()) 457 | -------------------------------------------------------------------------------- /sms_wsj/reverb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/sms_wsj/f6b94bb987d620fc8bbe005a92bd042e09a68d9b/sms_wsj/reverb/__init__.py -------------------------------------------------------------------------------- /sms_wsj/reverb/reverb_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Offers methods for calculating room impulse responses and convolutions of these 3 | with audio signals. 4 | """ 5 | 6 | import numpy as np 7 | import scipy 8 | import scipy.signal 9 | 10 | eps = 1e-60 11 | window_length = 256 12 | 13 | 14 | # TODO: Refactor 15 | def generate_rir( 16 | room_dimensions, 17 | source_positions, 18 | sensor_positions, 19 | sound_decay_time, 20 | sample_rate=16000, 21 | filter_length=2 ** 13, 22 | sensor_orientations=None, 23 | sensor_directivity=None, 24 | sound_velocity=343 25 | ): 26 | """ Wrapper for different RIR generators. Will replace generate_RIR(). 27 | 28 | Args: 29 | room_dimensions: Numpy array with shape (3, 1) 30 | which holds coordinates x, y and z. 31 | source_positions: Numpy array with shape (3, number_of_sources) 32 | which holds coordinates x, y and z in each column. 33 | sensor_positions: Numpy array with shape (3, number_of_sensors) 34 | which holds coordinates x, y and z in each column. 35 | sound_decay_time: Reverberation time in seconds. 36 | sample_rate: Sampling rate in Hertz. 37 | filter_length: Filter length, typically 2**13. 38 | Longer huge reverberation times. 39 | sensor_orientations: Numpy array with shape (2, 1) 40 | which holds azimuth and elevation angle in each column. 41 | sensor_directivity: String determining directivity for all sensors. 42 | sound_velocity: Set to 343 m/s. 43 | 44 | Returns: Numpy array of room impulse respones with 45 | shape (number_of_sources, number_of_sensors, filter_length). 46 | """ 47 | import rirgen 48 | room_dimensions = np.array(room_dimensions) 49 | source_positions = np.array(source_positions) 50 | sensor_positions = np.array(sensor_positions) 51 | 52 | if np.ndim(source_positions) == 1: 53 | source_positions = np.reshape(source_positions, (-1, 1)) 54 | if np.ndim(room_dimensions) == 1: 55 | room_dimensions = np.reshape(room_dimensions, (-1, 1)) 56 | if np.ndim(sensor_positions) == 1: 57 | sensor_positions = np.reshape(sensor_positions, (-1, 1)) 58 | 59 | assert room_dimensions.shape == (3, 1) 60 | assert source_positions.shape[0] == 3 61 | assert sensor_positions.shape[0] == 3 62 | 63 | number_of_sources = source_positions.shape[1] 64 | number_of_sensors = sensor_positions.shape[1] 65 | 66 | if sensor_orientations is None: 67 | sensor_orientations = np.zeros((2, number_of_sources)) 68 | else: 69 | raise NotImplementedError(sensor_orientations) 70 | 71 | if sensor_directivity is None: 72 | sensor_directivity = 'omnidirectional' 73 | else: 74 | raise NotImplementedError(sensor_directivity) 75 | 76 | assert filter_length is not None 77 | rir = np.zeros( 78 | (number_of_sources, number_of_sensors, filter_length), 79 | dtype=np.float64 80 | ) 81 | for k in range(number_of_sources): 82 | temp = rirgen.generate_rir( 83 | room_measures=room_dimensions[:, 0], 84 | source_position=source_positions[:, k], 85 | receiver_positions=sensor_positions.T, 86 | reverb_time=sound_decay_time, 87 | sound_velocity=sound_velocity, 88 | fs=sample_rate, 89 | n_samples=filter_length 90 | ) 91 | rir[k, :, :] = np.asarray(temp) 92 | 93 | assert rir.shape[0] == number_of_sources 94 | assert rir.shape[1] == number_of_sensors 95 | assert rir.shape[2] == filter_length 96 | 97 | return rir 98 | 99 | 100 | def blackman_harris_window(x): 101 | # Can not be replaced by from scipy.signal import blackmanharris. 102 | a0 = 0.35875 103 | a1 = 0.48829 104 | a2 = 0.14128 105 | a3 = 0.01168 106 | x = np.pi * (x - window_length / 2) / window_length 107 | x = a0 - a1 * np.cos(2.0 * x) + a2 * np.cos(4.0 * x) - a3 * np.cos(6.0 * x) 108 | return np.maximum(x, 0) 109 | 110 | 111 | def convolve(signal, impulse_response, truncate=False): 112 | """ Convolution of time signal with impulse response. 113 | 114 | Takes audio signals and the impulse responses according to their position 115 | and returns the convolution. The number of audio signals in x are required 116 | to correspond to the number of sources in the given RIR. 117 | Convolution is conducted through frequency domain via FFT. 118 | 119 | x = h conv s 120 | 121 | Args: 122 | signal: Time signal with shape (..., samples) 123 | impulse_response: Shape (..., sensors, filter_length) 124 | truncate: Truncates result to input signal length if True. 125 | 126 | Alternative args: 127 | signal: Time signal with shape (samples,) 128 | impulse_response: Shape (filter_length,) 129 | 130 | Returns: Convolution result with shape (..., sensors, length) or (length,) 131 | 132 | >>> signal = np.asarray([1, 2, 3]) 133 | >>> impulse_response = np.asarray([1, 1]) 134 | >>> print(convolve(signal, impulse_response)) 135 | [1. 3. 5. 3.] 136 | 137 | >>> K, T, D, filter_length = 2, 12, 3, 5 138 | >>> signal = np.random.normal(size=(K, T)) 139 | >>> impulse_response = np.random.normal(size=(K, D, filter_length)) 140 | >>> convolve(signal, impulse_response).shape 141 | (2, 3, 16) 142 | 143 | >>> signal = np.random.normal(size=(T,)) 144 | >>> impulse_response = np.random.normal(size=(D, filter_length)) 145 | >>> convolve(signal, impulse_response).shape 146 | (3, 16) 147 | """ 148 | signal = np.array(signal) 149 | impulse_response = np.array(impulse_response) 150 | 151 | if impulse_response.ndim == 1: 152 | x = convolve(signal, impulse_response[None, ...], truncate=truncate) 153 | x = np.squeeze(x, axis=0) 154 | return x 155 | 156 | *independent, samples = signal.shape 157 | *independent_, sensors, filter_length = impulse_response.shape 158 | assert independent == independent_, f'signal.shape {signal.shape} does' \ 159 | f' not match impulse_response.shape {impulse_response.shape}' 160 | 161 | x = scipy.signal.fftconvolve( 162 | signal[..., None, :], 163 | impulse_response, 164 | axes=-1 165 | ) 166 | 167 | return x[..., :samples] if truncate else x 168 | 169 | 170 | def get_rir_start_sample(h, level_ratio=1e-1): 171 | """Finds start sample in a room impulse response. 172 | 173 | Selects that index as start sample where the first time 174 | a value larger than `level_ratio * max_abs_value` 175 | occurs. 176 | 177 | If you intend to use this heuristic, test it on simulated and real RIR 178 | first. This heuristic is developed on MIRD database RIRs and on some 179 | simulated RIRs but may not be appropriate for your database. 180 | 181 | If you want to use it to shorten impulse responses, keep the initial part 182 | of the room impulse response intact and just set the tail to zero. 183 | 184 | Params: 185 | h: Room impulse response with Shape (num_samples,) 186 | level_ratio: Ratio between start value and max value. 187 | 188 | >>> get_rir_start_sample(np.array([0, 0, 1, 0.5, 0.1])) 189 | 2 190 | """ 191 | assert level_ratio < 1, level_ratio 192 | if h.ndim > 1: 193 | assert h.shape[0] < 20, h.shape 194 | h = np.reshape(h, (-1, h.shape[-1])) 195 | return np.min( 196 | [get_rir_start_sample(h_, level_ratio=level_ratio) for h_ in h] 197 | ) 198 | 199 | abs_h = np.abs(h) 200 | max_index = np.argmax(abs_h) 201 | max_abs_value = abs_h[max_index] 202 | # +1 because python excludes the last value 203 | larger_than_threshold = abs_h[:max_index + 1] > level_ratio * max_abs_value 204 | 205 | # Finds first occurrence of max 206 | rir_start_sample = np.argmax(larger_than_threshold) 207 | return rir_start_sample 208 | 209 | 210 | if __name__ == "__main__": 211 | import doctest 212 | 213 | doctest.testmod() 214 | -------------------------------------------------------------------------------- /sms_wsj/reverb/rotation.py: -------------------------------------------------------------------------------- 1 | """Contains functions regarding 3D rotation matrices. Used in `scenario.py`.""" 2 | import numpy as np 3 | 4 | 5 | def rot_x(alpha): 6 | """Returns rotation matrix.""" 7 | return np.asarray( 8 | [ 9 | [1, 0, 0], 10 | [0, np.cos(alpha), -np.sin(alpha)], 11 | [0, np.sin(alpha), np.cos(alpha)] 12 | ] 13 | ) 14 | 15 | 16 | def rot_y(alpha): 17 | """Returns rotation matrix.""" 18 | return np.asarray( 19 | [ 20 | [np.cos(alpha), 0, np.sin(alpha)], 21 | [0, 1, 0], 22 | [-np.sin(alpha), 0, np.cos(alpha)] 23 | ] 24 | ) 25 | 26 | 27 | def rot_z(alpha): 28 | """Returns rotation matrix.""" 29 | return np.asarray( 30 | [ 31 | [np.cos(alpha), -np.sin(alpha), 0], 32 | [np.sin(alpha), np.cos(alpha), 0], 33 | [0, 0, 1] 34 | ] 35 | ) 36 | -------------------------------------------------------------------------------- /sms_wsj/reverb/scenario.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helps to quickly create source and sensor positions. 3 | Try it with the following code: 4 | 5 | >>> import numpy as np 6 | >>> import sms_wsj.reverb.scenario as scenario 7 | >>> src = scenario.generate_random_source_positions(dims=2, sources=1000) 8 | >>> src[1, :] = np.abs(src[1, :]) 9 | >>> mic = scenario.generate_sensor_positions(shape='linear', scale=0.1, number_of_sensors=6) 10 | """ 11 | 12 | import numpy as np 13 | from sms_wsj.reverb.rotation import rot_x, rot_y, rot_z 14 | 15 | 16 | def sample_from_random_box(center, edge_lengths, rng=np.random): 17 | """ Sample from a random box to get somewhat random locations. 18 | 19 | >>> points = np.asarray([sample_from_random_box( 20 | ... [[10], [20], [30]], [[1], [2], [3]] 21 | ... ) for _ in range(1000)]) 22 | >>> import matplotlib.pyplot as plt 23 | >>> from mpl_toolkits.mplot3d import Axes3D 24 | >>> fig = plt.figure() 25 | >>> ax = fig.add_subplot(111, projection='3d') 26 | >>> _ = ax.scatter(points[:, 0, 0], points[:, 1, 0], points[:, 2, 0]) 27 | >>> _ = plt.show() 28 | 29 | Args: 30 | center: Original center (mean). 31 | edge_lengths: Edge length of the box to be sampled from. 32 | 33 | Returns: 34 | 35 | """ 36 | center = np.asarray(center) 37 | edge_lengths = np.asarray(edge_lengths) 38 | return center + rng.uniform( 39 | low=-edge_lengths / 2, 40 | high=edge_lengths / 2 41 | ) 42 | 43 | 44 | def generate_sensor_positions( 45 | shape='cube', 46 | center=np.zeros((3, 1), dtype=np.float64), 47 | scale=0.01, 48 | number_of_sensors=None, 49 | jitter=None, 50 | rng=np.random, 51 | rotate_x=0, rotate_y=0, rotate_z=0 52 | ): 53 | """ Generate different sensor configurations. 54 | 55 | Sensors are index counter-clockwise starting with the 0th sensor below 56 | the x axis. This is done, such that the first two sensors point towards 57 | the x axis. 58 | 59 | :param shape: A shape, i.e. 'cube', 'triangle', 'linear' or 'circular'. 60 | :param center: Numpy array with shape (3, 1) 61 | which holds coordinates x, y and z. 62 | :param scale: Scalar responsible for scale of the array. See individual 63 | implementations, if it is used as radius or edge length. 64 | :param jitter: Add random Gaussian noise with standard deviation ``jitter`` 65 | to sensor positions. 66 | :return: Numpy array with shape (3, number_of_sensors). 67 | """ 68 | 69 | center = np.array(center) 70 | if center.ndim == 1: 71 | center = center[:, None] 72 | 73 | if shape == 'cube': 74 | b = scale / 2 75 | sensor_positions = np.array([ 76 | [-b, -b, -b], 77 | [-b, -b, b], 78 | [-b, b, -b], 79 | [-b, b, b], 80 | [b, -b, -b], 81 | [b, -b, b], 82 | [b, b, -b], 83 | [b, b, b] 84 | ]).T 85 | 86 | elif shape == 'triangle': 87 | assert number_of_sensors == 3, ( 88 | "triangle is only defined for 3 sensors", 89 | number_of_sensors) 90 | sensor_positions = generate_sensor_positions( 91 | shape='circular', scale=scale, number_of_sensors=3, rng=rng 92 | ) 93 | 94 | elif shape == 'linear': 95 | sensor_positions = np.zeros((3, number_of_sensors), dtype=np.float64) 96 | sensor_positions[1, :] = scale * np.arange(number_of_sensors) 97 | sensor_positions -= np.mean(sensor_positions, keepdims=True, axis=1) 98 | 99 | elif shape == 'circular': 100 | if number_of_sensors == 1: 101 | sensor_positions = np.zeros((3, 1), dtype=np.float64) 102 | else: 103 | radius = scale 104 | delta_phi = 2 * np.pi / number_of_sensors 105 | phi_0 = delta_phi / 2 106 | phi = np.arange(0, number_of_sensors) * delta_phi - phi_0 107 | sensor_positions = np.asarray([ 108 | radius * np.cos(phi), 109 | radius * np.sin(phi), 110 | np.zeros(phi.shape) 111 | ]) 112 | 113 | elif shape == 'chime3': 114 | assert scale is None, scale 115 | assert ( 116 | number_of_sensors is None or number_of_sensors == 6 117 | ), number_of_sensors 118 | 119 | sensor_positions = np.asarray( 120 | [ 121 | [-0.1, 0, 0.1, -0.1, 0, 0.1], 122 | [0.095, 0.095, 0.095, -0.095, -0.095, -0.095], 123 | [0, -0.02, 0, 0, 0, 0] 124 | ] 125 | ) 126 | 127 | else: 128 | raise NotImplementedError('Given shape is not implemented.') 129 | 130 | sensor_positions = rot_x(rotate_x) @ sensor_positions 131 | sensor_positions = rot_y(rotate_y) @ sensor_positions 132 | sensor_positions = rot_z(rotate_z) @ sensor_positions 133 | 134 | if jitter is not None: 135 | sensor_positions += rng.normal( 136 | 0., jitter, size=sensor_positions.shape 137 | ) 138 | 139 | return np.asarray(sensor_positions + center) 140 | 141 | 142 | def generate_random_source_positions( 143 | center=np.zeros((3, 1)), 144 | sources=1, 145 | distance_interval=(1, 2), 146 | dims=2, 147 | minimum_angular_distance=None, 148 | maximum_angular_distance=None, 149 | rng=np.random 150 | ): 151 | """ Generates random positions on a hollow sphere or circle. 152 | 153 | Samples are drawn from a uniform distribution on a hollow sphere with 154 | inner and outer radius according to distance_interval. 155 | 156 | The idea is to sample from an angular centric Gaussian distribution. 157 | 158 | Params: 159 | center 160 | sources 161 | distance_interval 162 | dims 163 | minimum_angular_distance: In randiant or None. 164 | maximum_angular_distance: In randiant or None. 165 | rng: Random number generator, if you need to set the seed. 166 | """ 167 | enforce_angular_constrains = ( 168 | minimum_angular_distance is not None or 169 | maximum_angular_distance is not None 170 | ) 171 | 172 | if not dims == 2 and enforce_angular_constrains: 173 | raise NotImplementedError( 174 | 'Only implemented distance constraints for 2D.' 175 | ) 176 | 177 | accept = False 178 | while not accept: 179 | x = rng.normal(size=(3, sources)) 180 | if dims == 2: 181 | x[2, :] = 0 182 | 183 | if enforce_angular_constrains: 184 | if not sources == 2: 185 | raise NotImplementedError 186 | angle = np.arctan2(x[1, :], x[0, :]) 187 | difference = np.angle( 188 | np.exp(1j * (angle[None, :], angle[:, None]))) 189 | difference = difference[np.triu_indices_from(difference, k=1)] 190 | distance = np.abs(difference) 191 | if ( 192 | minimum_angular_distance is not None and 193 | minimum_angular_distance > np.min(distance) 194 | ): 195 | continue 196 | if ( 197 | maximum_angular_distance is not None and 198 | maximum_angular_distance < np.max(distance) 199 | ): 200 | continue 201 | accept = True 202 | 203 | x /= np.linalg.norm(x, axis=0) 204 | 205 | radius = rng.uniform( 206 | distance_interval[0] ** dims, 207 | distance_interval[1] ** dims, 208 | size=(1, sources) 209 | ) ** (1 / dims) 210 | 211 | x *= radius 212 | return np.asarray(x + center) 213 | -------------------------------------------------------------------------------- /sms_wsj/train_baseline_asr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script for baseline asr system. Expects all sms_wsj files to be 3 | written to storage and json_path pointing to a json using those files. 4 | example call: 5 | 6 | python -m train_baseline_asr with egs_path=$KALDI_ROOT/egs/ json_path=$JSON_PATH/sms_wsj.json 7 | """ 8 | import os 9 | import shutil 10 | from pathlib import Path 11 | 12 | import sacred 13 | from lazy_dataset.database import JsonDatabase 14 | from sms_wsj.kaldi.utils import create_data_dir, create_kaldi_dir 15 | from sms_wsj.kaldi.utils import run_process, pc2_environ 16 | 17 | kaldi_root = Path(os.environ['KALDI_ROOT']) 18 | assert kaldi_root.exists(), ( 19 | f'The environmental variable KALDI_ROOT has to be set to a working kaldi' 20 | f' root, at the moment it points to f{kaldi_root}' 21 | ) 22 | assert (kaldi_root / 'src').exists(), ( 23 | f'The environmental variable KALDI_ROOT has to be set to a working kaldi' 24 | f' root, at the moment it points to f{kaldi_root}' 25 | ) 26 | assert (kaldi_root / 'src' / 'base' / '.depend.mk').exists(), ( 27 | 'The kaldi your KALDI_ROOT points to is not installed, please refer to' 28 | ' kaldi for further information on how to install it' 29 | ) 30 | ex = sacred.Experiment('Kaldi ASR baseline training') 31 | 32 | 33 | @ex.config 34 | def config(): 35 | egs_path = None 36 | json_path = None 37 | # only used for the paderborn parallel computing center 38 | if 'CCS_NODEFILE' in os.environ: 39 | num_jobs = len(list( 40 | Path(os.environ['CCS_NODEFILE']).read_text().strip().splitlines() 41 | )) 42 | else: 43 | num_jobs = os.cpu_count() 44 | stage = 0 45 | end_stage = 20 46 | kaldi_cmd = 'run.pl' 47 | ali_data_type = 'sms_early' 48 | train_data_type = 'sms_single_speaker' 49 | target_speaker = [0, 1] 50 | channels = [0, 2, 4] 51 | sample_rate = 8000 52 | gmm_dir = None 53 | # ToDo: change to kaldi_root/egs/ if no egs_path is defined? 54 | assert egs_path is not None, \ 55 | 'The directory where all asr training related data is stored has' \ 56 | ' to be defined, use "with egs_path=/path/to/storage/dir"' 57 | assert json_path is not None, \ 58 | 'The path to the json describing the SMS-WSJ database has to be' \ 59 | ' defined, use "with json_path=/path/to/json/sms_wsj.json"' \ 60 | ' (for creating the json use ...)' 61 | 62 | 63 | @ex.automain 64 | def run(_config, egs_path, json_path, stage, end_stage, gmm_dir, 65 | ali_data_type, train_data_type, target_speaker, channels, 66 | sample_rate, kaldi_cmd, num_jobs): 67 | sms_db = JsonDatabase(json_path) 68 | sms_kaldi_dir = Path(egs_path).resolve().expanduser() 69 | sms_kaldi_dir = sms_kaldi_dir / train_data_type / 's5' 70 | if stage <= 1 < end_stage: 71 | create_kaldi_dir(sms_kaldi_dir, sample_rate=sample_rate) 72 | 73 | if kaldi_cmd == 'ssh.pl': 74 | if 'CCS_NODEFILE' in os.environ: 75 | pc2_environ(sms_kaldi_dir) 76 | with (sms_kaldi_dir / 'cmd.sh').open('a') as fd: 77 | fd.writelines('export train_cmd="ssh.pl"') 78 | elif kaldi_cmd == 'run.pl': 79 | with (sms_kaldi_dir / 'cmd.sh').open('a') as fd: 80 | fd.writelines('export train_cmd="run.pl"') 81 | else: 82 | raise ValueError(kaldi_cmd) 83 | 84 | if gmm_dir is None: 85 | gmm = 'tri4b' 86 | else: 87 | gmm_dir = Path(gmm_dir) 88 | gmm = gmm_dir.name 89 | if stage <= 2 < end_stage: 90 | if gmm_dir is None: 91 | create_data_dir(sms_kaldi_dir, db=sms_db, data_type='wsj_8k', 92 | target_speaker=target_speaker, 93 | sample_rate=sample_rate) 94 | print('Start training tri3 model on wsj_8k') 95 | run_process([ 96 | f'{sms_kaldi_dir}/local_sms/get_tri3_model.bash', 97 | '--dest_dir', f'{sms_kaldi_dir}', 98 | '--nj', str(num_jobs)], 99 | cwd=str(sms_kaldi_dir), 100 | stdout=None, stderr=None 101 | ) 102 | else: 103 | assert gmm_dir.exists() 104 | gmm_parent_dir = sms_kaldi_dir / 'exp' / 'wsj_8k' 105 | gmm_parent_dir.mkdir(parents=True) 106 | shutil.copytree(gmm_dir, gmm_parent_dir / gmm) 107 | 108 | if stage <= 3 < end_stage and not ali_data_type == train_data_type: 109 | create_data_dir( 110 | sms_kaldi_dir, db=sms_db, data_type=ali_data_type, 111 | ref_channels=channels, target_speaker=target_speaker, 112 | sample_rate=sample_rate 113 | ) 114 | 115 | if stage <= 4 < end_stage: 116 | create_data_dir( 117 | sms_kaldi_dir, db=sms_db, data_type=train_data_type, 118 | ref_channels=channels, target_speaker=target_speaker, 119 | sample_rate=sample_rate 120 | ) 121 | 122 | if stage <= 16 < end_stage: 123 | print('Prepare data for nnet3 model training on sms_wsj') 124 | run_process([ 125 | f'{sms_kaldi_dir}/local_sms/prepare_nnet3_model_training.bash', 126 | '--dest_dir', f'{sms_kaldi_dir}', 127 | '--cv_sets', "cv_dev93", 128 | '--stage', str(stage), 129 | '--gmm_data_type', 'wsj_8k', 130 | '--gmm', gmm, 131 | '--ali_data_type', ali_data_type, 132 | '--dataset', train_data_type, 133 | '--nj', str(num_jobs)], 134 | cwd=str(sms_kaldi_dir), 135 | stdout=None, stderr=None 136 | ) 137 | 138 | if stage <= 20 and end_stage >= 17: 139 | print('Start training nnet3 model on sms_wsj') 140 | run_process([ 141 | f'{sms_kaldi_dir}/local_sms/get_nnet3_model.bash', 142 | '--dest_dir', f'{sms_kaldi_dir}', 143 | '--cv_sets', '"cv_dev93"', 144 | '--stage', str(stage), 145 | '--gmm_data_type', 'wsj_8k', 146 | '--gmm', gmm, 147 | '--ali_data_type', ali_data_type, 148 | '--dataset', train_data_type, 149 | '--nj', str(num_jobs)], 150 | cwd=str(sms_kaldi_dir), 151 | stdout=None, stderr=None 152 | ) 153 | -------------------------------------------------------------------------------- /sms_wsj/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.patches 4 | 5 | 6 | def plot_scenario(ex, ax: plt.Axes = None, *, figsize=(12, 8)): 7 | """ 8 | Plot the source position, sensor position and room dimensions from an 9 | SMS-WSJ example in the (x, y) plane, i.e. the z-axis is ignored. 10 | 11 | Args: 12 | ex: An example from SMS-WSJ 13 | ax: 14 | 15 | Returns: 16 | 17 | """ 18 | 19 | if ax is None: 20 | ax = plt.subplots(1, figsize=figsize)[1] 21 | 22 | speaker_id_to_source_position = {} 23 | for i, source_position in enumerate(np.array(ex['source_position']).T): 24 | x, y, z = source_position 25 | speaker_id = ex["speaker_id"][i] 26 | if speaker_id in speaker_id_to_source_position: 27 | np.testing.assert_equal( 28 | source_position, speaker_id_to_source_position[speaker_id]) 29 | else: 30 | speaker_id_to_source_position[speaker_id] = source_position 31 | ax.scatter(x, y, label=f'Speaker {speaker_id}') 32 | 33 | xs, ys, zs = np.array(ex['sensor_position']) 34 | ax.scatter(xs, ys, label=f'Microphones') 35 | 36 | room_dimensions = ex['room_dimensions'] 37 | (x,), (y,), (z,) = room_dimensions 38 | # Draw the wall 39 | w = 0.30 40 | ax.add_patch(matplotlib.patches.Polygon(np.array([ 41 | (0, 0), (x, 0), (x, y), (0, y), (0, -w), (-w, -w), (-w, y+w), 42 | (x+w, y+w), (x+w, -w), (0, -w) 43 | ]), fill=False, hatch='/', linewidth=0)) 44 | ax.add_patch(matplotlib.patches.Rectangle((0, 0), x, y, fill=None)) 45 | 46 | ax.set(title=f'Dataset: {ex["dataset"]!r}, ExID: {ex["example_id"]!r}, RT60: {ex["sound_decay_time"]}') 47 | ax.autoscale(tight=True) 48 | ax.set_aspect('equal') 49 | ax.legend() 50 | return ax 51 | -------------------------------------------------------------------------------- /tests/database.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | import sms_wsj 6 | from sms_wsj.database.database import SmsWsj, AudioReader 7 | 8 | 9 | json_path = Path(sms_wsj.__file__) / 'cache' / 'sms_wsj.json' 10 | 11 | 12 | def test_first_example(): 13 | db = SmsWsj(json_path) 14 | ds = db.get_dataset('cv_dev93') 15 | ds = ds.map(AudioReader(AudioReader.all_keys)) 16 | 17 | example = ds[0] 18 | 19 | np.testing.assert_allclose( 20 | example['audio_data']['observation'], 21 | np.sum( 22 | example['audio_data']['speech_reverberation_early'] 23 | + example['audio_data']['speech_reverberation_tail'], 24 | axis=0 # sum above speaker 25 | ) + example['audio_data']['noise_image'], 26 | atol=1e-7 27 | ) 28 | 29 | assert len(example['audio_data']) == 8 30 | assert example['audio_data']['observation'].shape == (6, 93389) 31 | assert example['audio_data']['noise_image'].shape == (6, 93389) 32 | assert example['audio_data']['speech_reverberation_early'].shape == (2, 6, 93389) 33 | assert example['audio_data']['speech_reverberation_tail'].shape == (2, 6, 93389) 34 | assert example['audio_data']['speech_source'].shape == (2, 93389) 35 | assert example['audio_data']['speech_image'].shape == (2, 6, 93389) 36 | assert example['audio_data']['rir'].shape == (2, 6, 8192) 37 | assert example['audio_data']['original_source'][0].shape == (31633,) 38 | assert example['audio_data']['original_source'][1].shape == (93389,) 39 | 40 | assert list(example.keys()) == [ 41 | 'room_dimensions', 'sound_decay_time', 'source_position', 42 | 'sensor_position', 'example_id', 'num_speakers', 'speaker_id', 43 | 'source_id', 'gender', 'kaldi_transcription', 'log_weights', 44 | 'num_samples', 'offset', 'audio_path', 'snr', 'dataset', 'audio_data' 45 | ] 46 | 47 | assert example['example_id'] == '0_4k6c0303_4k4c0319' 48 | assert example['snr'] == 23.287502642941252 49 | assert example['room_dimensions'] == [[8.169], [5.905], [3.073]] 50 | assert example['source_position'] == [[3.312, 3.0], [1.921, 2.379], [1.557, 1.557]] 51 | assert example['sensor_position'] == [ 52 | [4.015, 3.973, 4.03, 4.129, 4.172, 4.115], 53 | [3.265, 3.175, 3.093, 3.102, 3.192, 3.274], 54 | [1.55, 1.556, 1.563, 1.563, 1.558, 1.551]] 55 | assert example['sound_decay_time'] == 0.387 56 | assert example['offset'] == [52476, 0] 57 | assert example['log_weights'] == [0.9885484337248203, -0.9885484337248203] 58 | assert example['num_samples'] == {'observation': 93389, 59 | 'original_source': [31633, 93389]} 60 | 61 | 62 | def test_random_example(): 63 | db = SmsWsj(json_path) 64 | ds = db.get_dataset('cv_dev93') 65 | ds = ds.map(AudioReader(AudioReader.all_keys)) 66 | 67 | example = ds.random_choice() 68 | 69 | np.testing.assert_allclose( 70 | example['audio_data']['observation'], 71 | np.sum( 72 | example['audio_data']['speech_reverberation_early'] 73 | + example['audio_data']['speech_reverberation_tail'], 74 | axis=0 # sum above speaker 75 | ) + example['audio_data']['noise_image'], 76 | atol=1e-7 77 | ) 78 | 79 | 80 | def test_order(): 81 | db = SmsWsj(json_path) 82 | 83 | ds = db.get_dataset('cv_dev93') 84 | for scenario_id, example in enumerate(ds): 85 | assert scenario_id == int(example['example_id'].split('_')[0]) 86 | 87 | ds = db.get_dataset('test_eval92') 88 | for scenario_id, example in enumerate(ds): 89 | assert scenario_id == int(example['example_id'].split('_')[0]) 90 | 91 | ds = db.get_dataset('train_si284') 92 | for scenario_id, example in enumerate(ds): 93 | assert scenario_id == int(example['example_id'].split('_')[0]) 94 | -------------------------------------------------------------------------------- /tests/test_import.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import importlib 4 | from pathlib import Path 5 | 6 | import pytest 7 | 8 | import sms_wsj 9 | 10 | 11 | def get_module_name_from_file(file, package_path=Path(sms_wsj.__file__).parent): 12 | """ 13 | >> import sms_wsj 14 | >> file = sms_wsj.io.__file__ 15 | >> file # doctest: +ELLIPSIS 16 | '.../sms_wsj/io.py' 17 | >> get_module_name_from_file(file) 18 | 'sms_wsj.io' 19 | >> file = sms_wsj.database.__file__ 20 | >> file # doctest: +ELLIPSIS 21 | '.../sms_wsj/database/__init__.py' 22 | >> get_module_name_from_file(pb.transform.__file__) 23 | 'sms_wsj.database' 24 | """ 25 | 26 | assert package_path in file.parents, (package_path, file) 27 | file = file.relative_to(package_path.parent) 28 | parts = list(file.with_suffix('').parts) 29 | if parts[-1] == '__init__': 30 | parts.pop(-1) 31 | module_path = '.'.join(parts) 32 | return module_path 33 | 34 | 35 | @pytest.fixture(scope="session", autouse=True) 36 | def dummy_kaldi_root(tmp_path_factory): 37 | kaldi_root = tmp_path_factory.mktemp("kaldi") 38 | (kaldi_root / 'src' / 'base' / '.depend.mk').mkdir(parents=True) 39 | 40 | if 'KALDI_ROOT' not in os.environ: 41 | os.environ.setdefault('KALDI_ROOT', str(kaldi_root)) 42 | 43 | return kaldi_root 44 | 45 | 46 | class TestImport: 47 | python_files = Path(sms_wsj.__file__).parent.glob('**/*.py') 48 | 49 | @pytest.mark.parametrize('py_file', [ 50 | pytest.param( 51 | py_file, 52 | id=get_module_name_from_file(py_file)) 53 | for py_file in python_files 54 | ]) 55 | def test_import(self, py_file: Path, with_importlib=True): 56 | """ 57 | Import `py_file` into the system 58 | 59 | Args: 60 | py_file: Python file to import 61 | with_importlib: If True, use `importlib` for importing. Else, use 62 | `subprocess.run`: It is considerably slower but may 63 | have better readable test output 64 | """ 65 | import_name = get_module_name_from_file(py_file) 66 | suffix = Path(py_file).suffix 67 | try: 68 | if with_importlib: 69 | _ = importlib.import_module(import_name) 70 | else: 71 | _ = subprocess.run( 72 | ['python', '-c', f'import {import_name}'], 73 | stdout=subprocess.PIPE, 74 | stderr=subprocess.PIPE, 75 | check=True, 76 | universal_newlines=True, 77 | ) 78 | except ( 79 | ImportError, 80 | ModuleNotFoundError, 81 | subprocess.CalledProcessError, 82 | ) as e: 83 | try: 84 | err = e.stderr 85 | except AttributeError: 86 | err = 'See Traceback above' 87 | assert False, f'Cannot import file "{import_name}{suffix}" \n\n' \ 88 | f'stderr: {err}' 89 | --------------------------------------------------------------------------------