├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── audioset_label.csv ├── package └── whisper-at │ ├── .flake8 │ ├── .gitattributes │ ├── .github │ └── workflows │ │ ├── python-publish.yml │ │ └── test.yml │ ├── LICENSE │ ├── MANIFEST.in │ ├── README.md │ ├── pyproject.toml │ ├── requirements.txt │ ├── setup.py │ └── whisper_at │ ├── __init__.py │ ├── __main__.py │ ├── assets │ ├── gpt2.tiktoken │ ├── label_name_dict.json │ ├── mel_filters.npz │ └── multilingual.tiktoken │ ├── at_post_processing.py │ ├── audio.py │ ├── decoding.py │ ├── model.py │ ├── normalizers │ ├── __init__.py │ ├── basic.py │ ├── english.json │ └── english.py │ ├── timing.py │ ├── tokenizer.py │ ├── transcribe.py │ ├── triton_ops.py │ ├── utils.py │ └── version.py ├── poster.pdf ├── poster.png ├── poster_low.png ├── pretrained_models └── README.md ├── review ├── author_response.pdf └── whisper_at_review.pdf ├── sample ├── whisper_at_demo.ipynb └── whisper_transcribe_test_simple.py ├── src ├── noise_robust_asr │ ├── asr_experiments │ │ ├── compute_wer.py │ │ ├── compute_wer_cla.py │ │ ├── gen_noisy_speech.py │ │ ├── transcribe_esc_hubert_xl.py │ │ ├── transcribe_hubert_large.py │ │ ├── transcribe_wav2vec_base.py │ │ ├── transcribe_wav2vec_robust.py │ │ └── transcribe_whisper.py │ ├── baseline_sound_classification.py │ ├── intermediate_feat_extract │ │ ├── as_full │ │ │ ├── batch_as_full_extract.sh │ │ │ ├── extract_as_full_whisper_all.py │ │ │ └── extract_as_full_whisper_all.sh │ │ ├── esc-50 │ │ │ ├── extract_esc50_hubert_xl_all_pool.py │ │ │ ├── extract_esc50_w2v_robust_all.py │ │ │ └── extract_esc50_whisper_all_pool.py │ │ └── whisper_feat_extracrt │ │ │ ├── .github │ │ │ └── workflows │ │ │ │ ├── python-publish.yml │ │ │ │ └── test.yml │ │ │ ├── LICENSE │ │ │ ├── MANIFEST.in │ │ │ ├── data │ │ │ ├── README.md │ │ │ └── meanwhile.json │ │ │ ├── notebooks │ │ │ ├── LibriSpeech.ipynb │ │ │ └── Multilingual_ASR.ipynb │ │ │ ├── requirements.txt │ │ │ ├── setup.py │ │ │ ├── tests │ │ │ ├── jfk.flac │ │ │ ├── test_audio.py │ │ │ ├── test_normalizer.py │ │ │ ├── test_tokenizer.py │ │ │ └── test_transcribe.py │ │ │ └── whisper │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ ├── assets │ │ │ ├── gpt2 │ │ │ │ ├── merges.txt │ │ │ │ ├── special_tokens_map.json │ │ │ │ ├── tokenizer_config.json │ │ │ │ └── vocab.json │ │ │ ├── mel_filters.npz │ │ │ └── multilingual │ │ │ │ ├── added_tokens.json │ │ │ │ ├── merges.txt │ │ │ │ ├── special_tokens_map.json │ │ │ │ ├── tokenizer_config.json │ │ │ │ └── vocab.json │ │ │ ├── audio.py │ │ │ ├── decoding.py │ │ │ ├── model.py │ │ │ ├── normalizers │ │ │ ├── __init__.py │ │ │ ├── basic.py │ │ │ ├── english.json │ │ │ └── english.py │ │ │ ├── tokenizer.py │ │ │ ├── transcribe.py │ │ │ ├── utils.py │ │ │ └── version.py │ └── plots │ │ ├── plot_figure1_lower.py │ │ ├── plot_figure1_upper.py │ │ ├── plot_figure2.py │ │ └── plot_figure3.py └── whisper_at_train │ ├── class_labels_indices.csv │ ├── datafiles │ └── README.md │ ├── dataloader_feat.py │ ├── gen_weight_file.py │ ├── log │ ├── base_ori.txt │ ├── large-v1_low.txt │ ├── large-v2_low.txt │ ├── large-v2_ori.txt │ ├── large_v1.txt │ ├── medium_low.txt │ ├── medium_ori.txt │ ├── small_low.txt │ ├── small_ori.txt │ └── tiny_ori.txt │ ├── models.py │ ├── run.py │ ├── run_as_full_train.sh │ ├── traintest.py │ └── utilities │ ├── __init__.py │ ├── compute_flops.py │ ├── compute_mAP.py │ ├── rename_state_dict.py │ ├── stats.py │ ├── util.py │ └── whisper_at_as_eval.py └── tltr.png /.gitattributes: -------------------------------------------------------------------------------- 1 | * linguist-vendored 2 | *.py linguist-vendored=false -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | thumbs.db 8 | .DS_Store 9 | .idea 10 | old/* 11 | *.pptx -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2023, Yuan Gong 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import whisper_at as whisper 3 | 4 | link = "https://github.com/YuanGongND/whisper-AT" 5 | text = "[Github]" 6 | paper_link = "https://arxiv.org/pdf/2307.03183.pdf" 7 | paper_text = "[Paper]" 8 | 9 | model_large = whisper.load_model("large-v1") 10 | model_tiny = whisper.load_model("tiny") 11 | model_tiny_en = whisper.load_model("tiny.en") 12 | model_small = whisper.load_model("small") 13 | 14 | mdl_dict = {"tiny": model_tiny, "tiny.en": model_tiny_en, "small": model_small, "large": model_large} 15 | lan_dict = {"English": 'en', "Chinese": 'zh'} 16 | 17 | def round_time_resolution(time_resolution): 18 | multiple = float(time_resolution) / 0.4 19 | rounded_multiple = round(multiple) 20 | rounded_time_resolution = rounded_multiple * 0.4 21 | return rounded_time_resolution 22 | 23 | def predict(audio_path_m, audio_path_t, model_size, language, time_resolution): 24 | # print(audio_path_m, audio_path_t) 25 | # print(type(audio_path_m), type(audio_path_t)) 26 | #return audio_path_m, audio_path_t 27 | if ((audio_path_m is None) != (audio_path_t is None)) == False: 28 | return "Please upload and only upload one recording, either upload the audio file or record using microphone.", "Please upload and only upload one recording, either upload the audio file or record using microphone." 29 | else: 30 | audio_path = audio_path_m or audio_path_t 31 | audio_tagging_time_resolution = round_time_resolution(time_resolution) 32 | model = mdl_dict[model_size] 33 | if language == 'Auto Detection': 34 | result = model.transcribe(audio_path, at_time_res=audio_tagging_time_resolution) 35 | else: 36 | result = model.transcribe(audio_path, at_time_res=audio_tagging_time_resolution, language=lan_dict[language]) 37 | audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527))) 38 | asr_output = "" 39 | for segment in result['segments']: 40 | asr_output = asr_output + format(segment['start'], ".1f") + 's-' + format(segment['end'], ".1f") + 's: ' + segment['text'] + '\n' 41 | at_output = "" 42 | for segment in audio_tag_result: 43 | print(segment) 44 | at_output = at_output + format(segment['time']['start'], ".1f") + 's-' + format(segment['time']['end'], ".1f") + 's: ' + ', '.join([x[0] for x in segment['audio tags']]) + '\n' 45 | print(at_output) 46 | return asr_output, at_output 47 | 48 | iface = gr.Interface(fn=predict, 49 | inputs=[gr.Audio(type="filepath", source='microphone', label='Please either upload an audio file or record using the microphone.', show_label=True), gr.Audio(type="filepath"), 50 | gr.Radio(["tiny", "tiny.en", "small", "large"], value='large', label="Model size", info="The larger the model, the better the performance and the slower the speed."), 51 | gr.Radio(["Auto Detection", "English", "Chinese"], value='Auto Detection', label="Language", info="Please specify the language, or let the model detect it automatically"), 52 | gr.Textbox(value='10', label='Time Resolution in Seconds (Must be must be an integer multiple of 0.4, e.g., 0.4, 2, 10)')], 53 | outputs=[gr.Textbox(label="Speech Output"), gr.Textbox(label="Audio Tag Output")], 54 | cache_examples=True, 55 | title="Quick Demo of Whisper-AT", 56 | description="We are glad to introduce Whisper-AT - A new joint audio tagging and speech recognition model. It outputs background sound labels in addition to text." + f"{paper_text} " + f"{text}
" + 57 | "Whisper-AT is authored by Yuan Gong, Sameer Khurana, Leonid Karlinsky, and James Glass (MIT & MIT-IBM Watson AI Lab). It is an Interspeech 2023 paper.") 58 | iface.launch(debug=True, share=True) -------------------------------------------------------------------------------- /package/whisper-at/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | per-file-ignores = 3 | */__init__.py: F401 4 | 5 | -------------------------------------------------------------------------------- /package/whisper-at/.gitattributes: -------------------------------------------------------------------------------- 1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages 2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code 3 | *.ipynb linguist-generated 4 | -------------------------------------------------------------------------------- /package/whisper-at/.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - uses: actions-ecosystem/action-regex-match@v2 13 | id: regex-match 14 | with: 15 | text: ${{ github.event.head_commit.message }} 16 | regex: '^Release ([^ ]+)' 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Release 26 | if: ${{ steps.regex-match.outputs.match != '' }} 27 | uses: softprops/action-gh-release@v1 28 | with: 29 | tag_name: v${{ steps.regex-match.outputs.group1 }} 30 | - name: Build and publish 31 | if: ${{ steps.regex-match.outputs.match != '' }} 32 | env: 33 | TWINE_USERNAME: __token__ 34 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 35 | run: | 36 | python setup.py sdist 37 | twine upload dist/* 38 | -------------------------------------------------------------------------------- /package/whisper-at/.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | jobs: 10 | whisper-test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ['3.8', '3.9', '3.10', '3.11'] 15 | pytorch-version: [1.13.1, 2.0.0] 16 | exclude: 17 | - python-version: '3.11' 18 | pytorch-version: 1.13.1 19 | steps: 20 | - uses: conda-incubator/setup-miniconda@v2 21 | - run: conda install -n test ffmpeg python=${{ matrix.python-version }} 22 | - run: pip3 install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu 23 | - uses: actions/checkout@v3 24 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH 25 | - run: pip install .["dev"] 26 | - run: black --check --diff -t py38 --include '(\.pyi?)$' . 27 | - run: isort --check --diff . 28 | - run: flake8 --ignore E203,W503,W504,E501,E731,E741 . 29 | - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' 30 | -------------------------------------------------------------------------------- /package/whisper-at/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2023, Yuan Gong 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /package/whisper-at/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include README.md 3 | include LICENSE 4 | include whisper_at/assets/* 5 | include whisper_at/normalizers/english.json 6 | -------------------------------------------------------------------------------- /package/whisper-at/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | 3 | [tool.isort] 4 | profile = "black" 5 | include_trailing_comma = true 6 | line_length = 88 7 | multi_line_output = 3 8 | 9 | -------------------------------------------------------------------------------- /package/whisper-at/requirements.txt: -------------------------------------------------------------------------------- 1 | numba 2 | numpy 3 | torch 4 | tqdm 5 | more-itertools 6 | tiktoken==0.3.3 -------------------------------------------------------------------------------- /package/whisper-at/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import sys 4 | 5 | import pkg_resources 6 | from setuptools import find_packages, setup 7 | 8 | requirements = [] 9 | if sys.platform.startswith("linux") and platform.machine() == "x86_64": 10 | requirements.append("triton==2.0.0") 11 | 12 | setup( 13 | name="whisper-at", 14 | py_modules=["whisper_at"], 15 | version=0.5, 16 | description="Joint speech recognition and audio tagging model.", 17 | long_description=open("README.md", encoding="utf-8").read(), 18 | long_description_content_type="text/markdown", 19 | readme="README.md", 20 | python_requires=">=3.8", 21 | author="Yuan Gong", 22 | url="https://github.com/YuanGongND/whisper-at", 23 | license="BSD", 24 | packages=find_packages(exclude=["tests*"]), 25 | install_requires=requirements 26 | + [ 27 | str(r) 28 | for r in pkg_resources.parse_requirements( 29 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 30 | ) 31 | ], 32 | entry_points={ 33 | "console_scripts": ["whisper=whisper.transcribe:cli"], 34 | }, 35 | include_package_data=True, 36 | extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]}, 37 | ) 38 | -------------------------------------------------------------------------------- /package/whisper-at/whisper_at/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim 12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language 13 | from .model import ModelDimensions, Whisper 14 | from .transcribe import transcribe 15 | from .at_post_processing import * 16 | from .version import __version__ 17 | 18 | _MODELS = { 19 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 20 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 21 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 22 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 23 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 24 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 25 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 26 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 27 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", 28 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 29 | "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 30 | } 31 | 32 | _MODELS_AT = { 33 | "tiny.en": "https://www.dropbox.com/s/atq9so6w0qug5ai/tiny.en_ori.pth?dl=1", 34 | "tiny": "https://www.dropbox.com/s/cib4q4iz6g758l0/tiny_ori.pth?dl=1", 35 | "base.en": "https://www.dropbox.com/s/qtzgsbuquoz0afn/base.en_ori.pth?dl=1", 36 | "base": "https://www.dropbox.com/s/2odwh42u6e9ger7/base_ori.pth?dl=1", 37 | "small.en": "https://www.dropbox.com/s/cyx50ycl1ul7lji/small.en_ori.pth?dl=1", 38 | "small.en_low": "https://www.dropbox.com/s/507o66zgl8v6ddd/small.en_low.pth?dl=1", 39 | "small": "https://www.dropbox.com/s/jftj9s0kr4ycvr1/small_ori.pth?dl=1", 40 | "small_low": "https://www.dropbox.com/s/a1x0416v58f7wrf/small_low.pth?dl=1", 41 | "medium.en": "https://www.dropbox.com/s/bbvylvmgns8ja4p/medium.en_ori.pth?dl=1", 42 | "medium.en_low": "https://www.dropbox.com/s/2q5wprr8f9gti5t/medium.en_low.pth?dl=1", 43 | "medium": "https://www.dropbox.com/s/65aabayr7o819az/medium_ori.pth?dl=1", 44 | "medium_low": "https://www.dropbox.com/s/0mnfmcasram4n6o/medium_low.pth?dl=1", 45 | "large-v1": "https://www.dropbox.com/s/b8x2en1fdzc8nhk/large-v1_ori.pth?dl=1", 46 | "large-v1_low": "https://www.dropbox.com/s/5o79h70wyla8jlk/large-v1_low.pth?dl=1", 47 | "large-v2": "https://www.dropbox.com/s/3zxpyvdrxy22eq7/large-v2_ori.pth?dl=1", 48 | "large-v2_low": "https://www.dropbox.com/s/jw2rh4uylhqgn85/large-v2_low.pth?dl=1", 49 | "large": "https://www.dropbox.com/s/3zxpyvdrxy22eq7/large-v2_ori.pth?dl=1", 50 | "large_low": "https://www.dropbox.com/s/jw2rh4uylhqgn85/large-v2_low.pth?dl=1", 51 | } 52 | 53 | # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are 54 | # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. 55 | _ALIGNMENT_HEADS = { 56 | "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", 57 | "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", 58 | "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", 59 | "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", 61 | "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", 64 | "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", 66 | "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", 67 | } 68 | 69 | 70 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 71 | os.makedirs(root, exist_ok=True) 72 | 73 | expected_sha256 = url.split("/")[-2] 74 | parsed_url = urllib.parse.urlparse(url).path 75 | download_target = os.path.join(root, os.path.basename(parsed_url)) 76 | 77 | if os.path.exists(download_target) and not os.path.isfile(download_target): 78 | raise RuntimeError(f"{download_target} exists and is not a regular file") 79 | 80 | if os.path.isfile(download_target): 81 | with open(download_target, "rb") as f: 82 | model_bytes = f.read() 83 | #if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 84 | return model_bytes if in_memory else download_target 85 | # else: 86 | # warnings.warn( 87 | # f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 88 | # ) 89 | 90 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 91 | with tqdm( 92 | total=int(source.info().get("Content-Length")), 93 | ncols=80, 94 | unit="iB", 95 | unit_scale=True, 96 | unit_divisor=1024, 97 | ) as loop: 98 | while True: 99 | buffer = source.read(8192) 100 | if not buffer: 101 | break 102 | 103 | output.write(buffer) 104 | loop.update(len(buffer)) 105 | 106 | model_bytes = open(download_target, "rb").read() 107 | # if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 108 | # raise RuntimeError( 109 | # "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." 110 | # ) 111 | 112 | return model_bytes if in_memory else download_target 113 | 114 | 115 | def available_models() -> List[str]: 116 | """Returns the names of available models""" 117 | return list(_MODELS.keys()) 118 | 119 | 120 | def load_model( 121 | name: str, 122 | device: Optional[Union[str, torch.device]] = None, 123 | download_root: str = None, 124 | in_memory: bool = False, 125 | at_low_compute = False 126 | ) -> Whisper: 127 | """ 128 | Load a Whisper ASR model 129 | 130 | Parameters 131 | ---------- 132 | name : str 133 | one of the official model names listed by `whisper.available_models()`, or 134 | path to a model checkpoint containing the model dimensions and the model state_dict. 135 | device : Union[str, torch.device] 136 | the PyTorch device to put the model into 137 | download_root: str 138 | path to download the model files; by default, it uses "~/.cache/whisper" 139 | in_memory: bool 140 | whether to preload the model weights into host memory 141 | 142 | Returns 143 | ------- 144 | model : Whisper 145 | The Whisper ASR model instance 146 | """ 147 | 148 | if device is None: 149 | device = "cuda" if torch.cuda.is_available() else "cpu" 150 | if download_root is None: 151 | default = os.path.join(os.path.expanduser("~"), ".cache") 152 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") 153 | 154 | # if use low-dim proj, only applied for large, medium, and small model 155 | if at_low_compute == True: 156 | at_mdl_name = name + '_low' 157 | else: 158 | at_mdl_name = name 159 | 160 | if name in _MODELS: 161 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 162 | checkpoint_file_at = _download(_MODELS_AT[at_mdl_name], download_root, in_memory) 163 | alignment_heads = _ALIGNMENT_HEADS[name] 164 | elif os.path.isfile(name): 165 | checkpoint_file = open(name, "rb").read() if in_memory else name 166 | alignment_heads = None 167 | else: 168 | raise RuntimeError( 169 | f"Model {name} not found; available models = {available_models()}" 170 | ) 171 | 172 | with ( 173 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") 174 | ) as fp: 175 | checkpoint = torch.load(fp, map_location=device) 176 | del checkpoint_file 177 | 178 | with ( 179 | io.BytesIO(checkpoint_file_at) if in_memory else open(checkpoint_file_at, "rb") 180 | ) as fp: 181 | checkpoint_at = torch.load(fp, map_location=device) 182 | del checkpoint_file_at 183 | 184 | dims = ModelDimensions(**checkpoint["dims"]) 185 | model = Whisper(dims, at_low_compute=at_low_compute) 186 | 187 | combined_state_dict = {} 188 | combined_state_dict.update(checkpoint["model_state_dict"]) 189 | combined_state_dict.update(checkpoint_at) 190 | 191 | model.load_state_dict(combined_state_dict, strict=True) 192 | 193 | if alignment_heads is not None: 194 | model.set_alignment_heads(alignment_heads) 195 | 196 | return model.to(device) 197 | -------------------------------------------------------------------------------- /package/whisper-at/whisper_at/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | cli() 4 | -------------------------------------------------------------------------------- /package/whisper-at/whisper_at/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/package/whisper-at/whisper_at/assets/mel_filters.npz -------------------------------------------------------------------------------- /package/whisper-at/whisper_at/at_post_processing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 6/1/23 3:33 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : at_post_processing.py 7 | 8 | import os 9 | import json 10 | import torch 11 | import warnings 12 | from .tokenizer import LANGUAGES 13 | 14 | def parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527))): 15 | """ 16 | :param result: The result dict returned by the whisper-at transcribe function. 17 | :param language: The audio tag label name language, e.g., 'en', 'zh'. Default='follow_asr', i.e., same with ASR result. 18 | :param top_k: Output up to k sound classes that have logits above p_threshold. Default=5. 19 | :param p_threshold: The logit threshold to predict a sound class. Default=-1. 20 | :param p_threshold: A list of indexes that of interest. Default = list(range(527)) (all classes). 21 | :return: A dictionary of audio tagging results 22 | """ 23 | asr_language = result['language'] 24 | at_time_res = result['at_time_res'] 25 | audio_tag = result['audio_tag'] 26 | 27 | if language == 'follow_asr': 28 | language = asr_language 29 | 30 | with open(os.path.join(os.path.dirname(__file__), "assets", "label_name_dict.json")) as json_file: 31 | label_name_dict = json.load(json_file) 32 | 33 | if language not in label_name_dict.keys(): 34 | warnings.warn("{:s} language not supported. Use English label names instead. If you wish to use label names of a specific language, please specify the language argument".format(language)) 35 | language = 'en' 36 | 37 | label_name_list = label_name_dict[language] 38 | 39 | all_res = [] 40 | for i in range(audio_tag.shape[0]): 41 | top_values, top_indices = torch.topk(audio_tag[i], k=top_k) 42 | cur_time_stamp = {'start': i*at_time_res, 'end': (i+1)*at_time_res} 43 | cur_labels_list = [] 44 | for j in range(top_indices.shape[0]): 45 | if top_values[j] > p_threshold and top_indices[j] in include_class_list: 46 | cur_label = (label_name_list[top_indices[j]], top_values[j].item()) 47 | cur_labels_list.append(cur_label) 48 | all_res.append({'time': cur_time_stamp, 'audio tags': cur_labels_list}) 49 | return all_res 50 | 51 | def print_label_name(language='en'): 52 | with open(os.path.join(os.path.dirname(__file__), "assets", "label_name_dict.json")) as json_file: 53 | label_name_dict = json.load(json_file) 54 | label_name_list = label_name_dict[language] 55 | for i in range(len(label_name_list)): 56 | print("index: {:d} : {:s}".format(i, label_name_list[i])) 57 | 58 | def print_support_language(): 59 | with open(os.path.join(os.path.dirname(__file__), "assets", "label_name_dict.json")) as json_file: 60 | label_name_dict = json.load(json_file) 61 | for key in label_name_dict.keys(): 62 | print("language code: {:s} : {:s}".format(key, LANGUAGES[key])) 63 | 64 | if __name__ == '__main__': 65 | print_support_language() 66 | print_label_name(language='zh') 67 | 68 | -------------------------------------------------------------------------------- /package/whisper-at/whisper_at/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from subprocess import CalledProcessError, run 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | N_MELS = 80 16 | HOP_LENGTH = 160 17 | CHUNK_LENGTH = 30 18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input 20 | 21 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 22 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame 23 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token 24 | 25 | 26 | def load_audio(file: str, sr: int = SAMPLE_RATE): 27 | """ 28 | Open an audio file and read as mono waveform, resampling as necessary 29 | 30 | Parameters 31 | ---------- 32 | file: str 33 | The audio file to open 34 | 35 | sr: int 36 | The sample rate to resample the audio if necessary 37 | 38 | Returns 39 | ------- 40 | A NumPy array containing the audio waveform, in float32 dtype. 41 | """ 42 | 43 | # This launches a subprocess to decode audio while down-mixing 44 | # and resampling as necessary. Requires the ffmpeg CLI in PATH. 45 | # fmt: off 46 | cmd = [ 47 | "ffmpeg", 48 | "-nostdin", 49 | "-threads", "0", 50 | "-i", file, 51 | "-f", "s16le", 52 | "-ac", "1", 53 | "-acodec", "pcm_s16le", 54 | "-ar", str(sr), 55 | "-" 56 | ] 57 | # fmt: on 58 | try: 59 | out = run(cmd, capture_output=True, check=True).stdout 60 | except CalledProcessError as e: 61 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 62 | 63 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 64 | 65 | 66 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 67 | """ 68 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 69 | """ 70 | if torch.is_tensor(array): 71 | if array.shape[axis] > length: 72 | array = array.index_select( 73 | dim=axis, index=torch.arange(length, device=array.device) 74 | ) 75 | 76 | if array.shape[axis] < length: 77 | pad_widths = [(0, 0)] * array.ndim 78 | pad_widths[axis] = (0, length - array.shape[axis]) 79 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 80 | else: 81 | if array.shape[axis] > length: 82 | array = array.take(indices=range(length), axis=axis) 83 | 84 | if array.shape[axis] < length: 85 | pad_widths = [(0, 0)] * array.ndim 86 | pad_widths[axis] = (0, length - array.shape[axis]) 87 | array = np.pad(array, pad_widths) 88 | 89 | return array 90 | 91 | 92 | @lru_cache(maxsize=None) 93 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 94 | """ 95 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 96 | Allows decoupling librosa dependency; saved using: 97 | 98 | np.savez_compressed( 99 | "mel_filters.npz", 100 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 101 | ) 102 | """ 103 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 104 | with np.load( 105 | os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") 106 | ) as f: 107 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 108 | 109 | 110 | def log_mel_spectrogram( 111 | audio: Union[str, np.ndarray, torch.Tensor], 112 | n_mels: int = N_MELS, 113 | padding: int = 0, 114 | device: Optional[Union[str, torch.device]] = None, 115 | ): 116 | """ 117 | Compute the log-Mel spectrogram of 118 | 119 | Parameters 120 | ---------- 121 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 122 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 123 | 124 | n_mels: int 125 | The number of Mel-frequency filters, only 80 is supported 126 | 127 | padding: int 128 | Number of zero samples to pad to the right 129 | 130 | device: Optional[Union[str, torch.device]] 131 | If given, the audio tensor is moved to this device before STFT 132 | 133 | Returns 134 | ------- 135 | torch.Tensor, shape = (80, n_frames) 136 | A Tensor that contains the Mel spectrogram 137 | """ 138 | if not torch.is_tensor(audio): 139 | if isinstance(audio, str): 140 | audio = load_audio(audio) 141 | audio = torch.from_numpy(audio) 142 | 143 | if device is not None: 144 | audio = audio.to(device) 145 | if padding > 0: 146 | audio = F.pad(audio, (0, padding)) 147 | window = torch.hann_window(N_FFT).to(audio.device) 148 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 149 | magnitudes = stft[..., :-1].abs() ** 2 150 | 151 | filters = mel_filters(audio.device, n_mels) 152 | mel_spec = filters @ magnitudes 153 | 154 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 155 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 156 | log_spec = (log_spec + 4.0) / 4.0 157 | return log_spec 158 | -------------------------------------------------------------------------------- /package/whisper-at/whisper_at/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer as BasicTextNormalizer 2 | from .english import EnglishTextNormalizer as EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /package/whisper-at/whisper_at/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c 52 | for c in unicodedata.normalize("NFKC", s) 53 | ) 54 | 55 | 56 | class BasicTextNormalizer: 57 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 58 | self.clean = ( 59 | remove_symbols_and_diacritics if remove_diacritics else remove_symbols 60 | ) 61 | self.split_letters = split_letters 62 | 63 | def __call__(self, s: str): 64 | s = s.lower() 65 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 66 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 67 | s = self.clean(s).lower() 68 | 69 | if self.split_letters: 70 | s = " ".join(regex.findall(r"\X", s, regex.U)) 71 | 72 | s = re.sub( 73 | r"\s+", " ", s 74 | ) # replace any successive whitespace characters with a space 75 | 76 | return s 77 | -------------------------------------------------------------------------------- /package/whisper-at/whisper_at/triton_ops.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import numpy as np 4 | import torch 5 | 6 | try: 7 | import triton 8 | import triton.language as tl 9 | except ImportError: 10 | raise RuntimeError("triton import failed; try `pip install --pre triton`") 11 | 12 | 13 | @triton.jit 14 | def dtw_kernel( 15 | cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr 16 | ): 17 | offsets = tl.arange(0, BLOCK_SIZE) 18 | mask = offsets < M 19 | 20 | for k in range(1, N + M + 1): # k = i + j 21 | tl.debug_barrier() 22 | 23 | p0 = cost + (k - 1) * cost_stride 24 | p1 = cost + k * cost_stride 25 | p2 = cost + k * cost_stride + 1 26 | 27 | c0 = tl.load(p0 + offsets, mask=mask) 28 | c1 = tl.load(p1 + offsets, mask=mask) 29 | c2 = tl.load(p2 + offsets, mask=mask) 30 | 31 | x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0) 32 | cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2) 33 | 34 | cost_ptr = cost + (k + 1) * cost_stride + 1 35 | tl.store(cost_ptr + offsets, cost_row, mask=mask) 36 | 37 | trace_ptr = trace + (k + 1) * trace_stride + 1 38 | tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1)) 39 | tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2)) 40 | tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2)) 41 | 42 | 43 | @lru_cache(maxsize=None) 44 | def median_kernel(filter_width: int): 45 | @triton.jit 46 | def kernel( 47 | y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr 48 | ): # x.shape[-1] == filter_width 49 | row_idx = tl.program_id(0) 50 | offsets = tl.arange(0, BLOCK_SIZE) 51 | mask = offsets < y_stride 52 | 53 | x_ptr = x + row_idx * x_stride # noqa: F841 54 | y_ptr = y + row_idx * y_stride 55 | 56 | LOAD_ALL_ROWS_HERE # noqa: F821 57 | 58 | BUBBLESORT_HERE # noqa: F821 59 | 60 | tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 61 | 62 | kernel = triton.JITFunction(kernel.fn) 63 | kernel.src = kernel.src.replace( 64 | " LOAD_ALL_ROWS_HERE", 65 | "\n".join( 66 | [ 67 | f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" 68 | for i in range(filter_width) 69 | ] 70 | ), 71 | ) 72 | kernel.src = kernel.src.replace( 73 | " BUBBLESORT_HERE", 74 | "\n\n".join( 75 | [ 76 | "\n\n".join( 77 | [ 78 | "\n".join( 79 | [ 80 | f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})", 81 | f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})", 82 | f" row{j} = smaller", 83 | f" row{j + 1} = larger", 84 | ] 85 | ) 86 | for j in range(filter_width - i - 1) 87 | ] 88 | ) 89 | for i in range(filter_width // 2 + 1) 90 | ] 91 | ), 92 | ) 93 | kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") 94 | 95 | return kernel 96 | 97 | 98 | def median_filter_cuda(x: torch.Tensor, filter_width: int): 99 | """Apply a median filter of given width along the last dimension of x""" 100 | slices = x.contiguous().unfold(-1, filter_width, 1) 101 | grid = np.prod(slices.shape[:-2]) 102 | 103 | kernel = median_kernel(filter_width) 104 | y = torch.empty_like(slices[..., 0]) 105 | 106 | BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length() 107 | kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE) 108 | 109 | return y 110 | -------------------------------------------------------------------------------- /package/whisper-at/whisper_at/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | import zlib 6 | from typing import Callable, Optional, TextIO 7 | 8 | system_encoding = sys.getdefaultencoding() 9 | 10 | if system_encoding != "utf-8": 11 | 12 | def make_safe(string): 13 | # replaces any character not representable using the system default encoding with an '?', 14 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). 15 | return string.encode(system_encoding, errors="replace").decode(system_encoding) 16 | 17 | else: 18 | 19 | def make_safe(string): 20 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding 21 | return string 22 | 23 | 24 | def exact_div(x, y): 25 | assert x % y == 0 26 | return x // y 27 | 28 | 29 | def str2bool(string): 30 | str2val = {"True": True, "False": False} 31 | if string in str2val: 32 | return str2val[string] 33 | else: 34 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 35 | 36 | 37 | def optional_int(string): 38 | return None if string == "None" else int(string) 39 | 40 | 41 | def optional_float(string): 42 | return None if string == "None" else float(string) 43 | 44 | 45 | def compression_ratio(text) -> float: 46 | text_bytes = text.encode("utf-8") 47 | return len(text_bytes) / len(zlib.compress(text_bytes)) 48 | 49 | 50 | def format_timestamp( 51 | seconds: float, always_include_hours: bool = False, decimal_marker: str = "." 52 | ): 53 | assert seconds >= 0, "non-negative timestamp expected" 54 | milliseconds = round(seconds * 1000.0) 55 | 56 | hours = milliseconds // 3_600_000 57 | milliseconds -= hours * 3_600_000 58 | 59 | minutes = milliseconds // 60_000 60 | milliseconds -= minutes * 60_000 61 | 62 | seconds = milliseconds // 1_000 63 | milliseconds -= seconds * 1_000 64 | 65 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 66 | return ( 67 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 68 | ) 69 | 70 | 71 | class ResultWriter: 72 | extension: str 73 | 74 | def __init__(self, output_dir: str): 75 | self.output_dir = output_dir 76 | 77 | def __call__(self, result: dict, audio_path: str, options: dict): 78 | audio_basename = os.path.basename(audio_path) 79 | audio_basename = os.path.splitext(audio_basename)[0] 80 | output_path = os.path.join( 81 | self.output_dir, audio_basename + "." + self.extension 82 | ) 83 | 84 | with open(output_path, "w", encoding="utf-8") as f: 85 | self.write_result(result, file=f, options=options) 86 | 87 | def write_result(self, result: dict, file: TextIO, options: dict): 88 | raise NotImplementedError 89 | 90 | 91 | class WriteTXT(ResultWriter): 92 | extension: str = "txt" 93 | 94 | def write_result(self, result: dict, file: TextIO, options: dict): 95 | for segment in result["segments"]: 96 | print(segment["text"].strip(), file=file, flush=True) 97 | 98 | 99 | class SubtitlesWriter(ResultWriter): 100 | always_include_hours: bool 101 | decimal_marker: str 102 | 103 | def iterate_result(self, result: dict, options: dict): 104 | raw_max_line_width: Optional[int] = options["max_line_width"] 105 | max_line_count: Optional[int] = options["max_line_count"] 106 | highlight_words: bool = options["highlight_words"] 107 | max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width 108 | preserve_segments = max_line_count is None or raw_max_line_width is None 109 | 110 | def iterate_subtitles(): 111 | line_len = 0 112 | line_count = 1 113 | # the next subtitle to yield (a list of word timings with whitespace) 114 | subtitle: list[dict] = [] 115 | last = result["segments"][0]["words"][0]["start"] 116 | for segment in result["segments"]: 117 | for i, original_timing in enumerate(segment["words"]): 118 | timing = original_timing.copy() 119 | long_pause = not preserve_segments and timing["start"] - last > 3.0 120 | has_room = line_len + len(timing["word"]) <= max_line_width 121 | seg_break = i == 0 and len(subtitle) > 0 and preserve_segments 122 | if line_len > 0 and has_room and not long_pause and not seg_break: 123 | # line continuation 124 | line_len += len(timing["word"]) 125 | else: 126 | # new line 127 | timing["word"] = timing["word"].strip() 128 | if ( 129 | len(subtitle) > 0 130 | and max_line_count is not None 131 | and (long_pause or line_count >= max_line_count) 132 | or seg_break 133 | ): 134 | # subtitle break 135 | yield subtitle 136 | subtitle = [] 137 | line_count = 1 138 | elif line_len > 0: 139 | # line break 140 | line_count += 1 141 | timing["word"] = "\n" + timing["word"] 142 | line_len = len(timing["word"].strip()) 143 | subtitle.append(timing) 144 | last = timing["start"] 145 | if len(subtitle) > 0: 146 | yield subtitle 147 | 148 | if "words" in result["segments"][0]: 149 | for subtitle in iterate_subtitles(): 150 | subtitle_start = self.format_timestamp(subtitle[0]["start"]) 151 | subtitle_end = self.format_timestamp(subtitle[-1]["end"]) 152 | subtitle_text = "".join([word["word"] for word in subtitle]) 153 | if highlight_words: 154 | last = subtitle_start 155 | all_words = [timing["word"] for timing in subtitle] 156 | for i, this_word in enumerate(subtitle): 157 | start = self.format_timestamp(this_word["start"]) 158 | end = self.format_timestamp(this_word["end"]) 159 | if last != start: 160 | yield last, start, subtitle_text 161 | 162 | yield start, end, "".join( 163 | [ 164 | re.sub(r"^(\s*)(.*)$", r"\1\2", word) 165 | if j == i 166 | else word 167 | for j, word in enumerate(all_words) 168 | ] 169 | ) 170 | last = end 171 | else: 172 | yield subtitle_start, subtitle_end, subtitle_text 173 | else: 174 | for segment in result["segments"]: 175 | segment_start = self.format_timestamp(segment["start"]) 176 | segment_end = self.format_timestamp(segment["end"]) 177 | segment_text = segment["text"].strip().replace("-->", "->") 178 | yield segment_start, segment_end, segment_text 179 | 180 | def format_timestamp(self, seconds: float): 181 | return format_timestamp( 182 | seconds=seconds, 183 | always_include_hours=self.always_include_hours, 184 | decimal_marker=self.decimal_marker, 185 | ) 186 | 187 | 188 | class WriteVTT(SubtitlesWriter): 189 | extension: str = "vtt" 190 | always_include_hours: bool = False 191 | decimal_marker: str = "." 192 | 193 | def write_result(self, result: dict, file: TextIO, options: dict): 194 | print("WEBVTT\n", file=file) 195 | for start, end, text in self.iterate_result(result, options): 196 | print(f"{start} --> {end}\n{text}\n", file=file, flush=True) 197 | 198 | 199 | class WriteSRT(SubtitlesWriter): 200 | extension: str = "srt" 201 | always_include_hours: bool = True 202 | decimal_marker: str = "," 203 | 204 | def write_result(self, result: dict, file: TextIO, options: dict): 205 | for i, (start, end, text) in enumerate( 206 | self.iterate_result(result, options), start=1 207 | ): 208 | print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) 209 | 210 | 211 | class WriteTSV(ResultWriter): 212 | """ 213 | Write a transcript to a file in TSV (tab-separated values) format containing lines like: 214 | \t\t 215 | 216 | Using integer milliseconds as start and end times means there's no chance of interference from 217 | an environment setting a language encoding that causes the decimal in a floating point number 218 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. 219 | """ 220 | 221 | extension: str = "tsv" 222 | 223 | def write_result(self, result: dict, file: TextIO, options: dict): 224 | print("start", "end", "text", sep="\t", file=file) 225 | for segment in result["segments"]: 226 | print(round(1000 * segment["start"]), file=file, end="\t") 227 | print(round(1000 * segment["end"]), file=file, end="\t") 228 | print(segment["text"].strip().replace("\t", " "), file=file, flush=True) 229 | 230 | 231 | class WriteJSON(ResultWriter): 232 | extension: str = "json" 233 | 234 | def write_result(self, result: dict, file: TextIO, options: dict): 235 | json.dump(result, file) 236 | 237 | 238 | def get_writer( 239 | output_format: str, output_dir: str 240 | ) -> Callable[[dict, TextIO, dict], None]: 241 | writers = { 242 | "txt": WriteTXT, 243 | "vtt": WriteVTT, 244 | "srt": WriteSRT, 245 | "tsv": WriteTSV, 246 | "json": WriteJSON, 247 | } 248 | 249 | if output_format == "all": 250 | all_writers = [writer(output_dir) for writer in writers.values()] 251 | 252 | def write_all(result: dict, file: TextIO, options: dict): 253 | for writer in all_writers: 254 | writer(result, file, options) 255 | 256 | return write_all 257 | 258 | return writers[output_format](output_dir) 259 | -------------------------------------------------------------------------------- /package/whisper-at/whisper_at/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5" 2 | -------------------------------------------------------------------------------- /poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/poster.pdf -------------------------------------------------------------------------------- /poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/poster.png -------------------------------------------------------------------------------- /poster_low.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/poster_low.png -------------------------------------------------------------------------------- /pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | # Pretrained Weights 2 | 3 | The Whisper-AT script downloads the original OpenAI Whisper model and our AT model automatically. So you do not really need to download it manually. 4 | But in case your device does not have Internet access, here is the link. Download and place the model in the same directory (by default `~/.cache/whisper`) as the original OpenAI Whisper model. 5 | 6 | These links support `wget`. 7 | 8 | ```python 9 | dropbox_path = [ 10 | "tiny.en": "https://www.dropbox.com/s/atq9so6w0qug5ai/tiny.en_ori.pth?dl=1", 11 | "tiny": "https://www.dropbox.com/s/06f2h29aki39q9r/tiny_ori.pth?dl=1", 12 | "base.en": "https://www.dropbox.com/s/qtzgsbuquoz0afn/base.en_ori.pth?dl=1", 13 | "base": "https://www.dropbox.com/s/4vn2oatda321y7h/base_ori.pth?dl=1", 14 | "small.en": "https://www.dropbox.com/s/cyx50ycl1ul7lji/small.en_ori.pth?dl=1", 15 | "small.en_low": "https://www.dropbox.com/s/507o66zgl8v6ddd/small.en_low.pth?dl=1", 16 | "small": "https://www.dropbox.com/s/5zqzs3e47zwhjd3/small_ori.pth?dl=1", 17 | "small_low": "https://www.dropbox.com/s/3lxlmh437tneifl/small_low.pth?dl=1", 18 | "medium.en": "https://www.dropbox.com/s/bbvylvmgns8ja4p/medium.en_ori.pth?dl=1", 19 | "medium.en_low": "https://www.dropbox.com/s/2q5wprr8f9gti5t/medium.en_low.pth?dl=1", 20 | "medium": "https://www.dropbox.com/s/93zfj4afmv0qfyl/medium_ori.pth?dl=1", 21 | "medium_low": "https://www.dropbox.com/s/g66h1vtn1u426dj/medium_low.pth?dl=1", 22 | "large-v1": "https://www.dropbox.com/s/b8x2en1fdzc8nhk/large-v1_ori.pth?dl=1", 23 | "large-v1_low": "https://www.dropbox.com/s/5o79h70wyla8jlk/large-v1_low.pth?dl=1", 24 | "large-v2": "https://www.dropbox.com/s/94x7wqw4hscpls0/large-v2_ori.pth?dl=1", 25 | "large-v2_low": "https://www.dropbox.com/s/wk5dyxustpji06c/large-v2_low.pth?dl=1", 26 | "large": "https://www.dropbox.com/s/94x7wqw4hscpls0/large-v2_ori.pth?dl=1", 27 | "large_low": "https://www.dropbox.com/s/wk5dyxustpji06c/large-v2_low.pth?dl=1"] 28 | ``` 29 | 30 | ## China Mirror Links 镜像链接 31 | 32 | The models are hosted on Dropbox. If dropbox is not accessible, use a VPN or the mirror link, you would have to donwload it manually and place it in the same directory (by default `~/.cache/whisper`) as the original OpenAI Whisper model. 33 | [[镜像链接(腾讯微云)]](https://share.weiyun.com/bVxQWxTe) -------------------------------------------------------------------------------- /review/author_response.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/review/author_response.pdf -------------------------------------------------------------------------------- /review/whisper_at_review.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/review/whisper_at_review.pdf -------------------------------------------------------------------------------- /sample/whisper_transcribe_test_simple.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 5/28/23 2:36 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : whisper_transcribe_test.py 7 | 8 | import whisper_at as whisper 9 | 10 | model = whisper.load_model("large-v1") 11 | result = model.transcribe("/data/sls/scratch/yuangong/dataset/adress_train/ADReSS-IS2020-data/train/Full_wave_enhanced_audio/cc/S024.wav") 12 | #result = model.transcribe("/data/sls/scratch/yuangong/whisper-at/sample_audio/007P6bFgRCU_10.000.flac", at_time_res=2) 13 | 14 | print(result['text']) 15 | print(result['segments']) 16 | text_segments = result['segments'] 17 | text_annotation = [(x['start'], x['end'], x['text']) for x in text_segments] 18 | print(text_annotation) 19 | at_res = whisper.parse_at_label(result, language='en', p_threshold=-1) 20 | print(at_res) 21 | 22 | all_seg = [] 23 | for segment in at_res: 24 | cur_start = segment['time']['start'] 25 | cur_end = segment['time']['end'] 26 | cur_tags = segment['audio tags'] 27 | cur_tags = [x[0] for x in cur_tags] 28 | cur_tags = '; '.join(cur_tags) 29 | all_seg.append((cur_start, cur_end, cur_tags)) 30 | print(all_seg) 31 | 32 | whisper.print_support_language() 33 | whisper.print_label_name(language='zh') -------------------------------------------------------------------------------- /src/noise_robust_asr/asr_experiments/compute_wer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3/3/23 6:27 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : cal_wer.py 7 | 8 | import os 9 | import editdistance 10 | import jiwer 11 | import numpy as np 12 | 13 | def fileList(source): 14 | matches = [] 15 | for root, dirnames, filenames in os.walk(source): 16 | for filename in filenames: 17 | if filename.endswith(('.txt')): 18 | matches.append(os.path.join(root, filename)) 19 | return matches 20 | 21 | def calculate_wer(seqs_hat, seqs_true): 22 | """Calculate sentence-level WER score. 23 | :param list seqs_hat: prediction 24 | :param list seqs_true: reference 25 | :return: average sentence-level WER score 26 | :rtype float 27 | """ 28 | word_eds, word_ref_lens = [], [] 29 | for i in range(len(seqs_true)): 30 | seq_true_text = seqs_true[i] 31 | seq_hat_text = seqs_hat[i] 32 | hyp_words = seq_hat_text.split() 33 | ref_words = seq_true_text.split() 34 | word_eds.append(editdistance.eval(hyp_words, ref_words)) 35 | word_ref_lens.append(len(ref_words)) 36 | return float(sum(word_eds)) / sum(word_ref_lens) 37 | 38 | def eval_noise_wer(trans_path, result_path): 39 | whisper_trans = fileList(trans_path) 40 | truth_path = '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/ground_truth_trans/' 41 | truth_trans = fileList(truth_path) 42 | print(len(whisper_trans), len(truth_trans)) 43 | 44 | def preprocess_text(cur_trans): 45 | cur_trans = jiwer.ToUpperCase()(cur_trans) 46 | cur_trans = jiwer.RemovePunctuation()(cur_trans) 47 | return cur_trans 48 | 49 | wer_list = [] 50 | for db in [-20, -15, -10, -5, 0, 5, 10, 15, 20]: 51 | cur_trans_list, cur_truth_list = [], [] 52 | for trans_name in whisper_trans: 53 | if int(trans_name.split('/')[-1].split('_')[0]) == db: 54 | with open(trans_name, "r") as f: 55 | cur_trans = f.read() 56 | cur_trans = preprocess_text(cur_trans) 57 | cur_trans_list.append(cur_trans) 58 | print('trans: ', cur_trans) 59 | 60 | cur_truth_name = '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/ground_truth_trans/' + trans_name.split('/')[-1].split('_mix_')[0].split('_')[2] + '.txt' 61 | with open(cur_truth_name, "r") as f: 62 | cur_truth = f.read() 63 | cur_truth = preprocess_text(cur_truth) 64 | cur_truth_list.append(cur_truth) 65 | print('truth: ', cur_truth) 66 | wer = calculate_wer(cur_trans_list, cur_truth_list) 67 | print('wer is ', wer) 68 | wer_list.append(wer) 69 | print(wer_list) 70 | np.savetxt(result_path, wer_list, delimiter=',') 71 | 72 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_large-v1/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_large-v1.csv') 73 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_tiny.en/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_tiny.csv') 74 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_small.en/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_small.csv') 75 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_base.en/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_base.csv') 76 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_medium.en/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_medium.csv') 77 | 78 | -------------------------------------------------------------------------------- /src/noise_robust_asr/asr_experiments/compute_wer_cla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3/3/23 6:27 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : cal_wer.py 7 | 8 | import os 9 | import editdistance 10 | import jiwer 11 | import numpy as np 12 | 13 | 14 | def fileList(source): 15 | matches = [] 16 | for root, dirnames, filenames in os.walk(source): 17 | for filename in filenames: 18 | if filename.endswith(('.txt')): 19 | matches.append(os.path.join(root, filename)) 20 | return matches 21 | 22 | def calculate_wer(seqs_hat, seqs_true): 23 | """Calculate sentence-level WER score. 24 | :param list seqs_hat: prediction 25 | :param list seqs_true: reference 26 | :return: average sentence-level WER score 27 | :rtype float 28 | """ 29 | word_eds, word_ref_lens = [], [] 30 | for i in range(len(seqs_true)): 31 | seq_true_text = seqs_true[i] 32 | seq_hat_text = seqs_hat[i] 33 | hyp_words = seq_hat_text.split() 34 | ref_words = seq_true_text.split() 35 | word_eds.append(editdistance.eval(hyp_words, ref_words)) 36 | word_ref_lens.append(len(ref_words)) 37 | return float(sum(word_eds)) / sum(word_ref_lens) 38 | 39 | def eval_noise_wer(trans_path, result_path): 40 | whisper_trans = fileList(trans_path) 41 | truth_path = '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/ground_truth_trans/' 42 | truth_trans = fileList(truth_path) 43 | print(len(whisper_trans), len(truth_trans)) 44 | 45 | def preprocess_text(cur_trans): 46 | cur_trans = jiwer.ToUpperCase()(cur_trans) 47 | cur_trans = jiwer.RemovePunctuation()(cur_trans) 48 | return cur_trans 49 | 50 | all_wer_list = [] 51 | for db in [-20, -15, -10, -5, 0, 5, 10, 15, 20]: 52 | wer_list = [] 53 | for cla in range(50): 54 | cur_trans_list, cur_truth_list = [], [] 55 | for trans_name in whisper_trans: 56 | if int(trans_name.split('/')[-1].split('_')[0]) == db and int(trans_name.split('/')[-1].split('_')[1]) == cla: 57 | with open(trans_name, "r") as f: 58 | cur_trans = f.read() 59 | cur_trans = preprocess_text(cur_trans) 60 | cur_trans_list.append(cur_trans) 61 | #print('trans: ', cur_trans) 62 | 63 | cur_truth_name = '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/ground_truth_trans/' + trans_name.split('/')[-1].split('_mix_')[0].split('_')[2] + '.txt' 64 | with open(cur_truth_name, "r") as f: 65 | cur_truth = f.read() 66 | cur_truth = preprocess_text(cur_truth) 67 | cur_truth_list.append(cur_truth) 68 | #print('truth: ', cur_truth) 69 | #print(len(cur_trans_list), len(cur_truth_list)) 70 | wer = calculate_wer(cur_trans_list, cur_truth_list) 71 | #print('wer is ', wer) 72 | wer_list.append(wer) 73 | #print(wer_list) 74 | all_wer_list.append(wer_list) 75 | np.savetxt(result_path, all_wer_list, delimiter=',') 76 | 77 | # eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_text_hubert_large/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results/hubert_large_cla.csv') 78 | # eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_text_hubert_xlarge/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results/hubert_xlarge_cla.csv') 79 | # eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_text_w2v_base/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results/w2v_base_cla.csv') 80 | # eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_text_w2v_large_robust/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results/w2v_large_robust_cla.csv') 81 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_large-v1/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_large-v1_cla.csv') 82 | 83 | -------------------------------------------------------------------------------- /src/noise_robust_asr/asr_experiments/gen_noisy_speech.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3/3/23 3:04 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : gen_noisy_speech.py 7 | 8 | import numpy as np 9 | import torchaudio 10 | import torch 11 | import os 12 | 13 | def fileList(source): 14 | matches = [] 15 | for root, dirnames, filenames in os.walk(source): 16 | for filename in filenames: 17 | if filename.endswith(('.flac', '.wav')): 18 | matches.append(os.path.join(root, filename)) 19 | return matches 20 | 21 | def add_noise_torch(speech_path, noise_path, noise_db, tar_path): 22 | speech, sr_s = torchaudio.load(speech_path) 23 | noise, sr_n = torchaudio.load(noise_path) 24 | 25 | assert sr_s == sr_n 26 | power_speech = (speech ** 2).mean() 27 | power_noise = (noise ** 2).mean() 28 | 29 | scale = (10 ** (-noise_db / 20) * np.sqrt(power_speech) / np.sqrt(max(power_noise, 1e-10))) 30 | 31 | # if speech is longer than the noise 32 | if speech.shape[1] > noise.shape[1]: 33 | ratio = int(np.ceil(speech.shape[1] / noise.shape[1])) 34 | noise = torch.concat([noise for _ in range(ratio)], dim=1) 35 | 36 | if speech.shape[1] < noise.shape[1]: 37 | noise = noise[:, :speech.shape[1]] 38 | 39 | speech = speech + scale * noise 40 | torchaudio.save(tar_path, speech, sample_rate=sr_s) 41 | 42 | all_speech = fileList('/data/sls/scratch/yuangong/whisper-a/sample_audio/test-clean') 43 | all_speech.sort() 44 | all_speech = all_speech[:40] 45 | print(all_speech) 46 | 47 | all_noise_dict = {} 48 | all_noise = fileList('/data/sls/scratch/yuangong/sslast2/egs/esc50/data/ESC-50-master/audio_16k/') 49 | for noise in all_noise: 50 | cla = int(noise.split('.')[-2].split('-')[-1]) 51 | if cla not in all_noise_dict: 52 | all_noise_dict[cla] = [noise] 53 | else: 54 | all_noise_dict[cla].append(noise) 55 | print(all_noise_dict[0]) 56 | print(len(all_noise_dict[0])) 57 | 58 | for db in [-20, -15, -10, -5, 0, 5, 10, 15, 20]: 59 | for cla in range(50): 60 | # if os.path.exists('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/{:d}/{:d}'.format(db,cla)) == False: 61 | # os.makedirs('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/{:d}/{:d}'.format(db,cla)) 62 | # for each snr, for each class, test 40 librispeech samples 63 | for idx in range(40): 64 | tar_name = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/' + str(db) + '_' + str(cla) + '_' + all_speech[idx].split('/')[-1].split('.')[-2] + '_mix_' + all_noise_dict[cla][idx].split('/')[-1].split('.')[-2] + '.wav' 65 | add_noise_torch(all_speech[idx], all_noise_dict[cla][idx], db, tar_name) -------------------------------------------------------------------------------- /src/noise_robust_asr/asr_experiments/transcribe_esc_hubert_xl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/20/23 2:09 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : transcribe_aus.py 7 | 8 | import sys 9 | argument = sys.argv[1] 10 | if argument=='4': 11 | argument='0,1,2,3' 12 | import os 13 | if argument != '-1': 14 | os.environ["CUDA_VISIBLE_DEVICES"]=argument 15 | 16 | import os 17 | import torch 18 | import soundfile 19 | from transformers import Wav2Vec2Processor, HubertForCTC 20 | 21 | def fileList(source): 22 | matches = [] 23 | for root, dirnames, filenames in os.walk(source): 24 | for filename in filenames: 25 | if filename.endswith(('.flac', '.wav')): 26 | matches.append(os.path.join(root, filename)) 27 | return matches 28 | 29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | print(device) 31 | 32 | processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-xlarge-ls960-ft") 33 | model = HubertForCTC.from_pretrained("facebook/hubert-xlarge-ls960-ft").to(device) 34 | 35 | tar_path = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_hubert_xlarge/' 36 | if os.path.exists(tar_path) == False: 37 | os.mkdir(tar_path) 38 | 39 | audio_list = fileList('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/') 40 | 41 | audio_list.sort() 42 | start_file = int(argument) * 4500 43 | end_file = int(argument) * 4500 + 4500 44 | audio_list = audio_list[start_file: end_file] 45 | 46 | print('number of files to transcribe: ', len(audio_list)) 47 | 48 | for i in range(len(audio_list)): 49 | audio_path = audio_list[i] 50 | if os.path.exists(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt') == False: 51 | if audio_path[-3:] == 'wav': 52 | audio_path = audio_list[i] 53 | source, curr_sample_rate = soundfile.read(audio_path, dtype="float32") 54 | input_features = processor(source, 55 | sampling_rate=curr_sample_rate, 56 | return_tensors="pt").input_values 57 | logits = model(input_features.to(device)).logits 58 | predicted_ids = torch.argmax(logits, dim=-1) 59 | text = processor.decode(predicted_ids[0]) 60 | with open(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt', "w") as text_file: 61 | text_file.write(text) 62 | del logits, input_features 63 | if i % 100 == 0: 64 | print("{:d} / {:d} processd from processor {:s}".format(i, len(audio_list), argument)) 65 | del model -------------------------------------------------------------------------------- /src/noise_robust_asr/asr_experiments/transcribe_hubert_large.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/20/23 2:09 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : transcribe_aus.py 7 | 8 | import sys 9 | argument = sys.argv[1] 10 | if argument=='4': 11 | argument='0,1,2,3' 12 | import os 13 | if argument != '-1': 14 | os.environ["CUDA_VISIBLE_DEVICES"]=argument 15 | import torch 16 | import soundfile 17 | from transformers import Wav2Vec2Processor, HubertForCTC 18 | 19 | def fileList(source): 20 | matches = [] 21 | for root, dirnames, filenames in os.walk(source): 22 | for filename in filenames: 23 | if filename.endswith(('.flac', '.wav')): 24 | matches.append(os.path.join(root, filename)) 25 | return matches 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | print(device) 29 | 30 | processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") 31 | model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device) 32 | 33 | tar_path = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_hubert_large/' 34 | if os.path.exists(tar_path) == False: 35 | os.mkdir(tar_path) 36 | 37 | audio_list = fileList('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/') 38 | 39 | audio_list.sort() 40 | start_file = int(argument) * 4500 41 | end_file = int(argument) * 4500 + 4500 42 | audio_list = audio_list[start_file: end_file] 43 | 44 | print('number of files to transcribe: ', len(audio_list)) 45 | 46 | for i in range(len(audio_list)): 47 | audio_path = audio_list[i] 48 | if os.path.exists(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt') == False: 49 | if audio_path[-3:] == 'wav': 50 | audio_path = audio_list[i] 51 | source, curr_sample_rate = soundfile.read(audio_path, dtype="float32") 52 | input_features = processor(source, 53 | sampling_rate=curr_sample_rate, 54 | return_tensors="pt").input_values 55 | logits = model(input_features.to(device)).logits 56 | predicted_ids = torch.argmax(logits, dim=-1) 57 | text = processor.decode(predicted_ids[0]) 58 | with open(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt', "w") as text_file: 59 | text_file.write(text) 60 | del logits, input_features 61 | if i % 100 == 0: 62 | print("{:d} / {:d} processd from processor {:s}".format(i, len(audio_list), argument)) 63 | del model -------------------------------------------------------------------------------- /src/noise_robust_asr/asr_experiments/transcribe_wav2vec_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/20/23 2:09 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : transcribe_aus.py 7 | 8 | import sys 9 | argument = sys.argv[1] 10 | if argument=='4': 11 | argument='0,1,2,3' 12 | import os 13 | if argument != '-1': 14 | os.environ["CUDA_VISIBLE_DEVICES"]=argument 15 | import torch 16 | import soundfile 17 | from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC 18 | 19 | def fileList(source): 20 | matches = [] 21 | for root, dirnames, filenames in os.walk(source): 22 | for filename in filenames: 23 | if filename.endswith(('.flac', '.wav')): 24 | matches.append(os.path.join(root, filename)) 25 | return matches 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | print(device) 29 | 30 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") 31 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device) 32 | 33 | tar_path = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_w2v_base/' 34 | if os.path.exists(tar_path) == False: 35 | os.mkdir(tar_path) 36 | 37 | audio_list = fileList('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/') 38 | 39 | audio_list.sort() 40 | start_file = int(argument) * 4500 41 | end_file = int(argument) * 4500 + 4500 42 | audio_list = audio_list[start_file: end_file] 43 | 44 | print('number of files to transcribe: ', len(audio_list)) 45 | 46 | for i in range(len(audio_list)): 47 | audio_path = audio_list[i] 48 | if os.path.exists(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt') == False: 49 | if audio_path[-3:] == 'wav': 50 | audio_path = audio_list[i] 51 | source, curr_sample_rate = soundfile.read(audio_path, dtype="float32") 52 | input_features = processor(source, 53 | sampling_rate=curr_sample_rate, 54 | return_tensors="pt").input_values 55 | logits = model(input_features.to(device)).logits 56 | predicted_ids = torch.argmax(logits, dim=-1) 57 | text = processor.decode(predicted_ids[0]) 58 | with open(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt', "w") as text_file: 59 | text_file.write(text) 60 | del logits, input_features 61 | if i % 100 == 0: 62 | print("{:d} / {:d} processd from processor {:s}".format(i, len(audio_list), argument)) 63 | del model -------------------------------------------------------------------------------- /src/noise_robust_asr/asr_experiments/transcribe_wav2vec_robust.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/20/23 2:09 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : transcribe_aus.py 7 | 8 | import sys 9 | argument = sys.argv[1] 10 | if argument=='4': 11 | argument='0,1,2,3' 12 | import os 13 | if argument != '-1': 14 | os.environ["CUDA_VISIBLE_DEVICES"]=argument 15 | import torch 16 | import soundfile 17 | from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC 18 | 19 | def fileList(source): 20 | matches = [] 21 | for root, dirnames, filenames in os.walk(source): 22 | for filename in filenames: 23 | if filename.endswith(('.flac', '.wav')): 24 | matches.append(os.path.join(root, filename)) 25 | return matches 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | print(device) 29 | 30 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h") 31 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h").to(device) 32 | 33 | tar_path = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_w2v_large_robust/' 34 | if os.path.exists(tar_path) == False: 35 | os.mkdir(tar_path) 36 | 37 | audio_list = fileList('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/') 38 | 39 | audio_list.sort() 40 | start_file = int(argument) * 4500 41 | end_file = int(argument) * 4500 + 4500 42 | audio_list = audio_list[start_file: end_file] 43 | 44 | print('number of files to transcribe: ', len(audio_list)) 45 | 46 | for i in range(len(audio_list)): 47 | audio_path = audio_list[i] 48 | if os.path.exists(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt') == False: 49 | if audio_path[-3:] == 'wav': 50 | audio_path = audio_list[i] 51 | source, curr_sample_rate = soundfile.read(audio_path, dtype="float32") 52 | input_features = processor(source, 53 | sampling_rate=curr_sample_rate, 54 | return_tensors="pt").input_values 55 | logits = model(input_features.to(device)).logits 56 | predicted_ids = torch.argmax(logits, dim=-1) 57 | text = processor.decode(predicted_ids[0]) 58 | with open(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt', "w") as text_file: 59 | text_file.write(text) 60 | del logits, input_features 61 | if i % 100 == 0: 62 | print("{:d} / {:d} processd from processor {:s}".format(i, len(audio_list), argument)) 63 | del model -------------------------------------------------------------------------------- /src/noise_robust_asr/asr_experiments/transcribe_whisper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/20/23 2:09 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : transcribe_aus.py 7 | 8 | # transcribe adress-m datasets 9 | import sys 10 | argument = sys.argv[1] 11 | if argument=='4': 12 | argument='0,1,2,3' 13 | import os 14 | if argument != '-1': 15 | os.environ["CUDA_VISIBLE_DEVICES"]=argument 16 | 17 | os.environ["XDG_CACHE_HOME"] = './' 18 | import ssl 19 | ssl._create_default_https_context = ssl._create_unverified_context 20 | import torch 21 | import whisper 22 | import torchaudio 23 | 24 | def show_twod(input): 25 | return "{:.2f}".format(input) 26 | 27 | def get_immediate_files(a_dir): 28 | return [a_dir + '/' + name for name in os.listdir(a_dir) if os.path.isfile(os.path.join(a_dir, name))] 29 | 30 | def fileList(source): 31 | matches = [] 32 | for root, dirnames, filenames in os.walk(source): 33 | for filename in filenames: 34 | if filename.endswith(('.flac', '.wav')): 35 | matches.append(os.path.join(root, filename)) 36 | return matches 37 | 38 | audio_list = fileList('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/') 39 | 40 | audio_list.sort() 41 | 42 | print('number of files to transcribe: ', len(audio_list)) 43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 44 | print(device) 45 | 46 | whisper_mdl_list = ['base.en'] # medium.en, small.en, base.en, tiny.en, medium, large-v1 47 | 48 | for mdl_size in whisper_mdl_list: 49 | model = whisper.load_model(mdl_size, device) 50 | for beam_size in [0]: 51 | tar_path = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_{:s}/'.format(mdl_size) 52 | if os.path.exists(tar_path) == False: 53 | os.mkdir(tar_path) 54 | for i in range(len(audio_list)): 55 | audio_path = audio_list[i] 56 | if os.path.exists(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt') == False: 57 | if audio_path[-3:] == 'wav': 58 | ori_waveform, sr = torchaudio.load(audio_path) 59 | wav_len = ori_waveform.shape[1] 60 | assert sr == 16000 61 | if beam_size == 0: 62 | result = model.transcribe(audio_path, language='en') 63 | else: 64 | result = model.transcribe(audio_path, beam_size=beam_size) 65 | 66 | # remove the first space 67 | text = result["text"][1:] 68 | if os.path.exists(tar_path) == False: 69 | os.mkdir(tar_path) 70 | with open(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt', "w") as text_file: 71 | text_file.write(text) 72 | if i % 100 == 0: 73 | print("{:d} / {:d} processd from processor {:s}".format(i, len(audio_list), argument)) 74 | del model 75 | -------------------------------------------------------------------------------- /src/noise_robust_asr/baseline_sound_classification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/20/23 1:35 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : cluster_esc50_feat2.py 7 | 8 | # use new all layer feat, note these feats are already pooled over time 9 | # for whisper, w2v, and hubert models 10 | 11 | import json 12 | import os 13 | os.environ["XDG_CACHE_HOME"] = './' 14 | import numpy as np 15 | 16 | from sklearn.pipeline import Pipeline 17 | from sklearn.metrics import accuracy_score, classification_report 18 | 19 | from sklearn.preprocessing import StandardScaler 20 | from sklearn.neural_network import MLPClassifier 21 | 22 | def cluster_feat(dataset_json_file, tar_path): 23 | with open(dataset_json_file, 'r') as fp: 24 | data_json = json.load(fp) 25 | data = data_json['data'] 26 | num_sample = len(data) 27 | for idx, entry in enumerate(data): 28 | wav = entry["wav"] 29 | # the first sample 30 | if idx == 0: 31 | cur_sample = np.load(tar_path + '/' + wav.split('/')[-1][:-3] + 'npy') 32 | num_layer = cur_sample.shape[0] 33 | feat_dim = cur_sample.shape[-1] 34 | print('number of layers {:d} feat dim {:d}'.format(num_layer, feat_dim)) 35 | all_feat = np.zeros((num_layer + 1, num_sample, feat_dim)) 36 | all_label = [] 37 | 38 | cur_rep = np.load(tar_path + '/' + wav.split('/')[-1][:-3] + 'npy') 39 | for layer in range(cur_rep.shape[0]): 40 | all_feat[layer, idx] = np.mean(cur_rep[layer], axis=0) 41 | 42 | all_feat[-1, idx] = np.mean(np.mean(cur_rep, axis=0), axis=0) 43 | 44 | cur_label = int(wav.split('.')[-2].split('-')[-1]) 45 | all_label.append(cur_label) 46 | 47 | assert all_feat[0].shape[0] == len(all_label) 48 | return all_feat, all_label 49 | 50 | def get_immediate_dir(a_dir): 51 | return [name for name in os.listdir(a_dir) if os.path.isdir(os.path.join(a_dir, name))] 52 | 53 | mdl_size_list = ['wav2vec2-large-robust-ft-swbd-300h'] 54 | 55 | for mdl_size in mdl_size_list: 56 | for fold in range(1, 6): 57 | print(mdl_size) 58 | if 'whisper' not in mdl_size: 59 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_all/' + mdl_size + '_all_layer/' 60 | else: 61 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_all/' + mdl_size + '/' 62 | esc_train1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_train_data_' + str(fold) + '.json' 63 | esc_eval1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_eval_data_' + str(fold) + '.json' 64 | all_tr_feat, all_tr_label = cluster_feat(esc_train1, tar_path) 65 | all_te_feat, all_te_label = cluster_feat(esc_eval1, tar_path) 66 | 67 | num_layer = all_tr_feat.shape[0] 68 | print(all_tr_feat.shape, all_te_feat.shape) 69 | 70 | for lr in [0.001]: 71 | all_res = [] 72 | for layer in range(num_layer): 73 | cla = MLPClassifier(hidden_layer_sizes=(), learning_rate='adaptive', learning_rate_init=lr, max_iter=5000, random_state=0) 74 | pipe = Pipeline([('scaler', StandardScaler()), ('svc', cla)]) 75 | pipe.fit(all_tr_feat[layer], all_tr_label) 76 | pred = pipe.predict(all_te_feat[layer]) 77 | acc = accuracy_score(all_te_label, pred) 78 | all_acc = classification_report(all_te_label, pred, output_dict=True) 79 | all_acc = [all_acc[str(x)]['f1-score'] for x in range(50)] 80 | res = [mdl_size, fold, all_te_feat[0].shape[1], lr, layer, acc] + all_acc 81 | all_res.append(res) 82 | np.savetxt('./baseline_res/esc_{:s}_fold{:d}_lr_{:.4f}.csv'.format(mdl_size, fold, lr), all_res, delimiter=',', fmt='%s') -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/as_full/batch_as_full_extract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # need to batch to speed up processing 4 | 5 | for((fold=0;fold<=39;fold++)); 6 | do 7 | sbatch extract_as_full_whisper_all.sh ${fold} 8 | done -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/as_full/extract_as_full_whisper_all.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/19/23 11:35 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : extract_esc50.py 7 | 8 | # extract representation for all layers for whisper model, pool by 20, not include the input mel. 9 | # save as npz to save space 10 | 11 | import json 12 | import torch 13 | import os 14 | os.environ["XDG_CACHE_HOME"] = './' 15 | import numpy as np 16 | from whisper.model import Whisper, ModelDimensions 17 | import skimage.measure 18 | import argparse 19 | 20 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 21 | parser.add_argument("--split", type=int, default=0, help="which split") 22 | args = parser.parse_args() 23 | 24 | def extract_audio(dataset_json_file, mdl, tar_path): 25 | if os.path.exists(tar_path) == False: 26 | os.mkdir((tar_path)) 27 | with open(dataset_json_file, 'r') as fp: 28 | data_json = json.load(fp) 29 | data = data_json['data'] 30 | for idx, entry in enumerate(data): 31 | wav = entry["wav"] 32 | 33 | if os.path.exists(tar_path + '/' + wav.split('/')[-1][:-4] + 'npz') == False: 34 | _, audio_rep = mdl.transcribe_audio(wav) 35 | audio_rep = audio_rep[0] 36 | audio_rep = torch.permute(audio_rep, (2, 0, 1)).detach().cpu().numpy() 37 | audio_rep = skimage.measure.block_reduce(audio_rep, (1, 20, 1), np.mean) 38 | audio_rep = audio_rep[1:] 39 | if idx == 0: 40 | print(audio_rep.shape) 41 | np.savez_compressed(tar_path + '/' + wav.split('/')[-1][:-4] + 'npz', audio_rep) 42 | if idx % 50 == 0: 43 | print(idx) 44 | 45 | mdl_size_list = ['medium'] # , 'large-v1', 'medium.en' 46 | mdl_size_list = mdl_size_list[::-1] 47 | for mdl_size in mdl_size_list: 48 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 49 | print(device) 50 | checkpoint_path = '/data/sls/scratch/yuangong/whisper-a/src/{:s}.pt'.format(mdl_size) 51 | checkpoint = torch.load(checkpoint_path, map_location=device) 52 | dims = ModelDimensions(**checkpoint["dims"]) 53 | model = Whisper(dims) 54 | model.load_state_dict(checkpoint["model_state_dict"], strict=False) 55 | model.to(device) 56 | 57 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_as_full/' + 'whisper_' + mdl_size + '/' 58 | esc_train1 = '/data/sls/scratch/yuangong/whisper-a/egs/audioset/feat_extract/split_json/{:d}.json'.format(args.split) 59 | extract_audio(esc_train1, model, tar_path) 60 | del model -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/as_full/extract_as_full_whisper_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ##SBATCH -p 1080,sm 3 | ##SBATCH -x sls-sm-1,sls-2080-[1,3],sls-1080-[1,2],sls-sm-[1,2,11,13] 4 | #SBATCH -p gpu 5 | #SBATCH --gres=gpu:1 6 | #SBATCH -c 1 7 | #SBATCH -n 1 8 | 9 | #SBATCH --mem=24000 10 | #SBATCH --job-name="as_extract" 11 | #SBATCH --output=../../log/%j_as_extract.txt 12 | 13 | set -x 14 | # comment this line if not running on sls cluster 15 | #. /data/sls/scratch/share-201907/slstoolchainrc 16 | #source /data/sls/scratch/yuangong/whisper-a/venv-wa/bin/activate 17 | export TORCH_HOME=../../pretrained_models 18 | 19 | python extract_as_full_whisper_all.py --split $1 -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/esc-50/extract_esc50_hubert_xl_all_pool.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/19/23 11:35 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : extract_esc50.py 7 | 8 | # extract representation for all layers from hubert xl 9 | 10 | import json 11 | import torch 12 | import os 13 | os.environ["XDG_CACHE_HOME"] = './' 14 | import numpy as np 15 | from transformers import Wav2Vec2Processor, HubertModel 16 | import soundfile as sf 17 | import skimage.measure 18 | 19 | def extract_audio(dataset_json_file, model, processor, tar_path): 20 | if os.path.exists(tar_path) == False: 21 | os.mkdir((tar_path)) 22 | with open(dataset_json_file, 'r') as fp: 23 | data_json = json.load(fp) 24 | data = data_json['data'] 25 | for idx, entry in enumerate(data): 26 | wav = entry["wav"] 27 | audio, sr = sf.read(wav) 28 | assert sr == 16000 29 | 30 | input_values = processor(audio, sampling_rate=sr, return_tensors="pt").input_values.to(device) # Batch size 1 31 | audio_rep = model(input_values, output_hidden_states=True).hidden_states 32 | audio_rep = torch.stack(audio_rep, dim=0).squeeze(1) 33 | audio_rep = audio_rep.detach().cpu().numpy() 34 | audio_rep = skimage.measure.block_reduce(audio_rep, (1, 10, 1), np.mean) 35 | audio_rep = audio_rep[1:] 36 | np.savez_compressed(tar_path + '/' + wav.split('/')[-1][:-3] + 'npz', audio_rep) 37 | 38 | mdl_size_list = ['facebook/hubert-xlarge-ls960-ft'] 39 | for mdl_size in mdl_size_list: 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-xlarge-ls960-ft") 42 | model = HubertModel.from_pretrained(mdl_size) 43 | model.to(device) 44 | model.eval() 45 | 46 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/' + mdl_size.split('/')[-1] + '/' 47 | esc_train1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_train_data_1.json' 48 | esc_eval1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_eval_data_1.json' 49 | extract_audio(esc_train1, model, processor, tar_path) 50 | extract_audio(esc_eval1, model, processor, tar_path) 51 | del model -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/esc-50/extract_esc50_w2v_robust_all.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/19/23 11:35 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : extract_esc50.py 7 | 8 | # extract representation for all layers from hubert xl 9 | 10 | import json 11 | import torch 12 | import os 13 | os.environ["XDG_CACHE_HOME"] = './' 14 | import numpy as np 15 | from transformers import Wav2Vec2Processor, Wav2Vec2Model 16 | import soundfile as sf 17 | import skimage.measure 18 | 19 | def extract_audio(dataset_json_file, model, processor, tar_path): 20 | if os.path.exists(tar_path) == False: 21 | os.mkdir((tar_path)) 22 | with open(dataset_json_file, 'r') as fp: 23 | data_json = json.load(fp) 24 | data = data_json['data'] 25 | for idx, entry in enumerate(data): 26 | wav = entry["wav"] 27 | audio, sr = sf.read(wav) 28 | assert sr == 16000 29 | 30 | input_values = processor(audio, sampling_rate=sr, return_tensors="pt").input_values.to(device) 31 | audio_rep = model(input_values, output_hidden_states=True).hidden_states 32 | audio_rep = torch.stack(audio_rep, dim=0).squeeze(1) 33 | audio_rep = audio_rep.detach().cpu().numpy() 34 | audio_rep = skimage.measure.block_reduce(audio_rep, (1, 10, 1), np.mean) 35 | audio_rep = audio_rep[1:] 36 | np.savez_compressed(tar_path + '/' + wav.split('/')[-1][:-3] + 'npz', audio_rep) 37 | 38 | mdl_size_list = ['facebook/wav2vec2-large-robust-ft-swbd-300h'] 39 | for mdl_size in mdl_size_list: 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h") 42 | model = Wav2Vec2Model.from_pretrained(mdl_size) 43 | model.to(device) 44 | model.eval() 45 | 46 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/' + mdl_size.split('/')[-1] + '/' 47 | esc_train1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_train_data_1.json' 48 | esc_eval1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_eval_data_1.json' 49 | extract_audio(esc_train1, model, processor, tar_path) 50 | extract_audio(esc_eval1, model, processor, tar_path) 51 | del model -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/esc-50/extract_esc50_whisper_all_pool.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/19/23 11:35 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : extract_esc50.py 7 | 8 | # extract representation for all layers for whisper model, pool by 10, not include the input mel. 9 | # save as npz to save space 10 | 11 | import json 12 | import torch 13 | import os 14 | os.environ["XDG_CACHE_HOME"] = './' 15 | import whisper 16 | import numpy as np 17 | from whisper.model import Whisper, ModelDimensions 18 | import skimage.measure 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument("--split", type=int, default=0, help="which split") 23 | args = parser.parse_args() 24 | 25 | def extract_audio(dataset_json_file, mdl, tar_path): 26 | if os.path.exists(tar_path) == False: 27 | os.mkdir((tar_path)) 28 | with open(dataset_json_file, 'r') as fp: 29 | data_json = json.load(fp) 30 | data = data_json['data'] 31 | for idx, entry in enumerate(data): 32 | wav = entry["wav"] 33 | 34 | if os.path.exists(tar_path + '/' + wav.split('/')[-1][:-4] + 'npz') == False: 35 | # NOTE: this use a customized whisper model for feature extraction, original whisper model does not have transcribe_audio function 36 | _, audio_rep = mdl.transcribe_audio(wav) 37 | audio_rep = audio_rep[0] 38 | audio_rep = torch.permute(audio_rep, (2, 0, 1)).detach().cpu().numpy() 39 | audio_rep = skimage.measure.block_reduce(audio_rep, (1, 10, 1), np.mean) # downsample x10 for esc, 20 for audioset 40 | audio_rep = audio_rep[1:] 41 | np.savez_compressed(tar_path + '/' + wav.split('/')[-1][:-3] + 'npz', audio_rep) 42 | 43 | mdl_size_list = ['large-v2', 'large-v1', 'medium.en', 'medium', 'small.en', 'small', 'base.en', 'base', 'tiny.en', 'tiny'] 44 | for mdl_size in mdl_size_list: 45 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 46 | print(device) 47 | checkpoint_path = '/data/sls/scratch/yuangong/whisper-a/src/{:s}.pt'.format(mdl_size) 48 | checkpoint = torch.load(checkpoint_path, map_location=device) 49 | dims = ModelDimensions(**checkpoint["dims"]) 50 | model = Whisper(dims) # NOTE: this use a customized whisper model for feature extraction, original whisper model does not have transcribe_audio function 51 | model.load_state_dict(checkpoint["model_state_dict"], strict=False) 52 | model.to(device) 53 | 54 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/' + 'whisper_' + mdl_size + '/' 55 | esc_train1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_train_data_1.json' 56 | esc_eval1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_eval_data_1.json' # esc-50 is 5-fold cross-validation, so 1st train and eval split covers all datas 57 | extract_audio(esc_train1, model, tar_path) 58 | extract_audio(esc_eval1, model, tar_path) 59 | del model -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions-ecosystem/action-regex-match@v2 13 | id: regex-match 14 | with: 15 | text: ${{ github.event.head_commit.message }} 16 | regex: '^Release ([^ ]+)' 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Release 26 | if: ${{ steps.regex-match.outputs.match != '' }} 27 | uses: softprops/action-gh-release@v1 28 | with: 29 | tag_name: v${{ steps.regex-match.outputs.group1 }} 30 | - name: Build and publish 31 | if: ${{ steps.regex-match.outputs.match != '' }} 32 | env: 33 | TWINE_USERNAME: __token__ 34 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 35 | run: | 36 | python setup.py sdist 37 | twine upload dist/* 38 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | jobs: 10 | whisper-test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ['3.8', '3.9', '3.10'] 15 | pytorch-version: [1.10.2, 1.13.1] 16 | exclude: 17 | - python-version: '3.10' 18 | pytorch-version: 1.10.2 19 | steps: 20 | - uses: conda-incubator/setup-miniconda@v2 21 | - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch 22 | - uses: actions/checkout@v2 23 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH 24 | - run: pip install pytest 25 | - run: pip install . 26 | - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' 27 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include README.md 3 | include LICENSE 4 | include whisper/assets/* 5 | include whisper/assets/gpt2/* 6 | include whisper/assets/multilingual/* 7 | include whisper/normalizers/english.json 8 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/data/README.md: -------------------------------------------------------------------------------- 1 | This directory supplements the paper with more details on how we prepared the data for evaluation, to help replicate our experiments. 2 | 3 | ## Short-form English-only datasets 4 | 5 | ### LibriSpeech 6 | 7 | We used the test-clean and test-other splits from the [LibriSpeech ASR corpus](https://www.openslr.org/12). 8 | 9 | ### TED-LIUM 3 10 | 11 | We used the test split of [TED-LIUM Release 3](https://www.openslr.org/51/), using the segmented manual transcripts included in the release. 12 | 13 | ### Common Voice 5.1 14 | 15 | We downloaded the English subset of Common Voice Corpus 5.1 from [the official website](https://commonvoice.mozilla.org/en/datasets) 16 | 17 | ### Artie 18 | 19 | We used the [Artie bias corpus](https://github.com/artie-inc/artie-bias-corpus). This is a subset of the Common Voice dataset. 20 | 21 | ### CallHome & Switchboard 22 | 23 | We used the two corpora from [LDC2002S09](https://catalog.ldc.upenn.edu/LDC2002S09) and [LDC2002T43](https://catalog.ldc.upenn.edu/LDC2002T43) and followed the [eval2000_data_prep.sh](https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/eval2000_data_prep.sh) script for preprocessing. The `wav.scp` files can be converted to WAV files with the following bash commands: 24 | 25 | ```bash 26 | mkdir -p wav 27 | while read name cmd; do 28 | echo $name 29 | echo ${cmd/\|/} wav/$name.wav | bash 30 | done < wav.scp 31 | ``` 32 | 33 | 34 | ### WSJ 35 | 36 | We used [LDC93S6B](https://catalog.ldc.upenn.edu/LDC93S6B) and [LDC94S13B](https://catalog.ldc.upenn.edu/LDC94S13B) and followed the [s5 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/wsj/s5) to preprocess the dataset. 37 | 38 | ### CORAAL 39 | 40 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the segmentations from [the FairSpeech project](https://github.com/stanford-policylab/asr-disparities/blob/master/input/CORAAL_transcripts.csv). 41 | 42 | ### CHiME-6 43 | 44 | We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challenge/CHiME5/download.html) and followed the stage 0 of the [s5_track1 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/chime6/s5_track1) to create the CHiME-6 dataset which fixes synchronization. We then used the binaural recordings (`*_P??.wav`) and the corresponding transcripts. 45 | 46 | ### AMI-IHM, AMI-SDM1 47 | 48 | We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b). 49 | 50 | 51 | ## Long-form English-only datasets 52 | 53 | ### TED-LIUM 3 54 | 55 | To create a long-form transcription dataset from the [TED-LIUM3](https://www.openslr.org/51/) dataset, we sliced the audio between the beginning of the first labeled segment and the end of the last labeled segment of each talk, and we used the concatenated text as the label. Below are the timestamps used for slicing each of the 11 TED talks in the test split. 56 | 57 | | Filename | Begin time (s) | End time (s) | 58 | |---------------------|----------------|--------------| 59 | | DanBarber_2010 | 16.09 | 1116.24 | 60 | | JaneMcGonigal_2010 | 15.476 | 1187.61 | 61 | | BillGates_2010 | 15.861 | 1656.94 | 62 | | TomWujec_2010U | 16.26 | 402.17 | 63 | | GaryFlake_2010 | 16.06 | 367.14 | 64 | | EricMead_2009P | 18.434 | 536.44 | 65 | | MichaelSpecter_2010 | 16.11 | 979.312 | 66 | | DanielKahneman_2010 | 15.8 | 1199.44 | 67 | | AimeeMullins_2009P | 17.82 | 1296.59 | 68 | | JamesCameron_2010 | 16.75 | 1010.65 | 69 | | RobertGupta_2010U | 16.8 | 387.03 | 70 | 71 | ### Meanwhile 72 | 73 | This dataset consists of 64 segments from The Late Show with Stephen Colbert. The YouTube video ID, start and end timestamps, and the labels can be found in [meanwhile.json](meanwhile.json). The labels are collected from the closed-caption data for each video and corrected with manual inspection. 74 | 75 | ### Rev16 76 | 77 | We use a subset of 16 files from the 30 podcast episodes in [Rev.AI's Podcast Transcription Benchmark](https://www.rev.ai/blog/podcast-transcription-benchmark-part-1/), after finding that there are multiple cases where a significant portion of the audio and the labels did not match, mostly on the parts introducing the sponsors. We selected 16 episodes that do not have this error, whose "file number" are: 78 | 79 | 3 4 9 10 11 14 17 18 20 21 23 24 26 27 29 32 80 | 81 | ### Kincaid46 82 | 83 | This dataset consists of 46 audio files and the corresponding transcripts compiled in the blog article [Which automatic transcription service is the most accurate - 2018](https://medium.com/descript/which-automatic-transcription-service-is-the-most-accurate-2018-2e859b23ed19) by Jason Kincaid. We used the 46 audio files and reference transcripts from the Airtable widget in the article. 84 | 85 | For the human transcription benchmark in the paper, we use a subset of 25 examples from this data, whose "Ref ID" are: 86 | 87 | 2 4 5 8 9 10 12 13 14 16 19 21 23 25 26 28 29 30 33 35 36 37 42 43 45 88 | 89 | ### Earnings-21, Earnings-22 90 | 91 | For these datasets, we used the files available in [the speech-datasets repository](https://github.com/revdotcom/speech-datasets), as of their `202206` version. 92 | 93 | ### CORAAL 94 | 95 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the full-length interview files and transcripts. 96 | 97 | 98 | ## Multilingual datasets 99 | 100 | ### Multilingual LibriSpeech 101 | 102 | We used the test splits from each language in [the Multilingual LibriSpeech (MLS) corpus](https://www.openslr.org/94/). 103 | 104 | ### Fleurs 105 | 106 | We collected audio files and transcripts using the implementation available as [HuggingFace datasets](https://huggingface.co/datasets/google/fleurs/blob/main/fleurs.py). To use as a translation dataset, we matched the numerical utterance IDs to find the corresponding transcript in English. 107 | 108 | ### VoxPopuli 109 | 110 | We used the `get_asr_data.py` script from [the official repository](https://github.com/facebookresearch/voxpopuli) to collect the ASR data in 14 languages. 111 | 112 | ### Common Voice 9 113 | 114 | We downloaded the Common Voice Corpus 9 from [the official website](https://commonvoice.mozilla.org/en/datasets) 115 | 116 | ### CoVOST 2 117 | 118 | We collected the `X into English` data collected using [the official repository](https://github.com/facebookresearch/covost). 119 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | tqdm 4 | more-itertools 5 | transformers>=4.19.0 6 | ffmpeg-python==0.2.0 7 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | 7 | def read_version(fname="whisper/version.py"): 8 | exec(compile(open(fname, encoding="utf-8").read(), fname, "exec")) 9 | return locals()["__version__"] 10 | 11 | 12 | setup( 13 | name="openai-whisper", 14 | py_modules=["whisper"], 15 | version=read_version(), 16 | description="Robust Speech Recognition via Large-Scale Weak Supervision", 17 | long_description=open("README.md", encoding="utf-8").read(), 18 | long_description_content_type="text/markdown", 19 | readme="README.md", 20 | python_requires=">=3.7", 21 | author="OpenAI", 22 | url="https://github.com/openai/whisper", 23 | license="MIT", 24 | packages=find_packages(exclude=["tests*"]), 25 | install_requires=[ 26 | str(r) 27 | for r in pkg_resources.parse_requirements( 28 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 29 | ) 30 | ], 31 | entry_points={ 32 | "console_scripts": ["whisper=whisper.transcribe:cli"], 33 | }, 34 | include_package_data=True, 35 | extras_require={"dev": ["pytest"]}, 36 | ) 37 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/jfk.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/jfk.flac -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/test_audio.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import numpy as np 4 | 5 | from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE 6 | 7 | 8 | def test_audio(): 9 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") 10 | audio = load_audio(audio_path) 11 | assert audio.ndim == 1 12 | assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12 13 | assert 0 < audio.std() < 1 14 | 15 | mel_from_audio = log_mel_spectrogram(audio) 16 | mel_from_file = log_mel_spectrogram(audio_path) 17 | 18 | assert np.allclose(mel_from_audio, mel_from_file) 19 | assert mel_from_audio.max() - mel_from_audio.min() <= 2.0 20 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/test_normalizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from whisper.normalizers import EnglishTextNormalizer 4 | from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer 5 | 6 | 7 | @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()]) 8 | def test_number_normalizer(std): 9 | assert std("two") == "2" 10 | assert std("thirty one") == "31" 11 | assert std("five twenty four") == "524" 12 | assert std("nineteen ninety nine") == "1999" 13 | assert std("twenty nineteen") == "2019" 14 | 15 | assert std("two point five million") == "2500000" 16 | assert std("four point two billions") == "4200000000s" 17 | assert std("200 thousand") == "200000" 18 | assert std("200 thousand dollars") == "$200000" 19 | assert std("$20 million") == "$20000000" 20 | assert std("€52.4 million") == "€52400000" 21 | assert std("£77 thousands") == "£77000s" 22 | 23 | assert std("two double o eight") == "2008" 24 | 25 | assert std("three thousand twenty nine") == "3029" 26 | assert std("forty three thousand two hundred sixty") == "43260" 27 | assert std("forty three thousand two hundred and sixty") == "43260" 28 | 29 | assert std("nineteen fifties") == "1950s" 30 | assert std("thirty first") == "31st" 31 | assert std("thirty three thousand and three hundred and thirty third") == "33333rd" 32 | 33 | assert std("three billion") == "3000000000" 34 | assert std("millions") == "1000000s" 35 | 36 | assert std("july third twenty twenty") == "july 3rd 2020" 37 | assert std("august twenty sixth twenty twenty one") == "august 26th 2021" 38 | assert std("3 14") == "3 14" 39 | assert std("3.14") == "3.14" 40 | assert std("3 point 2") == "3.2" 41 | assert std("3 point 14") == "3.14" 42 | assert std("fourteen point 4") == "14.4" 43 | assert std("two point two five dollars") == "$2.25" 44 | assert std("two hundred million dollars") == "$200000000" 45 | assert std("$20.1 million") == "$20100000" 46 | 47 | assert std("ninety percent") == "90%" 48 | assert std("seventy six per cent") == "76%" 49 | 50 | assert std("double oh seven") == "007" 51 | assert std("double zero seven") == "007" 52 | assert std("nine one one") == "911" 53 | assert std("nine double one") == "911" 54 | assert std("one triple oh one") == "10001" 55 | 56 | assert std("two thousandth") == "2000th" 57 | assert std("thirty two thousandth") == "32000th" 58 | 59 | assert std("minus 500") == "-500" 60 | assert std("positive twenty thousand") == "+20000" 61 | 62 | assert std("two dollars and seventy cents") == "$2.70" 63 | assert std("3 cents") == "¢3" 64 | assert std("$0.36") == "¢36" 65 | assert std("three euros and sixty five cents") == "€3.65" 66 | 67 | assert std("three and a half million") == "3500000" 68 | assert std("forty eight and a half dollars") == "$48.5" 69 | assert std("b747") == "b 747" 70 | assert std("10 th") == "10th" 71 | assert std("10th") == "10th" 72 | 73 | 74 | def test_spelling_normalizer(): 75 | std = EnglishSpellingNormalizer() 76 | 77 | assert std("mobilisation") == "mobilization" 78 | assert std("cancelation") == "cancellation" 79 | 80 | 81 | def test_text_normalizer(): 82 | std = EnglishTextNormalizer() 83 | assert std("Let's") == "let us" 84 | assert std("he's like") == "he is like" 85 | assert std("she's been like") == "she has been like" 86 | assert std("10km") == "10 km" 87 | assert std("10mm") == "10 mm" 88 | assert std("RC232") == "rc 232" 89 | 90 | assert ( 91 | std("Mr. Park visited Assoc. Prof. Kim Jr.") 92 | == "mister park visited associate professor kim junior" 93 | ) 94 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | from whisper.tokenizer import get_tokenizer 2 | 3 | 4 | def test_tokenizer(): 5 | gpt2_tokenizer = get_tokenizer(multilingual=False) 6 | multilingual_tokenizer = get_tokenizer(multilingual=True) 7 | 8 | text = "다람쥐 헌 쳇바퀴에 타고파" 9 | gpt2_tokens = gpt2_tokenizer.encode(text) 10 | multilingual_tokens = multilingual_tokenizer.encode(text) 11 | 12 | assert gpt2_tokenizer.decode(gpt2_tokens) == text 13 | assert multilingual_tokenizer.decode(multilingual_tokens) == text 14 | assert len(gpt2_tokens) > len(multilingual_tokens) 15 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/test_transcribe.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | 6 | import whisper 7 | 8 | 9 | @pytest.mark.parametrize("model_name", whisper.available_models()) 10 | def test_transcribe(model_name: str): 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | model = whisper.load_model(model_name).to(device) 13 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") 14 | 15 | language = "en" if model_name.endswith(".en") else None 16 | result = model.transcribe(audio_path, language=language, temperature=0.0) 17 | assert result["language"] == "en" 18 | 19 | transcription = result["text"].lower() 20 | assert "my fellow americans" in transcription 21 | assert "your country" in transcription 22 | assert "do for you" in transcription 23 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim 12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language 13 | from .model import Whisper, ModelDimensions 14 | from .transcribe import transcribe 15 | from .transcribe import transcribe_audio 16 | from .version import __version__ 17 | 18 | 19 | _MODELS = { 20 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 21 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 22 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 23 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 24 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 25 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 26 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 27 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 28 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", 29 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 30 | "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 31 | } 32 | 33 | 34 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 35 | os.makedirs(root, exist_ok=True) 36 | 37 | expected_sha256 = url.split("/")[-2] 38 | download_target = os.path.join(root, os.path.basename(url)) 39 | 40 | if os.path.exists(download_target) and not os.path.isfile(download_target): 41 | raise RuntimeError(f"{download_target} exists and is not a regular file") 42 | 43 | if os.path.isfile(download_target): 44 | with open(download_target, "rb") as f: 45 | model_bytes = f.read() 46 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 47 | return model_bytes if in_memory else download_target 48 | else: 49 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 50 | 51 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 52 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 53 | while True: 54 | buffer = source.read(8192) 55 | if not buffer: 56 | break 57 | 58 | output.write(buffer) 59 | loop.update(len(buffer)) 60 | 61 | model_bytes = open(download_target, "rb").read() 62 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 63 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") 64 | 65 | return model_bytes if in_memory else download_target 66 | 67 | 68 | def available_models() -> List[str]: 69 | """Returns the names of available models""" 70 | return list(_MODELS.keys()) 71 | 72 | 73 | def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper: 74 | """ 75 | Load a Whisper ASR model 76 | 77 | Parameters 78 | ---------- 79 | name : str 80 | one of the official model names listed by `whisper.available_models()`, or 81 | path to a model checkpoint containing the model dimensions and the model state_dict. 82 | device : Union[str, torch.device] 83 | the PyTorch device to put the model into 84 | download_root: str 85 | path to download the model files; by default, it uses "~/.cache/whisper" 86 | in_memory: bool 87 | whether to preload the model weights into host memory 88 | 89 | Returns 90 | ------- 91 | model : Whisper 92 | The Whisper ASR model instance 93 | """ 94 | 95 | if device is None: 96 | device = "cuda" if torch.cuda.is_available() else "cpu" 97 | if download_root is None: 98 | download_root = os.getenv( 99 | "XDG_CACHE_HOME", 100 | os.path.join(os.path.expanduser("~"), ".cache", "whisper") 101 | ) 102 | 103 | if name in _MODELS: 104 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 105 | elif os.path.isfile(name): 106 | checkpoint_file = open(name, "rb").read() if in_memory else name 107 | else: 108 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 109 | 110 | with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp: 111 | checkpoint = torch.load(fp, map_location=device) 112 | del checkpoint_file 113 | 114 | dims = ModelDimensions(**checkpoint["dims"]) 115 | model = Whisper(dims) 116 | model.load_state_dict(checkpoint["model_state_dict"], strict=False) 117 | 118 | return model.to(device) 119 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | 4 | cli() 5 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/gpt2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/gpt2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/multilingual/added_tokens.json: -------------------------------------------------------------------------------- 1 | {"<|endoftext|>": 50257} 2 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/multilingual/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/multilingual/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Union 4 | 5 | import ffmpeg 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | N_MELS = 80 16 | HOP_LENGTH = 160 17 | CHUNK_LENGTH = 30 18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk 19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input 20 | 21 | 22 | def load_audio(file: str, sr: int = SAMPLE_RATE): 23 | """ 24 | Open an audio file and read as mono waveform, resampling as necessary 25 | 26 | Parameters 27 | ---------- 28 | file: str 29 | The audio file to open 30 | 31 | sr: int 32 | The sample rate to resample the audio if necessary 33 | 34 | Returns 35 | ------- 36 | A NumPy array containing the audio waveform, in float32 dtype. 37 | """ 38 | try: 39 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 40 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 41 | out, _ = ( 42 | ffmpeg.input(file, threads=0) 43 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) 44 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) 45 | ) 46 | except ffmpeg.Error as e: 47 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 48 | 49 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 50 | 51 | 52 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 53 | """ 54 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 55 | """ 56 | if torch.is_tensor(array): 57 | if array.shape[axis] > length: 58 | array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) 59 | 60 | if array.shape[axis] < length: 61 | pad_widths = [(0, 0)] * array.ndim 62 | pad_widths[axis] = (0, length - array.shape[axis]) 63 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 64 | else: 65 | if array.shape[axis] > length: 66 | array = array.take(indices=range(length), axis=axis) 67 | 68 | if array.shape[axis] < length: 69 | pad_widths = [(0, 0)] * array.ndim 70 | pad_widths[axis] = (0, length - array.shape[axis]) 71 | array = np.pad(array, pad_widths) 72 | 73 | return array 74 | 75 | 76 | @lru_cache(maxsize=None) 77 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 78 | """ 79 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 80 | Allows decoupling librosa dependency; saved using: 81 | 82 | np.savez_compressed( 83 | "mel_filters.npz", 84 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 85 | ) 86 | """ 87 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 88 | with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: 89 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 90 | 91 | 92 | def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): 93 | """ 94 | Compute the log-Mel spectrogram of 95 | 96 | Parameters 97 | ---------- 98 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 99 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 100 | 101 | n_mels: int 102 | The number of Mel-frequency filters, only 80 is supported 103 | 104 | Returns 105 | ------- 106 | torch.Tensor, shape = (80, n_frames) 107 | A Tensor that contains the Mel spectrogram 108 | """ 109 | if not torch.is_tensor(audio): 110 | if isinstance(audio, str): 111 | audio = load_audio(audio) 112 | audio = torch.from_numpy(audio) 113 | 114 | window = torch.hann_window(N_FFT).to(audio.device) 115 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 116 | magnitudes = stft[..., :-1].abs() ** 2 117 | 118 | filters = mel_filters(audio.device, n_mels) 119 | mel_spec = filters @ magnitudes 120 | 121 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 122 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 123 | log_spec = (log_spec + 4.0) / 4.0 124 | return log_spec 125 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer 2 | from .english import EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s) 52 | ) 53 | 54 | 55 | class BasicTextNormalizer: 56 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 57 | self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols 58 | self.split_letters = split_letters 59 | 60 | def __call__(self, s: str): 61 | s = s.lower() 62 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 63 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 64 | s = self.clean(s).lower() 65 | 66 | if self.split_letters: 67 | s = " ".join(regex.findall(r"\X", s, regex.U)) 68 | 69 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 70 | 71 | return s 72 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from functools import lru_cache 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | from transformers import GPT2TokenizerFast 9 | 10 | LANGUAGES = { 11 | "en": "english", 12 | "zh": "chinese", 13 | "de": "german", 14 | "es": "spanish", 15 | "ru": "russian", 16 | "ko": "korean", 17 | "fr": "french", 18 | "ja": "japanese", 19 | "pt": "portuguese", 20 | "tr": "turkish", 21 | "pl": "polish", 22 | "ca": "catalan", 23 | "nl": "dutch", 24 | "ar": "arabic", 25 | "sv": "swedish", 26 | "it": "italian", 27 | "id": "indonesian", 28 | "hi": "hindi", 29 | "fi": "finnish", 30 | "vi": "vietnamese", 31 | "he": "hebrew", 32 | "uk": "ukrainian", 33 | "el": "greek", 34 | "ms": "malay", 35 | "cs": "czech", 36 | "ro": "romanian", 37 | "da": "danish", 38 | "hu": "hungarian", 39 | "ta": "tamil", 40 | "no": "norwegian", 41 | "th": "thai", 42 | "ur": "urdu", 43 | "hr": "croatian", 44 | "bg": "bulgarian", 45 | "lt": "lithuanian", 46 | "la": "latin", 47 | "mi": "maori", 48 | "ml": "malayalam", 49 | "cy": "welsh", 50 | "sk": "slovak", 51 | "te": "telugu", 52 | "fa": "persian", 53 | "lv": "latvian", 54 | "bn": "bengali", 55 | "sr": "serbian", 56 | "az": "azerbaijani", 57 | "sl": "slovenian", 58 | "kn": "kannada", 59 | "et": "estonian", 60 | "mk": "macedonian", 61 | "br": "breton", 62 | "eu": "basque", 63 | "is": "icelandic", 64 | "hy": "armenian", 65 | "ne": "nepali", 66 | "mn": "mongolian", 67 | "bs": "bosnian", 68 | "kk": "kazakh", 69 | "sq": "albanian", 70 | "sw": "swahili", 71 | "gl": "galician", 72 | "mr": "marathi", 73 | "pa": "punjabi", 74 | "si": "sinhala", 75 | "km": "khmer", 76 | "sn": "shona", 77 | "yo": "yoruba", 78 | "so": "somali", 79 | "af": "afrikaans", 80 | "oc": "occitan", 81 | "ka": "georgian", 82 | "be": "belarusian", 83 | "tg": "tajik", 84 | "sd": "sindhi", 85 | "gu": "gujarati", 86 | "am": "amharic", 87 | "yi": "yiddish", 88 | "lo": "lao", 89 | "uz": "uzbek", 90 | "fo": "faroese", 91 | "ht": "haitian creole", 92 | "ps": "pashto", 93 | "tk": "turkmen", 94 | "nn": "nynorsk", 95 | "mt": "maltese", 96 | "sa": "sanskrit", 97 | "lb": "luxembourgish", 98 | "my": "myanmar", 99 | "bo": "tibetan", 100 | "tl": "tagalog", 101 | "mg": "malagasy", 102 | "as": "assamese", 103 | "tt": "tatar", 104 | "haw": "hawaiian", 105 | "ln": "lingala", 106 | "ha": "hausa", 107 | "ba": "bashkir", 108 | "jw": "javanese", 109 | "su": "sundanese", 110 | } 111 | 112 | # language code lookup by name, with a few language aliases 113 | TO_LANGUAGE_CODE = { 114 | **{language: code for code, language in LANGUAGES.items()}, 115 | "burmese": "my", 116 | "valencian": "ca", 117 | "flemish": "nl", 118 | "haitian": "ht", 119 | "letzeburgesch": "lb", 120 | "pushto": "ps", 121 | "panjabi": "pa", 122 | "moldavian": "ro", 123 | "moldovan": "ro", 124 | "sinhalese": "si", 125 | "castilian": "es", 126 | } 127 | 128 | 129 | @dataclass(frozen=True) 130 | class Tokenizer: 131 | """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" 132 | 133 | tokenizer: "GPT2TokenizerFast" 134 | language: Optional[str] 135 | sot_sequence: Tuple[int] 136 | 137 | def encode(self, text, **kwargs): 138 | return self.tokenizer.encode(text, **kwargs) 139 | 140 | def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs): 141 | return self.tokenizer.decode(token_ids, **kwargs) 142 | 143 | def decode_with_timestamps(self, tokens) -> str: 144 | """ 145 | Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. 146 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 147 | """ 148 | outputs = [[]] 149 | for token in tokens: 150 | if token >= self.timestamp_begin: 151 | timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" 152 | outputs.append(timestamp) 153 | outputs.append([]) 154 | else: 155 | outputs[-1].append(token) 156 | outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] 157 | return "".join(outputs) 158 | 159 | @property 160 | @lru_cache() 161 | def eot(self) -> int: 162 | return self.tokenizer.eos_token_id 163 | 164 | @property 165 | @lru_cache() 166 | def sot(self) -> int: 167 | return self._get_single_token_id("<|startoftranscript|>") 168 | 169 | @property 170 | @lru_cache() 171 | def sot_lm(self) -> int: 172 | return self._get_single_token_id("<|startoflm|>") 173 | 174 | @property 175 | @lru_cache() 176 | def sot_prev(self) -> int: 177 | return self._get_single_token_id("<|startofprev|>") 178 | 179 | @property 180 | @lru_cache() 181 | def no_speech(self) -> int: 182 | return self._get_single_token_id("<|nospeech|>") 183 | 184 | @property 185 | @lru_cache() 186 | def no_timestamps(self) -> int: 187 | return self._get_single_token_id("<|notimestamps|>") 188 | 189 | @property 190 | @lru_cache() 191 | def timestamp_begin(self) -> int: 192 | return self.tokenizer.all_special_ids[-1] + 1 193 | 194 | @property 195 | @lru_cache() 196 | def language_token(self) -> int: 197 | """Returns the token id corresponding to the value of the `language` field""" 198 | if self.language is None: 199 | raise ValueError(f"This tokenizer does not have language token configured") 200 | 201 | additional_tokens = dict( 202 | zip( 203 | self.tokenizer.additional_special_tokens, 204 | self.tokenizer.additional_special_tokens_ids, 205 | ) 206 | ) 207 | candidate = f"<|{self.language}|>" 208 | if candidate in additional_tokens: 209 | return additional_tokens[candidate] 210 | 211 | raise KeyError(f"Language {self.language} not found in tokenizer.") 212 | 213 | @property 214 | @lru_cache() 215 | def all_language_tokens(self) -> Tuple[int]: 216 | result = [] 217 | for token, token_id in zip( 218 | self.tokenizer.additional_special_tokens, 219 | self.tokenizer.additional_special_tokens_ids, 220 | ): 221 | if token.strip("<|>") in LANGUAGES: 222 | result.append(token_id) 223 | return tuple(result) 224 | 225 | @property 226 | @lru_cache() 227 | def all_language_codes(self) -> Tuple[str]: 228 | return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) 229 | 230 | @property 231 | @lru_cache() 232 | def sot_sequence_including_notimestamps(self) -> Tuple[int]: 233 | return tuple(list(self.sot_sequence) + [self.no_timestamps]) 234 | 235 | @property 236 | @lru_cache() 237 | def non_speech_tokens(self) -> Tuple[int]: 238 | """ 239 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech 240 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. 241 | 242 | - ♪♪♪ 243 | - ( SPEAKING FOREIGN LANGUAGE ) 244 | - [DAVID] Hey there, 245 | 246 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc. 247 | """ 248 | symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") 249 | symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() 250 | 251 | # symbols that may be a single token or multiple tokens depending on the tokenizer. 252 | # In case they're multiple tokens, suppress the first token, which is safe because: 253 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress 254 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. 255 | miscellaneous = set("♩♪♫♬♭♮♯") 256 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) 257 | 258 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word 259 | result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} 260 | for symbol in symbols + list(miscellaneous): 261 | for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]: 262 | if len(tokens) == 1 or symbol in miscellaneous: 263 | result.add(tokens[0]) 264 | 265 | return tuple(sorted(result)) 266 | 267 | def _get_single_token_id(self, text) -> int: 268 | tokens = self.tokenizer.encode(text) 269 | assert len(tokens) == 1, f"{text} is not encoded as a single token" 270 | return tokens[0] 271 | 272 | 273 | @lru_cache(maxsize=None) 274 | def build_tokenizer(name: str = "gpt2"): 275 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 276 | path = os.path.join(os.path.dirname(__file__), "assets", name) 277 | tokenizer = GPT2TokenizerFast.from_pretrained(path) 278 | 279 | specials = [ 280 | "<|startoftranscript|>", 281 | *[f"<|{lang}|>" for lang in LANGUAGES.keys()], 282 | "<|translate|>", 283 | "<|transcribe|>", 284 | "<|startoflm|>", 285 | "<|startofprev|>", 286 | "<|nospeech|>", 287 | "<|notimestamps|>", 288 | ] 289 | 290 | tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) 291 | return tokenizer 292 | 293 | 294 | @lru_cache(maxsize=None) 295 | def get_tokenizer( 296 | multilingual: bool, 297 | *, 298 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 299 | language: Optional[str] = None, 300 | ) -> Tokenizer: 301 | if language is not None: 302 | language = language.lower() 303 | if language not in LANGUAGES: 304 | if language in TO_LANGUAGE_CODE: 305 | language = TO_LANGUAGE_CODE[language] 306 | else: 307 | raise ValueError(f"Unsupported language: {language}") 308 | 309 | if multilingual: 310 | tokenizer_name = "multilingual" 311 | task = task or "transcribe" 312 | language = language or "en" 313 | else: 314 | tokenizer_name = "gpt2" 315 | task = None 316 | language = None 317 | 318 | tokenizer = build_tokenizer(name=tokenizer_name) 319 | all_special_ids: List[int] = tokenizer.all_special_ids 320 | sot: int = all_special_ids[1] 321 | translate: int = all_special_ids[-6] 322 | transcribe: int = all_special_ids[-5] 323 | 324 | langs = tuple(LANGUAGES.keys()) 325 | sot_sequence = [sot] 326 | if language is not None: 327 | sot_sequence.append(sot + 1 + langs.index(language)) 328 | if task is not None: 329 | sot_sequence.append(transcribe if task == "transcribe" else translate) 330 | 331 | return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)) 332 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/utils.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | from typing import Iterator, TextIO 3 | 4 | 5 | def exact_div(x, y): 6 | assert x % y == 0 7 | return x // y 8 | 9 | 10 | def str2bool(string): 11 | str2val = {"True": True, "False": False} 12 | if string in str2val: 13 | return str2val[string] 14 | else: 15 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 16 | 17 | 18 | def optional_int(string): 19 | return None if string == "None" else int(string) 20 | 21 | 22 | def optional_float(string): 23 | return None if string == "None" else float(string) 24 | 25 | 26 | def compression_ratio(text) -> float: 27 | text_bytes = text.encode("utf-8") 28 | return len(text_bytes) / len(zlib.compress(text_bytes)) 29 | 30 | 31 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): 32 | assert seconds >= 0, "non-negative timestamp expected" 33 | milliseconds = round(seconds * 1000.0) 34 | 35 | hours = milliseconds // 3_600_000 36 | milliseconds -= hours * 3_600_000 37 | 38 | minutes = milliseconds // 60_000 39 | milliseconds -= minutes * 60_000 40 | 41 | seconds = milliseconds // 1_000 42 | milliseconds -= seconds * 1_000 43 | 44 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 45 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 46 | 47 | 48 | def write_txt(transcript: Iterator[dict], file: TextIO): 49 | for segment in transcript: 50 | print(segment['text'].strip(), file=file, flush=True) 51 | 52 | 53 | def write_vtt(transcript: Iterator[dict], file: TextIO): 54 | print("WEBVTT\n", file=file) 55 | for segment in transcript: 56 | print( 57 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" 58 | f"{segment['text'].strip().replace('-->', '->')}\n", 59 | file=file, 60 | flush=True, 61 | ) 62 | 63 | 64 | def write_srt(transcript: Iterator[dict], file: TextIO): 65 | """ 66 | Write a transcript to a file in SRT format. 67 | 68 | Example usage: 69 | from pathlib import Path 70 | from whisper.utils import write_srt 71 | 72 | result = transcribe(model, audio_path, temperature=temperature, **args) 73 | 74 | # save SRT 75 | audio_basename = Path(audio_path).stem 76 | with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: 77 | write_srt(result["segments"], file=srt) 78 | """ 79 | for i, segment in enumerate(transcript, start=1): 80 | # write srt lines 81 | print( 82 | f"{i}\n" 83 | f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " 84 | f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" 85 | f"{segment['text'].strip().replace('-->', '->')}\n", 86 | file=file, 87 | flush=True, 88 | ) 89 | -------------------------------------------------------------------------------- /src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "20230117" 2 | -------------------------------------------------------------------------------- /src/noise_robust_asr/plots/plot_figure1_lower.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2/13/23 3:40 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : get_summary.py 7 | 8 | # summarize the results of esc experiments 9 | 10 | import json 11 | import os 12 | os.environ["XDG_CACHE_HOME"] = './' 13 | import numpy as np 14 | from matplotlib import pyplot as plt 15 | 16 | mdl_size_list = ['whisper_large-v1', 17 | 'hubert-xlarge-ll60k', 18 | 'hubert-xlarge-ls960-ft', 19 | 'wav2vec2-large-robust', 20 | 'wav2vec2-large-robust-ft-swbd-300h', 21 | 'hubert-large-ls960-ft', 22 | 'wav2vec2-base-960h'] 23 | 24 | legend_list = ['Whisper-Large', 'Hubert-XLarge-PR', 'Hubert-XLarge-FT', 'Wav2vec2-Large-Robust-PR', 'Wav2vec2-Large-Robust-FT', 'Hubert-Large-FT', 'Wav2vec2-Base-FT'] 25 | 26 | for i, mdl_size in enumerate(mdl_size_list): 27 | all_res = [] 28 | for fold in range(1, 6): 29 | for lr in [0.001]: 30 | cur_res = np.loadtxt('./baseline_res/esc_{:s}_fold{:d}_lr_{:.4f}.csv'.format(mdl_size, fold, lr), delimiter=',', usecols=(5)).tolist() 31 | all_res.append(cur_res) 32 | all_res = np.array(all_res) 33 | all_res = np.mean(all_res, axis=0)[1:-1] * 100 34 | print(all_res.shape) 35 | num_layer = all_res.shape[0] 36 | if i == 0: # whisper 37 | plt.plot(list(range(1, num_layer+1)), all_res, '-o', label = legend_list[i], linewidth=2) 38 | elif i == 1: 39 | plt.plot(list(range(1, num_layer + 1)), all_res, 'g-', label=legend_list[i], linewidth=2, alpha=0.5) 40 | elif i == 2: 41 | plt.plot(list(range(1, num_layer + 1)), all_res, 'g-x', label=legend_list[i], linewidth=2) 42 | elif i == 3: 43 | plt.plot(list(range(1, num_layer + 1)), all_res, 'c-', label=legend_list[i], linewidth=2, alpha=0.5) 44 | elif i == 4: 45 | plt.plot(list(range(1, num_layer + 1)), all_res, 'c-*', label=legend_list[i], linewidth=2) 46 | elif i == 5: 47 | plt.plot(list(range(1, num_layer + 1)), all_res, '-^', label=legend_list[i], linewidth=2) 48 | elif i == 6: 49 | plt.plot(list(range(1, num_layer + 1)), all_res, 'r-d', label=legend_list[i], linewidth=2) 50 | 51 | plt.ylim([0, 1]) 52 | plt.xlabel('Classifying Using Representation of Layer # as Input', fontsize=13.5) 53 | plt.ylabel('Sound Classification Accuracy (%)', fontsize=14) 54 | plt.legend(fontsize=10) 55 | plt.grid() 56 | plt.ylim([28, 90]) 57 | plt.xlim([0, 50]) 58 | figure = plt.gcf() 59 | figure.set_size_inches(6, 4) 60 | plt.savefig('./formal_plot/result_summary_' + str(lr) + '_cr.pdf', dpi=300, bbox_inches='tight') 61 | plt.close() -------------------------------------------------------------------------------- /src/noise_robust_asr/plots/plot_figure1_upper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3/4/23 9:11 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : plot_snr.py 7 | import numpy as np 8 | from matplotlib import pyplot as plt 9 | 10 | mdl_list = ['Whisper-Large', 'Hubert-XLarge-FT', 'Wav2vec2-Large-Robust-FT', 'Hubert-Large-FT', 'Wav2vec2-Base-FT'] 11 | 12 | exp_name_list = ['whisper_large-v1', 'hubert_xlarge', 'w2v_large_robust', 'hubert_large', 'w2v_base'] 13 | snr_list = [-20, -15, -10, -5, 0, 5, 10, 15, 20] 14 | 15 | for i in range(len(exp_name_list)): 16 | exp_name = exp_name_list[i] 17 | cur_res = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/{:s}.csv'.format(exp_name)) 18 | cur_res = cur_res * 100 19 | print(exp_name, cur_res.shape) 20 | if i == 0: 21 | plt.plot(snr_list, cur_res, '-o', label=mdl_list[i], linewidth=2) 22 | elif i == 1: 23 | plt.plot(snr_list, cur_res, 'g-x', label=mdl_list[i], linewidth=2) 24 | elif i == 2: 25 | plt.plot(snr_list, cur_res, 'c-*', label=mdl_list[i], linewidth=2) 26 | elif i == 3: 27 | plt.plot(snr_list, cur_res, '-^', label=mdl_list[i], linewidth=2) 28 | elif i == 4: 29 | plt.plot(snr_list, cur_res, 'r-d', label=mdl_list[i], linewidth=2) 30 | 31 | plt.xlabel('Signal-to-Noise Ratio (dB)', fontsize=14) 32 | plt.ylabel('Word Error Rate (%)', fontsize=14) 33 | plt.legend(fontsize=10) 34 | plt.gca().invert_xaxis() 35 | plt.grid() 36 | figure = plt.gcf() 37 | figure.set_size_inches(6, 2.5) 38 | plt.savefig('./snr_plot_cr.pdf', dpi=300, bbox_inches='tight') 39 | plt.close() -------------------------------------------------------------------------------- /src/noise_robust_asr/plots/plot_figure2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3/4/23 9:11 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : plot_snr.py 7 | import numpy as np 8 | from matplotlib import pyplot as plt 9 | 10 | label_list = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/esc_class_labels_indices.csv', delimiter=',', dtype=str, skiprows=1, usecols=(2)).tolist() 11 | print(label_list) 12 | 13 | # classes not used in fitting the line 14 | outlier_list = [8, 10, 11, 16, 22, 27, 35, 36, 41, 45, 46, 49] 15 | not_outliear_list = [] 16 | for x in range(50): 17 | if x not in outlier_list: 18 | not_outliear_list.append(x) 19 | 20 | start = 2 # -10 epoch 21 | snr_res = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_large-v1_cla.csv', delimiter=',') 22 | snr_drop = snr_res[start] - snr_res[-1] # from -10 (3nd row, index 2) to 20 (last row) snr 23 | print(snr_drop.shape) 24 | snr_drop = [x*100 for x in snr_drop] 25 | 26 | # sound classification result 27 | all_res = [] 28 | mdl_size = 'whisper_large-v1' 29 | for fold in range(1, 6): 30 | for lr in [0.001]: 31 | cur_res = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/src/baseline_cla/baseline_res/esc_{:s}_fold{:d}_lr_{:.4f}.csv'.format(mdl_size, fold, lr), delimiter=',', usecols=list(range(6,56)))#.tolist() 32 | cur_res = cur_res[-2, :].tolist() # get the none-wa mean of representation, [50,] corresponds to 50 classes, -2 is the last layer out, 0 is the input layer out, -1 is the wa out 33 | all_res.append(cur_res) 34 | all_res = np.array(all_res) # [5, 50] , 5 folds, 50 classes 35 | 36 | sound_res = np.mean(all_res, axis=0) * 100 37 | print(sound_res.shape) 38 | 39 | print(start, 'corr', np.corrcoef(sound_res, snr_drop)[0, 1]) 40 | 41 | b, a = np.polyfit(np.array(sound_res)[not_outliear_list], np.array(snr_drop)[not_outliear_list], deg=1) 42 | 43 | print(start, 'corr', np.corrcoef(np.array(sound_res)[not_outliear_list], np.array(snr_drop)[not_outliear_list])[0, 1]) 44 | 45 | # Create sequence of 100 numbers from 0 to 100 46 | xseq = np.linspace(50, 100, num=50) 47 | 48 | # Plot regression line 49 | plt.plot(xseq, a + b * xseq, '--', lw=2.5, alpha=0.7) 50 | plt.fill_between(xseq, a+22 + (b - 0.465) * xseq, 70, alpha=0.3, color='lightblue') 51 | 52 | plt.scatter(sound_res, snr_drop) 53 | 54 | font_size = 2 55 | for cla_i in range(50): 56 | plt.annotate(label_list[cla_i][1:-1], (sound_res[cla_i], snr_drop[cla_i]), fontsize=font_size) 57 | 58 | plt.ylim([-5, 110]) 59 | plt.xlim([50, 100]) 60 | plt.grid() 61 | plt.gca().invert_yaxis() 62 | plt.xlabel('ESC-50 Class-wise F1-Score') 63 | plt.ylabel('Word Error Rate Increase from 20dB to -10dB SNR (%)') 64 | figure = plt.gcf() 65 | figure.set_size_inches(5, 5) 66 | plt.savefig('./figure_test/snr_plot_cla_{:d}_cr_outlier.pdf'.format(start), dpi=300, bbox_inches='tight') 67 | plt.close() -------------------------------------------------------------------------------- /src/noise_robust_asr/plots/plot_figure3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2/13/23 3:40 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : plot_figure3.py 7 | 8 | # hist of best layer for each sound class 9 | 10 | import os 11 | os.environ["XDG_CACHE_HOME"] = './' 12 | import numpy as np 13 | from matplotlib import pyplot as plt 14 | 15 | label_list = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/esc_class_labels_indices.csv', delimiter=',', dtype=str, skiprows=1, usecols=(2)).tolist() 16 | 17 | all_res = [] 18 | mdl_size = 'whisper_large-v1' 19 | for fold in range(1, 6): 20 | for lr in [0.001]: 21 | cur_res = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/src/baseline_cla/baseline_res/esc_{:s}_fold{:d}_lr_{:.4f}.csv'.format(mdl_size, fold, lr), delimiter=',', usecols=list(range(6, 56))) 22 | cur_res = cur_res[1:-1, :].tolist() # exclude the input and last avg layer 23 | all_res.append(cur_res) 24 | all_res = np.array(all_res) # [5, 50] , 5 folds, 50 classes 25 | sound_res = np.mean(all_res, axis=0) * 100 26 | 27 | best_layer_list = [] 28 | for i in range(sound_res.shape[1]): 29 | best_layer_list.append(np.argmax(sound_res[:, i] + 1)) 30 | 31 | print(best_layer_list) 32 | 33 | plt.hist(best_layer_list, bins=16, histtype ='bar', rwidth=0.7) 34 | plt.xlabel('Representation of Layer', fontsize=14) 35 | plt.ylabel('# Classes', fontsize=14) 36 | plt.xticks(range(1, 33, 2)) 37 | plt.grid() 38 | figure = plt.gcf() 39 | figure.set_size_inches(6, 2) 40 | plt.savefig('./formal_plot/best_layer.pdf', dpi=300, bbox_inches='tight') 41 | plt.close() -------------------------------------------------------------------------------- /src/whisper_at_train/datafiles/README.md: -------------------------------------------------------------------------------- 1 | # AudioSet Datafiles 2 | 3 | We provide the json file we used to train the model to facilitate reproduction, it contains sample id and corresponding labels, actual audio files are not contained. 4 | 5 | [AudioSet-2M Training Json File](https://www.dropbox.com/scl/fi/szjlerzblw17f8d2fykgt/whole_train_data.json?rlkey=dr3rdri1jaql0g9lgfpiyfihg&dl=1) 6 | 7 | [AudioSet-Eval Json File](https://www.dropbox.com/scl/fi/cfu70jnxqgphi9d1nm89w/eval_data.json?rlkey=t7b44qk27iznrrtwl80hnz12q&dl=1) -------------------------------------------------------------------------------- /src/whisper_at_train/dataloader_feat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 6/19/21 12:23 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : dataloader.py 7 | 8 | # modified from: 9 | # Author: David Harwath 10 | # with some functions borrowed from https://github.com/SeanNaren/deepspeech.pytorch 11 | # load from whisper feats 12 | 13 | import csv 14 | import json 15 | import os.path 16 | 17 | import torchaudio 18 | import numpy as np 19 | import torch 20 | import torch.nn.functional 21 | from torch.utils.data import Dataset 22 | import random 23 | from whisper.audio import log_mel_spectrogram, pad_or_trim, load_audio 24 | 25 | def make_index_dict(label_csv): 26 | index_lookup = {} 27 | with open(label_csv, 'r') as f: 28 | csv_reader = csv.DictReader(f) 29 | line_count = 0 30 | for row in csv_reader: 31 | index_lookup[row['mid']] = row['index'] 32 | line_count += 1 33 | return index_lookup 34 | 35 | def make_name_dict(label_csv): 36 | name_lookup = {} 37 | with open(label_csv, 'r') as f: 38 | csv_reader = csv.DictReader(f) 39 | line_count = 0 40 | for row in csv_reader: 41 | name_lookup[row['index']] = row['display_name'] 42 | line_count += 1 43 | return name_lookup 44 | 45 | def lookup_list(index_list, label_csv): 46 | label_list = [] 47 | table = make_name_dict(label_csv) 48 | for item in index_list: 49 | label_list.append(table[item]) 50 | return label_list 51 | 52 | def preemphasis(signal,coeff=0.97): 53 | return np.append(signal[0],signal[1:]-coeff*signal[:-1]) 54 | 55 | class AudiosetDataset(Dataset): 56 | def __init__(self, dataset_json_file, audio_conf, label_csv=None): 57 | self.datapath = dataset_json_file 58 | with open(dataset_json_file, 'r') as fp: 59 | data_json = json.load(fp) 60 | 61 | self.data = data_json['data'] 62 | self.data = self.pro_data(self.data) 63 | print('Dataset has {:d} samples'.format(self.data.shape[0])) 64 | self.num_samples = self.data.shape[0] 65 | self.audio_conf = audio_conf 66 | self.label_smooth = self.audio_conf.get('label_smooth', 0.0) 67 | print('Using Label Smoothing: ' + str(self.label_smooth)) 68 | self.freqm = self.audio_conf.get('freqm', 0) 69 | self.timem = self.audio_conf.get('timem', 0) 70 | print('Using Following Mask: {:d} Freq, {:d} Time'.format(self.audio_conf.get('freqm'), self.audio_conf.get('timem'))) 71 | self.mixup = self.audio_conf.get('mixup', 0) 72 | print('Using Mix-up with Rate {:f}'.format(self.mixup)) 73 | self.dataset = self.audio_conf.get('dataset') 74 | print('Now Process ' + self.dataset) 75 | 76 | self.index_dict = make_index_dict(label_csv) 77 | self.label_num = len(self.index_dict) 78 | print('Number of Classes is {:d}'.format(self.label_num)) 79 | 80 | self.tar_path= self.audio_conf.get('tar_path') 81 | print('Now load features from {:s}'.format(self.tar_path)) 82 | 83 | # change python list to numpy array to avoid memory leak. 84 | def pro_data(self, data_json): 85 | for i in range(len(data_json)): 86 | data_json[i] = [data_json[i]['wav'], data_json[i]['labels']] 87 | data_np = np.array(data_json, dtype=str) 88 | return data_np 89 | 90 | # reformat numpy data to original json format, make it compatible with old code 91 | def decode_data(self, np_data): 92 | datum = {} 93 | datum['wav'] = np_data[0] 94 | datum['labels'] = np_data[1] 95 | return datum 96 | 97 | def load_rep(self, path): 98 | try: 99 | # if npy file 100 | if path[-3:] == 'npy': 101 | return np.load(path) 102 | elif path[-3:] == 'npz': 103 | return np.load(path)['arr_0'] 104 | except: 105 | print('a missing file', path) 106 | return np.zeros((6, 25, 512)) # should only work for whisper-base model, which has missing file problem 107 | 108 | def _wav2fbank(self, filename, filename2=None, mix_lambda=-1): 109 | if 'feat_as' in self.tar_path or 'feat_esc_pool' in self.tar_path: 110 | fmt = '.npz' 111 | else: 112 | fmt = '.npy' 113 | 114 | tar_path = self.tar_path + '/' 115 | if filename2 == None: 116 | filename = tar_path + '.'.join(filename.split('/')[-1].split('.')[:-1]) + fmt 117 | feat = self.load_rep(filename) 118 | feat = torch.Tensor(feat) 119 | 120 | # 25 is the time length after pooling 121 | if feat.shape[1] < 25: 122 | len_diff = 25 - feat.shape[1] 123 | feat = torch.nn.functional.pad(feat, (0, 0, 0, len_diff)) 124 | else: 125 | feat = feat[:, :25, :] 126 | 127 | else: 128 | filename = tar_path + '.'.join(filename.split('/')[-1].split('.')[:-1]) + fmt 129 | feat = self.load_rep(filename) 130 | feat = torch.Tensor(feat) 131 | 132 | filename2 = tar_path + '.'.join(filename2.split('/')[-1].split('.')[:-1]) + fmt 133 | feat2 = self.load_rep(filename2) 134 | feat2 = torch.Tensor(feat2) 135 | 136 | if feat.shape[1] < 25: 137 | len_diff = 25 - feat.shape[1] 138 | feat = torch.nn.functional.pad(feat, (0, 0, 0, len_diff)) 139 | else: 140 | feat = feat[:, :25, :] 141 | if feat2.shape[1] < 25: 142 | len_diff = 25 - feat2.shape[1] 143 | feat2 = torch.nn.functional.pad(feat2, (0, 0, 0, len_diff)) 144 | else: 145 | feat2 = feat2[:, :25, :] 146 | feat = mix_lambda * feat + (1 - mix_lambda) * feat2 147 | 148 | return feat 149 | 150 | def __getitem__(self, index): 151 | if random.random() < self.mixup: 152 | datum = self.data[index] 153 | datum = self.decode_data(datum) 154 | mix_sample_idx = random.randint(0, self.num_samples-1) 155 | mix_datum = self.data[mix_sample_idx] 156 | mix_datum = self.decode_data(mix_datum) 157 | # get the mixed fbank 158 | mix_lambda = np.random.beta(10, 10) 159 | fbank = self._wav2fbank(datum['wav'], mix_datum['wav'], mix_lambda) 160 | label_indices = np.zeros(self.label_num) + (self.label_smooth / self.label_num) 161 | for label_str in datum['labels'].split(','): 162 | label_indices[int(self.index_dict[label_str])] += mix_lambda * (1.0 - self.label_smooth) 163 | for label_str in mix_datum['labels'].split(','): 164 | label_indices[int(self.index_dict[label_str])] += (1.0 - mix_lambda) * (1.0 - self.label_smooth) 165 | label_indices = torch.FloatTensor(label_indices) 166 | 167 | else: 168 | datum = self.data[index] 169 | datum = self.decode_data(datum) 170 | # label smooth for negative samples, epsilon/label_num 171 | label_indices = np.zeros(self.label_num) + (self.label_smooth / self.label_num) 172 | fbank = self._wav2fbank(datum['wav'], None, 0) 173 | for label_str in datum['labels'].split(','): 174 | label_indices[int(self.index_dict[label_str])] = 1.0 - self.label_smooth 175 | label_indices = torch.FloatTensor(label_indices) 176 | 177 | # SpecAug, not do for eval set, input feat shape in [25, 1280], i.e. t-f, need to transpose 178 | freqm = torchaudio.transforms.FrequencyMasking(self.freqm) 179 | timem = torchaudio.transforms.TimeMasking(self.timem) 180 | fbank = fbank.transpose(1, 2) 181 | if self.freqm != 0: 182 | fbank = freqm(fbank) 183 | if self.timem != 0: 184 | fbank = timem(fbank) 185 | fbank = fbank.transpose(1, 2) 186 | return fbank, label_indices 187 | 188 | def __len__(self): 189 | return self.num_samples -------------------------------------------------------------------------------- /src/whisper_at_train/gen_weight_file.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 11/17/20 3:22 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : gen_weight_file.py 7 | 8 | # gen sample weight = sum(label_weight) for label in all labels of the audio clip, where label_weight is the reciprocal of the total sample count of that class. 9 | # Note audioset is a multi-label dataset 10 | 11 | import argparse 12 | import json 13 | import numpy as np 14 | import sys, os, csv 15 | 16 | def make_index_dict(label_csv): 17 | index_lookup = {} 18 | with open(label_csv, 'r') as f: 19 | csv_reader = csv.DictReader(f) 20 | line_count = 0 21 | for row in csv_reader: 22 | index_lookup[row['mid']] = row['index'] 23 | line_count += 1 24 | return index_lookup 25 | 26 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 27 | parser.add_argument("--data_path", type=str, default='path_to_your_data_file.json', help="the root path of data json file") 28 | 29 | if __name__ == '__main__': 30 | args = parser.parse_args() 31 | data_path = args.data_path 32 | 33 | index_dict = make_index_dict('./class_labels_indices.csv') 34 | label_count = np.zeros(527) 35 | 36 | with open(data_path, 'r', encoding='utf8')as fp: 37 | data = json.load(fp) 38 | data = data['data'] 39 | 40 | for sample in data: 41 | sample_labels = sample['labels'].split(',') 42 | for label in sample_labels: 43 | label_idx = int(index_dict[label]) 44 | label_count[label_idx] = label_count[label_idx] + 1 45 | 46 | # the reason not using 1 is to avoid underflow for majority classes, add small value to avoid underflow 47 | label_weight = 1000.0 / (label_count + 0.01) 48 | #label_weight = 1000.0 / (label_count + 0.00) 49 | sample_weight = np.zeros(len(data)) 50 | 51 | for i, sample in enumerate(data): 52 | sample_labels = sample['labels'].split(',') 53 | for label in sample_labels: 54 | label_idx = int(index_dict[label]) 55 | # summing up the weight of all appeared classes in the sample, note audioset is multiple-label classification 56 | sample_weight[i] += label_weight[label_idx] 57 | np.savetxt(data_path[:-5]+'_weight.csv', sample_weight, delimiter=',') -------------------------------------------------------------------------------- /src/whisper_at_train/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/26/23 11:13 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : high_mdls.py 7 | 8 | # high models 9 | 10 | import numpy as np 11 | import torch 12 | import math 13 | import torch.nn.functional as F 14 | from torch import Tensor 15 | from torch import nn 16 | from whisper.model import ResidualAttentionBlock, Linear 17 | 18 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 19 | def norm_cdf(x): 20 | # Computes standard normal cumulative distribution function 21 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 22 | 23 | with torch.no_grad(): 24 | # Values are generated by using a truncated uniform distribution and 25 | # then using the inverse CDF for the normal distribution. 26 | # Get upper and lower cdf values 27 | l = norm_cdf((a - mean) / std) 28 | u = norm_cdf((b - mean) / std) 29 | 30 | # Uniformly fill tensor with values from [l, u], then translate to 31 | # [2l-1, 2u-1]. 32 | tensor.uniform_(2 * l - 1, 2 * u - 1) 33 | 34 | # Use inverse cdf transform for normal distribution to get truncated 35 | # standard normal 36 | tensor.erfinv_() 37 | 38 | # Transform to proper mean, std 39 | tensor.mul_(std * math.sqrt(2.)) 40 | tensor.add_(mean) 41 | 42 | # Clamp to ensure it's in the proper range 43 | tensor.clamp_(min=a, max=b) 44 | return tensor 45 | 46 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 47 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 48 | 49 | class TLTR(nn.Module): 50 | def __init__(self, label_dim=527, n_layer=33, rep_dim=1280, mode='basic'): 51 | super().__init__() 52 | self.mode = mode 53 | self.n_layer = n_layer 54 | self.rep_dim = rep_dim 55 | self.label_dim = label_dim 56 | 57 | # (baseline) mean pool over time and layer, and mlp head 58 | if mode == 'mean_mlp' or mode == 'last_mlp': 59 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.label_dim)) 60 | 61 | # (baseline) mean pool over time, and weight average over layers, and mlp head 62 | if mode == 'wa_mlp': 63 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.label_dim)) 64 | self.layer_weight = torch.nn.Parameter(torch.tensor([1 / self.n_layer] * self.n_layer)) 65 | 66 | # (baseline) mean pool over layer, and apply a original rep_dim transformer 67 | if 'mean_tr' in mode or 'last_tr' in mode: 68 | self.num_att_head = int(mode.split('_')[-1]) 69 | self.time_tr = ResidualAttentionBlock(self.rep_dim, self.num_att_head) 70 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.label_dim)) 71 | 72 | # (baseline) weight average over layers, and apply a original rep_dim transformer 73 | if 'wa_tr' in mode: 74 | self.num_att_head = int(mode.split('_')[-1]) 75 | self.layer_weight = torch.nn.Parameter(torch.tensor([1 / self.n_layer] * self.n_layer)) 76 | self.time_tr = ResidualAttentionBlock(self.rep_dim, self.num_att_head) 77 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.label_dim)) 78 | 79 | # (baseline) weight average over layers, and apply a low-dimensional transformer 80 | if 'wa_down_tr' in mode: # 512_1 81 | self.inter_rep_dim = int(mode.split('_')[-2]) 82 | self.num_att_head = int(mode.split('_')[-1]) 83 | 84 | self.down_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.inter_rep_dim)) 85 | self.layer_weight = torch.nn.Parameter(torch.tensor([1 / self.n_layer] * self.n_layer)) 86 | self.time_tr = ResidualAttentionBlock(self.inter_rep_dim, self.num_att_head) 87 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.inter_rep_dim), nn.Linear(self.inter_rep_dim, self.label_dim)) 88 | 89 | # (proposed), tl-tr, weight average over layers, and apply a original rep_dim transformer 90 | if 'lw_tr' in mode: 91 | self.num_tatt_head = int(mode.split('_')[-2]) 92 | self.num_latt_head = int(mode.split('_')[-1]) 93 | self.time_tr = ResidualAttentionBlock(self.rep_dim, self.num_tatt_head) 94 | self.layer_tr = ResidualAttentionBlock(self.rep_dim, self.num_latt_head) 95 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.label_dim)) 96 | 97 | # (proposed), tl-tr with low-dimension projection, lower the dimension of the transformer # lw_down_tr_512_1_8 98 | if 'lw_down_tr' in mode: 99 | self.inter_rep_dim = int(mode.split('_')[-3]) 100 | self.num_tatt_head = int(mode.split('_')[-2]) 101 | self.num_latt_head = int(mode.split('_')[-1]) 102 | 103 | self.down_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.inter_rep_dim)) 104 | self.time_tr = ResidualAttentionBlock(self.inter_rep_dim, self.num_tatt_head) 105 | self.layer_tr = ResidualAttentionBlock(self.inter_rep_dim, self.num_latt_head) 106 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.inter_rep_dim), nn.Linear(self.inter_rep_dim, self.label_dim)) 107 | 108 | def forward(self, audio_rep): 109 | # audio_rep in shape (# batch size, #whisper_enc_layer, time length after (20x) pooling, whisper_enc_dim) 110 | # e.g., (B, 32, 25, 1280) for whisper large-v1 111 | 112 | # (baseline) 113 | if self.mode == 'mean_mlp': 114 | audio_rep = torch.mean(audio_rep, dim=1) 115 | audio_rep = torch.mean(audio_rep, dim=1) 116 | audio_rep = self.mlp_layer(audio_rep) 117 | return audio_rep 118 | 119 | # (baseline) 120 | elif self.mode == 'last_mlp': 121 | audio_rep = audio_rep[:, -1, :, :] # get the last layer 122 | audio_rep = torch.mean(audio_rep, dim=1) 123 | audio_rep = self.mlp_layer(audio_rep) 124 | return audio_rep 125 | 126 | # (baseline) 127 | elif self.mode == 'wa_mlp': 128 | audio_rep = torch.mean(audio_rep, dim=2) # [B, 32 1280] 129 | audio_rep = torch.permute(audio_rep, (0, 2, 1)) # (B, 1280, 32) 130 | audio_rep = (audio_rep @ self.layer_weight) / self.layer_weight.sum() 131 | audio_rep = self.mlp_layer(audio_rep) 132 | return audio_rep 133 | 134 | # (baseline) 135 | elif 'mean_tr' in self.mode: 136 | audio_rep = torch.mean(audio_rep, dim=1) # [B, 25, 1280] 137 | audio_rep = self.time_tr(audio_rep) # [B, 25, 1280] 138 | audio_rep = torch.mean(audio_rep, dim=1) # [B*32, 1280] 139 | audio_rep = self.mlp_layer(audio_rep) 140 | return audio_rep 141 | 142 | # (baseline) time transformer on the last layer representation 143 | elif 'last_tr' in self.mode: 144 | audio_rep = audio_rep[:, -1, :, :] # [B, 25, 1280] 145 | audio_rep = self.time_tr(audio_rep) # [B, 25, 1280] 146 | audio_rep = torch.mean(audio_rep, dim=1) # [B*32, 1280] 147 | audio_rep = self.mlp_layer(audio_rep) 148 | return audio_rep 149 | 150 | # (baseline) time transformer on the layer-wise weight-average representation 151 | elif 'wa_tr' in self.mode: 152 | audio_rep = torch.permute(audio_rep, (0, 2, 3, 1)) # (B, 25, 1280, 32) 153 | audio_rep = (audio_rep @ self.layer_weight) / self.layer_weight.sum() # [B, 25, 1280] 154 | audio_rep = self.time_tr(audio_rep) # [B, 25, 1280] 155 | audio_rep = torch.mean(audio_rep, dim=1) # [B*25, 1280] 156 | audio_rep = self.mlp_layer(audio_rep) 157 | return audio_rep 158 | 159 | # (baseline) weight average with low-dimension projection 160 | elif 'wa_down_tr' in self.mode: 161 | audio_rep = torch.permute(audio_rep, (0, 2, 3, 1)) # (B, 25, 1280, 32) 162 | audio_rep = (audio_rep @ self.layer_weight) / self.layer_weight.sum() # [B, 25, 1280] 163 | audio_rep = self.down_layer(audio_rep) 164 | audio_rep = self.time_tr(audio_rep) # [B, 25, 1280] 165 | audio_rep = torch.mean(audio_rep, dim=1) # [B*32, 1280] 166 | audio_rep = self.mlp_layer(audio_rep) 167 | return audio_rep 168 | 169 | # (proposed) tl-tr 170 | elif 'lw_tr' in self.mode: 171 | B = audio_rep.shape[0] 172 | audio_rep = audio_rep.reshape(B*self.n_layer, audio_rep.shape[2], audio_rep.shape[3]) # [B*32, 25, 1280] 173 | audio_rep = self.time_tr(audio_rep) # [B*32, 25, 1280] 174 | audio_rep = torch.mean(audio_rep, dim=1) # [B*32, 1280] 175 | audio_rep = audio_rep.reshape(B, self.n_layer, audio_rep.shape[1]) # [B, 32, 1280] 176 | audio_rep = self.layer_tr(audio_rep) # [B, 32, 1280] 177 | audio_rep = torch.mean(audio_rep, dim=1) # [B, 1280] 178 | audio_rep = self.mlp_layer(audio_rep) 179 | return audio_rep 180 | 181 | #(proposed) tl-tr with low-dimensional projection 182 | elif 'lw_down_tr' in self.mode: 183 | B = audio_rep.shape[0] 184 | audio_rep = self.down_layer(audio_rep) 185 | audio_rep = audio_rep.reshape(B*self.n_layer, audio_rep.shape[2], audio_rep.shape[3]) # [B*32, 25, 1280] 186 | audio_rep = self.time_tr(audio_rep) # [B*32, 25, 1280] 187 | audio_rep = torch.mean(audio_rep, dim=1) # [B*32, 1280] 188 | audio_rep = audio_rep.reshape(B, self.n_layer, audio_rep.shape[1]) # [B, 32, 1280] 189 | audio_rep = self.layer_tr(audio_rep) # [B, 32, 1280] 190 | audio_rep = torch.mean(audio_rep, dim=1) # [B, 1280] 191 | audio_rep = self.mlp_layer(audio_rep) 192 | return audio_rep 193 | 194 | -------------------------------------------------------------------------------- /src/whisper_at_train/run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 6/11/21 12:57 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : run.py 7 | 8 | import argparse 9 | import os 10 | os.environ['MPLCONFIGDIR'] = './plt/' 11 | os.environ['TRANSFORMERS_CACHE'] = './tr/' 12 | import ast 13 | import pickle 14 | import sys 15 | import time 16 | import torch 17 | from torch.utils.data import WeightedRandomSampler 18 | basepath = os.path.dirname(os.path.dirname(sys.path[0])) 19 | sys.path.append(basepath) 20 | import dataloader_feat as dataloader 21 | import numpy as np 22 | from traintest import train, validate 23 | from models import TLTR 24 | 25 | print("I am process %s, running on %s: starting (%s)" % (os.getpid(), os.uname()[1], time.asctime())) 26 | 27 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 28 | parser.add_argument("--data-train", type=str, default='', help="training data json") 29 | parser.add_argument("--data-val", type=str, default='', help="validation data json") 30 | parser.add_argument("--data-eval", type=str, default=None, help="evaluation data json") 31 | parser.add_argument("--label-csv", type=str, default='', help="csv with class labels") 32 | parser.add_argument("--n_class", type=int, default=527, help="number of classes") 33 | parser.add_argument("--model", type=str, default='ast', help="the model used") 34 | parser.add_argument("--dataset", type=str, default="audioset", help="the dataset used") 35 | 36 | parser.add_argument("--exp-dir", type=str, default="", help="directory to dump experiments") 37 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate') 38 | parser.add_argument("--optim", type=str, default="adam", help="training optimizer", choices=["sgd", "adam"]) 39 | parser.add_argument('-b', '--batch-size', default=12, type=int, metavar='N', help='mini-batch size') 40 | parser.add_argument('-w', '--num-workers', default=8, type=int, metavar='NW', help='# of workers for dataloading (default: 32)') 41 | parser.add_argument("--n-epochs", type=int, default=1, help="number of maximum training epochs") 42 | # not used in the formal experiments 43 | parser.add_argument("--lr_patience", type=int, default=1, help="how many epoch to wait to reduce lr if mAP doesn't improve") 44 | parser.add_argument("--lr_adapt", help='if use adaptive learning rate', type=ast.literal_eval) 45 | parser.add_argument("--metrics", type=str, default="mAP", help="the main evaluation metrics in finetuning", choices=["mAP", "acc"]) 46 | parser.add_argument("--loss", type=str, default="BCE", help="the loss function for finetuning, depend on the task", choices=["BCE", "CE"]) 47 | parser.add_argument('--warmup', help='if use warmup learning rate scheduler', type=ast.literal_eval, default='True') 48 | parser.add_argument("--lrscheduler_start", default=10, type=int, help="when to start decay in finetuning") 49 | parser.add_argument("--lrscheduler_step", default=5, type=int, help="the number of step to decrease the learning rate in finetuning") 50 | parser.add_argument("--lrscheduler_decay", default=0.5, type=float, help="the learning rate decay ratio in finetuning") 51 | 52 | parser.add_argument("--wa", help='if do weight averaging in finetuning', type=ast.literal_eval) 53 | parser.add_argument("--wa_start", type=int, default=16, help="which epoch to start weight averaging in finetuning") 54 | parser.add_argument("--wa_end", type=int, default=30, help="which epoch to end weight averaging in finetuning") 55 | 56 | parser.add_argument("--n-print-steps", type=int, default=100, help="number of steps to print statistics") 57 | parser.add_argument('--save_model', help='save the model or not', type=ast.literal_eval) 58 | 59 | parser.add_argument('--freqm', help='frequency mask max length', type=int, default=0) 60 | parser.add_argument('--timem', help='time mask max length', type=int, default=0) 61 | parser.add_argument("--mixup", type=float, default=0, help="how many (0-1) samples need to be mixup during training") 62 | parser.add_argument("--bal", type=str, default=None, help="use balanced sampling or not") 63 | parser.add_argument("--model_size", type=str, default='medium.en', help="The model size") 64 | parser.add_argument("--label_smooth", type=float, default=0.0, help="label smoothing factor") 65 | parser.add_argument("--weight_file", type=str, default=None, help="path to weight file") 66 | parser.add_argument("--ftmode", type=str, default='last', help="pretrained model path") 67 | parser.add_argument("--pretrain_epoch", type=int, default=0, help="number of pretrained epochs") 68 | parser.add_argument("--head_lr", type=float, default=1.0, help="learning rate ratio between mlp/base") 69 | args = parser.parse_args() 70 | 71 | if args.dataset == 'esc': 72 | if args.model_size == 'hubert-xlarge-ls960-ft' or args.model_size == 'wav2vec2-large-robust-ft-swbd-300h': 73 | train_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/' + args.model_size 74 | eval_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/' + args.model_size 75 | else: 76 | train_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/whisper_' + args.model_size 77 | eval_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/whisper_' + args.model_size 78 | shuffle = True 79 | elif args.dataset == 'as-bal' or args.dataset == 'as-full': 80 | if args.model_size == 'hubert-xlarge-ls960-ft' or args.model_size == 'wav2vec2-large-robust-ft-swbd-300h': 81 | train_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_as_full/' + args.model_size 82 | eval_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_as_full/' + args.model_size 83 | else: 84 | train_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_as_full/whisper_' + args.model_size 85 | eval_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_as_eval/whisper_' + args.model_size 86 | shuffle = True 87 | 88 | audio_conf = {'freqm': args.freqm, 'timem': args.timem, 'mixup': args.mixup, 'dataset': args.dataset, 'label_smooth': args.label_smooth, 'tar_path': train_tar_path} 89 | val_audio_conf = {'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset, 'tar_path': eval_tar_path} 90 | 91 | if args.bal == 'bal': 92 | print('balanced sampler is being used') 93 | if args.weight_file == None: 94 | samples_weight = np.loadtxt(args.data_train[:-5]+'_weight.csv', delimiter=',') 95 | else: 96 | samples_weight = np.loadtxt(args.data_train[:-5] + '_' + args.weight_file + '.csv', delimiter=',') 97 | sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True) 98 | 99 | train_loader = torch.utils.data.DataLoader( 100 | dataloader.AudiosetDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf), 101 | batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, pin_memory=True, drop_last=True) 102 | else: 103 | print('balanced sampler is not used') 104 | train_loader = torch.utils.data.DataLoader( 105 | dataloader.AudiosetDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf), 106 | batch_size=args.batch_size, shuffle=shuffle, num_workers=args.num_workers, pin_memory=True, drop_last=True) 107 | 108 | val_loader = torch.utils.data.DataLoader( 109 | dataloader.AudiosetDataset(args.data_val, label_csv=args.label_csv, audio_conf=val_audio_conf), 110 | batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=True) 111 | 112 | if args.data_eval != None: 113 | eval_loader = torch.utils.data.DataLoader( 114 | dataloader.AudiosetDataset(args.data_eval, label_csv=args.label_csv, audio_conf=val_audio_conf), 115 | batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) 116 | 117 | def get_feat_shape(path, args): 118 | mdl_size = args.model_size 119 | n_rep_dim_dict = {'tiny.en': 384, 'tiny': 384, 'base.en': 512, 'base': 512, 'small.en': 768, 'small': 768, 'medium.en': 1024, 'medium': 1024, 'large-v1': 1280, 'large-v2': 1280, 'wav2vec2-large-robust-ft-swbd-300h': 1024, 'hubert-xlarge-ls960-ft': 1280} 120 | n_layer_dict = {'tiny.en': 4, 'tiny': 4, 'base.en': 6, 'base': 6, 'small.en': 12, 'small': 12, 'medium.en': 24, 'medium': 24, 'large-v1': 32, 'large-v2': 32, 'wav2vec2-large-robust-ft-swbd-300h': 24, 'hubert-xlarge-ls960-ft': 48} 121 | return n_layer_dict[mdl_size], n_rep_dim_dict[mdl_size] 122 | 123 | if 'whisper-high' in args.model: 124 | mode = args.model.split('-')[-1] 125 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 126 | n_layer, rep_dim = get_feat_shape(train_tar_path, args) 127 | print(mode, args.model_size, n_layer, rep_dim) 128 | audio_model = TLTR(label_dim=args.n_class, n_layer=n_layer, rep_dim=rep_dim, mode=mode) 129 | else: 130 | raise ValueError('model not supported') 131 | 132 | # use data parallel 133 | if not isinstance(audio_model, torch.nn.DataParallel): 134 | audio_model = torch.nn.DataParallel(audio_model) 135 | 136 | print("\nCreating experiment directory: %s" % args.exp_dir) 137 | try: 138 | os.makedirs("%s/models" % args.exp_dir) 139 | except: 140 | pass 141 | with open("%s/args.pkl" % args.exp_dir, "wb") as f: 142 | pickle.dump(args, f) 143 | 144 | code_path = args.exp_dir + '/src/' 145 | if os.path.exists(code_path) == False: 146 | os.mkdir(code_path) 147 | copy_path = '/data/sls/scratch/yuangong/whisper-a/src/' 148 | os.system('cp ' + copy_path + '/*.sh ' + code_path) 149 | os.system('cp ' + copy_path + '/*.py ' + code_path) 150 | 151 | print('Now starting training for {:d} epochs'.format(args.n_epochs)) 152 | train(audio_model, train_loader, val_loader, args) 153 | 154 | 155 | def wa_model(exp_dir, start_epoch=16, end_epoch=30): 156 | sdA = torch.load(exp_dir + '/models/audio_model.' + str(start_epoch) + '.pth', map_location=device) 157 | model_cnt = 1 158 | for epoch in range(start_epoch+1, end_epoch+1): 159 | if os.path.exists(exp_dir + '/models/audio_model.' + str(epoch) + '.pth') == True: 160 | sdB = torch.load(exp_dir + '/models/audio_model.' + str(epoch) + '.pth', map_location=device) 161 | for key in sdA: 162 | sdA[key] = sdA[key] + sdB[key] 163 | model_cnt += 1 164 | print('wa {:d} models from {:d} to {:d}'.format(model_cnt, start_epoch, end_epoch)) 165 | # averaging 166 | for key in sdA: 167 | sdA[key] = sdA[key] / float(model_cnt) 168 | torch.save(sdA, exp_dir + '/models/audio_model_wa.pth') 169 | return sdA 170 | 171 | # do model weight averaging 172 | if args.wa == True: 173 | sdA = wa_model(args.exp_dir, args.wa_start, args.wa_end) 174 | msg = audio_model.load_state_dict(sdA, strict=True) 175 | print(msg) 176 | audio_model.eval() 177 | stats, _ = validate(audio_model, val_loader, args) 178 | wa_res = np.mean([stat['AP'] for stat in stats]) 179 | print('mAP of model with weights averaged from checkpoint {:d}-{:d} is {:.4f}'.format(args.wa_start, args.wa_end, wa_res)) 180 | np.savetxt(args.exp_dir + '/wa_res.csv', [args.wa_start, args.wa_end, wa_res], delimiter=',') -------------------------------------------------------------------------------- /src/whisper_at_train/run_as_full_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p a5 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -c 16 5 | #SBATCH --qos regular 6 | #SBATCH --mem=48000 7 | #SBATCH --job-name="w-as-high" 8 | #SBATCH --output=./log/%j_as.txt 9 | 10 | set -x 11 | # comment this line if not running on sls cluster 12 | . /data/sls/scratch/share-201907/slstoolchainrc 13 | source /data/sls/scratch/yuangong/whisper-a/venv-a5/bin/activate 14 | export TORCH_HOME=../../pretrained_models 15 | 16 | lr=5e-5 17 | freqm=0 18 | timem=10 19 | mixup=0.5 20 | batch_size=48 21 | model=whisper-high-lw_tr_1_8 #whisper-high-lw_tr_1_8 (tl-tr, lr=5e-5) whisper-high-lw_down_tr_512_1_8 (tl-tr-512, w/ low-dim proj, lr=1e-4) 22 | model_size=large-v2 23 | 24 | dataset=as-full 25 | bal=bal 26 | epoch=30 27 | lrscheduler_start=15 28 | lrscheduler_decay=0.75 29 | lrscheduler_step=5 30 | wa=True 31 | wa_start=16 32 | wa_end=30 33 | lr_adapt=False 34 | tr_data=/data/sls/scratch/yuangong/aed-pc/src/enhance_label/datafiles_local/whole_train_data.json 35 | te_data=/data/sls/scratch/yuangong/aed-pc/src/enhance_label/datafiles_local/eval_data.json 36 | label_smooth=0.1 37 | 38 | exp_dir=./exp/test-${dataset}-${model}-${model_size}-${lr}-${lrscheduler_start}-${lrscheduler_decay}-bs${batch_size}-lda${lr_adapt}-mix${mixup}-${freqm}-${timem} 39 | mkdir -p $exp_dir 40 | 41 | python -W ignore ./run.py --model ${model} --dataset ${dataset} \ 42 | --data-train ${tr_data} --data-val ${te_data} --exp-dir $exp_dir \ 43 | --label-csv /data/sls/scratch/yuangong/convast/egs/audioset/data/class_labels_indices.csv --n_class 527 \ 44 | --lr $lr --n-epochs ${epoch} --batch-size $batch_size --save_model True \ 45 | --freqm $freqm --timem $timem --mixup ${mixup} --bal ${bal} \ 46 | --model_size ${model_size} --label_smooth ${label_smooth} \ 47 | --lrscheduler_start ${lrscheduler_start} --lrscheduler_decay ${lrscheduler_decay} --lrscheduler_step ${lrscheduler_step} \ 48 | --loss BCE --metrics mAP --warmup True \ 49 | --wa ${wa} --wa_start ${wa_start} --wa_end ${wa_end} --lr_adapt ${lr_adapt} \ 50 | --num-workers 8 -------------------------------------------------------------------------------- /src/whisper_at_train/traintest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 6/10/21 11:00 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : traintest.py 7 | 8 | import sys 9 | import os 10 | import datetime 11 | sys.path.append(os.path.dirname(os.path.dirname(sys.path[0]))) 12 | from utilities import * 13 | import time 14 | import torch 15 | from torch import nn 16 | import numpy as np 17 | import pickle 18 | from torch.cuda.amp import autocast,GradScaler 19 | 20 | def train(audio_model, train_loader, test_loader, args): 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | print('running on ' + str(device)) 23 | torch.set_grad_enabled(True) 24 | 25 | # Initialize all of the statistics we want to keep track of 26 | batch_time = AverageMeter() 27 | per_sample_time = AverageMeter() 28 | data_time = AverageMeter() 29 | per_sample_data_time = AverageMeter() 30 | loss_meter = AverageMeter() 31 | per_sample_dnn_time = AverageMeter() 32 | progress = [] 33 | # best_cum_mAP is checkpoint ensemble from the first epoch to the best epoch 34 | best_epoch, best_cum_epoch, best_mAP, best_acc, best_cum_mAP = 0, 0, -np.inf, -np.inf, -np.inf 35 | global_step, epoch = 0, 0 36 | start_time = time.time() 37 | exp_dir = args.exp_dir 38 | 39 | def _save_progress(): 40 | progress.append([epoch, global_step, best_epoch, best_mAP, time.time() - start_time]) 41 | with open("%s/progress.pkl" % exp_dir, "wb") as f: 42 | pickle.dump(progress, f) 43 | 44 | if not isinstance(audio_model, nn.DataParallel): 45 | audio_model = nn.DataParallel(audio_model) 46 | 47 | audio_model = audio_model.to(device) 48 | 49 | trainables = [p for p in audio_model.parameters() if p.requires_grad] 50 | print('Total parameter number is : {:.3f} million'.format(sum(p.numel() for p in audio_model.parameters()) / 1e6)) 51 | print('Total trainable parameter number is : {:.3f} million'.format(sum(p.numel() for p in trainables) / 1e6)) 52 | 53 | optimizer = torch.optim.Adam(trainables, args.lr, weight_decay=5e-7, betas=(0.95, 0.999)) 54 | 55 | if args.lr_adapt == True: 56 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=args.lr_patience, verbose=True) 57 | print('Override to use adaptive learning rate scheduler.') 58 | else: 59 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(args.lrscheduler_start, 1000, args.lrscheduler_step)),gamma=args.lrscheduler_decay) 60 | print('The learning rate scheduler starts at {:d} epoch with decay rate of {:.3f} every {:d} epoches'.format(args.lrscheduler_start, args.lrscheduler_decay, args.lrscheduler_step)) 61 | main_metrics = args.metrics 62 | if args.loss == 'BCE': 63 | loss_fn = nn.BCEWithLogitsLoss() 64 | elif args.loss == 'CE': 65 | loss_fn = nn.CrossEntropyLoss() 66 | args.loss_fn = loss_fn 67 | 68 | print('now training with {:s}, main metrics: {:s}, loss function: {:s}, learning rate scheduler: {:s}'.format(str(args.dataset), str(main_metrics), str(loss_fn), str(scheduler))) 69 | 70 | epoch += 1 71 | scaler = GradScaler() 72 | 73 | print("current #steps=%s, #epochs=%s" % (global_step, epoch)) 74 | print("start training...") 75 | result = np.zeros([args.n_epochs, 4]) 76 | audio_model.train() 77 | while epoch < args.n_epochs + 1: 78 | begin_time = time.time() 79 | end_time = time.time() 80 | audio_model.train() 81 | print('---------------') 82 | print(datetime.datetime.now()) 83 | print("current #epochs=%s, #steps=%s" % (epoch, global_step)) 84 | 85 | for i, (a_input, labels) in enumerate(train_loader): 86 | 87 | B = a_input.size(0) 88 | a_input = a_input.to(device, non_blocking=True) 89 | labels = labels.to(device, non_blocking=True) 90 | 91 | data_time.update(time.time() - end_time) 92 | per_sample_data_time.update((time.time() - end_time) / a_input.shape[0]) 93 | dnn_start_time = time.time() 94 | 95 | with autocast(): 96 | audio_output = audio_model(a_input) 97 | loss = loss_fn(audio_output, labels) 98 | 99 | # optimiztion if amp is used 100 | optimizer.zero_grad() 101 | scaler.scale(loss).backward() 102 | scaler.step(optimizer) 103 | scaler.update() 104 | 105 | # record loss 106 | loss_meter.update(loss.item(), B) 107 | batch_time.update(time.time() - end_time) 108 | per_sample_time.update((time.time() - end_time)/a_input.shape[0]) 109 | per_sample_dnn_time.update((time.time() - dnn_start_time)/a_input.shape[0]) 110 | 111 | print_step = global_step % args.n_print_steps == 0 112 | early_print_step = epoch == 0 and global_step % (args.n_print_steps/10) == 0 113 | print_step = print_step or early_print_step 114 | 115 | if print_step and global_step != 0: 116 | print('Epoch: [{0}][{1}/{2}]\t' 117 | 'Per Sample Total Time {per_sample_time.avg:.5f}\t' 118 | 'Per Sample Data Time {per_sample_data_time.avg:.5f}\t' 119 | 'Per Sample DNN Time {per_sample_dnn_time.avg:.5f}\t' 120 | 'Train Loss {loss_meter.val:.4f}\t'.format( 121 | epoch, i, len(train_loader), per_sample_time=per_sample_time, per_sample_data_time=per_sample_data_time, 122 | per_sample_dnn_time=per_sample_dnn_time, loss_meter=loss_meter), flush=True) 123 | if np.isnan(loss_meter.avg): 124 | print("training diverged...") 125 | return 126 | 127 | end_time = time.time() 128 | global_step += 1 129 | 130 | # for audioset-full, break every 10% of the epoch, i.e., equivalent epochs = 0.1 * specified epochs 131 | if args.dataset == 'as-full': 132 | if i > 0.1 * len(train_loader): 133 | break 134 | 135 | print('start validation') 136 | 137 | stats, valid_loss = validate(audio_model, test_loader, args) 138 | 139 | mAP = np.mean([stat['AP'] for stat in stats]) 140 | mAUC = np.mean([stat['auc'] for stat in stats]) 141 | acc = stats[0]['acc'] 142 | 143 | if main_metrics == 'mAP': 144 | print("mAP: {:.6f}".format(mAP)) 145 | else: 146 | print("acc: {:.6f}".format(acc)) 147 | print("AUC: {:.6f}".format(mAUC)) 148 | print("d_prime: {:.6f}".format(d_prime(mAUC))) 149 | print("train_loss: {:.6f}".format(loss_meter.avg)) 150 | print("valid_loss: {:.6f}".format(valid_loss)) 151 | 152 | result[epoch-1, :] = [acc, mAP, mAUC, optimizer.param_groups[0]['lr']] 153 | np.savetxt(exp_dir + '/result.csv', result, delimiter=',') 154 | print('validation finished') 155 | 156 | if mAP > best_mAP: 157 | best_mAP = mAP 158 | if main_metrics == 'mAP': 159 | best_epoch = epoch 160 | 161 | if acc > best_acc: 162 | best_acc = acc 163 | if main_metrics == 'acc': 164 | best_epoch = epoch 165 | 166 | if best_epoch == epoch: 167 | pass 168 | #torch.save(audio_model.state_dict(), "%s/models/best_audio_model.pth" % (exp_dir)) 169 | if args.save_model == True: 170 | torch.save(audio_model.state_dict(), "%s/models/audio_model.%d.pth" % (exp_dir, epoch)) 171 | 172 | if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 173 | if main_metrics == 'mAP': 174 | scheduler.step(mAP) 175 | elif main_metrics == 'acc': 176 | scheduler.step(acc) 177 | else: 178 | scheduler.step() 179 | 180 | print('Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[0]['lr'])) 181 | 182 | with open(exp_dir + '/stats_' + str(epoch) +'.pickle', 'wb') as handle: 183 | pickle.dump(stats, handle, protocol=pickle.HIGHEST_PROTOCOL) 184 | _save_progress() 185 | 186 | finish_time = time.time() 187 | print('epoch {:d} training time: {:.3f}'.format(epoch, finish_time-begin_time)) 188 | 189 | epoch += 1 190 | 191 | batch_time.reset() 192 | per_sample_time.reset() 193 | data_time.reset() 194 | per_sample_data_time.reset() 195 | loss_meter.reset() 196 | per_sample_dnn_time.reset() 197 | 198 | def validate(audio_model, val_loader, args): 199 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 200 | batch_time = AverageMeter() 201 | if not isinstance(audio_model, nn.DataParallel): 202 | audio_model = nn.DataParallel(audio_model) 203 | audio_model = audio_model.to(device) 204 | # switch to evaluate mode 205 | audio_model.eval() 206 | 207 | end = time.time() 208 | A_predictions = [] 209 | A_targets = [] 210 | A_loss = [] 211 | with torch.no_grad(): 212 | for i, (a_input, labels) in enumerate(val_loader): 213 | a_input = a_input.to(device, non_blocking=True) 214 | 215 | with autocast(): 216 | audio_output = audio_model(a_input) 217 | 218 | predictions = audio_output.to('cpu').detach() 219 | 220 | A_predictions.append(predictions) 221 | A_targets.append(labels) 222 | 223 | # compute the loss 224 | labels = labels.to(device) 225 | loss = args.loss_fn(audio_output, labels) 226 | A_loss.append(loss.to('cpu').detach()) 227 | 228 | batch_time.update(time.time() - end) 229 | end = time.time() 230 | 231 | audio_output = torch.cat(A_predictions) 232 | target = torch.cat(A_targets) 233 | loss = np.mean(A_loss) 234 | stats = calculate_stats(audio_output, target) 235 | return stats, loss -------------------------------------------------------------------------------- /src/whisper_at_train/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 6/19/21 4:39 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : __init__.py 7 | 8 | from .util import * 9 | from .stats import * -------------------------------------------------------------------------------- /src/whisper_at_train/utilities/compute_flops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3/7/23 1:13 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : check_flops.py 7 | 8 | # check model size and flops 9 | import torch 10 | from fvcore.nn import FlopCountAnalysis 11 | from fvcore.nn import flop_count_table 12 | from high_mdls import HighMDL, HighMDLPool, HighMDLLayer, HighMDLFormal 13 | # from whisper.model import Whisper, ModelDimensions 14 | 15 | def cnt_flops(model, input): 16 | flops = FlopCountAnalysis(model, input) 17 | print(flop_count_table(flops)) 18 | print(flops.total()/1e9) 19 | print(flops.by_operator()) 20 | #print(flops.by_module()) 21 | #print(flops.by_module_and_operator()) 22 | 23 | # # original whisper model 24 | # checkpoint_path = '/data/sls/scratch/yuangong/whisper-a/src/{:s}.pt'.format('small.en') 25 | # checkpoint = torch.load(checkpoint_path, map_location='cpu') 26 | # dims = ModelDimensions(**checkpoint["dims"]) 27 | # print(dims) 28 | # model = Whisper(dims, label_dim=527, cla='mlp_1') 29 | # input = torch.rand([1, 80, 512*2]) 30 | # cnt_flops(model, input) 31 | 32 | 33 | def get_feat_shape(mdl_size): 34 | n_rep_dim_dict = {'tiny.en': 384, 'tiny': 384, 'base.en': 512, 'base': 512, 'small.en': 768, 'small': 768, 'medium.en': 1024, 'medium': 1024, 'large-v1': 1280, 'large-v2': 1280, 'wav2vec2-large-robust-ft-swbd-300h': 1024, 'hubert-xlarge-ls960-ft': 1280} 35 | n_layer_dict = {'tiny.en': 4, 'tiny': 4, 'base.en': 6, 'base': 6, 'small.en': 12, 'small': 12, 'medium.en': 24, 'medium': 24, 'large-v1': 32, 'large-v2': 32, 'wav2vec2-large-robust-ft-swbd-300h': 24, 'hubert-xlarge-ls960-ft': 48} 36 | return n_layer_dict[mdl_size], n_rep_dim_dict[mdl_size] 37 | 38 | model_name = 'whisper-high-lw_down_tr_768_1_8' 39 | model_size = 'large-v1' 40 | mode = model_name.split('-')[-1] 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | n_layer, rep_dim = get_feat_shape(model_size) 43 | print(mode, model_size, n_layer, rep_dim) 44 | model = HighMDLFormal(label_dim=527, n_layer=n_layer, rep_dim=rep_dim, mode=mode) 45 | 46 | # for large-v1 47 | cnt_flops(model, torch.rand([1, n_layer, 25, rep_dim])) -------------------------------------------------------------------------------- /src/whisper_at_train/utilities/compute_mAP.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 6/1/23 12:40 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : compute_mAP.py 7 | 8 | # compute mAP on whisper-at 9 | 10 | from stats import * 11 | 12 | mdl_size_list = [ 13 | 'tiny_low_False', 14 | 'tiny.en_low_False', 15 | 'base_low_False', 16 | 'base.en_low_False', 17 | 'small_low_False', 18 | 'small_low_True', 19 | 'small.en_low_False', 20 | 'small.en_low_True', 21 | 'medium_low_False', 22 | 'medium_low_True', 23 | 'medium.en_low_False', 24 | 'medium.en_low_True', 25 | 'large-v1_low_False', 26 | 'large-v1_low_True', 27 | 'large-v2_low_False', 28 | 'large-v2_low_True'] 29 | 30 | for mdl_size in mdl_size_list: 31 | all_truth = np.load('/data/sls/scratch/yuangong/whisper-at/old/at_res/all_truth_' + mdl_size + '.npy') 32 | all_pred = np.load('/data/sls/scratch/yuangong/whisper-at/old/at_res/all_pred_' + mdl_size + '.npy') 33 | print(mdl_size) 34 | print(all_truth.shape, all_pred.shape) 35 | 36 | stats = calculate_stats(all_pred, all_truth) 37 | mAP = np.mean([stat['AP'] for stat in stats]) 38 | print(mAP) -------------------------------------------------------------------------------- /src/whisper_at_train/utilities/rename_state_dict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 5/30/23 3:00 AM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : rename_state_dict.py 7 | 8 | # rename state dict (trained with feats) to put together with whisper-at model 9 | 10 | import torch 11 | import os 12 | 13 | def get_immediate_files_with_extension(directory, extension='pth'): 14 | file_list = [] 15 | for file in os.listdir(directory): 16 | if file.endswith(extension) and os.path.isfile(os.path.join(directory, file)): 17 | file_list.append(os.path.join(directory, file)) 18 | return file_list 19 | 20 | def replace_name(ori_mdl_path, tar_mdl_path): 21 | sd = torch.load(ori_mdl_path, map_location='cpu') 22 | mdl_key_list = sd.keys() 23 | 24 | whisper_at_dict = {} 25 | for mdl_key in mdl_key_list: 26 | new_mdl_key = mdl_key.replace('module.', 'at_model.') 27 | #print(new_mdl_key) 28 | whisper_at_dict[new_mdl_key] = sd[mdl_key] 29 | 30 | print(len(sd.keys()), len(whisper_at_dict.keys())) 31 | torch.save(whisper_at_dict, tar_mdl_path) 32 | 33 | mdl_list = get_immediate_files_with_extension('/data/sls/scratch/yuangong/whisper-at/exp/') 34 | print(mdl_list) 35 | tar_path = '/data/sls/scratch/yuangong/whisper-at/exp/converted_to_whisper_at/' 36 | for mdl in mdl_list: 37 | print(mdl) 38 | print('-----------------------') 39 | replace_name(mdl, tar_path + mdl.split('/')[-1]) 40 | -------------------------------------------------------------------------------- /src/whisper_at_train/utilities/stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | from sklearn import metrics 4 | import torch 5 | 6 | def d_prime(auc): 7 | standard_normal = stats.norm() 8 | d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) 9 | return d_prime 10 | 11 | def calculate_stats(output, target): 12 | """Calculate statistics including mAP, AUC, etc. 13 | 14 | Args: 15 | output: 2d array, (samples_num, classes_num) 16 | target: 2d array, (samples_num, classes_num) 17 | 18 | Returns: 19 | stats: list of statistic of each class. 20 | """ 21 | 22 | classes_num = target.shape[-1] 23 | stats = [] 24 | 25 | # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet 26 | acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) 27 | 28 | # Class-wise statistics 29 | for k in range(classes_num): 30 | 31 | # Average precision 32 | avg_precision = metrics.average_precision_score( 33 | target[:, k], output[:, k], average=None) 34 | 35 | # AUC 36 | try: 37 | auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None) 38 | 39 | # Precisions, recalls 40 | (precisions, recalls, thresholds) = metrics.precision_recall_curve( 41 | target[:, k], output[:, k]) 42 | 43 | # FPR, TPR 44 | (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) 45 | 46 | save_every_steps = 1000 # Sample statistics to reduce size 47 | dict = {'precisions': precisions[0::save_every_steps], 48 | 'recalls': recalls[0::save_every_steps], 49 | 'AP': avg_precision, 50 | 'fpr': fpr[0::save_every_steps], 51 | 'fnr': 1. - tpr[0::save_every_steps], 52 | 'auc': auc, 53 | # note acc is not class-wise, this is just to keep consistent with other metrics 54 | 'acc': acc 55 | } 56 | except: 57 | dict = {'precisions': -1, 58 | 'recalls': -1, 59 | 'AP': avg_precision, 60 | 'fpr': -1, 61 | 'fnr': -1, 62 | 'auc': -1, 63 | # note acc is not class-wise, this is just to keep consistent with other metrics 64 | 'acc': acc 65 | } 66 | print('class {:s} no true sample'.format(str(k))) 67 | stats.append(dict) 68 | 69 | return stats -------------------------------------------------------------------------------- /src/whisper_at_train/utilities/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import random 7 | from collections import namedtuple 8 | 9 | def calc_recalls(S): 10 | """ 11 | Computes recall at 1, 5, and 10 given a similarity matrix S. 12 | By convention, rows of S are assumed to correspond to images and columns are captions. 13 | """ 14 | assert(S.dim() == 2) 15 | assert(S.size(0) == S.size(1)) 16 | if isinstance(S, torch.autograd.Variable): 17 | S = S.data 18 | n = S.size(0) 19 | A2I_scores, A2I_ind = S.topk(10, 0) 20 | I2A_scores, I2A_ind = S.topk(10, 1) 21 | A_r1 = AverageMeter() 22 | A_r5 = AverageMeter() 23 | A_r10 = AverageMeter() 24 | I_r1 = AverageMeter() 25 | I_r5 = AverageMeter() 26 | I_r10 = AverageMeter() 27 | for i in range(n): 28 | A_foundind = -1 29 | I_foundind = -1 30 | for ind in range(10): 31 | if A2I_ind[ind, i] == i: 32 | I_foundind = ind 33 | if I2A_ind[i, ind] == i: 34 | A_foundind = ind 35 | # do r1s 36 | if A_foundind == 0: 37 | A_r1.update(1) 38 | else: 39 | A_r1.update(0) 40 | if I_foundind == 0: 41 | I_r1.update(1) 42 | else: 43 | I_r1.update(0) 44 | # do r5s 45 | if A_foundind >= 0 and A_foundind < 5: 46 | A_r5.update(1) 47 | else: 48 | A_r5.update(0) 49 | if I_foundind >= 0 and I_foundind < 5: 50 | I_r5.update(1) 51 | else: 52 | I_r5.update(0) 53 | # do r10s 54 | if A_foundind >= 0 and A_foundind < 10: 55 | A_r10.update(1) 56 | else: 57 | A_r10.update(0) 58 | if I_foundind >= 0 and I_foundind < 10: 59 | I_r10.update(1) 60 | else: 61 | I_r10.update(0) 62 | 63 | recalls = {'A_r1':A_r1.avg, 'A_r5':A_r5.avg, 'A_r10':A_r10.avg, 64 | 'I_r1':I_r1.avg, 'I_r5':I_r5.avg, 'I_r10':I_r10.avg} 65 | #'A_meanR':A_meanR.avg, 'I_meanR':I_meanR.avg} 66 | 67 | return recalls 68 | 69 | def computeMatchmap(I, A): 70 | assert(I.dim() == 3) 71 | assert(A.dim() == 2) 72 | D = I.size(0) 73 | H = I.size(1) 74 | W = I.size(2) 75 | T = A.size(1) 76 | Ir = I.view(D, -1).t() 77 | matchmap = torch.mm(Ir, A) 78 | matchmap = matchmap.view(H, W, T) 79 | return matchmap 80 | 81 | def matchmapSim(M, simtype): 82 | assert(M.dim() == 3) 83 | if simtype == 'SISA': 84 | return M.mean() 85 | elif simtype == 'MISA': 86 | M_maxH, _ = M.max(0) 87 | M_maxHW, _ = M_maxH.max(0) 88 | return M_maxHW.mean() 89 | elif simtype == 'SIMA': 90 | M_maxT, _ = M.max(2) 91 | return M_maxT.mean() 92 | else: 93 | raise ValueError 94 | 95 | def sampled_margin_rank_loss(image_outputs, audio_outputs, nframes, margin=1., simtype='MISA'): 96 | """ 97 | Computes the triplet margin ranking loss for each anchor image/caption pair 98 | The impostor image/caption is randomly sampled from the minibatch 99 | """ 100 | assert(image_outputs.dim() == 4) 101 | assert(audio_outputs.dim() == 3) 102 | n = image_outputs.size(0) 103 | loss = torch.zeros(1, device=image_outputs.device, requires_grad=True) 104 | for i in range(n): 105 | I_imp_ind = i 106 | A_imp_ind = i 107 | while I_imp_ind == i: 108 | I_imp_ind = np.random.randint(0, n) 109 | while A_imp_ind == i: 110 | A_imp_ind = np.random.randint(0, n) 111 | nF = nframes[i] 112 | nFimp = nframes[A_imp_ind] 113 | anchorsim = matchmapSim(computeMatchmap(image_outputs[i], audio_outputs[i][:, 0:nF]), simtype) 114 | Iimpsim = matchmapSim(computeMatchmap(image_outputs[I_imp_ind], audio_outputs[i][:, 0:nF]), simtype) 115 | Aimpsim = matchmapSim(computeMatchmap(image_outputs[i], audio_outputs[A_imp_ind][:, 0:nFimp]), simtype) 116 | A2I_simdif = margin + Iimpsim - anchorsim 117 | if (A2I_simdif.data > 0).all(): 118 | loss = loss + A2I_simdif 119 | I2A_simdif = margin + Aimpsim - anchorsim 120 | if (I2A_simdif.data > 0).all(): 121 | loss = loss + I2A_simdif 122 | loss = loss / n 123 | return loss 124 | 125 | def compute_matchmap_similarity_matrix(image_outputs, audio_outputs, nframes, simtype='MISA'): 126 | """ 127 | Assumes image_outputs is a (batchsize, embedding_dim, rows, height) tensor 128 | Assumes audio_outputs is a (batchsize, embedding_dim, 1, time) tensor 129 | Returns similarity matrix S where images are rows and audios are along the columns 130 | """ 131 | assert(image_outputs.dim() == 4) 132 | assert(audio_outputs.dim() == 3) 133 | n = image_outputs.size(0) 134 | S = torch.zeros(n, n, device=image_outputs.device) 135 | for image_idx in range(n): 136 | for audio_idx in range(n): 137 | nF = max(1, nframes[audio_idx]) 138 | S[image_idx, audio_idx] = matchmapSim(computeMatchmap(image_outputs[image_idx], audio_outputs[audio_idx][:, 0:nF]), simtype) 139 | return S 140 | 141 | def compute_pooldot_similarity_matrix(image_outputs, audio_outputs, nframes): 142 | """ 143 | Assumes image_outputs is a (batchsize, embedding_dim, rows, height) tensor 144 | Assumes audio_outputs is a (batchsize, embedding_dim, 1, time) tensor 145 | Returns similarity matrix S where images are rows and audios are along the columns 146 | S[i][j] is computed as the dot product between the meanpooled embeddings of 147 | the ith image output and jth audio output 148 | """ 149 | assert(image_outputs.dim() == 4) 150 | assert(audio_outputs.dim() == 4) 151 | n = image_outputs.size(0) 152 | imagePoolfunc = nn.AdaptiveAvgPool2d((1, 1)) 153 | pooled_image_outputs = imagePoolfunc(image_outputs).squeeze(3).squeeze(2) 154 | audioPoolfunc = nn.AdaptiveAvgPool2d((1, 1)) 155 | pooled_audio_outputs_list = [] 156 | for idx in range(n): 157 | nF = max(1, nframes[idx]) 158 | pooled_audio_outputs_list.append(audioPoolfunc(audio_outputs[idx][:, :, 0:nF]).unsqueeze(0)) 159 | pooled_audio_outputs = torch.cat(pooled_audio_outputs_list).squeeze(3).squeeze(2) 160 | S = torch.mm(pooled_image_outputs, pooled_audio_outputs.t()) 161 | return S 162 | 163 | def one_imposter_index(i, N): 164 | imp_ind = random.randint(0, N - 2) 165 | if imp_ind == i: 166 | imp_ind = N - 1 167 | return imp_ind 168 | 169 | def basic_get_imposter_indices(N): 170 | imposter_idc = [] 171 | for i in range(N): 172 | # Select an imposter index for example i: 173 | imp_ind = one_imposter_index(i, N) 174 | imposter_idc.append(imp_ind) 175 | return imposter_idc 176 | 177 | def semihardneg_triplet_loss_from_S(S, margin): 178 | """ 179 | Input: Similarity matrix S as an autograd.Variable 180 | Output: The one-way triplet loss from rows of S to columns of S. Impostors are taken 181 | to be the most similar point to the anchor that is still less similar to the anchor 182 | than the positive example. 183 | You would need to run this function twice, once with S and once with S.t(), 184 | in order to compute the triplet loss in both directions. 185 | """ 186 | assert(S.dim() == 2) 187 | assert(S.size(0) == S.size(1)) 188 | N = S.size(0) 189 | loss = torch.autograd.Variable(torch.zeros(1).type(S.data.type()), requires_grad=True) 190 | # Imposter - ground truth 191 | Sdiff = S - torch.diag(S).view(-1, 1) 192 | eps = 1e-12 193 | # All examples less similar than ground truth 194 | mask = (Sdiff < -eps).type(torch.LongTensor) 195 | maskf = mask.type_as(S) 196 | # Mask out all examples >= gt with minimum similarity 197 | Sp = maskf * Sdiff + (1 - maskf) * torch.min(Sdiff).detach() 198 | # Find the index maximum similar of the remaining 199 | _, idc = Sp.max(dim=1) 200 | idc = idc.data.cpu() 201 | # Vector mask: 1 iff there exists an example < gt 202 | has_neg = (mask.sum(dim=1) > 0).data.type(torch.LongTensor) 203 | # Random imposter indices 204 | random_imp_ind = torch.LongTensor(basic_get_imposter_indices(N)) 205 | # Use hardneg if there exists an example < gt, otherwise use random imposter 206 | imp_idc = has_neg * idc + (1 - has_neg) * random_imp_ind 207 | # This could probably be vectorized too, but I haven't. 208 | for i, imp in enumerate(imp_idc): 209 | local_loss = Sdiff[i, imp] + margin 210 | if (local_loss.data > 0).all(): 211 | loss = loss + local_loss 212 | loss = loss / N 213 | return loss 214 | 215 | def sampled_triplet_loss_from_S(S, margin): 216 | """ 217 | Input: Similarity matrix S as an autograd.Variable 218 | Output: The one-way triplet loss from rows of S to columns of S. Imposters are 219 | randomly sampled from the columns of S. 220 | You would need to run this function twice, once with S and once with S.t(), 221 | in order to compute the triplet loss in both directions. 222 | """ 223 | assert(S.dim() == 2) 224 | assert(S.size(0) == S.size(1)) 225 | N = S.size(0) 226 | loss = torch.autograd.Variable(torch.zeros(1).type(S.data.type()), requires_grad=True) 227 | # Imposter - ground truth 228 | Sdiff = S - torch.diag(S).view(-1, 1) 229 | imp_ind = torch.LongTensor(basic_get_imposter_indices(N)) 230 | # This could probably be vectorized too, but I haven't. 231 | for i, imp in enumerate(imp_ind): 232 | local_loss = Sdiff[i, imp] + margin 233 | if (local_loss.data > 0).all(): 234 | loss = loss + local_loss 235 | loss = loss / N 236 | return loss 237 | 238 | class AverageMeter(object): 239 | """Computes and stores the average and current value""" 240 | def __init__(self): 241 | self.reset() 242 | 243 | def reset(self): 244 | self.val = 0 245 | self.avg = 0 246 | self.sum = 0 247 | self.count = 0 248 | 249 | def update(self, val, n=1): 250 | self.val = val 251 | self.sum += val * n 252 | self.count += n 253 | self.avg = self.sum / self.count 254 | 255 | def adjust_learning_rate(base_lr, lr_decay, optimizer, epoch): 256 | """Sets the learning rate to the initial LR decayed by 10 every lr_decay epochs""" 257 | lr = base_lr * (0.1 ** (epoch // lr_decay)) 258 | print('now learning rate changed to {:f}'.format(lr)) 259 | for param_group in optimizer.param_groups: 260 | param_group['lr'] = lr 261 | 262 | def adjust_learning_rate2(base_lr, lr_decay, optimizer, epoch): 263 | """Sets the learning rate to the initial LR decayed by 10 every lr_decay epochs""" 264 | for param_group in optimizer.param_groups: 265 | cur_lr = param_group['lr'] 266 | print('current learing rate is {:f}'.format(lr)) 267 | lr = cur_lr * 0.1 268 | print('now learning rate changed to {:f}'.format(lr)) 269 | for param_group in optimizer.param_groups: 270 | param_group['lr'] = lr 271 | 272 | 273 | def load_progress(prog_pkl, quiet=False): 274 | """ 275 | load progress pkl file 276 | Args: 277 | prog_pkl(str): path to progress pkl file 278 | Return: 279 | progress(list): 280 | epoch(int): 281 | global_step(int): 282 | best_epoch(int): 283 | best_avg_r10(float): 284 | """ 285 | def _print(msg): 286 | if not quiet: 287 | print(msg) 288 | 289 | with open(prog_pkl, "rb") as f: 290 | prog = pickle.load(f) 291 | epoch, global_step, best_epoch, best_avg_r10, _ = prog[-1] 292 | 293 | _print("\nPrevious Progress:") 294 | msg = "[%5s %7s %5s %7s %6s]" % ("epoch", "step", "best_epoch", "best_avg_r10", "time") 295 | _print(msg) 296 | return prog, epoch, global_step, best_epoch, best_avg_r10 297 | 298 | def count_parameters(model): 299 | return sum([p.numel() for p in model.parameters() if p.requires_grad]) 300 | 301 | PrenetConfig = namedtuple( 302 | 'PrenetConfig', ['input_size', 'hidden_size', 'num_layers', 'dropout']) 303 | 304 | RNNConfig = namedtuple( 305 | 'RNNConfig', 306 | ['input_size', 'hidden_size', 'num_layers', 'dropout', 'residual']) 307 | -------------------------------------------------------------------------------- /src/whisper_at_train/utilities/whisper_at_as_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 5/28/23 2:36 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : whisper_transcribe_test.py 7 | 8 | # evaluate whisper-at on as-eval set 9 | # note this use 30s window, performance will be slightly lower than that using 10s window 10 | 11 | import sys 12 | argument = sys.argv[1] 13 | if argument=='4': 14 | argument='0,1,2,3' 15 | import os 16 | if argument != '-1': 17 | os.environ["CUDA_VISIBLE_DEVICES"]=argument 18 | 19 | import whisper_at as whisper 20 | import numpy as np 21 | import csv 22 | import json 23 | import torch 24 | 25 | def make_index_dict(label_csv): 26 | index_lookup = {} 27 | with open(label_csv, 'r') as f: 28 | csv_reader = csv.DictReader(f) 29 | line_count = 0 30 | for row in csv_reader: 31 | index_lookup[row['mid']] = row['index'] 32 | line_count += 1 33 | return index_lookup 34 | 35 | def make_name_dict(label_csv): 36 | name_lookup = {} 37 | with open(label_csv, 'r') as f: 38 | csv_reader = csv.DictReader(f) 39 | line_count = 0 40 | for row in csv_reader: 41 | name_lookup[row['index']] = row['display_name'] 42 | line_count += 1 43 | return name_lookup 44 | 45 | def lookup_list(index_list, label_csv): 46 | label_list = [] 47 | table = make_name_dict(label_csv) 48 | for item in index_list: 49 | label_list.append(table[item]) 50 | return label_list 51 | 52 | mdl_size='large-v1' 53 | at_low_compute=False 54 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 55 | model = whisper.load_model(mdl_size, at_low_compute=at_low_compute).to(device) 56 | 57 | index_dict = make_index_dict('/data/sls/scratch/yuangong/whisper-at/src/class_labels_indices.csv') 58 | with open('/data/sls/scratch/yuangong/whisper-at/src/eval_data.json') as json_file: 59 | data = json.load(json_file)['data'] 60 | num_file = len(data) 61 | 62 | all_pred, all_truth = torch.zeros([num_file, 527]).to(device), torch.zeros([num_file, 527]).to(device) 63 | for i, entry in enumerate(data): 64 | cur_wav = entry['wav'] 65 | labels = entry['labels'].split(',') 66 | for label in labels: 67 | all_truth[i, int(index_dict[label])] = 1.0 68 | result = model.transcribe(cur_wav, language='en', logprob_threshold=None, compression_ratio_threshold=None)['audio_tag'] 69 | all_pred[i] = result[0] 70 | 71 | if i % 100 == 0: 72 | print(i) 73 | np.save('./at_res/all_pred_{:s}_low_{:s}.npy'.format(mdl_size, str(at_low_compute)), all_pred.cpu().numpy()) 74 | np.save('./at_res/all_truth_{:s}_low_{:s}.npy'.format(mdl_size, str(at_low_compute)), all_truth.cpu().numpy()) 75 | 76 | np.save('./at_res/all_pred_{:s}_low_{:s}.npy'.format(mdl_size, str(at_low_compute)), all_pred.cpu().numpy()) 77 | np.save('./at_res/all_truth_{:s}_low_{:s}.npy'.format(mdl_size, str(at_low_compute)), all_truth.cpu().numpy()) -------------------------------------------------------------------------------- /tltr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/tltr.png --------------------------------------------------------------------------------