├── km_model.pt ├── causal_hubert ├── __init__.py ├── kmeans.py └── causal_hubert.py ├── audio_dir ├── 1 │ ├── 1272-128104-0000.wav │ └── 1272-128104-0000.txt └── 2 │ ├── 1272-128104-0001.wav │ └── 1272-128104-0001.txt ├── .gitignore ├── README.md ├── speech2unit.py ├── requirements.txt ├── speech2unit_dir.py └── modeling_hubert.py /km_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nervjack2/Speech2Unit/HEAD/km_model.pt -------------------------------------------------------------------------------- /causal_hubert/__init__.py: -------------------------------------------------------------------------------- 1 | from .causal_hubert import * 2 | from .kmeans import * 3 | -------------------------------------------------------------------------------- /audio_dir/1/1272-128104-0000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nervjack2/Speech2Unit/HEAD/audio_dir/1/1272-128104-0000.wav -------------------------------------------------------------------------------- /audio_dir/2/1272-128104-0001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nervjack2/Speech2Unit/HEAD/audio_dir/2/1272-128104-0001.wav -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore Python bytecode files 2 | *.pyc 3 | *.pyo 4 | *.pyd 5 | __pycache__/ 6 | 7 | # Ignore virtual environments 8 | venv/ 9 | env/ 10 | 11 | # Ignore Jupyter Notebook checkpoints 12 | .ipynb_checkpoints/ 13 | 14 | # Ignore system files 15 | .DS_Store 16 | Thumbs.db 17 | 18 | # Ignore logs 19 | *.log 20 | 21 | # Ignore ckpt directories or files 22 | hubert_ckpt/ 23 | llama_ckpt/ 24 | vocoder_ckpt/ 25 | *.tar.gz -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speech to Interleaving Unit 2 | Transform audio into interleaving sequnce. 3 | We use Whisper for ASR transcription, CausalHuBERT+Kmeans for speech units. 4 | 5 | See audio_dir/ for examples of interleaving results. 6 | ## Install (Important) 7 | python == 3.8 8 | ``` 9 | git clone https://github.com/nervjack2/Speech2Unit.git 10 | cd Speech2Unit 11 | pip install -r requirements.txt 12 | git clone https://github.com/huggingface/transformers.git 13 | cp modeling_hubert.py transformers/src/transformers/models/hubert/modeling_hubert.py 14 | cd transformers 15 | pip install . 16 | ``` 17 | 18 | ## Usage 19 | ``` 20 | python3 speech2unit_dir.py --audio_dir AUDIO_DIR_PATH --ext EXT --downsample 2 21 | ``` 22 | AUDIO_DIR_PATH: the directory of the audio 23 | 24 | EXT: extension of the audio, ex. wav 25 | -------------------------------------------------------------------------------- /audio_dir/2/1272-128104-0001.txt: -------------------------------------------------------------------------------- 1 | Nor<|1131|><|162|><|162|><|162|><|162|><|1454|><|1454|><|891|><|891|><|162|><|162|><|61|><|61|><|385|><|297|><|1071|><|222|><|1019|><|944|><|1447|> is<|944|><|24|><|310|><|372|><|370|><|1083|><|7|> Mr.<|1565|><|1951|><|1367|><|703|><|1660|><|1788|><|1356|><|603|> Quilter's<|1409|><|565|><|1173|><|576|><|515|><|19|><|1418|><|1520|><|1156|><|1887|><|736|><|1951|> manner<|494|><|998|><|1935|><|1605|><|1413|><|1405|><|944|> less<|944|><|944|><|1357|><|968|><|1988|><|341|><|1545|><|1606|><|305|><|257|> interesting<|1660|><|1149|><|1674|><|905|><|1329|><|1870|><|167|><|1229|><|1229|><|23|><|590|><|1650|><|305|><|1660|><|1959|> than<|438|><|1263|><|160|><|1836|><|849|><|1432|><|530|><|1501|> his<|62|><|250|><|190|><|713|><|1269|><|1499|><|1305|> matter.<|736|><|1660|><|1804|><|1327|><|494|><|525|><|262|><|495|><|673|><|42|><|1570|><|944|><|944|><|1447|><|1292|><|1454|><|1698|><|501|><|651|><|1088|><|47|><|872|><|872|><|87|><|872|><|93|> 2 | -------------------------------------------------------------------------------- /causal_hubert/kmeans.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class ApplyKmeans(object): 7 | def __init__(self, km_path, use_gpu): 8 | self.km_model = joblib.load(km_path) 9 | self.C_np = self.km_model.cluster_centers_.transpose() 10 | self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True) 11 | 12 | self.C = torch.from_numpy(self.C_np) 13 | self.Cnorm = torch.from_numpy(self.Cnorm_np) 14 | if use_gpu and torch.cuda.is_available(): 15 | self.C = self.C.cuda() 16 | self.Cnorm = self.Cnorm.cuda() 17 | 18 | def __call__(self, x): 19 | if isinstance(x, torch.Tensor): 20 | x = x.to(self.C.device) 21 | dist = ( 22 | x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm 23 | ) 24 | return dist.argmin(dim=1).cpu().numpy() 25 | else: 26 | dist = ( 27 | (x**2).sum(1, keepdims=True) 28 | - 2 * np.matmul(x, self.C_np) 29 | + self.Cnorm_np 30 | ) 31 | return np.argmin(dist, axis=1) 32 | 33 | -------------------------------------------------------------------------------- /audio_dir/1/1272-128104-0000.txt: -------------------------------------------------------------------------------- 1 | Mr.<|162|><|162|><|162|><|162|><|162|><|162|><|162|><|162|><|162|><|61|><|61|><|61|><|795|><|218|><|1356|><|1129|><|869|><|1546|><|224|><|1638|><|881|><|1437|> Quilter<|565|><|587|><|812|><|1984|><|843|><|224|><|944|><|944|> is<|1229|><|1152|><|1212|><|1005|><|148|> the<|257|><|1316|><|1523|><|141|><|135|> apostle<|1546|><|904|><|586|><|1603|><|298|><|1026|><|893|><|943|><|1583|> of<|1317|><|1402|><|139|><|472|><|13|><|13|><|93|> the<|788|><|1982|><|277|><|1169|> middle<|55|><|1379|><|1023|><|89|><|845|> classes,<|42|><|1510|><|586|><|1450|><|1275|><|1882|><|1640|><|1002|><|1411|><|736|><|270|><|158|><|1499|><|1887|><|1007|><|7|><|1660|><|598|><|1262|><|980|> and<|1075|> we<|1339|><|223|><|150|><|201|> are<|652|><|1212|><|1662|> glad<|506|><|1970|><|1557|><|237|><|237|><|1413|> to<|1234|><|1035|><|1102|><|904|><|1309|><|1108|><|916|> welcome<|1704|><|595|><|748|><|345|><|125|><|531|> his<|277|><|375|><|1608|><|1604|><|1604|><|713|><|1156|><|1001|><|50|> gospel.<|1660|><|1462|><|506|><|854|><|1272|><|1002|><|305|><|1958|><|1660|><|1149|><|1339|><|1402|><|755|><|755|><|1513|><|501|><|307|><|872|><|872|><|1205|><|87|><|766|><|1481|><|501|><|93|> 2 | -------------------------------------------------------------------------------- /speech2unit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import joblib 3 | import soundfile as sf 4 | # from transformers import Wav2Vec2Model 5 | from causal_hubert import DiscreteHubertEncoder, ApplyKmeans 6 | from argparse import ArgumentParser 7 | from faster_whisper import WhisperModel 8 | 9 | TPS = 50 10 | 11 | def transcribe(audio_path): 12 | # Read audio 13 | audio, sr = sf.read(audio_path) 14 | assert sr == 16000, "Sample rate of audio should be 16000 Hz" 15 | # Maybe we can use batch pipeline in faster whisper for better efficiency 16 | segments, info = ASR.transcribe(audio, beam_size=5, language="en", condition_on_previous_text=False, word_timestamps=True) 17 | return segments 18 | 19 | def quantize(audio_path): 20 | feat, leng = encoder.encode(audio_path) 21 | ssl_units = apply_kmeans(feat) 22 | return [f"<|{p}|>" for p in ssl_units] 23 | 24 | 25 | def combine(kms, segments): 26 | words = [] 27 | for segment in segments: 28 | for w in segment.words: 29 | words.append((w.word, int(w.start * TPS))) 30 | for i, (w, s) in enumerate(words): 31 | kms.insert(i + s, ' ' + w) 32 | 33 | return ''.join(kms) 34 | 35 | if __name__ == '__main__': 36 | parser = ArgumentParser() 37 | parser.add_argument("--input_audio", type=str, help="Input audio file") 38 | parser.add_argument("--output_path", type=str, default="tmp.txt", help="Path to save interleaving sequence") 39 | parser.add_argument("--device", type=str, default="cuda", help="Acceleration device") 40 | parser.add_argument("--km_model", type=str, default="./km_model.pt") 41 | parser.add_argument("--fp16", action="store_true", help="Data types for quantizing HuBERT features. Using flash_attention_2 (float16), which is faster, but sometimes results in different results") 42 | args = parser.parse_args() 43 | 44 | # Initialize Whisper model for transcribing 45 | ASR = WhisperModel("andybi7676/cool-whisper", device=args.device, compute_type="float16") 46 | 47 | # Initialize causal HuBERT and kmeans quantize module 48 | encoder = DiscreteHubertEncoder() 49 | apply_kmeans = ApplyKmeans(args.km_model, use_gpu=True) 50 | 51 | # Transcribe given audio 52 | segments = transcribe(args.input_audio) 53 | 54 | # Quantize Causal HuBERT features 55 | kms = quantize(args.input_audio) 56 | 57 | # Generate interleaving sequence 58 | interleave = combine(kms, segments) 59 | 60 | # Dump results 61 | with open(args.output_path, 'w') as f: 62 | f.write(interleave + '\n') 63 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | annotated-types==0.7.0 2 | anyio==4.4.0 3 | audioread==3.0.1 4 | av==12.2.0 5 | cachetools==5.4.0 6 | certifi==2024.7.4 7 | cffi==1.16.0 8 | charset-normalizer==3.3.2 9 | coloredlogs==15.0.1 10 | contourpy==1.1.1 11 | ctranslate2==4.3.1 12 | cycler==0.12.1 13 | decorator==5.1.1 14 | distro==1.9.0 15 | exceptiongroup==1.2.2 16 | faster-whisper==1.0.3 17 | filelock==3.15.4 18 | flatbuffers==24.3.25 19 | fonttools==4.53.1 20 | fsspec==2024.6.1 21 | google-api-core==2.19.1 22 | google-auth==2.32.0 23 | google-cloud-bigquery==3.25.0 24 | google-cloud-core==2.4.1 25 | google-cloud-texttospeech==2.16.4 26 | google-crc32c==1.5.0 27 | google-resumable-media==2.7.1 28 | googleapis-common-protos==1.63.2 29 | grpcio==1.65.1 30 | grpcio-status==1.65.1 31 | h11==0.14.0 32 | httpcore==1.0.5 33 | httpx==0.27.0 34 | huggingface-hub==0.23.5 35 | humanfriendly==10.0 36 | idna==3.7 37 | importlib_metadata==8.0.0 38 | importlib_resources==6.4.4 39 | Jinja2==3.1.4 40 | joblib==1.4.2 41 | kiwisolver==1.4.5 42 | lazy_loader==0.4 43 | librosa==0.10.2.post1 44 | llvmlite==0.41.1 45 | MarkupSafe==2.1.5 46 | matplotlib==3.7.5 47 | mpmath==1.3.0 48 | msgpack==1.0.8 49 | networkx==3.1 50 | numba==0.58.1 51 | numpy==1.24.4 52 | nvidia-cublas-cu12==12.1.3.1 53 | nvidia-cuda-cupti-cu12==12.1.105 54 | nvidia-cuda-nvrtc-cu12==12.1.105 55 | nvidia-cuda-runtime-cu12==12.1.105 56 | nvidia-cudnn-cu12==8.9.2.26 57 | nvidia-cufft-cu12==11.0.2.54 58 | nvidia-curand-cu12==10.3.2.106 59 | nvidia-cusolver-cu12==11.4.5.107 60 | nvidia-cusparse-cu12==12.1.0.106 61 | nvidia-nccl-cu12==2.20.5 62 | nvidia-nvjitlink-cu12==12.5.82 63 | nvidia-nvtx-cu12==12.1.105 64 | onnxruntime==1.18.1 65 | openai==1.35.14 66 | opus-fast-mosestokenizer==0.0.8.6 67 | packaging==24.1 68 | pillow==10.4.0 69 | platformdirs==4.2.2 70 | pooch==1.8.2 71 | proto-plus==1.24.0 72 | protobuf==5.27.2 73 | pyasn1==0.6.0 74 | pyasn1_modules==0.4.0 75 | pycparser==2.22 76 | pydantic==2.8.2 77 | pydantic_core==2.20.1 78 | pyparsing==3.1.2 79 | python-dateutil==2.9.0.post0 80 | PyYAML==6.0.1 81 | regex==2024.7.24 82 | requests==2.32.3 83 | rsa==4.9 84 | safetensors==0.4.3 85 | scikit-learn==1.3.2 86 | scipy==1.10.1 87 | sentencepiece==0.2.0 88 | six==1.16.0 89 | sniffio==1.3.1 90 | soundfile==0.12.1 91 | soxr==0.3.7 92 | sympy==1.13.0 93 | threadpoolctl==3.5.0 94 | tokenizers==0.19.1 95 | torch==2.3.1 96 | torchaudio==2.3.1 97 | torchvision==0.18.1 98 | tqdm==4.66.4 99 | triton==2.3.1 100 | typing_extensions==4.12.2 101 | urllib3==2.2.2 102 | zipp==3.19.2 103 | -------------------------------------------------------------------------------- /speech2unit_dir.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import joblib 3 | import glob 4 | import librosa 5 | import os 6 | import soundfile as sf 7 | from causal_hubert import DiscreteHubertEncoder, ApplyKmeans 8 | from argparse import ArgumentParser 9 | from faster_whisper import WhisperModel 10 | from tqdm import tqdm 11 | 12 | 13 | def transcribe(audio_path): 14 | # Read audio 15 | audio, sr = sf.read(audio_path) 16 | if sr != 16000: 17 | audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) 18 | # assert sr == 16000, "Sample rate of audio should be 16000 Hz" 19 | # Maybe we can use batch pipeline in faster whisper for better efficiency 20 | segments, info = ASR.transcribe(audio, beam_size=5, language=args.lan, condition_on_previous_text=False, word_timestamps=True) 21 | return segments 22 | 23 | def quantize(audio_path, downsample): 24 | feat, leng = encoder.encode(audio_path) 25 | ssl_units = apply_kmeans(feat) 26 | return [f"<|{p}|>" for p in ssl_units][::downsample] 27 | 28 | def combine(kms, segments): 29 | words = [] 30 | for segment in segments: 31 | for w in segment.words: 32 | words.append((w.word, int(w.start * TPS))) 33 | for i, (w, s) in enumerate(words): 34 | # print(w, s) 35 | kms.insert(i + s, ' ' + w) 36 | 37 | return ''.join(kms) 38 | 39 | if __name__ == '__main__': 40 | parser = ArgumentParser() 41 | parser.add_argument("--audio_dir", type=str, help="Audio dir") 42 | parser.add_argument("--ext", type=str, help="Wave format", default='wav') 43 | parser.add_argument("--downsample", type=int, help="Downsample ratio", default=2) 44 | parser.add_argument("--lan", type=str, help="Language code", default='zh') 45 | parser.add_argument("--device", type=str, default="cuda", help="Acceleration device") 46 | parser.add_argument("--km_model", type=str, default="./km_model.pt") 47 | parser.add_argument("--fp16", action="store_true", help="Data types for quantizing HuBERT features. Using flash_attention_2 (float16), which is faster, but sometimes results in different results") 48 | args = parser.parse_args() 49 | 50 | TPS = 50/args.downsample 51 | # Initialize Whisper model for transcribing 52 | ASR = WhisperModel("andybi7676/cool-whisper", device=args.device, compute_type="float16") 53 | 54 | # Initialize causal HuBERT and kmeans quantize module 55 | encoder = DiscreteHubertEncoder() 56 | apply_kmeans = ApplyKmeans(args.km_model, use_gpu=True) 57 | 58 | file_lists = list(glob.glob(os.path.join(args.audio_dir, f'**/*.{args.ext}'), recursive=True)) 59 | print(f"Generate Interleaving Data for {len(file_lists)} files.") 60 | 61 | for audio_path in tqdm(file_lists): 62 | # Transcribe given audio 63 | segments = transcribe(audio_path) 64 | # Quantize Causal HuBERT features 65 | kms = quantize(audio_path, args.downsample) 66 | # Generate interleaving sequence 67 | interleave = combine(kms, segments) 68 | # Output path 69 | output_path = os.path.splitext(audio_path)[0] + ".txt" 70 | # Dump results 71 | with open(output_path, 'w') as f: 72 | f.write(interleave + '\n') -------------------------------------------------------------------------------- /causal_hubert/causal_hubert.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import soundfile as sf 3 | from transformers import ( 4 | Wav2Vec2FeatureExtractor, 5 | HubertModel, 6 | ) 7 | from tqdm import tqdm 8 | import torch 9 | import numpy as np 10 | import librosa 11 | 12 | 13 | class DiscreteHubertEncoder(): 14 | def __init__(self, batch_size=16, device="cuda"): 15 | model_path = "TencentGameMate/chinese-hubert-base" 16 | 17 | self.batch_size = batch_size 18 | self.device = device 19 | self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_path) 20 | self.model = HubertModel.from_pretrained(model_path) 21 | self.model = self.model.to(self.device) 22 | self.model.eval() 23 | 24 | 25 | def batch_encode(self, file_list): 26 | feats, lens = [], [] 27 | for i in tqdm(range(0, len(file_list), self.batch_size)): 28 | start_idx = i 29 | end_idx = min(i+self.batch_size, len(file_list)) 30 | wavs = [] 31 | for j in range(start_idx, end_idx): 32 | wav, sr = sf.read(file_list[j]) 33 | wavs.append(torch.from_numpy(wav)) 34 | 35 | batch_feats, batch_lens = self._encode(wavs) 36 | 37 | btz = batch_feats.shape[0] 38 | feats.extend([batch_feats[j, :batch_lens[j], :].numpy() for j in range(btz)]) 39 | lens.extend(batch_lens) 40 | 41 | torch.cuda.empty_cache() 42 | 43 | # if len(feats) >= SHARD_SIZE: 44 | # torch.save({ 45 | # "feats": feats[:SHARD_SIZE], "lens": lens[:SHARD_SIZE] 46 | # }, f"km_data_new/yt-data-{shard_id}.pt") 47 | # shard_id += 1 48 | # feats = feats[SHARD_SIZE:] 49 | # lens = lens[SHARD_SIZE:] 50 | # 51 | # torch.save({ 52 | # "feats": feats, "lens": lens 53 | # }, f"km_data_new/yt-data-{shard_id}.pt") 54 | 55 | return feats, lens 56 | 57 | 58 | def encode(self, audio_input, sr=16000): 59 | """ 60 | Encode one audio into discrete Hubert units. 61 | 62 | Parameters: 63 | audio_input(str, np.ndarray): can be path string or numpy array 64 | sr(int): sampling rate for audio_input 65 | """ 66 | 67 | if type(audio_input) == str: 68 | wav, sr = sf.read(audio_input) 69 | elif type(audio_input) == np.ndarray: 70 | if sr != 16000: 71 | audio_input = librosa.resample(audio_input, orig_sr=sr, target_sr=16000) 72 | wav = audio_input 73 | else: 74 | raise NotImplementedError 75 | 76 | feats, lens = self._encode([wav]) 77 | return feats[0], lens[0] 78 | 79 | 80 | def _encode(self, wavs): 81 | """ 82 | Encode list of audio into Hubert features 83 | 84 | Parameters: 85 | wavs: list of np.ndarray 86 | 87 | Returns: 88 | feats: list of torch.tensor, representing the L6 Hubert features 89 | lens: list of integer, representing the lengths of the features 90 | """ 91 | is_batch = (len(wavs) > 1) 92 | wavs = [torch.tensor(wav, dtype=torch.float32) for wav in wavs] 93 | max_len = max(wav.shape[0] for wav in wavs) 94 | wavs_padded = [F.pad(wav, (0, max_len - wav.shape[0])) for wav in wavs] 95 | wavs_padded = torch.vstack(wavs_padded).squeeze() 96 | 97 | input_values = self.feature_extractor(wavs_padded, return_tensors="pt", sampling_rate=16000).input_values 98 | input_values = input_values.to(self.device) 99 | if is_batch: 100 | input_values = input_values.squeeze() 101 | outputs = self.model(input_values, attention_mask=torch.ones(input_values.shape[0]).to(self.device), output_hidden_states=True) 102 | feats = outputs.hidden_states[6].detach().cpu() 103 | lens = [(l.shape[0]-80)//320 for l in wavs] 104 | 105 | return feats, lens 106 | 107 | -------------------------------------------------------------------------------- /modeling_hubert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch Hubert model.""" 16 | 17 | import warnings 18 | from typing import Optional, Tuple, Union 19 | 20 | import numpy as np 21 | import torch 22 | import torch.utils.checkpoint 23 | from torch import nn 24 | from torch.nn import CrossEntropyLoss, attention 25 | 26 | from ...activations import ACT2FN 27 | from ...integrations.deepspeed import is_deepspeed_zero3_enabled 28 | from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput 29 | from ...modeling_utils import PreTrainedModel 30 | from ...utils import ( 31 | add_code_sample_docstrings, 32 | add_start_docstrings, 33 | add_start_docstrings_to_model_forward, 34 | is_flash_attn_2_available, 35 | is_flash_attn_greater_or_equal_2_10, 36 | logging, 37 | replace_return_docstrings, 38 | ) 39 | from .configuration_hubert import HubertConfig 40 | 41 | 42 | if is_flash_attn_2_available(): 43 | from ...modeling_flash_attention_utils import _flash_attention_forward 44 | 45 | 46 | logger = logging.get_logger(__name__) 47 | 48 | _HIDDEN_STATES_START_POSITION = 1 49 | 50 | # General docstring 51 | _CONFIG_FOR_DOC = "HubertConfig" 52 | 53 | # Base docstring 54 | _CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft" 55 | _EXPECTED_OUTPUT_SHAPE = [1, 292, 768] 56 | 57 | # CTC docstring 58 | _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" 59 | _CTC_EXPECTED_LOSS = 22.68 60 | 61 | # Audio class docstring 62 | _SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks" 63 | _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" 64 | _SEQ_CLASS_EXPECTED_LOSS = 8.53 65 | 66 | 67 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices 68 | def _compute_mask_indices( 69 | shape: Tuple[int, int], 70 | mask_prob: float, 71 | mask_length: int, 72 | attention_mask: Optional[torch.LongTensor] = None, 73 | min_masks: int = 0, 74 | ) -> np.ndarray: 75 | """ 76 | Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for 77 | ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on 78 | CPU as part of the preprocessing during training. 79 | 80 | Args: 81 | shape: The shape for which to compute masks. This should be of a tuple of size 2 where 82 | the first element is the batch size and the second element is the length of the axis to span. 83 | mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of 84 | independently generated mask spans of length `mask_length` is computed by 85 | `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the 86 | actual percentage will be smaller. 87 | mask_length: size of the mask 88 | min_masks: minimum number of masked spans 89 | attention_mask: A (right-padded) attention mask which independently shortens the feature axis of 90 | each batch dimension. 91 | """ 92 | batch_size, sequence_length = shape 93 | 94 | if mask_length < 1: 95 | raise ValueError("`mask_length` has to be bigger than 0.") 96 | 97 | if mask_length > sequence_length: 98 | raise ValueError( 99 | f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" 100 | f" and `sequence_length`: {sequence_length}`" 101 | ) 102 | 103 | # epsilon is used for probabilistic rounding 104 | epsilon = np.random.rand(1).item() 105 | 106 | def compute_num_masked_span(input_length): 107 | """Given input length, compute how many spans should be masked""" 108 | num_masked_span = int(mask_prob * input_length / mask_length + epsilon) 109 | num_masked_span = max(num_masked_span, min_masks) 110 | 111 | # make sure num masked span <= sequence_length 112 | if num_masked_span * mask_length > sequence_length: 113 | num_masked_span = sequence_length // mask_length 114 | 115 | # make sure num_masked span is also <= input_length - (mask_length - 1) 116 | if input_length - (mask_length - 1) < num_masked_span: 117 | num_masked_span = max(input_length - (mask_length - 1), 0) 118 | 119 | return num_masked_span 120 | 121 | # compute number of masked spans in batch 122 | input_lengths = ( 123 | attention_mask.sum(-1).detach().tolist() 124 | if attention_mask is not None 125 | else [sequence_length for _ in range(batch_size)] 126 | ) 127 | 128 | # SpecAugment mask to fill 129 | spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) 130 | spec_aug_mask_idxs = [] 131 | 132 | max_num_masked_span = compute_num_masked_span(sequence_length) 133 | 134 | if max_num_masked_span == 0: 135 | return spec_aug_mask 136 | 137 | for input_length in input_lengths: 138 | # compute num of masked spans for this input 139 | num_masked_span = compute_num_masked_span(input_length) 140 | 141 | # get random indices to mask 142 | spec_aug_mask_idx = np.random.choice( 143 | np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False 144 | ) 145 | 146 | # pick first sampled index that will serve as a dummy index to pad vector 147 | # to ensure same dimension for all batches due to probabilistic rounding 148 | # Picking first sample just pads those vectors twice. 149 | if len(spec_aug_mask_idx) == 0: 150 | # this case can only happen if `input_length` is strictly smaller then 151 | # `sequence_length` in which case the last token has to be a padding 152 | # token which we can use as a dummy mask id 153 | dummy_mask_idx = sequence_length - 1 154 | else: 155 | dummy_mask_idx = spec_aug_mask_idx[0] 156 | 157 | spec_aug_mask_idx = np.concatenate( 158 | [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] 159 | ) 160 | spec_aug_mask_idxs.append(spec_aug_mask_idx) 161 | 162 | spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) 163 | 164 | # expand masked indices to masked spans 165 | spec_aug_mask_idxs = np.broadcast_to( 166 | spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) 167 | ) 168 | spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) 169 | 170 | # add offset to the starting indexes so that indexes now create a span 171 | offsets = np.arange(mask_length)[None, None, :] 172 | offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( 173 | batch_size, max_num_masked_span * mask_length 174 | ) 175 | spec_aug_mask_idxs = spec_aug_mask_idxs + offsets 176 | 177 | # ensure that we cannot have indices larger than sequence_length 178 | if spec_aug_mask_idxs.max() > sequence_length - 1: 179 | spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 180 | 181 | # scatter indices to mask 182 | np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) 183 | 184 | return spec_aug_mask 185 | 186 | 187 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Hubert 188 | class HubertNoLayerNormConvLayer(nn.Module): 189 | def __init__(self, config, layer_id=0): 190 | super().__init__() 191 | self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 192 | self.out_conv_dim = config.conv_dim[layer_id] 193 | 194 | self.conv = nn.Conv1d( 195 | self.in_conv_dim, 196 | self.out_conv_dim, 197 | kernel_size=config.conv_kernel[layer_id], 198 | stride=config.conv_stride[layer_id], 199 | bias=config.conv_bias, 200 | ) 201 | self.activation = ACT2FN[config.feat_extract_activation] 202 | 203 | def forward(self, hidden_states): 204 | hidden_states = self.conv(hidden_states) 205 | hidden_states = self.activation(hidden_states) 206 | return hidden_states 207 | 208 | 209 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Hubert 210 | class HubertLayerNormConvLayer(nn.Module): 211 | def __init__(self, config, layer_id=0): 212 | super().__init__() 213 | self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 214 | self.out_conv_dim = config.conv_dim[layer_id] 215 | 216 | self.conv = nn.Conv1d( 217 | self.in_conv_dim, 218 | self.out_conv_dim, 219 | kernel_size=config.conv_kernel[layer_id], 220 | stride=config.conv_stride[layer_id], 221 | bias=config.conv_bias, 222 | ) 223 | self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) 224 | self.activation = ACT2FN[config.feat_extract_activation] 225 | 226 | def forward(self, hidden_states): 227 | hidden_states = self.conv(hidden_states) 228 | 229 | hidden_states = hidden_states.transpose(-2, -1) 230 | hidden_states = self.layer_norm(hidden_states) 231 | hidden_states = hidden_states.transpose(-2, -1) 232 | 233 | hidden_states = self.activation(hidden_states) 234 | return hidden_states 235 | 236 | 237 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Hubert 238 | class HubertGroupNormConvLayer(nn.Module): 239 | def __init__(self, config, layer_id=0): 240 | super().__init__() 241 | self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 242 | self.out_conv_dim = config.conv_dim[layer_id] 243 | 244 | self.conv = nn.Conv1d( 245 | self.in_conv_dim, 246 | self.out_conv_dim, 247 | kernel_size=config.conv_kernel[layer_id], 248 | stride=config.conv_stride[layer_id], 249 | bias=config.conv_bias, 250 | ) 251 | self.activation = ACT2FN[config.feat_extract_activation] 252 | 253 | self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) 254 | 255 | def forward(self, hidden_states): 256 | hidden_states = self.conv(hidden_states) 257 | hidden_states = self.layer_norm(hidden_states) 258 | hidden_states = self.activation(hidden_states) 259 | return hidden_states 260 | 261 | 262 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert 263 | class HubertPositionalConvEmbedding(nn.Module): 264 | def __init__(self, config): 265 | super().__init__() 266 | self.conv = nn.Conv1d( 267 | config.hidden_size, 268 | config.hidden_size, 269 | kernel_size=config.num_conv_pos_embeddings, 270 | padding=config.num_conv_pos_embeddings // 2, 271 | groups=config.num_conv_pos_embedding_groups, 272 | ) 273 | 274 | weight_norm = nn.utils.weight_norm 275 | if hasattr(nn.utils.parametrizations, "weight_norm"): 276 | weight_norm = nn.utils.parametrizations.weight_norm 277 | 278 | if is_deepspeed_zero3_enabled(): 279 | import deepspeed 280 | 281 | with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): 282 | self.conv = weight_norm(self.conv, name="weight", dim=2) 283 | if hasattr(self.conv, "parametrizations"): 284 | weight_g = self.conv.parametrizations.weight.original0 285 | weight_v = self.conv.parametrizations.weight.original1 286 | else: 287 | weight_g = self.conv.weight_g 288 | weight_v = self.conv.weight_v 289 | deepspeed.zero.register_external_parameter(self, weight_v) 290 | deepspeed.zero.register_external_parameter(self, weight_g) 291 | else: 292 | self.conv = weight_norm(self.conv, name="weight", dim=2) 293 | 294 | self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings) 295 | self.activation = ACT2FN[config.feat_extract_activation] 296 | 297 | def forward(self, hidden_states): 298 | hidden_states = hidden_states.transpose(1, 2) 299 | 300 | hidden_states = self.conv(hidden_states) 301 | hidden_states = self.padding(hidden_states) 302 | hidden_states = self.activation(hidden_states) 303 | 304 | hidden_states = hidden_states.transpose(1, 2) 305 | return hidden_states 306 | 307 | 308 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Hubert 309 | class HubertSamePadLayer(nn.Module): 310 | def __init__(self, num_conv_pos_embeddings): 311 | super().__init__() 312 | self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 313 | 314 | def forward(self, hidden_states): 315 | if self.num_pad_remove > 0: 316 | hidden_states = hidden_states[:, :, : -self.num_pad_remove] 317 | return hidden_states 318 | 319 | 320 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Hubert 321 | class HubertFeatureEncoder(nn.Module): 322 | """Construct the features from raw audio waveform""" 323 | 324 | def __init__(self, config): 325 | super().__init__() 326 | 327 | if config.feat_extract_norm == "group": 328 | conv_layers = [HubertGroupNormConvLayer(config, layer_id=0)] + [ 329 | HubertNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) 330 | ] 331 | elif config.feat_extract_norm == "layer": 332 | conv_layers = [HubertLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] 333 | else: 334 | raise ValueError( 335 | f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" 336 | ) 337 | self.conv_layers = nn.ModuleList(conv_layers) 338 | self.gradient_checkpointing = False 339 | self._requires_grad = True 340 | 341 | def _freeze_parameters(self): 342 | for param in self.parameters(): 343 | param.requires_grad = False 344 | self._requires_grad = False 345 | 346 | def forward(self, input_values): 347 | hidden_states = input_values[:, None] 348 | 349 | # make sure hidden_states require grad for gradient_checkpointing 350 | if self._requires_grad and self.training: 351 | hidden_states.requires_grad = True 352 | 353 | for conv_layer in self.conv_layers: 354 | if self._requires_grad and self.gradient_checkpointing and self.training: 355 | hidden_states = self._gradient_checkpointing_func( 356 | conv_layer.__call__, 357 | hidden_states, 358 | ) 359 | else: 360 | hidden_states = conv_layer(hidden_states) 361 | 362 | return hidden_states 363 | 364 | 365 | class HubertFeatureExtractor(HubertFeatureEncoder): 366 | def __init__(self, config): 367 | super().__init__(config) 368 | warnings.warn( 369 | f"The class `{self.__class__.__name__}` has been depreciated " 370 | "and will be removed in Transformers v5. " 371 | f"Use `{self.__class__.__bases__[0].__name__}` instead.", 372 | FutureWarning, 373 | ) 374 | 375 | 376 | class HubertFeatureProjection(nn.Module): 377 | def __init__(self, config): 378 | super().__init__() 379 | self.feat_proj_layer_norm = config.feat_proj_layer_norm 380 | if self.feat_proj_layer_norm: 381 | self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) 382 | self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) 383 | self.dropout = nn.Dropout(config.feat_proj_dropout) 384 | 385 | def forward(self, hidden_states): 386 | # non-projected hidden states are needed for quantization 387 | if self.feat_proj_layer_norm: 388 | hidden_states = self.layer_norm(hidden_states) 389 | hidden_states = self.projection(hidden_states) 390 | hidden_states = self.dropout(hidden_states) 391 | return hidden_states 392 | 393 | 394 | # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Hubert 395 | class HubertAttention(nn.Module): 396 | """Multi-headed attention from 'Attention Is All You Need' paper""" 397 | 398 | def __init__( 399 | self, 400 | embed_dim: int, 401 | num_heads: int, 402 | dropout: float = 0.0, 403 | is_decoder: bool = False, 404 | bias: bool = True, 405 | is_causal: bool = False, 406 | config: Optional[HubertConfig] = None, 407 | ): 408 | super().__init__() 409 | self.embed_dim = embed_dim 410 | self.num_heads = num_heads 411 | self.dropout = dropout 412 | self.head_dim = embed_dim // num_heads 413 | self.config = config 414 | 415 | if (self.head_dim * num_heads) != self.embed_dim: 416 | raise ValueError( 417 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 418 | f" and `num_heads`: {num_heads})." 419 | ) 420 | self.scaling = self.head_dim**-0.5 421 | self.is_decoder = is_decoder 422 | self.is_causal = is_causal 423 | 424 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 425 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 426 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 427 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 428 | 429 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 430 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 431 | 432 | def forward( 433 | self, 434 | hidden_states: torch.Tensor, 435 | key_value_states: Optional[torch.Tensor] = None, 436 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 437 | attention_mask: Optional[torch.Tensor] = None, 438 | layer_head_mask: Optional[torch.Tensor] = None, 439 | output_attentions: bool = False, 440 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 441 | """Input shape: Batch x Time x Channel""" 442 | 443 | # if key_value_states are provided this layer is used as a cross-attention layer 444 | # for the decoder 445 | is_cross_attention = key_value_states is not None 446 | 447 | bsz, tgt_len, _ = hidden_states.size() 448 | 449 | # get query proj 450 | query_states = self.q_proj(hidden_states) * self.scaling 451 | # get key, value proj 452 | # `past_key_value[0].shape[2] == key_value_states.shape[1]` 453 | # is checking that the `sequence_length` of the `past_key_value` is the same as 454 | # the provided `key_value_states` to support prefix tuning 455 | if ( 456 | is_cross_attention 457 | and past_key_value is not None 458 | and past_key_value[0].shape[2] == key_value_states.shape[1] 459 | ): 460 | # reuse k,v, cross_attentions 461 | key_states = past_key_value[0] 462 | value_states = past_key_value[1] 463 | elif is_cross_attention: 464 | # cross_attentions 465 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 466 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 467 | elif past_key_value is not None: 468 | # reuse k, v, self_attention 469 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 470 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 471 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 472 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 473 | else: 474 | # self_attention 475 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 476 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 477 | 478 | if self.is_decoder: 479 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 480 | # Further calls to cross_attention layer can then reuse all cross-attention 481 | # key/value_states (first "if" case) 482 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 483 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 484 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 485 | # if encoder bi-directional self-attention `past_key_value` is always `None` 486 | past_key_value = (key_states, value_states) 487 | 488 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 489 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 490 | key_states = key_states.reshape(*proj_shape) 491 | value_states = value_states.reshape(*proj_shape) 492 | 493 | src_len = key_states.size(1) 494 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 495 | 496 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 497 | raise ValueError( 498 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" 499 | f" {attn_weights.size()}" 500 | ) 501 | 502 | if attention_mask is not None: 503 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 504 | raise ValueError( 505 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 506 | ) 507 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 508 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 509 | 510 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 511 | 512 | if layer_head_mask is not None: 513 | if layer_head_mask.size() != (self.num_heads,): 514 | raise ValueError( 515 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" 516 | f" {layer_head_mask.size()}" 517 | ) 518 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 519 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 520 | 521 | if output_attentions: 522 | # this operation is a bit awkward, but it's required to 523 | # make sure that attn_weights keeps its gradient. 524 | # In order to do so, attn_weights have to be reshaped 525 | # twice and have to be reused in the following 526 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 527 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 528 | else: 529 | attn_weights_reshaped = None 530 | 531 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 532 | 533 | attn_output = torch.bmm(attn_probs, value_states) 534 | 535 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 536 | raise ValueError( 537 | f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" 538 | f" {attn_output.size()}" 539 | ) 540 | 541 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 542 | attn_output = attn_output.transpose(1, 2) 543 | 544 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 545 | # partitioned across GPUs when using tensor-parallelism. 546 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 547 | 548 | attn_output = self.out_proj(attn_output) 549 | 550 | return attn_output, attn_weights_reshaped, past_key_value 551 | 552 | 553 | # Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Hubert 554 | class HubertFlashAttention2(HubertAttention): 555 | """ 556 | Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays 557 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 558 | flash attention and deal with padding tokens in case the input contains any of them. 559 | """ 560 | 561 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 562 | def __init__(self, *args, **kwargs): 563 | super().__init__(*args, **kwargs) 564 | 565 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 566 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 567 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 568 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 569 | 570 | def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 571 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) 572 | 573 | def forward( 574 | self, 575 | hidden_states: torch.Tensor, 576 | key_value_states: Optional[torch.Tensor] = None, 577 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 578 | attention_mask: Optional[torch.Tensor] = None, 579 | layer_head_mask: Optional[torch.Tensor] = None, 580 | output_attentions: bool = False, 581 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 582 | # HubertFlashAttention2 attention does not support output_attentions 583 | if output_attentions: 584 | raise ValueError("HubertFlashAttention2 attention does not support output_attentions") 585 | 586 | # if key_value_states are provided this layer is used as a cross-attention layer 587 | # for the decoder 588 | is_cross_attention = key_value_states is not None 589 | 590 | bsz, q_len, _ = hidden_states.size() 591 | 592 | # get query proj 593 | query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) 594 | # get key, value proj 595 | # `past_key_value[0].shape[2] == key_value_states.shape[1]` 596 | # is checking that the `sequence_length` of the `past_key_value` is the same as 597 | # the provided `key_value_states` to support prefix tuning 598 | if ( 599 | is_cross_attention 600 | and past_key_value is not None 601 | and past_key_value[0].shape[2] == key_value_states.shape[1] 602 | ): 603 | # reuse k,v, cross_attentions 604 | key_states = past_key_value[0].transpose(1, 2) 605 | value_states = past_key_value[1].transpose(1, 2) 606 | elif is_cross_attention: 607 | # cross_attentions 608 | key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) 609 | value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) 610 | elif past_key_value is not None: 611 | # reuse k, v, self_attention 612 | key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) 613 | value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) 614 | key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) 615 | value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) 616 | else: 617 | # self_attention 618 | key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) 619 | value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) 620 | 621 | if self.is_decoder: 622 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 623 | # Further calls to cross_attention layer can then reuse all cross-attention 624 | # key/value_states (first "if" case) 625 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 626 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 627 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 628 | # if encoder bi-directional self-attention `past_key_value` is always `None` 629 | past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) 630 | 631 | kv_seq_len = key_states.shape[-2] 632 | if past_key_value is not None: 633 | kv_seq_len += past_key_value[0].shape[-2] 634 | 635 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 636 | # therefore the input hidden states gets silently casted in float32. Hence, we need 637 | # cast them back in the correct dtype just to be sure everything works as expected. 638 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 639 | # in fp32. (LlamaRMSNorm handles it correctly) 640 | 641 | input_dtype = query_states.dtype 642 | if input_dtype == torch.float32: 643 | if torch.is_autocast_enabled(): 644 | target_dtype = torch.get_autocast_gpu_dtype() 645 | # Handle the case where the model is quantized 646 | elif hasattr(self.config, "_pre_quantization_dtype"): 647 | target_dtype = self.config._pre_quantization_dtype 648 | else: 649 | target_dtype = self.q_proj.weight.dtype 650 | 651 | logger.warning_once( 652 | f"The input hidden states seems to be silently casted in float32, this might be related to" 653 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 654 | f" {target_dtype}." 655 | ) 656 | 657 | query_states = query_states.to(target_dtype) 658 | key_states = key_states.to(target_dtype) 659 | value_states = value_states.to(target_dtype) 660 | 661 | attn_output = _flash_attention_forward( 662 | query_states, 663 | key_states, 664 | value_states, 665 | attention_mask, 666 | q_len, 667 | dropout=self.dropout, 668 | is_causal=self.is_causal, 669 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 670 | ) 671 | 672 | attn_output = attn_output.reshape(bsz, q_len, -1) 673 | attn_output = self.out_proj(attn_output) 674 | 675 | if not output_attentions: 676 | attn_weights = None 677 | 678 | return attn_output, attn_weights, past_key_value 679 | 680 | 681 | def create_causal_attention_mask(seq_len, window_size): 682 | mask = torch.full([seq_len, seq_len], float("-inf")) 683 | 684 | for i in range(seq_len): 685 | start_index = max(0, i - window_size + 1) 686 | mask[i, start_index:i+1] = 0 687 | 688 | return mask 689 | 690 | # import pickle 691 | # import os 692 | # 693 | # if not os.path.exists("custom_masks.pkl"): 694 | # custom_masks = [None] + [create_causal_attention_mask(i, 35*5) for i in range(1, 2000)] 695 | # pickle.dump(custom_masks, open("custom_masks.pkl", "wb")) 696 | # else: 697 | # custom_masks = pickle.load(open("custom_masks.pkl", "rb")) 698 | 699 | custom_masks = [None] * 4000 700 | 701 | 702 | class HubertSdpaAttention(HubertAttention): 703 | # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Hubert 704 | def forward( 705 | self, 706 | hidden_states: torch.Tensor, 707 | key_value_states: Optional[torch.Tensor] = None, 708 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 709 | attention_mask: Optional[torch.Tensor] = None, 710 | layer_head_mask: Optional[torch.Tensor] = None, 711 | output_attentions: bool = False, 712 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 713 | """Input shape: Batch x Time x Channel""" 714 | if output_attentions or layer_head_mask is not None: 715 | # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. 716 | logger.warning_once( 717 | "HubertModel is using HubertSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" 718 | ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 719 | ) 720 | return super().forward( 721 | hidden_states, 722 | key_value_states=key_value_states, 723 | past_key_value=past_key_value, 724 | attention_mask=attention_mask, 725 | layer_head_mask=layer_head_mask, 726 | output_attentions=output_attentions, 727 | ) 728 | 729 | # if key_value_states are provided this layer is used as a cross-attention layer 730 | # for the decoder 731 | is_cross_attention = key_value_states is not None 732 | 733 | bsz, tgt_len, _ = hidden_states.size() 734 | 735 | # get query proj 736 | query_states = self.q_proj(hidden_states) 737 | # get key, value proj 738 | # `past_key_value[0].shape[2] == key_value_states.shape[1]` 739 | # is checking that the `sequence_length` of the `past_key_value` is the same as 740 | # the provided `key_value_states` to support prefix tuning 741 | if ( 742 | is_cross_attention 743 | and past_key_value is not None 744 | and past_key_value[0].shape[2] == key_value_states.shape[1] 745 | ): 746 | # reuse k,v, cross_attentions 747 | key_states = past_key_value[0] 748 | value_states = past_key_value[1] 749 | elif is_cross_attention: 750 | # cross_attentions 751 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 752 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 753 | elif past_key_value is not None: 754 | # reuse k, v, self_attention 755 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 756 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 757 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 758 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 759 | else: 760 | # self_attention 761 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 762 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 763 | 764 | if self.is_decoder: 765 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 766 | # Further calls to cross_attention layer can then reuse all cross-attention 767 | # key/value_states (first "if" case) 768 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 769 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 770 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 771 | # if encoder bi-directional self-attention `past_key_value` is always `None` 772 | past_key_value = (key_states, value_states) 773 | 774 | query_states = self._shape(query_states, tgt_len, bsz) 775 | 776 | # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment 777 | # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. 778 | # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. 779 | is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False 780 | 781 | # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, 782 | # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 783 | # print("Use custom attention") 784 | btz = attention_mask.shape[0] 785 | seq_len = attention_mask.shape[-1] 786 | if custom_masks[seq_len] is None: 787 | custom_masks[seq_len] = create_causal_attention_mask(seq_len, 50*5) 788 | # masks = create_causal_attention_mask(seq_len, 35 * 5).to("cuda") 789 | 790 | from copy import deepcopy 791 | mask = deepcopy(custom_masks[seq_len]).to("cuda") 792 | masks = mask.repeat(btz, 1).reshape(attention_mask.shape) 793 | 794 | attn_output = torch.nn.functional.scaled_dot_product_attention( 795 | query_states, 796 | key_states, 797 | value_states, 798 | attn_mask=masks, 799 | dropout_p=self.dropout if self.training else 0.0, 800 | is_causal=is_causal, 801 | ) 802 | 803 | if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): 804 | raise ValueError( 805 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" 806 | f" {attn_output.size()}" 807 | ) 808 | 809 | attn_output = attn_output.transpose(1, 2) 810 | 811 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 812 | # partitioned across GPUs when using tensor-parallelism. 813 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 814 | 815 | attn_output = self.out_proj(attn_output) 816 | 817 | return attn_output, None, past_key_value 818 | 819 | 820 | HUBERT_ATTENTION_CLASSES = { 821 | "eager": HubertAttention, 822 | "sdpa": HubertSdpaAttention, 823 | "flash_attention_2": HubertFlashAttention2, 824 | } 825 | 826 | 827 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Hubert 828 | class HubertFeedForward(nn.Module): 829 | def __init__(self, config): 830 | super().__init__() 831 | self.intermediate_dropout = nn.Dropout(config.activation_dropout) 832 | 833 | self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) 834 | if isinstance(config.hidden_act, str): 835 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 836 | else: 837 | self.intermediate_act_fn = config.hidden_act 838 | 839 | self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) 840 | self.output_dropout = nn.Dropout(config.hidden_dropout) 841 | 842 | def forward(self, hidden_states): 843 | hidden_states = self.intermediate_dense(hidden_states) 844 | hidden_states = self.intermediate_act_fn(hidden_states) 845 | hidden_states = self.intermediate_dropout(hidden_states) 846 | 847 | hidden_states = self.output_dense(hidden_states) 848 | hidden_states = self.output_dropout(hidden_states) 849 | return hidden_states 850 | 851 | 852 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert, WAV2VEC2->HUBERT 853 | class HubertEncoderLayer(nn.Module): 854 | def __init__(self, config): 855 | super().__init__() 856 | self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation]( 857 | embed_dim=config.hidden_size, 858 | num_heads=config.num_attention_heads, 859 | dropout=config.attention_dropout, 860 | is_decoder=False, 861 | ) 862 | 863 | self.dropout = nn.Dropout(config.hidden_dropout) 864 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 865 | self.feed_forward = HubertFeedForward(config) 866 | self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 867 | 868 | def forward(self, hidden_states, attention_mask=None, output_attentions=False): 869 | attn_residual = hidden_states 870 | hidden_states, attn_weights, _ = self.attention( 871 | hidden_states, attention_mask=attention_mask, output_attentions=output_attentions 872 | ) 873 | hidden_states = self.dropout(hidden_states) 874 | hidden_states = attn_residual + hidden_states 875 | 876 | hidden_states = self.layer_norm(hidden_states) 877 | hidden_states = hidden_states + self.feed_forward(hidden_states) 878 | hidden_states = self.final_layer_norm(hidden_states) 879 | 880 | outputs = (hidden_states,) 881 | 882 | if output_attentions: 883 | outputs += (attn_weights,) 884 | 885 | return outputs 886 | 887 | 888 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->Hubert 889 | class HubertAttnAdapterLayer(nn.Module): 890 | def __init__(self, config): 891 | """ 892 | Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed 893 | up training throughput. 894 | """ 895 | super().__init__() 896 | self.input_dim = config.adapter_attn_dim 897 | self.hidden_dim = config.hidden_size 898 | 899 | self.norm = nn.LayerNorm(self.hidden_dim) 900 | self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) 901 | self.act_fn = nn.ReLU() 902 | self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) 903 | 904 | def forward(self, hidden_states: torch.FloatTensor): 905 | hidden_states = self.norm(hidden_states) 906 | 907 | hidden_states = self.linear_1(hidden_states) 908 | hidden_states = self.act_fn(hidden_states) 909 | hidden_states = self.linear_2(hidden_states) 910 | 911 | return hidden_states 912 | 913 | 914 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert, WAV2VEC2->HUBERT 915 | class HubertEncoderLayerStableLayerNorm(nn.Module): 916 | def __init__(self, config): 917 | super().__init__() 918 | self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation]( 919 | embed_dim=config.hidden_size, 920 | num_heads=config.num_attention_heads, 921 | dropout=config.attention_dropout, 922 | is_decoder=False, 923 | ) 924 | self.dropout = nn.Dropout(config.hidden_dropout) 925 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 926 | self.feed_forward = HubertFeedForward(config) 927 | self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 928 | 929 | if getattr(config, "adapter_attn_dim", None) is not None: 930 | self.adapter_layer = HubertAttnAdapterLayer(config) 931 | else: 932 | self.adapter_layer = None 933 | 934 | def forward( 935 | self, 936 | hidden_states: torch.Tensor, 937 | attention_mask: Optional[torch.Tensor] = None, 938 | output_attentions: bool = False, 939 | ): 940 | attn_residual = hidden_states 941 | hidden_states = self.layer_norm(hidden_states) 942 | hidden_states, attn_weights, _ = self.attention( 943 | hidden_states, attention_mask=attention_mask, output_attentions=output_attentions 944 | ) 945 | hidden_states = self.dropout(hidden_states) 946 | hidden_states = attn_residual + hidden_states 947 | hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) 948 | 949 | if self.adapter_layer is not None: 950 | hidden_states = hidden_states + self.adapter_layer(hidden_states) 951 | 952 | outputs = (hidden_states,) 953 | 954 | if output_attentions: 955 | outputs += (attn_weights,) 956 | 957 | return outputs 958 | 959 | 960 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Hubert 961 | class HubertEncoder(nn.Module): 962 | def __init__(self, config): 963 | super().__init__() 964 | self.config = config 965 | self.pos_conv_embed = HubertPositionalConvEmbedding(config) 966 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 967 | self.dropout = nn.Dropout(config.hidden_dropout) 968 | self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) 969 | self.gradient_checkpointing = False 970 | self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 971 | 972 | def forward( 973 | self, 974 | hidden_states: torch.tensor, 975 | attention_mask: Optional[torch.Tensor] = None, 976 | output_attentions: bool = False, 977 | output_hidden_states: bool = False, 978 | return_dict: bool = True, 979 | ): 980 | all_hidden_states = () if output_hidden_states else None 981 | all_self_attentions = () if output_attentions else None 982 | 983 | if attention_mask is not None: 984 | # make sure padded tokens output 0 985 | expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) 986 | hidden_states[~expand_attention_mask] = 0 987 | if self._use_flash_attention_2: 988 | # 2d mask is passed through the layers 989 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 990 | else: 991 | # extend attention_mask 992 | attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) 993 | attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min 994 | attention_mask = attention_mask.expand( 995 | attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] 996 | ) 997 | 998 | position_embeddings = self.pos_conv_embed(hidden_states) 999 | hidden_states = hidden_states + position_embeddings 1000 | hidden_states = self.layer_norm(hidden_states) 1001 | hidden_states = self.dropout(hidden_states) 1002 | 1003 | deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() 1004 | 1005 | for layer in self.layers: 1006 | if output_hidden_states: 1007 | all_hidden_states = all_hidden_states + (hidden_states,) 1008 | 1009 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 1010 | dropout_probability = torch.rand([]) 1011 | 1012 | skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False 1013 | if not skip_the_layer or deepspeed_zero3_is_enabled: 1014 | # under deepspeed zero3 all gpus must run in sync 1015 | if self.gradient_checkpointing and self.training: 1016 | layer_outputs = self._gradient_checkpointing_func( 1017 | layer.__call__, 1018 | hidden_states, 1019 | attention_mask, 1020 | output_attentions, 1021 | ) 1022 | else: 1023 | layer_outputs = layer( 1024 | hidden_states, attention_mask=attention_mask, output_attentions=output_attentions 1025 | ) 1026 | hidden_states = layer_outputs[0] 1027 | 1028 | if skip_the_layer: 1029 | layer_outputs = (None, None) 1030 | 1031 | if output_attentions: 1032 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 1033 | 1034 | if output_hidden_states: 1035 | all_hidden_states = all_hidden_states + (hidden_states,) 1036 | 1037 | if not return_dict: 1038 | return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) 1039 | return BaseModelOutput( 1040 | last_hidden_state=hidden_states, 1041 | hidden_states=all_hidden_states, 1042 | attentions=all_self_attentions, 1043 | ) 1044 | 1045 | 1046 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert 1047 | class HubertEncoderStableLayerNorm(nn.Module): 1048 | def __init__(self, config): 1049 | super().__init__() 1050 | self.config = config 1051 | self.pos_conv_embed = HubertPositionalConvEmbedding(config) 1052 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 1053 | self.dropout = nn.Dropout(config.hidden_dropout) 1054 | self.layers = nn.ModuleList( 1055 | [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] 1056 | ) 1057 | self.gradient_checkpointing = False 1058 | self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 1059 | 1060 | def forward( 1061 | self, 1062 | hidden_states, 1063 | attention_mask=None, 1064 | output_attentions=False, 1065 | output_hidden_states=False, 1066 | return_dict=True, 1067 | ): 1068 | all_hidden_states = () if output_hidden_states else None 1069 | all_self_attentions = () if output_attentions else None 1070 | 1071 | if attention_mask is not None: 1072 | # make sure padded tokens are not attended to 1073 | expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) 1074 | hidden_states[~expand_attention_mask] = 0 1075 | if self._use_flash_attention_2: 1076 | # 2d mask is passed through the layers 1077 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 1078 | else: 1079 | # extend attention_mask 1080 | attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) 1081 | attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min 1082 | attention_mask = attention_mask.expand( 1083 | attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] 1084 | ) 1085 | 1086 | position_embeddings = self.pos_conv_embed(hidden_states) 1087 | hidden_states = hidden_states + position_embeddings 1088 | hidden_states = self.dropout(hidden_states) 1089 | 1090 | deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() 1091 | 1092 | for layer in self.layers: 1093 | if output_hidden_states: 1094 | all_hidden_states = all_hidden_states + (hidden_states,) 1095 | 1096 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 1097 | dropout_probability = torch.rand([]) 1098 | 1099 | skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False 1100 | if not skip_the_layer or deepspeed_zero3_is_enabled: 1101 | # under deepspeed zero3 all gpus must run in sync 1102 | # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication 1103 | if self.gradient_checkpointing and self.training: 1104 | layer_outputs = self._gradient_checkpointing_func( 1105 | layer.__call__, 1106 | hidden_states, 1107 | attention_mask, 1108 | output_attentions, 1109 | ) 1110 | else: 1111 | layer_outputs = layer( 1112 | hidden_states, attention_mask=attention_mask, output_attentions=output_attentions 1113 | ) 1114 | hidden_states = layer_outputs[0] 1115 | 1116 | if skip_the_layer: 1117 | layer_outputs = (None, None) 1118 | 1119 | if output_attentions: 1120 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 1121 | 1122 | hidden_states = self.layer_norm(hidden_states) 1123 | 1124 | if output_hidden_states: 1125 | all_hidden_states = all_hidden_states + (hidden_states,) 1126 | 1127 | if not return_dict: 1128 | return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) 1129 | return BaseModelOutput( 1130 | last_hidden_state=hidden_states, 1131 | hidden_states=all_hidden_states, 1132 | attentions=all_self_attentions, 1133 | ) 1134 | 1135 | 1136 | class HubertPreTrainedModel(PreTrainedModel): 1137 | """ 1138 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 1139 | models. 1140 | """ 1141 | 1142 | config_class = HubertConfig 1143 | base_model_prefix = "hubert" 1144 | main_input_name = "input_values" 1145 | supports_gradient_checkpointing = True 1146 | _supports_flash_attn_2 = True 1147 | _supports_sdpa = True 1148 | 1149 | def _init_weights(self, module): 1150 | """Initialize the weights""" 1151 | if isinstance(module, nn.Linear): 1152 | # Slightly different from the TF version which uses truncated_normal for initialization 1153 | # cf https://github.com/pytorch/pytorch/pull/5617 1154 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 1155 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): 1156 | module.bias.data.zero_() 1157 | module.weight.data.fill_(1.0) 1158 | elif isinstance(module, nn.Conv1d): 1159 | if is_deepspeed_zero3_enabled(): 1160 | import deepspeed 1161 | 1162 | if hasattr(module, "weight_v") and hasattr(module, "weight_g"): 1163 | with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): 1164 | nn.init.kaiming_normal_(module.weight.data) 1165 | else: 1166 | with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): 1167 | nn.init.kaiming_normal_(module.weight.data) 1168 | else: 1169 | nn.init.kaiming_normal_(module.weight.data) 1170 | 1171 | if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: 1172 | module.bias.data.zero_() 1173 | 1174 | def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): 1175 | """ 1176 | Computes the output length of the convolutional layers 1177 | """ 1178 | 1179 | def _conv_out_length(input_length, kernel_size, stride): 1180 | # 1D convolutional layer output length formula taken 1181 | # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html 1182 | return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 1183 | 1184 | for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): 1185 | input_lengths = _conv_out_length(input_lengths, kernel_size, stride) 1186 | 1187 | return input_lengths 1188 | 1189 | def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): 1190 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) 1191 | batch_size = attention_mask.shape[0] 1192 | 1193 | attention_mask = torch.zeros( 1194 | (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device 1195 | ) 1196 | # these two operations makes sure that all values before the output lengths idxs are attended to 1197 | attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 1198 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() 1199 | return attention_mask 1200 | 1201 | 1202 | HUBERT_START_DOCSTRING = r""" 1203 | Hubert was proposed in [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden 1204 | Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, 1205 | Ruslan Salakhutdinov, Abdelrahman Mohamed. 1206 | 1207 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 1208 | library implements for all its model (such as downloading or saving etc.). 1209 | 1210 | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use 1211 | it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and 1212 | behavior. 1213 | 1214 | Parameters: 1215 | config ([`HubertConfig`]): Model configuration class with all the parameters of the model. 1216 | Initializing with a config file does not load the weights associated with the model, only the 1217 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 1218 | """ 1219 | 1220 | 1221 | HUBERT_INPUTS_DOCSTRING = r""" 1222 | Args: 1223 | input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): 1224 | Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file 1225 | into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install 1226 | soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and 1227 | conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. 1228 | attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1229 | Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, 1230 | 1]`: 1231 | 1232 | - 1 for tokens that are **not masked**, 1233 | - 0 for tokens that are **masked**. 1234 | 1235 | [What are attention masks?](../glossary#attention-mask) 1236 | 1237 | 1238 | 1239 | `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == 1240 | True`. For all models whose processor has `config.return_attention_mask == False`, such as 1241 | [hubert-base](https://huggingface.co/facebook/hubert-base-ls960), `attention_mask` should **not** be passed 1242 | to avoid degraded performance when doing batched inference. For such models `input_values` should simply be 1243 | padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly different 1244 | results depending on whether `input_values` is padded or not. 1245 | 1246 | 1247 | 1248 | output_attentions (`bool`, *optional*): 1249 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 1250 | tensors for more detail. 1251 | output_hidden_states (`bool`, *optional*): 1252 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 1253 | more detail. 1254 | return_dict (`bool`, *optional*): 1255 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 1256 | """ 1257 | 1258 | 1259 | @add_start_docstrings( 1260 | "The bare Hubert Model transformer outputting raw hidden-states without any specific head on top.", 1261 | HUBERT_START_DOCSTRING, 1262 | ) 1263 | class HubertModel(HubertPreTrainedModel): 1264 | def __init__(self, config: HubertConfig): 1265 | super().__init__(config) 1266 | self.config = config 1267 | self.feature_extractor = HubertFeatureEncoder(config) 1268 | self.feature_projection = HubertFeatureProjection(config) 1269 | 1270 | if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: 1271 | self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) 1272 | 1273 | if config.do_stable_layer_norm: 1274 | self.encoder = HubertEncoderStableLayerNorm(config) 1275 | else: 1276 | self.encoder = HubertEncoder(config) 1277 | 1278 | # Initialize weights and apply final processing 1279 | self.post_init() 1280 | 1281 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states 1282 | def _mask_hidden_states( 1283 | self, 1284 | hidden_states: torch.FloatTensor, 1285 | mask_time_indices: Optional[torch.FloatTensor] = None, 1286 | attention_mask: Optional[torch.LongTensor] = None, 1287 | ): 1288 | """ 1289 | Masks extracted features along time axis and/or along feature axis according to 1290 | [SpecAugment](https://arxiv.org/abs/1904.08779). 1291 | """ 1292 | 1293 | # `config.apply_spec_augment` can set masking to False 1294 | if not getattr(self.config, "apply_spec_augment", True): 1295 | return hidden_states 1296 | 1297 | # generate indices & apply SpecAugment along time axis 1298 | batch_size, sequence_length, hidden_size = hidden_states.size() 1299 | 1300 | if mask_time_indices is not None: 1301 | # apply SpecAugment along time axis with given mask_time_indices 1302 | hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) 1303 | elif self.config.mask_time_prob > 0 and self.training: 1304 | mask_time_indices = _compute_mask_indices( 1305 | (batch_size, sequence_length), 1306 | mask_prob=self.config.mask_time_prob, 1307 | mask_length=self.config.mask_time_length, 1308 | attention_mask=attention_mask, 1309 | min_masks=self.config.mask_time_min_masks, 1310 | ) 1311 | mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) 1312 | hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) 1313 | 1314 | if self.config.mask_feature_prob > 0 and self.training: 1315 | # generate indices & apply SpecAugment along feature axis 1316 | mask_feature_indices = _compute_mask_indices( 1317 | (batch_size, hidden_size), 1318 | mask_prob=self.config.mask_feature_prob, 1319 | mask_length=self.config.mask_feature_length, 1320 | min_masks=self.config.mask_feature_min_masks, 1321 | ) 1322 | mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) 1323 | mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) 1324 | hidden_states[mask_feature_indices] = 0 1325 | 1326 | return hidden_states 1327 | 1328 | @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) 1329 | @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) 1330 | def forward( 1331 | self, 1332 | input_values: Optional[torch.Tensor], 1333 | attention_mask: Optional[torch.Tensor] = None, 1334 | mask_time_indices: Optional[torch.FloatTensor] = None, 1335 | output_attentions: Optional[bool] = None, 1336 | output_hidden_states: Optional[bool] = None, 1337 | return_dict: Optional[bool] = None, 1338 | ) -> Union[Tuple, BaseModelOutput]: 1339 | """ 1340 | 1341 | Returns: 1342 | 1343 | Example: 1344 | 1345 | ```python 1346 | >>> from transformers import AutoProcessor, HubertModel 1347 | >>> from datasets import load_dataset 1348 | >>> import soundfile as sf 1349 | 1350 | >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") 1351 | >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") 1352 | 1353 | 1354 | >>> def map_to_array(batch): 1355 | ... speech, _ = sf.read(batch["file"]) 1356 | ... batch["speech"] = speech 1357 | ... return batch 1358 | 1359 | 1360 | >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") 1361 | >>> ds = ds.map(map_to_array) 1362 | 1363 | >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 1364 | >>> hidden_states = model(input_values).last_hidden_state 1365 | ```""" 1366 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1367 | output_hidden_states = ( 1368 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1369 | ) 1370 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1371 | 1372 | extract_features = self.feature_extractor(input_values) 1373 | extract_features = extract_features.transpose(1, 2) 1374 | 1375 | if attention_mask is not None: 1376 | # compute reduced attention_mask corresponding to feature vectors 1377 | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) 1378 | 1379 | hidden_states = self.feature_projection(extract_features) 1380 | hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) 1381 | 1382 | encoder_outputs = self.encoder( 1383 | hidden_states, 1384 | attention_mask=attention_mask, 1385 | output_attentions=output_attentions, 1386 | output_hidden_states=output_hidden_states, 1387 | return_dict=return_dict, 1388 | ) 1389 | 1390 | hidden_states = encoder_outputs[0] 1391 | 1392 | if not return_dict: 1393 | return (hidden_states,) + encoder_outputs[1:] 1394 | 1395 | return BaseModelOutput( 1396 | last_hidden_state=hidden_states, 1397 | hidden_states=encoder_outputs.hidden_states, 1398 | attentions=encoder_outputs.attentions, 1399 | ) 1400 | 1401 | 1402 | @add_start_docstrings( 1403 | """Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", 1404 | HUBERT_START_DOCSTRING, 1405 | ) 1406 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT 1407 | class HubertForCTC(HubertPreTrainedModel): 1408 | def __init__(self, config, target_lang: Optional[str] = None): 1409 | super().__init__(config) 1410 | 1411 | self.hubert = HubertModel(config) 1412 | self.dropout = nn.Dropout(config.final_dropout) 1413 | 1414 | self.target_lang = target_lang 1415 | 1416 | if config.vocab_size is None: 1417 | raise ValueError( 1418 | f"You are trying to instantiate {self.__class__} with a configuration that " 1419 | "does not define the vocabulary size of the language model head. Please " 1420 | "instantiate the model as follows: `HubertForCTC.from_pretrained(..., vocab_size=vocab_size)`. " 1421 | "or define `vocab_size` of your model's configuration." 1422 | ) 1423 | output_hidden_size = ( 1424 | config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size 1425 | ) 1426 | self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) 1427 | 1428 | # Initialize weights and apply final processing 1429 | self.post_init() 1430 | 1431 | def tie_weights(self): 1432 | """ 1433 | This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when 1434 | passing `target_lang=...` to `from_pretrained(...)`. 1435 | 1436 | This method is **not** supposed to be called by the user and is prone to be changed in the future. 1437 | """ 1438 | 1439 | # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to 1440 | # correctly load adapter layers for Hubert so that we do not have to introduce a new API to 1441 | # [`PreTrainedModel`]. While slightly hacky, Hubert never has to tie input and output embeddings, so that it is 1442 | # ok to repurpose this function here. 1443 | target_lang = self.target_lang 1444 | 1445 | if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: 1446 | raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") 1447 | elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: 1448 | logger.info("By default `target_lang` is set to 'eng'.") 1449 | elif target_lang is not None: 1450 | self.load_adapter(target_lang, force_load=True) 1451 | 1452 | def freeze_feature_extractor(self): 1453 | """ 1454 | Calling this function will disable the gradient computation for the feature encoder so that its parameter will 1455 | not be updated during training. 1456 | """ 1457 | warnings.warn( 1458 | "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " 1459 | "Please use the equivalent `freeze_feature_encoder` method instead.", 1460 | FutureWarning, 1461 | ) 1462 | self.freeze_feature_encoder() 1463 | 1464 | def freeze_feature_encoder(self): 1465 | """ 1466 | Calling this function will disable the gradient computation for the feature encoder so that its parameter will 1467 | not be updated during training. 1468 | """ 1469 | self.hubert.feature_extractor._freeze_parameters() 1470 | 1471 | def freeze_base_model(self): 1472 | """ 1473 | Calling this function will disable the gradient computation for the base model so that its parameters will not 1474 | be updated during training. Only the classification head will be updated. 1475 | """ 1476 | for param in self.hubert.parameters(): 1477 | param.requires_grad = False 1478 | 1479 | @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) 1480 | @add_code_sample_docstrings( 1481 | checkpoint=_CHECKPOINT_FOR_DOC, 1482 | output_type=CausalLMOutput, 1483 | config_class=_CONFIG_FOR_DOC, 1484 | expected_output=_CTC_EXPECTED_OUTPUT, 1485 | expected_loss=_CTC_EXPECTED_LOSS, 1486 | ) 1487 | def forward( 1488 | self, 1489 | input_values: Optional[torch.Tensor], 1490 | attention_mask: Optional[torch.Tensor] = None, 1491 | output_attentions: Optional[bool] = None, 1492 | output_hidden_states: Optional[bool] = None, 1493 | return_dict: Optional[bool] = None, 1494 | labels: Optional[torch.Tensor] = None, 1495 | ) -> Union[Tuple, CausalLMOutput]: 1496 | r""" 1497 | labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): 1498 | Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to 1499 | the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. 1500 | All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., 1501 | config.vocab_size - 1]`. 1502 | """ 1503 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1504 | 1505 | if labels is not None and labels.max() >= self.config.vocab_size: 1506 | raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") 1507 | 1508 | outputs = self.hubert( 1509 | input_values, 1510 | attention_mask=attention_mask, 1511 | output_attentions=output_attentions, 1512 | output_hidden_states=output_hidden_states, 1513 | return_dict=return_dict, 1514 | ) 1515 | 1516 | hidden_states = outputs[0] 1517 | hidden_states = self.dropout(hidden_states) 1518 | 1519 | logits = self.lm_head(hidden_states) 1520 | 1521 | loss = None 1522 | if labels is not None: 1523 | # retrieve loss input_lengths from attention_mask 1524 | attention_mask = ( 1525 | attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) 1526 | ) 1527 | input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) 1528 | 1529 | # assuming that padded tokens are filled with -100 1530 | # when not being attended to 1531 | labels_mask = labels >= 0 1532 | target_lengths = labels_mask.sum(-1) 1533 | flattened_targets = labels.masked_select(labels_mask) 1534 | 1535 | # ctc_loss doesn't support fp16 1536 | log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) 1537 | 1538 | with torch.backends.cudnn.flags(enabled=False): 1539 | loss = nn.functional.ctc_loss( 1540 | log_probs, 1541 | flattened_targets, 1542 | input_lengths, 1543 | target_lengths, 1544 | blank=self.config.pad_token_id, 1545 | reduction=self.config.ctc_loss_reduction, 1546 | zero_infinity=self.config.ctc_zero_infinity, 1547 | ) 1548 | 1549 | if not return_dict: 1550 | output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] 1551 | return ((loss,) + output) if loss is not None else output 1552 | 1553 | return CausalLMOutput( 1554 | loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions 1555 | ) 1556 | 1557 | 1558 | @add_start_docstrings( 1559 | """ 1560 | Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like 1561 | SUPERB Keyword Spotting. 1562 | """, 1563 | HUBERT_START_DOCSTRING, 1564 | ) 1565 | # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT 1566 | class HubertForSequenceClassification(HubertPreTrainedModel): 1567 | def __init__(self, config): 1568 | super().__init__(config) 1569 | 1570 | if hasattr(config, "add_adapter") and config.add_adapter: 1571 | raise ValueError( 1572 | "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)" 1573 | ) 1574 | self.hubert = HubertModel(config) 1575 | num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings 1576 | if config.use_weighted_layer_sum: 1577 | self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) 1578 | self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) 1579 | self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) 1580 | 1581 | # Initialize weights and apply final processing 1582 | self.post_init() 1583 | 1584 | def freeze_feature_extractor(self): 1585 | """ 1586 | Calling this function will disable the gradient computation for the feature encoder so that its parameters will 1587 | not be updated during training. 1588 | """ 1589 | warnings.warn( 1590 | "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " 1591 | "Please use the equivalent `freeze_feature_encoder` method instead.", 1592 | FutureWarning, 1593 | ) 1594 | self.freeze_feature_encoder() 1595 | 1596 | def freeze_feature_encoder(self): 1597 | """ 1598 | Calling this function will disable the gradient computation for the feature encoder so that its parameter will 1599 | not be updated during training. 1600 | """ 1601 | self.hubert.feature_extractor._freeze_parameters() 1602 | 1603 | def freeze_base_model(self): 1604 | """ 1605 | Calling this function will disable the gradient computation for the base model so that its parameters will not 1606 | be updated during training. Only the classification head will be updated. 1607 | """ 1608 | for param in self.hubert.parameters(): 1609 | param.requires_grad = False 1610 | 1611 | @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) 1612 | @add_code_sample_docstrings( 1613 | checkpoint=_SEQ_CLASS_CHECKPOINT, 1614 | output_type=SequenceClassifierOutput, 1615 | config_class=_CONFIG_FOR_DOC, 1616 | modality="audio", 1617 | expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, 1618 | expected_loss=_SEQ_CLASS_EXPECTED_LOSS, 1619 | ) 1620 | def forward( 1621 | self, 1622 | input_values: Optional[torch.Tensor], 1623 | attention_mask: Optional[torch.Tensor] = None, 1624 | output_attentions: Optional[bool] = None, 1625 | output_hidden_states: Optional[bool] = None, 1626 | return_dict: Optional[bool] = None, 1627 | labels: Optional[torch.Tensor] = None, 1628 | ) -> Union[Tuple, SequenceClassifierOutput]: 1629 | r""" 1630 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1631 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1632 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1633 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1634 | """ 1635 | 1636 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1637 | output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states 1638 | 1639 | outputs = self.hubert( 1640 | input_values, 1641 | attention_mask=attention_mask, 1642 | output_attentions=output_attentions, 1643 | output_hidden_states=output_hidden_states, 1644 | return_dict=return_dict, 1645 | ) 1646 | 1647 | if self.config.use_weighted_layer_sum: 1648 | hidden_states = outputs[_HIDDEN_STATES_START_POSITION] 1649 | hidden_states = torch.stack(hidden_states, dim=1) 1650 | norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) 1651 | hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) 1652 | else: 1653 | hidden_states = outputs[0] 1654 | 1655 | hidden_states = self.projector(hidden_states) 1656 | if attention_mask is None: 1657 | pooled_output = hidden_states.mean(dim=1) 1658 | else: 1659 | padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) 1660 | hidden_states[~padding_mask] = 0.0 1661 | pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) 1662 | 1663 | logits = self.classifier(pooled_output) 1664 | 1665 | loss = None 1666 | if labels is not None: 1667 | loss_fct = CrossEntropyLoss() 1668 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) 1669 | 1670 | if not return_dict: 1671 | output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] 1672 | return ((loss,) + output) if loss is not None else output 1673 | 1674 | return SequenceClassifierOutput( 1675 | loss=loss, 1676 | logits=logits, 1677 | hidden_states=outputs.hidden_states, 1678 | attentions=outputs.attentions, 1679 | ) 1680 | --------------------------------------------------------------------------------