├── .github └── workflows │ ├── python-publish.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── kaldiio ├── __init__.py ├── compression_header.py ├── highlevel.py ├── matio.py ├── python_wave.py ├── utils.py └── wavio.py ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── arks ├── create_arks.sh ├── incorrect_header.wav ├── test.ark ├── test.cm1.ark ├── test.cm3.ark ├── test.cm5.ark └── test.text.ark ├── test_extended_ark.py ├── test_highlevel.py ├── test_limited_size_dict.py ├── test_mat_ark.py ├── test_multi_file_descriptor.py ├── test_open_like_kaldi.py ├── test_parse_specifier.py └── test_wav.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Linter and test 5 | on: 6 | push: 7 | branches: [ "master" ] 8 | pull_request: 9 | branches: [ "master" ] 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | sudo apt-get install sox 29 | python -m pip install --upgrade pip 30 | python -m pip install -e . 31 | python -m pip install flake8 pytest codecov pytest-cov soundfile 32 | - name: Test 33 | run: | 34 | pytest tests 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .cache 2 | *.pyc 3 | __pycache__ 4 | .idea 5 | .coverage 6 | htmlcov 7 | .eggs 8 | *.egg-info 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT FOR EVALUATION 2 | 3 | This SOFTWARE EVALUATION LICENSE AGREEMENT (this "Agreement") is a legal contract between a person who uses or otherwise accesses or installs the Software (“User(s)”), and Nippon Telegraph and Telephone corporation ("NTT"). 4 | READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE. 5 | 6 | 7 | BACKGROUND 8 | A. NTT is the owner of all rights, including all patent rights, and copyrights in and to the Software and related documentation listed in Exhibit A to this Agreement. 9 | B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such a license to User, pursuant and subject to the terms and conditions of this Agreement. 10 | C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement. 11 | In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows: 12 | 1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in "NTT Neural Machine Translation Systems at WAT 2017, Morishita et al., WAT 2017". User may make a reasonable number of backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1. 13 | 2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User shall be solely responsible for proper installation of the Software. 14 | 3. Term. This Agreement is effective whichever is earlier (i) upon User’s acceptance of the Agreement, or (ii) upon User’s installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to any other rights, NTT may terminate this Agreement without notice to User. User may terminate this Agreement at any time by User’s decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this Agreement for any reason, User agrees to uninstall the Software and destroy all copies of the Software. 15 | 4. Proprietary Rights 16 | (a) The Software is the valuable and proprietary property of NTT, and NTT shall retain exclusive title to this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges that all patent rights and copyrights in the Software shall remain the exclusive property of NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software. 17 | (b) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i) SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; OR (iii) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (ii) ABOVE. 18 | (c) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied. 19 | 5. Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage, or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE. 20 | 6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT. 21 | 7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARD¬LESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3. 22 | 8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned, or otherwise transferred by User without NTT's prior written consent. 23 | 9. General 24 | (a) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this Agreement shall remain in full force and effect. 25 | (b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter. 26 | (c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User. 27 | (d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding. 28 | (e) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof. 29 | (f) NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT’s obligation set forth under this Agreement due to any cause beyond NTT’s reasonable control. 30 |   31 | EXHIBIT A 32 | 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kaldiio 2 | [![pypi](https://img.shields.io/pypi/v/kaldiio.svg)](https://pypi.python.org/pypi/kaldiio) 3 | [![Supported Python versions](https://img.shields.io/pypi/pyversions/kaldiio.svg)](https://pypi.python.org/pypi/kaldiio) 4 | [![codecov](https://codecov.io/gh/nttcslab-sp/kaldiio/branch/master/graph/badge.svg)](https://codecov.io/gh/nttcslab-sp/kaldiio) 5 | 6 | A pure python module for reading and writing kaldi ark files 7 | 8 | - [Introduction](#introduction) 9 | - [What is this? What are `ark` and `scp`?](#what-is-this-what-are-ark-and-scp) 10 | - [Features](#features) 11 | - [Similar projects](#similar-projects) 12 | - [Install](#install) 13 | - [Usage](#usage) 14 | - [ReadHelper](#readhelper) 15 | - [WriteHelper](#writehelper) 16 | - [More low level API](#more-low-level-api) 17 | 18 | ## Introduction 19 | ### What are `ark` and `scp`? 20 | `kaldiio` is an IO utility implemented in pure Python language for several file formats used in [kaldi](https://github.com/kaldi-asr/kaldi), which are named as`ark` and `scp`. `ark` and `scp` are used in in order to archive some objects defined in Kaldi, typically it is Matrix object of Kaldi. 21 | 22 | In this section, we describe the basic concept of `ark` and `scp`. More detail about the File-IO in `Kaldi-asr`: http://kaldi-asr.org/doc/io.html 23 | 24 | 25 | #### Basic of File IO in kaldi: Ark and copy-feats 26 | `ark` is an archive format to save any `Kaldi objects`. This library mainly support `KaldiMatrix/KaldiVector`. 27 | This ia an example of ark file of KaldiMatrix: [ark file](tests/arks/test.ark) 28 | 29 | If you have `Kaldi`, you can convert it to text format as following 30 | 31 | ```bash 32 | # copy-feats 33 | copy-feats ark:test.ark ark,t:text.ark 34 | ``` 35 | 36 | 37 | `copy-feats` is designed to have high affinity with unix command line: 38 | 39 | 1. `ark` can be flushed to and from unix pipe. 40 | 41 | cat test.ark | copy-feats ark:- ark,t:- | less # Show the contents in the ark 42 | `-` indicates standard input stream or output stream. 43 | 1. Unix command can be used as `read-specifier` and `wspecifier` 44 | 45 | copy-feats ark:'gunzip -c some.ark.gz |' ark:some.ark 46 | 47 | #### Scp file 48 | `scp` is a text file such as, 49 | 50 | ``` 51 | uttid1 /some/where/feats.ark:123 52 | uttid2 /some/where/feats.ark:156 53 | uttid3 /some/where/feats.ark:245 54 | ``` 55 | The first column, `uttid1`, indicates the utterance id and the second, `/some/where/feats.ark:123`, is the file path of matrix/vector of kaldi formats. The number after colon is a starting addressof the object of the file. 56 | 57 | `scp` looks very simple format, but has several powerful features. 58 | 59 | 1. Mutual conversion between`ark` and `scp` 60 | 61 | copy-feats scp:foo.scp ark:foo.ark # scp -> ark 62 | copy-feats ark:foo.ark ark,scp:bar.ark,bar.scp # ark -> ark,scp 63 | 64 | 1. Unix command can be used insead of direct file path 65 | 66 | For example, the following scp file can be also used. 67 | 68 | uttid1 cat /some/where/feats1.mat | 69 | uttid2 cat /some/where/feats2.mat | 70 | uttid3 cat /some/where/feats3.mat | 71 | 72 | #### wav.scp 73 | `wav.scp` is a `scp` to describe wave file paths. 74 | 75 | ``` 76 | uttid1 /some/path/a.wav 77 | uttid2 /some/path/b.wav 78 | uttid3 /some/path/c.wav 79 | ``` 80 | 81 | `wav.scp` is also can be embeded unix command as normal scp file. This is often used for converting file format in kaldi recipes. 82 | 83 | ``` 84 | uttid1 sph2pipe -f wav /some/path/a.wv1 | 85 | uttid2 sph2pipe -f wav /some/path/b.wv1 | 86 | uttid3 sph2pipe -f wav /some/path/c.wv1 | 87 | ``` 88 | 89 | ### Features 90 | Kaldiio supports: 91 | 92 | - Read/Write for archive formats: ark, scp 93 | - Binary/Text - Float/Double Matrix: DM, FM 94 | - Binary/Text - Float/Double Vector: DV, FV 95 | - Compressed Matrix for loading: CM, CM2, CM3 96 | - Compressed Matrix for writing: All compressoin_method are supported: 1,2,3,4,5,6,7 97 | - Binary/Text for Int-vector, typically used for `ali` files. 98 | - Read/Write via a pipe: e.g. "ark: cat feats.ark |" 99 | - Read wav.scp / wav.ark 100 | - (New!) Some extended ark format **not supported** in Kaldi originally. 101 | - The ark file for numpy, pickle, wav, flac files. 102 | 103 | The followings are **not supported** 104 | 105 | - Write in existing scp file 106 | - NNet2/NNet3 egs 107 | - Lattice file 108 | 109 | ### Similar projects 110 | - Python-C++ binding 111 | - https://github.com/pykaldi/pykaldi 112 | - Looks great. I recommend pykaldi if you aren't particular about pure python. 113 | - https://github.com/janchorowski/kaldi-python/ 114 | - Maybe not enough maintained now. 115 | - https://github.com/t13m/kaldi-readers-for-tensorflow 116 | - Ark reader for tensorflow 117 | - https://github.com/csukuangfj/kaldi_native_io 118 | - Implemented in C++ 119 | - Have interface for Python 120 | - Support all types of `rspecifier` and `wspecifier` 121 | - Have a uniform interface for writing, sequential reading, and random access reading 122 | - `pip install kaldi_native_io` 123 | - Pure Python 124 | - https://github.com/vesis84/kaldi-io-for-python 125 | - `kaldiio` is based on this module, but `kaldiio` supports more features than it. 126 | - https://github.com/funcwj/kaldi-python-io 127 | - Python>=3.6. `nnet3-egs`is also supported. 128 | 129 | ## Install 130 | 131 | ```bash 132 | pip install kaldiio 133 | ``` 134 | 135 | ## Usage 136 | `kaldiio` doesn't distinguish the API for each kaldi-objects, i.e. 137 | `Kaldi-Matrix`, `Kaldi-Vector`, not depending on whether it is binary or text, or compressed or not, 138 | can be handled by the same API. 139 | 140 | ### ReadHelper 141 | `ReadHelper` supports sequential accessing for `scp` or `ark`. If you need to access randomly, then use `kaldiio.load_scp`. 142 | 143 | 144 | - Read matrix-scp 145 | 146 | ```python 147 | from kaldiio import ReadHelper 148 | with ReadHelper('scp:file.scp') as reader: 149 | for key, numpy_array in reader: 150 | ... 151 | ``` 152 | 153 | 154 | - Read gziped ark 155 | 156 | ```python 157 | from kaldiio import ReadHelper 158 | with ReadHelper('ark: gunzip -c file.ark.gz |') as reader: 159 | for key, numpy_array in reader: 160 | ... 161 | 162 | # Ali file 163 | with ReadHelper('ark: gunzip -c exp/tri3_ali/ali.*.gz |') as reader: 164 | for key, numpy_array in reader: 165 | ... 166 | ``` 167 | 168 | 169 | - Read wav.scp 170 | 171 | ```python 172 | from kaldiio import ReadHelper 173 | with ReadHelper('scp:wav.scp') as reader: 174 | for key, (rate, numpy_array) in reader: 175 | ... 176 | ``` 177 | 178 |     - v2.11.0: Removed `wav` option. You can load `wav.scp` without any addtional argument. 179 | 180 | - Read wav.scp with segments 181 | 182 | ```python 183 | from kaldiio import ReadHelper 184 | with ReadHelper('scp:wav.scp', segments='segments') as reader 185 | for key, (rate, numpy_array) in reader: 186 | ... 187 | ``` 188 | 189 | - Read from stdin 190 | 191 | ```python 192 | from kaldiio import ReadHelper 193 | with ReadHelper('ark:-') as reader: 194 | for key, numpy_array in reader: 195 | ... 196 | ``` 197 | 198 | ### WriteHelper 199 | - Write matrices and vectors in a ark with scp 200 | 201 | ```python 202 | import numpy 203 | from kaldiio import WriteHelper 204 | with WriteHelper('ark,scp:file.ark,file.scp') as writer: 205 | for i in range(10): 206 | writer(str(i), numpy.random.randn(10, 10)) 207 | # The following is equivalent 208 | # writer[str(i)] = numpy.random.randn(10, 10) 209 | ``` 210 | 211 | - Write in compressed matrix 212 | 213 | ```python 214 | import numpy 215 | from kaldiio import WriteHelper 216 | with WriteHelper('ark:file.ark', compression_method=2) as writer: 217 | for i in range(10): 218 | writer(str(i), numpy.random.randn(10, 10)) 219 | ``` 220 | 221 | - Write matrices in text 222 | 223 | ```python 224 | import numpy 225 | from kaldiio import WriteHelper 226 | with WriteHelper('ark,t:file.ark') as writer: 227 | for i in range(10): 228 | writer(str(i), numpy.random.randn(10, 10)) 229 | ``` 230 | 231 | - Write in gziped ark 232 | 233 | ```python 234 | import numpy 235 | from kaldiio import WriteHelper 236 | with WriteHelper('ark:| gzip -c > file.ark.gz') as writer: 237 | for i in range(10): 238 | writer(str(i), numpy.random.randn(10, 10)) 239 | ``` 240 | - Write matrice to stdout 241 | 242 | ```python 243 | import numpy 244 | from kaldiio import WriteHelper 245 | with WriteHelper('ark:-') as writer: 246 | for i in range(10): 247 | writer(str(i), numpy.random.randn(10, 10)) 248 | ``` 249 | 250 | 251 | - (New!) Extended ark format using numpy, pickle, soundfile 252 | 253 | ```python 254 | import numpy 255 | from kaldiio import WriteHelper 256 | 257 | # NPY ARK 258 | with WriteHelper('ark:-', write_function="numpy") as writer: 259 | writer("foo", numpy.random.randn(10, 10)) 260 | 261 | # PICKLE ARK 262 | with WriteHelper('ark:-', write_function="pickle") as writer: 263 | writer("foo", numpy.random.randn(10, 10)) 264 | 265 | # FLAC ARK 266 | with WriteHelper('ark:-', write_function="soundfile_flac") as writer: 267 | writer("foo", numpy.random.randn(1000)) 268 | ``` 269 | 270 | Note that `soundfile` is an optional module and you need to install it to use this feature. 271 | 272 | ```sh 273 | pip install soundfile 274 | ``` 275 | 276 | ## More low level API 277 | `WriteHelper` and `ReadHelper` are high level wrapper of the following API to support kaldi style arguments. 278 | 279 | ### load_ark 280 | 281 | ```python 282 | import kaldiio 283 | 284 | d = kaldiio.load_ark('a.ark') # d is a generator object 285 | for key, numpy_array in d: 286 | ... 287 | 288 | # === load_ark can accepts file descriptor, too 289 | with open('a.ark') as fd: 290 | for key, numpy_array in kaldiio.load_ark(fd): 291 | ... 292 | 293 | # === Use with open_like_kaldi 294 | from kaldiio import open_like_kaldi 295 | with open_like_kaldi('gunzip -c file.ark.gz |', 'r') as f: 296 | for key, numpy_array in kaldiio.load_ark(fd): 297 | ... 298 | ``` 299 | 300 | - `load_ark` can load both matrices of ark and vectors of ark and also, it can be both text and binary. 301 | 302 | ### load_scp 303 | `load_scp` creates "lazy dict", i.e. 304 | The data are loaded in memory when accessing the element. 305 | 306 | ```python 307 | import kaldiio 308 | 309 | d = kaldiio.load_scp('a.scp') 310 | for key in d: 311 | numpy_array = d[key] 312 | 313 | 314 | with open('a.scp') as fd: 315 | kaldiio.load_scp(fd) 316 | 317 | d = kaldiio.load_scp('data/train/wav.scp', segments='data/train/segments') 318 | for key in d: 319 | rate, numpy_array = d[key] 320 | ``` 321 | 322 | The object created by `load_scp` is a dict-like object, thus it has methods of `dict`. 323 | 324 | ```python 325 | import kaldiio 326 | d = kaldiio.load_scp('a.scp') 327 | d.keys() 328 | d.items() 329 | d.values() 330 | 'uttid' in d 331 | d.get('uttid') 332 | ``` 333 | 334 | ### load_scp_sequential (from v2.13.0) 335 | 336 | `load_scp_sequential` creates "generator" as same as `load_ark`. 337 | If you don't need random-accessing for each elements 338 | and use it just to iterate for whole data, 339 | then this method possibly performs faster than `load_scp`. 340 | 341 | ```python 342 | import kaldiio 343 | d = kaldiio.load_scp_sequential('a.scp') 344 | for key, numpy_array in d: 345 | ... 346 | ``` 347 | 348 | ### load_wav_scp 349 | ```python 350 | d = kaldiio.load_scp('wav.scp') 351 | for key in d: 352 | rate, numpy_array = d[key] 353 | 354 | # Supporting "segments" 355 | d = kaldiio.load_scp('data/train/wav.scp', segments='data/train/segments') 356 | for key in d: 357 | rate, numpy_array = d[key] 358 | ``` 359 | 360 | - v2.11.0: `load_wav_scp` is deprecated now. Use `load_scp`. 361 | 362 | ### load_mat 363 | ```python 364 | numpy_array = kaldiio.load_mat('a.mat') 365 | numpy_array = kaldiio.load_mat('a.ark:1134') # Seek and load 366 | 367 | # If the file is wav, gets Tuple[int, numpy.ndarray] 368 | rate, numpy_array = kaldiio.load_mat('a.wav') 369 | ``` 370 | - `load_mat` can load kaldi-matrix, kaldi-vector, and wave 371 | 372 | ### save_ark 373 | ```python 374 | 375 | # === Create ark file from numpy 376 | kaldiio.save_ark('b.ark', {'key': numpy_array, 'key2': numpy_array2}) 377 | # Create ark with scp _file, too 378 | kaldiio.save_ark('b.ark', {'key': numpy_array, 'key2': numpy_array2}, 379 | scp='b.scp') 380 | 381 | # === Writes arrays to sys.stdout 382 | import sys 383 | kaldiio.save_ark(sys.stdout, {'key': numpy_array}) 384 | 385 | # === Writes arrays for each keys 386 | # generate a.ark 387 | kaldiio.save_ark('a.ark', {'key': numpy_array, 'key2': numpy_array2}) 388 | # After here, a.ark is opened with 'a' (append) mode. 389 | kaldiio.save_ark('a.ark', {'key3': numpy_array3}, append=True) 390 | 391 | 392 | # === Use with open_like_kaldi 393 | from kaldiio import open_like_kaldi 394 | with open_like_kaldi('| gzip a.ark.gz', 'w') as f: 395 | kaldiio.save_ark(f, {'key': numpy_array}) 396 | kaldiio.save_ark(f, {'key2': numpy_array2}) 397 | ``` 398 | ### save_mat 399 | ```python 400 | # array.ndim must be 1 or 2 401 | kaldiio.save_mat('a.mat', numpy_array) 402 | ``` 403 | - `save_mat` can save both kaldi-matrix and kaldi-vector 404 | 405 | 406 | ### open_like_kaldi 407 | 408 | ``kaldiio.open_like_kaldi`` is a useful tool if you are familiar with Kaldi. This function can performs as following, 409 | 410 | ```python 411 | from kaldiio import open_like_kaldi 412 | with open_like_kaldi('echo -n hello |', 'r') as f: 413 | assert f.read() == 'hello' 414 | with open_like_kaldi('| cat > out.txt', 'w') as f: 415 | f.write('hello') 416 | with open('out.txt', 'r') as f: 417 | assert f.read() == 'hello' 418 | 419 | import sys 420 | with open_like_kaldi('-', 'r') as f: 421 | assert f is sys.stdin 422 | with open_like_kaldi('-', 'w') as f: 423 | assert f is sys.stdout 424 | ``` 425 | 426 | For example, if there are gziped alignment file, then you can load it as: 427 | 428 | ```python 429 | from kaldiio import open_like_kaldi, load_ark 430 | with open_like_kaldi('gunzip -c exp/tri3_ali/ali.*.gz |', 'rb') as f: 431 | # Alignment format equals ark of IntVector 432 | g = load_ark(f) 433 | for k, numpy_array in g: 434 | ... 435 | ``` 436 | 437 | ### parse_specifier 438 | 439 | ```python 440 | from kaldiio import parse_specifier, open_like_kaldi, load_ark 441 | rspecifier = 'ark:gunzip -c file.ark.gz |' 442 | spec_dict = parse_specifier(rspecifier) 443 | # spec_dict = {'ark': 'gunzip -c file.ark.gz |'} 444 | 445 | with open_like_kaldi(spec_dict['ark'], 'rb') as fark: 446 | for key, numpy_array in load_ark(fark): 447 | ... 448 | ``` 449 | -------------------------------------------------------------------------------- /kaldiio/__init__.py: -------------------------------------------------------------------------------- 1 | from kaldiio.matio import load_ark 2 | from kaldiio.matio import load_mat 3 | from kaldiio.matio import load_scp 4 | from kaldiio.matio import load_scp_sequential 5 | from kaldiio.matio import load_wav_scp 6 | from kaldiio.matio import save_ark 7 | from kaldiio.matio import save_mat 8 | from kaldiio.highlevel import ReadHelper 9 | from kaldiio.highlevel import WriteHelper 10 | from kaldiio.utils import open_like_kaldi 11 | from kaldiio.utils import parse_specifier 12 | 13 | __version__ = "2.18.1" 14 | -------------------------------------------------------------------------------- /kaldiio/compression_header.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | import struct 4 | 5 | import numpy as np 6 | 7 | 8 | kAutomaticMethod = 1 9 | kSpeechFeature = 2 10 | kTwoByteAuto = 3 11 | kTwoByteSignedInteger = 4 12 | kOneByteAuto = 5 13 | kOneByteUnsignedInteger = 6 14 | kOneByteZeroOne = 7 15 | 16 | 17 | class GlobalHeader(object): 18 | """This is a imitation class of the structure "GlobalHeader" """ 19 | 20 | def __init__(self, type, min_value, range, rows, cols, endian="<"): 21 | if type in ("CM", "CM2"): 22 | c = 65535.0 23 | elif type == "CM3": 24 | c = 255.0 25 | else: 26 | raise RuntimeError("Not supported type={}".format(type)) 27 | self.type = type 28 | self.c = c 29 | self.min_value = min_value 30 | self.range = range 31 | self.rows = rows 32 | self.cols = cols 33 | self.endian = endian 34 | 35 | @property 36 | def size(self): 37 | return 17 + len(self.type) 38 | 39 | @staticmethod 40 | def read(fd, type="CM", endian="<"): 41 | min_value = struct.unpack(endian + "f", fd.read(4))[0] 42 | range = struct.unpack(endian + "f", fd.read(4))[0] 43 | rows = struct.unpack(endian + "i", fd.read(4))[0] 44 | cols = struct.unpack(endian + "i", fd.read(4))[0] 45 | return GlobalHeader(type, min_value, range, rows, cols, endian) 46 | 47 | def write(self, fd, endian=None): 48 | if endian is None: 49 | endian = self.endian 50 | fd.write(self.type.encode() + b" ") 51 | fd.write(struct.pack(endian + "f", self.min_value)) 52 | fd.write(struct.pack(endian + "f", self.range)) 53 | fd.write(struct.pack(endian + "i", self.rows)) 54 | fd.write(struct.pack(endian + "i", self.cols)) 55 | return self.size 56 | 57 | @staticmethod 58 | def compute(array, compression_method, endian="<"): 59 | if compression_method == kAutomaticMethod: 60 | if array.shape[0] > 8: 61 | compression_method = kSpeechFeature 62 | else: 63 | compression_method = kTwoByteAuto 64 | 65 | if compression_method == kSpeechFeature: 66 | matrix_type = "CM" 67 | elif ( 68 | compression_method == kTwoByteAuto 69 | or compression_method == kTwoByteSignedInteger 70 | ): 71 | matrix_type = "CM2" 72 | elif ( 73 | compression_method == kOneByteAuto 74 | or compression_method == kOneByteUnsignedInteger 75 | or compression_method == kOneByteZeroOne 76 | ): 77 | matrix_type = "CM3" 78 | else: 79 | raise ValueError( 80 | "Unknown compression_method: {}".format(compression_method) 81 | ) 82 | 83 | if ( 84 | compression_method == kSpeechFeature 85 | or compression_method == kTwoByteAuto 86 | or compression_method == kOneByteAuto 87 | ): 88 | min_value = array.min() 89 | max_value = array.max() 90 | if min_value == max_value: 91 | max_value = min_value + (1.0 + abs(min_value)) 92 | range_ = max_value - min_value 93 | elif compression_method == kTwoByteSignedInteger: 94 | min_value = -32768.0 95 | range_ = 65535.0 96 | elif compression_method == kOneByteUnsignedInteger: 97 | min_value = 0.0 98 | range_ = 255.0 99 | elif compression_method == kOneByteZeroOne: 100 | min_value = 0.0 101 | range_ = 1.0 102 | else: 103 | raise ValueError( 104 | "Unknown compression_method: {}".format(compression_method) 105 | ) 106 | 107 | return GlobalHeader( 108 | matrix_type, min_value, range_, array.shape[0], array.shape[1], endian 109 | ) 110 | 111 | def float_to_uint(self, array): 112 | if self.c == 65535.0: 113 | dtype = np.dtype(self.endian + "u2") 114 | else: 115 | dtype = np.dtype(self.endian + "u1") 116 | # + 0.499 is to round to closest int 117 | array = (array - self.min_value) / self.range * self.c + 0.499 118 | return array.astype(np.dtype(dtype)) 119 | 120 | def uint_to_float(self, array): 121 | array = array.astype(np.float32) 122 | return self.min_value + array * self.range / self.c 123 | 124 | 125 | class PerColHeader(object): 126 | """This is a imitation class of the structure "PerColHeader" """ 127 | 128 | def __init__(self, p0, p25, p75, p100, endian="<"): 129 | # p means percentile 130 | self.p0 = p0 131 | self.p25 = p25 132 | self.p75 = p75 133 | self.p100 = p100 134 | self.endian = endian 135 | 136 | @property 137 | def size(self): 138 | return 8 * self.p0.shape[0] 139 | 140 | @staticmethod 141 | def read(fd, global_header): 142 | endian = global_header.endian 143 | # Read PerColHeader 144 | size_of_percolheader = 8 145 | buf = fd.read(size_of_percolheader * global_header.cols) 146 | header_array = np.frombuffer(buf, dtype=np.dtype(endian + "u2")) 147 | header_array = np.asarray(header_array, np.float32) 148 | # Decompress header 149 | header_array = global_header.uint_to_float(header_array) 150 | header_array = header_array.reshape(-1, 4, 1) 151 | return PerColHeader( 152 | header_array[:, 0], 153 | header_array[:, 1], 154 | header_array[:, 2], 155 | header_array[:, 3], 156 | endian, 157 | ) 158 | 159 | def write(self, fd, global_header, endian=None): 160 | if endian is None: 161 | endian = self.endian 162 | header_array = np.concatenate([self.p0, self.p25, self.p75, self.p100], axis=1) 163 | header_array = global_header.float_to_uint(header_array) 164 | header_array = header_array.astype(np.dtype(endian + "u2")) 165 | byte_str = header_array.tobytes() 166 | fd.write(byte_str) 167 | return len(byte_str) 168 | 169 | @staticmethod 170 | def compute(array, global_header): 171 | quarter_nr = array.shape[0] // 4 172 | if array.shape[0] >= 5: 173 | srows = np.partition( 174 | array, [0, quarter_nr, 3 * quarter_nr, array.shape[0] - 1], axis=0 175 | ) 176 | p0 = srows[0] 177 | p25 = srows[quarter_nr] 178 | p75 = srows[3 * quarter_nr] 179 | p100 = srows[array.shape[0] - 1] 180 | else: 181 | srows = np.sort(array, axis=0) 182 | p0 = srows[0] 183 | if array.shape[0] > 1: 184 | p25 = srows[1] 185 | else: 186 | p25 = p0 + 1 187 | if array.shape[0] > 2: 188 | p75 = srows[2] 189 | else: 190 | p75 = p25 + 1 191 | if array.shape[0] > 3: 192 | p100 = srows[3] 193 | else: 194 | p100 = p75 + 1 195 | p0 = global_header.float_to_uint(p0) 196 | p25 = global_header.float_to_uint(p25) 197 | p75 = global_header.float_to_uint(p75) 198 | p100 = global_header.float_to_uint(p100) 199 | 200 | p0 = np.minimum(p0, 65532) 201 | p25 = np.minimum(np.maximum(p25, p0 + 1), 65533) 202 | p75 = np.minimum(np.maximum(p75, p25 + 1), 65534) 203 | p100 = np.maximum(p100, p75 + 1) 204 | 205 | p0 = global_header.uint_to_float(p0) 206 | p25 = global_header.uint_to_float(p25) 207 | p75 = global_header.uint_to_float(p75) 208 | p100 = global_header.uint_to_float(p100) 209 | 210 | p0 = p0[:, None] 211 | p25 = p25[:, None] 212 | p75 = p75[:, None] 213 | p100 = p100[:, None] 214 | return PerColHeader(p0, p25, p75, p100, global_header.endian) 215 | 216 | def float_to_char(self, array): 217 | p0, p25, p75, p100 = self.p0, self.p25, self.p75, self.p100 218 | 219 | ma1 = array < p25 220 | ma3 = array >= p75 221 | ma2 = ~ma1 * ~ma3 222 | 223 | # +0.5 round to the closest int 224 | tmp = (array - p0) / (p25 - p0) * 64.0 + 0.5 225 | tmp = np.where(tmp < 0.0, 0.0, np.where(tmp > 64.0, 64.0, tmp)) 226 | 227 | tmp2 = (array - p25) / (p75 - p25) * 128.0 + 64.5 228 | tmp2 = np.where(tmp2 < 64.0, 64.0, np.where(tmp2 > 192.0, 192.0, tmp2)) 229 | 230 | tmp3 = (array - p75) / (p100 - p75) * 63.0 + 192.5 231 | tmp3 = np.where(tmp3 < 192.0, 192.0, np.where(tmp3 > 255.0, 255.0, tmp3)) 232 | array = np.where(ma1, tmp, np.where(ma2, tmp2, tmp3)) 233 | return array.astype(np.dtype(self.endian + "u1")) 234 | 235 | def char_to_float(self, array): 236 | array = array.astype(np.float32) 237 | p0, p25, p75, p100 = self.p0, self.p25, self.p75, self.p100 238 | 239 | ma1 = array <= 64 240 | ma3 = array > 192 241 | ma2 = ~ma1 * ~ma3 # 192 >= array > 64 242 | 243 | return np.where( 244 | ma1, 245 | p0 + (p25 - p0) * array * (1 / 64.0), 246 | np.where( 247 | ma2, 248 | p25 + (p75 - p25) * (array - 64.0) * (1 / 128.0), 249 | p75 + (p100 - p75) * (array - 192.0) * (1 / 63.0), 250 | ), 251 | ) 252 | -------------------------------------------------------------------------------- /kaldiio/highlevel.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | import warnings 4 | 5 | from kaldiio.matio import load_ark 6 | from kaldiio.matio import load_scp_sequential 7 | from kaldiio.matio import save_ark 8 | from kaldiio.utils import open_like_kaldi 9 | from kaldiio.utils import parse_specifier 10 | 11 | 12 | class WriteHelper(object): 13 | """A heghlevel interface to write ark or/and scp 14 | 15 | >>> helper = WriteHelper('ark,scp:a.ark,b.ark') 16 | >>> helper('uttid', array) 17 | 18 | """ 19 | 20 | def __init__(self, wspecifier, compression_method=None, write_function=None): 21 | self.initialized = False 22 | self.closed = False 23 | 24 | self.compression_method = compression_method 25 | self.write_function = write_function 26 | spec_dict = parse_specifier(wspecifier) 27 | if spec_dict["scp"] is not None and spec_dict["ark"] is None: 28 | raise ValueError( 29 | "Writing only in a scp file is not supported. " 30 | "Please specify a ark file with a scp file." 31 | ) 32 | for k in spec_dict: 33 | if spec_dict[k] and k not in ("scp", "ark", "t", "f"): 34 | warnings.warn( 35 | "{} option is given, but currently it never affects".format(k) 36 | ) 37 | 38 | self.text = spec_dict["t"] 39 | self.flush = spec_dict["f"] 40 | ark_file = spec_dict["ark"] 41 | self.fark = open_like_kaldi(ark_file, "wb") 42 | if spec_dict["scp"] is not None: 43 | self.fscp = open_like_kaldi(spec_dict["scp"], "w") 44 | else: 45 | self.fscp = None 46 | self.initialized = True 47 | 48 | def __call__(self, key, array): 49 | if self.closed: 50 | raise RuntimeError("WriteHelper has been already closed") 51 | save_ark( 52 | self.fark, 53 | {key: array}, 54 | scp=self.fscp, 55 | text=self.text, 56 | compression_method=self.compression_method, 57 | write_function=self.write_function, 58 | ) 59 | 60 | if self.flush: 61 | if self.fark is not None: 62 | self.fark.flush() 63 | if self.fscp is not None: 64 | self.fscp.flush() 65 | 66 | def __setitem__(self, key, value): 67 | self(key, value) 68 | 69 | def __enter__(self): 70 | return self 71 | 72 | def __exit__(self, exc_type, exc_val, exc_tb): 73 | self.close() 74 | 75 | def close(self): 76 | if self.initialized and not self.closed: 77 | self.fark.close() 78 | if self.fscp is not None: 79 | self.fscp.close() 80 | self.closed = True 81 | 82 | 83 | class ReadHelper(object): 84 | """A highlevel interface to load ark or scp 85 | 86 | >>> import numpy 87 | >>> array_in = numpy.random.randn(10, 10) 88 | >>> save_ark('feats.ark', {'foo': array_in}, scp='feats.scp') 89 | >>> helper = ReadHelper('ark:cat feats.ark |') 90 | >>> for uttid, array_out in helper: 91 | ... assert uttid == 'foo' 92 | ... numpy.testing.assert_array_equal(array_in, array_out) 93 | >>> helper = ReadHelper('scp:feats.scp') 94 | >>> for uttid, array_out in helper: 95 | ... assert uttid == 'foo' 96 | ... numpy.testing.assert_array_equal(array_in, array_out) 97 | 98 | """ 99 | 100 | def __init__(self, wspecifier, segments=None): 101 | self.initialized = False 102 | self.scp = None 103 | self.closed = False 104 | self.generator = None 105 | self.segments = segments 106 | 107 | spec_dict = parse_specifier(wspecifier) 108 | if spec_dict["scp"] is not None and spec_dict["ark"] is not None: 109 | raise RuntimeError("Specify one of scp or ark in rspecifier") 110 | for k in spec_dict: 111 | if spec_dict[k] and k not in ("scp", "ark", "p"): 112 | warnings.warn( 113 | "{} option is given, but currently it never affects".format(k) 114 | ) 115 | self.permissive = spec_dict["p"] 116 | 117 | if spec_dict["scp"] is not None: 118 | self.scp = spec_dict["scp"] 119 | else: 120 | self.scp = False 121 | 122 | if self.scp: 123 | self.generator = load_scp_sequential(spec_dict["scp"], segments=segments) 124 | 125 | self.file = None 126 | else: 127 | if segments is not None: 128 | raise ValueError('Not supporting "segments" argument with ark file') 129 | self.file = open_like_kaldi(spec_dict["ark"], "rb") 130 | self.initialized = True 131 | 132 | def __iter__(self): 133 | if self.scp: 134 | while True: 135 | try: 136 | k, v = next(self.generator) 137 | except StopIteration: 138 | break 139 | except Exception: 140 | if self.permissive: 141 | # Stop if error happen 142 | break 143 | else: 144 | raise 145 | yield k, v 146 | 147 | else: 148 | with self.file as f: 149 | it = load_ark(f) 150 | while True: 151 | try: 152 | k, v = next(it) 153 | except StopIteration: 154 | break 155 | except Exception: 156 | if self.permissive: 157 | # Stop if error happen 158 | break 159 | else: 160 | raise 161 | yield k, v 162 | self.closed = True 163 | 164 | def __enter__(self): 165 | return self 166 | 167 | def __exit__(self, exc_type, exc_val, exc_tb): 168 | if not self.scp and not self.closed: 169 | self.close() 170 | 171 | def close(self): 172 | if self.initialized and not self.scp and not self.closed: 173 | self.file.close() 174 | self.closed = True 175 | -------------------------------------------------------------------------------- /kaldiio/matio.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import unicode_literals 3 | 4 | import codecs 5 | import math 6 | import pickle 7 | import re 8 | import struct 9 | import sys 10 | import warnings 11 | from functools import partial 12 | from io import BytesIO 13 | from io import StringIO 14 | 15 | import numpy as np 16 | 17 | from kaldiio.compression_header import GlobalHeader 18 | from kaldiio.compression_header import PerColHeader 19 | from kaldiio.utils import LazyLoader 20 | from kaldiio.utils import LimitedSizeDict 21 | from kaldiio.utils import MultiFileDescriptor 22 | from kaldiio.utils import default_encoding 23 | from kaldiio.utils import open_like_kaldi 24 | from kaldiio.utils import open_or_fd 25 | from kaldiio.utils import seekable 26 | from kaldiio.wavio import read_wav 27 | from kaldiio.wavio import write_wav 28 | 29 | PY3 = sys.version_info[0] == 3 30 | 31 | if PY3: 32 | from collections.abc import Mapping 33 | 34 | binary_type = bytes 35 | string_types = (str,) 36 | 37 | def to_bytes(n, length, endianess="little"): 38 | return n.to_bytes(length, endianess) 39 | 40 | def from_bytes(s, endianess="little"): 41 | return int.from_bytes(s, endianess) 42 | 43 | else: 44 | from collections import Mapping 45 | 46 | binary_type = str 47 | string_types = (basestring,) # noqa: F821 48 | 49 | def to_bytes(n, length, endianess="little"): 50 | assert endianess in ("big", "little"), endianess 51 | h = b"%x" % n 52 | s = codecs.decode((b"0" * (len(h) % 2) + h).zfill(length * 2), "hex") 53 | return s if endianess == "big" else s[::-1] 54 | 55 | def from_bytes(s, endianess="little"): 56 | if endianess == "little": 57 | s = s[::-1] 58 | return int(codecs.encode(s, "hex"), 16) 59 | 60 | 61 | def load_scp(fname, endian="<", separator=None, segments=None, max_cache_fd=0): 62 | """Lazy loader for kaldi scp file. 63 | 64 | Args: 65 | fname (str or file(text mode)): 66 | endian (str): 67 | separator (str): 68 | segments (str): The path of segments 69 | """ 70 | assert endian in ("<", ">"), endian 71 | 72 | if max_cache_fd != 0: 73 | if segments is not None: 74 | raise ValueError("max_cache_fd is not supported for segments mode") 75 | d = LimitedSizeDict(max_cache_fd) 76 | else: 77 | d = None 78 | 79 | if segments is None: 80 | load_func = partial(load_mat, endian=endian, fd_dict=d) 81 | loader = LazyLoader(load_func) 82 | with open_like_kaldi(fname, "r") as fd: 83 | for line in fd: 84 | seps = line.split(separator, 1) 85 | if len(seps) != 2: 86 | raise ValueError("Invalid line is found:\n> {}".format(line)) 87 | token, arkname = seps 88 | loader[token] = arkname.rstrip() 89 | return loader 90 | else: 91 | return SegmentsExtractor(fname, separator=separator, segments=segments) 92 | 93 | 94 | def load_scp_sequential(fname, endian="<", separator=None, segments=None): 95 | """Lazy loader for kaldi scp file. 96 | 97 | Args: 98 | fname (str or file(text mode)): 99 | endian (str): 100 | separator (str): 101 | segments (str): The path of segments 102 | """ 103 | assert endian in ("<", ">"), endian 104 | if segments is None: 105 | with open_like_kaldi(fname, "r") as fd: 106 | prev_ark = None 107 | prev_arkfd = None 108 | 109 | try: 110 | for line in fd: 111 | seps = line.split(separator, 1) 112 | if len(seps) != 2: 113 | raise ValueError("Invalid line is found:\n> {}".format(line)) 114 | token, arkname = seps 115 | arkname = arkname.rstrip() 116 | 117 | ark, offset, slices = _parse_arkpath(arkname) 118 | 119 | if prev_ark == ark: 120 | arkfd = prev_arkfd 121 | mat = _load_mat(arkfd, offset, slices, endian=endian) 122 | else: 123 | if prev_arkfd is not None: 124 | prev_arkfd.close() 125 | arkfd = open_like_kaldi(ark, "rb") 126 | mat = _load_mat(arkfd, offset, slices, endian=endian) 127 | 128 | prev_ark = ark 129 | prev_arkfd = arkfd 130 | yield token, mat 131 | except Exception: 132 | if prev_arkfd is not None: 133 | prev_arkfd.close() 134 | raise 135 | 136 | else: 137 | for data in SegmentsExtractor( 138 | fname, separator=separator, segments=segments 139 | ).generator(): 140 | yield data 141 | 142 | 143 | def load_wav_scp(fname, segments=None, separator=None): 144 | warnings.warn("Use load_scp instead of load_wav_scp", DeprecationWarning) 145 | return load_scp(fname, separator=separator, segments=segments) 146 | 147 | 148 | class SegmentsExtractor(Mapping): 149 | """Emulate the following, 150 | 151 | https://github.com/kaldi-asr/kaldi/blob/master/src/featbin/extract-segments.cc 152 | 153 | Args: 154 | segments (str): The file format is 155 | " \n" 156 | "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5\n" 157 | """ 158 | 159 | def __init__(self, fname, segments=None, separator=None): 160 | self.wav_scp = fname 161 | self.wav_loader = load_scp(self.wav_scp, separator=separator) 162 | 163 | self.segments = segments 164 | self._segments_dict = {} 165 | with open_or_fd(self.segments, "r") as f: 166 | for line in f: 167 | sps = line.rstrip().split(separator) 168 | if len(sps) != 4: 169 | raise RuntimeError("Format is invalid: {}".format(line)) 170 | uttid, recodeid, st, et = sps 171 | self._segments_dict[uttid] = (recodeid, float(st), float(et)) 172 | 173 | if recodeid not in self.wav_loader: 174 | raise RuntimeError( 175 | 'Not found "{}" in {}'.format(recodeid, self.wav_scp) 176 | ) 177 | 178 | def generator(self): 179 | recodeid_counter = {} 180 | for utt, (recodeid, st, et) in self._segments_dict.items(): 181 | recodeid_counter[recodeid] = recodeid_counter.get(recodeid, 0) + 1 182 | 183 | cached = {} 184 | for utt, (recodeid, st, et) in self._segments_dict.items(): 185 | if recodeid not in cached: 186 | cached[recodeid] = self.wav_loader[recodeid] 187 | array = cached[recodeid] 188 | 189 | # Keep array until the last query 190 | recodeid_counter[recodeid] -= 1 191 | if recodeid_counter[recodeid] == 0: 192 | cached.pop(recodeid) 193 | 194 | yield utt, self._return(array, st, et) 195 | 196 | def __iter__(self): 197 | return iter(self._segments_dict) 198 | 199 | def __contains__(self, item): 200 | return item in self._segments_dict 201 | 202 | def __len__(self): 203 | return len(self._segments_dict) 204 | 205 | def __getitem__(self, key): 206 | recodeid, st, et = self._segments_dict[key] 207 | array = self.wav_loader[recodeid] 208 | return self._return(array, st, et) 209 | 210 | def _return(self, array, st, et): 211 | if isinstance(array, (tuple, list)): 212 | rate, array = array 213 | else: 214 | raise RuntimeError("{} is not wav.scp?".format(self.wav_scp)) 215 | 216 | # Convert starting time of the segment to corresponding sample number. 217 | # If end time is -1 then use the whole file starting from start time. 218 | if et != -1: 219 | return rate, array[int(st * rate) : int(et * rate)] 220 | else: 221 | return rate, array[int(st * rate) :] 222 | 223 | 224 | def load_mat(ark_name, endian="<", fd_dict=None): 225 | assert endian in ("<", ">"), endian 226 | if fd_dict is not None and not isinstance(fd_dict, Mapping): 227 | raise RuntimeError( 228 | "fd_dict must be dict or None, bot got {}".format(type(fd_dict)) 229 | ) 230 | 231 | ark, offset, slices = _parse_arkpath(ark_name) 232 | 233 | if fd_dict is not None and not (ark.strip()[-1] == "|" or ark.strip()[0] == "|"): 234 | if ark not in fd_dict: 235 | fd_dict[ark] = open_like_kaldi(ark, "rb") 236 | fd = fd_dict[ark] 237 | return _load_mat(fd, offset, slices, endian=endian) 238 | else: 239 | with open_like_kaldi(ark, "rb") as fd: 240 | return _load_mat(fd, offset, slices, endian=endian) 241 | 242 | 243 | def _parse_arkpath(ark_name): 244 | """Parse arkpath 245 | 246 | Args: 247 | ark_name (str): 248 | Returns: 249 | Tuple[str, int, Optional[Tuple[slice, ...]]] 250 | Examples: 251 | >>> _parse_arkpath('a.ark') 252 | 'a.ark', None, None 253 | >>> _parse_arkpath('a.ark:12') 254 | 'a.ark', 12, None 255 | >>> _parse_arkpath('a.ark:12[3:4]') 256 | 'a.ark', 12, (slice(3, 4, None),) 257 | >>> _parse_arkpath('cat "fo:o.ark" |') 258 | 'cat "fo:o.ark" |', None, None 259 | 260 | """ 261 | 262 | if ark_name.strip()[-1] == "|" or ark_name.strip()[0] == "|": 263 | # Something like: "| cat foo" or "cat bar|" shouldn't be parsed 264 | return ark_name, None, None 265 | 266 | slices = None 267 | if "[" in ark_name and "]" in ark_name: 268 | _ark_name, Range = ark_name.split("[") 269 | Range = Range.replace("]", "").strip() 270 | try: 271 | slices = _convert_to_slice(Range) 272 | except Exception: 273 | pass 274 | else: 275 | ark_name = _ark_name 276 | 277 | if ":" in ark_name: 278 | fname, offset = ark_name.rsplit(":", 1) 279 | try: 280 | offset = int(offset) 281 | except ValueError: 282 | fname = ark_name 283 | offset = None 284 | else: 285 | fname = ark_name 286 | offset = None 287 | return fname, offset, slices 288 | 289 | 290 | def _convert_to_slice(string): 291 | """Convert slice-str to slice 292 | 293 | Examples: 294 | >>> _convert_to_slice('0:51') 295 | (slice(0, 52),) 296 | >>> _convert_to_slice('0:51,6:10') 297 | (slice(0, 52), slice(6, 11)) 298 | >>> _convert_to_slice(',6:10') 299 | (slice(None), slice(6, 11)) 300 | 301 | """ 302 | 303 | slices = [] 304 | for ele in string.split(","): 305 | if ele == "" or ele == ":": 306 | slices.append(slice(None)) 307 | else: 308 | sps = [] 309 | for sp in ele.split(":"): 310 | try: 311 | sps.append(int(sp)) 312 | except ValueError: 313 | raise ValueError("Format error: {}".format(string)) 314 | if len(sps) == 1: 315 | sl = slice(sps[0], sps[0] + 1) 316 | elif len(sps) == 2: 317 | sl = slice(sps[0], sps[1] + 1) 318 | elif len(sps) == 3: 319 | sl = slice(sps[0], sps[1] + 1, sps[2]) 320 | else: 321 | raise RuntimeError("Too many : {}".format(string)) 322 | 323 | slices.append(sl) 324 | return tuple(slices) 325 | 326 | 327 | def _load_mat(fd, offset, slices=None, endian="<"): 328 | if offset is not None: 329 | fd.seek(offset) 330 | array = read_kaldi(fd, endian) 331 | 332 | if slices is not None: 333 | if isinstance(array, (tuple, list)): 334 | array = (array[0], array[1][slices]) 335 | else: 336 | array = array[slices] 337 | return array 338 | 339 | 340 | def load_ark(fname, endian="<"): 341 | assert endian in ("<", ">"), endian 342 | with open_or_fd(fname, "rb") as fd: 343 | while True: 344 | token = read_token(fd) 345 | if token is None: 346 | break 347 | array = read_kaldi(fd, endian) 348 | yield token, array 349 | 350 | 351 | def read_token(fd): 352 | """Read token 353 | 354 | Args: 355 | fd (file): 356 | """ 357 | token = [] 358 | # Keep the loop until finding ' ' or end of char 359 | while True: 360 | c = fd.read(1) 361 | if c == b" " or c == b"": 362 | break 363 | token.append(c) 364 | if len(token) == 0: # End of file 365 | return None 366 | decoded = b"".join(token).decode(encoding=default_encoding) 367 | return decoded 368 | 369 | 370 | def read_kaldi(fd, endian="<", audio_loader="soundfile", load_kwargs=None): 371 | """Load kaldi 372 | 373 | Args: 374 | fd (file): Binary mode file object. Cannot input string 375 | endian (str): 376 | audio_loader: (Union[str, callable]): 377 | """ 378 | assert endian in ("<", ">"), endian 379 | if load_kwargs is None: 380 | load_kwargs = {} 381 | 382 | max_flag_length = len(b"AUDIO") 383 | 384 | binary_flag = fd.read(max_flag_length) 385 | assert isinstance(binary_flag, binary_type), type(binary_flag) 386 | 387 | if seekable(fd): 388 | fd.seek(-max_flag_length, 1) 389 | else: 390 | fd = MultiFileDescriptor(BytesIO(binary_flag), fd) 391 | 392 | if binary_flag[:4] == b"RIFF": 393 | # array: Tuple[int, np.ndarray] 394 | array = read_wav(fd) 395 | 396 | elif binary_flag[:4] == b"fLaC": 397 | import soundfile 398 | buf = fd.read() 399 | _fd = BytesIO(buf) 400 | audio, rate = soundfile.read(_fd) 401 | array = (rate, audio,) 402 | elif binary_flag[:3] == b"NPY": 403 | fd.read(3) 404 | length_ = _read_length_header(fd) 405 | buf = fd.read(length_) 406 | _fd = BytesIO(buf) 407 | array = np.load(_fd, **load_kwargs) 408 | 409 | elif binary_flag[:3] == b"PKL": 410 | fd.read(3) 411 | array = pickle.load(fd, **load_kwargs) 412 | 413 | elif binary_flag[:5] == b"AUDIO": 414 | fd.read(5) 415 | length_ = _read_length_header(fd) 416 | buf = fd.read(length_) 417 | _fd = BytesIO(buf) 418 | 419 | if audio_loader == "soundfile": 420 | import soundfile 421 | 422 | audio_loader = soundfile.read 423 | else: 424 | raise ValueError("Not supported: audio_loader={}".format(audio_loader)) 425 | 426 | x1, x2 = audio_loader(_fd, **load_kwargs) 427 | 428 | # array: Tuple[int, np.ndarray] according to scipy wav read 429 | if isinstance(x1, int) and isinstance(x2, np.ndarray): 430 | array = (x1, x2) 431 | elif isinstance(x1, np.ndarray) and isinstance(x2, int): 432 | array = (x2, x1) 433 | else: 434 | raise RuntimeError( 435 | "Got unexpected type from audio_loader: ({}, {})".format( 436 | type(x1), type(x2) 437 | ) 438 | ) 439 | 440 | # Load as binary 441 | elif binary_flag[:2] == b"\0B": 442 | if binary_flag[2:3] == b"\4": # This is int32Vector 443 | array = read_int32vector(fd, endian) 444 | else: 445 | array = read_matrix_or_vector(fd, endian) 446 | # Load as ascii 447 | else: 448 | array = read_ascii_mat(fd) 449 | 450 | return array 451 | 452 | 453 | def read_int32vector(fd, endian="<", return_size=False): 454 | assert fd.read(2) == b"\0B" 455 | assert fd.read(1) == b"\4" 456 | length = struct.unpack(endian + "i", fd.read(4))[0] 457 | array = np.empty(length, dtype=np.int32) 458 | for i in range(length): 459 | assert fd.read(1) == b"\4" 460 | array[i] = struct.unpack(endian + "i", fd.read(4))[0] 461 | if return_size: 462 | return array, (length + 1) * 5 + 2 463 | else: 464 | return array 465 | 466 | 467 | def read_matrix_or_vector(fd, endian="<", return_size=False): 468 | """Call from load_kaldi_file 469 | 470 | Args: 471 | fd (file): 472 | endian (str): 473 | return_size (bool): 474 | """ 475 | size = 0 476 | assert fd.read(2) == b"\0B" 477 | size += 2 478 | 479 | Type = str(read_token(fd)) 480 | size += len(Type) + 1 481 | 482 | # CompressedMatrix 483 | if "CM" == Type: 484 | # Read GlobalHeader 485 | global_header = GlobalHeader.read(fd, Type, endian) 486 | size += global_header.size 487 | per_col_header = PerColHeader.read(fd, global_header) 488 | size += per_col_header.size 489 | 490 | # Read data 491 | buf = fd.read(global_header.rows * global_header.cols) 492 | size += global_header.rows * global_header.cols 493 | array = np.frombuffer(buf, dtype=np.dtype(endian + "u1")) 494 | array = array.reshape((global_header.cols, global_header.rows)) 495 | 496 | # Decompress 497 | array = per_col_header.char_to_float(array) 498 | array = array.T 499 | 500 | elif "CM2" == Type: 501 | # Read GlobalHeader 502 | global_header = GlobalHeader.read(fd, Type, endian) 503 | size += global_header.size 504 | 505 | # Read matrix 506 | buf = fd.read(2 * global_header.rows * global_header.cols) 507 | array = np.frombuffer(buf, dtype=np.dtype(endian + "u2")) 508 | array = array.reshape((global_header.rows, global_header.cols)) 509 | 510 | # Decompress 511 | array = global_header.uint_to_float(array) 512 | 513 | elif "CM3" == Type: 514 | # Read GlobalHeader 515 | global_header = GlobalHeader.read(fd, Type, endian) 516 | size += global_header.size 517 | 518 | # Read matrix 519 | buf = fd.read(global_header.rows * global_header.cols) 520 | array = np.frombuffer(buf, dtype=np.dtype(endian + "u1")) 521 | array = array.reshape((global_header.rows, global_header.cols)) 522 | 523 | # Decompress 524 | array = global_header.uint_to_float(array) 525 | 526 | else: 527 | if Type == "FM" or Type == "FV": 528 | dtype = endian + "f" 529 | bytes_per_sample = 4 530 | elif Type == "DM" or Type == "DV": 531 | dtype = endian + "d" 532 | bytes_per_sample = 8 533 | else: 534 | raise ValueError( 535 | 'Unexpected format: "{}". Now FM, FV, DM, DV, ' 536 | "CM, CM2, CM3 are supported.".format(Type) 537 | ) 538 | 539 | assert fd.read(1) == b"\4" 540 | size += 1 541 | rows = struct.unpack(endian + "i", fd.read(4))[0] 542 | size += 4 543 | dim = rows 544 | if "M" in Type: # As matrix 545 | assert fd.read(1) == b"\4" 546 | size += 1 547 | cols = struct.unpack(endian + "i", fd.read(4))[0] 548 | size += 4 549 | dim = rows * cols 550 | 551 | buf = fd.read(dim * bytes_per_sample) 552 | size += dim * bytes_per_sample 553 | array = np.frombuffer(buf, dtype=np.dtype(dtype)) 554 | 555 | if "M" in Type: # As matrix 556 | array = np.reshape(array, (rows, cols)) 557 | 558 | if return_size: 559 | return array, size 560 | else: 561 | return array 562 | 563 | 564 | def read_ascii_mat(fd, return_size=False): 565 | """Call from load_kaldi_file 566 | 567 | Args: 568 | fd (file): binary mode 569 | return_size (bool): 570 | """ 571 | string = [] 572 | size = 0 573 | 574 | # Find '[' char 575 | while True: 576 | b = fd.read(1) 577 | try: 578 | char = b.decode(encoding=default_encoding) 579 | except UnicodeDecodeError: 580 | raise ValueError("File format is wrong?") 581 | size += 1 582 | if char == " " or char == "\n": 583 | continue 584 | elif char == "[": 585 | hasparent = True 586 | break 587 | else: 588 | string.append(char) 589 | hasparent = False 590 | break 591 | 592 | # Read data 593 | ndmin = 1 594 | while True: 595 | char = fd.read(1).decode(encoding=default_encoding) 596 | size += 1 597 | if hasparent: 598 | if char == "]": 599 | char = fd.read(1).decode(encoding=default_encoding) 600 | size += 1 601 | assert char == "\n" or char == "" 602 | break 603 | elif char == "\n": 604 | ndmin = 2 605 | elif char == "": 606 | raise ValueError("There are no corresponding bracket ']' with '['") 607 | else: 608 | if char == "\n" or char == "": 609 | break 610 | string.append(char) 611 | string = "".join(string) 612 | assert len(string) != 0 613 | 614 | # Examine dtype 615 | match = re.match(r" *([^ \n]+) *", string) 616 | if match is None: 617 | dtype = np.float32 618 | else: 619 | ma = match.group(0) 620 | # If first element is integer, deal as interger array 621 | try: 622 | float(ma) 623 | except ValueError: 624 | raise RuntimeError(ma + "is not a digit\nFile format is wrong?") 625 | if "." in ma: 626 | dtype = np.float32 627 | else: 628 | dtype = np.int32 629 | array = np.loadtxt(StringIO(string), dtype=dtype, ndmin=ndmin) 630 | if return_size: 631 | return array, size 632 | else: 633 | return array 634 | 635 | 636 | def _read_length_header(fd): 637 | (bytes_length,) = struct.unpack(" self.chunksize: 160 | raise RuntimeError 161 | self.file.seek(self.offset + pos, 0) 162 | self.size_read = pos 163 | 164 | def tell(self): 165 | if self.closed: 166 | raise ValueError("I/O operation on closed file") 167 | return self.size_read 168 | 169 | def read(self, size=-1): 170 | """Read at most size bytes from the chunk. 171 | If size is omitted or negative, read until the end 172 | of the chunk. 173 | """ 174 | 175 | if self.closed: 176 | raise ValueError("I/O operation on closed file") 177 | if self.size_read >= self.chunksize: 178 | return b'' 179 | if size < 0: 180 | size = self.chunksize - self.size_read 181 | if size > self.chunksize - self.size_read: 182 | size = self.chunksize - self.size_read 183 | data = self.file.read(size) 184 | self.size_read = self.size_read + len(data) 185 | if self.size_read == self.chunksize and \ 186 | self.align and \ 187 | (self.chunksize & 1): 188 | dummy = self.file.read(1) 189 | self.size_read = self.size_read + len(dummy) 190 | return data 191 | 192 | def skip(self): 193 | """Skip the rest of the chunk. 194 | If you are not interested in the contents of the chunk, 195 | this method should be called so that the file points to 196 | the start of the next chunk. 197 | """ 198 | 199 | if self.closed: 200 | raise ValueError("I/O operation on closed file") 201 | if self.seekable: 202 | try: 203 | n = self.chunksize - self.size_read 204 | # maybe fix alignment 205 | if self.align and (self.chunksize & 1): 206 | n = n + 1 207 | self.file.seek(n, 1) 208 | self.size_read = self.size_read + n 209 | return 210 | except OSError: 211 | pass 212 | while self.size_read < self.chunksize: 213 | n = min(8192, self.chunksize - self.size_read) 214 | dummy = self.read(n) 215 | if not dummy: 216 | raise EOFError 217 | 218 | 219 | class Wave_read: 220 | """Variables used in this class: 221 | 222 | These variables are available to the user though appropriate 223 | methods of this class: 224 | _file -- the open file with methods read(), close(), and seek() 225 | set through the __init__() method 226 | _nchannels -- the number of audio channels 227 | available through the getnchannels() method 228 | _nframes -- the number of audio frames 229 | available through the getnframes() method 230 | _sampwidth -- the number of bytes per audio sample 231 | available through the getsampwidth() method 232 | _framerate -- the sampling frequency 233 | available through the getframerate() method 234 | _comptype -- the AIFF-C compression type ('NONE' if AIFF) 235 | available through the getcomptype() method 236 | _compname -- the human-readable AIFF-C compression type 237 | available through the getcomptype() method 238 | _soundpos -- the position in the audio stream 239 | available through the tell() method, set through the 240 | setpos() method 241 | 242 | These variables are used internally only: 243 | _fmt_chunk_read -- 1 iff the FMT chunk has been read 244 | _data_seek_needed -- 1 iff positioned correctly in audio 245 | file for readframes() 246 | _data_chunk -- instantiation of a chunk class for the DATA chunk 247 | _framesize -- size of one frame in the file 248 | """ 249 | 250 | def initfp(self, file): 251 | self._convert = None 252 | self._soundpos = 0 253 | self._file = _Chunk(file, bigendian = 0) 254 | if self._file.getname() != b'RIFF': 255 | raise Error('file does not start with RIFF id') 256 | if self._file.read(4) != b'WAVE': 257 | raise Error('not a WAVE file') 258 | self._fmt_chunk_read = 0 259 | self._data_chunk = None 260 | while 1: 261 | self._data_seek_needed = 1 262 | try: 263 | chunk = _Chunk(self._file, bigendian = 0) 264 | except EOFError: 265 | break 266 | chunkname = chunk.getname() 267 | if chunkname == b'fmt ': 268 | self._read_fmt_chunk(chunk) 269 | self._fmt_chunk_read = 1 270 | elif chunkname == b'data': 271 | if not self._fmt_chunk_read: 272 | raise Error('data chunk before fmt chunk') 273 | self._data_chunk = chunk 274 | self._nframes = chunk.chunksize // self._framesize 275 | self._data_seek_needed = 0 276 | break 277 | chunk.skip() 278 | if not self._fmt_chunk_read or not self._data_chunk: 279 | raise Error('fmt chunk and/or data chunk missing') 280 | 281 | def __init__(self, f): 282 | self._i_opened_the_file = None 283 | if isinstance(f, str): 284 | f = builtins.open(f, 'rb') 285 | self._i_opened_the_file = f 286 | # else, assume it is an open file object already 287 | try: 288 | self.initfp(f) 289 | except: 290 | if self._i_opened_the_file: 291 | f.close() 292 | raise 293 | 294 | def __del__(self): 295 | self.close() 296 | 297 | def __enter__(self): 298 | return self 299 | 300 | def __exit__(self, *args): 301 | self.close() 302 | 303 | # 304 | # User visible methods. 305 | # 306 | def getfp(self): 307 | return self._file 308 | 309 | def rewind(self): 310 | self._data_seek_needed = 1 311 | self._soundpos = 0 312 | 313 | def close(self): 314 | self._file = None 315 | file = self._i_opened_the_file 316 | if file: 317 | self._i_opened_the_file = None 318 | file.close() 319 | 320 | def tell(self): 321 | return self._soundpos 322 | 323 | def getnchannels(self): 324 | return self._nchannels 325 | 326 | def getnframes(self): 327 | return self._nframes 328 | 329 | def getsampwidth(self): 330 | return self._sampwidth 331 | 332 | def getframerate(self): 333 | return self._framerate 334 | 335 | def getcomptype(self): 336 | return self._comptype 337 | 338 | def getcompname(self): 339 | return self._compname 340 | 341 | def getparams(self): 342 | return _wave_params(self.getnchannels(), self.getsampwidth(), 343 | self.getframerate(), self.getnframes(), 344 | self.getcomptype(), self.getcompname()) 345 | 346 | def getmarkers(self): 347 | import warnings 348 | warnings._deprecated("Wave_read.getmarkers", remove=(3, 15)) 349 | return None 350 | 351 | def getmark(self, id): 352 | import warnings 353 | warnings._deprecated("Wave_read.getmark", remove=(3, 15)) 354 | raise Error('no marks') 355 | 356 | def setpos(self, pos): 357 | if pos < 0 or pos > self._nframes: 358 | raise Error('position not in range') 359 | self._soundpos = pos 360 | self._data_seek_needed = 1 361 | 362 | def readframes(self, nframes): 363 | if self._data_seek_needed: 364 | self._data_chunk.seek(0, 0) 365 | pos = self._soundpos * self._framesize 366 | if pos: 367 | self._data_chunk.seek(pos, 0) 368 | self._data_seek_needed = 0 369 | if nframes == 0: 370 | return b'' 371 | data = self._data_chunk.read(nframes * self._framesize) 372 | if self._sampwidth != 1 and sys.byteorder == 'big': 373 | data = _byteswap(data, self._sampwidth) 374 | if self._convert and data: 375 | data = self._convert(data) 376 | self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth) 377 | return data 378 | 379 | # 380 | # Internal methods. 381 | # 382 | 383 | def _read_fmt_chunk(self, chunk): 384 | try: 385 | wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from(' 4: 499 | raise Error('bad sample width') 500 | self._sampwidth = sampwidth 501 | 502 | def getsampwidth(self): 503 | if not self._sampwidth: 504 | raise Error('sample width not set') 505 | return self._sampwidth 506 | 507 | def setframerate(self, framerate): 508 | if self._datawritten: 509 | raise Error('cannot change parameters after starting to write') 510 | if framerate <= 0: 511 | raise Error('bad frame rate') 512 | self._framerate = int(round(framerate)) 513 | 514 | def getframerate(self): 515 | if not self._framerate: 516 | raise Error('frame rate not set') 517 | return self._framerate 518 | 519 | def setnframes(self, nframes): 520 | if self._datawritten: 521 | raise Error('cannot change parameters after starting to write') 522 | self._nframes = nframes 523 | 524 | def getnframes(self): 525 | return self._nframeswritten 526 | 527 | def setcomptype(self, comptype, compname): 528 | if self._datawritten: 529 | raise Error('cannot change parameters after starting to write') 530 | if comptype not in ('NONE',): 531 | raise Error('unsupported compression type') 532 | self._comptype = comptype 533 | self._compname = compname 534 | 535 | def getcomptype(self): 536 | return self._comptype 537 | 538 | def getcompname(self): 539 | return self._compname 540 | 541 | def setparams(self, params): 542 | nchannels, sampwidth, framerate, nframes, comptype, compname = params 543 | if self._datawritten: 544 | raise Error('cannot change parameters after starting to write') 545 | self.setnchannels(nchannels) 546 | self.setsampwidth(sampwidth) 547 | self.setframerate(framerate) 548 | self.setnframes(nframes) 549 | self.setcomptype(comptype, compname) 550 | 551 | def getparams(self): 552 | if not self._nchannels or not self._sampwidth or not self._framerate: 553 | raise Error('not all parameters set') 554 | return _wave_params(self._nchannels, self._sampwidth, self._framerate, 555 | self._nframes, self._comptype, self._compname) 556 | 557 | def setmark(self, id, pos, name): 558 | import warnings 559 | warnings._deprecated("Wave_write.setmark", remove=(3, 15)) 560 | raise Error('setmark() not supported') 561 | 562 | def getmark(self, id): 563 | import warnings 564 | warnings._deprecated("Wave_write.getmark", remove=(3, 15)) 565 | raise Error('no marks') 566 | 567 | def getmarkers(self): 568 | import warnings 569 | warnings._deprecated("Wave_write.getmarkers", remove=(3, 15)) 570 | return None 571 | 572 | def tell(self): 573 | return self._nframeswritten 574 | 575 | def writeframesraw(self, data): 576 | if not isinstance(data, (bytes, bytearray)): 577 | data = memoryview(data).cast('B') 578 | self._ensure_header_written(len(data)) 579 | nframes = len(data) // (self._sampwidth * self._nchannels) 580 | if self._convert: 581 | data = self._convert(data) 582 | if self._sampwidth != 1 and sys.byteorder == 'big': 583 | data = _byteswap(data, self._sampwidth) 584 | self._file.write(data) 585 | self._datawritten += len(data) 586 | self._nframeswritten = self._nframeswritten + nframes 587 | 588 | def writeframes(self, data): 589 | self.writeframesraw(data) 590 | if self._datalength != self._datawritten: 591 | self._patchheader() 592 | 593 | def close(self): 594 | try: 595 | if self._file: 596 | self._ensure_header_written(0) 597 | if self._datalength != self._datawritten: 598 | self._patchheader() 599 | self._file.flush() 600 | finally: 601 | self._file = None 602 | file = self._i_opened_the_file 603 | if file: 604 | self._i_opened_the_file = None 605 | file.close() 606 | 607 | # 608 | # Internal methods. 609 | # 610 | 611 | def _ensure_header_written(self, datasize): 612 | if not self._headerwritten: 613 | if not self._nchannels: 614 | raise Error('# channels not specified') 615 | if not self._sampwidth: 616 | raise Error('sample width not specified') 617 | if not self._framerate: 618 | raise Error('sampling rate not specified') 619 | self._write_header(datasize) 620 | 621 | def _write_header(self, initlength): 622 | assert not self._headerwritten 623 | self._file.write(b'RIFF') 624 | if not self._nframes: 625 | self._nframes = initlength // (self._nchannels * self._sampwidth) 626 | self._datalength = self._nframes * self._nchannels * self._sampwidth 627 | try: 628 | self._form_length_pos = self._file.tell() 629 | except (AttributeError, OSError): 630 | self._form_length_pos = None 631 | self._file.write(struct.pack('=3.7, 39 | 40 | - preferred encoding 41 | 42 | locale.getpreferredencoding(). Used for open(). 43 | The default value depends on the local in your unix system. 44 | 45 | - stdout and stdin 46 | 47 | If PYTHONIOENCODING is set, then it's used, 48 | else if in a terminal, same as filesystem encoding 49 | else same as preferred encoding 50 | 51 | - default encoding 52 | 53 | The default encoding for str.encode() or bytes.decode(). 54 | If Python2, it's ascii, if Python3, it's utf-8. 55 | 56 | """ 57 | 58 | 59 | if PY3: 60 | 61 | def my_popen(cmd, mode="r", buffering=-1): 62 | """Originated from python os module 63 | 64 | Extend for supporting mode == 'rb' and 'wb' 65 | 66 | Args: 67 | cmd (str): 68 | mode (str): 69 | buffering (int): 70 | """ 71 | if isinstance(cmd, text_type): 72 | cmd = cmd.encode(default_encoding) 73 | if buffering == 0 or buffering is None: 74 | raise ValueError("popen() does not support unbuffered streams") 75 | if mode == "r": 76 | proc = subprocess.Popen( 77 | cmd, shell=True, stdout=subprocess.PIPE, bufsize=buffering 78 | ) 79 | return _wrap_close( 80 | io.TextIOWrapper(proc.stdout, encoding=default_encoding), proc 81 | ) 82 | elif mode == "rb": 83 | proc = subprocess.Popen( 84 | cmd, shell=True, stdout=subprocess.PIPE, bufsize=buffering 85 | ) 86 | return _wrap_close(proc.stdout, proc) 87 | elif mode == "w": 88 | proc = subprocess.Popen( 89 | cmd, shell=True, stdin=subprocess.PIPE, bufsize=buffering 90 | ) 91 | return _wrap_close( 92 | io.TextIOWrapper(proc.stdin, encoding=default_encoding), proc 93 | ) 94 | elif mode == "wb": 95 | proc = subprocess.Popen( 96 | cmd, shell=True, stdin=subprocess.PIPE, bufsize=buffering 97 | ) 98 | return _wrap_close(proc.stdin, proc) 99 | else: 100 | raise TypeError("Unsupported mode == {}".format(mode)) 101 | 102 | else: 103 | my_popen = os.popen 104 | 105 | 106 | class _wrap_close(object): 107 | """Originated from python os module 108 | 109 | A proxy for a file whose close waits for the process 110 | """ 111 | 112 | def __init__(self, stream, proc): 113 | self._stream = stream 114 | self._proc = proc 115 | 116 | def close(self): 117 | self._stream.close() 118 | returncode = self._proc.wait() 119 | if returncode == 0: 120 | return None 121 | if os.name == "nt": 122 | return returncode 123 | else: 124 | return returncode << 8 # Shift left to match old behavior 125 | 126 | def __enter__(self): 127 | return self 128 | 129 | def __exit__(self, *args): 130 | self.close() 131 | 132 | def __getattr__(self, name): 133 | return getattr(self._stream, name) 134 | 135 | def __iter__(self): 136 | return iter(self._stream) 137 | 138 | 139 | class _stdstream_wrap(object): 140 | def __init__(self, fd): 141 | self.fd = fd 142 | 143 | def __enter__(self): 144 | return self.fd 145 | 146 | def __exit__(self, *args): 147 | # Never close 148 | pass 149 | 150 | def close(self): 151 | # Never close 152 | pass 153 | 154 | def __getattr__(self, name): 155 | return getattr(self.fd, name) 156 | 157 | def __iter__(self): 158 | return iter(self.fd) 159 | 160 | 161 | def open_like_kaldi(name, mode="r"): 162 | """Open a file like kaldi io 163 | 164 | Args: 165 | name (str or file): 166 | mode (str): 167 | """ 168 | # If file descriptor 169 | if not isinstance(name, string_types): 170 | if PY3 and "b" in mode and isinstance(name, TextIOBase): 171 | return name.buffer 172 | else: 173 | return name 174 | 175 | # If writting to stdout 176 | if name.strip().endswith("|"): 177 | cmd = name.strip()[:-1].encode(default_encoding) 178 | return my_popen(cmd, mode) 179 | # If reading from stdin 180 | elif name.strip().startswith("|"): 181 | cmd = name.strip()[1:].encode(default_encoding) 182 | return my_popen(cmd, mode) 183 | # If read mode 184 | elif name == "-" and "r" in mode: 185 | if PY3: 186 | if mode == "rb": 187 | return _stdstream_wrap(sys.stdin.buffer) 188 | else: 189 | return _stdstream_wrap( 190 | io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) 191 | ) 192 | else: 193 | return _stdstream_wrap(sys.stdin) 194 | # If write mode 195 | elif name == "-" and ("w" in mode or "a" in mode): 196 | if PY3: 197 | if mode == "wb" or mode == "ab": 198 | return _stdstream_wrap(sys.stdout.buffer) 199 | else: 200 | return _stdstream_wrap( 201 | io.TextIOWrapper(sys.stdout.buffer, encoding=default_encoding) 202 | ) 203 | else: 204 | return _stdstream_wrap(sys.stdout) 205 | else: 206 | encoding = None if "b" in mode else default_encoding 207 | return io.open(name, mode, encoding=encoding) 208 | 209 | 210 | @contextmanager 211 | def open_or_fd(fname, mode): 212 | # If fname is a file name 213 | if isinstance(fname, string_types): 214 | encoding = None if "b" in mode else default_encoding 215 | f = io.open(fname, mode, encoding=encoding) 216 | # If fname is a file descriptor 217 | else: 218 | if PY3 and "b" in mode and isinstance(fname, TextIOBase): 219 | f = fname.buffer 220 | else: 221 | f = fname 222 | yield f 223 | 224 | if isinstance(fname, string_types): 225 | f.close() 226 | 227 | 228 | class MultiFileDescriptor(object): 229 | """What is this class? 230 | 231 | First of all, I want to load all format kaldi files 232 | only by using read_kaldi function, and I want to load it 233 | from file and file descriptor including standard input stream. 234 | To judge its file format it is required to make the 235 | file descriptor read and seek(to return original position). 236 | However, stdin is not seekable, so I create this clas. 237 | This class joints multiple file descriptors 238 | and I assume this class is used as follwoing, 239 | 240 | >>> string = fd.read(size) 241 | >>> # To check format from string 242 | >>> _fd = StringIO(string) 243 | >>> newfd = MultiFileDescriptor(_fd, fd) 244 | """ 245 | 246 | def __init__(self, *fds): 247 | self.fds = fds 248 | 249 | if self.seekable(): 250 | self.init_pos = [f.tell() for f in self.fds] 251 | else: 252 | self.init_pos = None 253 | 254 | def seek(self, offset, from_what=0): 255 | if not self.seekable(): 256 | if PY3: 257 | raise OSError 258 | else: 259 | raise IOError 260 | if offset != 0: 261 | raise NotImplementedError("offset={}".format(offset)) 262 | if from_what == 1: 263 | offset += self.tell() 264 | from_what = 0 265 | 266 | if from_what == 0: 267 | for idx, f in enumerate(self.fds): 268 | pos = self.init_pos[idx] 269 | f.seek(pos + offset, 0) 270 | offset -= f.tell() - pos 271 | else: 272 | raise NotImplementedError("from_what={}".format(from_what)) 273 | 274 | def seekable(self): 275 | return all(seekable(f) for f in self.fds) 276 | 277 | def tell(self): 278 | if not self.seekable(): 279 | if PY3: 280 | raise OSError 281 | else: 282 | raise IOError 283 | return sum(f.tell() - self.init_pos[idx] for idx, f in enumerate(self.fds)) 284 | 285 | def read(self, size=-1): 286 | remain = size 287 | string = None 288 | for f in self.fds: 289 | if string is None: 290 | string = f.read(remain) 291 | else: 292 | string += f.read(remain) 293 | remain = size - len(string) 294 | if remain == 0: 295 | break 296 | elif remain < 0: 297 | remain = -1 298 | return string 299 | 300 | def readline(self, size=-1): 301 | remain = size 302 | string = None 303 | for f in self.fds: 304 | if string is None: 305 | string = f.readline(remain) 306 | else: 307 | string += f.readline(remain) 308 | if isinstance(string, text_type): 309 | if string.endswith("\n"): 310 | break 311 | else: 312 | if string.endswith(b"\n"): 313 | break 314 | remain = size - len(string) 315 | if remain == 0: 316 | break 317 | elif remain < 0: 318 | remain = -1 319 | return string 320 | 321 | 322 | class CountFileDescriptor(object): 323 | def __init__(self, f): 324 | self.f = f 325 | self.position = 0 326 | 327 | def close(self): 328 | return self.f.close() 329 | 330 | def closed(self): 331 | return self.f.closed() 332 | 333 | def fileno(self): 334 | return self.f.flieno() 335 | 336 | def flush(self): 337 | return self.f.flush() 338 | 339 | def isatty(self): 340 | return self.f.isatty() 341 | 342 | def readbale(self): 343 | return self.f.readable() 344 | 345 | def readline(self, size=-1): 346 | line = self.f.readline(size) 347 | self.position += len(line) 348 | return line 349 | 350 | def readlines(self, hint=-1): 351 | lines = self.f.readlines(hint) 352 | for line in lines: 353 | self.position += len(line) 354 | return lines 355 | 356 | def seek(self, offset, whence=0): 357 | raise RuntimeError("Can't use seek") 358 | 359 | def seekable(self): 360 | return False 361 | 362 | def tell(self): 363 | return self.f.tell() 364 | 365 | def truncate(self, size=None): 366 | return self.f.trauncate(size) 367 | 368 | def writable(self): 369 | return self.f.writable() 370 | 371 | def writelines(self, lines): 372 | for line in lines: 373 | self.position += len(line) 374 | return self.f.writelines(lines) 375 | 376 | def __del__(self): 377 | return self.__del__() 378 | 379 | def read(self, size=-1): 380 | data = self.f.read(size) 381 | self.position += len(data) 382 | return data 383 | 384 | def readall(self): 385 | data = self.f.readall() 386 | self.position += len(data) 387 | return data 388 | 389 | def readinfo(self, b): 390 | nbyte = self.f.readinfo(b) 391 | self.position += nbyte 392 | return nbyte 393 | 394 | def write(self, b): 395 | self.position += b 396 | self.write(b) 397 | 398 | 399 | def parse_specifier(specifier): 400 | """A utility to parse "specifier" 401 | 402 | Args: 403 | specifier (str): 404 | Returns: 405 | parsed_dict (OrderedDict): 406 | Like {'ark': 'file.ark', 'scp': 'file.scp'} 407 | 408 | 409 | >>> d = parse_specifier('ark,t,scp:file.ark,file.scp') 410 | >>> print(d['ark,t']) 411 | file.ark 412 | 413 | """ 414 | sp = specifier.split(":", 1) 415 | if len(sp) != 2: 416 | if ":" not in specifier: 417 | raise ValueError( 418 | "The output file must be specified with " 419 | "kaldi-specifier style," 420 | " e.g. ark,scp:out.ark,out.scp, but you gave as " 421 | "{}".format(specifier) 422 | ) 423 | 424 | types, files = sp 425 | types = types.split(",") 426 | if "ark" not in types and "scp" not in types: 427 | raise ValueError( 428 | "One of/both ark and scp is required: " 429 | "e.g. ark,scp:out.ark,out.scp: " 430 | "{}".format(specifier) 431 | ) 432 | elif "ark" in types and "scp" in types: 433 | if "," not in files: 434 | raise ValueError( 435 | "You specified both ark and scp, " 436 | "but a file path is given: " 437 | "e.g. ark,scp:out.ark,out.scp: {}".format(specifier) 438 | ) 439 | files = files.split(",", 1) 440 | else: 441 | files = [files] 442 | 443 | spec_dict = { 444 | "ark": None, 445 | "scp": None, 446 | "t": False, # text 447 | "o": False, # once 448 | "p": False, # permissive 449 | "f": False, # flush 450 | "s": False, # sorted 451 | "cs": False, # called-sorted 452 | } 453 | for t in types: 454 | if t not in spec_dict: 455 | raise ValueError("Unknown option {}({})".format(t, types)) 456 | if t in ("scp", "ark"): 457 | if spec_dict[t] is not None: 458 | raise ValueError("You specified {} twice".format(t)) 459 | spec_dict[t] = files.pop(0) 460 | else: 461 | spec_dict[t] = True 462 | 463 | return spec_dict 464 | 465 | 466 | class LazyLoader(MutableMapping): 467 | """Don't use this class directly""" 468 | 469 | def __init__(self, loader): 470 | self._dict = {} 471 | self._loader = loader 472 | 473 | def __repr__(self): 474 | return "LazyLoader [{} keys]".format(len(self)) 475 | 476 | def __getitem__(self, key): 477 | ark_name = self._dict[key] 478 | try: 479 | return self._loader(ark_name) 480 | except Exception: 481 | warnings.warn('An error happend when loading "{}"'.format(ark_name)) 482 | raise 483 | 484 | def __setitem__(self, key, value): 485 | self._dict[key] = value 486 | 487 | def __delitem__(self, key): 488 | del self._dict[key] 489 | 490 | def __iter__(self): 491 | return self._dict.__iter__() 492 | 493 | def __len__(self): 494 | return len(self._dict) 495 | 496 | def __contains__(self, item): 497 | return item in self._dict 498 | 499 | 500 | def seekable(f): 501 | if hasattr(f, "seekable"): 502 | return f.seekable() 503 | 504 | # For Py2 505 | else: 506 | if hasattr(f, "tell"): 507 | try: 508 | f.tell() 509 | except (IOError, OSError): 510 | return False 511 | else: 512 | return True 513 | else: 514 | return False 515 | 516 | 517 | class LimitedSizeDict(MutableMapping): 518 | def __init__(self, maxsize): 519 | self._maxsize = maxsize 520 | self.data = OrderedDict() 521 | 522 | def __repr__(self): 523 | return repr(self.data) 524 | 525 | def __setitem__(self, key, value): 526 | if len(self) >= self._maxsize: 527 | self.data.pop(next(iter(self.data))) 528 | 529 | self.data[key] = value 530 | 531 | def __getitem__(self, item): 532 | return self.data[item] 533 | 534 | def __delitem__(self, key): 535 | self._maxsize -= 1 536 | del self.data[key] 537 | 538 | def __iter__(self): 539 | return iter(self.data) 540 | 541 | def __len__(self): 542 | return len(self.data) 543 | -------------------------------------------------------------------------------- /kaldiio/wavio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import wave 3 | 4 | import kaldiio.python_wave as wave 5 | 6 | def read_wav(fd, return_size=False): 7 | wd = wave.open(fd) 8 | rate = wd.getframerate() 9 | nchannels = wd.getnchannels() 10 | nbytes = wd.getsampwidth() 11 | if nbytes == 1: 12 | # 8bit-PCM is unsigned 13 | dtype = "uint8" 14 | elif nbytes == 2: 15 | dtype = "int16" 16 | else: 17 | raise ValueError("bytes_per_sample must be 1, 2, 4 or 8") 18 | data = wd.readframes(wd.getnframes()) 19 | size = 44 + len(data) 20 | array = np.frombuffer(data, dtype=np.dtype(dtype)) 21 | if nchannels > 1: 22 | array = array.reshape(-1, nchannels) 23 | 24 | if return_size: 25 | return (rate, array), size 26 | else: 27 | return rate, array 28 | 29 | 30 | def write_wav(fd, rate, array): 31 | if array.dtype == np.uint8: 32 | sampwidth = 1 33 | elif array.dtype == np.int16: 34 | sampwidth = 2 35 | else: 36 | raise ValueError("Not Supported dtype {}".format(array.dtype)) 37 | 38 | if array.ndim == 2: 39 | nchannels = array.shape[1] 40 | elif array.ndim == 1: 41 | nchannels = 1 42 | else: 43 | raise ValueError( 44 | "Not Supported dimension: 0 or 1, but got {}".format(array.ndim) 45 | ) 46 | 47 | w = wave.Wave_write(fd) 48 | w.setnchannels(nchannels) 49 | w.setsampwidth(sampwidth) 50 | w.setframerate(rate) 51 | data = array.tobytes() 52 | w.writeframes(data) 53 | 54 | return 44 + len(data) 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | [tool:pytest] 4 | testpaths = tests 5 | addopts = --cov=kaldiio --cov-report html 6 | 7 | [flake8] 8 | ignore = W503,E203 9 | max-line-length = 88 10 | [pycodestyle] 11 | ignore = W503,E203 12 | max-line-length = 88 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import io 3 | import os.path 4 | from setuptools import setup 5 | 6 | setup( 7 | name="kaldiio", 8 | version="2.18.1", 9 | description="Kaldi-ark loading and writing module", 10 | author="nttcslab-sp", 11 | # author_email='', 12 | url="https://github.com/nttcslab-sp/kaldiio", 13 | long_description=io.open( 14 | os.path.join(os.path.dirname(__file__), "README.md"), "r", encoding="utf-8" 15 | ).read(), 16 | long_description_content_type="text/markdown", 17 | packages=["kaldiio"], 18 | install_requires=["numpy"], 19 | setup_requires=["pytest-runner"], 20 | tests_require=["pytest", "pytest-cov", "soundfile"], 21 | classifiers=[ 22 | "Development Status :: 5 - Production/Stable", 23 | "Intended Audience :: Science/Research", 24 | "Programming Language :: Python :: 3.8", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: 3.12", 29 | "Programming Language :: Python :: 3.13", 30 | "Topic :: Multimedia :: Sound/Audio :: Analysis", 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttcslab-sp/kaldiio/7533933a533281c22046b53fd92995f87ad05cad/tests/__init__.py -------------------------------------------------------------------------------- /tests/arks/create_arks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # python -c 'import kaldiio as k;import numpy as np;k.save_ark("test.ark", {"test" + str(i): np.random.randn(10, 20).astype(np.float32) for i in range(3)})' 4 | # type=CM 5 | copy-feats --compress=true --compression-method=1 ark:test.ark ark:test.cm1.ark 6 | # type=CM2 7 | copy-feats --compress=true --compression-method=3 ark:test.ark ark:test.cm3.ark 8 | # type=CM3 9 | copy-feats --compress=true --compression-method=5 ark:test.ark ark:test.cm5.ark 10 | copy-feats ark:test.ark ark,t:test.text.ark 11 | -------------------------------------------------------------------------------- /tests/arks/incorrect_header.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttcslab-sp/kaldiio/7533933a533281c22046b53fd92995f87ad05cad/tests/arks/incorrect_header.wav -------------------------------------------------------------------------------- /tests/arks/test.ark: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttcslab-sp/kaldiio/7533933a533281c22046b53fd92995f87ad05cad/tests/arks/test.ark -------------------------------------------------------------------------------- /tests/arks/test.cm1.ark: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttcslab-sp/kaldiio/7533933a533281c22046b53fd92995f87ad05cad/tests/arks/test.cm1.ark -------------------------------------------------------------------------------- /tests/arks/test.cm3.ark: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttcslab-sp/kaldiio/7533933a533281c22046b53fd92995f87ad05cad/tests/arks/test.cm3.ark -------------------------------------------------------------------------------- /tests/arks/test.cm5.ark: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nttcslab-sp/kaldiio/7533933a533281c22046b53fd92995f87ad05cad/tests/arks/test.cm5.ark -------------------------------------------------------------------------------- /tests/arks/test.text.ark: -------------------------------------------------------------------------------- 1 | test0 [ 2 | -0.9555095 1.454947 -1.008481 0.2451129 -0.7873046 -0.0189307 -1.231949 0.3024222 -0.9821944 -2.064091 -0.7519853 -0.7602366 1.534634 -0.3720366 0.9163448 -0.5454536 0.2661778 1.004537 -0.1362433 -0.4863206 3 | -0.1615189 -1.286037 -0.278406 1.383011 -0.4565755 -0.2526966 -0.3234434 -0.2018496 0.07208711 0.2978341 1.110052 1.160063 -0.1374429 0.3717514 0.2956631 0.2012479 1.881859 1.403757 -1.358232 0.9189199 4 | 0.436644 0.8273382 0.6248685 0.4088997 0.6348245 2.072881 0.3483515 0.5029373 0.2768012 1.213575 -1.868161 -0.06029448 1.138834 0.09501906 0.6609563 0.5600834 -1.449978 0.239671 -0.1063323 -0.09148861 5 | 1.230149 0.7318826 -0.02516201 0.8443054 0.04025598 1.208913 -0.2594405 -0.8716439 -1.62826 -0.1396292 -0.1044207 -1.550279 0.5556436 -0.008538663 -1.173256 -1.424557 2.202492 -0.5055786 -0.4521027 -1.164939 6 | -1.892946 -0.2540782 0.1985293 -0.6708191 -1.000573 0.07442529 1.083191 -0.1431073 0.44705 1.10043 -0.7878925 0.06547084 1.782503 -0.3680364 1.05319 -1.792711 -0.2661398 0.680892 3.015188 -0.2786612 7 | -0.9575458 -1.241638 0.6394869 -0.02755863 0.7843663 0.1754117 -0.6328281 1.110241 -2.619615 -0.3521794 0.2037585 -0.7685624 -0.06049781 -0.7342759 -0.07526525 0.3623313 0.6251125 -1.701245 1.416456 0.8119849 8 | -0.0428966 -1.112612 -2.030228 0.4425827 0.8099385 -1.63304 -1.452492 -1.289263 -0.7333626 1.709762 0.4754612 1.028463 -0.1699338 1.639834 0.2369575 -0.6846613 -0.389258 -0.563786 -0.711178 0.9648252 9 | -1.284601 -1.447733 -1.907235 -0.08635759 0.4865483 1.054216 -1.376037 -0.04105163 0.8636565 1.489118 -0.1840287 -2.115296 -0.4525339 -0.03172754 2.250825 0.7784247 -0.6923981 0.714159 -0.02567131 -0.8188892 10 | 0.16787 -1.511532 -0.07795387 0.1117329 1.464449 3.132158 -0.3491844 -0.3571101 0.3014558 -0.8904001 0.8553936 0.6268699 0.4164898 1.413707 0.5557342 1.287184 -0.2259016 1.513568 -1.217759 -0.0979333 11 | 1.02295 -0.6095613 -0.7592701 -1.165664 0.8258702 -1.100448 -1.34527 0.2474493 0.2693125 0.4231107 -0.4276125 -0.0122413 -1.56883 -0.7808103 0.9110163 -0.2070457 -1.141785 0.1767025 -1.302306 -0.6661522 ] 12 | test1 [ 13 | -0.6844308 0.9551033 0.1185881 -0.3902329 -0.4759946 -0.300018 -0.5822846 0.1268353 -0.2234675 -0.7354357 1.00286 -0.7464286 2.196428 -1.179742 0.135118 0.4769018 -0.327175 0.7099938 0.5665276 0.6952059 14 | 0.7880214 -0.1726646 -0.3973365 0.3147721 -0.1853947 0.6145713 0.4840802 -0.4425978 -0.698534 0.7619419 0.6283662 -0.09393702 0.0005893837 -0.629052 -0.301087 1.508289 -0.4314308 -1.290576 -0.6170319 -0.1029666 15 | 0.01357115 -0.2854042 -1.641771 -0.1965098 0.5777639 -1.136725 -0.02763914 -0.7202342 0.158676 -2.51132 0.5286252 -1.848845 0.9032698 -0.3685609 -0.3868743 -1.406672 -0.04530309 0.4566102 -0.05803674 -0.5606434 16 | 0.5957915 2.08742 0.6181515 1.760105 -1.286737 0.1731578 0.3143925 0.04014806 0.6161535 2.186375 -0.2330954 1.616964 -1.152532 -0.7229942 -2.632406 0.005835837 1.281452 -1.161146 -0.7036497 1.447765 17 | -0.1019134 -0.2462025 1.631101 -0.7971479 0.0628774 1.593686 -0.3647632 -1.16713 -0.7721201 -0.4507362 -0.3335608 0.6188009 -0.7608762 -0.8286696 0.8494864 -0.5085956 -0.1838554 1.13931 -1.55555 -0.8727736 18 | 0.2715224 -0.09891139 0.115285 0.343922 0.8591443 -0.6102904 -1.971848 -0.9055794 0.8599834 0.04912258 -0.08519455 -0.3412841 -1.247487 0.8330864 0.4989747 -0.878387 -0.7572019 -1.737416 0.2565427 0.2976327 19 | -0.09848356 1.193955 1.099459 -0.06504259 1.904736 0.9157236 1.281467 -0.7776058 -1.457422 -1.467779 0.7131976 -0.03832197 0.3191243 1.171605 1.839011 -1.506045 -2.226002 -0.2974362 -0.6475887 -1.523697 20 | 1.024535 -0.1319958 1.807397 1.403119 -0.6603892 -0.5215808 -1.969197 0.338027 -0.7207099 1.025891 -0.3526199 0.6416234 2.117505 0.833358 -0.3696733 2.189285 1.373135 0.2757411 -0.8004436 0.6600307 21 | -0.1197911 0.1083441 -1.002847 -0.1523209 0.1178646 -0.7255216 0.108596 -1.03938 0.03690559 0.1410743 -0.3745326 -0.8649802 0.6908674 0.05708767 0.7592875 -0.5275413 -0.2940069 -1.103511 0.7052273 -0.2835359 22 | -1.032998 -0.8557081 0.3231594 -0.7145321 -0.5193874 -0.7942839 0.3961963 -0.1388498 0.2117817 -0.7407426 -0.9254794 0.2423774 0.1330372 -1.828296 -0.1919031 -0.7574898 -1.537131 0.0829481 0.4538945 0.6518058 ] 23 | test2 [ 24 | -1.701658 -2.680414 1.175492 -0.1581941 -0.1296562 -1.927805 0.8222864 -0.9010202 -0.3237416 -0.7964234 1.379909 0.2361659 0.4044839 0.5384488 -0.588733 0.9855443 0.06156125 -0.6611108 -2.786002 0.1846388 25 | -1.078388 0.3899579 -0.739109 0.7121035 -0.2019595 -1.002578 -0.2465368 -0.5528726 0.1960258 1.06741 0.2354005 -1.629054 -0.6344082 0.7297142 1.077592 -0.8398266 0.9826697 0.2997488 0.3143 -1.942825 26 | -0.237495 -0.8185795 1.180302 -0.1031644 1.136648 -1.072904 0.3376451 -0.2180097 1.171386 0.5075268 1.168654 -1.278127 0.5058408 0.7390035 -0.3650108 -0.8843033 -1.320734 0.3514549 -0.9180931 -1.165465 27 | -1.004107 -0.8204701 1.262943 0.1553162 0.2463018 0.7926398 -0.3523607 0.2606941 -0.7753636 -0.08443997 -1.616649 0.634791 0.490848 1.647278 0.2318057 -1.190175 -1.065687 0.7044011 0.6867521 0.6456593 28 | -1.353486 -0.4680033 0.8019905 -2.387373 1.327594 0.2467267 -1.404506 0.5431932 -1.262602 -0.05173817 -1.574367 -1.738925 0.434441 -2.093164 -0.5382612 1.363735 -1.982175 0.1070962 0.3393866 -1.499526 29 | -1.261056 -0.1887451 2.259133 0.1043354 1.471273 -0.4148357 1.286983 0.1118475 0.1243153 0.1026321 -0.9160644 1.208413 0.5685107 0.3896375 0.2734956 0.2725626 0.6029116 1.583123 0.1414515 0.9257417 30 | -3.863745 0.1899447 0.546055 -1.346787 1.632556 1.969914 -0.2030162 0.03311367 -0.365466 -0.3592365 0.7666726 1.123307 0.05261933 -0.7461058 -0.2659309 -1.877691 -1.009092 1.757111 0.7521642 0.4094087 31 | 0.9262108 0.9081749 -1.092547 -0.504918 0.2030315 -0.1855687 1.332459 0.2603446 -0.2249646 -0.9849409 -1.436265 1.232928 0.3634579 -0.6490468 -0.07655149 0.1481647 1.445654 -0.53654 -0.09138415 -1.729867 32 | 1.682236 0.8436629 0.2957647 0.9255861 -0.002097577 -0.2440814 0.4314352 -0.2439245 -1.310752 0.1134503 -0.5786554 0.8362756 0.3380536 -0.1914093 -0.1595648 1.097495 -0.8916295 -0.1022337 -0.7276955 1.207484 33 | 0.8573562 0.7800487 -0.8665044 0.5702928 0.7183536 2.113982 -0.2193797 -2.491414 0.1160529 -0.2336408 0.2698385 -0.299866 0.9956051 2.306329 -0.2409541 0.06064559 1.560999 -0.9510009 0.5474395 -0.907324 ] 34 | -------------------------------------------------------------------------------- /tests/test_extended_ark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from kaldiio.matio import load_ark 5 | from kaldiio.matio import load_scp 6 | from kaldiio.matio import load_scp_sequential 7 | from kaldiio.matio import save_ark 8 | from kaldiio.utils import open_like_kaldi 9 | 10 | 11 | @pytest.mark.parametrize("dtype", [np.int16]) 12 | @pytest.mark.parametrize("func", [load_scp, load_scp_sequential]) 13 | @pytest.mark.parametrize("write_function", ["soundfile", "pickle", "numpy"]) 14 | def test_read_write(tmpdir, func, dtype, write_function): 15 | path = tmpdir.mkdir("test") 16 | ark = path.join("a.ark").strpath 17 | scp = path.join("a.scp").strpath 18 | 19 | # Write as pcm16 20 | array = np.random.randint(-1000, 1000, 100).astype(np.double) / float( 21 | abs(np.iinfo(np.int16).min) 22 | ) 23 | array2 = np.random.randint(-1000, 1000, 100).astype(np.double) / abs( 24 | np.iinfo(np.int16).min 25 | ) 26 | 27 | if write_function == "numpy": 28 | d = {"utt": array, "utt2": array2} 29 | else: 30 | d = {"utt": (8000, array), "utt2": (8000, array2)} 31 | save_ark(ark, d, scp=scp, write_function=write_function) 32 | 33 | d = dict(func(scp)) 34 | if write_function == "numpy": 35 | test = d["utt"] 36 | else: 37 | rate, test = d["utt"] 38 | assert rate == 8000 39 | np.testing.assert_allclose(array, test) 40 | 41 | if write_function == "numpy": 42 | test = d["utt2"] 43 | else: 44 | rate, test = d["utt2"] 45 | assert rate == 8000 46 | np.testing.assert_allclose(array2, test) 47 | 48 | d = dict(load_ark(ark)) 49 | if write_function == "numpy": 50 | test = d["utt"] 51 | else: 52 | rate, test = d["utt"] 53 | assert rate == 8000 54 | np.testing.assert_allclose(array, test) 55 | 56 | if write_function == "numpy": 57 | test = d["utt2"] 58 | else: 59 | rate, test = d["utt2"] 60 | assert rate == 8000 61 | np.testing.assert_allclose(array2, test) 62 | 63 | 64 | @pytest.mark.parametrize("dtype", [np.int16]) 65 | @pytest.mark.parametrize("write_function", ["soundfile", "pickle", "numpy"]) 66 | def test_wavark_stream(tmpdir, dtype, write_function): 67 | path = tmpdir.mkdir("test") 68 | ark = path.join("a.ark").strpath 69 | 70 | # Write as pcm16 71 | array = np.random.randint(-1000, 1000, 100).astype(np.double) / abs( 72 | np.iinfo(np.int16).min 73 | ) 74 | array2 = np.random.randint(-1000, 1000, 100).astype(np.double) / abs( 75 | np.iinfo(np.int16).min 76 | ) 77 | if write_function == "numpy": 78 | d = {"utt": array, "utt2": array2} 79 | else: 80 | d = {"utt": (8000, array), "utt2": (8000, array2)} 81 | save_ark(ark, d, write_function=write_function) 82 | 83 | with open_like_kaldi("cat {}|".format(ark), "rb") as f: 84 | d = dict(load_ark(f)) 85 | if write_function == "numpy": 86 | test = d["utt"] 87 | else: 88 | rate, test = d["utt"] 89 | assert rate == 8000 90 | np.testing.assert_allclose(array, test) 91 | 92 | if write_function == "numpy": 93 | test = d["utt2"] 94 | else: 95 | rate, test = d["utt2"] 96 | assert rate == 8000 97 | np.testing.assert_allclose(array2, test) 98 | -------------------------------------------------------------------------------- /tests/test_highlevel.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from kaldiio.highlevel import ReadHelper 4 | from kaldiio.highlevel import WriteHelper 5 | from kaldiio.matio import load_ark 6 | from kaldiio.matio import load_scp 7 | from kaldiio.matio import save_ark 8 | from kaldiio.wavio import write_wav 9 | 10 | 11 | def test_read_helper(tmpdir): 12 | path = tmpdir.strpath 13 | array_in = numpy.random.randn(10, 10) 14 | save_ark( 15 | "{}/feats.ark".format(path), {"foo": array_in}, scp="{}/feats.scp".format(path) 16 | ) 17 | helper = ReadHelper("ark:cat {}/feats.ark |".format(path)) 18 | for uttid, array_out in helper: 19 | assert uttid == "foo" 20 | numpy.testing.assert_array_equal(array_in, array_out) 21 | 22 | helper = ReadHelper("scp:{}/feats.scp".format(path)) 23 | for uttid, array_out in helper: 24 | assert uttid == "foo" 25 | numpy.testing.assert_array_equal(array_in, array_out) 26 | 27 | 28 | def test_read_helper_ascii(tmpdir): 29 | path = tmpdir.strpath 30 | array_in = numpy.random.randn(10, 10) 31 | save_ark( 32 | "{}/feats.ark".format(path), 33 | {"foo": array_in}, 34 | scp="{}/feats.scp".format(path), 35 | text=True, 36 | ) 37 | helper = ReadHelper("ark:cat {}/feats.ark |".format(path)) 38 | for uttid, array_out in helper: 39 | assert uttid == "foo" 40 | numpy.testing.assert_allclose(array_in, array_out) 41 | 42 | helper = ReadHelper("ark:{}/feats.ark".format(path)) 43 | for uttid, array_out in helper: 44 | assert uttid == "foo" 45 | numpy.testing.assert_allclose(array_in, array_out) 46 | 47 | 48 | def test_write_helper(tmpdir): 49 | path = tmpdir.strpath 50 | d = {"foo": numpy.random.randn(10, 10), "bar": numpy.random.randn(10, 10)} 51 | 52 | with WriteHelper("ark,f,scp:{p}/out.ark,{p}/out.scp".format(p=path)) as w: 53 | for k, v in d.items(): 54 | w(k, v) 55 | from_ark = dict(load_ark("{p}/out.ark".format(p=path))) 56 | from_scp = load_scp("{p}/out.scp".format(p=path)) 57 | _compare(from_ark, d) 58 | _compare(from_scp, d) 59 | 60 | 61 | def test_write_helper_scp_ark(tmpdir): 62 | path = tmpdir.strpath 63 | d = {"foo": numpy.random.randn(10, 10), "bar": numpy.random.randn(10, 10)} 64 | 65 | with WriteHelper("scp,f,ark:{p}/out.scp,{p}/out.ark".format(p=path)) as w: 66 | for k, v in d.items(): 67 | w(k, v) 68 | from_ark = dict(load_ark("{p}/out.ark".format(p=path))) 69 | from_scp = load_scp("{p}/out.scp".format(p=path)) 70 | _compare(from_ark, d) 71 | _compare(from_scp, d) 72 | 73 | 74 | def test_write_helper_ascii(tmpdir): 75 | path = tmpdir.strpath 76 | d = {"foo": numpy.random.randn(10, 10), "bar": numpy.random.randn(10, 10)} 77 | 78 | with WriteHelper("ark,t,f,scp:{p}/out.ark,{p}/out.scp".format(p=path)) as w: 79 | for k, v in d.items(): 80 | w(k, v) 81 | from_ark = dict(load_ark("{p}/out.ark".format(p=path))) 82 | from_scp = load_scp("{p}/out.scp".format(p=path)) 83 | _compare_allclose(from_ark, d) 84 | _compare_allclose(from_scp, d) 85 | 86 | 87 | def test_scpwav_stream(tmpdir): 88 | path = tmpdir.mkdir("test") 89 | wav = path.join("aaa.wav").strpath 90 | wav2 = path.join("bbb.wav").strpath 91 | scp = path.join("wav.scp").strpath 92 | 93 | # Write as pcm16 94 | array = numpy.random.randint(0, 10, 10, dtype=numpy.int16) 95 | write_wav(wav, 8000, array) 96 | 97 | array2 = numpy.random.randint(0, 10, 10, dtype=numpy.int16) 98 | write_wav(wav2, 8000, array2) 99 | 100 | valid = {"aaa": array, "bbb": array2} 101 | 102 | with open(scp, "w") as f: 103 | f.write("aaa cat {wav} |\n".format(wav=wav)) 104 | f.write("bbb cat {wav} |\n".format(wav=wav2)) 105 | 106 | with ReadHelper("scp:{}".format(scp)) as r: 107 | for k, (rate, array) in r: 108 | assert rate == 8000 109 | numpy.testing.assert_array_equal(array, valid[k]) 110 | 111 | 112 | def test_segments(tmpdir): 113 | # Create wav.scp 114 | path = tmpdir.mkdir("test") 115 | wavscp = path.join("wav.scp").strpath 116 | 117 | rate = 500 118 | with open(wavscp, "w") as f: 119 | wav = path.join("0.wav").strpath 120 | array0 = numpy.random.randint(0, 10, 2000, dtype=numpy.int16) 121 | write_wav(wav, rate, array0) 122 | f.write("wav0 {}\n".format(wav)) 123 | 124 | wav = path.join("1.wav").strpath 125 | array1 = numpy.random.randint(0, 10, 2000, dtype=numpy.int16) 126 | write_wav(wav, rate, array1) 127 | f.write("wav1 {}\n".format(wav)) 128 | 129 | # Create segments 130 | segments = path.join("segments").strpath 131 | with open(segments, "w") as f: 132 | f.write("utt1 wav0 0.1 0.2\n") 133 | f.write("utt2 wav0 0.4 0.6\n") 134 | f.write("utt3 wav1 0.4 0.5\n") 135 | f.write("utt4 wav1 0.6 0.8\n") 136 | 137 | with ReadHelper("scp:{}".format(wavscp), segments=segments) as r: 138 | d = {k: a for k, a in r} 139 | 140 | numpy.testing.assert_array_equal( 141 | d["utt1"][1], array0[int(0.1 * rate) : int(0.2 * rate)] 142 | ) 143 | numpy.testing.assert_array_equal( 144 | d["utt2"][1], array0[int(0.4 * rate) : int(0.6 * rate)] 145 | ) 146 | numpy.testing.assert_array_equal( 147 | d["utt3"][1], array1[int(0.4 * rate) : int(0.5 * rate)] 148 | ) 149 | numpy.testing.assert_array_equal( 150 | d["utt4"][1], array1[int(0.6 * rate) : int(0.8 * rate)] 151 | ) 152 | 153 | 154 | def _compare(d1, d2): 155 | assert len(d1) != 0 156 | assert set(d1.keys()) == set(d2.keys()) 157 | for key in d1: 158 | numpy.testing.assert_array_equal(d1[key], d2[key]) 159 | 160 | 161 | def _compare_allclose(d1, d2, rtol=1e-07, atol=0.0): 162 | assert len(d1) != 0 163 | assert set(d1.keys()) == set(d2.keys()) 164 | for key in d1: 165 | numpy.testing.assert_allclose(d1[key], d2[key], rtol, atol) 166 | -------------------------------------------------------------------------------- /tests/test_limited_size_dict.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import unicode_literals 3 | 4 | from kaldiio.utils import LimitedSizeDict 5 | 6 | 7 | def test_limted_size_dict(): 8 | d = LimitedSizeDict(3) 9 | d["foo"] = 1 10 | d["bar"] = 2 11 | d["baz"] = 3 12 | 13 | assert "foo" in d 14 | assert "bar" in d 15 | assert "baz" in d 16 | 17 | d["foo2"] = 4 18 | assert "foo" not in d 19 | assert "foo2" in d 20 | 21 | d["bar2"] = 4 22 | assert "bar" not in d 23 | assert "bar2" in d 24 | -------------------------------------------------------------------------------- /tests/test_mat_ark.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import unicode_literals 3 | 4 | import glob 5 | import io 6 | import os 7 | 8 | import numpy as np 9 | import pytest 10 | 11 | import kaldiio 12 | from kaldiio.matio import _parse_arkpath 13 | 14 | arkdir = os.path.join(os.path.dirname(__file__), "arks") 15 | 16 | 17 | @pytest.mark.parametrize("fname", glob.glob(os.path.join(arkdir, "*.ark"))) 18 | def test_read_arks(fname): 19 | # Assume arks dir existing at the same directory 20 | ark0 = dict( 21 | kaldiio.load_ark(os.path.join(os.path.dirname(__file__), "arks", "test.ark")) 22 | ) 23 | ark = dict(kaldiio.load_ark(fname)) 24 | _compare_allclose(ark, ark0, atol=1e-1) 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "shape1,shape2", [[(1000, 120), (10, 120)], [(0, 0), (0, 120)], [(100,), (120,)]] 29 | ) 30 | @pytest.mark.parametrize("endian", ["<", ">"]) 31 | @pytest.mark.parametrize("dtype", [np.float32, np.float64]) 32 | @pytest.mark.parametrize("max_cache_fd", [0, 3]) 33 | def test_write_read(tmpdir, shape1, shape2, endian, dtype, max_cache_fd): 34 | path = tmpdir.mkdir("test") 35 | 36 | a = np.random.rand(*shape1).astype(dtype) 37 | b = np.random.rand(*shape2).astype(dtype) 38 | origin = {"Ï,é,à": a, "あいうえお": b} 39 | kaldiio.save_ark( 40 | path.join("a.ark").strpath, 41 | origin, 42 | scp=path.join("b.scp").strpath, 43 | endian=endian, 44 | ) 45 | 46 | d2 = {k: v for k, v in kaldiio.load_ark(path.join("a.ark").strpath, endian=endian)} 47 | d5 = { 48 | k: v 49 | for k, v in kaldiio.load_scp( 50 | path.join("b.scp").strpath, endian=endian, max_cache_fd=max_cache_fd 51 | ).items() 52 | } 53 | with io.open(path.join("a.ark").strpath, "rb") as fd: 54 | d6 = {k: v for k, v in kaldiio.load_ark(fd, endian=endian)} 55 | _compare(d2, origin) 56 | _compare(d5, origin) 57 | _compare(d6, origin) 58 | 59 | 60 | @pytest.mark.parametrize("endian", ["<", ">"]) 61 | @pytest.mark.parametrize("dtype", [np.float32, np.float64]) 62 | def test_write_read_multiark(tmpdir, endian, dtype): 63 | path = tmpdir.mkdir("test") 64 | 65 | a = np.random.rand(1000, 120).astype(dtype) 66 | b = np.random.rand(10, 120).astype(dtype) 67 | origin = {"Ï,é,à": a, "あいうえお": b} 68 | 69 | kaldiio.save_ark( 70 | path.join("a.ark").strpath, 71 | origin, 72 | scp=path.join("b.scp").strpath, 73 | endian=endian, 74 | ) 75 | 76 | c = np.random.rand(1000, 120).astype(dtype) 77 | d = np.random.rand(10, 120).astype(dtype) 78 | origin.update({"c": c, "d": d}) 79 | with io.open(path.join("b.scp").strpath, "a", encoding="utf-8") as f: 80 | kaldiio.save_ark(path.join("b.ark").strpath, origin, scp=f, endian=endian) 81 | 82 | d5 = { 83 | k: v 84 | for k, v in kaldiio.load_scp(path.join("b.scp").strpath, endian=endian).items() 85 | } 86 | _compare(d5, origin) 87 | 88 | 89 | @pytest.mark.parametrize("endian", ["<", ">"]) 90 | def test_write_read_sequential(tmpdir, endian): 91 | path = tmpdir.mkdir("test") 92 | 93 | a = np.random.rand(1000, 120).astype(np.float32) 94 | b = np.random.rand(10, 120).astype(np.float32) 95 | origin = {"Ï,é,à": a, "あいうえお": b} 96 | kaldiio.save_ark( 97 | path.join("a.ark").strpath, 98 | origin, 99 | scp=path.join("b.scp").strpath, 100 | endian=endian, 101 | ) 102 | 103 | d5 = { 104 | k: v 105 | for k, v in kaldiio.load_scp_sequential( 106 | path.join("b.scp").strpath, endian=endian 107 | ) 108 | } 109 | _compare(d5, origin) 110 | 111 | 112 | @pytest.mark.parametrize("endian", ["<", ">"]) 113 | def test_write_read_multiark_sequential(tmpdir, endian): 114 | path = tmpdir.mkdir("test") 115 | 116 | a = np.random.rand(1000, 120).astype(np.float32) 117 | b = np.random.rand(10, 120).astype(np.float32) 118 | origin = {"Ï,é,à": a, "あいうえお": b} 119 | 120 | kaldiio.save_ark( 121 | path.join("a.ark").strpath, 122 | origin, 123 | scp=path.join("b.scp").strpath, 124 | endian=endian, 125 | ) 126 | 127 | c = np.random.rand(1000, 120).astype(np.float32) 128 | d = np.random.rand(10, 120).astype(np.float32) 129 | origin.update({"c": c, "d": d}) 130 | with io.open(path.join("b.scp").strpath, "a", encoding="utf-8") as f: 131 | kaldiio.save_ark(path.join("b.ark").strpath, origin, scp=f, endian=endian) 132 | 133 | d5 = { 134 | k: v 135 | for k, v in kaldiio.load_scp_sequential( 136 | path.join("b.scp").strpath, endian=endian 137 | ) 138 | } 139 | _compare(d5, origin) 140 | 141 | 142 | def test_write_read_ascii(tmpdir): 143 | path = tmpdir.mkdir("test") 144 | a = np.random.rand(10, 10).astype(np.float32) 145 | b = np.random.rand(5, 35).astype(np.float32) 146 | origin = {"Ï,é,à": a, "あいうえお": b} 147 | kaldiio.save_ark( 148 | path.join("a.ark").strpath, origin, scp=path.join("a.scp").strpath, text=True 149 | ) 150 | d2 = {k: v for k, v in kaldiio.load_ark(path.join("a.ark").strpath)} 151 | d5 = {k: v for k, v in kaldiio.load_scp(path.join("a.scp").strpath).items()} 152 | _compare_allclose(d2, origin) 153 | _compare_allclose(d5, origin) 154 | 155 | 156 | @pytest.mark.parametrize("endian", ["<", ">"]) 157 | def test_write_read_int32_vector(tmpdir, endian): 158 | path = tmpdir.mkdir("test") 159 | 160 | a = np.random.randint(1, 128, 10, dtype=np.int32) 161 | b = np.random.randint(1, 128, 10, dtype=np.int32) 162 | origin = {"Ï,é,à": a, "あいうえお": b} 163 | kaldiio.save_ark( 164 | path.join("a.ark").strpath, 165 | origin, 166 | scp=path.join("b.scp").strpath, 167 | endian=endian, 168 | ) 169 | 170 | d2 = {k: v for k, v in kaldiio.load_ark(path.join("a.ark").strpath, endian=endian)} 171 | d5 = { 172 | k: v 173 | for k, v in kaldiio.load_scp(path.join("b.scp").strpath, endian=endian).items() 174 | } 175 | with io.open(path.join("a.ark").strpath, "rb") as fd: 176 | d6 = {k: v for k, v in kaldiio.load_ark(fd, endian=endian)} 177 | _compare(d2, origin) 178 | _compare(d5, origin) 179 | _compare(d6, origin) 180 | 181 | 182 | def test_write_read_int32_vector_ascii(tmpdir): 183 | path = tmpdir.mkdir("test") 184 | 185 | a = np.random.randint(1, 128, 10, dtype=np.int32) 186 | b = np.random.randint(1, 128, 10, dtype=np.int32) 187 | origin = {"Ï,é,à": a, "あいうえお": b} 188 | kaldiio.save_ark( 189 | path.join("a.ark").strpath, origin, scp=path.join("b.scp").strpath, text=True 190 | ) 191 | 192 | d2 = {k: v for k, v in kaldiio.load_ark(path.join("a.ark").strpath)} 193 | d5 = {k: v for k, v in kaldiio.load_scp(path.join("b.scp").strpath).items()} 194 | with io.open(path.join("a.ark").strpath, "rb") as fd: 195 | d6 = {k: v for k, v in kaldiio.load_ark(fd)} 196 | _compare_allclose(d2, origin) 197 | _compare_allclose(d5, origin) 198 | _compare_allclose(d6, origin) 199 | 200 | 201 | @pytest.mark.parametrize("compression_method", [1, 3, 5]) 202 | def test_write_compressed_arks(tmpdir, compression_method): 203 | # Assume arks dir existing at the same directory 204 | ark0 = dict( 205 | kaldiio.load_ark(os.path.join(os.path.dirname(__file__), "arks", "test.ark")) 206 | ) 207 | path = tmpdir.mkdir("test").join("c.ark").strpath 208 | kaldiio.save_ark(path, ark0, compression_method=compression_method) 209 | arkc = dict(kaldiio.load_ark(path)) 210 | arkc_valid = dict( 211 | kaldiio.load_ark( 212 | os.path.join( 213 | os.path.dirname(__file__), 214 | "arks", 215 | "test.cm{}.ark".format(compression_method), 216 | ) 217 | ) 218 | ) 219 | _compare_allclose(arkc, arkc_valid, atol=1e-4) 220 | 221 | 222 | @pytest.mark.parametrize("endian", ["<", ">"]) 223 | @pytest.mark.parametrize("compression_method", [2, 3, 7]) 224 | def test_write_read_compress(tmpdir, compression_method, endian): 225 | path = tmpdir.mkdir("test") 226 | 227 | a = np.random.rand(1000, 120).astype(np.float32) 228 | b = np.random.rand(10, 120).astype(np.float32) 229 | origin = {"Ï,é,à": a, "あいうえお": b} 230 | kaldiio.save_ark( 231 | path.join("a.ark").strpath, 232 | origin, 233 | scp=path.join("b.scp").strpath, 234 | compression_method=compression_method, 235 | endian=endian, 236 | ) 237 | 238 | d2 = {k: v for k, v in kaldiio.load_ark(path.join("a.ark").strpath, endian=endian)} 239 | d5 = { 240 | k: v 241 | for k, v in kaldiio.load_scp(path.join("b.scp").strpath, endian=endian).items() 242 | } 243 | with io.open(path.join("a.ark").strpath, "rb") as fd: 244 | d6 = {k: v for k, v in kaldiio.load_ark(fd, endian=endian)} 245 | _compare_allclose(d2, origin, atol=1e-1) 246 | _compare_allclose(d5, origin, atol=1e-1) 247 | _compare_allclose(d6, origin, atol=1e-1) 248 | 249 | 250 | def test_append_mode(tmpdir): 251 | path = tmpdir.mkdir("test") 252 | 253 | a = np.random.rand(1000, 120).astype(np.float32) 254 | b = np.random.rand(10, 120).astype(np.float32) 255 | origin = {"Ï,é,à": a, "あいうえお": b} 256 | kaldiio.save_ark(path.join("a.ark").strpath, origin, scp=path.join("b.scp").strpath) 257 | 258 | kaldiio.save_ark( 259 | path.join("a2.ark").strpath, 260 | {"Ï,é,à": a}, 261 | scp=path.join("b2.scp").strpath, 262 | append=True, 263 | ) 264 | kaldiio.save_ark( 265 | path.join("a2.ark").strpath, 266 | {"あいうえお": b}, 267 | scp=path.join("b2.scp").strpath, 268 | append=True, 269 | ) 270 | d1 = {k: v for k, v in kaldiio.load_ark(path.join("a.ark").strpath)} 271 | d2 = {k: v for k, v in kaldiio.load_scp(path.join("b.scp").strpath).items()} 272 | d3 = {k: v for k, v in kaldiio.load_ark(path.join("a2.ark").strpath)} 273 | d4 = {k: v for k, v in kaldiio.load_scp(path.join("b2.scp").strpath).items()} 274 | _compare(d1, origin) 275 | _compare(d2, origin) 276 | _compare(d3, origin) 277 | _compare(d4, origin) 278 | 279 | 280 | @pytest.mark.parametrize("endian", ["<", ">"]) 281 | @pytest.mark.parametrize("dtype", [np.float32, np.float64]) 282 | def test_write_read_mat(tmpdir, endian, dtype): 283 | path = tmpdir.mkdir("test") 284 | valid = np.random.rand(1000, 120).astype(dtype) 285 | kaldiio.save_mat(path.join("a.mat").strpath, valid, endian=endian) 286 | test = kaldiio.load_mat(path.join("a.mat").strpath, endian=endian) 287 | np.testing.assert_array_equal(test, valid) 288 | 289 | 290 | def test__parse_arkpath(): 291 | assert _parse_arkpath("a.ark") == ("a.ark", None, None) 292 | assert _parse_arkpath("a.ark:12") == ("a.ark", 12, None) 293 | assert _parse_arkpath("a.ark:12[4]") == ("a.ark", 12, (slice(4, 5, None),)) 294 | assert _parse_arkpath("a.ark:12[3:4]") == ("a.ark", 12, (slice(3, 5, None),)) 295 | assert _parse_arkpath("a.ark:12[3:10:2]") == ("a.ark", 12, (slice(3, 11, 2),)) 296 | assert _parse_arkpath("a.ark:12[2:6,3:4]") == ( 297 | "a.ark", 298 | 12, 299 | (slice(2, 7), slice(3, 5, None)), 300 | ) 301 | assert _parse_arkpath('cat "fo:o.ark" |') == ('cat "fo:o.ark" |', None, None) 302 | 303 | 304 | def _compare(d1, d2): 305 | assert len(d1) != 0 306 | assert set(d1.keys()) == set(d2.keys()) 307 | for key in d1: 308 | np.testing.assert_array_equal(d1[key], d2[key]) 309 | 310 | 311 | def _compare_allclose(d1, d2, rtol=1e-07, atol=0.0): 312 | assert len(d1) != 0 313 | assert set(d1.keys()) == set(d2.keys()) 314 | for key in d1: 315 | np.testing.assert_allclose(d1[key], d2[key], rtol, atol) 316 | 317 | 318 | if __name__ == "__main__": 319 | pytest.main() 320 | -------------------------------------------------------------------------------- /tests/test_multi_file_descriptor.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from io import StringIO 4 | 5 | import pytest 6 | 7 | from kaldiio.utils import MultiFileDescriptor 8 | 9 | 10 | def test_read(): 11 | fd = StringIO("abc") 12 | fd2 = StringIO("xdef") 13 | fd2.read(1) 14 | fd3 = StringIO("ghi") 15 | mfd = MultiFileDescriptor(fd, fd2, fd3) 16 | 17 | assert mfd.read(3) == "abc" 18 | assert mfd.read(6) == "defghi" 19 | 20 | 21 | @pytest.mark.parametrize("offset", [3, 4, 5, 6, 7, 8]) 22 | @pytest.mark.parametrize("offset2", [3, 4, 5, 6]) 23 | def test_read_tell(offset, offset2): 24 | fd = StringIO("abc") 25 | fd2 = StringIO("xdef") 26 | fd2.read(1) 27 | fd3 = StringIO("ghi") 28 | mfd = MultiFileDescriptor(fd, fd2, fd3) 29 | 30 | mfd.read(offset) 31 | assert mfd.tell() == min(offset, 9) 32 | mfd.read(offset2) 33 | assert mfd.tell() == min(offset + offset2, 9) 34 | 35 | 36 | def test_seek(): 37 | fd = StringIO("abc") 38 | fd2 = StringIO("xdef") 39 | fd2.read(1) 40 | fd3 = StringIO("ghi") 41 | mfd = MultiFileDescriptor(fd, fd2, fd3) 42 | 43 | assert mfd.read() == "abcdefghi" 44 | mfd.seek(0) 45 | assert mfd.read() == "abcdefghi" 46 | -------------------------------------------------------------------------------- /tests/test_open_like_kaldi.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import unicode_literals 3 | 4 | import io 5 | import sys 6 | 7 | from kaldiio.utils import open_like_kaldi 8 | 9 | PY3 = sys.version_info[0] == 3 10 | 11 | 12 | def test_open_like_kaldi(tmpdir): 13 | with open_like_kaldi("echo あああ |", "r") as f: 14 | if PY3: 15 | assert f.read() == "あああ\n" 16 | else: 17 | assert f.read().decode("utf-8") == "あああ\n" 18 | txt = tmpdir.mkdir("test").join("out.txt").strpath 19 | with open_like_kaldi("| cat > {}".format(txt), "w") as f: 20 | if PY3: 21 | f.write("あああ") 22 | else: 23 | f.write("あああ".encode("utf-8")) 24 | with io.open(txt, "r", encoding="utf-8") as f: 25 | assert f.read() == "あああ" 26 | -------------------------------------------------------------------------------- /tests/test_parse_specifier.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kaldiio import parse_specifier 4 | 5 | 6 | def test_ark(): 7 | d = parse_specifier("ark:file.ark") 8 | assert d["ark"] == "file.ark" 9 | 10 | 11 | def test_scp(): 12 | d = parse_specifier("scp:file.scp") 13 | assert d["scp"] == "file.scp" 14 | 15 | 16 | def test_ark_scp(): 17 | d = parse_specifier("ark,scp:file.ark,file.scp") 18 | assert d["ark"] == "file.ark" 19 | assert d["scp"] == "file.scp" 20 | 21 | 22 | def test_scp_ark(): 23 | d = parse_specifier("scp,ark:file.scp,file.ark") 24 | assert d["ark"] == "file.ark" 25 | assert d["scp"] == "file.scp" 26 | 27 | 28 | def test_error1(): 29 | with pytest.raises(ValueError): 30 | parse_specifier("ffdafafaf") 31 | 32 | 33 | def test_error2(): 34 | with pytest.raises(ValueError): 35 | parse_specifier("ak,sp:file.ark,file.scp") 36 | 37 | 38 | def test_error3(): 39 | with pytest.raises(ValueError): 40 | parse_specifier("ark,scp:file.ark") 41 | 42 | 43 | def test_error4(): 44 | with pytest.raises(ValueError): 45 | parse_specifier("ark,ark:file.ark,file.ark") 46 | -------------------------------------------------------------------------------- /tests/test_wav.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from kaldiio.matio import load_ark 7 | from kaldiio.matio import load_scp 8 | from kaldiio.matio import load_scp_sequential 9 | from kaldiio.matio import save_ark 10 | from kaldiio.utils import open_like_kaldi 11 | from kaldiio.wavio import read_wav 12 | from kaldiio.wavio import write_wav 13 | 14 | 15 | @pytest.mark.parametrize("dtype", [np.uint8, np.int16]) 16 | @pytest.mark.parametrize("func", [read_wav]) 17 | def test_read_wav(tmpdir, func, dtype): 18 | path = tmpdir.mkdir("test") 19 | wav = path.join("a.wav").strpath 20 | # Write as pcm16 21 | array = np.random.randint(0, 10, 10, dtype=dtype) 22 | write_wav(wav, 8000, array) 23 | with open(wav, "rb") as f: 24 | rate, array2 = func(f) 25 | np.testing.assert_array_equal(array, array2) 26 | 27 | 28 | @pytest.mark.parametrize("dtype", [np.uint8, np.int16]) 29 | @pytest.mark.parametrize("func", [load_scp, load_scp_sequential]) 30 | def test_load_wav_scp(tmpdir, func, dtype): 31 | path = tmpdir.mkdir("test") 32 | wav = path.join("a.wav").strpath 33 | scp = path.join("wav.scp").strpath 34 | 35 | # Write as pcm16 36 | array = np.random.randint(0, 10, 10, dtype=dtype) 37 | write_wav(wav, 8000, array) 38 | with open(scp, "w") as f: 39 | f.write("aaa {wav}\n".format(wav=wav)) 40 | rate, array2 = list(dict(func(scp)).values())[0] 41 | np.testing.assert_array_equal(array, array2) 42 | 43 | 44 | @pytest.mark.parametrize("dtype", [np.uint8, np.int16]) 45 | @pytest.mark.parametrize("func", [load_scp, load_scp_sequential]) 46 | def test_read_write_wav(tmpdir, func, dtype): 47 | path = tmpdir.mkdir("test") 48 | ark = path.join("a.ark").strpath 49 | scp = path.join("a.scp").strpath 50 | 51 | # Write as pcm16 52 | array = np.random.randint(0, 10, 10, dtype=dtype) 53 | array2 = np.random.randint(0, 10, 10, dtype=dtype) 54 | d = {"utt": (8000, array), "utt2": (8000, array2)} 55 | save_ark(ark, d, scp=scp) 56 | 57 | d = dict(func(scp)) 58 | rate, test = d["utt"] 59 | assert rate == 8000 60 | np.testing.assert_array_equal(array, test) 61 | 62 | rate, test = d["utt2"] 63 | assert rate == 8000 64 | np.testing.assert_array_equal(array2, test) 65 | 66 | d = dict(load_ark(ark)) 67 | rate, test = d["utt"] 68 | assert rate == 8000 69 | np.testing.assert_array_equal(array, test) 70 | 71 | rate, test = d["utt2"] 72 | assert rate == 8000 73 | np.testing.assert_array_equal(array2, test) 74 | 75 | 76 | @pytest.mark.parametrize("dtype", [np.uint8, np.int16]) 77 | @pytest.mark.parametrize("func", [load_scp, load_scp_sequential]) 78 | def test_scpwav_stream(tmpdir, func, dtype): 79 | path = tmpdir.mkdir("test") 80 | wav = path.join("aaa.wav").strpath 81 | wav2 = path.join("bbb.wav").strpath 82 | scp = path.join("wav.scp").strpath 83 | 84 | # Write as pcm16 85 | array = np.random.randint(0, 10, 40, dtype=dtype).reshape(5, 8) 86 | write_wav(wav, 8000, array) 87 | 88 | array2 = np.random.randint(0, 10, 10, dtype=dtype) 89 | write_wav(wav2, 8000, array2) 90 | 91 | with open(scp, "w") as f: 92 | f.write("aaa sox {wav} -t wav - |\n".format(wav=wav)) 93 | f.write("bbb cat {wav} |\n".format(wav=wav2)) 94 | rate, test = dict(func(scp))["aaa"] 95 | rate, test2 = dict(func(scp))["bbb"] 96 | np.testing.assert_array_equal(array, test) 97 | np.testing.assert_array_equal(array2, test2) 98 | 99 | 100 | @pytest.mark.parametrize("dtype", [np.uint8, np.int16]) 101 | def test_wavark_stream(tmpdir, dtype): 102 | path = tmpdir.mkdir("test") 103 | ark = path.join("a.ark").strpath 104 | 105 | # Write as pcm16 106 | array = np.random.randint(0, 10, 10, dtype=dtype) 107 | array2 = np.random.randint(0, 10, 10, dtype=dtype) 108 | d = {"utt": (8000, array), "utt2": (8000, array2)} 109 | save_ark(ark, d) 110 | 111 | with open_like_kaldi("cat {}|".format(ark), "rb") as f: 112 | d = dict(load_ark(f)) 113 | rate, test = d["utt"] 114 | assert rate == 8000 115 | np.testing.assert_array_equal(array, test) 116 | 117 | rate, test = d["utt2"] 118 | assert rate == 8000 119 | np.testing.assert_array_equal(array2, test) 120 | 121 | 122 | @pytest.mark.parametrize("dtype", [np.uint8, np.int16]) 123 | @pytest.mark.parametrize("func", [load_scp, load_scp_sequential]) 124 | def test_segments(tmpdir, func, dtype): 125 | # Create wav.scp 126 | path = tmpdir.mkdir("test") 127 | wavscp = path.join("wav.scp").strpath 128 | 129 | rate = 500 130 | with open(wavscp, "w") as f: 131 | wav = path.join("0.wav").strpath 132 | array0 = np.random.randint(0, 10, 2000, dtype=dtype) 133 | write_wav(wav, rate, array0) 134 | f.write("wav0 {}\n".format(wav)) 135 | 136 | wav = path.join("1.wav").strpath 137 | array1 = np.random.randint(0, 10, 2000, dtype=dtype) 138 | write_wav(wav, rate, array1) 139 | f.write("wav1 {}\n".format(wav)) 140 | 141 | # Create segments 142 | segments = path.join("segments").strpath 143 | with open(segments, "w") as f: 144 | f.write("utt1 wav0 0.1 0.2\n") 145 | f.write("utt2 wav0 0.4 0.6\n") 146 | f.write("utt3 wav1 0.4 0.5\n") 147 | f.write("utt4 wav1 0.6 0.8\n") 148 | d = dict(func(wavscp, segments=segments)) 149 | 150 | np.testing.assert_array_equal( 151 | d["utt1"][1], array0[int(0.1 * rate) : int(0.2 * rate)] 152 | ) 153 | np.testing.assert_array_equal( 154 | d["utt2"][1], array0[int(0.4 * rate) : int(0.6 * rate)] 155 | ) 156 | np.testing.assert_array_equal( 157 | d["utt3"][1], array1[int(0.4 * rate) : int(0.5 * rate)] 158 | ) 159 | np.testing.assert_array_equal( 160 | d["utt4"][1], array1[int(0.6 * rate) : int(0.8 * rate)] 161 | ) 162 | 163 | 164 | @pytest.mark.parametrize("func", [load_scp, load_scp_sequential]) 165 | def test_incorrect_header_wav(tmpdir, func): 166 | wav = os.path.join(os.path.dirname(__file__), "arks", "incorrect_header.wav") 167 | _, array = read_wav(wav) 168 | path = tmpdir.mkdir("test") 169 | scp = path.join("wav.scp").strpath 170 | 171 | with open(scp, "w") as f: 172 | f.write("aaa sox {wav} -t wav - |\n".format(wav=wav)) 173 | rate, test = dict(func(scp))["aaa"] 174 | np.testing.assert_array_equal(array, test) 175 | --------------------------------------------------------------------------------