'
102 |
103 | with torch.no_grad():
104 | if audio_tensor == None and audio_path != None:
105 | audio_tensor, sr = torchaudio.load(audio_path)
106 | audio_tensor = self.audio_processor(audio_tensor.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
107 |
108 | pre_tokenized_ids = self.llm_tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
109 | post_tokenized_ids = self.llm_tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
110 | output_tokenized_ids = self.llm_tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
111 |
112 | combined_embeds, atts, label_ids = self.encode(audio_tensor.to(device), pre_tokenized_ids.to(device), post_tokenized_ids.to(device), output_tokenized_ids.to(device))
113 |
114 | out = self.llm_model.generate(
115 | inputs_embeds=combined_embeds,
116 | max_new_tokens=max_new_tokens,
117 | pad_token_id=self.llm_tokenizer.pad_token_id
118 | ).cpu().tolist()[0]
119 |
120 | output_text = self.llm_tokenizer.decode(out, skip_special_tokens=True)
121 | return output_text
122 |
123 |
--------------------------------------------------------------------------------
/huggingface/push_to_hub.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": []
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "# Load model and push to Huggingface"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from hf_repo.config import SpeechLLMModelConfig\n",
24 | "from hf_repo.model import SpeechLLMModel\n",
25 | "import torch\n",
26 | "\n",
27 | "conf = SpeechLLMModelConfig()\n",
28 | "model = SpeechLLMModel(conf)"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "model.load_state_dict(torch.load('/root/ml-research/multi-modal-llm/repo/paper_exp/checkpoints/pth/torch_model_best_checkpoints.pth')) "
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "conf.register_for_auto_class()\n",
47 | "model.register_for_auto_class(\"AutoModel\")"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "conf.push_to_hub('skit-ai/speechllm-1.5B', commit_message=\"checkpoint update\")\n",
57 | "model.push_to_hub('skit-ai/speechllm-1.5B', commit_message=\"checkpoint update\")"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "# Infer after push"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "# # Load model directly\n",
74 | "from transformers import AutoModel\n",
75 | "model = AutoModel.from_pretrained(\"skit-ai/speechllm-1.5B\", trust_remote_code=True)"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "model = model.to(\"cuda\").eval()"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "%%time\n",
94 | "model.generate_meta(\n",
95 | " \"/root/ml-research/multi-modal-llm/datadir/data/LibriSpeech/dev-other/1255/90407/1255-90407-0004.flac\", \n",
96 | " instruction=\"Give me the [SpeechActivity, Transcript, Gender, Age, Accent, Emotion] of the audio.\",\n",
97 | " )"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "import torchaudio\n",
107 | "\n",
108 | "audio_tensor, rate = torchaudio.load(\"/root/ml-research/multi-modal-llm/datadir/data/LibriSpeech/dev-other/1255/90407/1255-90407-0004.flac\")\n",
109 | "model.generate_meta(\n",
110 | " audio_tensor=audio_tensor, \n",
111 | " instruction=\"Give me the [SpeechActivity, Transcript, Gender, Age, Accen, Emotion] of the audio.\",\n",
112 | " )"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "import IPython\n",
122 | "IPython.display.Audio(\"/root/ml-research/multi-modal-llm/datadir/data/LibriSpeech/dev-other/1255/90407/1255-90407-0004.flac\")"
123 | ]
124 | }
125 | ],
126 | "metadata": {
127 | "kernelspec": {
128 | "display_name": "mm_trainer",
129 | "language": "python",
130 | "name": "python3"
131 | },
132 | "language_info": {
133 | "codemirror_mode": {
134 | "name": "ipython",
135 | "version": 3
136 | },
137 | "file_extension": ".py",
138 | "mimetype": "text/x-python",
139 | "name": "python",
140 | "nbconvert_exporter": "python",
141 | "pygments_lexer": "ipython3",
142 | "version": "3.11.3"
143 | }
144 | },
145 | "nbformat": 4,
146 | "nbformat_minor": 2
147 | }
148 |
--------------------------------------------------------------------------------
/huggingface/save_checkpoint.py:
--------------------------------------------------------------------------------
1 | from trainer import SpeechLLMLightning
2 | import os
3 | import torch
4 |
5 | ckpt_path = "best_checkpoint.ckpt"
6 | model = SpeechLLMLightning.load_from_checkpoint(ckpt_path)
7 | model = model.eval()
8 |
9 | # Directory to save the models
10 | save_dir = "checkpoints/pth"
11 | os.makedirs(save_dir, exist_ok=True)
12 |
13 | model.llm_model = model.llm_model.merge_and_unload()
14 | torch.save(model.state_dict(), 'torch_model_best_checkpoints.pth')
15 | print("Models saved successfully.")
16 |
--------------------------------------------------------------------------------
/model/__pycache__/connector.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skit-ai/SpeechLLM/f44d361277ae5e2fa687b39f861f630ca2571318/model/__pycache__/connector.cpython-311.pyc
--------------------------------------------------------------------------------
/model/__pycache__/encoder.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skit-ai/SpeechLLM/f44d361277ae5e2fa687b39f861f630ca2571318/model/__pycache__/encoder.cpython-311.pyc
--------------------------------------------------------------------------------
/model/__pycache__/llm.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skit-ai/SpeechLLM/f44d361277ae5e2fa687b39f861f630ca2571318/model/__pycache__/llm.cpython-311.pyc
--------------------------------------------------------------------------------
/model/connector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def get_connector(name, audio_enc_dim, llm_dim, k):
6 | if name == 'linear-pool':
7 | return LinearPoolConnector(audio_enc_dim, llm_dim, k)
8 | elif name == 'linear':
9 | return LinearConnector(audio_enc_dim, llm_dim, k)
10 | elif name == 'cnn':
11 | return CNNConnector(audio_enc_dim, llm_dim, k)
12 | else:
13 | raise NotImplementedError
14 |
15 | class LinearConnector(nn.Module):
16 | def __init__(self, in_dim, out_dim, k):
17 | super().__init__()
18 | self.layer = nn.Linear(in_dim, out_dim)
19 | self.pool = nn.AvgPool1d(kernel_size=k, stride=k)
20 |
21 | def forward(self, x):
22 | x = self.layer(x)
23 | x = x.transpose(1, 2)
24 | x = self.pool(x)
25 | x = x.transpose(1, 2)
26 | return x
27 |
28 |
29 | class LinearPoolConnector(nn.Module):
30 | def __init__(self, input_dim, output_dim, k):
31 | super(LinearPoolConnector, self).__init__()
32 | self.linear1 = nn.Sequential(
33 | nn.Linear(input_dim, output_dim),
34 | nn.ReLU())
35 | self.pool = nn.AvgPool1d(kernel_size=k, stride=k)
36 | self.linear2 = nn.Sequential(
37 | nn.Linear(output_dim, output_dim),
38 | nn.ReLU(),
39 | nn.Linear(output_dim, output_dim))
40 |
41 | def forward(self, x):
42 | # x: [B, T, d]
43 | x = self.linear1(x) # x: [B, T, D]
44 | x = x.transpose(1, 2) # x: [B, D, T]
45 | x = self.pool(x) # x: [B, D, T']
46 | x = x.transpose(1, 2) # x: [B, T', D]
47 | x = self.linear2(x)
48 | return x
49 |
50 | class CNNConnector(nn.Module):
51 | def __init__(self, in_channels, out_channels, k):
52 | super().__init__()
53 | self.layer = nn.Sequential(
54 | nn.ReLU(),
55 | nn.Conv1d(in_channels, out_channels//2, kernel_size=5,
56 | stride=1, padding=0),
57 | nn.ReLU(),
58 | nn.Conv1d(out_channels//2, out_channels, kernel_size=5,
59 | stride=k, padding=0),
60 | nn.ReLU(),
61 | nn.Conv1d(out_channels, out_channels, kernel_size=5,
62 | stride=1, padding=0),
63 | )
64 |
65 | def forward(self, x):
66 | return self.layer(x.transpose(1,2)).transpose(1,2)
67 |
68 |
69 |
70 | if __name__ == "__main__":
71 | model = CNNConnector(128, 256)
72 | x = torch.randn(4, 50, 128)
73 | z = model(x)
74 | print(z.shape)
--------------------------------------------------------------------------------
/model/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from transformers import AutoModel
4 | from speechtokenizer import SpeechTokenizer
5 |
6 | def get_audio_encoder(name, finetune_encoder):
7 | if name == "facebook/hubert-xlarge-ll60k":
8 | return TransformerAudioEnoder(model_name='facebook/hubert-xlarge-ll60k', finetune=finetune_encoder)
9 | elif name == "microsoft/wavlm-large":
10 | return TransformerAudioEnoder(model_name='microsoft/wavlm-large', finetune=finetune_encoder)
11 | elif name == "openai/whisper-small":
12 | return WhisperAudioEncoder(finetune=finetune_encoder)
13 | elif name == 'speech-tokenizer':
14 | return SpeechTokenizerEnoder(finetune=finetune_encoder)
15 | elif name == 'audio-clip':
16 | return AudioCLIPEncoder(finetune=finetune_encoder)
17 | else:
18 | raise NotImplementedError
19 |
20 | class TransformerAudioEnoder(nn.Module):
21 | def __init__(self, model_name='facebook/hubert-xlarge-ll60k', finetune=False):
22 | super().__init__()
23 | self.encoder = AutoModel.from_pretrained(model_name)
24 | for param in self.encoder.parameters():
25 | param.requires_grad = finetune
26 |
27 | for param in self.encoder.encoder.layers[-15:].parameters():
28 | param.requires_grad = True
29 |
30 | def forward(self, x):
31 | return self.encoder(x).last_hidden_state
32 |
33 |
34 | if __name__ == "__main__":
35 | model = SpeechTokenizerEnoder()
36 | # print(model)
37 |
38 | x = torch.randn(2, 1, 16000)
39 | z = model(x)
40 | print(z.shape)
--------------------------------------------------------------------------------
/model/llm.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer
2 | from peft import LoraConfig, get_peft_model, PeftModel
3 |
4 | def get_llm(name, use_lora, lora_r, lora_alpha):
5 | llm_tokenizer = AutoTokenizer.from_pretrained(name)
6 | llm_model = AutoModelForCausalLM.from_pretrained(
7 | name,
8 | trust_remote_code=True,
9 | )
10 |
11 | if use_lora:
12 | peft_config = LoraConfig(
13 | r=lora_r,
14 | lora_alpha=lora_alpha,
15 | target_modules="all-linear",
16 | lora_dropout=0.05,
17 | task_type="CAUSAL_LM",
18 | )
19 |
20 | llm_model = get_peft_model(llm_model, peft_config)
21 | llm_model.print_trainable_parameters()
22 |
23 | return llm_tokenizer, llm_model
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | wandb==0.15.3
3 | transformers==4.41.2
4 | torchaudio==2.0.2
5 | tokenizers==0.19.1
6 | pytorch-lightning==1.9.4
7 | peft ==0.9.0
8 | librosa==0.10.1
9 | jiwer==3.0.3
10 | huggingface-hub==0.23.0
11 | datasets==2.2.1
12 | accelerate==0.30.0
13 | streamlit==1.34.0
14 | audio_recorder_streamlit==0.0.8
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning import Trainer
2 | from pytorch_lightning.loggers import WandbLogger
3 | from trainer import SpeechLLMLightning
4 | from dataset import InstructionalAudioDataset
5 |
6 | import torch.utils.data as data_utils
7 | from dataset import InstructionalAudioDataset, MyCollator
8 |
9 | if __name__ == "__main__":
10 |
11 | model_config = {
12 | 'audio_enc_dim': 1024,
13 | 'llm_dim': 2048,
14 | 'audio_encoder_name': "microsoft/wavlm-large", #"facebook/hubert-xlarge-ll60k",
15 | 'connector_name': 'cnn',
16 | 'llm_name': "TinyLlama/TinyLlama-1.1B-Chat-v1.0", #"google/gemma-2b-it", #"TinyLlama/TinyLlama-1.1B-Chat-v1.0", #"microsoft/phi-2",
17 | 'finetune_encoder': False,
18 | 'connector_k': 2,
19 | 'use_lora': True,
20 | 'lora_r': 8,
21 | 'lora_alpha': 16,
22 | 'max_lr': 3e-4,
23 | 'total_training_step': 1000000,
24 | 'warmup_steps': 100,
25 | 'train_batch_per_epoch': 10000,
26 | 'grad_accumulate_steps': 8
27 | }
28 |
29 | model = SpeechLLMLightning.load_from_checkpoint("./path-to-checkpoint-dir/best_checkpoint.ckpt")
30 | tokenizer = model.llm_tokenizer
31 |
32 | test_dataset = InstructionalAudioDataset(
33 | csv_file='./path-to-test-dir/librispeech-test-clean.csv', # same train.csv and dev.csv
34 | mode='test'
35 | )
36 |
37 | my_collator = MyCollator(model_config['audio_encoder_name'], tokenizer)
38 | test_loader = data_utils.DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=my_collator, num_workers=3)
39 |
40 | trainer = Trainer(
41 | accelerator='gpu', devices=1
42 | )
43 | trainer.test(model=model, dataloaders=test_loader)
44 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from pytorch_lightning import Trainer
4 | from pytorch_lightning.loggers import WandbLogger
5 | from trainer import SpeechLLMLightning
6 | from dataset import InstructionalAudioDataset, MyCollator
7 | from pytorch_lightning.strategies import DDPStrategy
8 |
9 | import torch.utils.data as data_utils
10 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
11 | import wandb
12 |
13 | if __name__ == "__main__":
14 | log_path = 'WavLM-CNN-tinyllama-run1'
15 | wandb.init(project="mmllm", name=log_path)
16 | logger = WandbLogger(project="mmllm", name=log_path)
17 |
18 | model_config = {
19 | 'audio_enc_dim': 1024,
20 | 'llm_dim': 2048,
21 | 'audio_encoder_name': "microsoft/wavlm-large",
22 | 'connector_name': 'cnn',
23 | 'llm_name': "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
24 | 'finetune_encoder': False,
25 | 'connector_k': 2,
26 | 'use_lora': True,
27 | 'lora_r': 8,
28 | 'lora_alpha': 16,
29 | 'max_lr': 1e-4,
30 | 'total_training_step': 10000000,
31 | 'warmup_steps': 100,
32 | 'train_batch_per_epoch': 10000,
33 | 'grad_accumulate_steps': 8
34 | }
35 |
36 | model = SpeechLLMLightning(**model_config)
37 | tokenizer = model.llm_tokenizer
38 |
39 | train_dataset = InstructionalAudioDataset(
40 | csv_file = './data_samples/train.csv',
41 | mode='train',
42 | random_keys_prob=0.2,
43 | )
44 |
45 | val_dataset = InstructionalAudioDataset(
46 | csv_file='./data_samples/dev.csv',
47 | mode='test'
48 | )
49 |
50 | print(len(train_dataset), len(val_dataset))
51 |
52 | my_collator = MyCollator(model_config['audio_encoder_name'], tokenizer)
53 | train_loader = data_utils.DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=my_collator, num_workers=3)
54 | val_loader = data_utils.DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=my_collator, num_workers=3)
55 |
56 | checkpoint_callback = ModelCheckpoint(dirpath=f"checkpoints", filename=log_path+'-{epoch}', save_top_k=1, monitor="val/loss", save_last=True)
57 | early_stop_callback = EarlyStopping(monitor="val/loss", min_delta=0.00, patience=10, verbose=False, mode="min")
58 |
59 | trainer = Trainer(
60 | max_epochs=model_config['total_training_step']//model_config['train_batch_per_epoch'], gpus=2,
61 | strategy=DDPStrategy(find_unused_parameters=True),
62 | limit_train_batches=model_config['train_batch_per_epoch'],
63 | limit_val_batches=model_config['train_batch_per_epoch'],
64 | log_every_n_steps=model_config['train_batch_per_epoch'],
65 | enable_checkpointing=True,
66 | callbacks=[checkpoint_callback],
67 | fast_dev_run=False, logger=logger,
68 | accumulate_grad_batches=model_config['grad_accumulate_steps'],
69 | resume_from_checkpoint=None
70 | )
71 | trainer.fit(model, train_loader, val_loader)
72 |
73 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.optim import Adam
4 | from torch.optim.lr_scheduler import LambdaLR
5 |
6 | import wandb
7 | import pytorch_lightning as pl
8 | import numpy as np
9 | from jiwer import wer
10 | import torchmetrics
11 | import random
12 | import re
13 | import json
14 |
15 | from model.encoder import get_audio_encoder, TransformerAudioEnoder
16 | from model.connector import get_connector, LinearConnector, LinearPoolConnector, CNNConnector
17 | from model.llm import get_llm
18 |
19 | class SpeechLLMLightning(pl.LightningModule):
20 | def __init__(self,
21 | audio_enc_dim=512,
22 | llm_dim=2048,
23 | audio_encoder_name="speech-tokenizer",
24 | connector_name='linear-pool',
25 | llm_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
26 | finetune_encoder=False,
27 | connector_k=5,
28 | use_lora=True,
29 | lora_r=32,
30 | lora_alpha=2,
31 | max_lr=3e-4,
32 | total_training_step=500000,
33 | warmup_steps=1000,
34 | **kwargs
35 | ):
36 | super().__init__()
37 | self.save_hyperparameters()
38 |
39 | self.audio_enc_dim = audio_enc_dim
40 | self.llm_dim = llm_dim
41 | self.llm_name = llm_name
42 | self.finetune_encoder = finetune_encoder
43 | self.use_lora = use_lora
44 |
45 | self.audio_encoder = get_audio_encoder(audio_encoder_name, finetune_encoder)
46 | self.connector = get_connector(connector_name, audio_enc_dim, llm_dim, connector_k)
47 | self.llm_tokenizer, self.llm_model = get_llm(llm_name, use_lora, lora_r, lora_alpha)
48 |
49 | self.max_lr = max_lr
50 | self.total_training_step = total_training_step
51 | self.warmup_steps = warmup_steps
52 | self.use_embedding_loss = False
53 | self.num_validation_samples = 5000
54 |
55 | def configure_optimizers(self):
56 | opt = [
57 | {"params": self.audio_encoder.parameters(), "lr": 1e-5},
58 | {"params": self.connector.parameters(), "lr": self.max_lr},
59 | {"params": self.llm_model.parameters(), "lr": self.max_lr},
60 | ]
61 | optimizer = Adam(opt, lr=self.max_lr)
62 | return optimizer
63 |
64 | def encode(self, mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids, return_embedding_loss=False):
65 | batch_size = mel.shape[0]
66 |
67 | speech_embeds = self.audio_encoder(mel)
68 | speech_embeds = self.connector(speech_embeds)
69 |
70 | embedder = self.llm_model.model.model.embed_tokens
71 | pre_prompt_embeds = embedder(pre_tokenized_ids)
72 | post_prompt_embeds = embedder(post_tokenized_ids)
73 | output_prompt_embeds = embedder(output_tokenized_ids)
74 |
75 | combined_embeds = torch.cat([pre_prompt_embeds, speech_embeds, post_prompt_embeds, output_prompt_embeds], dim=1)
76 | atts = torch.ones(combined_embeds.size()[:-1], dtype=torch.long).to(combined_embeds.device)
77 |
78 | input_token_length = pre_tokenized_ids.shape[1] + speech_embeds.shape[1] + post_tokenized_ids.shape[1]
79 | label_ids = torch.cat([
80 | torch.ones([batch_size, input_token_length], device=combined_embeds.device)*-100,
81 | output_tokenized_ids
82 | ], 1).to(combined_embeds.device).to(torch.int64)
83 | return combined_embeds, atts, label_ids
84 |
85 | def forward(self, embeds, atts, label_ids):
86 | out = self.llm_model(
87 | inputs_embeds=embeds,
88 | attention_mask=atts,
89 | labels=label_ids,
90 | )
91 | return out
92 |
93 | def training_step(self, batch, batch_idx):
94 | mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids = batch
95 | embeds, atts, label_ids = self.encode(mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
96 | outputs = self.forward(embeds, atts, label_ids)
97 | loss = outputs["loss"]
98 | self.log("train/loss", loss, on_epoch=False)
99 | return loss
100 |
101 | def validation_step(self, batch, batch_idx):
102 | mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids = batch
103 | embeds, atts, label_ids = self.encode(mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
104 | outputs = self.forward(embeds, atts, label_ids)
105 | loss = outputs["loss"]
106 | self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
107 |
108 | logits = outputs.logits
109 | predicted_ids = torch.argmax(logits, dim=-1).cpu()
110 |
111 | generated_output_text = self.llm_tokenizer.decode(predicted_ids[0], skip_special_tokens=False)
112 | target_text = self.llm_tokenizer.decode(output_tokenized_ids[0], skip_special_tokens=False)
113 |
114 | extracted_pred = self.extract_prediction_values(generated_output_text)
115 | extracted_target = self.extract_prediction_values(target_text)
116 |
117 | keys = extracted_target.keys()
118 | pred_keys = extracted_pred.keys()
119 |
120 | for key in keys:
121 | if key not in pred_keys:
122 | extracted_pred[key] = "NA"
123 |
124 | if 'Transcript' in keys:
125 | target_transcript = extracted_target['Transcript']
126 | predicted_transcript = extracted_pred['Transcript']
127 | wer_metric = wer(target_transcript.lower(), predicted_transcript.lower())
128 | self.log("val/wer", wer_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True)
129 |
130 | if 'Response' in keys:
131 | target_transcript = extracted_target['Response']
132 | predicted_transcript = extracted_pred['Response']
133 | wer_metric = wer(target_transcript.lower(), predicted_transcript.lower())
134 | self.log("val/response_wer", wer_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True)
135 |
136 | if 'SpeechActivity' in keys:
137 | target_isspeech = extracted_target['SpeechActivity']
138 | predicted_isspeech = extracted_pred['SpeechActivity']
139 | self.log("val/speech_activity", float(target_isspeech.lower()==predicted_isspeech.lower()), on_step=False, on_epoch=True, prog_bar=True, logger=True)
140 |
141 | if 'Gender' in keys:
142 | target_gender = extracted_target['Gender']
143 | predicted_gender = extracted_pred['Gender']
144 | self.log("val/gender", float(target_gender.lower()==predicted_gender.lower()), on_step=False, on_epoch=True, prog_bar=True, logger=True)
145 |
146 | if 'Emotion' in keys:
147 | target_emotion = extracted_target['Emotion']
148 | predicted_emotion = extracted_pred['Emotion']
149 | self.log("val/emotion", float(target_emotion.lower()==predicted_emotion.lower()), on_step=False, on_epoch=True, prog_bar=True, logger=True)
150 |
151 | if 'Age' in keys:
152 | target_age = extracted_target['Age']
153 | predicted_age = extracted_pred['Age']
154 | self.log("val/age", float(target_age.lower()==predicted_age.lower()), on_step=False, on_epoch=True, prog_bar=True, logger=True)
155 |
156 | if 'Accent' in keys:
157 | target_accent = extracted_target['Accent']
158 | predicted_accent = extracted_pred['Accent']
159 | self.log("val/accent", float(target_accent.lower()==predicted_accent.lower()), on_step=False, on_epoch=True, prog_bar=True, logger=True)
160 |
161 | if batch_idx in self.selected_samples_for_logging:
162 | sample_idx = self.selected_samples_for_logging.index(batch_idx)
163 | # Use wandb.log to log prediction and truth texts
164 | wandb.log({
165 | f"val_sample_{sample_idx}_pred": wandb.Html(f"{str(extracted_pred)}
"),
166 | f"val_sample_{sample_idx}_target": wandb.Html(f"{str(target_text).replace('', '').replace('', '')}
"),
167 | f"val_sample_{sample_idx}_gen": wandb.Html(f"{generated_output_text.replace('', '').replace('', '')}
"),
168 | }, commit=False)
169 |
170 | return {"val_loss": loss}
171 |
172 | def test_step(self, batch, batch_idx):
173 | mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids = batch
174 | embeds, atts, label_ids = self.encode(mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
175 | outputs = self.forward(embeds, atts, label_ids)
176 | loss = outputs["loss"]
177 | self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
178 |
179 | logits = outputs.logits
180 | predicted_ids = torch.argmax(logits, dim=-1)
181 |
182 | input_token_length = output_tokenized_ids.shape[1]
183 | generated_output_text = self.llm_tokenizer.decode(predicted_ids[0], skip_special_tokens=False)
184 | target_text = self.llm_tokenizer.decode(output_tokenized_ids[0], skip_special_tokens=False)
185 |
186 | extracted_pred = self.extract_prediction_values(generated_output_text)
187 | extracted_target = self.extract_prediction_values(target_text)
188 |
189 | keys = extracted_target.keys()
190 | pred_keys = extracted_pred.keys()
191 |
192 | for key in keys:
193 | if key not in pred_keys:
194 | extracted_pred[key] = "NA"
195 |
196 | if 'Transcript' in keys:
197 | target_transcript = extracted_target['Transcript']
198 | predicted_transcript = extracted_pred['Transcript']
199 | wer_metric = wer(target_transcript.lower(), predicted_transcript.lower())
200 | self.log("val/wer", wer_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True)
201 |
202 | if 'Response' in keys:
203 | target_transcript = extracted_target['Response']
204 | predicted_transcript = extracted_pred['Response']
205 | wer_metric = wer(target_transcript.lower(), predicted_transcript.lower())
206 | self.log("val/response_wer", wer_metric, on_step=False, on_epoch=True, prog_bar=True, logger=True)
207 |
208 | if 'SpeechActivity' in keys:
209 | target_isspeech = extracted_target['SpeechActivity']
210 | predicted_isspeech = extracted_pred['SpeechActivity']
211 | self.log("val/speech_activity", float(target_isspeech.lower()==predicted_isspeech.lower()), on_step=False, on_epoch=True, prog_bar=True, logger=True)
212 |
213 | if 'Gender' in keys:
214 | target_gender = extracted_target['Gender']
215 | predicted_gender = extracted_pred['Gender']
216 | self.log("val/gender", float(target_gender.lower()==predicted_gender.lower()), on_step=False, on_epoch=True, prog_bar=True, logger=True)
217 |
218 | if 'Emotion' in keys:
219 | target_emotion = extracted_target['Emotion']
220 | predicted_emotion = extracted_pred['Emotion']
221 | self.log("val/emotion", float(target_emotion.lower()==predicted_emotion.lower()), on_step=False, on_epoch=True, prog_bar=True, logger=True)
222 |
223 | if 'Age' in keys:
224 | target_age = extracted_target['Age']
225 | predicted_age = extracted_pred['Age']
226 | self.log("val/age", float(target_age.lower()==predicted_age.lower()), on_step=False, on_epoch=True, prog_bar=True, logger=True)
227 |
228 | if 'Accent' in keys:
229 | target_accent = extracted_target['Accent']
230 | predicted_accent = extracted_pred['Accent']
231 | self.log("val/accent", float(target_accent.lower()==predicted_accent.lower()), on_step=False, on_epoch=True, prog_bar=True, logger=True)
232 |
233 | return {"val_loss": loss}
234 |
235 | def on_validation_epoch_start(self):
236 | """Select two random validation samples to log for each epoch."""
237 | self.selected_samples_for_logging = random.sample(range(self.num_validation_samples), 2)
238 |
239 |
240 | def extract_dictionary(self, input_string):
241 | pattern = r'\s*(\{.*?\})\s*'
242 | match = re.search(pattern, input_string, re.DOTALL)
243 | if match:
244 | dict_string = match.group(1)
245 | dict_string = re.sub(r',\s*}', '}', dict_string)
246 | try:
247 | return json.loads(dict_string)
248 | except json.JSONDecodeError as e:
249 | return {}
250 | else:
251 | return {}
252 |
253 | def extract_prediction_values(self, input_string):
254 | json_str_match = re.search(r'\s*\{.*?\}\s*', input_string)
255 | try:
256 | json_str = json_str_match.group(0)
257 | except:
258 | json_str = '{}'
259 | return self.extract_dictionary(json_str)
--------------------------------------------------------------------------------