├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── audioset_label.csv
├── package
└── whisper-at
│ ├── .flake8
│ ├── .gitattributes
│ ├── .github
│ └── workflows
│ │ ├── python-publish.yml
│ │ └── test.yml
│ ├── LICENSE
│ ├── MANIFEST.in
│ ├── README.md
│ ├── pyproject.toml
│ ├── requirements.txt
│ ├── setup.py
│ └── whisper_at
│ ├── __init__.py
│ ├── __main__.py
│ ├── assets
│ ├── gpt2.tiktoken
│ ├── label_name_dict.json
│ ├── mel_filters.npz
│ └── multilingual.tiktoken
│ ├── at_post_processing.py
│ ├── audio.py
│ ├── decoding.py
│ ├── model.py
│ ├── normalizers
│ ├── __init__.py
│ ├── basic.py
│ ├── english.json
│ └── english.py
│ ├── timing.py
│ ├── tokenizer.py
│ ├── transcribe.py
│ ├── triton_ops.py
│ ├── utils.py
│ └── version.py
├── poster.pdf
├── poster.png
├── poster_low.png
├── pretrained_models
└── README.md
├── review
├── author_response.pdf
└── whisper_at_review.pdf
├── sample
├── whisper_at_demo.ipynb
└── whisper_transcribe_test_simple.py
├── src
├── noise_robust_asr
│ ├── asr_experiments
│ │ ├── compute_wer.py
│ │ ├── compute_wer_cla.py
│ │ ├── gen_noisy_speech.py
│ │ ├── transcribe_esc_hubert_xl.py
│ │ ├── transcribe_hubert_large.py
│ │ ├── transcribe_wav2vec_base.py
│ │ ├── transcribe_wav2vec_robust.py
│ │ └── transcribe_whisper.py
│ ├── baseline_sound_classification.py
│ ├── intermediate_feat_extract
│ │ ├── as_full
│ │ │ ├── batch_as_full_extract.sh
│ │ │ ├── extract_as_full_whisper_all.py
│ │ │ └── extract_as_full_whisper_all.sh
│ │ ├── esc-50
│ │ │ ├── extract_esc50_hubert_xl_all_pool.py
│ │ │ ├── extract_esc50_w2v_robust_all.py
│ │ │ └── extract_esc50_whisper_all_pool.py
│ │ └── whisper_feat_extracrt
│ │ │ ├── .github
│ │ │ └── workflows
│ │ │ │ ├── python-publish.yml
│ │ │ │ └── test.yml
│ │ │ ├── LICENSE
│ │ │ ├── MANIFEST.in
│ │ │ ├── data
│ │ │ ├── README.md
│ │ │ └── meanwhile.json
│ │ │ ├── notebooks
│ │ │ ├── LibriSpeech.ipynb
│ │ │ └── Multilingual_ASR.ipynb
│ │ │ ├── requirements.txt
│ │ │ ├── setup.py
│ │ │ ├── tests
│ │ │ ├── jfk.flac
│ │ │ ├── test_audio.py
│ │ │ ├── test_normalizer.py
│ │ │ ├── test_tokenizer.py
│ │ │ └── test_transcribe.py
│ │ │ └── whisper
│ │ │ ├── __init__.py
│ │ │ ├── __main__.py
│ │ │ ├── assets
│ │ │ ├── gpt2
│ │ │ │ ├── merges.txt
│ │ │ │ ├── special_tokens_map.json
│ │ │ │ ├── tokenizer_config.json
│ │ │ │ └── vocab.json
│ │ │ ├── mel_filters.npz
│ │ │ └── multilingual
│ │ │ │ ├── added_tokens.json
│ │ │ │ ├── merges.txt
│ │ │ │ ├── special_tokens_map.json
│ │ │ │ ├── tokenizer_config.json
│ │ │ │ └── vocab.json
│ │ │ ├── audio.py
│ │ │ ├── decoding.py
│ │ │ ├── model.py
│ │ │ ├── normalizers
│ │ │ ├── __init__.py
│ │ │ ├── basic.py
│ │ │ ├── english.json
│ │ │ └── english.py
│ │ │ ├── tokenizer.py
│ │ │ ├── transcribe.py
│ │ │ ├── utils.py
│ │ │ └── version.py
│ └── plots
│ │ ├── plot_figure1_lower.py
│ │ ├── plot_figure1_upper.py
│ │ ├── plot_figure2.py
│ │ └── plot_figure3.py
└── whisper_at_train
│ ├── class_labels_indices.csv
│ ├── datafiles
│ └── README.md
│ ├── dataloader_feat.py
│ ├── gen_weight_file.py
│ ├── log
│ ├── base_ori.txt
│ ├── large-v1_low.txt
│ ├── large-v2_low.txt
│ ├── large-v2_ori.txt
│ ├── large_v1.txt
│ ├── medium_low.txt
│ ├── medium_ori.txt
│ ├── small_low.txt
│ ├── small_ori.txt
│ └── tiny_ori.txt
│ ├── models.py
│ ├── run.py
│ ├── run_as_full_train.sh
│ ├── traintest.py
│ └── utilities
│ ├── __init__.py
│ ├── compute_flops.py
│ ├── compute_mAP.py
│ ├── rename_state_dict.py
│ ├── stats.py
│ ├── util.py
│ └── whisper_at_as_eval.py
└── tltr.png
/.gitattributes:
--------------------------------------------------------------------------------
1 | * linguist-vendored
2 | *.py linguist-vendored=false
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 | *.egg-info
5 | .pytest_cache
6 | .ipynb_checkpoints
7 | thumbs.db
8 | .DS_Store
9 | .idea
10 | old/*
11 | *.pptx
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2023, Yuan Gong
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import whisper_at as whisper
3 |
4 | link = "https://github.com/YuanGongND/whisper-AT"
5 | text = "[Github]"
6 | paper_link = "https://arxiv.org/pdf/2307.03183.pdf"
7 | paper_text = "[Paper]"
8 |
9 | model_large = whisper.load_model("large-v1")
10 | model_tiny = whisper.load_model("tiny")
11 | model_tiny_en = whisper.load_model("tiny.en")
12 | model_small = whisper.load_model("small")
13 |
14 | mdl_dict = {"tiny": model_tiny, "tiny.en": model_tiny_en, "small": model_small, "large": model_large}
15 | lan_dict = {"English": 'en', "Chinese": 'zh'}
16 |
17 | def round_time_resolution(time_resolution):
18 | multiple = float(time_resolution) / 0.4
19 | rounded_multiple = round(multiple)
20 | rounded_time_resolution = rounded_multiple * 0.4
21 | return rounded_time_resolution
22 |
23 | def predict(audio_path_m, audio_path_t, model_size, language, time_resolution):
24 | # print(audio_path_m, audio_path_t)
25 | # print(type(audio_path_m), type(audio_path_t))
26 | #return audio_path_m, audio_path_t
27 | if ((audio_path_m is None) != (audio_path_t is None)) == False:
28 | return "Please upload and only upload one recording, either upload the audio file or record using microphone.", "Please upload and only upload one recording, either upload the audio file or record using microphone."
29 | else:
30 | audio_path = audio_path_m or audio_path_t
31 | audio_tagging_time_resolution = round_time_resolution(time_resolution)
32 | model = mdl_dict[model_size]
33 | if language == 'Auto Detection':
34 | result = model.transcribe(audio_path, at_time_res=audio_tagging_time_resolution)
35 | else:
36 | result = model.transcribe(audio_path, at_time_res=audio_tagging_time_resolution, language=lan_dict[language])
37 | audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527)))
38 | asr_output = ""
39 | for segment in result['segments']:
40 | asr_output = asr_output + format(segment['start'], ".1f") + 's-' + format(segment['end'], ".1f") + 's: ' + segment['text'] + '\n'
41 | at_output = ""
42 | for segment in audio_tag_result:
43 | print(segment)
44 | at_output = at_output + format(segment['time']['start'], ".1f") + 's-' + format(segment['time']['end'], ".1f") + 's: ' + ', '.join([x[0] for x in segment['audio tags']]) + '\n'
45 | print(at_output)
46 | return asr_output, at_output
47 |
48 | iface = gr.Interface(fn=predict,
49 | inputs=[gr.Audio(type="filepath", source='microphone', label='Please either upload an audio file or record using the microphone.', show_label=True), gr.Audio(type="filepath"),
50 | gr.Radio(["tiny", "tiny.en", "small", "large"], value='large', label="Model size", info="The larger the model, the better the performance and the slower the speed."),
51 | gr.Radio(["Auto Detection", "English", "Chinese"], value='Auto Detection', label="Language", info="Please specify the language, or let the model detect it automatically"),
52 | gr.Textbox(value='10', label='Time Resolution in Seconds (Must be must be an integer multiple of 0.4, e.g., 0.4, 2, 10)')],
53 | outputs=[gr.Textbox(label="Speech Output"), gr.Textbox(label="Audio Tag Output")],
54 | cache_examples=True,
55 | title="Quick Demo of Whisper-AT",
56 | description="We are glad to introduce Whisper-AT - A new joint audio tagging and speech recognition model. It outputs background sound labels in addition to text." + f"{paper_text} " + f"{text}
" +
57 | "Whisper-AT is authored by Yuan Gong, Sameer Khurana, Leonid Karlinsky, and James Glass (MIT & MIT-IBM Watson AI Lab). It is an Interspeech 2023 paper.")
58 | iface.launch(debug=True, share=True)
--------------------------------------------------------------------------------
/package/whisper-at/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | per-file-ignores =
3 | */__init__.py: F401
4 |
5 |
--------------------------------------------------------------------------------
/package/whisper-at/.gitattributes:
--------------------------------------------------------------------------------
1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages
2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
3 | *.ipynb linguist-generated
4 |
--------------------------------------------------------------------------------
/package/whisper-at/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v3
12 | - uses: actions-ecosystem/action-regex-match@v2
13 | id: regex-match
14 | with:
15 | text: ${{ github.event.head_commit.message }}
16 | regex: '^Release ([^ ]+)'
17 | - name: Set up Python
18 | uses: actions/setup-python@v4
19 | with:
20 | python-version: '3.8'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Release
26 | if: ${{ steps.regex-match.outputs.match != '' }}
27 | uses: softprops/action-gh-release@v1
28 | with:
29 | tag_name: v${{ steps.regex-match.outputs.group1 }}
30 | - name: Build and publish
31 | if: ${{ steps.regex-match.outputs.match != '' }}
32 | env:
33 | TWINE_USERNAME: __token__
34 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
35 | run: |
36 | python setup.py sdist
37 | twine upload dist/*
38 |
--------------------------------------------------------------------------------
/package/whisper-at/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - main
9 | jobs:
10 | whisper-test:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ['3.8', '3.9', '3.10', '3.11']
15 | pytorch-version: [1.13.1, 2.0.0]
16 | exclude:
17 | - python-version: '3.11'
18 | pytorch-version: 1.13.1
19 | steps:
20 | - uses: conda-incubator/setup-miniconda@v2
21 | - run: conda install -n test ffmpeg python=${{ matrix.python-version }}
22 | - run: pip3 install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu
23 | - uses: actions/checkout@v3
24 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
25 | - run: pip install .["dev"]
26 | - run: black --check --diff -t py38 --include '(\.pyi?)$' .
27 | - run: isort --check --diff .
28 | - run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
29 | - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
30 |
--------------------------------------------------------------------------------
/package/whisper-at/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2023, Yuan Gong
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 |
--------------------------------------------------------------------------------
/package/whisper-at/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 | include README.md
3 | include LICENSE
4 | include whisper_at/assets/*
5 | include whisper_at/normalizers/english.json
6 |
--------------------------------------------------------------------------------
/package/whisper-at/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 |
3 | [tool.isort]
4 | profile = "black"
5 | include_trailing_comma = true
6 | line_length = 88
7 | multi_line_output = 3
8 |
9 |
--------------------------------------------------------------------------------
/package/whisper-at/requirements.txt:
--------------------------------------------------------------------------------
1 | numba
2 | numpy
3 | torch
4 | tqdm
5 | more-itertools
6 | tiktoken==0.3.3
--------------------------------------------------------------------------------
/package/whisper-at/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import sys
4 |
5 | import pkg_resources
6 | from setuptools import find_packages, setup
7 |
8 | requirements = []
9 | if sys.platform.startswith("linux") and platform.machine() == "x86_64":
10 | requirements.append("triton==2.0.0")
11 |
12 | setup(
13 | name="whisper-at",
14 | py_modules=["whisper_at"],
15 | version=0.5,
16 | description="Joint speech recognition and audio tagging model.",
17 | long_description=open("README.md", encoding="utf-8").read(),
18 | long_description_content_type="text/markdown",
19 | readme="README.md",
20 | python_requires=">=3.8",
21 | author="Yuan Gong",
22 | url="https://github.com/YuanGongND/whisper-at",
23 | license="BSD",
24 | packages=find_packages(exclude=["tests*"]),
25 | install_requires=requirements
26 | + [
27 | str(r)
28 | for r in pkg_resources.parse_requirements(
29 | open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
30 | )
31 | ],
32 | entry_points={
33 | "console_scripts": ["whisper=whisper.transcribe:cli"],
34 | },
35 | include_package_data=True,
36 | extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
37 | )
38 |
--------------------------------------------------------------------------------
/package/whisper-at/whisper_at/__init__.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import io
3 | import os
4 | import urllib
5 | import warnings
6 | from typing import List, Optional, Union
7 |
8 | import torch
9 | from tqdm import tqdm
10 |
11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim
12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language
13 | from .model import ModelDimensions, Whisper
14 | from .transcribe import transcribe
15 | from .at_post_processing import *
16 | from .version import __version__
17 |
18 | _MODELS = {
19 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
20 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
21 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
22 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
23 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
24 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
25 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
26 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
27 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
28 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
29 | "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
30 | }
31 |
32 | _MODELS_AT = {
33 | "tiny.en": "https://www.dropbox.com/s/atq9so6w0qug5ai/tiny.en_ori.pth?dl=1",
34 | "tiny": "https://www.dropbox.com/s/cib4q4iz6g758l0/tiny_ori.pth?dl=1",
35 | "base.en": "https://www.dropbox.com/s/qtzgsbuquoz0afn/base.en_ori.pth?dl=1",
36 | "base": "https://www.dropbox.com/s/2odwh42u6e9ger7/base_ori.pth?dl=1",
37 | "small.en": "https://www.dropbox.com/s/cyx50ycl1ul7lji/small.en_ori.pth?dl=1",
38 | "small.en_low": "https://www.dropbox.com/s/507o66zgl8v6ddd/small.en_low.pth?dl=1",
39 | "small": "https://www.dropbox.com/s/jftj9s0kr4ycvr1/small_ori.pth?dl=1",
40 | "small_low": "https://www.dropbox.com/s/a1x0416v58f7wrf/small_low.pth?dl=1",
41 | "medium.en": "https://www.dropbox.com/s/bbvylvmgns8ja4p/medium.en_ori.pth?dl=1",
42 | "medium.en_low": "https://www.dropbox.com/s/2q5wprr8f9gti5t/medium.en_low.pth?dl=1",
43 | "medium": "https://www.dropbox.com/s/65aabayr7o819az/medium_ori.pth?dl=1",
44 | "medium_low": "https://www.dropbox.com/s/0mnfmcasram4n6o/medium_low.pth?dl=1",
45 | "large-v1": "https://www.dropbox.com/s/b8x2en1fdzc8nhk/large-v1_ori.pth?dl=1",
46 | "large-v1_low": "https://www.dropbox.com/s/5o79h70wyla8jlk/large-v1_low.pth?dl=1",
47 | "large-v2": "https://www.dropbox.com/s/3zxpyvdrxy22eq7/large-v2_ori.pth?dl=1",
48 | "large-v2_low": "https://www.dropbox.com/s/jw2rh4uylhqgn85/large-v2_low.pth?dl=1",
49 | "large": "https://www.dropbox.com/s/3zxpyvdrxy22eq7/large-v2_ori.pth?dl=1",
50 | "large_low": "https://www.dropbox.com/s/jw2rh4uylhqgn85/large-v2_low.pth?dl=1",
51 | }
52 |
53 | # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
54 | # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
55 | _ALIGNMENT_HEADS = {
56 | "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
57 | "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
58 | "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
59 | "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
61 | "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
64 | "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
66 | "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
67 | }
68 |
69 |
70 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
71 | os.makedirs(root, exist_ok=True)
72 |
73 | expected_sha256 = url.split("/")[-2]
74 | parsed_url = urllib.parse.urlparse(url).path
75 | download_target = os.path.join(root, os.path.basename(parsed_url))
76 |
77 | if os.path.exists(download_target) and not os.path.isfile(download_target):
78 | raise RuntimeError(f"{download_target} exists and is not a regular file")
79 |
80 | if os.path.isfile(download_target):
81 | with open(download_target, "rb") as f:
82 | model_bytes = f.read()
83 | #if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
84 | return model_bytes if in_memory else download_target
85 | # else:
86 | # warnings.warn(
87 | # f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
88 | # )
89 |
90 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
91 | with tqdm(
92 | total=int(source.info().get("Content-Length")),
93 | ncols=80,
94 | unit="iB",
95 | unit_scale=True,
96 | unit_divisor=1024,
97 | ) as loop:
98 | while True:
99 | buffer = source.read(8192)
100 | if not buffer:
101 | break
102 |
103 | output.write(buffer)
104 | loop.update(len(buffer))
105 |
106 | model_bytes = open(download_target, "rb").read()
107 | # if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
108 | # raise RuntimeError(
109 | # "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
110 | # )
111 |
112 | return model_bytes if in_memory else download_target
113 |
114 |
115 | def available_models() -> List[str]:
116 | """Returns the names of available models"""
117 | return list(_MODELS.keys())
118 |
119 |
120 | def load_model(
121 | name: str,
122 | device: Optional[Union[str, torch.device]] = None,
123 | download_root: str = None,
124 | in_memory: bool = False,
125 | at_low_compute = False
126 | ) -> Whisper:
127 | """
128 | Load a Whisper ASR model
129 |
130 | Parameters
131 | ----------
132 | name : str
133 | one of the official model names listed by `whisper.available_models()`, or
134 | path to a model checkpoint containing the model dimensions and the model state_dict.
135 | device : Union[str, torch.device]
136 | the PyTorch device to put the model into
137 | download_root: str
138 | path to download the model files; by default, it uses "~/.cache/whisper"
139 | in_memory: bool
140 | whether to preload the model weights into host memory
141 |
142 | Returns
143 | -------
144 | model : Whisper
145 | The Whisper ASR model instance
146 | """
147 |
148 | if device is None:
149 | device = "cuda" if torch.cuda.is_available() else "cpu"
150 | if download_root is None:
151 | default = os.path.join(os.path.expanduser("~"), ".cache")
152 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
153 |
154 | # if use low-dim proj, only applied for large, medium, and small model
155 | if at_low_compute == True:
156 | at_mdl_name = name + '_low'
157 | else:
158 | at_mdl_name = name
159 |
160 | if name in _MODELS:
161 | checkpoint_file = _download(_MODELS[name], download_root, in_memory)
162 | checkpoint_file_at = _download(_MODELS_AT[at_mdl_name], download_root, in_memory)
163 | alignment_heads = _ALIGNMENT_HEADS[name]
164 | elif os.path.isfile(name):
165 | checkpoint_file = open(name, "rb").read() if in_memory else name
166 | alignment_heads = None
167 | else:
168 | raise RuntimeError(
169 | f"Model {name} not found; available models = {available_models()}"
170 | )
171 |
172 | with (
173 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
174 | ) as fp:
175 | checkpoint = torch.load(fp, map_location=device)
176 | del checkpoint_file
177 |
178 | with (
179 | io.BytesIO(checkpoint_file_at) if in_memory else open(checkpoint_file_at, "rb")
180 | ) as fp:
181 | checkpoint_at = torch.load(fp, map_location=device)
182 | del checkpoint_file_at
183 |
184 | dims = ModelDimensions(**checkpoint["dims"])
185 | model = Whisper(dims, at_low_compute=at_low_compute)
186 |
187 | combined_state_dict = {}
188 | combined_state_dict.update(checkpoint["model_state_dict"])
189 | combined_state_dict.update(checkpoint_at)
190 |
191 | model.load_state_dict(combined_state_dict, strict=True)
192 |
193 | if alignment_heads is not None:
194 | model.set_alignment_heads(alignment_heads)
195 |
196 | return model.to(device)
197 |
--------------------------------------------------------------------------------
/package/whisper-at/whisper_at/__main__.py:
--------------------------------------------------------------------------------
1 | from .transcribe import cli
2 |
3 | cli()
4 |
--------------------------------------------------------------------------------
/package/whisper-at/whisper_at/assets/mel_filters.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/package/whisper-at/whisper_at/assets/mel_filters.npz
--------------------------------------------------------------------------------
/package/whisper-at/whisper_at/at_post_processing.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 6/1/23 3:33 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : at_post_processing.py
7 |
8 | import os
9 | import json
10 | import torch
11 | import warnings
12 | from .tokenizer import LANGUAGES
13 |
14 | def parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527))):
15 | """
16 | :param result: The result dict returned by the whisper-at transcribe function.
17 | :param language: The audio tag label name language, e.g., 'en', 'zh'. Default='follow_asr', i.e., same with ASR result.
18 | :param top_k: Output up to k sound classes that have logits above p_threshold. Default=5.
19 | :param p_threshold: The logit threshold to predict a sound class. Default=-1.
20 | :param p_threshold: A list of indexes that of interest. Default = list(range(527)) (all classes).
21 | :return: A dictionary of audio tagging results
22 | """
23 | asr_language = result['language']
24 | at_time_res = result['at_time_res']
25 | audio_tag = result['audio_tag']
26 |
27 | if language == 'follow_asr':
28 | language = asr_language
29 |
30 | with open(os.path.join(os.path.dirname(__file__), "assets", "label_name_dict.json")) as json_file:
31 | label_name_dict = json.load(json_file)
32 |
33 | if language not in label_name_dict.keys():
34 | warnings.warn("{:s} language not supported. Use English label names instead. If you wish to use label names of a specific language, please specify the language argument".format(language))
35 | language = 'en'
36 |
37 | label_name_list = label_name_dict[language]
38 |
39 | all_res = []
40 | for i in range(audio_tag.shape[0]):
41 | top_values, top_indices = torch.topk(audio_tag[i], k=top_k)
42 | cur_time_stamp = {'start': i*at_time_res, 'end': (i+1)*at_time_res}
43 | cur_labels_list = []
44 | for j in range(top_indices.shape[0]):
45 | if top_values[j] > p_threshold and top_indices[j] in include_class_list:
46 | cur_label = (label_name_list[top_indices[j]], top_values[j].item())
47 | cur_labels_list.append(cur_label)
48 | all_res.append({'time': cur_time_stamp, 'audio tags': cur_labels_list})
49 | return all_res
50 |
51 | def print_label_name(language='en'):
52 | with open(os.path.join(os.path.dirname(__file__), "assets", "label_name_dict.json")) as json_file:
53 | label_name_dict = json.load(json_file)
54 | label_name_list = label_name_dict[language]
55 | for i in range(len(label_name_list)):
56 | print("index: {:d} : {:s}".format(i, label_name_list[i]))
57 |
58 | def print_support_language():
59 | with open(os.path.join(os.path.dirname(__file__), "assets", "label_name_dict.json")) as json_file:
60 | label_name_dict = json.load(json_file)
61 | for key in label_name_dict.keys():
62 | print("language code: {:s} : {:s}".format(key, LANGUAGES[key]))
63 |
64 | if __name__ == '__main__':
65 | print_support_language()
66 | print_label_name(language='zh')
67 |
68 |
--------------------------------------------------------------------------------
/package/whisper-at/whisper_at/audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import lru_cache
3 | from subprocess import CalledProcessError, run
4 | from typing import Optional, Union
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from .utils import exact_div
11 |
12 | # hard-coded audio hyperparameters
13 | SAMPLE_RATE = 16000
14 | N_FFT = 400
15 | N_MELS = 80
16 | HOP_LENGTH = 160
17 | CHUNK_LENGTH = 30
18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
20 |
21 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
22 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
23 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
24 |
25 |
26 | def load_audio(file: str, sr: int = SAMPLE_RATE):
27 | """
28 | Open an audio file and read as mono waveform, resampling as necessary
29 |
30 | Parameters
31 | ----------
32 | file: str
33 | The audio file to open
34 |
35 | sr: int
36 | The sample rate to resample the audio if necessary
37 |
38 | Returns
39 | -------
40 | A NumPy array containing the audio waveform, in float32 dtype.
41 | """
42 |
43 | # This launches a subprocess to decode audio while down-mixing
44 | # and resampling as necessary. Requires the ffmpeg CLI in PATH.
45 | # fmt: off
46 | cmd = [
47 | "ffmpeg",
48 | "-nostdin",
49 | "-threads", "0",
50 | "-i", file,
51 | "-f", "s16le",
52 | "-ac", "1",
53 | "-acodec", "pcm_s16le",
54 | "-ar", str(sr),
55 | "-"
56 | ]
57 | # fmt: on
58 | try:
59 | out = run(cmd, capture_output=True, check=True).stdout
60 | except CalledProcessError as e:
61 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
62 |
63 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
64 |
65 |
66 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
67 | """
68 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
69 | """
70 | if torch.is_tensor(array):
71 | if array.shape[axis] > length:
72 | array = array.index_select(
73 | dim=axis, index=torch.arange(length, device=array.device)
74 | )
75 |
76 | if array.shape[axis] < length:
77 | pad_widths = [(0, 0)] * array.ndim
78 | pad_widths[axis] = (0, length - array.shape[axis])
79 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
80 | else:
81 | if array.shape[axis] > length:
82 | array = array.take(indices=range(length), axis=axis)
83 |
84 | if array.shape[axis] < length:
85 | pad_widths = [(0, 0)] * array.ndim
86 | pad_widths[axis] = (0, length - array.shape[axis])
87 | array = np.pad(array, pad_widths)
88 |
89 | return array
90 |
91 |
92 | @lru_cache(maxsize=None)
93 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
94 | """
95 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
96 | Allows decoupling librosa dependency; saved using:
97 |
98 | np.savez_compressed(
99 | "mel_filters.npz",
100 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
101 | )
102 | """
103 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
104 | with np.load(
105 | os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
106 | ) as f:
107 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
108 |
109 |
110 | def log_mel_spectrogram(
111 | audio: Union[str, np.ndarray, torch.Tensor],
112 | n_mels: int = N_MELS,
113 | padding: int = 0,
114 | device: Optional[Union[str, torch.device]] = None,
115 | ):
116 | """
117 | Compute the log-Mel spectrogram of
118 |
119 | Parameters
120 | ----------
121 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
122 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
123 |
124 | n_mels: int
125 | The number of Mel-frequency filters, only 80 is supported
126 |
127 | padding: int
128 | Number of zero samples to pad to the right
129 |
130 | device: Optional[Union[str, torch.device]]
131 | If given, the audio tensor is moved to this device before STFT
132 |
133 | Returns
134 | -------
135 | torch.Tensor, shape = (80, n_frames)
136 | A Tensor that contains the Mel spectrogram
137 | """
138 | if not torch.is_tensor(audio):
139 | if isinstance(audio, str):
140 | audio = load_audio(audio)
141 | audio = torch.from_numpy(audio)
142 |
143 | if device is not None:
144 | audio = audio.to(device)
145 | if padding > 0:
146 | audio = F.pad(audio, (0, padding))
147 | window = torch.hann_window(N_FFT).to(audio.device)
148 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
149 | magnitudes = stft[..., :-1].abs() ** 2
150 |
151 | filters = mel_filters(audio.device, n_mels)
152 | mel_spec = filters @ magnitudes
153 |
154 | log_spec = torch.clamp(mel_spec, min=1e-10).log10()
155 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
156 | log_spec = (log_spec + 4.0) / 4.0
157 | return log_spec
158 |
--------------------------------------------------------------------------------
/package/whisper-at/whisper_at/normalizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import BasicTextNormalizer as BasicTextNormalizer
2 | from .english import EnglishTextNormalizer as EnglishTextNormalizer
3 |
--------------------------------------------------------------------------------
/package/whisper-at/whisper_at/normalizers/basic.py:
--------------------------------------------------------------------------------
1 | import re
2 | import unicodedata
3 |
4 | import regex
5 |
6 | # non-ASCII letters that are not separated by "NFKD" normalization
7 | ADDITIONAL_DIACRITICS = {
8 | "œ": "oe",
9 | "Œ": "OE",
10 | "ø": "o",
11 | "Ø": "O",
12 | "æ": "ae",
13 | "Æ": "AE",
14 | "ß": "ss",
15 | "ẞ": "SS",
16 | "đ": "d",
17 | "Đ": "D",
18 | "ð": "d",
19 | "Ð": "D",
20 | "þ": "th",
21 | "Þ": "th",
22 | "ł": "l",
23 | "Ł": "L",
24 | }
25 |
26 |
27 | def remove_symbols_and_diacritics(s: str, keep=""):
28 | """
29 | Replace any other markers, symbols, and punctuations with a space,
30 | and drop any diacritics (category 'Mn' and some manual mappings)
31 | """
32 | return "".join(
33 | c
34 | if c in keep
35 | else ADDITIONAL_DIACRITICS[c]
36 | if c in ADDITIONAL_DIACRITICS
37 | else ""
38 | if unicodedata.category(c) == "Mn"
39 | else " "
40 | if unicodedata.category(c)[0] in "MSP"
41 | else c
42 | for c in unicodedata.normalize("NFKD", s)
43 | )
44 |
45 |
46 | def remove_symbols(s: str):
47 | """
48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics
49 | """
50 | return "".join(
51 | " " if unicodedata.category(c)[0] in "MSP" else c
52 | for c in unicodedata.normalize("NFKC", s)
53 | )
54 |
55 |
56 | class BasicTextNormalizer:
57 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
58 | self.clean = (
59 | remove_symbols_and_diacritics if remove_diacritics else remove_symbols
60 | )
61 | self.split_letters = split_letters
62 |
63 | def __call__(self, s: str):
64 | s = s.lower()
65 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
66 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
67 | s = self.clean(s).lower()
68 |
69 | if self.split_letters:
70 | s = " ".join(regex.findall(r"\X", s, regex.U))
71 |
72 | s = re.sub(
73 | r"\s+", " ", s
74 | ) # replace any successive whitespace characters with a space
75 |
76 | return s
77 |
--------------------------------------------------------------------------------
/package/whisper-at/whisper_at/triton_ops.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 |
3 | import numpy as np
4 | import torch
5 |
6 | try:
7 | import triton
8 | import triton.language as tl
9 | except ImportError:
10 | raise RuntimeError("triton import failed; try `pip install --pre triton`")
11 |
12 |
13 | @triton.jit
14 | def dtw_kernel(
15 | cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
16 | ):
17 | offsets = tl.arange(0, BLOCK_SIZE)
18 | mask = offsets < M
19 |
20 | for k in range(1, N + M + 1): # k = i + j
21 | tl.debug_barrier()
22 |
23 | p0 = cost + (k - 1) * cost_stride
24 | p1 = cost + k * cost_stride
25 | p2 = cost + k * cost_stride + 1
26 |
27 | c0 = tl.load(p0 + offsets, mask=mask)
28 | c1 = tl.load(p1 + offsets, mask=mask)
29 | c2 = tl.load(p2 + offsets, mask=mask)
30 |
31 | x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
32 | cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
33 |
34 | cost_ptr = cost + (k + 1) * cost_stride + 1
35 | tl.store(cost_ptr + offsets, cost_row, mask=mask)
36 |
37 | trace_ptr = trace + (k + 1) * trace_stride + 1
38 | tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
39 | tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
40 | tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
41 |
42 |
43 | @lru_cache(maxsize=None)
44 | def median_kernel(filter_width: int):
45 | @triton.jit
46 | def kernel(
47 | y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
48 | ): # x.shape[-1] == filter_width
49 | row_idx = tl.program_id(0)
50 | offsets = tl.arange(0, BLOCK_SIZE)
51 | mask = offsets < y_stride
52 |
53 | x_ptr = x + row_idx * x_stride # noqa: F841
54 | y_ptr = y + row_idx * y_stride
55 |
56 | LOAD_ALL_ROWS_HERE # noqa: F821
57 |
58 | BUBBLESORT_HERE # noqa: F821
59 |
60 | tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
61 |
62 | kernel = triton.JITFunction(kernel.fn)
63 | kernel.src = kernel.src.replace(
64 | " LOAD_ALL_ROWS_HERE",
65 | "\n".join(
66 | [
67 | f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
68 | for i in range(filter_width)
69 | ]
70 | ),
71 | )
72 | kernel.src = kernel.src.replace(
73 | " BUBBLESORT_HERE",
74 | "\n\n".join(
75 | [
76 | "\n\n".join(
77 | [
78 | "\n".join(
79 | [
80 | f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
81 | f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
82 | f" row{j} = smaller",
83 | f" row{j + 1} = larger",
84 | ]
85 | )
86 | for j in range(filter_width - i - 1)
87 | ]
88 | )
89 | for i in range(filter_width // 2 + 1)
90 | ]
91 | ),
92 | )
93 | kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
94 |
95 | return kernel
96 |
97 |
98 | def median_filter_cuda(x: torch.Tensor, filter_width: int):
99 | """Apply a median filter of given width along the last dimension of x"""
100 | slices = x.contiguous().unfold(-1, filter_width, 1)
101 | grid = np.prod(slices.shape[:-2])
102 |
103 | kernel = median_kernel(filter_width)
104 | y = torch.empty_like(slices[..., 0])
105 |
106 | BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
107 | kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
108 |
109 | return y
110 |
--------------------------------------------------------------------------------
/package/whisper-at/whisper_at/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | import sys
5 | import zlib
6 | from typing import Callable, Optional, TextIO
7 |
8 | system_encoding = sys.getdefaultencoding()
9 |
10 | if system_encoding != "utf-8":
11 |
12 | def make_safe(string):
13 | # replaces any character not representable using the system default encoding with an '?',
14 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
15 | return string.encode(system_encoding, errors="replace").decode(system_encoding)
16 |
17 | else:
18 |
19 | def make_safe(string):
20 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
21 | return string
22 |
23 |
24 | def exact_div(x, y):
25 | assert x % y == 0
26 | return x // y
27 |
28 |
29 | def str2bool(string):
30 | str2val = {"True": True, "False": False}
31 | if string in str2val:
32 | return str2val[string]
33 | else:
34 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
35 |
36 |
37 | def optional_int(string):
38 | return None if string == "None" else int(string)
39 |
40 |
41 | def optional_float(string):
42 | return None if string == "None" else float(string)
43 |
44 |
45 | def compression_ratio(text) -> float:
46 | text_bytes = text.encode("utf-8")
47 | return len(text_bytes) / len(zlib.compress(text_bytes))
48 |
49 |
50 | def format_timestamp(
51 | seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
52 | ):
53 | assert seconds >= 0, "non-negative timestamp expected"
54 | milliseconds = round(seconds * 1000.0)
55 |
56 | hours = milliseconds // 3_600_000
57 | milliseconds -= hours * 3_600_000
58 |
59 | minutes = milliseconds // 60_000
60 | milliseconds -= minutes * 60_000
61 |
62 | seconds = milliseconds // 1_000
63 | milliseconds -= seconds * 1_000
64 |
65 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
66 | return (
67 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
68 | )
69 |
70 |
71 | class ResultWriter:
72 | extension: str
73 |
74 | def __init__(self, output_dir: str):
75 | self.output_dir = output_dir
76 |
77 | def __call__(self, result: dict, audio_path: str, options: dict):
78 | audio_basename = os.path.basename(audio_path)
79 | audio_basename = os.path.splitext(audio_basename)[0]
80 | output_path = os.path.join(
81 | self.output_dir, audio_basename + "." + self.extension
82 | )
83 |
84 | with open(output_path, "w", encoding="utf-8") as f:
85 | self.write_result(result, file=f, options=options)
86 |
87 | def write_result(self, result: dict, file: TextIO, options: dict):
88 | raise NotImplementedError
89 |
90 |
91 | class WriteTXT(ResultWriter):
92 | extension: str = "txt"
93 |
94 | def write_result(self, result: dict, file: TextIO, options: dict):
95 | for segment in result["segments"]:
96 | print(segment["text"].strip(), file=file, flush=True)
97 |
98 |
99 | class SubtitlesWriter(ResultWriter):
100 | always_include_hours: bool
101 | decimal_marker: str
102 |
103 | def iterate_result(self, result: dict, options: dict):
104 | raw_max_line_width: Optional[int] = options["max_line_width"]
105 | max_line_count: Optional[int] = options["max_line_count"]
106 | highlight_words: bool = options["highlight_words"]
107 | max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
108 | preserve_segments = max_line_count is None or raw_max_line_width is None
109 |
110 | def iterate_subtitles():
111 | line_len = 0
112 | line_count = 1
113 | # the next subtitle to yield (a list of word timings with whitespace)
114 | subtitle: list[dict] = []
115 | last = result["segments"][0]["words"][0]["start"]
116 | for segment in result["segments"]:
117 | for i, original_timing in enumerate(segment["words"]):
118 | timing = original_timing.copy()
119 | long_pause = not preserve_segments and timing["start"] - last > 3.0
120 | has_room = line_len + len(timing["word"]) <= max_line_width
121 | seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
122 | if line_len > 0 and has_room and not long_pause and not seg_break:
123 | # line continuation
124 | line_len += len(timing["word"])
125 | else:
126 | # new line
127 | timing["word"] = timing["word"].strip()
128 | if (
129 | len(subtitle) > 0
130 | and max_line_count is not None
131 | and (long_pause or line_count >= max_line_count)
132 | or seg_break
133 | ):
134 | # subtitle break
135 | yield subtitle
136 | subtitle = []
137 | line_count = 1
138 | elif line_len > 0:
139 | # line break
140 | line_count += 1
141 | timing["word"] = "\n" + timing["word"]
142 | line_len = len(timing["word"].strip())
143 | subtitle.append(timing)
144 | last = timing["start"]
145 | if len(subtitle) > 0:
146 | yield subtitle
147 |
148 | if "words" in result["segments"][0]:
149 | for subtitle in iterate_subtitles():
150 | subtitle_start = self.format_timestamp(subtitle[0]["start"])
151 | subtitle_end = self.format_timestamp(subtitle[-1]["end"])
152 | subtitle_text = "".join([word["word"] for word in subtitle])
153 | if highlight_words:
154 | last = subtitle_start
155 | all_words = [timing["word"] for timing in subtitle]
156 | for i, this_word in enumerate(subtitle):
157 | start = self.format_timestamp(this_word["start"])
158 | end = self.format_timestamp(this_word["end"])
159 | if last != start:
160 | yield last, start, subtitle_text
161 |
162 | yield start, end, "".join(
163 | [
164 | re.sub(r"^(\s*)(.*)$", r"\1\2", word)
165 | if j == i
166 | else word
167 | for j, word in enumerate(all_words)
168 | ]
169 | )
170 | last = end
171 | else:
172 | yield subtitle_start, subtitle_end, subtitle_text
173 | else:
174 | for segment in result["segments"]:
175 | segment_start = self.format_timestamp(segment["start"])
176 | segment_end = self.format_timestamp(segment["end"])
177 | segment_text = segment["text"].strip().replace("-->", "->")
178 | yield segment_start, segment_end, segment_text
179 |
180 | def format_timestamp(self, seconds: float):
181 | return format_timestamp(
182 | seconds=seconds,
183 | always_include_hours=self.always_include_hours,
184 | decimal_marker=self.decimal_marker,
185 | )
186 |
187 |
188 | class WriteVTT(SubtitlesWriter):
189 | extension: str = "vtt"
190 | always_include_hours: bool = False
191 | decimal_marker: str = "."
192 |
193 | def write_result(self, result: dict, file: TextIO, options: dict):
194 | print("WEBVTT\n", file=file)
195 | for start, end, text in self.iterate_result(result, options):
196 | print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
197 |
198 |
199 | class WriteSRT(SubtitlesWriter):
200 | extension: str = "srt"
201 | always_include_hours: bool = True
202 | decimal_marker: str = ","
203 |
204 | def write_result(self, result: dict, file: TextIO, options: dict):
205 | for i, (start, end, text) in enumerate(
206 | self.iterate_result(result, options), start=1
207 | ):
208 | print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
209 |
210 |
211 | class WriteTSV(ResultWriter):
212 | """
213 | Write a transcript to a file in TSV (tab-separated values) format containing lines like:
214 | \t\t
215 |
216 | Using integer milliseconds as start and end times means there's no chance of interference from
217 | an environment setting a language encoding that causes the decimal in a floating point number
218 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
219 | """
220 |
221 | extension: str = "tsv"
222 |
223 | def write_result(self, result: dict, file: TextIO, options: dict):
224 | print("start", "end", "text", sep="\t", file=file)
225 | for segment in result["segments"]:
226 | print(round(1000 * segment["start"]), file=file, end="\t")
227 | print(round(1000 * segment["end"]), file=file, end="\t")
228 | print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
229 |
230 |
231 | class WriteJSON(ResultWriter):
232 | extension: str = "json"
233 |
234 | def write_result(self, result: dict, file: TextIO, options: dict):
235 | json.dump(result, file)
236 |
237 |
238 | def get_writer(
239 | output_format: str, output_dir: str
240 | ) -> Callable[[dict, TextIO, dict], None]:
241 | writers = {
242 | "txt": WriteTXT,
243 | "vtt": WriteVTT,
244 | "srt": WriteSRT,
245 | "tsv": WriteTSV,
246 | "json": WriteJSON,
247 | }
248 |
249 | if output_format == "all":
250 | all_writers = [writer(output_dir) for writer in writers.values()]
251 |
252 | def write_all(result: dict, file: TextIO, options: dict):
253 | for writer in all_writers:
254 | writer(result, file, options)
255 |
256 | return write_all
257 |
258 | return writers[output_format](output_dir)
259 |
--------------------------------------------------------------------------------
/package/whisper-at/whisper_at/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.5"
2 |
--------------------------------------------------------------------------------
/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/poster.pdf
--------------------------------------------------------------------------------
/poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/poster.png
--------------------------------------------------------------------------------
/poster_low.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/poster_low.png
--------------------------------------------------------------------------------
/pretrained_models/README.md:
--------------------------------------------------------------------------------
1 | # Pretrained Weights
2 |
3 | The Whisper-AT script downloads the original OpenAI Whisper model and our AT model automatically. So you do not really need to download it manually.
4 | But in case your device does not have Internet access, here is the link. Download and place the model in the same directory (by default `~/.cache/whisper`) as the original OpenAI Whisper model.
5 |
6 | These links support `wget`.
7 |
8 | ```python
9 | dropbox_path = [
10 | "tiny.en": "https://www.dropbox.com/s/atq9so6w0qug5ai/tiny.en_ori.pth?dl=1",
11 | "tiny": "https://www.dropbox.com/s/06f2h29aki39q9r/tiny_ori.pth?dl=1",
12 | "base.en": "https://www.dropbox.com/s/qtzgsbuquoz0afn/base.en_ori.pth?dl=1",
13 | "base": "https://www.dropbox.com/s/4vn2oatda321y7h/base_ori.pth?dl=1",
14 | "small.en": "https://www.dropbox.com/s/cyx50ycl1ul7lji/small.en_ori.pth?dl=1",
15 | "small.en_low": "https://www.dropbox.com/s/507o66zgl8v6ddd/small.en_low.pth?dl=1",
16 | "small": "https://www.dropbox.com/s/5zqzs3e47zwhjd3/small_ori.pth?dl=1",
17 | "small_low": "https://www.dropbox.com/s/3lxlmh437tneifl/small_low.pth?dl=1",
18 | "medium.en": "https://www.dropbox.com/s/bbvylvmgns8ja4p/medium.en_ori.pth?dl=1",
19 | "medium.en_low": "https://www.dropbox.com/s/2q5wprr8f9gti5t/medium.en_low.pth?dl=1",
20 | "medium": "https://www.dropbox.com/s/93zfj4afmv0qfyl/medium_ori.pth?dl=1",
21 | "medium_low": "https://www.dropbox.com/s/g66h1vtn1u426dj/medium_low.pth?dl=1",
22 | "large-v1": "https://www.dropbox.com/s/b8x2en1fdzc8nhk/large-v1_ori.pth?dl=1",
23 | "large-v1_low": "https://www.dropbox.com/s/5o79h70wyla8jlk/large-v1_low.pth?dl=1",
24 | "large-v2": "https://www.dropbox.com/s/94x7wqw4hscpls0/large-v2_ori.pth?dl=1",
25 | "large-v2_low": "https://www.dropbox.com/s/wk5dyxustpji06c/large-v2_low.pth?dl=1",
26 | "large": "https://www.dropbox.com/s/94x7wqw4hscpls0/large-v2_ori.pth?dl=1",
27 | "large_low": "https://www.dropbox.com/s/wk5dyxustpji06c/large-v2_low.pth?dl=1"]
28 | ```
29 |
30 | ## China Mirror Links 镜像链接
31 |
32 | The models are hosted on Dropbox. If dropbox is not accessible, use a VPN or the mirror link, you would have to donwload it manually and place it in the same directory (by default `~/.cache/whisper`) as the original OpenAI Whisper model.
33 | [[镜像链接(腾讯微云)]](https://share.weiyun.com/bVxQWxTe)
--------------------------------------------------------------------------------
/review/author_response.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/review/author_response.pdf
--------------------------------------------------------------------------------
/review/whisper_at_review.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/review/whisper_at_review.pdf
--------------------------------------------------------------------------------
/sample/whisper_transcribe_test_simple.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 5/28/23 2:36 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : whisper_transcribe_test.py
7 |
8 | import whisper_at as whisper
9 |
10 | model = whisper.load_model("large-v1")
11 | result = model.transcribe("/data/sls/scratch/yuangong/dataset/adress_train/ADReSS-IS2020-data/train/Full_wave_enhanced_audio/cc/S024.wav")
12 | #result = model.transcribe("/data/sls/scratch/yuangong/whisper-at/sample_audio/007P6bFgRCU_10.000.flac", at_time_res=2)
13 |
14 | print(result['text'])
15 | print(result['segments'])
16 | text_segments = result['segments']
17 | text_annotation = [(x['start'], x['end'], x['text']) for x in text_segments]
18 | print(text_annotation)
19 | at_res = whisper.parse_at_label(result, language='en', p_threshold=-1)
20 | print(at_res)
21 |
22 | all_seg = []
23 | for segment in at_res:
24 | cur_start = segment['time']['start']
25 | cur_end = segment['time']['end']
26 | cur_tags = segment['audio tags']
27 | cur_tags = [x[0] for x in cur_tags]
28 | cur_tags = '; '.join(cur_tags)
29 | all_seg.append((cur_start, cur_end, cur_tags))
30 | print(all_seg)
31 |
32 | whisper.print_support_language()
33 | whisper.print_label_name(language='zh')
--------------------------------------------------------------------------------
/src/noise_robust_asr/asr_experiments/compute_wer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 3/3/23 6:27 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : cal_wer.py
7 |
8 | import os
9 | import editdistance
10 | import jiwer
11 | import numpy as np
12 |
13 | def fileList(source):
14 | matches = []
15 | for root, dirnames, filenames in os.walk(source):
16 | for filename in filenames:
17 | if filename.endswith(('.txt')):
18 | matches.append(os.path.join(root, filename))
19 | return matches
20 |
21 | def calculate_wer(seqs_hat, seqs_true):
22 | """Calculate sentence-level WER score.
23 | :param list seqs_hat: prediction
24 | :param list seqs_true: reference
25 | :return: average sentence-level WER score
26 | :rtype float
27 | """
28 | word_eds, word_ref_lens = [], []
29 | for i in range(len(seqs_true)):
30 | seq_true_text = seqs_true[i]
31 | seq_hat_text = seqs_hat[i]
32 | hyp_words = seq_hat_text.split()
33 | ref_words = seq_true_text.split()
34 | word_eds.append(editdistance.eval(hyp_words, ref_words))
35 | word_ref_lens.append(len(ref_words))
36 | return float(sum(word_eds)) / sum(word_ref_lens)
37 |
38 | def eval_noise_wer(trans_path, result_path):
39 | whisper_trans = fileList(trans_path)
40 | truth_path = '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/ground_truth_trans/'
41 | truth_trans = fileList(truth_path)
42 | print(len(whisper_trans), len(truth_trans))
43 |
44 | def preprocess_text(cur_trans):
45 | cur_trans = jiwer.ToUpperCase()(cur_trans)
46 | cur_trans = jiwer.RemovePunctuation()(cur_trans)
47 | return cur_trans
48 |
49 | wer_list = []
50 | for db in [-20, -15, -10, -5, 0, 5, 10, 15, 20]:
51 | cur_trans_list, cur_truth_list = [], []
52 | for trans_name in whisper_trans:
53 | if int(trans_name.split('/')[-1].split('_')[0]) == db:
54 | with open(trans_name, "r") as f:
55 | cur_trans = f.read()
56 | cur_trans = preprocess_text(cur_trans)
57 | cur_trans_list.append(cur_trans)
58 | print('trans: ', cur_trans)
59 |
60 | cur_truth_name = '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/ground_truth_trans/' + trans_name.split('/')[-1].split('_mix_')[0].split('_')[2] + '.txt'
61 | with open(cur_truth_name, "r") as f:
62 | cur_truth = f.read()
63 | cur_truth = preprocess_text(cur_truth)
64 | cur_truth_list.append(cur_truth)
65 | print('truth: ', cur_truth)
66 | wer = calculate_wer(cur_trans_list, cur_truth_list)
67 | print('wer is ', wer)
68 | wer_list.append(wer)
69 | print(wer_list)
70 | np.savetxt(result_path, wer_list, delimiter=',')
71 |
72 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_large-v1/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_large-v1.csv')
73 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_tiny.en/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_tiny.csv')
74 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_small.en/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_small.csv')
75 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_base.en/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_base.csv')
76 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_medium.en/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_medium.csv')
77 |
78 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/asr_experiments/compute_wer_cla.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 3/3/23 6:27 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : cal_wer.py
7 |
8 | import os
9 | import editdistance
10 | import jiwer
11 | import numpy as np
12 |
13 |
14 | def fileList(source):
15 | matches = []
16 | for root, dirnames, filenames in os.walk(source):
17 | for filename in filenames:
18 | if filename.endswith(('.txt')):
19 | matches.append(os.path.join(root, filename))
20 | return matches
21 |
22 | def calculate_wer(seqs_hat, seqs_true):
23 | """Calculate sentence-level WER score.
24 | :param list seqs_hat: prediction
25 | :param list seqs_true: reference
26 | :return: average sentence-level WER score
27 | :rtype float
28 | """
29 | word_eds, word_ref_lens = [], []
30 | for i in range(len(seqs_true)):
31 | seq_true_text = seqs_true[i]
32 | seq_hat_text = seqs_hat[i]
33 | hyp_words = seq_hat_text.split()
34 | ref_words = seq_true_text.split()
35 | word_eds.append(editdistance.eval(hyp_words, ref_words))
36 | word_ref_lens.append(len(ref_words))
37 | return float(sum(word_eds)) / sum(word_ref_lens)
38 |
39 | def eval_noise_wer(trans_path, result_path):
40 | whisper_trans = fileList(trans_path)
41 | truth_path = '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/ground_truth_trans/'
42 | truth_trans = fileList(truth_path)
43 | print(len(whisper_trans), len(truth_trans))
44 |
45 | def preprocess_text(cur_trans):
46 | cur_trans = jiwer.ToUpperCase()(cur_trans)
47 | cur_trans = jiwer.RemovePunctuation()(cur_trans)
48 | return cur_trans
49 |
50 | all_wer_list = []
51 | for db in [-20, -15, -10, -5, 0, 5, 10, 15, 20]:
52 | wer_list = []
53 | for cla in range(50):
54 | cur_trans_list, cur_truth_list = [], []
55 | for trans_name in whisper_trans:
56 | if int(trans_name.split('/')[-1].split('_')[0]) == db and int(trans_name.split('/')[-1].split('_')[1]) == cla:
57 | with open(trans_name, "r") as f:
58 | cur_trans = f.read()
59 | cur_trans = preprocess_text(cur_trans)
60 | cur_trans_list.append(cur_trans)
61 | #print('trans: ', cur_trans)
62 |
63 | cur_truth_name = '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/ground_truth_trans/' + trans_name.split('/')[-1].split('_mix_')[0].split('_')[2] + '.txt'
64 | with open(cur_truth_name, "r") as f:
65 | cur_truth = f.read()
66 | cur_truth = preprocess_text(cur_truth)
67 | cur_truth_list.append(cur_truth)
68 | #print('truth: ', cur_truth)
69 | #print(len(cur_trans_list), len(cur_truth_list))
70 | wer = calculate_wer(cur_trans_list, cur_truth_list)
71 | #print('wer is ', wer)
72 | wer_list.append(wer)
73 | #print(wer_list)
74 | all_wer_list.append(wer_list)
75 | np.savetxt(result_path, all_wer_list, delimiter=',')
76 |
77 | # eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_text_hubert_large/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results/hubert_large_cla.csv')
78 | # eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_text_hubert_xlarge/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results/hubert_xlarge_cla.csv')
79 | # eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_text_w2v_base/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results/w2v_base_cla.csv')
80 | # eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_text_w2v_large_robust/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results/w2v_large_robust_cla.csv')
81 | eval_noise_wer('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_large-v1/', '/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_large-v1_cla.csv')
82 |
83 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/asr_experiments/gen_noisy_speech.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 3/3/23 3:04 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : gen_noisy_speech.py
7 |
8 | import numpy as np
9 | import torchaudio
10 | import torch
11 | import os
12 |
13 | def fileList(source):
14 | matches = []
15 | for root, dirnames, filenames in os.walk(source):
16 | for filename in filenames:
17 | if filename.endswith(('.flac', '.wav')):
18 | matches.append(os.path.join(root, filename))
19 | return matches
20 |
21 | def add_noise_torch(speech_path, noise_path, noise_db, tar_path):
22 | speech, sr_s = torchaudio.load(speech_path)
23 | noise, sr_n = torchaudio.load(noise_path)
24 |
25 | assert sr_s == sr_n
26 | power_speech = (speech ** 2).mean()
27 | power_noise = (noise ** 2).mean()
28 |
29 | scale = (10 ** (-noise_db / 20) * np.sqrt(power_speech) / np.sqrt(max(power_noise, 1e-10)))
30 |
31 | # if speech is longer than the noise
32 | if speech.shape[1] > noise.shape[1]:
33 | ratio = int(np.ceil(speech.shape[1] / noise.shape[1]))
34 | noise = torch.concat([noise for _ in range(ratio)], dim=1)
35 |
36 | if speech.shape[1] < noise.shape[1]:
37 | noise = noise[:, :speech.shape[1]]
38 |
39 | speech = speech + scale * noise
40 | torchaudio.save(tar_path, speech, sample_rate=sr_s)
41 |
42 | all_speech = fileList('/data/sls/scratch/yuangong/whisper-a/sample_audio/test-clean')
43 | all_speech.sort()
44 | all_speech = all_speech[:40]
45 | print(all_speech)
46 |
47 | all_noise_dict = {}
48 | all_noise = fileList('/data/sls/scratch/yuangong/sslast2/egs/esc50/data/ESC-50-master/audio_16k/')
49 | for noise in all_noise:
50 | cla = int(noise.split('.')[-2].split('-')[-1])
51 | if cla not in all_noise_dict:
52 | all_noise_dict[cla] = [noise]
53 | else:
54 | all_noise_dict[cla].append(noise)
55 | print(all_noise_dict[0])
56 | print(len(all_noise_dict[0]))
57 |
58 | for db in [-20, -15, -10, -5, 0, 5, 10, 15, 20]:
59 | for cla in range(50):
60 | # if os.path.exists('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/{:d}/{:d}'.format(db,cla)) == False:
61 | # os.makedirs('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/{:d}/{:d}'.format(db,cla))
62 | # for each snr, for each class, test 40 librispeech samples
63 | for idx in range(40):
64 | tar_name = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/' + str(db) + '_' + str(cla) + '_' + all_speech[idx].split('/')[-1].split('.')[-2] + '_mix_' + all_noise_dict[cla][idx].split('/')[-1].split('.')[-2] + '.wav'
65 | add_noise_torch(all_speech[idx], all_noise_dict[cla][idx], db, tar_name)
--------------------------------------------------------------------------------
/src/noise_robust_asr/asr_experiments/transcribe_esc_hubert_xl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/20/23 2:09 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : transcribe_aus.py
7 |
8 | import sys
9 | argument = sys.argv[1]
10 | if argument=='4':
11 | argument='0,1,2,3'
12 | import os
13 | if argument != '-1':
14 | os.environ["CUDA_VISIBLE_DEVICES"]=argument
15 |
16 | import os
17 | import torch
18 | import soundfile
19 | from transformers import Wav2Vec2Processor, HubertForCTC
20 |
21 | def fileList(source):
22 | matches = []
23 | for root, dirnames, filenames in os.walk(source):
24 | for filename in filenames:
25 | if filename.endswith(('.flac', '.wav')):
26 | matches.append(os.path.join(root, filename))
27 | return matches
28 |
29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30 | print(device)
31 |
32 | processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-xlarge-ls960-ft")
33 | model = HubertForCTC.from_pretrained("facebook/hubert-xlarge-ls960-ft").to(device)
34 |
35 | tar_path = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_hubert_xlarge/'
36 | if os.path.exists(tar_path) == False:
37 | os.mkdir(tar_path)
38 |
39 | audio_list = fileList('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/')
40 |
41 | audio_list.sort()
42 | start_file = int(argument) * 4500
43 | end_file = int(argument) * 4500 + 4500
44 | audio_list = audio_list[start_file: end_file]
45 |
46 | print('number of files to transcribe: ', len(audio_list))
47 |
48 | for i in range(len(audio_list)):
49 | audio_path = audio_list[i]
50 | if os.path.exists(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt') == False:
51 | if audio_path[-3:] == 'wav':
52 | audio_path = audio_list[i]
53 | source, curr_sample_rate = soundfile.read(audio_path, dtype="float32")
54 | input_features = processor(source,
55 | sampling_rate=curr_sample_rate,
56 | return_tensors="pt").input_values
57 | logits = model(input_features.to(device)).logits
58 | predicted_ids = torch.argmax(logits, dim=-1)
59 | text = processor.decode(predicted_ids[0])
60 | with open(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt', "w") as text_file:
61 | text_file.write(text)
62 | del logits, input_features
63 | if i % 100 == 0:
64 | print("{:d} / {:d} processd from processor {:s}".format(i, len(audio_list), argument))
65 | del model
--------------------------------------------------------------------------------
/src/noise_robust_asr/asr_experiments/transcribe_hubert_large.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/20/23 2:09 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : transcribe_aus.py
7 |
8 | import sys
9 | argument = sys.argv[1]
10 | if argument=='4':
11 | argument='0,1,2,3'
12 | import os
13 | if argument != '-1':
14 | os.environ["CUDA_VISIBLE_DEVICES"]=argument
15 | import torch
16 | import soundfile
17 | from transformers import Wav2Vec2Processor, HubertForCTC
18 |
19 | def fileList(source):
20 | matches = []
21 | for root, dirnames, filenames in os.walk(source):
22 | for filename in filenames:
23 | if filename.endswith(('.flac', '.wav')):
24 | matches.append(os.path.join(root, filename))
25 | return matches
26 |
27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28 | print(device)
29 |
30 | processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
31 | model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device)
32 |
33 | tar_path = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_hubert_large/'
34 | if os.path.exists(tar_path) == False:
35 | os.mkdir(tar_path)
36 |
37 | audio_list = fileList('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/')
38 |
39 | audio_list.sort()
40 | start_file = int(argument) * 4500
41 | end_file = int(argument) * 4500 + 4500
42 | audio_list = audio_list[start_file: end_file]
43 |
44 | print('number of files to transcribe: ', len(audio_list))
45 |
46 | for i in range(len(audio_list)):
47 | audio_path = audio_list[i]
48 | if os.path.exists(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt') == False:
49 | if audio_path[-3:] == 'wav':
50 | audio_path = audio_list[i]
51 | source, curr_sample_rate = soundfile.read(audio_path, dtype="float32")
52 | input_features = processor(source,
53 | sampling_rate=curr_sample_rate,
54 | return_tensors="pt").input_values
55 | logits = model(input_features.to(device)).logits
56 | predicted_ids = torch.argmax(logits, dim=-1)
57 | text = processor.decode(predicted_ids[0])
58 | with open(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt', "w") as text_file:
59 | text_file.write(text)
60 | del logits, input_features
61 | if i % 100 == 0:
62 | print("{:d} / {:d} processd from processor {:s}".format(i, len(audio_list), argument))
63 | del model
--------------------------------------------------------------------------------
/src/noise_robust_asr/asr_experiments/transcribe_wav2vec_base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/20/23 2:09 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : transcribe_aus.py
7 |
8 | import sys
9 | argument = sys.argv[1]
10 | if argument=='4':
11 | argument='0,1,2,3'
12 | import os
13 | if argument != '-1':
14 | os.environ["CUDA_VISIBLE_DEVICES"]=argument
15 | import torch
16 | import soundfile
17 | from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
18 |
19 | def fileList(source):
20 | matches = []
21 | for root, dirnames, filenames in os.walk(source):
22 | for filename in filenames:
23 | if filename.endswith(('.flac', '.wav')):
24 | matches.append(os.path.join(root, filename))
25 | return matches
26 |
27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28 | print(device)
29 |
30 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
31 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
32 |
33 | tar_path = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_w2v_base/'
34 | if os.path.exists(tar_path) == False:
35 | os.mkdir(tar_path)
36 |
37 | audio_list = fileList('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/')
38 |
39 | audio_list.sort()
40 | start_file = int(argument) * 4500
41 | end_file = int(argument) * 4500 + 4500
42 | audio_list = audio_list[start_file: end_file]
43 |
44 | print('number of files to transcribe: ', len(audio_list))
45 |
46 | for i in range(len(audio_list)):
47 | audio_path = audio_list[i]
48 | if os.path.exists(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt') == False:
49 | if audio_path[-3:] == 'wav':
50 | audio_path = audio_list[i]
51 | source, curr_sample_rate = soundfile.read(audio_path, dtype="float32")
52 | input_features = processor(source,
53 | sampling_rate=curr_sample_rate,
54 | return_tensors="pt").input_values
55 | logits = model(input_features.to(device)).logits
56 | predicted_ids = torch.argmax(logits, dim=-1)
57 | text = processor.decode(predicted_ids[0])
58 | with open(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt', "w") as text_file:
59 | text_file.write(text)
60 | del logits, input_features
61 | if i % 100 == 0:
62 | print("{:d} / {:d} processd from processor {:s}".format(i, len(audio_list), argument))
63 | del model
--------------------------------------------------------------------------------
/src/noise_robust_asr/asr_experiments/transcribe_wav2vec_robust.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/20/23 2:09 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : transcribe_aus.py
7 |
8 | import sys
9 | argument = sys.argv[1]
10 | if argument=='4':
11 | argument='0,1,2,3'
12 | import os
13 | if argument != '-1':
14 | os.environ["CUDA_VISIBLE_DEVICES"]=argument
15 | import torch
16 | import soundfile
17 | from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
18 |
19 | def fileList(source):
20 | matches = []
21 | for root, dirnames, filenames in os.walk(source):
22 | for filename in filenames:
23 | if filename.endswith(('.flac', '.wav')):
24 | matches.append(os.path.join(root, filename))
25 | return matches
26 |
27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28 | print(device)
29 |
30 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h")
31 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h").to(device)
32 |
33 | tar_path = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_w2v_large_robust/'
34 | if os.path.exists(tar_path) == False:
35 | os.mkdir(tar_path)
36 |
37 | audio_list = fileList('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/')
38 |
39 | audio_list.sort()
40 | start_file = int(argument) * 4500
41 | end_file = int(argument) * 4500 + 4500
42 | audio_list = audio_list[start_file: end_file]
43 |
44 | print('number of files to transcribe: ', len(audio_list))
45 |
46 | for i in range(len(audio_list)):
47 | audio_path = audio_list[i]
48 | if os.path.exists(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt') == False:
49 | if audio_path[-3:] == 'wav':
50 | audio_path = audio_list[i]
51 | source, curr_sample_rate = soundfile.read(audio_path, dtype="float32")
52 | input_features = processor(source,
53 | sampling_rate=curr_sample_rate,
54 | return_tensors="pt").input_values
55 | logits = model(input_features.to(device)).logits
56 | predicted_ids = torch.argmax(logits, dim=-1)
57 | text = processor.decode(predicted_ids[0])
58 | with open(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt', "w") as text_file:
59 | text_file.write(text)
60 | del logits, input_features
61 | if i % 100 == 0:
62 | print("{:d} / {:d} processd from processor {:s}".format(i, len(audio_list), argument))
63 | del model
--------------------------------------------------------------------------------
/src/noise_robust_asr/asr_experiments/transcribe_whisper.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/20/23 2:09 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : transcribe_aus.py
7 |
8 | # transcribe adress-m datasets
9 | import sys
10 | argument = sys.argv[1]
11 | if argument=='4':
12 | argument='0,1,2,3'
13 | import os
14 | if argument != '-1':
15 | os.environ["CUDA_VISIBLE_DEVICES"]=argument
16 |
17 | os.environ["XDG_CACHE_HOME"] = './'
18 | import ssl
19 | ssl._create_default_https_context = ssl._create_unverified_context
20 | import torch
21 | import whisper
22 | import torchaudio
23 |
24 | def show_twod(input):
25 | return "{:.2f}".format(input)
26 |
27 | def get_immediate_files(a_dir):
28 | return [a_dir + '/' + name for name in os.listdir(a_dir) if os.path.isfile(os.path.join(a_dir, name))]
29 |
30 | def fileList(source):
31 | matches = []
32 | for root, dirnames, filenames in os.walk(source):
33 | for filename in filenames:
34 | if filename.endswith(('.flac', '.wav')):
35 | matches.append(os.path.join(root, filename))
36 | return matches
37 |
38 | audio_list = fileList('/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_ready/')
39 |
40 | audio_list.sort()
41 |
42 | print('number of files to transcribe: ', len(audio_list))
43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44 | print(device)
45 |
46 | whisper_mdl_list = ['base.en'] # medium.en, small.en, base.en, tiny.en, medium, large-v1
47 |
48 | for mdl_size in whisper_mdl_list:
49 | model = whisper.load_model(mdl_size, device)
50 | for beam_size in [0]:
51 | tar_path = '/data/sls/scratch/yuangong/whisper-a/noisy_speech_camera_text_whisper_{:s}/'.format(mdl_size)
52 | if os.path.exists(tar_path) == False:
53 | os.mkdir(tar_path)
54 | for i in range(len(audio_list)):
55 | audio_path = audio_list[i]
56 | if os.path.exists(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt') == False:
57 | if audio_path[-3:] == 'wav':
58 | ori_waveform, sr = torchaudio.load(audio_path)
59 | wav_len = ori_waveform.shape[1]
60 | assert sr == 16000
61 | if beam_size == 0:
62 | result = model.transcribe(audio_path, language='en')
63 | else:
64 | result = model.transcribe(audio_path, beam_size=beam_size)
65 |
66 | # remove the first space
67 | text = result["text"][1:]
68 | if os.path.exists(tar_path) == False:
69 | os.mkdir(tar_path)
70 | with open(tar_path + audio_path.split('/')[-1].split('.')[-2] + '.txt', "w") as text_file:
71 | text_file.write(text)
72 | if i % 100 == 0:
73 | print("{:d} / {:d} processd from processor {:s}".format(i, len(audio_list), argument))
74 | del model
75 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/baseline_sound_classification.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/20/23 1:35 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : cluster_esc50_feat2.py
7 |
8 | # use new all layer feat, note these feats are already pooled over time
9 | # for whisper, w2v, and hubert models
10 |
11 | import json
12 | import os
13 | os.environ["XDG_CACHE_HOME"] = './'
14 | import numpy as np
15 |
16 | from sklearn.pipeline import Pipeline
17 | from sklearn.metrics import accuracy_score, classification_report
18 |
19 | from sklearn.preprocessing import StandardScaler
20 | from sklearn.neural_network import MLPClassifier
21 |
22 | def cluster_feat(dataset_json_file, tar_path):
23 | with open(dataset_json_file, 'r') as fp:
24 | data_json = json.load(fp)
25 | data = data_json['data']
26 | num_sample = len(data)
27 | for idx, entry in enumerate(data):
28 | wav = entry["wav"]
29 | # the first sample
30 | if idx == 0:
31 | cur_sample = np.load(tar_path + '/' + wav.split('/')[-1][:-3] + 'npy')
32 | num_layer = cur_sample.shape[0]
33 | feat_dim = cur_sample.shape[-1]
34 | print('number of layers {:d} feat dim {:d}'.format(num_layer, feat_dim))
35 | all_feat = np.zeros((num_layer + 1, num_sample, feat_dim))
36 | all_label = []
37 |
38 | cur_rep = np.load(tar_path + '/' + wav.split('/')[-1][:-3] + 'npy')
39 | for layer in range(cur_rep.shape[0]):
40 | all_feat[layer, idx] = np.mean(cur_rep[layer], axis=0)
41 |
42 | all_feat[-1, idx] = np.mean(np.mean(cur_rep, axis=0), axis=0)
43 |
44 | cur_label = int(wav.split('.')[-2].split('-')[-1])
45 | all_label.append(cur_label)
46 |
47 | assert all_feat[0].shape[0] == len(all_label)
48 | return all_feat, all_label
49 |
50 | def get_immediate_dir(a_dir):
51 | return [name for name in os.listdir(a_dir) if os.path.isdir(os.path.join(a_dir, name))]
52 |
53 | mdl_size_list = ['wav2vec2-large-robust-ft-swbd-300h']
54 |
55 | for mdl_size in mdl_size_list:
56 | for fold in range(1, 6):
57 | print(mdl_size)
58 | if 'whisper' not in mdl_size:
59 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_all/' + mdl_size + '_all_layer/'
60 | else:
61 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_all/' + mdl_size + '/'
62 | esc_train1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_train_data_' + str(fold) + '.json'
63 | esc_eval1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_eval_data_' + str(fold) + '.json'
64 | all_tr_feat, all_tr_label = cluster_feat(esc_train1, tar_path)
65 | all_te_feat, all_te_label = cluster_feat(esc_eval1, tar_path)
66 |
67 | num_layer = all_tr_feat.shape[0]
68 | print(all_tr_feat.shape, all_te_feat.shape)
69 |
70 | for lr in [0.001]:
71 | all_res = []
72 | for layer in range(num_layer):
73 | cla = MLPClassifier(hidden_layer_sizes=(), learning_rate='adaptive', learning_rate_init=lr, max_iter=5000, random_state=0)
74 | pipe = Pipeline([('scaler', StandardScaler()), ('svc', cla)])
75 | pipe.fit(all_tr_feat[layer], all_tr_label)
76 | pred = pipe.predict(all_te_feat[layer])
77 | acc = accuracy_score(all_te_label, pred)
78 | all_acc = classification_report(all_te_label, pred, output_dict=True)
79 | all_acc = [all_acc[str(x)]['f1-score'] for x in range(50)]
80 | res = [mdl_size, fold, all_te_feat[0].shape[1], lr, layer, acc] + all_acc
81 | all_res.append(res)
82 | np.savetxt('./baseline_res/esc_{:s}_fold{:d}_lr_{:.4f}.csv'.format(mdl_size, fold, lr), all_res, delimiter=',', fmt='%s')
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/as_full/batch_as_full_extract.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # need to batch to speed up processing
4 |
5 | for((fold=0;fold<=39;fold++));
6 | do
7 | sbatch extract_as_full_whisper_all.sh ${fold}
8 | done
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/as_full/extract_as_full_whisper_all.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/19/23 11:35 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : extract_esc50.py
7 |
8 | # extract representation for all layers for whisper model, pool by 20, not include the input mel.
9 | # save as npz to save space
10 |
11 | import json
12 | import torch
13 | import os
14 | os.environ["XDG_CACHE_HOME"] = './'
15 | import numpy as np
16 | from whisper.model import Whisper, ModelDimensions
17 | import skimage.measure
18 | import argparse
19 |
20 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
21 | parser.add_argument("--split", type=int, default=0, help="which split")
22 | args = parser.parse_args()
23 |
24 | def extract_audio(dataset_json_file, mdl, tar_path):
25 | if os.path.exists(tar_path) == False:
26 | os.mkdir((tar_path))
27 | with open(dataset_json_file, 'r') as fp:
28 | data_json = json.load(fp)
29 | data = data_json['data']
30 | for idx, entry in enumerate(data):
31 | wav = entry["wav"]
32 |
33 | if os.path.exists(tar_path + '/' + wav.split('/')[-1][:-4] + 'npz') == False:
34 | _, audio_rep = mdl.transcribe_audio(wav)
35 | audio_rep = audio_rep[0]
36 | audio_rep = torch.permute(audio_rep, (2, 0, 1)).detach().cpu().numpy()
37 | audio_rep = skimage.measure.block_reduce(audio_rep, (1, 20, 1), np.mean)
38 | audio_rep = audio_rep[1:]
39 | if idx == 0:
40 | print(audio_rep.shape)
41 | np.savez_compressed(tar_path + '/' + wav.split('/')[-1][:-4] + 'npz', audio_rep)
42 | if idx % 50 == 0:
43 | print(idx)
44 |
45 | mdl_size_list = ['medium'] # , 'large-v1', 'medium.en'
46 | mdl_size_list = mdl_size_list[::-1]
47 | for mdl_size in mdl_size_list:
48 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49 | print(device)
50 | checkpoint_path = '/data/sls/scratch/yuangong/whisper-a/src/{:s}.pt'.format(mdl_size)
51 | checkpoint = torch.load(checkpoint_path, map_location=device)
52 | dims = ModelDimensions(**checkpoint["dims"])
53 | model = Whisper(dims)
54 | model.load_state_dict(checkpoint["model_state_dict"], strict=False)
55 | model.to(device)
56 |
57 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_as_full/' + 'whisper_' + mdl_size + '/'
58 | esc_train1 = '/data/sls/scratch/yuangong/whisper-a/egs/audioset/feat_extract/split_json/{:d}.json'.format(args.split)
59 | extract_audio(esc_train1, model, tar_path)
60 | del model
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/as_full/extract_as_full_whisper_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ##SBATCH -p 1080,sm
3 | ##SBATCH -x sls-sm-1,sls-2080-[1,3],sls-1080-[1,2],sls-sm-[1,2,11,13]
4 | #SBATCH -p gpu
5 | #SBATCH --gres=gpu:1
6 | #SBATCH -c 1
7 | #SBATCH -n 1
8 |
9 | #SBATCH --mem=24000
10 | #SBATCH --job-name="as_extract"
11 | #SBATCH --output=../../log/%j_as_extract.txt
12 |
13 | set -x
14 | # comment this line if not running on sls cluster
15 | #. /data/sls/scratch/share-201907/slstoolchainrc
16 | #source /data/sls/scratch/yuangong/whisper-a/venv-wa/bin/activate
17 | export TORCH_HOME=../../pretrained_models
18 |
19 | python extract_as_full_whisper_all.py --split $1
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/esc-50/extract_esc50_hubert_xl_all_pool.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/19/23 11:35 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : extract_esc50.py
7 |
8 | # extract representation for all layers from hubert xl
9 |
10 | import json
11 | import torch
12 | import os
13 | os.environ["XDG_CACHE_HOME"] = './'
14 | import numpy as np
15 | from transformers import Wav2Vec2Processor, HubertModel
16 | import soundfile as sf
17 | import skimage.measure
18 |
19 | def extract_audio(dataset_json_file, model, processor, tar_path):
20 | if os.path.exists(tar_path) == False:
21 | os.mkdir((tar_path))
22 | with open(dataset_json_file, 'r') as fp:
23 | data_json = json.load(fp)
24 | data = data_json['data']
25 | for idx, entry in enumerate(data):
26 | wav = entry["wav"]
27 | audio, sr = sf.read(wav)
28 | assert sr == 16000
29 |
30 | input_values = processor(audio, sampling_rate=sr, return_tensors="pt").input_values.to(device) # Batch size 1
31 | audio_rep = model(input_values, output_hidden_states=True).hidden_states
32 | audio_rep = torch.stack(audio_rep, dim=0).squeeze(1)
33 | audio_rep = audio_rep.detach().cpu().numpy()
34 | audio_rep = skimage.measure.block_reduce(audio_rep, (1, 10, 1), np.mean)
35 | audio_rep = audio_rep[1:]
36 | np.savez_compressed(tar_path + '/' + wav.split('/')[-1][:-3] + 'npz', audio_rep)
37 |
38 | mdl_size_list = ['facebook/hubert-xlarge-ls960-ft']
39 | for mdl_size in mdl_size_list:
40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41 | processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-xlarge-ls960-ft")
42 | model = HubertModel.from_pretrained(mdl_size)
43 | model.to(device)
44 | model.eval()
45 |
46 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/' + mdl_size.split('/')[-1] + '/'
47 | esc_train1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_train_data_1.json'
48 | esc_eval1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_eval_data_1.json'
49 | extract_audio(esc_train1, model, processor, tar_path)
50 | extract_audio(esc_eval1, model, processor, tar_path)
51 | del model
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/esc-50/extract_esc50_w2v_robust_all.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/19/23 11:35 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : extract_esc50.py
7 |
8 | # extract representation for all layers from hubert xl
9 |
10 | import json
11 | import torch
12 | import os
13 | os.environ["XDG_CACHE_HOME"] = './'
14 | import numpy as np
15 | from transformers import Wav2Vec2Processor, Wav2Vec2Model
16 | import soundfile as sf
17 | import skimage.measure
18 |
19 | def extract_audio(dataset_json_file, model, processor, tar_path):
20 | if os.path.exists(tar_path) == False:
21 | os.mkdir((tar_path))
22 | with open(dataset_json_file, 'r') as fp:
23 | data_json = json.load(fp)
24 | data = data_json['data']
25 | for idx, entry in enumerate(data):
26 | wav = entry["wav"]
27 | audio, sr = sf.read(wav)
28 | assert sr == 16000
29 |
30 | input_values = processor(audio, sampling_rate=sr, return_tensors="pt").input_values.to(device)
31 | audio_rep = model(input_values, output_hidden_states=True).hidden_states
32 | audio_rep = torch.stack(audio_rep, dim=0).squeeze(1)
33 | audio_rep = audio_rep.detach().cpu().numpy()
34 | audio_rep = skimage.measure.block_reduce(audio_rep, (1, 10, 1), np.mean)
35 | audio_rep = audio_rep[1:]
36 | np.savez_compressed(tar_path + '/' + wav.split('/')[-1][:-3] + 'npz', audio_rep)
37 |
38 | mdl_size_list = ['facebook/wav2vec2-large-robust-ft-swbd-300h']
39 | for mdl_size in mdl_size_list:
40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h")
42 | model = Wav2Vec2Model.from_pretrained(mdl_size)
43 | model.to(device)
44 | model.eval()
45 |
46 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/' + mdl_size.split('/')[-1] + '/'
47 | esc_train1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_train_data_1.json'
48 | esc_eval1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_eval_data_1.json'
49 | extract_audio(esc_train1, model, processor, tar_path)
50 | extract_audio(esc_eval1, model, processor, tar_path)
51 | del model
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/esc-50/extract_esc50_whisper_all_pool.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/19/23 11:35 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : extract_esc50.py
7 |
8 | # extract representation for all layers for whisper model, pool by 10, not include the input mel.
9 | # save as npz to save space
10 |
11 | import json
12 | import torch
13 | import os
14 | os.environ["XDG_CACHE_HOME"] = './'
15 | import whisper
16 | import numpy as np
17 | from whisper.model import Whisper, ModelDimensions
18 | import skimage.measure
19 | import argparse
20 |
21 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22 | parser.add_argument("--split", type=int, default=0, help="which split")
23 | args = parser.parse_args()
24 |
25 | def extract_audio(dataset_json_file, mdl, tar_path):
26 | if os.path.exists(tar_path) == False:
27 | os.mkdir((tar_path))
28 | with open(dataset_json_file, 'r') as fp:
29 | data_json = json.load(fp)
30 | data = data_json['data']
31 | for idx, entry in enumerate(data):
32 | wav = entry["wav"]
33 |
34 | if os.path.exists(tar_path + '/' + wav.split('/')[-1][:-4] + 'npz') == False:
35 | # NOTE: this use a customized whisper model for feature extraction, original whisper model does not have transcribe_audio function
36 | _, audio_rep = mdl.transcribe_audio(wav)
37 | audio_rep = audio_rep[0]
38 | audio_rep = torch.permute(audio_rep, (2, 0, 1)).detach().cpu().numpy()
39 | audio_rep = skimage.measure.block_reduce(audio_rep, (1, 10, 1), np.mean) # downsample x10 for esc, 20 for audioset
40 | audio_rep = audio_rep[1:]
41 | np.savez_compressed(tar_path + '/' + wav.split('/')[-1][:-3] + 'npz', audio_rep)
42 |
43 | mdl_size_list = ['large-v2', 'large-v1', 'medium.en', 'medium', 'small.en', 'small', 'base.en', 'base', 'tiny.en', 'tiny']
44 | for mdl_size in mdl_size_list:
45 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46 | print(device)
47 | checkpoint_path = '/data/sls/scratch/yuangong/whisper-a/src/{:s}.pt'.format(mdl_size)
48 | checkpoint = torch.load(checkpoint_path, map_location=device)
49 | dims = ModelDimensions(**checkpoint["dims"])
50 | model = Whisper(dims) # NOTE: this use a customized whisper model for feature extraction, original whisper model does not have transcribe_audio function
51 | model.load_state_dict(checkpoint["model_state_dict"], strict=False)
52 | model.to(device)
53 |
54 | tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/' + 'whisper_' + mdl_size + '/'
55 | esc_train1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_train_data_1.json'
56 | esc_eval1 = '/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/datafiles/esc_eval_data_1.json' # esc-50 is 5-fold cross-validation, so 1st train and eval split covers all datas
57 | extract_audio(esc_train1, model, tar_path)
58 | extract_audio(esc_eval1, model, tar_path)
59 | del model
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v2
12 | - uses: actions-ecosystem/action-regex-match@v2
13 | id: regex-match
14 | with:
15 | text: ${{ github.event.head_commit.message }}
16 | regex: '^Release ([^ ]+)'
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.8'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Release
26 | if: ${{ steps.regex-match.outputs.match != '' }}
27 | uses: softprops/action-gh-release@v1
28 | with:
29 | tag_name: v${{ steps.regex-match.outputs.group1 }}
30 | - name: Build and publish
31 | if: ${{ steps.regex-match.outputs.match != '' }}
32 | env:
33 | TWINE_USERNAME: __token__
34 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
35 | run: |
36 | python setup.py sdist
37 | twine upload dist/*
38 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - main
9 | jobs:
10 | whisper-test:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ['3.8', '3.9', '3.10']
15 | pytorch-version: [1.10.2, 1.13.1]
16 | exclude:
17 | - python-version: '3.10'
18 | pytorch-version: 1.10.2
19 | steps:
20 | - uses: conda-incubator/setup-miniconda@v2
21 | - run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
22 | - uses: actions/checkout@v2
23 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
24 | - run: pip install pytest
25 | - run: pip install .
26 | - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
27 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 OpenAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 | include README.md
3 | include LICENSE
4 | include whisper/assets/*
5 | include whisper/assets/gpt2/*
6 | include whisper/assets/multilingual/*
7 | include whisper/normalizers/english.json
8 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/data/README.md:
--------------------------------------------------------------------------------
1 | This directory supplements the paper with more details on how we prepared the data for evaluation, to help replicate our experiments.
2 |
3 | ## Short-form English-only datasets
4 |
5 | ### LibriSpeech
6 |
7 | We used the test-clean and test-other splits from the [LibriSpeech ASR corpus](https://www.openslr.org/12).
8 |
9 | ### TED-LIUM 3
10 |
11 | We used the test split of [TED-LIUM Release 3](https://www.openslr.org/51/), using the segmented manual transcripts included in the release.
12 |
13 | ### Common Voice 5.1
14 |
15 | We downloaded the English subset of Common Voice Corpus 5.1 from [the official website](https://commonvoice.mozilla.org/en/datasets)
16 |
17 | ### Artie
18 |
19 | We used the [Artie bias corpus](https://github.com/artie-inc/artie-bias-corpus). This is a subset of the Common Voice dataset.
20 |
21 | ### CallHome & Switchboard
22 |
23 | We used the two corpora from [LDC2002S09](https://catalog.ldc.upenn.edu/LDC2002S09) and [LDC2002T43](https://catalog.ldc.upenn.edu/LDC2002T43) and followed the [eval2000_data_prep.sh](https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/eval2000_data_prep.sh) script for preprocessing. The `wav.scp` files can be converted to WAV files with the following bash commands:
24 |
25 | ```bash
26 | mkdir -p wav
27 | while read name cmd; do
28 | echo $name
29 | echo ${cmd/\|/} wav/$name.wav | bash
30 | done < wav.scp
31 | ```
32 |
33 |
34 | ### WSJ
35 |
36 | We used [LDC93S6B](https://catalog.ldc.upenn.edu/LDC93S6B) and [LDC94S13B](https://catalog.ldc.upenn.edu/LDC94S13B) and followed the [s5 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/wsj/s5) to preprocess the dataset.
37 |
38 | ### CORAAL
39 |
40 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the segmentations from [the FairSpeech project](https://github.com/stanford-policylab/asr-disparities/blob/master/input/CORAAL_transcripts.csv).
41 |
42 | ### CHiME-6
43 |
44 | We downloaded the [CHiME-5 dataset](https://spandh.dcs.shef.ac.uk//chime_challenge/CHiME5/download.html) and followed the stage 0 of the [s5_track1 recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/chime6/s5_track1) to create the CHiME-6 dataset which fixes synchronization. We then used the binaural recordings (`*_P??.wav`) and the corresponding transcripts.
45 |
46 | ### AMI-IHM, AMI-SDM1
47 |
48 | We preprocessed the [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) by following the stage 0 ad 2 of the [s5b recipe](https://github.com/kaldi-asr/kaldi/tree/master/egs/ami/s5b).
49 |
50 |
51 | ## Long-form English-only datasets
52 |
53 | ### TED-LIUM 3
54 |
55 | To create a long-form transcription dataset from the [TED-LIUM3](https://www.openslr.org/51/) dataset, we sliced the audio between the beginning of the first labeled segment and the end of the last labeled segment of each talk, and we used the concatenated text as the label. Below are the timestamps used for slicing each of the 11 TED talks in the test split.
56 |
57 | | Filename | Begin time (s) | End time (s) |
58 | |---------------------|----------------|--------------|
59 | | DanBarber_2010 | 16.09 | 1116.24 |
60 | | JaneMcGonigal_2010 | 15.476 | 1187.61 |
61 | | BillGates_2010 | 15.861 | 1656.94 |
62 | | TomWujec_2010U | 16.26 | 402.17 |
63 | | GaryFlake_2010 | 16.06 | 367.14 |
64 | | EricMead_2009P | 18.434 | 536.44 |
65 | | MichaelSpecter_2010 | 16.11 | 979.312 |
66 | | DanielKahneman_2010 | 15.8 | 1199.44 |
67 | | AimeeMullins_2009P | 17.82 | 1296.59 |
68 | | JamesCameron_2010 | 16.75 | 1010.65 |
69 | | RobertGupta_2010U | 16.8 | 387.03 |
70 |
71 | ### Meanwhile
72 |
73 | This dataset consists of 64 segments from The Late Show with Stephen Colbert. The YouTube video ID, start and end timestamps, and the labels can be found in [meanwhile.json](meanwhile.json). The labels are collected from the closed-caption data for each video and corrected with manual inspection.
74 |
75 | ### Rev16
76 |
77 | We use a subset of 16 files from the 30 podcast episodes in [Rev.AI's Podcast Transcription Benchmark](https://www.rev.ai/blog/podcast-transcription-benchmark-part-1/), after finding that there are multiple cases where a significant portion of the audio and the labels did not match, mostly on the parts introducing the sponsors. We selected 16 episodes that do not have this error, whose "file number" are:
78 |
79 | 3 4 9 10 11 14 17 18 20 21 23 24 26 27 29 32
80 |
81 | ### Kincaid46
82 |
83 | This dataset consists of 46 audio files and the corresponding transcripts compiled in the blog article [Which automatic transcription service is the most accurate - 2018](https://medium.com/descript/which-automatic-transcription-service-is-the-most-accurate-2018-2e859b23ed19) by Jason Kincaid. We used the 46 audio files and reference transcripts from the Airtable widget in the article.
84 |
85 | For the human transcription benchmark in the paper, we use a subset of 25 examples from this data, whose "Ref ID" are:
86 |
87 | 2 4 5 8 9 10 12 13 14 16 19 21 23 25 26 28 29 30 33 35 36 37 42 43 45
88 |
89 | ### Earnings-21, Earnings-22
90 |
91 | For these datasets, we used the files available in [the speech-datasets repository](https://github.com/revdotcom/speech-datasets), as of their `202206` version.
92 |
93 | ### CORAAL
94 |
95 | We used the 231 interviews from [CORAAL (v. 2021.07)](https://oraal.uoregon.edu/coraal) and used the full-length interview files and transcripts.
96 |
97 |
98 | ## Multilingual datasets
99 |
100 | ### Multilingual LibriSpeech
101 |
102 | We used the test splits from each language in [the Multilingual LibriSpeech (MLS) corpus](https://www.openslr.org/94/).
103 |
104 | ### Fleurs
105 |
106 | We collected audio files and transcripts using the implementation available as [HuggingFace datasets](https://huggingface.co/datasets/google/fleurs/blob/main/fleurs.py). To use as a translation dataset, we matched the numerical utterance IDs to find the corresponding transcript in English.
107 |
108 | ### VoxPopuli
109 |
110 | We used the `get_asr_data.py` script from [the official repository](https://github.com/facebookresearch/voxpopuli) to collect the ASR data in 14 languages.
111 |
112 | ### Common Voice 9
113 |
114 | We downloaded the Common Voice Corpus 9 from [the official website](https://commonvoice.mozilla.org/en/datasets)
115 |
116 | ### CoVOST 2
117 |
118 | We collected the `X into English` data collected using [the official repository](https://github.com/facebookresearch/covost).
119 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | torch
3 | tqdm
4 | more-itertools
5 | transformers>=4.19.0
6 | ffmpeg-python==0.2.0
7 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pkg_resources
4 | from setuptools import setup, find_packages
5 |
6 |
7 | def read_version(fname="whisper/version.py"):
8 | exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
9 | return locals()["__version__"]
10 |
11 |
12 | setup(
13 | name="openai-whisper",
14 | py_modules=["whisper"],
15 | version=read_version(),
16 | description="Robust Speech Recognition via Large-Scale Weak Supervision",
17 | long_description=open("README.md", encoding="utf-8").read(),
18 | long_description_content_type="text/markdown",
19 | readme="README.md",
20 | python_requires=">=3.7",
21 | author="OpenAI",
22 | url="https://github.com/openai/whisper",
23 | license="MIT",
24 | packages=find_packages(exclude=["tests*"]),
25 | install_requires=[
26 | str(r)
27 | for r in pkg_resources.parse_requirements(
28 | open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
29 | )
30 | ],
31 | entry_points={
32 | "console_scripts": ["whisper=whisper.transcribe:cli"],
33 | },
34 | include_package_data=True,
35 | extras_require={"dev": ["pytest"]},
36 | )
37 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/jfk.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/jfk.flac
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/test_audio.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import numpy as np
4 |
5 | from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE
6 |
7 |
8 | def test_audio():
9 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
10 | audio = load_audio(audio_path)
11 | assert audio.ndim == 1
12 | assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
13 | assert 0 < audio.std() < 1
14 |
15 | mel_from_audio = log_mel_spectrogram(audio)
16 | mel_from_file = log_mel_spectrogram(audio_path)
17 |
18 | assert np.allclose(mel_from_audio, mel_from_file)
19 | assert mel_from_audio.max() - mel_from_audio.min() <= 2.0
20 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/test_normalizer.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from whisper.normalizers import EnglishTextNormalizer
4 | from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer
5 |
6 |
7 | @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])
8 | def test_number_normalizer(std):
9 | assert std("two") == "2"
10 | assert std("thirty one") == "31"
11 | assert std("five twenty four") == "524"
12 | assert std("nineteen ninety nine") == "1999"
13 | assert std("twenty nineteen") == "2019"
14 |
15 | assert std("two point five million") == "2500000"
16 | assert std("four point two billions") == "4200000000s"
17 | assert std("200 thousand") == "200000"
18 | assert std("200 thousand dollars") == "$200000"
19 | assert std("$20 million") == "$20000000"
20 | assert std("€52.4 million") == "€52400000"
21 | assert std("£77 thousands") == "£77000s"
22 |
23 | assert std("two double o eight") == "2008"
24 |
25 | assert std("three thousand twenty nine") == "3029"
26 | assert std("forty three thousand two hundred sixty") == "43260"
27 | assert std("forty three thousand two hundred and sixty") == "43260"
28 |
29 | assert std("nineteen fifties") == "1950s"
30 | assert std("thirty first") == "31st"
31 | assert std("thirty three thousand and three hundred and thirty third") == "33333rd"
32 |
33 | assert std("three billion") == "3000000000"
34 | assert std("millions") == "1000000s"
35 |
36 | assert std("july third twenty twenty") == "july 3rd 2020"
37 | assert std("august twenty sixth twenty twenty one") == "august 26th 2021"
38 | assert std("3 14") == "3 14"
39 | assert std("3.14") == "3.14"
40 | assert std("3 point 2") == "3.2"
41 | assert std("3 point 14") == "3.14"
42 | assert std("fourteen point 4") == "14.4"
43 | assert std("two point two five dollars") == "$2.25"
44 | assert std("two hundred million dollars") == "$200000000"
45 | assert std("$20.1 million") == "$20100000"
46 |
47 | assert std("ninety percent") == "90%"
48 | assert std("seventy six per cent") == "76%"
49 |
50 | assert std("double oh seven") == "007"
51 | assert std("double zero seven") == "007"
52 | assert std("nine one one") == "911"
53 | assert std("nine double one") == "911"
54 | assert std("one triple oh one") == "10001"
55 |
56 | assert std("two thousandth") == "2000th"
57 | assert std("thirty two thousandth") == "32000th"
58 |
59 | assert std("minus 500") == "-500"
60 | assert std("positive twenty thousand") == "+20000"
61 |
62 | assert std("two dollars and seventy cents") == "$2.70"
63 | assert std("3 cents") == "¢3"
64 | assert std("$0.36") == "¢36"
65 | assert std("three euros and sixty five cents") == "€3.65"
66 |
67 | assert std("three and a half million") == "3500000"
68 | assert std("forty eight and a half dollars") == "$48.5"
69 | assert std("b747") == "b 747"
70 | assert std("10 th") == "10th"
71 | assert std("10th") == "10th"
72 |
73 |
74 | def test_spelling_normalizer():
75 | std = EnglishSpellingNormalizer()
76 |
77 | assert std("mobilisation") == "mobilization"
78 | assert std("cancelation") == "cancellation"
79 |
80 |
81 | def test_text_normalizer():
82 | std = EnglishTextNormalizer()
83 | assert std("Let's") == "let us"
84 | assert std("he's like") == "he is like"
85 | assert std("she's been like") == "she has been like"
86 | assert std("10km") == "10 km"
87 | assert std("10mm") == "10 mm"
88 | assert std("RC232") == "rc 232"
89 |
90 | assert (
91 | std("Mr. Park visited Assoc. Prof. Kim Jr.")
92 | == "mister park visited associate professor kim junior"
93 | )
94 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/test_tokenizer.py:
--------------------------------------------------------------------------------
1 | from whisper.tokenizer import get_tokenizer
2 |
3 |
4 | def test_tokenizer():
5 | gpt2_tokenizer = get_tokenizer(multilingual=False)
6 | multilingual_tokenizer = get_tokenizer(multilingual=True)
7 |
8 | text = "다람쥐 헌 쳇바퀴에 타고파"
9 | gpt2_tokens = gpt2_tokenizer.encode(text)
10 | multilingual_tokens = multilingual_tokenizer.encode(text)
11 |
12 | assert gpt2_tokenizer.decode(gpt2_tokens) == text
13 | assert multilingual_tokenizer.decode(multilingual_tokens) == text
14 | assert len(gpt2_tokens) > len(multilingual_tokens)
15 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/tests/test_transcribe.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | import torch
5 |
6 | import whisper
7 |
8 |
9 | @pytest.mark.parametrize("model_name", whisper.available_models())
10 | def test_transcribe(model_name: str):
11 | device = "cuda" if torch.cuda.is_available() else "cpu"
12 | model = whisper.load_model(model_name).to(device)
13 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
14 |
15 | language = "en" if model_name.endswith(".en") else None
16 | result = model.transcribe(audio_path, language=language, temperature=0.0)
17 | assert result["language"] == "en"
18 |
19 | transcription = result["text"].lower()
20 | assert "my fellow americans" in transcription
21 | assert "your country" in transcription
22 | assert "do for you" in transcription
23 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/__init__.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import io
3 | import os
4 | import urllib
5 | import warnings
6 | from typing import List, Optional, Union
7 |
8 | import torch
9 | from tqdm import tqdm
10 |
11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim
12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language
13 | from .model import Whisper, ModelDimensions
14 | from .transcribe import transcribe
15 | from .transcribe import transcribe_audio
16 | from .version import __version__
17 |
18 |
19 | _MODELS = {
20 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
21 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
22 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
23 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
24 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
25 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
26 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
27 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
28 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
29 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
30 | "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
31 | }
32 |
33 |
34 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
35 | os.makedirs(root, exist_ok=True)
36 |
37 | expected_sha256 = url.split("/")[-2]
38 | download_target = os.path.join(root, os.path.basename(url))
39 |
40 | if os.path.exists(download_target) and not os.path.isfile(download_target):
41 | raise RuntimeError(f"{download_target} exists and is not a regular file")
42 |
43 | if os.path.isfile(download_target):
44 | with open(download_target, "rb") as f:
45 | model_bytes = f.read()
46 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
47 | return model_bytes if in_memory else download_target
48 | else:
49 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
50 |
51 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
52 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
53 | while True:
54 | buffer = source.read(8192)
55 | if not buffer:
56 | break
57 |
58 | output.write(buffer)
59 | loop.update(len(buffer))
60 |
61 | model_bytes = open(download_target, "rb").read()
62 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
63 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
64 |
65 | return model_bytes if in_memory else download_target
66 |
67 |
68 | def available_models() -> List[str]:
69 | """Returns the names of available models"""
70 | return list(_MODELS.keys())
71 |
72 |
73 | def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
74 | """
75 | Load a Whisper ASR model
76 |
77 | Parameters
78 | ----------
79 | name : str
80 | one of the official model names listed by `whisper.available_models()`, or
81 | path to a model checkpoint containing the model dimensions and the model state_dict.
82 | device : Union[str, torch.device]
83 | the PyTorch device to put the model into
84 | download_root: str
85 | path to download the model files; by default, it uses "~/.cache/whisper"
86 | in_memory: bool
87 | whether to preload the model weights into host memory
88 |
89 | Returns
90 | -------
91 | model : Whisper
92 | The Whisper ASR model instance
93 | """
94 |
95 | if device is None:
96 | device = "cuda" if torch.cuda.is_available() else "cpu"
97 | if download_root is None:
98 | download_root = os.getenv(
99 | "XDG_CACHE_HOME",
100 | os.path.join(os.path.expanduser("~"), ".cache", "whisper")
101 | )
102 |
103 | if name in _MODELS:
104 | checkpoint_file = _download(_MODELS[name], download_root, in_memory)
105 | elif os.path.isfile(name):
106 | checkpoint_file = open(name, "rb").read() if in_memory else name
107 | else:
108 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
109 |
110 | with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
111 | checkpoint = torch.load(fp, map_location=device)
112 | del checkpoint_file
113 |
114 | dims = ModelDimensions(**checkpoint["dims"])
115 | model = Whisper(dims)
116 | model.load_state_dict(checkpoint["model_state_dict"], strict=False)
117 |
118 | return model.to(device)
119 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/__main__.py:
--------------------------------------------------------------------------------
1 | from .transcribe import cli
2 |
3 |
4 | cli()
5 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/gpt2/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/gpt2/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/mel_filters.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/mel_filters.npz
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/multilingual/added_tokens.json:
--------------------------------------------------------------------------------
1 | {"<|endoftext|>": 50257}
2 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/multilingual/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/assets/multilingual/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import lru_cache
3 | from typing import Union
4 |
5 | import ffmpeg
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from .utils import exact_div
11 |
12 | # hard-coded audio hyperparameters
13 | SAMPLE_RATE = 16000
14 | N_FFT = 400
15 | N_MELS = 80
16 | HOP_LENGTH = 160
17 | CHUNK_LENGTH = 30
18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
20 |
21 |
22 | def load_audio(file: str, sr: int = SAMPLE_RATE):
23 | """
24 | Open an audio file and read as mono waveform, resampling as necessary
25 |
26 | Parameters
27 | ----------
28 | file: str
29 | The audio file to open
30 |
31 | sr: int
32 | The sample rate to resample the audio if necessary
33 |
34 | Returns
35 | -------
36 | A NumPy array containing the audio waveform, in float32 dtype.
37 | """
38 | try:
39 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
40 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
41 | out, _ = (
42 | ffmpeg.input(file, threads=0)
43 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
44 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
45 | )
46 | except ffmpeg.Error as e:
47 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
48 |
49 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
50 |
51 |
52 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
53 | """
54 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
55 | """
56 | if torch.is_tensor(array):
57 | if array.shape[axis] > length:
58 | array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
59 |
60 | if array.shape[axis] < length:
61 | pad_widths = [(0, 0)] * array.ndim
62 | pad_widths[axis] = (0, length - array.shape[axis])
63 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
64 | else:
65 | if array.shape[axis] > length:
66 | array = array.take(indices=range(length), axis=axis)
67 |
68 | if array.shape[axis] < length:
69 | pad_widths = [(0, 0)] * array.ndim
70 | pad_widths[axis] = (0, length - array.shape[axis])
71 | array = np.pad(array, pad_widths)
72 |
73 | return array
74 |
75 |
76 | @lru_cache(maxsize=None)
77 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
78 | """
79 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
80 | Allows decoupling librosa dependency; saved using:
81 |
82 | np.savez_compressed(
83 | "mel_filters.npz",
84 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
85 | )
86 | """
87 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
88 | with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
89 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
90 |
91 |
92 | def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
93 | """
94 | Compute the log-Mel spectrogram of
95 |
96 | Parameters
97 | ----------
98 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
99 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
100 |
101 | n_mels: int
102 | The number of Mel-frequency filters, only 80 is supported
103 |
104 | Returns
105 | -------
106 | torch.Tensor, shape = (80, n_frames)
107 | A Tensor that contains the Mel spectrogram
108 | """
109 | if not torch.is_tensor(audio):
110 | if isinstance(audio, str):
111 | audio = load_audio(audio)
112 | audio = torch.from_numpy(audio)
113 |
114 | window = torch.hann_window(N_FFT).to(audio.device)
115 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
116 | magnitudes = stft[..., :-1].abs() ** 2
117 |
118 | filters = mel_filters(audio.device, n_mels)
119 | mel_spec = filters @ magnitudes
120 |
121 | log_spec = torch.clamp(mel_spec, min=1e-10).log10()
122 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
123 | log_spec = (log_spec + 4.0) / 4.0
124 | return log_spec
125 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/normalizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import BasicTextNormalizer
2 | from .english import EnglishTextNormalizer
3 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/normalizers/basic.py:
--------------------------------------------------------------------------------
1 | import re
2 | import unicodedata
3 |
4 | import regex
5 |
6 | # non-ASCII letters that are not separated by "NFKD" normalization
7 | ADDITIONAL_DIACRITICS = {
8 | "œ": "oe",
9 | "Œ": "OE",
10 | "ø": "o",
11 | "Ø": "O",
12 | "æ": "ae",
13 | "Æ": "AE",
14 | "ß": "ss",
15 | "ẞ": "SS",
16 | "đ": "d",
17 | "Đ": "D",
18 | "ð": "d",
19 | "Ð": "D",
20 | "þ": "th",
21 | "Þ": "th",
22 | "ł": "l",
23 | "Ł": "L",
24 | }
25 |
26 |
27 | def remove_symbols_and_diacritics(s: str, keep=""):
28 | """
29 | Replace any other markers, symbols, and punctuations with a space,
30 | and drop any diacritics (category 'Mn' and some manual mappings)
31 | """
32 | return "".join(
33 | c
34 | if c in keep
35 | else ADDITIONAL_DIACRITICS[c]
36 | if c in ADDITIONAL_DIACRITICS
37 | else ""
38 | if unicodedata.category(c) == "Mn"
39 | else " "
40 | if unicodedata.category(c)[0] in "MSP"
41 | else c
42 | for c in unicodedata.normalize("NFKD", s)
43 | )
44 |
45 |
46 | def remove_symbols(s: str):
47 | """
48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics
49 | """
50 | return "".join(
51 | " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
52 | )
53 |
54 |
55 | class BasicTextNormalizer:
56 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
57 | self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
58 | self.split_letters = split_letters
59 |
60 | def __call__(self, s: str):
61 | s = s.lower()
62 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
63 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
64 | s = self.clean(s).lower()
65 |
66 | if self.split_letters:
67 | s = " ".join(regex.findall(r"\X", s, regex.U))
68 |
69 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
70 |
71 | return s
72 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/tokenizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 | from functools import lru_cache
4 | from typing import List, Optional, Tuple, Union
5 |
6 | import numpy as np
7 | import torch
8 | from transformers import GPT2TokenizerFast
9 |
10 | LANGUAGES = {
11 | "en": "english",
12 | "zh": "chinese",
13 | "de": "german",
14 | "es": "spanish",
15 | "ru": "russian",
16 | "ko": "korean",
17 | "fr": "french",
18 | "ja": "japanese",
19 | "pt": "portuguese",
20 | "tr": "turkish",
21 | "pl": "polish",
22 | "ca": "catalan",
23 | "nl": "dutch",
24 | "ar": "arabic",
25 | "sv": "swedish",
26 | "it": "italian",
27 | "id": "indonesian",
28 | "hi": "hindi",
29 | "fi": "finnish",
30 | "vi": "vietnamese",
31 | "he": "hebrew",
32 | "uk": "ukrainian",
33 | "el": "greek",
34 | "ms": "malay",
35 | "cs": "czech",
36 | "ro": "romanian",
37 | "da": "danish",
38 | "hu": "hungarian",
39 | "ta": "tamil",
40 | "no": "norwegian",
41 | "th": "thai",
42 | "ur": "urdu",
43 | "hr": "croatian",
44 | "bg": "bulgarian",
45 | "lt": "lithuanian",
46 | "la": "latin",
47 | "mi": "maori",
48 | "ml": "malayalam",
49 | "cy": "welsh",
50 | "sk": "slovak",
51 | "te": "telugu",
52 | "fa": "persian",
53 | "lv": "latvian",
54 | "bn": "bengali",
55 | "sr": "serbian",
56 | "az": "azerbaijani",
57 | "sl": "slovenian",
58 | "kn": "kannada",
59 | "et": "estonian",
60 | "mk": "macedonian",
61 | "br": "breton",
62 | "eu": "basque",
63 | "is": "icelandic",
64 | "hy": "armenian",
65 | "ne": "nepali",
66 | "mn": "mongolian",
67 | "bs": "bosnian",
68 | "kk": "kazakh",
69 | "sq": "albanian",
70 | "sw": "swahili",
71 | "gl": "galician",
72 | "mr": "marathi",
73 | "pa": "punjabi",
74 | "si": "sinhala",
75 | "km": "khmer",
76 | "sn": "shona",
77 | "yo": "yoruba",
78 | "so": "somali",
79 | "af": "afrikaans",
80 | "oc": "occitan",
81 | "ka": "georgian",
82 | "be": "belarusian",
83 | "tg": "tajik",
84 | "sd": "sindhi",
85 | "gu": "gujarati",
86 | "am": "amharic",
87 | "yi": "yiddish",
88 | "lo": "lao",
89 | "uz": "uzbek",
90 | "fo": "faroese",
91 | "ht": "haitian creole",
92 | "ps": "pashto",
93 | "tk": "turkmen",
94 | "nn": "nynorsk",
95 | "mt": "maltese",
96 | "sa": "sanskrit",
97 | "lb": "luxembourgish",
98 | "my": "myanmar",
99 | "bo": "tibetan",
100 | "tl": "tagalog",
101 | "mg": "malagasy",
102 | "as": "assamese",
103 | "tt": "tatar",
104 | "haw": "hawaiian",
105 | "ln": "lingala",
106 | "ha": "hausa",
107 | "ba": "bashkir",
108 | "jw": "javanese",
109 | "su": "sundanese",
110 | }
111 |
112 | # language code lookup by name, with a few language aliases
113 | TO_LANGUAGE_CODE = {
114 | **{language: code for code, language in LANGUAGES.items()},
115 | "burmese": "my",
116 | "valencian": "ca",
117 | "flemish": "nl",
118 | "haitian": "ht",
119 | "letzeburgesch": "lb",
120 | "pushto": "ps",
121 | "panjabi": "pa",
122 | "moldavian": "ro",
123 | "moldovan": "ro",
124 | "sinhalese": "si",
125 | "castilian": "es",
126 | }
127 |
128 |
129 | @dataclass(frozen=True)
130 | class Tokenizer:
131 | """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
132 |
133 | tokenizer: "GPT2TokenizerFast"
134 | language: Optional[str]
135 | sot_sequence: Tuple[int]
136 |
137 | def encode(self, text, **kwargs):
138 | return self.tokenizer.encode(text, **kwargs)
139 |
140 | def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
141 | return self.tokenizer.decode(token_ids, **kwargs)
142 |
143 | def decode_with_timestamps(self, tokens) -> str:
144 | """
145 | Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
146 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
147 | """
148 | outputs = [[]]
149 | for token in tokens:
150 | if token >= self.timestamp_begin:
151 | timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
152 | outputs.append(timestamp)
153 | outputs.append([])
154 | else:
155 | outputs[-1].append(token)
156 | outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
157 | return "".join(outputs)
158 |
159 | @property
160 | @lru_cache()
161 | def eot(self) -> int:
162 | return self.tokenizer.eos_token_id
163 |
164 | @property
165 | @lru_cache()
166 | def sot(self) -> int:
167 | return self._get_single_token_id("<|startoftranscript|>")
168 |
169 | @property
170 | @lru_cache()
171 | def sot_lm(self) -> int:
172 | return self._get_single_token_id("<|startoflm|>")
173 |
174 | @property
175 | @lru_cache()
176 | def sot_prev(self) -> int:
177 | return self._get_single_token_id("<|startofprev|>")
178 |
179 | @property
180 | @lru_cache()
181 | def no_speech(self) -> int:
182 | return self._get_single_token_id("<|nospeech|>")
183 |
184 | @property
185 | @lru_cache()
186 | def no_timestamps(self) -> int:
187 | return self._get_single_token_id("<|notimestamps|>")
188 |
189 | @property
190 | @lru_cache()
191 | def timestamp_begin(self) -> int:
192 | return self.tokenizer.all_special_ids[-1] + 1
193 |
194 | @property
195 | @lru_cache()
196 | def language_token(self) -> int:
197 | """Returns the token id corresponding to the value of the `language` field"""
198 | if self.language is None:
199 | raise ValueError(f"This tokenizer does not have language token configured")
200 |
201 | additional_tokens = dict(
202 | zip(
203 | self.tokenizer.additional_special_tokens,
204 | self.tokenizer.additional_special_tokens_ids,
205 | )
206 | )
207 | candidate = f"<|{self.language}|>"
208 | if candidate in additional_tokens:
209 | return additional_tokens[candidate]
210 |
211 | raise KeyError(f"Language {self.language} not found in tokenizer.")
212 |
213 | @property
214 | @lru_cache()
215 | def all_language_tokens(self) -> Tuple[int]:
216 | result = []
217 | for token, token_id in zip(
218 | self.tokenizer.additional_special_tokens,
219 | self.tokenizer.additional_special_tokens_ids,
220 | ):
221 | if token.strip("<|>") in LANGUAGES:
222 | result.append(token_id)
223 | return tuple(result)
224 |
225 | @property
226 | @lru_cache()
227 | def all_language_codes(self) -> Tuple[str]:
228 | return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
229 |
230 | @property
231 | @lru_cache()
232 | def sot_sequence_including_notimestamps(self) -> Tuple[int]:
233 | return tuple(list(self.sot_sequence) + [self.no_timestamps])
234 |
235 | @property
236 | @lru_cache()
237 | def non_speech_tokens(self) -> Tuple[int]:
238 | """
239 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
240 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
241 |
242 | - ♪♪♪
243 | - ( SPEAKING FOREIGN LANGUAGE )
244 | - [DAVID] Hey there,
245 |
246 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
247 | """
248 | symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
249 | symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
250 |
251 | # symbols that may be a single token or multiple tokens depending on the tokenizer.
252 | # In case they're multiple tokens, suppress the first token, which is safe because:
253 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
254 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
255 | miscellaneous = set("♩♪♫♬♭♮♯")
256 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
257 |
258 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
259 | result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
260 | for symbol in symbols + list(miscellaneous):
261 | for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
262 | if len(tokens) == 1 or symbol in miscellaneous:
263 | result.add(tokens[0])
264 |
265 | return tuple(sorted(result))
266 |
267 | def _get_single_token_id(self, text) -> int:
268 | tokens = self.tokenizer.encode(text)
269 | assert len(tokens) == 1, f"{text} is not encoded as a single token"
270 | return tokens[0]
271 |
272 |
273 | @lru_cache(maxsize=None)
274 | def build_tokenizer(name: str = "gpt2"):
275 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
276 | path = os.path.join(os.path.dirname(__file__), "assets", name)
277 | tokenizer = GPT2TokenizerFast.from_pretrained(path)
278 |
279 | specials = [
280 | "<|startoftranscript|>",
281 | *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
282 | "<|translate|>",
283 | "<|transcribe|>",
284 | "<|startoflm|>",
285 | "<|startofprev|>",
286 | "<|nospeech|>",
287 | "<|notimestamps|>",
288 | ]
289 |
290 | tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
291 | return tokenizer
292 |
293 |
294 | @lru_cache(maxsize=None)
295 | def get_tokenizer(
296 | multilingual: bool,
297 | *,
298 | task: Optional[str] = None, # Literal["transcribe", "translate", None]
299 | language: Optional[str] = None,
300 | ) -> Tokenizer:
301 | if language is not None:
302 | language = language.lower()
303 | if language not in LANGUAGES:
304 | if language in TO_LANGUAGE_CODE:
305 | language = TO_LANGUAGE_CODE[language]
306 | else:
307 | raise ValueError(f"Unsupported language: {language}")
308 |
309 | if multilingual:
310 | tokenizer_name = "multilingual"
311 | task = task or "transcribe"
312 | language = language or "en"
313 | else:
314 | tokenizer_name = "gpt2"
315 | task = None
316 | language = None
317 |
318 | tokenizer = build_tokenizer(name=tokenizer_name)
319 | all_special_ids: List[int] = tokenizer.all_special_ids
320 | sot: int = all_special_ids[1]
321 | translate: int = all_special_ids[-6]
322 | transcribe: int = all_special_ids[-5]
323 |
324 | langs = tuple(LANGUAGES.keys())
325 | sot_sequence = [sot]
326 | if language is not None:
327 | sot_sequence.append(sot + 1 + langs.index(language))
328 | if task is not None:
329 | sot_sequence.append(transcribe if task == "transcribe" else translate)
330 |
331 | return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
332 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/utils.py:
--------------------------------------------------------------------------------
1 | import zlib
2 | from typing import Iterator, TextIO
3 |
4 |
5 | def exact_div(x, y):
6 | assert x % y == 0
7 | return x // y
8 |
9 |
10 | def str2bool(string):
11 | str2val = {"True": True, "False": False}
12 | if string in str2val:
13 | return str2val[string]
14 | else:
15 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
16 |
17 |
18 | def optional_int(string):
19 | return None if string == "None" else int(string)
20 |
21 |
22 | def optional_float(string):
23 | return None if string == "None" else float(string)
24 |
25 |
26 | def compression_ratio(text) -> float:
27 | text_bytes = text.encode("utf-8")
28 | return len(text_bytes) / len(zlib.compress(text_bytes))
29 |
30 |
31 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
32 | assert seconds >= 0, "non-negative timestamp expected"
33 | milliseconds = round(seconds * 1000.0)
34 |
35 | hours = milliseconds // 3_600_000
36 | milliseconds -= hours * 3_600_000
37 |
38 | minutes = milliseconds // 60_000
39 | milliseconds -= minutes * 60_000
40 |
41 | seconds = milliseconds // 1_000
42 | milliseconds -= seconds * 1_000
43 |
44 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
45 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
46 |
47 |
48 | def write_txt(transcript: Iterator[dict], file: TextIO):
49 | for segment in transcript:
50 | print(segment['text'].strip(), file=file, flush=True)
51 |
52 |
53 | def write_vtt(transcript: Iterator[dict], file: TextIO):
54 | print("WEBVTT\n", file=file)
55 | for segment in transcript:
56 | print(
57 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
58 | f"{segment['text'].strip().replace('-->', '->')}\n",
59 | file=file,
60 | flush=True,
61 | )
62 |
63 |
64 | def write_srt(transcript: Iterator[dict], file: TextIO):
65 | """
66 | Write a transcript to a file in SRT format.
67 |
68 | Example usage:
69 | from pathlib import Path
70 | from whisper.utils import write_srt
71 |
72 | result = transcribe(model, audio_path, temperature=temperature, **args)
73 |
74 | # save SRT
75 | audio_basename = Path(audio_path).stem
76 | with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
77 | write_srt(result["segments"], file=srt)
78 | """
79 | for i, segment in enumerate(transcript, start=1):
80 | # write srt lines
81 | print(
82 | f"{i}\n"
83 | f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
84 | f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
85 | f"{segment['text'].strip().replace('-->', '->')}\n",
86 | file=file,
87 | flush=True,
88 | )
89 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/intermediate_feat_extract/whisper_feat_extracrt/whisper/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "20230117"
2 |
--------------------------------------------------------------------------------
/src/noise_robust_asr/plots/plot_figure1_lower.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2/13/23 3:40 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : get_summary.py
7 |
8 | # summarize the results of esc experiments
9 |
10 | import json
11 | import os
12 | os.environ["XDG_CACHE_HOME"] = './'
13 | import numpy as np
14 | from matplotlib import pyplot as plt
15 |
16 | mdl_size_list = ['whisper_large-v1',
17 | 'hubert-xlarge-ll60k',
18 | 'hubert-xlarge-ls960-ft',
19 | 'wav2vec2-large-robust',
20 | 'wav2vec2-large-robust-ft-swbd-300h',
21 | 'hubert-large-ls960-ft',
22 | 'wav2vec2-base-960h']
23 |
24 | legend_list = ['Whisper-Large', 'Hubert-XLarge-PR', 'Hubert-XLarge-FT', 'Wav2vec2-Large-Robust-PR', 'Wav2vec2-Large-Robust-FT', 'Hubert-Large-FT', 'Wav2vec2-Base-FT']
25 |
26 | for i, mdl_size in enumerate(mdl_size_list):
27 | all_res = []
28 | for fold in range(1, 6):
29 | for lr in [0.001]:
30 | cur_res = np.loadtxt('./baseline_res/esc_{:s}_fold{:d}_lr_{:.4f}.csv'.format(mdl_size, fold, lr), delimiter=',', usecols=(5)).tolist()
31 | all_res.append(cur_res)
32 | all_res = np.array(all_res)
33 | all_res = np.mean(all_res, axis=0)[1:-1] * 100
34 | print(all_res.shape)
35 | num_layer = all_res.shape[0]
36 | if i == 0: # whisper
37 | plt.plot(list(range(1, num_layer+1)), all_res, '-o', label = legend_list[i], linewidth=2)
38 | elif i == 1:
39 | plt.plot(list(range(1, num_layer + 1)), all_res, 'g-', label=legend_list[i], linewidth=2, alpha=0.5)
40 | elif i == 2:
41 | plt.plot(list(range(1, num_layer + 1)), all_res, 'g-x', label=legend_list[i], linewidth=2)
42 | elif i == 3:
43 | plt.plot(list(range(1, num_layer + 1)), all_res, 'c-', label=legend_list[i], linewidth=2, alpha=0.5)
44 | elif i == 4:
45 | plt.plot(list(range(1, num_layer + 1)), all_res, 'c-*', label=legend_list[i], linewidth=2)
46 | elif i == 5:
47 | plt.plot(list(range(1, num_layer + 1)), all_res, '-^', label=legend_list[i], linewidth=2)
48 | elif i == 6:
49 | plt.plot(list(range(1, num_layer + 1)), all_res, 'r-d', label=legend_list[i], linewidth=2)
50 |
51 | plt.ylim([0, 1])
52 | plt.xlabel('Classifying Using Representation of Layer # as Input', fontsize=13.5)
53 | plt.ylabel('Sound Classification Accuracy (%)', fontsize=14)
54 | plt.legend(fontsize=10)
55 | plt.grid()
56 | plt.ylim([28, 90])
57 | plt.xlim([0, 50])
58 | figure = plt.gcf()
59 | figure.set_size_inches(6, 4)
60 | plt.savefig('./formal_plot/result_summary_' + str(lr) + '_cr.pdf', dpi=300, bbox_inches='tight')
61 | plt.close()
--------------------------------------------------------------------------------
/src/noise_robust_asr/plots/plot_figure1_upper.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 3/4/23 9:11 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : plot_snr.py
7 | import numpy as np
8 | from matplotlib import pyplot as plt
9 |
10 | mdl_list = ['Whisper-Large', 'Hubert-XLarge-FT', 'Wav2vec2-Large-Robust-FT', 'Hubert-Large-FT', 'Wav2vec2-Base-FT']
11 |
12 | exp_name_list = ['whisper_large-v1', 'hubert_xlarge', 'w2v_large_robust', 'hubert_large', 'w2v_base']
13 | snr_list = [-20, -15, -10, -5, 0, 5, 10, 15, 20]
14 |
15 | for i in range(len(exp_name_list)):
16 | exp_name = exp_name_list[i]
17 | cur_res = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/{:s}.csv'.format(exp_name))
18 | cur_res = cur_res * 100
19 | print(exp_name, cur_res.shape)
20 | if i == 0:
21 | plt.plot(snr_list, cur_res, '-o', label=mdl_list[i], linewidth=2)
22 | elif i == 1:
23 | plt.plot(snr_list, cur_res, 'g-x', label=mdl_list[i], linewidth=2)
24 | elif i == 2:
25 | plt.plot(snr_list, cur_res, 'c-*', label=mdl_list[i], linewidth=2)
26 | elif i == 3:
27 | plt.plot(snr_list, cur_res, '-^', label=mdl_list[i], linewidth=2)
28 | elif i == 4:
29 | plt.plot(snr_list, cur_res, 'r-d', label=mdl_list[i], linewidth=2)
30 |
31 | plt.xlabel('Signal-to-Noise Ratio (dB)', fontsize=14)
32 | plt.ylabel('Word Error Rate (%)', fontsize=14)
33 | plt.legend(fontsize=10)
34 | plt.gca().invert_xaxis()
35 | plt.grid()
36 | figure = plt.gcf()
37 | figure.set_size_inches(6, 2.5)
38 | plt.savefig('./snr_plot_cr.pdf', dpi=300, bbox_inches='tight')
39 | plt.close()
--------------------------------------------------------------------------------
/src/noise_robust_asr/plots/plot_figure2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 3/4/23 9:11 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : plot_snr.py
7 | import numpy as np
8 | from matplotlib import pyplot as plt
9 |
10 | label_list = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/esc_class_labels_indices.csv', delimiter=',', dtype=str, skiprows=1, usecols=(2)).tolist()
11 | print(label_list)
12 |
13 | # classes not used in fitting the line
14 | outlier_list = [8, 10, 11, 16, 22, 27, 35, 36, 41, 45, 46, 49]
15 | not_outliear_list = []
16 | for x in range(50):
17 | if x not in outlier_list:
18 | not_outliear_list.append(x)
19 |
20 | start = 2 # -10 epoch
21 | snr_res = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/src/noisy_exp/results_camera/whisper_large-v1_cla.csv', delimiter=',')
22 | snr_drop = snr_res[start] - snr_res[-1] # from -10 (3nd row, index 2) to 20 (last row) snr
23 | print(snr_drop.shape)
24 | snr_drop = [x*100 for x in snr_drop]
25 |
26 | # sound classification result
27 | all_res = []
28 | mdl_size = 'whisper_large-v1'
29 | for fold in range(1, 6):
30 | for lr in [0.001]:
31 | cur_res = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/src/baseline_cla/baseline_res/esc_{:s}_fold{:d}_lr_{:.4f}.csv'.format(mdl_size, fold, lr), delimiter=',', usecols=list(range(6,56)))#.tolist()
32 | cur_res = cur_res[-2, :].tolist() # get the none-wa mean of representation, [50,] corresponds to 50 classes, -2 is the last layer out, 0 is the input layer out, -1 is the wa out
33 | all_res.append(cur_res)
34 | all_res = np.array(all_res) # [5, 50] , 5 folds, 50 classes
35 |
36 | sound_res = np.mean(all_res, axis=0) * 100
37 | print(sound_res.shape)
38 |
39 | print(start, 'corr', np.corrcoef(sound_res, snr_drop)[0, 1])
40 |
41 | b, a = np.polyfit(np.array(sound_res)[not_outliear_list], np.array(snr_drop)[not_outliear_list], deg=1)
42 |
43 | print(start, 'corr', np.corrcoef(np.array(sound_res)[not_outliear_list], np.array(snr_drop)[not_outliear_list])[0, 1])
44 |
45 | # Create sequence of 100 numbers from 0 to 100
46 | xseq = np.linspace(50, 100, num=50)
47 |
48 | # Plot regression line
49 | plt.plot(xseq, a + b * xseq, '--', lw=2.5, alpha=0.7)
50 | plt.fill_between(xseq, a+22 + (b - 0.465) * xseq, 70, alpha=0.3, color='lightblue')
51 |
52 | plt.scatter(sound_res, snr_drop)
53 |
54 | font_size = 2
55 | for cla_i in range(50):
56 | plt.annotate(label_list[cla_i][1:-1], (sound_res[cla_i], snr_drop[cla_i]), fontsize=font_size)
57 |
58 | plt.ylim([-5, 110])
59 | plt.xlim([50, 100])
60 | plt.grid()
61 | plt.gca().invert_yaxis()
62 | plt.xlabel('ESC-50 Class-wise F1-Score')
63 | plt.ylabel('Word Error Rate Increase from 20dB to -10dB SNR (%)')
64 | figure = plt.gcf()
65 | figure.set_size_inches(5, 5)
66 | plt.savefig('./figure_test/snr_plot_cla_{:d}_cr_outlier.pdf'.format(start), dpi=300, bbox_inches='tight')
67 | plt.close()
--------------------------------------------------------------------------------
/src/noise_robust_asr/plots/plot_figure3.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2/13/23 3:40 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : plot_figure3.py
7 |
8 | # hist of best layer for each sound class
9 |
10 | import os
11 | os.environ["XDG_CACHE_HOME"] = './'
12 | import numpy as np
13 | from matplotlib import pyplot as plt
14 |
15 | label_list = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/egs/esc-50/feat_extract/data/esc_class_labels_indices.csv', delimiter=',', dtype=str, skiprows=1, usecols=(2)).tolist()
16 |
17 | all_res = []
18 | mdl_size = 'whisper_large-v1'
19 | for fold in range(1, 6):
20 | for lr in [0.001]:
21 | cur_res = np.loadtxt('/data/sls/scratch/yuangong/whisper-a/src/baseline_cla/baseline_res/esc_{:s}_fold{:d}_lr_{:.4f}.csv'.format(mdl_size, fold, lr), delimiter=',', usecols=list(range(6, 56)))
22 | cur_res = cur_res[1:-1, :].tolist() # exclude the input and last avg layer
23 | all_res.append(cur_res)
24 | all_res = np.array(all_res) # [5, 50] , 5 folds, 50 classes
25 | sound_res = np.mean(all_res, axis=0) * 100
26 |
27 | best_layer_list = []
28 | for i in range(sound_res.shape[1]):
29 | best_layer_list.append(np.argmax(sound_res[:, i] + 1))
30 |
31 | print(best_layer_list)
32 |
33 | plt.hist(best_layer_list, bins=16, histtype ='bar', rwidth=0.7)
34 | plt.xlabel('Representation of Layer', fontsize=14)
35 | plt.ylabel('# Classes', fontsize=14)
36 | plt.xticks(range(1, 33, 2))
37 | plt.grid()
38 | figure = plt.gcf()
39 | figure.set_size_inches(6, 2)
40 | plt.savefig('./formal_plot/best_layer.pdf', dpi=300, bbox_inches='tight')
41 | plt.close()
--------------------------------------------------------------------------------
/src/whisper_at_train/datafiles/README.md:
--------------------------------------------------------------------------------
1 | # AudioSet Datafiles
2 |
3 | We provide the json file we used to train the model to facilitate reproduction, it contains sample id and corresponding labels, actual audio files are not contained.
4 |
5 | [AudioSet-2M Training Json File](https://www.dropbox.com/scl/fi/szjlerzblw17f8d2fykgt/whole_train_data.json?rlkey=dr3rdri1jaql0g9lgfpiyfihg&dl=1)
6 |
7 | [AudioSet-Eval Json File](https://www.dropbox.com/scl/fi/cfu70jnxqgphi9d1nm89w/eval_data.json?rlkey=t7b44qk27iznrrtwl80hnz12q&dl=1)
--------------------------------------------------------------------------------
/src/whisper_at_train/dataloader_feat.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 6/19/21 12:23 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : dataloader.py
7 |
8 | # modified from:
9 | # Author: David Harwath
10 | # with some functions borrowed from https://github.com/SeanNaren/deepspeech.pytorch
11 | # load from whisper feats
12 |
13 | import csv
14 | import json
15 | import os.path
16 |
17 | import torchaudio
18 | import numpy as np
19 | import torch
20 | import torch.nn.functional
21 | from torch.utils.data import Dataset
22 | import random
23 | from whisper.audio import log_mel_spectrogram, pad_or_trim, load_audio
24 |
25 | def make_index_dict(label_csv):
26 | index_lookup = {}
27 | with open(label_csv, 'r') as f:
28 | csv_reader = csv.DictReader(f)
29 | line_count = 0
30 | for row in csv_reader:
31 | index_lookup[row['mid']] = row['index']
32 | line_count += 1
33 | return index_lookup
34 |
35 | def make_name_dict(label_csv):
36 | name_lookup = {}
37 | with open(label_csv, 'r') as f:
38 | csv_reader = csv.DictReader(f)
39 | line_count = 0
40 | for row in csv_reader:
41 | name_lookup[row['index']] = row['display_name']
42 | line_count += 1
43 | return name_lookup
44 |
45 | def lookup_list(index_list, label_csv):
46 | label_list = []
47 | table = make_name_dict(label_csv)
48 | for item in index_list:
49 | label_list.append(table[item])
50 | return label_list
51 |
52 | def preemphasis(signal,coeff=0.97):
53 | return np.append(signal[0],signal[1:]-coeff*signal[:-1])
54 |
55 | class AudiosetDataset(Dataset):
56 | def __init__(self, dataset_json_file, audio_conf, label_csv=None):
57 | self.datapath = dataset_json_file
58 | with open(dataset_json_file, 'r') as fp:
59 | data_json = json.load(fp)
60 |
61 | self.data = data_json['data']
62 | self.data = self.pro_data(self.data)
63 | print('Dataset has {:d} samples'.format(self.data.shape[0]))
64 | self.num_samples = self.data.shape[0]
65 | self.audio_conf = audio_conf
66 | self.label_smooth = self.audio_conf.get('label_smooth', 0.0)
67 | print('Using Label Smoothing: ' + str(self.label_smooth))
68 | self.freqm = self.audio_conf.get('freqm', 0)
69 | self.timem = self.audio_conf.get('timem', 0)
70 | print('Using Following Mask: {:d} Freq, {:d} Time'.format(self.audio_conf.get('freqm'), self.audio_conf.get('timem')))
71 | self.mixup = self.audio_conf.get('mixup', 0)
72 | print('Using Mix-up with Rate {:f}'.format(self.mixup))
73 | self.dataset = self.audio_conf.get('dataset')
74 | print('Now Process ' + self.dataset)
75 |
76 | self.index_dict = make_index_dict(label_csv)
77 | self.label_num = len(self.index_dict)
78 | print('Number of Classes is {:d}'.format(self.label_num))
79 |
80 | self.tar_path= self.audio_conf.get('tar_path')
81 | print('Now load features from {:s}'.format(self.tar_path))
82 |
83 | # change python list to numpy array to avoid memory leak.
84 | def pro_data(self, data_json):
85 | for i in range(len(data_json)):
86 | data_json[i] = [data_json[i]['wav'], data_json[i]['labels']]
87 | data_np = np.array(data_json, dtype=str)
88 | return data_np
89 |
90 | # reformat numpy data to original json format, make it compatible with old code
91 | def decode_data(self, np_data):
92 | datum = {}
93 | datum['wav'] = np_data[0]
94 | datum['labels'] = np_data[1]
95 | return datum
96 |
97 | def load_rep(self, path):
98 | try:
99 | # if npy file
100 | if path[-3:] == 'npy':
101 | return np.load(path)
102 | elif path[-3:] == 'npz':
103 | return np.load(path)['arr_0']
104 | except:
105 | print('a missing file', path)
106 | return np.zeros((6, 25, 512)) # should only work for whisper-base model, which has missing file problem
107 |
108 | def _wav2fbank(self, filename, filename2=None, mix_lambda=-1):
109 | if 'feat_as' in self.tar_path or 'feat_esc_pool' in self.tar_path:
110 | fmt = '.npz'
111 | else:
112 | fmt = '.npy'
113 |
114 | tar_path = self.tar_path + '/'
115 | if filename2 == None:
116 | filename = tar_path + '.'.join(filename.split('/')[-1].split('.')[:-1]) + fmt
117 | feat = self.load_rep(filename)
118 | feat = torch.Tensor(feat)
119 |
120 | # 25 is the time length after pooling
121 | if feat.shape[1] < 25:
122 | len_diff = 25 - feat.shape[1]
123 | feat = torch.nn.functional.pad(feat, (0, 0, 0, len_diff))
124 | else:
125 | feat = feat[:, :25, :]
126 |
127 | else:
128 | filename = tar_path + '.'.join(filename.split('/')[-1].split('.')[:-1]) + fmt
129 | feat = self.load_rep(filename)
130 | feat = torch.Tensor(feat)
131 |
132 | filename2 = tar_path + '.'.join(filename2.split('/')[-1].split('.')[:-1]) + fmt
133 | feat2 = self.load_rep(filename2)
134 | feat2 = torch.Tensor(feat2)
135 |
136 | if feat.shape[1] < 25:
137 | len_diff = 25 - feat.shape[1]
138 | feat = torch.nn.functional.pad(feat, (0, 0, 0, len_diff))
139 | else:
140 | feat = feat[:, :25, :]
141 | if feat2.shape[1] < 25:
142 | len_diff = 25 - feat2.shape[1]
143 | feat2 = torch.nn.functional.pad(feat2, (0, 0, 0, len_diff))
144 | else:
145 | feat2 = feat2[:, :25, :]
146 | feat = mix_lambda * feat + (1 - mix_lambda) * feat2
147 |
148 | return feat
149 |
150 | def __getitem__(self, index):
151 | if random.random() < self.mixup:
152 | datum = self.data[index]
153 | datum = self.decode_data(datum)
154 | mix_sample_idx = random.randint(0, self.num_samples-1)
155 | mix_datum = self.data[mix_sample_idx]
156 | mix_datum = self.decode_data(mix_datum)
157 | # get the mixed fbank
158 | mix_lambda = np.random.beta(10, 10)
159 | fbank = self._wav2fbank(datum['wav'], mix_datum['wav'], mix_lambda)
160 | label_indices = np.zeros(self.label_num) + (self.label_smooth / self.label_num)
161 | for label_str in datum['labels'].split(','):
162 | label_indices[int(self.index_dict[label_str])] += mix_lambda * (1.0 - self.label_smooth)
163 | for label_str in mix_datum['labels'].split(','):
164 | label_indices[int(self.index_dict[label_str])] += (1.0 - mix_lambda) * (1.0 - self.label_smooth)
165 | label_indices = torch.FloatTensor(label_indices)
166 |
167 | else:
168 | datum = self.data[index]
169 | datum = self.decode_data(datum)
170 | # label smooth for negative samples, epsilon/label_num
171 | label_indices = np.zeros(self.label_num) + (self.label_smooth / self.label_num)
172 | fbank = self._wav2fbank(datum['wav'], None, 0)
173 | for label_str in datum['labels'].split(','):
174 | label_indices[int(self.index_dict[label_str])] = 1.0 - self.label_smooth
175 | label_indices = torch.FloatTensor(label_indices)
176 |
177 | # SpecAug, not do for eval set, input feat shape in [25, 1280], i.e. t-f, need to transpose
178 | freqm = torchaudio.transforms.FrequencyMasking(self.freqm)
179 | timem = torchaudio.transforms.TimeMasking(self.timem)
180 | fbank = fbank.transpose(1, 2)
181 | if self.freqm != 0:
182 | fbank = freqm(fbank)
183 | if self.timem != 0:
184 | fbank = timem(fbank)
185 | fbank = fbank.transpose(1, 2)
186 | return fbank, label_indices
187 |
188 | def __len__(self):
189 | return self.num_samples
--------------------------------------------------------------------------------
/src/whisper_at_train/gen_weight_file.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 11/17/20 3:22 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : gen_weight_file.py
7 |
8 | # gen sample weight = sum(label_weight) for label in all labels of the audio clip, where label_weight is the reciprocal of the total sample count of that class.
9 | # Note audioset is a multi-label dataset
10 |
11 | import argparse
12 | import json
13 | import numpy as np
14 | import sys, os, csv
15 |
16 | def make_index_dict(label_csv):
17 | index_lookup = {}
18 | with open(label_csv, 'r') as f:
19 | csv_reader = csv.DictReader(f)
20 | line_count = 0
21 | for row in csv_reader:
22 | index_lookup[row['mid']] = row['index']
23 | line_count += 1
24 | return index_lookup
25 |
26 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
27 | parser.add_argument("--data_path", type=str, default='path_to_your_data_file.json', help="the root path of data json file")
28 |
29 | if __name__ == '__main__':
30 | args = parser.parse_args()
31 | data_path = args.data_path
32 |
33 | index_dict = make_index_dict('./class_labels_indices.csv')
34 | label_count = np.zeros(527)
35 |
36 | with open(data_path, 'r', encoding='utf8')as fp:
37 | data = json.load(fp)
38 | data = data['data']
39 |
40 | for sample in data:
41 | sample_labels = sample['labels'].split(',')
42 | for label in sample_labels:
43 | label_idx = int(index_dict[label])
44 | label_count[label_idx] = label_count[label_idx] + 1
45 |
46 | # the reason not using 1 is to avoid underflow for majority classes, add small value to avoid underflow
47 | label_weight = 1000.0 / (label_count + 0.01)
48 | #label_weight = 1000.0 / (label_count + 0.00)
49 | sample_weight = np.zeros(len(data))
50 |
51 | for i, sample in enumerate(data):
52 | sample_labels = sample['labels'].split(',')
53 | for label in sample_labels:
54 | label_idx = int(index_dict[label])
55 | # summing up the weight of all appeared classes in the sample, note audioset is multiple-label classification
56 | sample_weight[i] += label_weight[label_idx]
57 | np.savetxt(data_path[:-5]+'_weight.csv', sample_weight, delimiter=',')
--------------------------------------------------------------------------------
/src/whisper_at_train/models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/26/23 11:13 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : high_mdls.py
7 |
8 | # high models
9 |
10 | import numpy as np
11 | import torch
12 | import math
13 | import torch.nn.functional as F
14 | from torch import Tensor
15 | from torch import nn
16 | from whisper.model import ResidualAttentionBlock, Linear
17 |
18 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
19 | def norm_cdf(x):
20 | # Computes standard normal cumulative distribution function
21 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
22 |
23 | with torch.no_grad():
24 | # Values are generated by using a truncated uniform distribution and
25 | # then using the inverse CDF for the normal distribution.
26 | # Get upper and lower cdf values
27 | l = norm_cdf((a - mean) / std)
28 | u = norm_cdf((b - mean) / std)
29 |
30 | # Uniformly fill tensor with values from [l, u], then translate to
31 | # [2l-1, 2u-1].
32 | tensor.uniform_(2 * l - 1, 2 * u - 1)
33 |
34 | # Use inverse cdf transform for normal distribution to get truncated
35 | # standard normal
36 | tensor.erfinv_()
37 |
38 | # Transform to proper mean, std
39 | tensor.mul_(std * math.sqrt(2.))
40 | tensor.add_(mean)
41 |
42 | # Clamp to ensure it's in the proper range
43 | tensor.clamp_(min=a, max=b)
44 | return tensor
45 |
46 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
47 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
48 |
49 | class TLTR(nn.Module):
50 | def __init__(self, label_dim=527, n_layer=33, rep_dim=1280, mode='basic'):
51 | super().__init__()
52 | self.mode = mode
53 | self.n_layer = n_layer
54 | self.rep_dim = rep_dim
55 | self.label_dim = label_dim
56 |
57 | # (baseline) mean pool over time and layer, and mlp head
58 | if mode == 'mean_mlp' or mode == 'last_mlp':
59 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.label_dim))
60 |
61 | # (baseline) mean pool over time, and weight average over layers, and mlp head
62 | if mode == 'wa_mlp':
63 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.label_dim))
64 | self.layer_weight = torch.nn.Parameter(torch.tensor([1 / self.n_layer] * self.n_layer))
65 |
66 | # (baseline) mean pool over layer, and apply a original rep_dim transformer
67 | if 'mean_tr' in mode or 'last_tr' in mode:
68 | self.num_att_head = int(mode.split('_')[-1])
69 | self.time_tr = ResidualAttentionBlock(self.rep_dim, self.num_att_head)
70 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.label_dim))
71 |
72 | # (baseline) weight average over layers, and apply a original rep_dim transformer
73 | if 'wa_tr' in mode:
74 | self.num_att_head = int(mode.split('_')[-1])
75 | self.layer_weight = torch.nn.Parameter(torch.tensor([1 / self.n_layer] * self.n_layer))
76 | self.time_tr = ResidualAttentionBlock(self.rep_dim, self.num_att_head)
77 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.label_dim))
78 |
79 | # (baseline) weight average over layers, and apply a low-dimensional transformer
80 | if 'wa_down_tr' in mode: # 512_1
81 | self.inter_rep_dim = int(mode.split('_')[-2])
82 | self.num_att_head = int(mode.split('_')[-1])
83 |
84 | self.down_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.inter_rep_dim))
85 | self.layer_weight = torch.nn.Parameter(torch.tensor([1 / self.n_layer] * self.n_layer))
86 | self.time_tr = ResidualAttentionBlock(self.inter_rep_dim, self.num_att_head)
87 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.inter_rep_dim), nn.Linear(self.inter_rep_dim, self.label_dim))
88 |
89 | # (proposed), tl-tr, weight average over layers, and apply a original rep_dim transformer
90 | if 'lw_tr' in mode:
91 | self.num_tatt_head = int(mode.split('_')[-2])
92 | self.num_latt_head = int(mode.split('_')[-1])
93 | self.time_tr = ResidualAttentionBlock(self.rep_dim, self.num_tatt_head)
94 | self.layer_tr = ResidualAttentionBlock(self.rep_dim, self.num_latt_head)
95 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.label_dim))
96 |
97 | # (proposed), tl-tr with low-dimension projection, lower the dimension of the transformer # lw_down_tr_512_1_8
98 | if 'lw_down_tr' in mode:
99 | self.inter_rep_dim = int(mode.split('_')[-3])
100 | self.num_tatt_head = int(mode.split('_')[-2])
101 | self.num_latt_head = int(mode.split('_')[-1])
102 |
103 | self.down_layer = nn.Sequential(nn.LayerNorm(self.rep_dim), nn.Linear(self.rep_dim, self.inter_rep_dim))
104 | self.time_tr = ResidualAttentionBlock(self.inter_rep_dim, self.num_tatt_head)
105 | self.layer_tr = ResidualAttentionBlock(self.inter_rep_dim, self.num_latt_head)
106 | self.mlp_layer = nn.Sequential(nn.LayerNorm(self.inter_rep_dim), nn.Linear(self.inter_rep_dim, self.label_dim))
107 |
108 | def forward(self, audio_rep):
109 | # audio_rep in shape (# batch size, #whisper_enc_layer, time length after (20x) pooling, whisper_enc_dim)
110 | # e.g., (B, 32, 25, 1280) for whisper large-v1
111 |
112 | # (baseline)
113 | if self.mode == 'mean_mlp':
114 | audio_rep = torch.mean(audio_rep, dim=1)
115 | audio_rep = torch.mean(audio_rep, dim=1)
116 | audio_rep = self.mlp_layer(audio_rep)
117 | return audio_rep
118 |
119 | # (baseline)
120 | elif self.mode == 'last_mlp':
121 | audio_rep = audio_rep[:, -1, :, :] # get the last layer
122 | audio_rep = torch.mean(audio_rep, dim=1)
123 | audio_rep = self.mlp_layer(audio_rep)
124 | return audio_rep
125 |
126 | # (baseline)
127 | elif self.mode == 'wa_mlp':
128 | audio_rep = torch.mean(audio_rep, dim=2) # [B, 32 1280]
129 | audio_rep = torch.permute(audio_rep, (0, 2, 1)) # (B, 1280, 32)
130 | audio_rep = (audio_rep @ self.layer_weight) / self.layer_weight.sum()
131 | audio_rep = self.mlp_layer(audio_rep)
132 | return audio_rep
133 |
134 | # (baseline)
135 | elif 'mean_tr' in self.mode:
136 | audio_rep = torch.mean(audio_rep, dim=1) # [B, 25, 1280]
137 | audio_rep = self.time_tr(audio_rep) # [B, 25, 1280]
138 | audio_rep = torch.mean(audio_rep, dim=1) # [B*32, 1280]
139 | audio_rep = self.mlp_layer(audio_rep)
140 | return audio_rep
141 |
142 | # (baseline) time transformer on the last layer representation
143 | elif 'last_tr' in self.mode:
144 | audio_rep = audio_rep[:, -1, :, :] # [B, 25, 1280]
145 | audio_rep = self.time_tr(audio_rep) # [B, 25, 1280]
146 | audio_rep = torch.mean(audio_rep, dim=1) # [B*32, 1280]
147 | audio_rep = self.mlp_layer(audio_rep)
148 | return audio_rep
149 |
150 | # (baseline) time transformer on the layer-wise weight-average representation
151 | elif 'wa_tr' in self.mode:
152 | audio_rep = torch.permute(audio_rep, (0, 2, 3, 1)) # (B, 25, 1280, 32)
153 | audio_rep = (audio_rep @ self.layer_weight) / self.layer_weight.sum() # [B, 25, 1280]
154 | audio_rep = self.time_tr(audio_rep) # [B, 25, 1280]
155 | audio_rep = torch.mean(audio_rep, dim=1) # [B*25, 1280]
156 | audio_rep = self.mlp_layer(audio_rep)
157 | return audio_rep
158 |
159 | # (baseline) weight average with low-dimension projection
160 | elif 'wa_down_tr' in self.mode:
161 | audio_rep = torch.permute(audio_rep, (0, 2, 3, 1)) # (B, 25, 1280, 32)
162 | audio_rep = (audio_rep @ self.layer_weight) / self.layer_weight.sum() # [B, 25, 1280]
163 | audio_rep = self.down_layer(audio_rep)
164 | audio_rep = self.time_tr(audio_rep) # [B, 25, 1280]
165 | audio_rep = torch.mean(audio_rep, dim=1) # [B*32, 1280]
166 | audio_rep = self.mlp_layer(audio_rep)
167 | return audio_rep
168 |
169 | # (proposed) tl-tr
170 | elif 'lw_tr' in self.mode:
171 | B = audio_rep.shape[0]
172 | audio_rep = audio_rep.reshape(B*self.n_layer, audio_rep.shape[2], audio_rep.shape[3]) # [B*32, 25, 1280]
173 | audio_rep = self.time_tr(audio_rep) # [B*32, 25, 1280]
174 | audio_rep = torch.mean(audio_rep, dim=1) # [B*32, 1280]
175 | audio_rep = audio_rep.reshape(B, self.n_layer, audio_rep.shape[1]) # [B, 32, 1280]
176 | audio_rep = self.layer_tr(audio_rep) # [B, 32, 1280]
177 | audio_rep = torch.mean(audio_rep, dim=1) # [B, 1280]
178 | audio_rep = self.mlp_layer(audio_rep)
179 | return audio_rep
180 |
181 | #(proposed) tl-tr with low-dimensional projection
182 | elif 'lw_down_tr' in self.mode:
183 | B = audio_rep.shape[0]
184 | audio_rep = self.down_layer(audio_rep)
185 | audio_rep = audio_rep.reshape(B*self.n_layer, audio_rep.shape[2], audio_rep.shape[3]) # [B*32, 25, 1280]
186 | audio_rep = self.time_tr(audio_rep) # [B*32, 25, 1280]
187 | audio_rep = torch.mean(audio_rep, dim=1) # [B*32, 1280]
188 | audio_rep = audio_rep.reshape(B, self.n_layer, audio_rep.shape[1]) # [B, 32, 1280]
189 | audio_rep = self.layer_tr(audio_rep) # [B, 32, 1280]
190 | audio_rep = torch.mean(audio_rep, dim=1) # [B, 1280]
191 | audio_rep = self.mlp_layer(audio_rep)
192 | return audio_rep
193 |
194 |
--------------------------------------------------------------------------------
/src/whisper_at_train/run.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 6/11/21 12:57 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : run.py
7 |
8 | import argparse
9 | import os
10 | os.environ['MPLCONFIGDIR'] = './plt/'
11 | os.environ['TRANSFORMERS_CACHE'] = './tr/'
12 | import ast
13 | import pickle
14 | import sys
15 | import time
16 | import torch
17 | from torch.utils.data import WeightedRandomSampler
18 | basepath = os.path.dirname(os.path.dirname(sys.path[0]))
19 | sys.path.append(basepath)
20 | import dataloader_feat as dataloader
21 | import numpy as np
22 | from traintest import train, validate
23 | from models import TLTR
24 |
25 | print("I am process %s, running on %s: starting (%s)" % (os.getpid(), os.uname()[1], time.asctime()))
26 |
27 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28 | parser.add_argument("--data-train", type=str, default='', help="training data json")
29 | parser.add_argument("--data-val", type=str, default='', help="validation data json")
30 | parser.add_argument("--data-eval", type=str, default=None, help="evaluation data json")
31 | parser.add_argument("--label-csv", type=str, default='', help="csv with class labels")
32 | parser.add_argument("--n_class", type=int, default=527, help="number of classes")
33 | parser.add_argument("--model", type=str, default='ast', help="the model used")
34 | parser.add_argument("--dataset", type=str, default="audioset", help="the dataset used")
35 |
36 | parser.add_argument("--exp-dir", type=str, default="", help="directory to dump experiments")
37 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate')
38 | parser.add_argument("--optim", type=str, default="adam", help="training optimizer", choices=["sgd", "adam"])
39 | parser.add_argument('-b', '--batch-size', default=12, type=int, metavar='N', help='mini-batch size')
40 | parser.add_argument('-w', '--num-workers', default=8, type=int, metavar='NW', help='# of workers for dataloading (default: 32)')
41 | parser.add_argument("--n-epochs", type=int, default=1, help="number of maximum training epochs")
42 | # not used in the formal experiments
43 | parser.add_argument("--lr_patience", type=int, default=1, help="how many epoch to wait to reduce lr if mAP doesn't improve")
44 | parser.add_argument("--lr_adapt", help='if use adaptive learning rate', type=ast.literal_eval)
45 | parser.add_argument("--metrics", type=str, default="mAP", help="the main evaluation metrics in finetuning", choices=["mAP", "acc"])
46 | parser.add_argument("--loss", type=str, default="BCE", help="the loss function for finetuning, depend on the task", choices=["BCE", "CE"])
47 | parser.add_argument('--warmup', help='if use warmup learning rate scheduler', type=ast.literal_eval, default='True')
48 | parser.add_argument("--lrscheduler_start", default=10, type=int, help="when to start decay in finetuning")
49 | parser.add_argument("--lrscheduler_step", default=5, type=int, help="the number of step to decrease the learning rate in finetuning")
50 | parser.add_argument("--lrscheduler_decay", default=0.5, type=float, help="the learning rate decay ratio in finetuning")
51 |
52 | parser.add_argument("--wa", help='if do weight averaging in finetuning', type=ast.literal_eval)
53 | parser.add_argument("--wa_start", type=int, default=16, help="which epoch to start weight averaging in finetuning")
54 | parser.add_argument("--wa_end", type=int, default=30, help="which epoch to end weight averaging in finetuning")
55 |
56 | parser.add_argument("--n-print-steps", type=int, default=100, help="number of steps to print statistics")
57 | parser.add_argument('--save_model', help='save the model or not', type=ast.literal_eval)
58 |
59 | parser.add_argument('--freqm', help='frequency mask max length', type=int, default=0)
60 | parser.add_argument('--timem', help='time mask max length', type=int, default=0)
61 | parser.add_argument("--mixup", type=float, default=0, help="how many (0-1) samples need to be mixup during training")
62 | parser.add_argument("--bal", type=str, default=None, help="use balanced sampling or not")
63 | parser.add_argument("--model_size", type=str, default='medium.en', help="The model size")
64 | parser.add_argument("--label_smooth", type=float, default=0.0, help="label smoothing factor")
65 | parser.add_argument("--weight_file", type=str, default=None, help="path to weight file")
66 | parser.add_argument("--ftmode", type=str, default='last', help="pretrained model path")
67 | parser.add_argument("--pretrain_epoch", type=int, default=0, help="number of pretrained epochs")
68 | parser.add_argument("--head_lr", type=float, default=1.0, help="learning rate ratio between mlp/base")
69 | args = parser.parse_args()
70 |
71 | if args.dataset == 'esc':
72 | if args.model_size == 'hubert-xlarge-ls960-ft' or args.model_size == 'wav2vec2-large-robust-ft-swbd-300h':
73 | train_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/' + args.model_size
74 | eval_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/' + args.model_size
75 | else:
76 | train_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/whisper_' + args.model_size
77 | eval_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_esc_pool/whisper_' + args.model_size
78 | shuffle = True
79 | elif args.dataset == 'as-bal' or args.dataset == 'as-full':
80 | if args.model_size == 'hubert-xlarge-ls960-ft' or args.model_size == 'wav2vec2-large-robust-ft-swbd-300h':
81 | train_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_as_full/' + args.model_size
82 | eval_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_as_full/' + args.model_size
83 | else:
84 | train_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_as_full/whisper_' + args.model_size
85 | eval_tar_path = '/data/sls/scratch/yuangong/whisper-a/feat_as_eval/whisper_' + args.model_size
86 | shuffle = True
87 |
88 | audio_conf = {'freqm': args.freqm, 'timem': args.timem, 'mixup': args.mixup, 'dataset': args.dataset, 'label_smooth': args.label_smooth, 'tar_path': train_tar_path}
89 | val_audio_conf = {'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset, 'tar_path': eval_tar_path}
90 |
91 | if args.bal == 'bal':
92 | print('balanced sampler is being used')
93 | if args.weight_file == None:
94 | samples_weight = np.loadtxt(args.data_train[:-5]+'_weight.csv', delimiter=',')
95 | else:
96 | samples_weight = np.loadtxt(args.data_train[:-5] + '_' + args.weight_file + '.csv', delimiter=',')
97 | sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
98 |
99 | train_loader = torch.utils.data.DataLoader(
100 | dataloader.AudiosetDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf),
101 | batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, pin_memory=True, drop_last=True)
102 | else:
103 | print('balanced sampler is not used')
104 | train_loader = torch.utils.data.DataLoader(
105 | dataloader.AudiosetDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf),
106 | batch_size=args.batch_size, shuffle=shuffle, num_workers=args.num_workers, pin_memory=True, drop_last=True)
107 |
108 | val_loader = torch.utils.data.DataLoader(
109 | dataloader.AudiosetDataset(args.data_val, label_csv=args.label_csv, audio_conf=val_audio_conf),
110 | batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=True)
111 |
112 | if args.data_eval != None:
113 | eval_loader = torch.utils.data.DataLoader(
114 | dataloader.AudiosetDataset(args.data_eval, label_csv=args.label_csv, audio_conf=val_audio_conf),
115 | batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
116 |
117 | def get_feat_shape(path, args):
118 | mdl_size = args.model_size
119 | n_rep_dim_dict = {'tiny.en': 384, 'tiny': 384, 'base.en': 512, 'base': 512, 'small.en': 768, 'small': 768, 'medium.en': 1024, 'medium': 1024, 'large-v1': 1280, 'large-v2': 1280, 'wav2vec2-large-robust-ft-swbd-300h': 1024, 'hubert-xlarge-ls960-ft': 1280}
120 | n_layer_dict = {'tiny.en': 4, 'tiny': 4, 'base.en': 6, 'base': 6, 'small.en': 12, 'small': 12, 'medium.en': 24, 'medium': 24, 'large-v1': 32, 'large-v2': 32, 'wav2vec2-large-robust-ft-swbd-300h': 24, 'hubert-xlarge-ls960-ft': 48}
121 | return n_layer_dict[mdl_size], n_rep_dim_dict[mdl_size]
122 |
123 | if 'whisper-high' in args.model:
124 | mode = args.model.split('-')[-1]
125 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126 | n_layer, rep_dim = get_feat_shape(train_tar_path, args)
127 | print(mode, args.model_size, n_layer, rep_dim)
128 | audio_model = TLTR(label_dim=args.n_class, n_layer=n_layer, rep_dim=rep_dim, mode=mode)
129 | else:
130 | raise ValueError('model not supported')
131 |
132 | # use data parallel
133 | if not isinstance(audio_model, torch.nn.DataParallel):
134 | audio_model = torch.nn.DataParallel(audio_model)
135 |
136 | print("\nCreating experiment directory: %s" % args.exp_dir)
137 | try:
138 | os.makedirs("%s/models" % args.exp_dir)
139 | except:
140 | pass
141 | with open("%s/args.pkl" % args.exp_dir, "wb") as f:
142 | pickle.dump(args, f)
143 |
144 | code_path = args.exp_dir + '/src/'
145 | if os.path.exists(code_path) == False:
146 | os.mkdir(code_path)
147 | copy_path = '/data/sls/scratch/yuangong/whisper-a/src/'
148 | os.system('cp ' + copy_path + '/*.sh ' + code_path)
149 | os.system('cp ' + copy_path + '/*.py ' + code_path)
150 |
151 | print('Now starting training for {:d} epochs'.format(args.n_epochs))
152 | train(audio_model, train_loader, val_loader, args)
153 |
154 |
155 | def wa_model(exp_dir, start_epoch=16, end_epoch=30):
156 | sdA = torch.load(exp_dir + '/models/audio_model.' + str(start_epoch) + '.pth', map_location=device)
157 | model_cnt = 1
158 | for epoch in range(start_epoch+1, end_epoch+1):
159 | if os.path.exists(exp_dir + '/models/audio_model.' + str(epoch) + '.pth') == True:
160 | sdB = torch.load(exp_dir + '/models/audio_model.' + str(epoch) + '.pth', map_location=device)
161 | for key in sdA:
162 | sdA[key] = sdA[key] + sdB[key]
163 | model_cnt += 1
164 | print('wa {:d} models from {:d} to {:d}'.format(model_cnt, start_epoch, end_epoch))
165 | # averaging
166 | for key in sdA:
167 | sdA[key] = sdA[key] / float(model_cnt)
168 | torch.save(sdA, exp_dir + '/models/audio_model_wa.pth')
169 | return sdA
170 |
171 | # do model weight averaging
172 | if args.wa == True:
173 | sdA = wa_model(args.exp_dir, args.wa_start, args.wa_end)
174 | msg = audio_model.load_state_dict(sdA, strict=True)
175 | print(msg)
176 | audio_model.eval()
177 | stats, _ = validate(audio_model, val_loader, args)
178 | wa_res = np.mean([stat['AP'] for stat in stats])
179 | print('mAP of model with weights averaged from checkpoint {:d}-{:d} is {:.4f}'.format(args.wa_start, args.wa_end, wa_res))
180 | np.savetxt(args.exp_dir + '/wa_res.csv', [args.wa_start, args.wa_end, wa_res], delimiter=',')
--------------------------------------------------------------------------------
/src/whisper_at_train/run_as_full_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p a5
3 | #SBATCH --gres=gpu:1
4 | #SBATCH -c 16
5 | #SBATCH --qos regular
6 | #SBATCH --mem=48000
7 | #SBATCH --job-name="w-as-high"
8 | #SBATCH --output=./log/%j_as.txt
9 |
10 | set -x
11 | # comment this line if not running on sls cluster
12 | . /data/sls/scratch/share-201907/slstoolchainrc
13 | source /data/sls/scratch/yuangong/whisper-a/venv-a5/bin/activate
14 | export TORCH_HOME=../../pretrained_models
15 |
16 | lr=5e-5
17 | freqm=0
18 | timem=10
19 | mixup=0.5
20 | batch_size=48
21 | model=whisper-high-lw_tr_1_8 #whisper-high-lw_tr_1_8 (tl-tr, lr=5e-5) whisper-high-lw_down_tr_512_1_8 (tl-tr-512, w/ low-dim proj, lr=1e-4)
22 | model_size=large-v2
23 |
24 | dataset=as-full
25 | bal=bal
26 | epoch=30
27 | lrscheduler_start=15
28 | lrscheduler_decay=0.75
29 | lrscheduler_step=5
30 | wa=True
31 | wa_start=16
32 | wa_end=30
33 | lr_adapt=False
34 | tr_data=/data/sls/scratch/yuangong/aed-pc/src/enhance_label/datafiles_local/whole_train_data.json
35 | te_data=/data/sls/scratch/yuangong/aed-pc/src/enhance_label/datafiles_local/eval_data.json
36 | label_smooth=0.1
37 |
38 | exp_dir=./exp/test-${dataset}-${model}-${model_size}-${lr}-${lrscheduler_start}-${lrscheduler_decay}-bs${batch_size}-lda${lr_adapt}-mix${mixup}-${freqm}-${timem}
39 | mkdir -p $exp_dir
40 |
41 | python -W ignore ./run.py --model ${model} --dataset ${dataset} \
42 | --data-train ${tr_data} --data-val ${te_data} --exp-dir $exp_dir \
43 | --label-csv /data/sls/scratch/yuangong/convast/egs/audioset/data/class_labels_indices.csv --n_class 527 \
44 | --lr $lr --n-epochs ${epoch} --batch-size $batch_size --save_model True \
45 | --freqm $freqm --timem $timem --mixup ${mixup} --bal ${bal} \
46 | --model_size ${model_size} --label_smooth ${label_smooth} \
47 | --lrscheduler_start ${lrscheduler_start} --lrscheduler_decay ${lrscheduler_decay} --lrscheduler_step ${lrscheduler_step} \
48 | --loss BCE --metrics mAP --warmup True \
49 | --wa ${wa} --wa_start ${wa_start} --wa_end ${wa_end} --lr_adapt ${lr_adapt} \
50 | --num-workers 8
--------------------------------------------------------------------------------
/src/whisper_at_train/traintest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 6/10/21 11:00 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : traintest.py
7 |
8 | import sys
9 | import os
10 | import datetime
11 | sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))
12 | from utilities import *
13 | import time
14 | import torch
15 | from torch import nn
16 | import numpy as np
17 | import pickle
18 | from torch.cuda.amp import autocast,GradScaler
19 |
20 | def train(audio_model, train_loader, test_loader, args):
21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 | print('running on ' + str(device))
23 | torch.set_grad_enabled(True)
24 |
25 | # Initialize all of the statistics we want to keep track of
26 | batch_time = AverageMeter()
27 | per_sample_time = AverageMeter()
28 | data_time = AverageMeter()
29 | per_sample_data_time = AverageMeter()
30 | loss_meter = AverageMeter()
31 | per_sample_dnn_time = AverageMeter()
32 | progress = []
33 | # best_cum_mAP is checkpoint ensemble from the first epoch to the best epoch
34 | best_epoch, best_cum_epoch, best_mAP, best_acc, best_cum_mAP = 0, 0, -np.inf, -np.inf, -np.inf
35 | global_step, epoch = 0, 0
36 | start_time = time.time()
37 | exp_dir = args.exp_dir
38 |
39 | def _save_progress():
40 | progress.append([epoch, global_step, best_epoch, best_mAP, time.time() - start_time])
41 | with open("%s/progress.pkl" % exp_dir, "wb") as f:
42 | pickle.dump(progress, f)
43 |
44 | if not isinstance(audio_model, nn.DataParallel):
45 | audio_model = nn.DataParallel(audio_model)
46 |
47 | audio_model = audio_model.to(device)
48 |
49 | trainables = [p for p in audio_model.parameters() if p.requires_grad]
50 | print('Total parameter number is : {:.3f} million'.format(sum(p.numel() for p in audio_model.parameters()) / 1e6))
51 | print('Total trainable parameter number is : {:.3f} million'.format(sum(p.numel() for p in trainables) / 1e6))
52 |
53 | optimizer = torch.optim.Adam(trainables, args.lr, weight_decay=5e-7, betas=(0.95, 0.999))
54 |
55 | if args.lr_adapt == True:
56 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=args.lr_patience, verbose=True)
57 | print('Override to use adaptive learning rate scheduler.')
58 | else:
59 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(args.lrscheduler_start, 1000, args.lrscheduler_step)),gamma=args.lrscheduler_decay)
60 | print('The learning rate scheduler starts at {:d} epoch with decay rate of {:.3f} every {:d} epoches'.format(args.lrscheduler_start, args.lrscheduler_decay, args.lrscheduler_step))
61 | main_metrics = args.metrics
62 | if args.loss == 'BCE':
63 | loss_fn = nn.BCEWithLogitsLoss()
64 | elif args.loss == 'CE':
65 | loss_fn = nn.CrossEntropyLoss()
66 | args.loss_fn = loss_fn
67 |
68 | print('now training with {:s}, main metrics: {:s}, loss function: {:s}, learning rate scheduler: {:s}'.format(str(args.dataset), str(main_metrics), str(loss_fn), str(scheduler)))
69 |
70 | epoch += 1
71 | scaler = GradScaler()
72 |
73 | print("current #steps=%s, #epochs=%s" % (global_step, epoch))
74 | print("start training...")
75 | result = np.zeros([args.n_epochs, 4])
76 | audio_model.train()
77 | while epoch < args.n_epochs + 1:
78 | begin_time = time.time()
79 | end_time = time.time()
80 | audio_model.train()
81 | print('---------------')
82 | print(datetime.datetime.now())
83 | print("current #epochs=%s, #steps=%s" % (epoch, global_step))
84 |
85 | for i, (a_input, labels) in enumerate(train_loader):
86 |
87 | B = a_input.size(0)
88 | a_input = a_input.to(device, non_blocking=True)
89 | labels = labels.to(device, non_blocking=True)
90 |
91 | data_time.update(time.time() - end_time)
92 | per_sample_data_time.update((time.time() - end_time) / a_input.shape[0])
93 | dnn_start_time = time.time()
94 |
95 | with autocast():
96 | audio_output = audio_model(a_input)
97 | loss = loss_fn(audio_output, labels)
98 |
99 | # optimiztion if amp is used
100 | optimizer.zero_grad()
101 | scaler.scale(loss).backward()
102 | scaler.step(optimizer)
103 | scaler.update()
104 |
105 | # record loss
106 | loss_meter.update(loss.item(), B)
107 | batch_time.update(time.time() - end_time)
108 | per_sample_time.update((time.time() - end_time)/a_input.shape[0])
109 | per_sample_dnn_time.update((time.time() - dnn_start_time)/a_input.shape[0])
110 |
111 | print_step = global_step % args.n_print_steps == 0
112 | early_print_step = epoch == 0 and global_step % (args.n_print_steps/10) == 0
113 | print_step = print_step or early_print_step
114 |
115 | if print_step and global_step != 0:
116 | print('Epoch: [{0}][{1}/{2}]\t'
117 | 'Per Sample Total Time {per_sample_time.avg:.5f}\t'
118 | 'Per Sample Data Time {per_sample_data_time.avg:.5f}\t'
119 | 'Per Sample DNN Time {per_sample_dnn_time.avg:.5f}\t'
120 | 'Train Loss {loss_meter.val:.4f}\t'.format(
121 | epoch, i, len(train_loader), per_sample_time=per_sample_time, per_sample_data_time=per_sample_data_time,
122 | per_sample_dnn_time=per_sample_dnn_time, loss_meter=loss_meter), flush=True)
123 | if np.isnan(loss_meter.avg):
124 | print("training diverged...")
125 | return
126 |
127 | end_time = time.time()
128 | global_step += 1
129 |
130 | # for audioset-full, break every 10% of the epoch, i.e., equivalent epochs = 0.1 * specified epochs
131 | if args.dataset == 'as-full':
132 | if i > 0.1 * len(train_loader):
133 | break
134 |
135 | print('start validation')
136 |
137 | stats, valid_loss = validate(audio_model, test_loader, args)
138 |
139 | mAP = np.mean([stat['AP'] for stat in stats])
140 | mAUC = np.mean([stat['auc'] for stat in stats])
141 | acc = stats[0]['acc']
142 |
143 | if main_metrics == 'mAP':
144 | print("mAP: {:.6f}".format(mAP))
145 | else:
146 | print("acc: {:.6f}".format(acc))
147 | print("AUC: {:.6f}".format(mAUC))
148 | print("d_prime: {:.6f}".format(d_prime(mAUC)))
149 | print("train_loss: {:.6f}".format(loss_meter.avg))
150 | print("valid_loss: {:.6f}".format(valid_loss))
151 |
152 | result[epoch-1, :] = [acc, mAP, mAUC, optimizer.param_groups[0]['lr']]
153 | np.savetxt(exp_dir + '/result.csv', result, delimiter=',')
154 | print('validation finished')
155 |
156 | if mAP > best_mAP:
157 | best_mAP = mAP
158 | if main_metrics == 'mAP':
159 | best_epoch = epoch
160 |
161 | if acc > best_acc:
162 | best_acc = acc
163 | if main_metrics == 'acc':
164 | best_epoch = epoch
165 |
166 | if best_epoch == epoch:
167 | pass
168 | #torch.save(audio_model.state_dict(), "%s/models/best_audio_model.pth" % (exp_dir))
169 | if args.save_model == True:
170 | torch.save(audio_model.state_dict(), "%s/models/audio_model.%d.pth" % (exp_dir, epoch))
171 |
172 | if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
173 | if main_metrics == 'mAP':
174 | scheduler.step(mAP)
175 | elif main_metrics == 'acc':
176 | scheduler.step(acc)
177 | else:
178 | scheduler.step()
179 |
180 | print('Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[0]['lr']))
181 |
182 | with open(exp_dir + '/stats_' + str(epoch) +'.pickle', 'wb') as handle:
183 | pickle.dump(stats, handle, protocol=pickle.HIGHEST_PROTOCOL)
184 | _save_progress()
185 |
186 | finish_time = time.time()
187 | print('epoch {:d} training time: {:.3f}'.format(epoch, finish_time-begin_time))
188 |
189 | epoch += 1
190 |
191 | batch_time.reset()
192 | per_sample_time.reset()
193 | data_time.reset()
194 | per_sample_data_time.reset()
195 | loss_meter.reset()
196 | per_sample_dnn_time.reset()
197 |
198 | def validate(audio_model, val_loader, args):
199 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
200 | batch_time = AverageMeter()
201 | if not isinstance(audio_model, nn.DataParallel):
202 | audio_model = nn.DataParallel(audio_model)
203 | audio_model = audio_model.to(device)
204 | # switch to evaluate mode
205 | audio_model.eval()
206 |
207 | end = time.time()
208 | A_predictions = []
209 | A_targets = []
210 | A_loss = []
211 | with torch.no_grad():
212 | for i, (a_input, labels) in enumerate(val_loader):
213 | a_input = a_input.to(device, non_blocking=True)
214 |
215 | with autocast():
216 | audio_output = audio_model(a_input)
217 |
218 | predictions = audio_output.to('cpu').detach()
219 |
220 | A_predictions.append(predictions)
221 | A_targets.append(labels)
222 |
223 | # compute the loss
224 | labels = labels.to(device)
225 | loss = args.loss_fn(audio_output, labels)
226 | A_loss.append(loss.to('cpu').detach())
227 |
228 | batch_time.update(time.time() - end)
229 | end = time.time()
230 |
231 | audio_output = torch.cat(A_predictions)
232 | target = torch.cat(A_targets)
233 | loss = np.mean(A_loss)
234 | stats = calculate_stats(audio_output, target)
235 | return stats, loss
--------------------------------------------------------------------------------
/src/whisper_at_train/utilities/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 6/19/21 4:39 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : __init__.py
7 |
8 | from .util import *
9 | from .stats import *
--------------------------------------------------------------------------------
/src/whisper_at_train/utilities/compute_flops.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 3/7/23 1:13 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : check_flops.py
7 |
8 | # check model size and flops
9 | import torch
10 | from fvcore.nn import FlopCountAnalysis
11 | from fvcore.nn import flop_count_table
12 | from high_mdls import HighMDL, HighMDLPool, HighMDLLayer, HighMDLFormal
13 | # from whisper.model import Whisper, ModelDimensions
14 |
15 | def cnt_flops(model, input):
16 | flops = FlopCountAnalysis(model, input)
17 | print(flop_count_table(flops))
18 | print(flops.total()/1e9)
19 | print(flops.by_operator())
20 | #print(flops.by_module())
21 | #print(flops.by_module_and_operator())
22 |
23 | # # original whisper model
24 | # checkpoint_path = '/data/sls/scratch/yuangong/whisper-a/src/{:s}.pt'.format('small.en')
25 | # checkpoint = torch.load(checkpoint_path, map_location='cpu')
26 | # dims = ModelDimensions(**checkpoint["dims"])
27 | # print(dims)
28 | # model = Whisper(dims, label_dim=527, cla='mlp_1')
29 | # input = torch.rand([1, 80, 512*2])
30 | # cnt_flops(model, input)
31 |
32 |
33 | def get_feat_shape(mdl_size):
34 | n_rep_dim_dict = {'tiny.en': 384, 'tiny': 384, 'base.en': 512, 'base': 512, 'small.en': 768, 'small': 768, 'medium.en': 1024, 'medium': 1024, 'large-v1': 1280, 'large-v2': 1280, 'wav2vec2-large-robust-ft-swbd-300h': 1024, 'hubert-xlarge-ls960-ft': 1280}
35 | n_layer_dict = {'tiny.en': 4, 'tiny': 4, 'base.en': 6, 'base': 6, 'small.en': 12, 'small': 12, 'medium.en': 24, 'medium': 24, 'large-v1': 32, 'large-v2': 32, 'wav2vec2-large-robust-ft-swbd-300h': 24, 'hubert-xlarge-ls960-ft': 48}
36 | return n_layer_dict[mdl_size], n_rep_dim_dict[mdl_size]
37 |
38 | model_name = 'whisper-high-lw_down_tr_768_1_8'
39 | model_size = 'large-v1'
40 | mode = model_name.split('-')[-1]
41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 | n_layer, rep_dim = get_feat_shape(model_size)
43 | print(mode, model_size, n_layer, rep_dim)
44 | model = HighMDLFormal(label_dim=527, n_layer=n_layer, rep_dim=rep_dim, mode=mode)
45 |
46 | # for large-v1
47 | cnt_flops(model, torch.rand([1, n_layer, 25, rep_dim]))
--------------------------------------------------------------------------------
/src/whisper_at_train/utilities/compute_mAP.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 6/1/23 12:40 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : compute_mAP.py
7 |
8 | # compute mAP on whisper-at
9 |
10 | from stats import *
11 |
12 | mdl_size_list = [
13 | 'tiny_low_False',
14 | 'tiny.en_low_False',
15 | 'base_low_False',
16 | 'base.en_low_False',
17 | 'small_low_False',
18 | 'small_low_True',
19 | 'small.en_low_False',
20 | 'small.en_low_True',
21 | 'medium_low_False',
22 | 'medium_low_True',
23 | 'medium.en_low_False',
24 | 'medium.en_low_True',
25 | 'large-v1_low_False',
26 | 'large-v1_low_True',
27 | 'large-v2_low_False',
28 | 'large-v2_low_True']
29 |
30 | for mdl_size in mdl_size_list:
31 | all_truth = np.load('/data/sls/scratch/yuangong/whisper-at/old/at_res/all_truth_' + mdl_size + '.npy')
32 | all_pred = np.load('/data/sls/scratch/yuangong/whisper-at/old/at_res/all_pred_' + mdl_size + '.npy')
33 | print(mdl_size)
34 | print(all_truth.shape, all_pred.shape)
35 |
36 | stats = calculate_stats(all_pred, all_truth)
37 | mAP = np.mean([stat['AP'] for stat in stats])
38 | print(mAP)
--------------------------------------------------------------------------------
/src/whisper_at_train/utilities/rename_state_dict.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 5/30/23 3:00 AM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : rename_state_dict.py
7 |
8 | # rename state dict (trained with feats) to put together with whisper-at model
9 |
10 | import torch
11 | import os
12 |
13 | def get_immediate_files_with_extension(directory, extension='pth'):
14 | file_list = []
15 | for file in os.listdir(directory):
16 | if file.endswith(extension) and os.path.isfile(os.path.join(directory, file)):
17 | file_list.append(os.path.join(directory, file))
18 | return file_list
19 |
20 | def replace_name(ori_mdl_path, tar_mdl_path):
21 | sd = torch.load(ori_mdl_path, map_location='cpu')
22 | mdl_key_list = sd.keys()
23 |
24 | whisper_at_dict = {}
25 | for mdl_key in mdl_key_list:
26 | new_mdl_key = mdl_key.replace('module.', 'at_model.')
27 | #print(new_mdl_key)
28 | whisper_at_dict[new_mdl_key] = sd[mdl_key]
29 |
30 | print(len(sd.keys()), len(whisper_at_dict.keys()))
31 | torch.save(whisper_at_dict, tar_mdl_path)
32 |
33 | mdl_list = get_immediate_files_with_extension('/data/sls/scratch/yuangong/whisper-at/exp/')
34 | print(mdl_list)
35 | tar_path = '/data/sls/scratch/yuangong/whisper-at/exp/converted_to_whisper_at/'
36 | for mdl in mdl_list:
37 | print(mdl)
38 | print('-----------------------')
39 | replace_name(mdl, tar_path + mdl.split('/')[-1])
40 |
--------------------------------------------------------------------------------
/src/whisper_at_train/utilities/stats.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import stats
3 | from sklearn import metrics
4 | import torch
5 |
6 | def d_prime(auc):
7 | standard_normal = stats.norm()
8 | d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
9 | return d_prime
10 |
11 | def calculate_stats(output, target):
12 | """Calculate statistics including mAP, AUC, etc.
13 |
14 | Args:
15 | output: 2d array, (samples_num, classes_num)
16 | target: 2d array, (samples_num, classes_num)
17 |
18 | Returns:
19 | stats: list of statistic of each class.
20 | """
21 |
22 | classes_num = target.shape[-1]
23 | stats = []
24 |
25 | # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
26 | acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
27 |
28 | # Class-wise statistics
29 | for k in range(classes_num):
30 |
31 | # Average precision
32 | avg_precision = metrics.average_precision_score(
33 | target[:, k], output[:, k], average=None)
34 |
35 | # AUC
36 | try:
37 | auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)
38 |
39 | # Precisions, recalls
40 | (precisions, recalls, thresholds) = metrics.precision_recall_curve(
41 | target[:, k], output[:, k])
42 |
43 | # FPR, TPR
44 | (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])
45 |
46 | save_every_steps = 1000 # Sample statistics to reduce size
47 | dict = {'precisions': precisions[0::save_every_steps],
48 | 'recalls': recalls[0::save_every_steps],
49 | 'AP': avg_precision,
50 | 'fpr': fpr[0::save_every_steps],
51 | 'fnr': 1. - tpr[0::save_every_steps],
52 | 'auc': auc,
53 | # note acc is not class-wise, this is just to keep consistent with other metrics
54 | 'acc': acc
55 | }
56 | except:
57 | dict = {'precisions': -1,
58 | 'recalls': -1,
59 | 'AP': avg_precision,
60 | 'fpr': -1,
61 | 'fnr': -1,
62 | 'auc': -1,
63 | # note acc is not class-wise, this is just to keep consistent with other metrics
64 | 'acc': acc
65 | }
66 | print('class {:s} no true sample'.format(str(k)))
67 | stats.append(dict)
68 |
69 | return stats
--------------------------------------------------------------------------------
/src/whisper_at_train/utilities/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pickle
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import random
7 | from collections import namedtuple
8 |
9 | def calc_recalls(S):
10 | """
11 | Computes recall at 1, 5, and 10 given a similarity matrix S.
12 | By convention, rows of S are assumed to correspond to images and columns are captions.
13 | """
14 | assert(S.dim() == 2)
15 | assert(S.size(0) == S.size(1))
16 | if isinstance(S, torch.autograd.Variable):
17 | S = S.data
18 | n = S.size(0)
19 | A2I_scores, A2I_ind = S.topk(10, 0)
20 | I2A_scores, I2A_ind = S.topk(10, 1)
21 | A_r1 = AverageMeter()
22 | A_r5 = AverageMeter()
23 | A_r10 = AverageMeter()
24 | I_r1 = AverageMeter()
25 | I_r5 = AverageMeter()
26 | I_r10 = AverageMeter()
27 | for i in range(n):
28 | A_foundind = -1
29 | I_foundind = -1
30 | for ind in range(10):
31 | if A2I_ind[ind, i] == i:
32 | I_foundind = ind
33 | if I2A_ind[i, ind] == i:
34 | A_foundind = ind
35 | # do r1s
36 | if A_foundind == 0:
37 | A_r1.update(1)
38 | else:
39 | A_r1.update(0)
40 | if I_foundind == 0:
41 | I_r1.update(1)
42 | else:
43 | I_r1.update(0)
44 | # do r5s
45 | if A_foundind >= 0 and A_foundind < 5:
46 | A_r5.update(1)
47 | else:
48 | A_r5.update(0)
49 | if I_foundind >= 0 and I_foundind < 5:
50 | I_r5.update(1)
51 | else:
52 | I_r5.update(0)
53 | # do r10s
54 | if A_foundind >= 0 and A_foundind < 10:
55 | A_r10.update(1)
56 | else:
57 | A_r10.update(0)
58 | if I_foundind >= 0 and I_foundind < 10:
59 | I_r10.update(1)
60 | else:
61 | I_r10.update(0)
62 |
63 | recalls = {'A_r1':A_r1.avg, 'A_r5':A_r5.avg, 'A_r10':A_r10.avg,
64 | 'I_r1':I_r1.avg, 'I_r5':I_r5.avg, 'I_r10':I_r10.avg}
65 | #'A_meanR':A_meanR.avg, 'I_meanR':I_meanR.avg}
66 |
67 | return recalls
68 |
69 | def computeMatchmap(I, A):
70 | assert(I.dim() == 3)
71 | assert(A.dim() == 2)
72 | D = I.size(0)
73 | H = I.size(1)
74 | W = I.size(2)
75 | T = A.size(1)
76 | Ir = I.view(D, -1).t()
77 | matchmap = torch.mm(Ir, A)
78 | matchmap = matchmap.view(H, W, T)
79 | return matchmap
80 |
81 | def matchmapSim(M, simtype):
82 | assert(M.dim() == 3)
83 | if simtype == 'SISA':
84 | return M.mean()
85 | elif simtype == 'MISA':
86 | M_maxH, _ = M.max(0)
87 | M_maxHW, _ = M_maxH.max(0)
88 | return M_maxHW.mean()
89 | elif simtype == 'SIMA':
90 | M_maxT, _ = M.max(2)
91 | return M_maxT.mean()
92 | else:
93 | raise ValueError
94 |
95 | def sampled_margin_rank_loss(image_outputs, audio_outputs, nframes, margin=1., simtype='MISA'):
96 | """
97 | Computes the triplet margin ranking loss for each anchor image/caption pair
98 | The impostor image/caption is randomly sampled from the minibatch
99 | """
100 | assert(image_outputs.dim() == 4)
101 | assert(audio_outputs.dim() == 3)
102 | n = image_outputs.size(0)
103 | loss = torch.zeros(1, device=image_outputs.device, requires_grad=True)
104 | for i in range(n):
105 | I_imp_ind = i
106 | A_imp_ind = i
107 | while I_imp_ind == i:
108 | I_imp_ind = np.random.randint(0, n)
109 | while A_imp_ind == i:
110 | A_imp_ind = np.random.randint(0, n)
111 | nF = nframes[i]
112 | nFimp = nframes[A_imp_ind]
113 | anchorsim = matchmapSim(computeMatchmap(image_outputs[i], audio_outputs[i][:, 0:nF]), simtype)
114 | Iimpsim = matchmapSim(computeMatchmap(image_outputs[I_imp_ind], audio_outputs[i][:, 0:nF]), simtype)
115 | Aimpsim = matchmapSim(computeMatchmap(image_outputs[i], audio_outputs[A_imp_ind][:, 0:nFimp]), simtype)
116 | A2I_simdif = margin + Iimpsim - anchorsim
117 | if (A2I_simdif.data > 0).all():
118 | loss = loss + A2I_simdif
119 | I2A_simdif = margin + Aimpsim - anchorsim
120 | if (I2A_simdif.data > 0).all():
121 | loss = loss + I2A_simdif
122 | loss = loss / n
123 | return loss
124 |
125 | def compute_matchmap_similarity_matrix(image_outputs, audio_outputs, nframes, simtype='MISA'):
126 | """
127 | Assumes image_outputs is a (batchsize, embedding_dim, rows, height) tensor
128 | Assumes audio_outputs is a (batchsize, embedding_dim, 1, time) tensor
129 | Returns similarity matrix S where images are rows and audios are along the columns
130 | """
131 | assert(image_outputs.dim() == 4)
132 | assert(audio_outputs.dim() == 3)
133 | n = image_outputs.size(0)
134 | S = torch.zeros(n, n, device=image_outputs.device)
135 | for image_idx in range(n):
136 | for audio_idx in range(n):
137 | nF = max(1, nframes[audio_idx])
138 | S[image_idx, audio_idx] = matchmapSim(computeMatchmap(image_outputs[image_idx], audio_outputs[audio_idx][:, 0:nF]), simtype)
139 | return S
140 |
141 | def compute_pooldot_similarity_matrix(image_outputs, audio_outputs, nframes):
142 | """
143 | Assumes image_outputs is a (batchsize, embedding_dim, rows, height) tensor
144 | Assumes audio_outputs is a (batchsize, embedding_dim, 1, time) tensor
145 | Returns similarity matrix S where images are rows and audios are along the columns
146 | S[i][j] is computed as the dot product between the meanpooled embeddings of
147 | the ith image output and jth audio output
148 | """
149 | assert(image_outputs.dim() == 4)
150 | assert(audio_outputs.dim() == 4)
151 | n = image_outputs.size(0)
152 | imagePoolfunc = nn.AdaptiveAvgPool2d((1, 1))
153 | pooled_image_outputs = imagePoolfunc(image_outputs).squeeze(3).squeeze(2)
154 | audioPoolfunc = nn.AdaptiveAvgPool2d((1, 1))
155 | pooled_audio_outputs_list = []
156 | for idx in range(n):
157 | nF = max(1, nframes[idx])
158 | pooled_audio_outputs_list.append(audioPoolfunc(audio_outputs[idx][:, :, 0:nF]).unsqueeze(0))
159 | pooled_audio_outputs = torch.cat(pooled_audio_outputs_list).squeeze(3).squeeze(2)
160 | S = torch.mm(pooled_image_outputs, pooled_audio_outputs.t())
161 | return S
162 |
163 | def one_imposter_index(i, N):
164 | imp_ind = random.randint(0, N - 2)
165 | if imp_ind == i:
166 | imp_ind = N - 1
167 | return imp_ind
168 |
169 | def basic_get_imposter_indices(N):
170 | imposter_idc = []
171 | for i in range(N):
172 | # Select an imposter index for example i:
173 | imp_ind = one_imposter_index(i, N)
174 | imposter_idc.append(imp_ind)
175 | return imposter_idc
176 |
177 | def semihardneg_triplet_loss_from_S(S, margin):
178 | """
179 | Input: Similarity matrix S as an autograd.Variable
180 | Output: The one-way triplet loss from rows of S to columns of S. Impostors are taken
181 | to be the most similar point to the anchor that is still less similar to the anchor
182 | than the positive example.
183 | You would need to run this function twice, once with S and once with S.t(),
184 | in order to compute the triplet loss in both directions.
185 | """
186 | assert(S.dim() == 2)
187 | assert(S.size(0) == S.size(1))
188 | N = S.size(0)
189 | loss = torch.autograd.Variable(torch.zeros(1).type(S.data.type()), requires_grad=True)
190 | # Imposter - ground truth
191 | Sdiff = S - torch.diag(S).view(-1, 1)
192 | eps = 1e-12
193 | # All examples less similar than ground truth
194 | mask = (Sdiff < -eps).type(torch.LongTensor)
195 | maskf = mask.type_as(S)
196 | # Mask out all examples >= gt with minimum similarity
197 | Sp = maskf * Sdiff + (1 - maskf) * torch.min(Sdiff).detach()
198 | # Find the index maximum similar of the remaining
199 | _, idc = Sp.max(dim=1)
200 | idc = idc.data.cpu()
201 | # Vector mask: 1 iff there exists an example < gt
202 | has_neg = (mask.sum(dim=1) > 0).data.type(torch.LongTensor)
203 | # Random imposter indices
204 | random_imp_ind = torch.LongTensor(basic_get_imposter_indices(N))
205 | # Use hardneg if there exists an example < gt, otherwise use random imposter
206 | imp_idc = has_neg * idc + (1 - has_neg) * random_imp_ind
207 | # This could probably be vectorized too, but I haven't.
208 | for i, imp in enumerate(imp_idc):
209 | local_loss = Sdiff[i, imp] + margin
210 | if (local_loss.data > 0).all():
211 | loss = loss + local_loss
212 | loss = loss / N
213 | return loss
214 |
215 | def sampled_triplet_loss_from_S(S, margin):
216 | """
217 | Input: Similarity matrix S as an autograd.Variable
218 | Output: The one-way triplet loss from rows of S to columns of S. Imposters are
219 | randomly sampled from the columns of S.
220 | You would need to run this function twice, once with S and once with S.t(),
221 | in order to compute the triplet loss in both directions.
222 | """
223 | assert(S.dim() == 2)
224 | assert(S.size(0) == S.size(1))
225 | N = S.size(0)
226 | loss = torch.autograd.Variable(torch.zeros(1).type(S.data.type()), requires_grad=True)
227 | # Imposter - ground truth
228 | Sdiff = S - torch.diag(S).view(-1, 1)
229 | imp_ind = torch.LongTensor(basic_get_imposter_indices(N))
230 | # This could probably be vectorized too, but I haven't.
231 | for i, imp in enumerate(imp_ind):
232 | local_loss = Sdiff[i, imp] + margin
233 | if (local_loss.data > 0).all():
234 | loss = loss + local_loss
235 | loss = loss / N
236 | return loss
237 |
238 | class AverageMeter(object):
239 | """Computes and stores the average and current value"""
240 | def __init__(self):
241 | self.reset()
242 |
243 | def reset(self):
244 | self.val = 0
245 | self.avg = 0
246 | self.sum = 0
247 | self.count = 0
248 |
249 | def update(self, val, n=1):
250 | self.val = val
251 | self.sum += val * n
252 | self.count += n
253 | self.avg = self.sum / self.count
254 |
255 | def adjust_learning_rate(base_lr, lr_decay, optimizer, epoch):
256 | """Sets the learning rate to the initial LR decayed by 10 every lr_decay epochs"""
257 | lr = base_lr * (0.1 ** (epoch // lr_decay))
258 | print('now learning rate changed to {:f}'.format(lr))
259 | for param_group in optimizer.param_groups:
260 | param_group['lr'] = lr
261 |
262 | def adjust_learning_rate2(base_lr, lr_decay, optimizer, epoch):
263 | """Sets the learning rate to the initial LR decayed by 10 every lr_decay epochs"""
264 | for param_group in optimizer.param_groups:
265 | cur_lr = param_group['lr']
266 | print('current learing rate is {:f}'.format(lr))
267 | lr = cur_lr * 0.1
268 | print('now learning rate changed to {:f}'.format(lr))
269 | for param_group in optimizer.param_groups:
270 | param_group['lr'] = lr
271 |
272 |
273 | def load_progress(prog_pkl, quiet=False):
274 | """
275 | load progress pkl file
276 | Args:
277 | prog_pkl(str): path to progress pkl file
278 | Return:
279 | progress(list):
280 | epoch(int):
281 | global_step(int):
282 | best_epoch(int):
283 | best_avg_r10(float):
284 | """
285 | def _print(msg):
286 | if not quiet:
287 | print(msg)
288 |
289 | with open(prog_pkl, "rb") as f:
290 | prog = pickle.load(f)
291 | epoch, global_step, best_epoch, best_avg_r10, _ = prog[-1]
292 |
293 | _print("\nPrevious Progress:")
294 | msg = "[%5s %7s %5s %7s %6s]" % ("epoch", "step", "best_epoch", "best_avg_r10", "time")
295 | _print(msg)
296 | return prog, epoch, global_step, best_epoch, best_avg_r10
297 |
298 | def count_parameters(model):
299 | return sum([p.numel() for p in model.parameters() if p.requires_grad])
300 |
301 | PrenetConfig = namedtuple(
302 | 'PrenetConfig', ['input_size', 'hidden_size', 'num_layers', 'dropout'])
303 |
304 | RNNConfig = namedtuple(
305 | 'RNNConfig',
306 | ['input_size', 'hidden_size', 'num_layers', 'dropout', 'residual'])
307 |
--------------------------------------------------------------------------------
/src/whisper_at_train/utilities/whisper_at_as_eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 5/28/23 2:36 PM
3 | # @Author : Yuan Gong
4 | # @Affiliation : Massachusetts Institute of Technology
5 | # @Email : yuangong@mit.edu
6 | # @File : whisper_transcribe_test.py
7 |
8 | # evaluate whisper-at on as-eval set
9 | # note this use 30s window, performance will be slightly lower than that using 10s window
10 |
11 | import sys
12 | argument = sys.argv[1]
13 | if argument=='4':
14 | argument='0,1,2,3'
15 | import os
16 | if argument != '-1':
17 | os.environ["CUDA_VISIBLE_DEVICES"]=argument
18 |
19 | import whisper_at as whisper
20 | import numpy as np
21 | import csv
22 | import json
23 | import torch
24 |
25 | def make_index_dict(label_csv):
26 | index_lookup = {}
27 | with open(label_csv, 'r') as f:
28 | csv_reader = csv.DictReader(f)
29 | line_count = 0
30 | for row in csv_reader:
31 | index_lookup[row['mid']] = row['index']
32 | line_count += 1
33 | return index_lookup
34 |
35 | def make_name_dict(label_csv):
36 | name_lookup = {}
37 | with open(label_csv, 'r') as f:
38 | csv_reader = csv.DictReader(f)
39 | line_count = 0
40 | for row in csv_reader:
41 | name_lookup[row['index']] = row['display_name']
42 | line_count += 1
43 | return name_lookup
44 |
45 | def lookup_list(index_list, label_csv):
46 | label_list = []
47 | table = make_name_dict(label_csv)
48 | for item in index_list:
49 | label_list.append(table[item])
50 | return label_list
51 |
52 | mdl_size='large-v1'
53 | at_low_compute=False
54 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55 | model = whisper.load_model(mdl_size, at_low_compute=at_low_compute).to(device)
56 |
57 | index_dict = make_index_dict('/data/sls/scratch/yuangong/whisper-at/src/class_labels_indices.csv')
58 | with open('/data/sls/scratch/yuangong/whisper-at/src/eval_data.json') as json_file:
59 | data = json.load(json_file)['data']
60 | num_file = len(data)
61 |
62 | all_pred, all_truth = torch.zeros([num_file, 527]).to(device), torch.zeros([num_file, 527]).to(device)
63 | for i, entry in enumerate(data):
64 | cur_wav = entry['wav']
65 | labels = entry['labels'].split(',')
66 | for label in labels:
67 | all_truth[i, int(index_dict[label])] = 1.0
68 | result = model.transcribe(cur_wav, language='en', logprob_threshold=None, compression_ratio_threshold=None)['audio_tag']
69 | all_pred[i] = result[0]
70 |
71 | if i % 100 == 0:
72 | print(i)
73 | np.save('./at_res/all_pred_{:s}_low_{:s}.npy'.format(mdl_size, str(at_low_compute)), all_pred.cpu().numpy())
74 | np.save('./at_res/all_truth_{:s}_low_{:s}.npy'.format(mdl_size, str(at_low_compute)), all_truth.cpu().numpy())
75 |
76 | np.save('./at_res/all_pred_{:s}_low_{:s}.npy'.format(mdl_size, str(at_low_compute)), all_pred.cpu().numpy())
77 | np.save('./at_res/all_truth_{:s}_low_{:s}.npy'.format(mdl_size, str(at_low_compute)), all_truth.cpu().numpy())
--------------------------------------------------------------------------------
/tltr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuanGongND/whisper-at/17d94d6acd53866390ce70f95afa13507dcb8aef/tltr.png
--------------------------------------------------------------------------------