├── README.md ├── Preprocess_json.ipynb ├── NAR-BERT-ASR.ipynb ├── Pretrain_Encoder.ipynb ├── utils.py └── bert_asr.py /README.md: -------------------------------------------------------------------------------- 1 | # NAR-BERT-ASR 2 | 3 | ## Instructions 4 | - step 1: We use ESPnet recipe preprocess AISHELL-1 dataset to stage 2 5 | - step 2: Run Preprocess_json.ipynb to re-tokenize the json file 6 | - step 3: Pretrain the Encoder 7 | - step 4: Fine-tune the NAR-BERT-ASR 8 | -------------------------------------------------------------------------------- /Preprocess_json.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "##Preprocess\n", 10 | "from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertTokenizerFast\n", 11 | "\n", 12 | "## BERT\n", 13 | "PRETRAINED_MODEL_NAME = \"bert-base-chinese\"\n", 14 | "tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)\n", 15 | "\n", 16 | "## bert_tokenize tokenize text to id\n", 17 | "## run {train, dev, test set}\n", 18 | "import re\n", 19 | "import json\n", 20 | "json_file = \"\"\n", 21 | "\n", 22 | "with open(\"./espnet/egs/aishell/asr1/dump/train/deltafalse/data.json\", \"r\", encoding=\"utf8\") as f:\n", 23 | " json_file = json.load(f)['utts']\n", 24 | "\n", 25 | "for key in json_file.keys():\n", 26 | " token = tokenizer(json_file[key][\"output\"][0][\"text\"])[\"input_ids\"]\n", 27 | " token = [str(i) for i in token]\n", 28 | " json_file[key][\"output\"][0][\"token\"] = \" \".join(tokenizer.tokenize(json_file[key][\"output\"][0][\"text\"]))\n", 29 | " json_file[key][\"output\"][0][\"shape\"][0] = len(token)\n", 30 | " #to id\n", 31 | " json_file[key][\"output\"][0][\"tokenid\"] = \" \".join(token)\n", 32 | "\n", 33 | "json_file={'utts':json_file}\n", 34 | "with open(\"./espnet/egs/aishell/asr1/dump/train/deltafalse/data.json\", \"w\", encoding=\"utf8\") as f:\n", 35 | " f.write(json.dumps(json_file, sort_keys=True, ensure_ascii=False, indent=4))\n" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [] 44 | } 45 | ], 46 | "metadata": { 47 | "kernelspec": { 48 | "display_name": "Python 3", 49 | "language": "python", 50 | "name": "python3" 51 | }, 52 | "language_info": { 53 | "codemirror_mode": { 54 | "name": "ipython", 55 | "version": 3 56 | }, 57 | "file_extension": ".py", 58 | "mimetype": "text/x-python", 59 | "name": "python", 60 | "nbconvert_exporter": "python", 61 | "pygments_lexer": "ipython3", 62 | "version": "3.7.6" 63 | } 64 | }, 65 | "nbformat": 4, 66 | "nbformat_minor": 4 67 | } 68 | -------------------------------------------------------------------------------- /NAR-BERT-ASR.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import torch\n", 11 | "from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertTokenizerFast, BertForSequenceClassification\n", 12 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import numpy as np\n", 22 | "import kaldiio\n", 23 | "import json\n", 24 | "from torch.nn import CrossEntropyLoss\n", 25 | "from torch.nn.utils.rnn import pad_sequence\n", 26 | "from torch.utils.data.dataset import Dataset\n", 27 | "from torch.utils.data import DataLoader\n", 28 | "from torch.utils.data.dataloader import default_collate\n", 29 | "from utils import *\n", 30 | "\n", 31 | "train_json = \"/path/espnet/egs/aishell/asrNoperAddBERTToken/dump/train/deltafalse/data.json\"\n", 32 | "with open(train_json, \"r\") as f:\n", 33 | " train_json = json.load(f)[\"utts\"]\n", 34 | "trainset = make_batchset(train_json,\n", 35 | " min_batch_size=1,\n", 36 | " shortest_first=True,\n", 37 | " count=\"frame\",\n", 38 | " batch_frames_in=10000,\n", 39 | " )\n", 40 | "\n", 41 | "dev_json = \"/path/espnet/egs/aishell/asrNoperAddBERTToken/dump/dev/deltafalse/data.json\"\n", 42 | "with open(dev_json, \"r\") as f:\n", 43 | " dev_json = json.load(f)[\"utts\"]\n", 44 | "devset = make_batchset(dev_json,\n", 45 | " min_batch_size=1,\n", 46 | " shortest_first=True,\n", 47 | " count=\"frame\",\n", 48 | " batch_frames_in=10000,\n", 49 | " )\n", 50 | "\n", 51 | "def collate(minibatch):\n", 52 | " fbanks = []\n", 53 | " tokens = []\n", 54 | " for key, info in minibatch[0]:\n", 55 | " fbanks.append(torch.tensor(spec_augment(kaldiio.load_mat(info[\"input\"][0][\"feat\"]))))\n", 56 | " s = info[\"output\"][0][\"tokenid\"].split()\n", 57 | " if len(s)<60:\n", 58 | " for i in range(60-len(s)):\n", 59 | " s+=[torch.tensor([0])]\n", 60 | " if len(s)>60:\n", 61 | " s=s[0:60]\n", 62 | " tokens.append(torch.tensor([int(st) for st in s]))\n", 63 | " ilens = torch.tensor([x.shape[0] for x in fbanks])\n", 64 | " return pad_sequence(fbanks, batch_first=True), pad_sequence(tokens, batch_first=True)\n", 65 | "\n", 66 | "def collate_dev(minibatch):\n", 67 | " fbanks = []\n", 68 | " tokens = []\n", 69 | " for key, info in minibatch[0]:\n", 70 | " fbanks.append(torch.tensor(kaldiio.load_mat(info[\"input\"][0][\"feat\"])))\n", 71 | " s = info[\"output\"][0][\"tokenid\"].split()\n", 72 | " if len(s)<60:\n", 73 | " for i in range(60-len(s)):\n", 74 | " s+=[torch.tensor([0])]\n", 75 | " if len(s)>60:\n", 76 | " s=s[0:60]\n", 77 | " tokens.append(torch.tensor([int(st) for st in s]))\n", 78 | " ilens = torch.tensor([x.shape[0] for x in fbanks])\n", 79 | " return pad_sequence(fbanks, batch_first=True), pad_sequence(tokens, batch_first=True)\n", 80 | "\n", 81 | "train_loader = DataLoader(trainset, collate_fn=collate, shuffle=True, pin_memory=True)\n", 82 | "dev_loader = DataLoader(devset, collate_fn=collate_dev, pin_memory=True)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "from distutils.version import LooseVersion\n", 92 | "from typing import Union\n", 93 | "from torch.optim.lr_scheduler import _LRScheduler\n", 94 | "\n", 95 | "class WarmupLR(_LRScheduler):\n", 96 | " def __init__(\n", 97 | " self,\n", 98 | " optimizer: torch.optim.Optimizer,\n", 99 | " warmup_steps: Union[int, float] = 25000,\n", 100 | " last_epoch: int = -1,\n", 101 | " ):\n", 102 | " self.warmup_steps = warmup_steps\n", 103 | " super().__init__(optimizer, last_epoch)\n", 104 | "\n", 105 | " def __repr__(self):\n", 106 | " return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n", 107 | "\n", 108 | " def get_lr(self):\n", 109 | " step_num = self.last_epoch + 1\n", 110 | " return [\n", 111 | " (768) ** -0.5\n", 112 | " * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)*0.1\n", 113 | " for lr in self.base_lrs\n", 114 | " ]\n" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "import bert_asr\n", 124 | "from tqdm.notebook import trange, tqdm\n", 125 | "\n", 126 | "PRETRAINED_MODEL_NAME = \"bert-base-chinese\"\n", 127 | "bertmodel = bert_asr.BertForMaskedLMForBERTASR.from_pretrained(PRETRAINED_MODEL_NAME)\n", 128 | "encoder_model = bert_asr.BERTASR_Encoder(83,21128)\n", 129 | "device = torch.device('cpu')\n", 130 | "encoder_model.load_state_dict(torch.load(\"./path/pretraining_avg10\", map_location=device))\n", 131 | "model = bert_asr.BERTASR(encoder_model, bertmodel)\n", 132 | "model = model.cuda()\n", 133 | "\n", 134 | "optimizer =torch.optim.Adam(model.parameters(), lr=0.001)\n", 135 | "scheduler = WarmupLR(optimizer, 12000)\n", 136 | "EPOCHS = 130\n", 137 | "\n", 138 | "print(\"----Training Start----\")\n", 139 | "optimizer.zero_grad()\n", 140 | "for epoch in range(EPOCHS):\n", 141 | "\n", 142 | " running_loss = 0.0\n", 143 | " train_step = 0\n", 144 | " train_index = 0\n", 145 | " model.train()\n", 146 | " for x, y in tqdm(train_loader):\n", 147 | " # forward pass\n", 148 | " outputs = model(x.cuda(), y.cuda())[0]\n", 149 | " loss = outputs/12.0\n", 150 | " # backward\n", 151 | " loss.backward()\n", 152 | " train_index += 1\n", 153 | " if train_index % 12 == 0:\n", 154 | " optimizer.step()\n", 155 | " scheduler.step()\n", 156 | " optimizer.zero_grad()\n", 157 | " train_index = 0\n", 158 | "\n", 159 | " # log batch loss\n", 160 | " running_loss += loss.item()\n", 161 | " train_step += 1\n", 162 | " #\n", 163 | " model.eval()\n", 164 | " dev_loss = 0.0\n", 165 | " dev_loss2 = 0.0\n", 166 | " dev_step = 0\n", 167 | " with torch.no_grad():\n", 168 | " for dev_x, dev_y in tqdm(dev_loader):\n", 169 | " outputs = model(dev_x.cuda(), dev_y.cuda())[0]\n", 170 | " dev_loss += outputs.item()\n", 171 | " dev_step += 1\n", 172 | " #\n", 173 | " outputs = model(dev_x.cuda())[0]\n", 174 | " loss_fct = CrossEntropyLoss()\n", 175 | " outputs = loss_fct(outputs.view(-1, model.encoder.odim), dev_y.cuda().view(-1))\n", 176 | " dev_loss2 += outputs.item()\n", 177 | " print('[epoch %d] train loss: %.3f | dev loss: %.3f | dev loss(w/o sm): %.3f' %\n", 178 | " (epoch + 1, running_loss/train_step*12, dev_loss/dev_step, dev_loss2/dev_step))\n", 179 | " torch.save(model.state_dict(), \"./bertasr.\"+str(epoch + 1))\n", 180 | " print(\"save for epoch:\" + str(epoch+1))\n" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "#Average 10 model 121~130\n", 190 | "bert_asr.avg_model(root=\"./bertasr.\", avg_num=10, last_num=130, save_path=\"./bertasr_avg10\")" 191 | ] 192 | } 193 | ], 194 | "metadata": { 195 | "kernelspec": { 196 | "display_name": "Python 3", 197 | "language": "python", 198 | "name": "python3" 199 | }, 200 | "language_info": { 201 | "codemirror_mode": { 202 | "name": "ipython", 203 | "version": 3 204 | }, 205 | "file_extension": ".py", 206 | "mimetype": "text/x-python", 207 | "name": "python", 208 | "nbconvert_exporter": "python", 209 | "pygments_lexer": "ipython3", 210 | "version": "3.7.6" 211 | } 212 | }, 213 | "nbformat": 4, 214 | "nbformat_minor": 4 215 | } 216 | -------------------------------------------------------------------------------- /Pretrain_Encoder.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import torch\n", 20 | "from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertTokenizerFast\n", 21 | "\n", 22 | "## BERT\n", 23 | "PRETRAINED_MODEL_NAME = \"bert-base-chinese\"\n", 24 | "ch_bert = BertModel.from_pretrained(PRETRAINED_MODEL_NAME)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import numpy as np\n", 34 | "import kaldiio\n", 35 | "import json\n", 36 | "from torch.nn import CrossEntropyLoss\n", 37 | "from torch.nn.utils.rnn import pad_sequence\n", 38 | "from torch.utils.data.dataset import Dataset\n", 39 | "from torch.utils.data import DataLoader\n", 40 | "from torch.utils.data.dataloader import default_collate\n", 41 | "from utils import *\n", 42 | "\n", 43 | "train_json = \"/path/espnet/egs/aishell/asrNoperAddBERTToken/dump/train/deltafalse/data.json\"\n", 44 | "with open(train_json, \"r\") as f:\n", 45 | " train_json = json.load(f)[\"utts\"]\n", 46 | "trainset = make_batchset(train_json,\n", 47 | " min_batch_size=1,\n", 48 | " shortest_first=True,\n", 49 | " count=\"frame\",\n", 50 | " batch_frames_in=10000,\n", 51 | " )\n", 52 | "\n", 53 | "dev_json = \"/path/espnet/egs/aishell/asrNoperAddBERTToken/dump/dev/deltafalse/data.json\"\n", 54 | "with open(dev_json, \"r\") as f:\n", 55 | " dev_json = json.load(f)[\"utts\"]\n", 56 | "devset = make_batchset(dev_json,\n", 57 | " min_batch_size=1,\n", 58 | " shortest_first=True,\n", 59 | " count=\"frame\",\n", 60 | " batch_frames_in=10000,\n", 61 | " )\n", 62 | "\n", 63 | "def collate(minibatch):\n", 64 | " fbanks = []\n", 65 | " tokens = []\n", 66 | " for key, info in minibatch[0]:\n", 67 | " fbanks.append(torch.tensor(spec_augment(kaldiio.load_mat(info[\"input\"][0][\"feat\"]))))\n", 68 | " s = info[\"output\"][0][\"tokenid\"].split()\n", 69 | " if len(s)<60:\n", 70 | " for i in range(60-len(s)):\n", 71 | " s+=[torch.tensor([0])]\n", 72 | " if len(s)>60:\n", 73 | " s=s[0:60]\n", 74 | " tokens.append(torch.tensor([int(st) for st in s]))\n", 75 | " ilens = torch.tensor([x.shape[0] for x in fbanks])\n", 76 | " return pad_sequence(fbanks, batch_first=True), pad_sequence(tokens, batch_first=True)\n", 77 | "\n", 78 | "def collate_dev(minibatch):\n", 79 | " fbanks = []\n", 80 | " tokens = []\n", 81 | " for key, info in minibatch[0]:\n", 82 | " fbanks.append(torch.tensor(kaldiio.load_mat(info[\"input\"][0][\"feat\"])))\n", 83 | " s = info[\"output\"][0][\"tokenid\"].split()\n", 84 | " if len(s)<60:\n", 85 | " for i in range(60-len(s)):\n", 86 | " s+=[torch.tensor([0])]\n", 87 | " if len(s)>60:\n", 88 | " s=s[0:60]\n", 89 | " tokens.append(torch.tensor([int(st) for st in s]))\n", 90 | " ilens = torch.tensor([x.shape[0] for x in fbanks])\n", 91 | " return pad_sequence(fbanks, batch_first=True), pad_sequence(tokens, batch_first=True)\n", 92 | "\n", 93 | "train_loader = DataLoader(trainset, collate_fn=collate, shuffle=True, pin_memory=True)\n", 94 | "dev_loader = DataLoader(devset, collate_fn=collate_dev, pin_memory=True)\n" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "from distutils.version import LooseVersion\n", 104 | "from typing import Union\n", 105 | "from torch.optim.lr_scheduler import _LRScheduler\n", 106 | "\n", 107 | "class WarmupLR(_LRScheduler):\n", 108 | " def __init__(\n", 109 | " self,\n", 110 | " optimizer: torch.optim.Optimizer,\n", 111 | " warmup_steps: Union[int, float] = 25000,\n", 112 | " last_epoch: int = -1,\n", 113 | " ):\n", 114 | " self.warmup_steps = warmup_steps\n", 115 | " super().__init__(optimizer, last_epoch)\n", 116 | "\n", 117 | " def __repr__(self):\n", 118 | " return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n", 119 | "\n", 120 | " def get_lr(self):\n", 121 | " step_num = self.last_epoch + 1\n", 122 | " return [\n", 123 | " (768) ** -0.5\n", 124 | " * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)\n", 125 | " for lr in self.base_lrs\n", 126 | " ]\n" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "import bert_asr\n", 136 | "from tqdm.notebook import trange, tqdm\n", 137 | "\n", 138 | "model = bert_asr.BERTASR_Encoder(83, 21128)\n", 139 | "BERTWordEmbed = ch_bert.embeddings.word_embeddings.state_dict()\n", 140 | "model.classifier.load_state_dict(BERTWordEmbed)\n", 141 | "model = model.cuda()\n", 142 | "\n", 143 | "# Pre-training\n", 144 | "for param in model.classifier.parameters():\n", 145 | " param.requires_grad = False\n", 146 | "\n", 147 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", 148 | "scheduler = WarmupLR(optimizer, 12000)\n", 149 | "EPOCHS = 130\n", 150 | "\n", 151 | "print(\"----Training Start----\")\n", 152 | "optimizer.zero_grad()\n", 153 | "for epoch in range(EPOCHS):\n", 154 | " running_loss = 0.0\n", 155 | " train_step = 0\n", 156 | " train_index = 0\n", 157 | " model.train()\n", 158 | " for x, y in tqdm(train_loader):\n", 159 | " # forward pass\n", 160 | " outputs = model(x.cuda(), y.cuda())\n", 161 | " loss = outputs/12.0\n", 162 | " # backward\n", 163 | " loss.backward()\n", 164 | " train_index += 1\n", 165 | " if train_index % 12 == 0:\n", 166 | " optimizer.step()\n", 167 | " scheduler.step()\n", 168 | " optimizer.zero_grad()\n", 169 | " train_index = 0\n", 170 | "\n", 171 | " # log batch loss\n", 172 | " running_loss += loss.item()\n", 173 | " train_step += 1\n", 174 | " #\n", 175 | " model.eval()\n", 176 | " dev_loss = 0.0\n", 177 | " dev_loss2 = 0.0\n", 178 | " dev_step = 0\n", 179 | " with torch.no_grad():\n", 180 | " for dev_x, dev_y in tqdm(dev_loader):\n", 181 | " outputs = model(dev_x.cuda(), dev_y.cuda())\n", 182 | " dev_loss += outputs.item()\n", 183 | " dev_step += 1\n", 184 | "\n", 185 | " outputs = model(dev_x.cuda())\n", 186 | " loss_fct = CrossEntropyLoss()\n", 187 | " outputs = loss_fct(outputs.view(-1, model.odim),\n", 188 | " dev_y.cuda().view(-1))\n", 189 | " dev_loss2 += outputs.item()\n", 190 | "\n", 191 | " print('[epoch %d] train loss: %.3f | dev loss: %.3f | dev loss(w/o sm): %.3f' %\n", 192 | " (epoch + 1, running_loss/train_step*12, dev_loss/dev_step, dev_loss2/dev_step))\n", 193 | " torch.save(model.state_dict(), \"./pretraining.\"+str(epoch + 1))\n" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "#Average 10 model 121~130\n", 203 | "bert_asr.avg_model(root=\"./pretraining.\", avg_num=10, last_num=130, save_path=\"./pretraining_avg10\")" 204 | ] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 3", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.7.6" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 4 228 | } 229 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy 3 | import numpy as np 4 | from PIL import Image 5 | from PIL.Image import BICUBIC 6 | import logging 7 | 8 | def batchfy_by_frame( 9 | sorted_data, 10 | max_frames_in, 11 | max_frames_out, 12 | max_frames_inout, 13 | num_batches=0, 14 | min_batch_size=1, 15 | shortest_first=False, 16 | ikey="input", 17 | okey="output", 18 | ): 19 | if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0: 20 | raise ValueError( 21 | "At least, one of `--batch-frames-in`, `--batch-frames-out` or " 22 | "`--batch-frames-inout` should be > 0" 23 | ) 24 | length = len(sorted_data) 25 | minibatches = [] 26 | start = 0 27 | end = 0 28 | while end != length: 29 | # Dynamic batch size depending on size of samples 30 | b = 0 31 | max_olen = 0 32 | max_ilen = 0 33 | while (start + b) < length: 34 | ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) 35 | if ilen > max_frames_in and max_frames_in != 0: 36 | raise ValueError( 37 | f"Can't fit one sample in --batch-frames-in ({max_frames_in}): " 38 | f"Please increase the value" 39 | ) 40 | olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) 41 | if olen > max_frames_out and max_frames_out != 0: 42 | raise ValueError( 43 | f"Can't fit one sample in --batch-frames-out ({max_frames_out}): " 44 | f"Please increase the value" 45 | ) 46 | if ilen + olen > max_frames_inout and max_frames_inout != 0: 47 | raise ValueError( 48 | f"Can't fit one sample in --batch-frames-out ({max_frames_inout}): " 49 | f"Please increase the value" 50 | ) 51 | max_olen = max(max_olen, olen) 52 | max_ilen = max(max_ilen, ilen) 53 | in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0 54 | out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0 55 | inout_ok = (max_ilen + max_olen) * ( 56 | b + 1 57 | ) <= max_frames_inout or max_frames_inout == 0 58 | if in_ok and out_ok and inout_ok: 59 | # add more seq in the minibatch 60 | b += 1 61 | else: 62 | # no more seq in the minibatch 63 | break 64 | end = min(length, start + b) 65 | batch = sorted_data[start:end] 66 | if shortest_first: 67 | batch.reverse() 68 | minibatches.append(batch) 69 | # Check for min_batch_size and fixes the batches if needed 70 | i = -1 71 | while len(minibatches[i]) < min_batch_size: 72 | missing = min_batch_size - len(minibatches[i]) 73 | if -i == len(minibatches): 74 | minibatches[i + 1].extend(minibatches[i]) 75 | minibatches = minibatches[1:] 76 | break 77 | else: 78 | minibatches[i].extend(minibatches[i - 1][:missing]) 79 | minibatches[i - 1] = minibatches[i - 1][missing:] 80 | i -= 1 81 | start = end 82 | if num_batches > 0: 83 | minibatches = minibatches[:num_batches] 84 | lengths = [len(x) for x in minibatches] 85 | logging.info( 86 | str(len(minibatches)) 87 | + " batches containing from " 88 | + str(min(lengths)) 89 | + " to " 90 | + str(max(lengths)) 91 | + " samples" 92 | + "(avg " 93 | + str(int(np.mean(lengths))) 94 | + " samples)." 95 | ) 96 | 97 | return minibatches 98 | BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"] 99 | BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"] 100 | 101 | def make_batchset( 102 | data, 103 | batch_size=0, 104 | max_length_in=float("inf"), 105 | max_length_out=float("inf"), 106 | num_batches=0, 107 | min_batch_size=1, 108 | shortest_first=False, 109 | batch_sort_key="input", 110 | swap_io=False, 111 | mt=False, 112 | count="auto", 113 | batch_bins=0, 114 | batch_frames_in=0, 115 | batch_frames_out=0, 116 | batch_frames_inout=0, 117 | iaxis=0, 118 | oaxis=0, 119 | ): 120 | 121 | # check args 122 | if count not in BATCH_COUNT_CHOICES: 123 | raise ValueError( 124 | f"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}" 125 | ) 126 | if batch_sort_key not in BATCH_SORT_KEY_CHOICES: 127 | raise ValueError( 128 | f"arg 'batch_sort_key' ({batch_sort_key}) should be " 129 | f"one of {BATCH_SORT_KEY_CHOICES}" 130 | ) 131 | 132 | # TODO(karita): remove this by creating converter from ASR to TTS json format 133 | batch_sort_axis = 0 134 | if swap_io: 135 | # for TTS 136 | ikey = "output" 137 | okey = "input" 138 | if batch_sort_key == "input": 139 | batch_sort_key = "output" 140 | elif batch_sort_key == "output": 141 | batch_sort_key = "input" 142 | elif mt: 143 | # for MT 144 | ikey = "output" 145 | okey = "output" 146 | batch_sort_key = "output" 147 | batch_sort_axis = 1 148 | assert iaxis == 1 149 | assert oaxis == 0 150 | # NOTE: input is json['output'][1] and output is json['output'][0] 151 | else: 152 | ikey = "input" 153 | okey = "output" 154 | 155 | if count == "auto": 156 | if batch_size != 0: 157 | count = "seq" 158 | elif batch_bins != 0: 159 | count = "bin" 160 | elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0: 161 | count = "frame" 162 | else: 163 | raise ValueError( 164 | f"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}" 165 | ) 166 | logging.info(f"count is auto detected as {count}") 167 | 168 | if count != "seq" and batch_sort_key == "shuffle": 169 | raise ValueError("batch_sort_key=shuffle is only available if batch_count=seq") 170 | 171 | category2data = {} # Dict[str, dict] 172 | for k, v in data.items(): 173 | category2data.setdefault(v.get("category"), {})[k] = v 174 | 175 | batches_list = [] # List[List[List[Tuple[str, dict]]]] 176 | for d in category2data.values(): 177 | if batch_sort_key == "shuffle": 178 | batches = batchfy_shuffle( 179 | d, batch_size, min_batch_size, num_batches, shortest_first 180 | ) 181 | batches_list.append(batches) 182 | continue 183 | 184 | # sort it by input lengths (long to short) 185 | sorted_data = sorted( 186 | d.items(), 187 | key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]), 188 | reverse=not shortest_first, 189 | ) 190 | logging.info("# utts: " + str(len(sorted_data))) 191 | if count == "seq": 192 | raise ValueError(f"Error") 193 | if count == "bin": 194 | raise ValueError(f"Error") 195 | if count == "frame": 196 | batches = batchfy_by_frame( 197 | sorted_data, 198 | max_frames_in=batch_frames_in, 199 | max_frames_out=batch_frames_out, 200 | max_frames_inout=batch_frames_inout, 201 | min_batch_size=min_batch_size, 202 | shortest_first=shortest_first, 203 | ikey=ikey, 204 | okey=okey, 205 | ) 206 | batches_list.append(batches) 207 | 208 | if len(batches_list) == 1: 209 | batches = batches_list[0] 210 | else: 211 | # Concat list. This way is faster than "sum(batch_list, [])" 212 | batches = list(itertools.chain(*batches_list)) 213 | 214 | # for debugging 215 | if num_batches > 0: 216 | batches = batches[:num_batches] 217 | logging.info("# minibatches: " + str(len(batches))) 218 | 219 | # batch: List[List[Tuple[str, dict]]] 220 | return batches 221 | 222 | #SpecAugment 223 | def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): 224 | """time warp for spec augment 225 | move random center frame by the random width ~ uniform(-window, window) 226 | :param numpy.ndarray x: spectrogram (time, freq) 227 | :param int max_time_warp: maximum time frames to warp 228 | :param bool inplace: overwrite x with the result 229 | :param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp" 230 | (slow, differentiable) 231 | :returns numpy.ndarray: time warped spectrogram (time, freq) 232 | """ 233 | window = max_time_warp 234 | if mode == "PIL": 235 | t = x.shape[0] 236 | if t - window <= window: 237 | return x 238 | # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 239 | center = random.randrange(window, t - window) 240 | warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1 241 | 242 | left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) 243 | right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC) 244 | if inplace: 245 | x[:warped] = left 246 | x[warped:] = right 247 | return x 248 | return numpy.concatenate((left, right), 0) 249 | elif mode == "sparse_image_warp": 250 | import torch 251 | 252 | from espnet.utils import spec_augment 253 | 254 | # TODO(karita): make this differentiable again 255 | return spec_augment.time_warp(torch.from_numpy(x), window).numpy() 256 | else: 257 | raise NotImplementedError( 258 | "unknown resize mode: " 259 | + mode 260 | + ", choose one from (PIL, sparse_image_warp)." 261 | ) 262 | 263 | 264 | def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False): 265 | """freq mask for spec agument 266 | :param numpy.ndarray x: (time, freq) 267 | :param int n_mask: the number of masks 268 | :param bool inplace: overwrite 269 | :param bool replace_with_zero: pad zero on mask if true else use mean 270 | """ 271 | if inplace: 272 | cloned = x 273 | else: 274 | cloned = x.copy() 275 | 276 | num_mel_channels = cloned.shape[1] 277 | fs = numpy.random.randint(0, F, size=(n_mask, 2)) 278 | 279 | for f, mask_end in fs: 280 | f_zero = random.randrange(0, num_mel_channels - f) 281 | mask_end += f_zero 282 | 283 | # avoids randrange error if values are equal and range is empty 284 | if f_zero == f_zero + f: 285 | continue 286 | 287 | if replace_with_zero: 288 | cloned[:, f_zero:mask_end] = 0 289 | else: 290 | cloned[:, f_zero:mask_end] = cloned.mean() 291 | return cloned 292 | 293 | 294 | 295 | 296 | def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False): 297 | """freq mask for spec agument 298 | :param numpy.ndarray spec: (time, freq) 299 | :param int n_mask: the number of masks 300 | :param bool inplace: overwrite 301 | :param bool replace_with_zero: pad zero on mask if true else use mean 302 | """ 303 | if inplace: 304 | cloned = spec 305 | else: 306 | cloned = spec.copy() 307 | len_spectro = cloned.shape[0] 308 | ts = numpy.random.randint(0, T, size=(n_mask, 2)) 309 | for t, mask_end in ts: 310 | # avoid randint range error 311 | if len_spectro - t <= 0: 312 | continue 313 | t_zero = random.randrange(0, len_spectro - t) 314 | 315 | # avoids randrange error if values are equal and range is empty 316 | if t_zero == t_zero + t: 317 | continue 318 | 319 | mask_end += t_zero 320 | if replace_with_zero: 321 | cloned[t_zero:mask_end] = 0 322 | else: 323 | cloned[t_zero:mask_end] = cloned.mean() 324 | return cloned 325 | 326 | 327 | def spec_augment( 328 | x, 329 | resize_mode="PIL", 330 | max_time_warp=5, 331 | max_freq_width=30, 332 | n_freq_mask=2, 333 | max_time_width=40, 334 | n_time_mask=2, 335 | inplace=True, 336 | replace_with_zero=False, 337 | ): 338 | """spec agument 339 | apply random time warping and time/freq masking 340 | default setting is based on LD (Librispeech double) in Table 2 341 | https://arxiv.org/pdf/1904.08779.pdf 342 | :param numpy.ndarray x: (time, freq) 343 | :param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp" 344 | (slow, differentiable) 345 | :param int max_time_warp: maximum frames to warp the center frame in spectrogram (W) 346 | :param int freq_mask_width: maximum width of the random freq mask (F) 347 | :param int n_freq_mask: the number of the random freq mask (m_F) 348 | :param int time_mask_width: maximum width of the random time mask (T) 349 | :param int n_time_mask: the number of the random time mask (m_T) 350 | :param bool inplace: overwrite intermediate array 351 | :param bool replace_with_zero: pad zero on mask if true else use mean 352 | """ 353 | assert isinstance(x, numpy.ndarray) 354 | assert x.ndim == 2 355 | x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode) 356 | x = freq_mask( 357 | x, 358 | max_freq_width, 359 | n_freq_mask, 360 | inplace=inplace, 361 | replace_with_zero=replace_with_zero, 362 | ) 363 | x = time_mask( 364 | x, 365 | max_time_width, 366 | n_time_mask, 367 | inplace=inplace, 368 | replace_with_zero=replace_with_zero, 369 | ) 370 | return x 371 | -------------------------------------------------------------------------------- /bert_asr.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | from torch.nn import CrossEntropyLoss, MSELoss, MultiheadAttention, Linear, Dropout, LayerNorm, ModuleList 8 | from transformers import BertTokenizer, BertModel, BertForSequenceClassification, BertPreTrainedModel, AdamW 9 | #from transformers.modeling_bert import * 10 | from transformers.models.bert.modeling_bert import * 11 | from typing import Optional, Any 12 | 13 | 14 | class BasicPositionalEncoding(nn.Module): 15 | def __init__(self, d_model=256, dropout=0.1, max_len=5000): 16 | super(BasicPositionalEncoding, self).__init__() 17 | self.dropout = nn.Dropout(p=dropout) 18 | 19 | pe = torch.zeros(max_len, d_model) 20 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 21 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 22 | pe[:, 0::2] = torch.sin(position * div_term) 23 | pe[:, 1::2] = torch.cos(position * div_term) 24 | pe = pe.unsqueeze(0) 25 | self.register_buffer('pe', pe) 26 | 27 | def forward(self, x): 28 | x = x + self.pe[:, : x.size(1)] 29 | return self.dropout(x) 30 | 31 | class Conv2dSubsampling(torch.nn.Module): 32 | """Convolutional 2D subsampling (to 1/4 length). 33 | Args: 34 | idim (int): Input dimension. 35 | odim (int): Output dimension. 36 | dropout_rate (float): Dropout rate. 37 | pos_enc (torch.nn.Module): Custom position encoding layer. 38 | """ 39 | def __init__(self, idim, odim=256, dropout_rate=0.1, pos_enc=None): 40 | """Construct an Conv2dSubsampling object.""" 41 | super(Conv2dSubsampling, self).__init__() 42 | self.conv = torch.nn.Sequential( 43 | torch.nn.Conv2d(1, 32, 3, 2), 44 | torch.nn.ReLU(), 45 | torch.nn.Conv2d(32, 32, 3, 2), 46 | torch.nn.ReLU(), 47 | ) 48 | self.out = torch.nn.Sequential( 49 | torch.nn.Linear(32 * (((idim - 1) // 2 - 1) // 2), odim), 50 | pos_enc if pos_enc is not None else BasicPositionalEncoding(odim, dropout_rate), 51 | ) 52 | 53 | def forward(self, x, x_mask=None): 54 | """Subsample x. 55 | Args: 56 | x (torch.Tensor): Input tensor (#batch, time, idim). 57 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 58 | Returns: 59 | torch.Tensor: Subsampled tensor (#batch, time', odim), 60 | where time' = time // 4. 61 | torch.Tensor: Subsampled mask (#batch, 1, time'), 62 | where time' = time // 4. 63 | """ 64 | x = x.unsqueeze(1) # (b, c, t, f) 65 | x = self.conv(x) 66 | b, c, t, f = x.size() 67 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 68 | if x_mask is None: 69 | return x 70 | return x 71 | 72 | class PositionalEncoding(nn.Module): 73 | def __init__(self, d_model=256, dropout=0.1, max_len=60): 74 | super(PositionalEncoding, self).__init__() 75 | self.dropout = nn.Dropout(p=dropout) 76 | pe = torch.zeros(max_len, d_model) 77 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 78 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 79 | pe[:, 0::2] = torch.sin(position * div_term) 80 | pe[:, 1::2] = torch.cos(position * div_term) 81 | pe = pe.unsqueeze(0).transpose(0, 1) 82 | self.register_buffer('pe', pe) 83 | 84 | def forward(self): 85 | #x: [seq, batch ,d] 86 | x = self.pe[:, :] 87 | return self.dropout(x) 88 | 89 | def _get_clones(module, N): 90 | return ModuleList([copy.deepcopy(module) for i in range(N)]) 91 | 92 | class PDSLayer(nn.Module): 93 | def __init__(self, d_model=256, nhead=8, dim_feedforward=2048, dropout=0.1): 94 | super(PDSLayer, self).__init__() 95 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 96 | # Implementation of Feedforward model 97 | self.linear1 = Linear(d_model, dim_feedforward) 98 | self.dropout = Dropout(dropout) 99 | self.linear2 = Linear(dim_feedforward//2, d_model) 100 | self.norm2 = LayerNorm(d_model) 101 | self.norm3 = LayerNorm(d_model) 102 | self.dropout2 = Dropout(dropout) 103 | self.dropout3 = Dropout(dropout) 104 | self.activation = F.glu 105 | 106 | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, 107 | tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 108 | tgt = self.norm2(tgt) 109 | tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, 110 | key_padding_mask=memory_key_padding_mask)[0] 111 | tgt = tgt + self.dropout2(tgt2) 112 | tgt = self.norm3(tgt) 113 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 114 | tgt = tgt + self.dropout3(tgt2) 115 | return tgt 116 | 117 | 118 | class PDS(nn.Module): 119 | def __init__(self, decoder_layer, num_layers=4, norm=None): 120 | super(PDS, self).__init__() 121 | self.position_multihead_attn = MultiheadAttention(256, 8, dropout=0.1) 122 | self.norm1 = LayerNorm(256) 123 | self.linear1 = Linear(256, 2048) 124 | self.dropout = Dropout(0.1) 125 | self.linear2 = Linear(2048//2, 256) 126 | self.activation = F.glu 127 | self.dropout1 = Dropout(0.1) 128 | self.dropout2 = Dropout(0.1) 129 | self.layers = _get_clones(decoder_layer, num_layers-1) 130 | self.num_layers = num_layers-1 131 | self.norm = norm 132 | 133 | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, 134 | memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, 135 | memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 136 | 137 | tgt2 = self.position_multihead_attn(tgt, memory, memory, attn_mask=tgt_mask, 138 | key_padding_mask=tgt_key_padding_mask)[0] 139 | tgt = tgt + self.dropout1(tgt2) 140 | tgt = self.norm1(tgt) 141 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 142 | tgt = tgt + self.dropout2(tgt2) 143 | ####### 144 | output = tgt 145 | for mod in self.layers: 146 | output = mod(output, memory, tgt_mask=tgt_mask, 147 | memory_mask=memory_mask, 148 | tgt_key_padding_mask=tgt_key_padding_mask, 149 | memory_key_padding_mask=memory_key_padding_mask) 150 | 151 | if self.norm is not None: 152 | output = self.norm(output) 153 | 154 | return output 155 | 156 | 157 | class TransformerEncoderLayerPre(nn.Module): 158 | def __init__(self, d_model=256, nhead=8, dim_feedforward=2048, dropout=0.1): 159 | super(TransformerEncoderLayerPre, self).__init__() 160 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 161 | # Implementation of Feedforward model 162 | self.linear1 = Linear(d_model, dim_feedforward) 163 | self.dropout = Dropout(dropout) 164 | self.linear2 = Linear(dim_feedforward//2, d_model) 165 | self.norm1 = LayerNorm(d_model) 166 | self.norm2 = LayerNorm(d_model) 167 | self.dropout1 = Dropout(dropout) 168 | self.dropout2 = Dropout(dropout) 169 | self.activation = F.glu 170 | def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: 171 | src = self.norm1(src) 172 | src2 = self.self_attn(src, src, src, attn_mask=src_mask, 173 | key_padding_mask=src_key_padding_mask)[0] 174 | src = src + self.dropout1(src2) 175 | src = self.norm2(src) 176 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 177 | src = src + self.dropout2(src2) 178 | return src 179 | 180 | def linear_combination(x, y, epsilon): 181 | return epsilon*x + (1-epsilon)*y 182 | 183 | def reduce_loss(loss, reduction='mean'): 184 | return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss 185 | 186 | class LabelSmoothingCrossEntropy(nn.Module): 187 | def __init__(self, epsilon:float=0.1, reduction='mean'): 188 | super().__init__() 189 | self.epsilon = epsilon 190 | self.reduction = reduction 191 | 192 | def forward(self, preds, target): 193 | n = preds.size()[-1] 194 | log_preds = F.log_softmax(preds, dim=-1) 195 | loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction) 196 | nll = F.nll_loss(log_preds, target, reduction=self.reduction) 197 | return linear_combination(loss/n, nll, self.epsilon) 198 | 199 | class BertForMaskedLMForBERTASR(BertForMaskedLM): 200 | def __init__(self, config): 201 | super().__init__(config) 202 | 203 | def forward( 204 | self, 205 | input_ids=None, 206 | attention_mask=None, 207 | token_type_ids=None, 208 | position_ids=None, 209 | head_mask=None, 210 | inputs_embeds=None, 211 | encoder_hidden_states=None, 212 | encoder_attention_mask=None, 213 | labels=None, 214 | output_attentions=None, 215 | output_hidden_states=None, 216 | return_dict=None, 217 | ): 218 | 219 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 220 | 221 | outputs = self.bert( 222 | input_ids, 223 | attention_mask=attention_mask, 224 | token_type_ids=token_type_ids, 225 | position_ids=position_ids, 226 | head_mask=head_mask, 227 | inputs_embeds=inputs_embeds, 228 | encoder_hidden_states=encoder_hidden_states, 229 | encoder_attention_mask=encoder_attention_mask, 230 | output_attentions=output_attentions, 231 | output_hidden_states=output_hidden_states, 232 | return_dict=return_dict, 233 | ) 234 | 235 | sequence_output = outputs[0] 236 | prediction_scores = self.cls(sequence_output) 237 | 238 | masked_lm_loss = None 239 | if labels is not None: 240 | loss_fct = LabelSmoothingCrossEntropy() # -100 index = padding token 241 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 242 | 243 | if not return_dict: 244 | output = (prediction_scores,) + outputs[2:] 245 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 246 | 247 | return MaskedLMOutput( 248 | loss=masked_lm_loss, 249 | logits=prediction_scores, 250 | hidden_states=outputs.hidden_states, 251 | attentions=outputs.attentions, 252 | ) 253 | 254 | class BERTASR_Encoder(nn.Module): 255 | def __init__(self, idim, odim, attention_dim=256, attention_heads=8, dropout_rate=0.1): 256 | super(BERTASR_Encoder, self).__init__() 257 | self.model_type = 'Transformer' 258 | self.odim = odim 259 | self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate) 260 | self.pe = PositionalEncoding(attention_dim) 261 | encoder_layers = TransformerEncoderLayerPre(d_model=attention_dim, nhead=attention_heads, dropout=dropout_rate) 262 | self.encoder_norm = nn.LayerNorm(attention_dim) 263 | self.encoder = nn.TransformerEncoder(encoder_layers, num_layers=6, norm= None) 264 | pdslayer = PDSLayer(d_model=attention_dim, nhead=attention_heads, dropout=dropout_rate) 265 | self.pds = PDS(pdslayer, num_layers=4) 266 | self.decoder = nn.TransformerEncoder(encoder_layers, num_layers=6) 267 | self.mlp = nn.Linear(attention_dim, 768) 268 | self.classifier = nn.Linear(768, odim, bias=False) 269 | 270 | def forward(self, src, label=None): 271 | #src: [batch, time, idim] 272 | #conv2D 273 | src = self.embed(src) 274 | #convert to [time, batch, idim] (For Pytorch transformer) 275 | src = src.transpose(0,1) 276 | src = self.encoder_norm(self.encoder(src)) 277 | q = self.pe().repeat(1,src.shape[1],1) 278 | src = self.pds(q, src) 279 | src = self.decoder(src) 280 | #convert to [batch, time, idim] 281 | src = src.transpose(0,1) 282 | #classification 283 | output = self.classifier(self.mlp(src)) 284 | if label !=None: 285 | loss_fct = LabelSmoothingCrossEntropy() 286 | loss = loss_fct(output.view(-1, self.odim), label.view(-1)) 287 | return loss 288 | else: 289 | return output 290 | 291 | class BERTASR(nn.Module): 292 | def __init__(self, encoder, BertMLM): 293 | super(BERTASR, self).__init__() 294 | self.encoder = encoder 295 | self.bertmodel = BertMLM 296 | 297 | def forward(self, src, label=None): 298 | #Conv 299 | src = self.encoder.embed(src) 300 | #convert to [time, batch, idim] (For Pytorch transformer) 301 | src = src.transpose(0, 1) 302 | src = self.encoder.encoder_norm(self.encoder.encoder(src)) 303 | q = self.encoder.pe().repeat(1, src.shape[1], 1) 304 | src = self.encoder.pds(q, src) 305 | src = self.encoder.decoder(src) 306 | #convert to [batch, time, idim] 307 | src = src.transpose(0, 1) 308 | src = self.encoder.mlp(src) 309 | # TO BERT Classification 310 | return self.bertmodel(inputs_embeds=src, labels=label) 311 | 312 | def avg_model(root="./pretraining.", avg_num=10, last_num=130, save_path="./pretraining_avg10"): 313 | avg=None 314 | # sum 315 | for num in range(0,avg_num): 316 | states = torch.load(root+str(last_num-num), map_location=torch.device("cpu")) 317 | if avg is None: 318 | avg = states 319 | else: 320 | for k in avg.keys(): 321 | avg[k] += states[k] 322 | # average 323 | for k in avg.keys(): 324 | if avg[k] is not None: 325 | if avg[k].is_floating_point(): 326 | avg[k] /= avg_num 327 | else: 328 | avg[k] //= avg_num 329 | torch.save(avg, save_path) --------------------------------------------------------------------------------