├── assets ├── paper.pdf └── model_figure_Mar8.png ├── src ├── jyupin2idx.npy ├── config.py ├── augmentation.py ├── argparser.py ├── focalloss.py ├── model.py ├── train.py └── dataset.py ├── INTERSPEECH_2023_refined.pdf ├── README.md ├── LICENSE └── apache.org_licenses_LICENSE-2.0.txt /assets/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpii-cai/PunCantonese/HEAD/assets/paper.pdf -------------------------------------------------------------------------------- /src/jyupin2idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpii-cai/PunCantonese/HEAD/src/jyupin2idx.npy -------------------------------------------------------------------------------- /INTERSPEECH_2023_refined.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpii-cai/PunCantonese/HEAD/INTERSPEECH_2023_refined.pdf -------------------------------------------------------------------------------- /assets/model_figure_Mar8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpii-cai/PunCantonese/HEAD/assets/model_figure_Mar8.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PunCantonese: A Benchmark Corpus for Low-Resource Cantonese Punctuation Restoration from Speech Transcripts 2 | 3 | This repository contains the data and source code of the paper accepted at the InterSpeech2023, ***[PunCantonese: A Benchmark Corpus for Low-Resource Cantonese Punctuation Restoration from Speech Transcripts](https://www.isca-speech.org/archive/pdfs/interspeech_2023/li23z_interspeech.pdf).*** 4 | 5 | 6 | ## Data 7 | 8 | The dataset would be available upon request to this email address: yunxiang.li@link.cuhk.edu.hk 9 | 10 | 11 | ## Model 12 | We propose a Transformer-based neural network model to evaluate the **PunCantonese** corpus. 13 | 14 | The model exploits pre-trained language models to obtain a good network initialization, a multi-task learning objective to prevent the network from paying too much attention to the largest subset of written-style sentences and a novel Jyutping embedding layer to represent a Cantonese character with its Jyutping sequence which potentially enables the model to incorporate phonetic features that are not explicitly available in Cantonese characters. Then we have one Bi-LSTM and one linear layer for classification. 15 | 16 | ![](./assets/model_figure_Mar8.png) 17 | 18 | 19 | ## Training 20 | To train our model with default settings and with both multitask learning and jyutping embedding on, run the following command 21 | ``` 22 | python src/train_jyupin_multi.py --cuda=True --pretrained-model=bert-base-multilingual-uncased --freeze-bert=False --lstm-dim=-1 --seed=0 --lr=2e-5 --epoch=15 --use-crf=False --data-path=data --save-path=out --batch-size=32 --sequence-length=128 --loss=focal --multitask=True --jyutping=True 23 | ``` 24 | 25 | ## Acknowledgements 26 | 27 | This project is based on modifications made to the original code from [Punctuation Restoration using Transformer Models](https://github.com/xashru/punctuation-restoration). Special thanks to the authors from this work for providing the excellent code base. 28 | 29 | ## License 30 | 31 | This repository is under the [Apache License 2.0](https://github.com/apache/.github/blob/main/LICENSE). 32 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from transformers import * 2 | # special tokens indices in different models available in transformers 3 | TOKEN_IDX = { 4 | 'bert': { 5 | 'START_SEQ': 101, 6 | 'PAD': 0, 7 | 'END_SEQ': 102, 8 | 'UNK': 100 9 | }, 10 | 'xlm': { 11 | 'START_SEQ': 0, 12 | 'PAD': 2, 13 | 'END_SEQ': 1, 14 | 'UNK': 3 15 | }, 16 | 'roberta': { 17 | 'START_SEQ': 0, 18 | 'PAD': 1, 19 | 'END_SEQ': 2, 20 | 'UNK': 3 21 | }, 22 | 'albert': { 23 | 'START_SEQ': 2, 24 | 'PAD': 0, 25 | 'END_SEQ': 3, 26 | 'UNK': 1 27 | }, 28 | } 29 | 30 | # 'O' -> No punctuation 31 | #punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4} 32 | #punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4,'INFORMAL':5,'FORMAL':6} 33 | 34 | # pretrained model name: (model class, model tokenizer, output dimension, token style) 35 | MODELS = { 36 | 'bert-base-uncased': (BertModel, BertTokenizer, 768, 'bert'), 37 | 'bert-base-chinese': (BertModel, BertTokenizer, 768, 'bert'), 38 | 'bert-large-uncased': (BertModel, BertTokenizer, 1024, 'bert'), 39 | 'bert-base-multilingual-cased': (BertModel, BertTokenizer, 768, 'bert'), 40 | 'bert-base-multilingual-uncased': (BertModel, BertTokenizer, 768, 'bert'), 41 | 'xlm-mlm-en-2048': (XLMModel, XLMTokenizer, 2048, 'xlm'), 42 | 'xlm-mlm-100-1280': (XLMModel, XLMTokenizer, 1280, 'xlm'), 43 | 'roberta-base': (RobertaModel, RobertaTokenizer, 768, 'roberta'), 44 | 'roberta-large': (RobertaModel, RobertaTokenizer, 1024, 'roberta'), 45 | 'distilbert-base-uncased': (DistilBertModel, DistilBertTokenizer, 768, 'bert'), 46 | 'distilbert-base-multilingual-cased': (DistilBertModel, DistilBertTokenizer, 768, 'bert'), 47 | 'xlm-roberta-base': (XLMRobertaModel, XLMRobertaTokenizer, 768, 'roberta'), 48 | 'xlm-roberta-large': (XLMRobertaModel, XLMRobertaTokenizer, 1024, 'roberta'), 49 | 'albert-base-v1': (AlbertModel, AlbertTokenizer, 768, 'albert'), 50 | 'albert-base-v2': (AlbertModel, AlbertTokenizer, 768, 'albert'), 51 | 'albert-large-v2': (AlbertModel, AlbertTokenizer, 1024, 'albert'), 52 | } 53 | -------------------------------------------------------------------------------- /src/augmentation.py: -------------------------------------------------------------------------------- 1 | from config import TOKEN_IDX 2 | import numpy as np 3 | 4 | 5 | # probability of applying substitution operation on tokens selected for augmentation 6 | alpha_sub = 0.40 7 | # probability of applying delete operation on tokens selected for augmentation 8 | alpha_del = 0.40 9 | 10 | tokenizer = None 11 | # substitution strategy: 'unk' -> replace with unknown tokens, 'rand' -> replace with random tokens from vocabulary 12 | sub_style = 'unk' 13 | 14 | 15 | def augment_none(x, y, y_mask, x_aug, y_aug, y_mask_aug, i, token_style): 16 | """ 17 | apply no augmentation 18 | """ 19 | x_aug.append(x[i]) 20 | y_aug.append(y[i]) 21 | y_mask_aug.append(y_mask[i]) 22 | 23 | 24 | def augment_substitute(x, y, y_mask, x_aug, y_aug, y_mask_aug, i, token_style): 25 | """ 26 | replace a token with a random token or the unknown token 27 | """ 28 | if sub_style == 'rand': 29 | x_aug.append(np.random.randint(tokenizer.vocab_size)) 30 | else: 31 | x_aug.append(TOKEN_IDX[token_style]['UNK']) 32 | y_aug.append(y[i]) 33 | y_mask_aug.append(y_mask[i]) 34 | 35 | 36 | def augment_insert(x, y, y_mask, x_aug, y_aug, y_mask_aug, i, token_style): 37 | """ 38 | insert the unknown token before this token 39 | """ 40 | x_aug.append(TOKEN_IDX[token_style]['UNK']) 41 | y_aug.append(0) 42 | y_mask_aug.append(1) 43 | x_aug.append(x[i]) 44 | y_aug.append(y[i]) 45 | y_mask_aug.append(y_mask[i]) 46 | 47 | 48 | def augment_delete(x, y, y_mask, x_aug, y_aug, y_mask_aug, i, token_style): 49 | """ 50 | remove this token i..e, not add in augmented tokens 51 | """ 52 | return 53 | 54 | 55 | def augment_all(x, y, y_mask, x_aug, y_aug, y_mask_aug, i, token_style): 56 | """ 57 | apply substitution with alpha_sub probability, deletion with alpha_sub probability and insertion with 58 | 1-(alpha_sub+alpha_sub) probability 59 | """ 60 | r = np.random.rand() 61 | if r < alpha_sub: 62 | augment_substitute(x, y, y_mask, x_aug, y_aug, y_mask_aug, i, token_style) 63 | elif r < alpha_sub + alpha_del: 64 | augment_delete(x, y, y_mask, x_aug, y_aug, y_mask_aug, i, token_style) 65 | else: 66 | augment_insert(x, y, y_mask, x_aug, y_aug, y_mask_aug, i, token_style) 67 | 68 | 69 | # supported augmentation techniques 70 | AUGMENTATIONS = { 71 | 'none': augment_none, 72 | 'substitute': augment_substitute, 73 | 'insert': augment_insert, 74 | 'delete': augment_delete, 75 | 'all': augment_all 76 | } 77 | -------------------------------------------------------------------------------- /src/argparser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_arguments(): 5 | parser = argparse.ArgumentParser(description='Punctuation restoration') 6 | parser.add_argument('--name', default='punctuation-restore', type=str, help='name of run') 7 | parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available') 8 | parser.add_argument('--seed', default=1, type=int, help='random seed') 9 | parser.add_argument('--pretrained-model', default='roberta-large', type=str, help='pretrained language model') 10 | parser.add_argument('--freeze-bert', default=False, type=lambda x: (str(x).lower() == 'true'), 11 | help='Freeze BERT layers or not') 12 | parser.add_argument('--lstm-dim', default=-1, type=int, 13 | help='hidden dimension in LSTM layer, if -1 is set equal to hidden dimension in language model') 14 | parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'true'), 15 | help='whether to use CRF layer or not') 16 | parser.add_argument('--data-path', default='data/', type=str, help='path to train/dev/test datasets') 17 | parser.add_argument('--sequence-length', default=256, type=int, 18 | help='sequence length to use when preparing dataset (default 256)') 19 | parser.add_argument('--augment-rate', default=0., type=float, help='token augmentation probability') 20 | parser.add_argument('--augment-type', default='all', type=str, help='which augmentation to use') 21 | parser.add_argument('--sub-style', default='unk', type=str, help='replacement strategy for substitution augment') 22 | parser.add_argument('--alpha-sub', default=0.4, type=float, help='augmentation rate for substitution') 23 | parser.add_argument('--alpha-del', default=0.4, type=float, help='augmentation rate for deletion') 24 | parser.add_argument('--lr', default=5e-6, type=float, help='learning rate') 25 | parser.add_argument('--decay', default=0, type=float, help='weight decay (default: 0)') 26 | parser.add_argument('--gradient-clip', default=-1, type=float, help='gradient clipping (default: -1 i.e., none)') 27 | parser.add_argument('--batch-size', default=8, type=int, help='batch size (default: 8)') 28 | parser.add_argument('--epoch', default=10, type=int, help='total epochs (default: 10)') 29 | parser.add_argument('--save-path', default='out/', type=str, help='model and log save directory') 30 | parser.add_argument('--loss', default='focal', type=str, help='loss type') 31 | parser.add_argument('--multitask', default=True, type=lambda x: (str(x).lower() == 'true'), help='use multitask structure or not') 32 | parser.add_argument('--jyutping', default=True, type=lambda x: (str(x).lower() == 'true'), help='use jyutping embedding or not') 33 | args = parser.parse_args() 34 | return args 35 | -------------------------------------------------------------------------------- /src/focalloss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from torch import nn 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | class focal_loss(nn.Module): 7 | def __init__(self, alpha=[0.75,0.75,0.75,0.25,0.25,0.75,0.25], gamma=2, num_classes = 7, size_average=True): 8 | """ 9 | focal_loss, -α(1-yi)**γ *ce_loss(xi,yi) 10 | """ 11 | 12 | super(focal_loss,self).__init__() 13 | self.size_average = size_average 14 | if isinstance(alpha,list): 15 | assert len(alpha)==num_classes 16 | print("Focal_loss alpha = {}".format(alpha)) 17 | self.alpha = torch.Tensor(alpha) 18 | else: 19 | assert alpha<1 20 | print(" --- Focal_loss alpha = {}".format(alpha)) 21 | self.alpha = torch.zeros(num_classes) 22 | self.alpha[0] += alpha 23 | self.alpha[1:] += (1-alpha) 24 | self.gamma = gamma 25 | 26 | def forward(self, preds, labels): 27 | # assert preds.dim()==2 and labels.dim()==1 28 | preds = preds.view(-1,preds.size(-1)) 29 | self.alpha = self.alpha.to(preds.device) 30 | preds_softmax = F.softmax(preds, dim=1) 31 | preds_logsoft = torch.log(preds_softmax) 32 | preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) 33 | preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1)) 34 | self.alpha = self.alpha.gather(0,labels.view(-1)) 35 | loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) 36 | loss = torch.mul(self.alpha, loss.t()) 37 | if self.size_average: 38 | loss = loss.mean() 39 | else: 40 | loss = loss.sum() 41 | return loss 42 | 43 | class focal_loss_multi(nn.Module): 44 | def __init__(self, alpha=[0.25,0.75], gamma=2, num_classes = 2, size_average=True): 45 | 46 | super(focal_loss_multi,self).__init__() 47 | self.size_average = size_average 48 | if isinstance(alpha,list): 49 | assert len(alpha)==num_classes 50 | print("Focal_loss_multi alpha = {}".format(alpha)) 51 | self.alpha = torch.Tensor(alpha) 52 | else: 53 | assert alpha<1 54 | print(" --- Focal_loss_multi alpha = {} ".format(alpha)) 55 | self.alpha = torch.zeros(num_classes) 56 | self.alpha[0] += alpha 57 | self.alpha[1:] += (1-alpha) 58 | self.gamma = gamma 59 | 60 | def forward(self, preds, labels): 61 | """ 62 | 63 | """ 64 | # assert preds.dim()==2 and labels.dim()==1 65 | preds = preds.view(-1,preds.size(-1)) 66 | self.alpha = self.alpha.to(preds.device) 67 | preds_softmax = F.softmax(preds, dim=1) 68 | preds_logsoft = torch.log(preds_softmax) 69 | preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) 70 | preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1)) 71 | self.alpha = self.alpha.gather(0,labels.view(-1)) 72 | loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft) 73 | loss = torch.mul(self.alpha, loss.t()) 74 | if self.size_average: 75 | loss = loss.mean() 76 | else: 77 | loss = loss.sum() 78 | return loss 79 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from config import * 4 | from torchcrf import CRF 5 | 6 | 7 | class DeepPunctuation(nn.Module): 8 | def __init__(self, pretrained_model, freeze_bert=False, lstm_dim=-1,punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4}): 9 | super(DeepPunctuation, self).__init__() 10 | self.output_dim = len(punctuation_dict) 11 | #self.bert_layer = MODELS[pretrained_model][0].from_pretrained(pretrained_model) 12 | #self.bert_layer = MODELS[pretrained_model][0].from_pretrained("../../../xlm-roberta-large") 13 | self.bert_layer = MODELS[pretrained_model][0].from_pretrained("../../../"+pretrained_model) 14 | # Freeze bert layers 15 | if freeze_bert: 16 | for p in self.bert_layer.parameters(): 17 | p.requires_grad = False 18 | bert_dim = MODELS[pretrained_model][2] 19 | if lstm_dim == -1: 20 | hidden_size = bert_dim 21 | else: 22 | hidden_size = lstm_dim 23 | self.lstm = nn.LSTM(input_size=bert_dim, hidden_size=hidden_size, num_layers=1, bidirectional=True) 24 | self.linear = nn.Linear(in_features=hidden_size*2, out_features=len(punctuation_dict)) 25 | 26 | def forward(self, x, attn_masks): 27 | if len(x.shape) == 1: 28 | x = x.view(1, x.shape[0]) # add dummy batch for single sample 29 | # (B, N, E) -> (B, N, E) 30 | x = self.bert_layer(x, attention_mask=attn_masks)[0] 31 | # (B, N, E) -> (N, B, E) 32 | x = torch.transpose(x, 0, 1) 33 | x, (_, _) = self.lstm(x) 34 | # (N, B, E) -> (B, N, E) 35 | x = torch.transpose(x, 0, 1) 36 | x = self.linear(x) 37 | return x 38 | 39 | 40 | class DeepPunctuationCRF(nn.Module): 41 | def __init__(self, pretrained_model, freeze_bert=False, lstm_dim=-1,punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4}): 42 | super(DeepPunctuationCRF, self).__init__() 43 | self.bert_lstm = DeepPunctuation(pretrained_model, freeze_bert, lstm_dim) 44 | self.crf = CRF(len(punctuation_dict), batch_first=True) 45 | 46 | def log_likelihood(self, x, attn_masks, y): 47 | x = self.bert_lstm(x, attn_masks) 48 | attn_masks = attn_masks.byte() 49 | return -self.crf(x, y, mask=attn_masks, reduction='token_mean') 50 | 51 | def forward(self, x, attn_masks, y): 52 | if len(x.shape) == 1: 53 | x = x.view(1, x.shape[0]) # add dummy batch for single sample 54 | x = self.bert_lstm(x, attn_masks) 55 | attn_masks = attn_masks.byte() 56 | dec_out = self.crf.decode(x, mask=attn_masks) 57 | y_pred = torch.zeros(y.shape).long().to(y.device) 58 | for i in range(len(dec_out)): 59 | y_pred[i, :len(dec_out[i])] = torch.tensor(dec_out[i]).to(y.device) 60 | return y_pred 61 | 62 | class DeepPunctuationCRF_jyupin(nn.Module): 63 | def __init__(self, pretrained_model, freeze_bert=False, lstm_dim=-1,punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4}): 64 | super(DeepPunctuationCRF_jyupin, self).__init__() 65 | self.bert_lstm = DeepPunctuation_jyupin(pretrained_model, freeze_bert, lstm_dim) 66 | self.crf = CRF(len(punctuation_dict), batch_first=True) 67 | 68 | def log_likelihood(self, x, attn_masks, y, jyupin): 69 | x = self.bert_lstm(x, attn_masks, jyupin) 70 | attn_masks = attn_masks.byte() 71 | return -self.crf(x, y, mask=attn_masks, reduction='token_mean') 72 | 73 | def forward(self, x, attn_masks, y, jyupin): 74 | if len(x.shape) == 1: 75 | x = x.view(1, x.shape[0]) # add dummy batch for single sample 76 | x = self.bert_lstm(x, attn_masks, jyupin) 77 | attn_masks = attn_masks.byte() 78 | dec_out = self.crf.decode(x, mask=attn_masks) 79 | y_pred = torch.zeros(y.shape).long().to(y.device) 80 | for i in range(len(dec_out)): 81 | y_pred[i, :len(dec_out[i])] = torch.tensor(dec_out[i]).to(y.device) 82 | return y_pred 83 | 84 | 85 | class DeepPunctuation_multitask_parallel(nn.Module): 86 | def __init__(self, pretrained_model, freeze_bert=False, lstm_dim=-1,punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4}): 87 | super(DeepPunctuation_multitask, self).__init__() 88 | self.output_dim = len(punctuation_dict) 89 | self.output_dim_multitask = 2 #change 90 | #self.bert_layer = MODELS[pretrained_model][0].from_pretrained(pretrained_model) 91 | #self.bert_layer = MODELS[pretrained_model][0].from_pretrained("../../../xlm-roberta-large") 92 | self.bert_layer = MODELS[pretrained_model][0].from_pretrained("../../../"+pretrained_model) 93 | # Freeze bert layers 94 | if freeze_bert: 95 | for p in self.bert_layer.parameters(): 96 | p.requires_grad = False 97 | bert_dim = MODELS[pretrained_model][2] 98 | if lstm_dim == -1: 99 | hidden_size = bert_dim 100 | else: 101 | hidden_size = lstm_dim 102 | self.lstm = nn.LSTM(input_size=bert_dim, hidden_size=hidden_size, num_layers=1, bidirectional=True) 103 | self.linear = nn.Linear(in_features=hidden_size*2, out_features=len(punctuation_dict)) 104 | self.linear_multitask = nn.Linear(in_features=hidden_size*2, out_features=2) 105 | 106 | def forward(self, x, attn_masks): 107 | if len(x.shape) == 1: 108 | x = x.view(1, x.shape[0]) # add dummy batch for single sample 109 | # (B, N, E) -> (B, N, E) 110 | x = self.bert_layer(x, attention_mask=attn_masks)[0] 111 | # (B, N, E) -> (N, B, E) 112 | x = torch.transpose(x, 0, 1) 113 | x, (_, _) = self.lstm(x) 114 | # (N, B, E) -> (B, N, E) 115 | x = torch.transpose(x, 0, 1) 116 | ''' 117 | print(x) 118 | print(x.size()) 119 | ''' 120 | multitask_input=x[:,0,:] 121 | x_multitask = self.linear_multitask(multitask_input) 122 | #print(x[:,0,:].size()) 123 | #print(x.size()) 124 | x = self.linear(x) 125 | 126 | #print(x.size()) 127 | return x,x_multitask 128 | 129 | 130 | class DeepPunctuation_jyupin(nn.Module): 131 | def __init__(self, pretrained_model, freeze_bert=False, lstm_dim=-1,punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4}): 132 | super(DeepPunctuation_jyupin, self).__init__() 133 | self.output_dim = len(punctuation_dict) 134 | #self.bert_layer = MODELS[pretrained_model][0].from_pretrained(pretrained_model) 135 | #self.bert_layer = MODELS[pretrained_model][0].from_pretrained("../../../xlm-roberta-large") 136 | self.bert_layer = MODELS[pretrained_model][0].from_pretrained("../../../"+pretrained_model) 137 | # Freeze bert layers 138 | if freeze_bert: 139 | for p in self.bert_layer.parameters(): 140 | p.requires_grad = False 141 | 142 | bert_dim = MODELS[pretrained_model][2] 143 | self.embedding = nn.Embedding(num_embeddings=31855,embedding_dim=768,padding_idx=31854) 144 | if lstm_dim == -1: 145 | hidden_size = bert_dim 146 | else: 147 | hidden_size = lstm_dim 148 | self.lstm = nn.LSTM(input_size=bert_dim, hidden_size=hidden_size, num_layers=1, bidirectional=True) 149 | self.linear = nn.Linear(in_features=hidden_size*2, out_features=len(punctuation_dict)) 150 | 151 | def forward(self, x, attn_masks, jyupin): 152 | if len(x.shape) == 1: 153 | x = x.view(1, x.shape[0]) # add dummy batch for single sample 154 | # (B, N, E) -> (B, N, E) 155 | 156 | x = self.bert_layer(x, attention_mask=attn_masks)[0] 157 | jyupin_embedding = self.embedding(jyupin) 158 | 159 | x += jyupin_embedding 160 | # (B, N, E) -> (N, B, E) 161 | x = torch.transpose(x, 0, 1) 162 | x, (_, _) = self.lstm(x) 163 | # (N, B, E) -> (B, N, E) 164 | x = torch.transpose(x, 0, 1) 165 | x = self.linear(x) 166 | return x 167 | 168 | 169 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /apache.org_licenses_LICENSE-2.0.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from torch.utils import data 6 | import torch.multiprocessing 7 | from tqdm import tqdm 8 | from focalloss import focal_loss 9 | 10 | 11 | from argparser import parse_arguments 12 | from dataset import Dataset_cant,Dataset_cant_jyupin, Dataset_cant_multitask_sequence 13 | from model import DeepPunctuation,DeepPunctuationCRF, DeepPunctuationCRF_jyupin, DeepPunctuation_jyupin 14 | from config import * 15 | import augmentation 16 | from transformers import XLMModel, XLMTokenizer,XLMRobertaModel,XLMRobertaTokenizer 17 | 18 | torch.multiprocessing.set_sharing_strategy('file_system') # https://github.com/pytorch/pytorch/issues/11201 19 | 20 | args = parse_arguments() 21 | 22 | # for reproducibility 23 | torch.manual_seed(args.seed) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | np.random.seed(args.seed) 27 | 28 | # tokenizer 29 | 30 | #tokenizer = MODELS[args.pretrained_model][1].from_pretrained(args.pretrained_model) 31 | #tokenizer = XLMRobertaTokenizer.from_pretrained("../../../xlm-roberta-large") 32 | tokenizer = MODELS[args.pretrained_model][1].from_pretrained("../../../"+args.pretrained_model) 33 | 34 | augmentation.tokenizer = tokenizer 35 | 36 | augmentation.sub_style = args.sub_style 37 | augmentation.alpha_sub = args.alpha_sub 38 | augmentation.alpha_del = args.alpha_del 39 | token_style = MODELS[args.pretrained_model][3] 40 | ar = args.augment_rate 41 | sequence_len = args.sequence_length 42 | aug_type = args.augment_type 43 | 44 | punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4} 45 | if args.multitask: 46 | punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4,'INFORMAL':5,'FORMAL':6} 47 | 48 | 49 | # Datasets 50 | if args.jyutping: 51 | if args.multitask: 52 | train_set = Dataset_cant_jyupin(os.path.join(args.data_path, 'cant_jyupin_multi/train_wiki.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 53 | token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type,multitask=True) 54 | val_set = Dataset_cant_jyupin(os.path.join(args.data_path, 'cant_jyupin_multi/val.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 55 | token_style=token_style, is_train=False,multitask=True) 56 | test_set_ref = Dataset_cant_jyupin(os.path.join(args.data_path, 'cant_jyupin_multi/test.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 57 | token_style=token_style, is_train=False,multitask=True) 58 | test_set = [val_set, test_set_ref] 59 | else: 60 | train_set = Dataset_cant_jyupin(os.path.join(args.data_path, 'cant_jyupin/train_wiki.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 61 | token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type) 62 | val_set = Dataset_cant_jyupin(os.path.join(args.data_path, 'cant_jyupin/val.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 63 | token_style=token_style, is_train=False) 64 | test_set_ref = Dataset_cant_jyupin(os.path.join(args.data_path, 'cant_jyupin/test.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 65 | token_style=token_style, is_train=False) 66 | test_set = [val_set, test_set_ref] 67 | 68 | elif args.multitask: 69 | train_set = Dataset_cant_multitask_sequence(os.path.join(args.data_path, 'cant_multitask/train.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 70 | token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type) 71 | val_set = Dataset_cant_multitask_sequence(os.path.join(args.data_path, 'cant_multitask/val.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 72 | token_style=token_style, is_train=False) 73 | test_set_ref = Dataset_cant_multitask_sequence(os.path.join(args.data_path, 'cant_multitask/test.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 74 | token_style=token_style, is_train=False) 75 | test_set = [val_set, test_set_ref] 76 | else: 77 | train_set = Dataset_cant(os.path.join(args.data_path, 'cant_informal_wiki/train.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 78 | token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type) 79 | val_set = Dataset_cant(os.path.join(args.data_path, 'cant_informal_wiki/val.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 80 | token_style=token_style, is_train=False) 81 | test_set_ref = Dataset_cant(os.path.join(args.data_path, 'cant_informal_wiki/test.txt'), tokenizer=tokenizer, sequence_len=sequence_len, 82 | token_style=token_style, is_train=False) 83 | test_set = [val_set, test_set_ref] 84 | 85 | # Data Loaders 86 | data_loader_params = { 87 | 'batch_size': args.batch_size, 88 | 'shuffle': True, 89 | 'num_workers': 1 90 | } 91 | train_loader = torch.utils.data.DataLoader(train_set, **data_loader_params) 92 | val_loader = torch.utils.data.DataLoader(val_set, **data_loader_params) 93 | test_loaders = [torch.utils.data.DataLoader(x, **data_loader_params) for x in test_set] 94 | 95 | # logs 96 | os.makedirs(args.save_path, exist_ok=True) 97 | model_save_path = os.path.join(args.save_path, args.name+'weights.pt') 98 | log_path = os.path.join(args.save_path, args.name + '_logs.txt') 99 | 100 | 101 | # select gpu 102 | #os.environ["CUDA_VISIBLE_DEVICES"]="1" 103 | 104 | 105 | # Model 106 | device = torch.device('cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu') 107 | if args.use_crf: 108 | if args.jyutping: 109 | deep_punctuation = DeepPunctuationCRF_jyupin(args.pretrained_model, freeze_bert=args.freeze_bert, lstm_dim=args.lstm_dim,punctuation_dict=punctuation_dict) 110 | else: 111 | deep_punctuation = DeepPunctuationCRF(args.pretrained_model, freeze_bert=args.freeze_bert, lstm_dim=args.lstm_dim,punctuation_dict=punctuation_dict) 112 | 113 | else: 114 | if args.jyutping: 115 | deep_punctuation = DeepPunctuation_jyupin(args.pretrained_model, freeze_bert=args.freeze_bert, lstm_dim=args.lstm_dim,punctuation_dict=punctuation_dict) 116 | else: 117 | deep_punctuation = DeepPunctuation(args.pretrained_model, freeze_bert=args.freeze_bert, lstm_dim=args.lstm_dim,punctuation_dict=punctuation_dict) 118 | 119 | deep_punctuation.to(device) 120 | 121 | 122 | #loss 123 | criterion = nn.CrossEntropyLoss() 124 | #criterion1 = focal_loss() 125 | 126 | if args.loss == 'CE': 127 | criterion = nn.CrossEntropyLoss() 128 | elif args.loss == 'focal': 129 | if args.multitask: 130 | criterion = focal_loss(alpha=[0.75,0.75,0.75,0.25,0.25,0.75,0.75],num_classes = 7) 131 | else: 132 | criterion = focal_loss(alpha=[0.75,0.75,0.75,0.25,0.25],num_classes = 5) 133 | 134 | 135 | 136 | 137 | optimizer = torch.optim.Adam(deep_punctuation.parameters(), lr=args.lr, weight_decay=args.decay) 138 | 139 | 140 | def validate(data_loader): 141 | """ 142 | :return: validation accuracy, validation loss 143 | """ 144 | num_iteration = 0 145 | deep_punctuation.eval() 146 | correct = 0 147 | total = 0 148 | val_loss = 0 149 | with torch.no_grad(): 150 | if args.jyutping: 151 | for x, y, att, y_mask, jyupin in tqdm(data_loader, desc='eval'): 152 | x, y, att, y_mask, jyupin = x.to(device), y.to(device), att.to(device), y_mask.to(device), jyupin.to(device) 153 | y_mask = y_mask.view(-1) 154 | if args.use_crf: 155 | y_predict = deep_punctuation(x, att, y, jyupin) 156 | loss = deep_punctuation.log_likelihood(x, att, y, jyupin) 157 | y_predict = y_predict.view(-1) 158 | y = y.view(-1) 159 | else: 160 | y_predict = deep_punctuation(x, att, jyupin) 161 | y = y.view(-1) 162 | y_predict = y_predict.view(-1, y_predict.shape[2]) 163 | loss = criterion(y_predict, y) 164 | #loss = criterion(y_predict, y)+criterion1(y_predict, y) 165 | y_predict = torch.argmax(y_predict, dim=1).view(-1) 166 | val_loss += loss.item() 167 | num_iteration += 1 168 | y_mask = y_mask.view(-1) 169 | correct += torch.sum(y_mask * (y_predict == y).long()).item() 170 | total += torch.sum(y_mask).item() 171 | else: 172 | for x, y, att, y_mask in tqdm(data_loader, desc='eval'): 173 | x, y, att, y_mask = x.to(device), y.to(device), att.to(device), y_mask.to(device) 174 | y_mask = y_mask.view(-1) 175 | if args.use_crf: 176 | y_predict = deep_punctuation(x, att, y) 177 | loss = deep_punctuation.log_likelihood(x, att, y) 178 | y_predict = y_predict.view(-1) 179 | y = y.view(-1) 180 | else: 181 | y_predict = deep_punctuation(x, att) 182 | y = y.view(-1) 183 | y_predict = y_predict.view(-1, y_predict.shape[2]) 184 | loss = criterion(y_predict, y) 185 | #loss = criterion(y_predict, y)+criterion1(y_predict, y) 186 | y_predict = torch.argmax(y_predict, dim=1).view(-1) 187 | val_loss += loss.item() 188 | num_iteration += 1 189 | y_mask = y_mask.view(-1) 190 | #correct += torch.sum(y_mask * (y_predict == y).long()).item() 191 | #yx 192 | correct += torch.sum(y_mask[1:] * (y_predict[1:] == y[1:]).long()).item() 193 | total += torch.sum(y_mask).item() 194 | return correct/total, val_loss/num_iteration 195 | 196 | 197 | def test(data_loader): 198 | """ 199 | :return: precision[numpy array], recall[numpy array], f1 score [numpy array], accuracy, confusion matrix 200 | """ 201 | num_iteration = 0 202 | deep_punctuation.eval() 203 | # +1 for overall result 204 | tp = np.zeros(1+len(punctuation_dict), dtype=np.int) 205 | fp = np.zeros(1+len(punctuation_dict), dtype=np.int) 206 | fn = np.zeros(1+len(punctuation_dict), dtype=np.int) 207 | cm = np.zeros((len(punctuation_dict), len(punctuation_dict)), dtype=np.int) 208 | correct = 0 209 | total = 0 210 | with torch.no_grad(): 211 | if args.jyutping: 212 | for x, y, att, y_mask, jyupin in tqdm(data_loader, desc='test'): 213 | x, y, att, y_mask, jyupin = x.to(device), y.to(device), att.to(device), y_mask.to(device), jyupin.to(device) 214 | y_mask = y_mask.view(-1) 215 | if args.use_crf: 216 | y_predict = deep_punctuation(x, att, y, jyupin) 217 | y_predict = y_predict.view(-1) 218 | y = y.view(-1) 219 | else: 220 | y_predict = deep_punctuation(x, att, jyupin) 221 | y = y.view(-1) 222 | y_predict = y_predict.view(-1, y_predict.shape[2]) 223 | y_predict = torch.argmax(y_predict, dim=1).view(-1) 224 | num_iteration += 1 225 | y_mask = y_mask.view(-1) 226 | correct += torch.sum(y_mask * (y_predict == y).long()).item() 227 | total += torch.sum(y_mask).item() 228 | for i in range(y.shape[0]): 229 | if y_mask[i] == 0: 230 | # we can ignore this because we know there won't be any punctuation in this position 231 | # since we created this position due to padding or sub-word tokenization 232 | continue 233 | cor = y[i] 234 | prd = y_predict[i] 235 | if cor == prd: 236 | tp[cor] += 1 237 | else: 238 | fn[cor] += 1 239 | fp[prd] += 1 240 | cm[cor][prd] += 1 241 | else: 242 | for x, y, att, y_mask in tqdm(data_loader, desc='test'): 243 | x, y, att, y_mask = x.to(device), y.to(device), att.to(device), y_mask.to(device) 244 | y_mask = y_mask.view(-1) 245 | if args.use_crf: 246 | y_predict = deep_punctuation(x, att, y) 247 | y_predict = y_predict.view(-1) 248 | y = y.view(-1) 249 | else: 250 | y_predict = deep_punctuation(x, att) 251 | y = y.view(-1) 252 | y_predict = y_predict.view(-1, y_predict.shape[2]) 253 | y_predict = torch.argmax(y_predict, dim=1).view(-1) 254 | num_iteration += 1 255 | y_mask = y_mask.view(-1) 256 | correct += torch.sum(y_mask * (y_predict == y).long()).item() 257 | total += torch.sum(y_mask).item() 258 | for i in range(y.shape[0]): 259 | if y_mask[i] == 0: 260 | # we can ignore this because we know there won't be any punctuation in this position 261 | # since we created this position due to padding or sub-word tokenization 262 | continue 263 | cor = y[i] 264 | prd = y_predict[i] 265 | if cor == prd: 266 | tp[cor] += 1 267 | else: 268 | fn[cor] += 1 269 | fp[prd] += 1 270 | cm[cor][prd] += 1 271 | 272 | # ignore first index which is for no punctuation 273 | tp[-1] = np.sum(tp[1:5]) 274 | fp[-1] = np.sum(fp[1:5]) 275 | fn[-1] = np.sum(fn[1:5]) 276 | precision = tp/(tp+fp) 277 | recall = tp/(tp+fn) 278 | f1 = 2 * precision * recall / (precision + recall) 279 | 280 | return precision, recall, f1, correct/total, cm 281 | 282 | 283 | def train(): 284 | with open(log_path, 'a') as f: 285 | f.write(str(args)+'\n') 286 | best_val_acc = 0 287 | for epoch in range(args.epoch): 288 | train_loss = 0.0 289 | train_iteration = 0 290 | correct = 0 291 | total = 0 292 | deep_punctuation.train() 293 | if args.jyutping: 294 | for x, y, att, y_mask, jyupin in tqdm(train_loader, desc='train'): 295 | x, y, att, y_mask, jyupin= x.to(device), y.to(device), att.to(device), y_mask.to(device), jyupin.to(device) 296 | y_mask = y_mask.view(-1) 297 | if args.use_crf: 298 | loss = deep_punctuation.log_likelihood(x, att, y, jyupin) 299 | y_predict = deep_punctuation(x, att, y, jyupin) 300 | y_predict = y_predict.view(-1) 301 | y = y.view(-1) 302 | else: 303 | y_predict = deep_punctuation(x, att, jyupin) 304 | y_predict = y_predict.view(-1, y_predict.shape[2]) 305 | y = y.view(-1) 306 | loss = criterion(y_predict, y) 307 | #loss = criterion(y_predict, y)+criterion1(y_predict, y) 308 | y_predict = torch.argmax(y_predict, dim=1).view(-1) 309 | 310 | correct += torch.sum(y_mask * (y_predict == y).long()).item() 311 | 312 | optimizer.zero_grad() 313 | train_loss += loss.item() 314 | train_iteration += 1 315 | loss.backward() 316 | 317 | if args.gradient_clip > 0: 318 | torch.nn.utils.clip_grad_norm_(deep_punctuation.parameters(), args.gradient_clip) 319 | optimizer.step() 320 | 321 | y_mask = y_mask.view(-1) 322 | 323 | total += torch.sum(y_mask).item() 324 | else: 325 | for x, y, att, y_mask in tqdm(train_loader, desc='train'): 326 | x, y, att, y_mask = x.to(device), y.to(device), att.to(device), y_mask.to(device) 327 | y_mask = y_mask.view(-1) 328 | if args.use_crf: 329 | loss = deep_punctuation.log_likelihood(x, att, y) 330 | # y_predict = deep_punctuation(x, att, y) 331 | # y_predict = y_predict.view(-1) 332 | y = y.view(-1) 333 | else: 334 | y_predict = deep_punctuation(x, att) 335 | y_predict = y_predict.view(-1, y_predict.shape[2]) 336 | y = y.view(-1) 337 | loss = criterion(y_predict, y) 338 | #loss = criterion(y_predict, y)+criterion1(y_predict, y) 339 | y_predict = torch.argmax(y_predict, dim=1).view(-1) 340 | 341 | correct += torch.sum(y_mask * (y_predict == y).long()).item() 342 | 343 | optimizer.zero_grad() 344 | train_loss += loss.item() 345 | train_iteration += 1 346 | loss.backward() 347 | 348 | if args.gradient_clip > 0: 349 | torch.nn.utils.clip_grad_norm_(deep_punctuation.parameters(), args.gradient_clip) 350 | optimizer.step() 351 | 352 | y_mask = y_mask.view(-1) 353 | 354 | total += torch.sum(y_mask).item() 355 | train_loss /= train_iteration 356 | log = 'epoch: {}, Train loss: {}, Train accuracy: {}'.format(epoch, train_loss, correct / total) 357 | with open(log_path, 'a') as f: 358 | f.write(log + '\n') 359 | print(log) 360 | 361 | val_acc, val_loss = validate(val_loader) 362 | log = 'epoch: {}, Val loss: {}, Val accuracy: {}'.format(epoch, val_loss, val_acc) 363 | with open(log_path, 'a') as f: 364 | f.write(log + '\n') 365 | print(log) 366 | if val_acc > best_val_acc: 367 | best_val_acc = val_acc 368 | torch.save(deep_punctuation.state_dict(), model_save_path) 369 | 370 | print('Best validation Acc:', best_val_acc) 371 | deep_punctuation.load_state_dict(torch.load(model_save_path)) 372 | for loader in test_loaders: 373 | precision, recall, f1, accuracy, cm = test(loader) 374 | log = 'Precision: ' + str(precision) + '\n' + 'Recall: ' + str(recall) + '\n' + 'F1 score: ' + str(f1) + \ 375 | '\n' + 'Accuracy:' + str(accuracy) + '\n' + 'Confusion Matrix' + str(cm) + '\n' 376 | print(log) 377 | with open(log_path, 'a') as f: 378 | f.write(log) 379 | log_text = '' 380 | for i in range(1, 5): 381 | log_text += str(precision[i] * 100) + ' ' + str(recall[i] * 100) + ' ' + str(f1[i] * 100) + ' ' 382 | with open(log_path, 'a') as f: 383 | f.write(log_text[:-1] + '\n\n') 384 | 385 | 386 | 387 | 388 | if __name__ == '__main__': 389 | train() 390 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from config import * 3 | from augmentation import * 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | def parse_data(file_path, tokenizer, sequence_len, token_style): 9 | """ 10 | 11 | :param file_path: text file path that contains tokens and punctuations separated by tab in lines 12 | :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models 13 | :param sequence_len: maximum length of each sequence 14 | :param token_style: For getting index of special tokens in config.TOKEN_IDX 15 | :return: list of [tokens_index, punctuation_index, attention_masks, punctuation_mask], each having sequence_len 16 | punctuation_mask is used to ignore special indices like padding and intermediate sub-word token during evaluation 17 | """ 18 | punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4} 19 | data_items = [] 20 | with open(file_path, 'r', encoding='utf-8') as f: 21 | lines = [line for line in f.read().split('\n') if line.strip()] 22 | idx = 0 23 | # loop until end of the entire text 24 | while idx < len(lines): 25 | x = [TOKEN_IDX[token_style]['START_SEQ']] 26 | y = [0] 27 | y_mask = [1] # which positions we need to consider while evaluating i.e., ignore pad or sub tokens 28 | 29 | # loop until we have required sequence length 30 | # -1 because we will have a special end of sequence token at the end 31 | while len(x) < sequence_len - 1 and idx < len(lines): 32 | word, punc = lines[idx].split(' ') 33 | tokens = tokenizer.tokenize(word) 34 | # if taking these tokens exceeds sequence length we finish current sequence with padding 35 | # then start next sequence from this token 36 | if len(tokens) + len(x) >= sequence_len: 37 | break 38 | else: 39 | for i in range(len(tokens) - 1): 40 | x.append(tokenizer.convert_tokens_to_ids(tokens[i])) 41 | y.append(0) 42 | y_mask.append(0) 43 | if len(tokens) > 0: 44 | x.append(tokenizer.convert_tokens_to_ids(tokens[-1])) 45 | else: 46 | x.append(TOKEN_IDX[token_style]['UNK']) 47 | y.append(punctuation_dict[punc]) 48 | y_mask.append(1) 49 | idx += 1 50 | 51 | 52 | x.append(TOKEN_IDX[token_style]['END_SEQ']) 53 | y.append(0) 54 | y_mask.append(1) 55 | if len(x) < sequence_len: 56 | x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))] 57 | y = y + [0 for _ in range(sequence_len - len(y))] 58 | y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))] 59 | attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x] 60 | data_items.append([x, y, attn_mask, y_mask]) 61 | return data_items 62 | 63 | def parse_data_cant(file_path, tokenizer, sequence_len, token_style): 64 | """ 65 | 66 | :param file_path: text file path that contains tokens and punctuations separated by tab in lines 67 | :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models 68 | :param sequence_len: maximum length of each sequence 69 | :param token_style: For getting index of special tokens in config.TOKEN_IDX 70 | :return: list of [tokens_index, punctuation_index, attention_masks, punctuation_mask], each having sequence_len 71 | punctuation_mask is used to ignore special indices like padding and intermediate sub-word token during evaluation 72 | """ 73 | punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4} 74 | data_items = [] 75 | with open(file_path, 'r', encoding='utf-8') as f: 76 | lines = [line for line in f.read().split('\n') if line.strip()] 77 | #print(lines) 78 | idx = 0 79 | # loop until end of the entire text 80 | while idx < len(lines): 81 | x = [TOKEN_IDX[token_style]['START_SEQ']] 82 | y = [0] 83 | y_mask = [1] # which positions we need to consider while evaluating i.e., ignore pad or sub tokens 84 | 85 | # loop until we have required sequence length 86 | # -1 because we will have a special end of sequence token at the end 87 | while len(x) < sequence_len - 1 and idx < len(lines): # -2 for multitask -1 for normal 88 | word, punc = lines[idx].split(' ') #原來是/t 89 | #print(word) 90 | #print(punc) 91 | #exit() 92 | tokens = tokenizer.tokenize(word) 93 | # if taking these tokens exceeds sequence length we finish current sequence with padding 94 | # then start next sequence from this token 95 | if len(tokens) + len(x) >= sequence_len: 96 | break 97 | else: 98 | for i in range(len(tokens) - 1): 99 | x.append(tokenizer.convert_tokens_to_ids(tokens[i])) 100 | y.append(0) 101 | y_mask.append(0) 102 | if len(tokens) > 0: 103 | x.append(tokenizer.convert_tokens_to_ids(tokens[-1])) 104 | else: 105 | x.append(TOKEN_IDX[token_style]['UNK']) 106 | y.append(punctuation_dict[punc]) 107 | y_mask.append(1) 108 | idx += 1 109 | #yunxiang for multi-task 110 | if punctuation_dict[punc] in [2,3,4]: 111 | break 112 | 113 | x.append(TOKEN_IDX[token_style]['END_SEQ']) 114 | y.append(0) 115 | y_mask.append(1) 116 | 117 | if len(x) < sequence_len: 118 | x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))] 119 | y = y + [0 for _ in range(sequence_len - len(y))] 120 | y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))] 121 | attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x] 122 | data_items.append([x, y, attn_mask, y_mask]) 123 | 124 | return data_items 125 | 126 | def parse_data_cant_multitask_sequence(file_path, tokenizer, sequence_len, token_style): 127 | """ 128 | 129 | :param file_path: text file path that contains tokens and punctuations separated by tab in lines 130 | :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models 131 | :param sequence_len: maximum length of each sequence 132 | :param token_style: For getting index of special tokens in config.TOKEN_IDX 133 | :return: list of [tokens_index, punctuation_index, attention_masks, punctuation_mask], each having sequence_len 134 | punctuation_mask is used to ignore special indices like padding and intermediate sub-word token during evaluation 135 | """ 136 | punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4,'INFORMAL':5,'FORMAL':6} 137 | data_items = [] 138 | with open(file_path, 'r', encoding='utf-8') as f: 139 | lines = [line for line in f.read().split('\n') if line.strip()] 140 | #print(lines) 141 | idx = 0 142 | # loop until end of the entire text 143 | while idx < len(lines): 144 | x = [TOKEN_IDX[token_style]['START_SEQ']] 145 | y = [5] 146 | 147 | y_mask = [1] # which positions we need to consider while evaluating i.e., ignore pad or sub tokens 148 | 149 | # loop until we have required sequence length 150 | # -1 because we will have a special end of sequence token at the end 151 | while len(x) < sequence_len - 1 and idx < len(lines): # -2 for multitask -1 for normal 152 | word, punc, scenario = lines[idx].split(' ') #原來是/t 153 | if int(scenario)==0: 154 | y[0] = 5 155 | else: 156 | y[0] = 6 157 | 158 | tokens = tokenizer.tokenize(word) 159 | # if taking these tokens exceeds sequence length we finish current sequence with padding 160 | # then start next sequence from this token 161 | if len(tokens) + len(x) >= sequence_len: 162 | break 163 | else: 164 | for i in range(len(tokens) - 1): 165 | x.append(tokenizer.convert_tokens_to_ids(tokens[i])) 166 | y.append(0) 167 | y_mask.append(0) 168 | if len(tokens) > 0: 169 | x.append(tokenizer.convert_tokens_to_ids(tokens[-1])) 170 | else: 171 | x.append(TOKEN_IDX[token_style]['UNK']) 172 | y.append(punctuation_dict[punc]) 173 | y_mask.append(1) 174 | idx += 1 175 | #yunxiang for multi-task 176 | if punctuation_dict[punc] in [2,3,4]: 177 | break 178 | 179 | x.append(TOKEN_IDX[token_style]['END_SEQ']) 180 | y.append(0) 181 | y_mask.append(1) 182 | 183 | if len(x) < sequence_len: 184 | x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))] 185 | y = y + [0 for _ in range(sequence_len - len(y))] 186 | y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))] 187 | attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x] 188 | data_items.append([x, y, attn_mask, y_mask]) 189 | 190 | return data_items 191 | 192 | def parse_data_cant_multitask(file_path, tokenizer, sequence_len, token_style): 193 | """ 194 | 195 | :param file_path: text file path that contains tokens and punctuations separated by tab in lines 196 | :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models 197 | :param sequence_len: maximum length of each sequence 198 | :param token_style: For getting index of special tokens in config.TOKEN_IDX 199 | :return: list of [tokens_index, punctuation_index, attention_masks, punctuation_mask], each having sequence_len 200 | punctuation_mask is used to ignore special indices like padding and intermediate sub-word token during evaluation 201 | """ 202 | punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4} 203 | data_items = [] 204 | with open(file_path, 'r', encoding='utf-8') as f: 205 | lines = [line for line in f.read().split('\n') if line.strip()] 206 | #print(lines) 207 | idx = 0 208 | # loop until end of the entire text 209 | while idx < len(lines): 210 | x = [TOKEN_IDX[token_style]['START_SEQ']] 211 | y = [0] 212 | y_multitask = [0] 213 | y_mask = [1] # which positions we need to consider while evaluating i.e., ignore pad or sub tokens 214 | 215 | # loop until we have required sequence length 216 | # -1 because we will have a special end of sequence token at the end 217 | while len(x) < sequence_len - 1 and idx < len(lines): #-2 for multitask -1 for normal 218 | word, punc, scenario = lines[idx].split(' ') #原來是/t 219 | y_multitask[0] = int(scenario) 220 | 221 | tokens = tokenizer.tokenize(word) 222 | # if taking these tokens exceeds sequence length we finish current sequence with padding 223 | # then start next sequence from this token 224 | if len(tokens) + len(x) >= sequence_len: 225 | break 226 | else: 227 | for i in range(len(tokens) - 1): 228 | x.append(tokenizer.convert_tokens_to_ids(tokens[i])) 229 | y.append(0) 230 | y_mask.append(0) 231 | if len(tokens) > 0: 232 | x.append(tokenizer.convert_tokens_to_ids(tokens[-1])) 233 | else: 234 | x.append(TOKEN_IDX[token_style]['UNK']) 235 | y.append(punctuation_dict[punc]) 236 | y_mask.append(1) 237 | idx += 1 238 | #yunxiang for multi-task 239 | if punctuation_dict[punc] in [2,3,4]: 240 | break 241 | 242 | x.append(TOKEN_IDX[token_style]['END_SEQ']) 243 | y.append(0) 244 | y_mask.append(1) 245 | 246 | if len(x) < sequence_len: 247 | x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))] 248 | y = y + [0 for _ in range(sequence_len - len(y))] 249 | y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))] 250 | attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x] 251 | data_items.append([x, y, attn_mask, y_mask,y_multitask]) 252 | 253 | return data_items 254 | 255 | def parse_data_cant_jyupin(file_path, tokenizer, sequence_len, token_style,multitask): 256 | """ 257 | 258 | :param file_path: text file path that contains tokens and punctuations separated by tab in lines 259 | :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models 260 | :param sequence_len: maximum length of each sequence 261 | :param token_style: For getting index of special tokens in config.TOKEN_IDX 262 | :return: list of [tokens_index, punctuation_index, attention_masks, punctuation_mask], each having sequence_len 263 | punctuation_mask is used to ignore special indices like padding and intermediate sub-word token during evaluation 264 | """ 265 | if multitask: 266 | punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4,'INFORMAL':5,'FORMAL':6} 267 | else: 268 | punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3,'EXCLAMATION': 4} 269 | data_items = [] 270 | #load dict 271 | jyupin_dict = np.load('src/jyupin2idx.npy', allow_pickle=True).item() 272 | with open(file_path, 'r', encoding='utf-8') as f: 273 | lines = [line for line in f.read().split('\n') if line.strip()] 274 | bool=False 275 | #print(lines) 276 | idx = 0 277 | # loop until end of the entire text 278 | while idx < len(lines): 279 | x = [TOKEN_IDX[token_style]['START_SEQ']] 280 | y = [0] 281 | y_mask = [1] # which positions we need to consider while evaluating i.e., ignore pad or sub tokens 282 | jyupin_seq = [31852] 283 | # loop until we have required sequence length 284 | # -1 because we will have a special end of sequence token at the end 285 | while len(x) < sequence_len - 1 and idx < len(lines): #yunxiang -2 for multitask -1 for normal 286 | if multitask: 287 | word, punc, jyupin, scenario = lines[idx].split(' ') #原來是/t 288 | if int(scenario)==0: 289 | y[0] = 5 290 | else: 291 | y[0] = 6 292 | 293 | else: 294 | word, punc, jyupin = lines[idx].split(' ') #原來是/t 295 | 296 | try: 297 | jyupin_tokens = jyupin_dict[jyupin] 298 | except: 299 | jyupin_tokens = jyupin_dict[str(jyupin)] 300 | 301 | tokens = tokenizer.tokenize(word) 302 | 303 | # if taking these tokens exceeds sequence length we finish current sequence with padding 304 | # then start next sequence from this token 305 | if len(tokens) + len(x) >= sequence_len: 306 | 307 | break 308 | else: 309 | for i in range(len(tokens) - 1): 310 | bool=True 311 | x.append(tokenizer.convert_tokens_to_ids(tokens[i])) 312 | y.append(0) 313 | y_mask.append(0) 314 | jyupin_seq.append(jyupin_tokens) 315 | 316 | if len(tokens) > 0: 317 | x.append(tokenizer.convert_tokens_to_ids(tokens[-1])) 318 | jyupin_seq.append(jyupin_tokens) 319 | else: 320 | x.append(TOKEN_IDX[token_style]['UNK']) 321 | jyupin_seq.append(jyupin_tokens) 322 | y.append(punctuation_dict[punc]) 323 | y_mask.append(1) 324 | 325 | idx += 1 326 | #yunxiang for multi-task 327 | if punctuation_dict[punc] in [2,3,4]: 328 | break 329 | 330 | x.append(TOKEN_IDX[token_style]['END_SEQ']) 331 | y.append(0) 332 | y_mask.append(1) 333 | jyupin_seq.append(31853) 334 | 335 | if len(x) < sequence_len: 336 | x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))] 337 | jyupin_seq = jyupin_seq + [31854 for _ in range(sequence_len - len(jyupin_seq))] 338 | y = y + [0 for _ in range(sequence_len - len(y))] 339 | y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))] 340 | attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x] 341 | data_items.append([x, y, attn_mask, y_mask,jyupin_seq]) 342 | 343 | return data_items 344 | 345 | 346 | class Dataset_sentences(torch.utils.data.Dataset): 347 | def __init__(self, files, tokenizer, sequence_len, token_style, is_train=False, augment_rate=0.1, 348 | augment_type='substitute'): 349 | """ 350 | 351 | :param files: single file or list of text files containing tokens and punctuations separated by tab in lines 352 | :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models 353 | :param sequence_len: length of each sequence 354 | :param token_style: For getting index of special tokens in config.TOKEN_IDX 355 | 356 | """ 357 | if isinstance(files, list): 358 | self.data = [] 359 | for file in files: 360 | self.data += parse_data(file, tokenizer, sequence_len, token_style) 361 | else: 362 | self.data = parse_data(files, tokenizer, sequence_len, token_style) 363 | self.sequence_len = sequence_len 364 | self.augment_rate = augment_rate 365 | self.token_style = token_style 366 | self.is_train = is_train 367 | self.augment_type = augment_type 368 | 369 | def __len__(self): 370 | return len(self.data) 371 | 372 | def _augment(self, x, y, y_mask): 373 | x_aug = [] 374 | y_aug = [] 375 | y_mask_aug = [] 376 | for i in range(len(x)): 377 | r = np.random.rand() 378 | if r < self.augment_rate: 379 | AUGMENTATIONS[self.augment_type](x, y, y_mask, x_aug, y_aug, y_mask_aug, i, self.token_style) 380 | else: 381 | x_aug.append(x[i]) 382 | y_aug.append(y[i]) 383 | y_mask_aug.append(y_mask[i]) 384 | 385 | if len(x_aug) > self.sequence_len: 386 | # len increased due to insert 387 | x_aug = x_aug[0:self.sequence_len] 388 | y_aug = y_aug[0:self.sequence_len] 389 | y_mask_aug = y_mask_aug[0:self.sequence_len] 390 | elif len(x_aug) < self.sequence_len: 391 | # len decreased due to delete 392 | x_aug = x_aug + [TOKEN_IDX[self.token_style]['PAD'] for _ in range(self.sequence_len - len(x_aug))] 393 | y_aug = y_aug + [0 for _ in range(self.sequence_len - len(y_aug))] 394 | y_mask_aug = y_mask_aug + [0 for _ in range(self.sequence_len - len(y_mask_aug))] 395 | 396 | attn_mask = [1 if token != TOKEN_IDX[self.token_style]['PAD'] else 0 for token in x] 397 | return x_aug, y_aug, attn_mask, y_mask_aug 398 | 399 | def __getitem__(self, index): 400 | x = self.data[index][0] 401 | y = self.data[index][1] 402 | attn_mask = self.data[index][2] 403 | y_mask = self.data[index][3] 404 | 405 | if self.is_train and self.augment_rate > 0: 406 | x, y, attn_mask, y_mask = self._augment(x, y, y_mask) 407 | 408 | x = torch.tensor(x) 409 | y = torch.tensor(y) 410 | attn_mask = torch.tensor(attn_mask) 411 | y_mask = torch.tensor(y_mask) 412 | 413 | return x, y, attn_mask, y_mask 414 | 415 | 416 | class Dataset_cant(torch.utils.data.Dataset): 417 | def __init__(self, files, tokenizer, sequence_len, token_style, is_train=False, augment_rate=0.1, 418 | augment_type='substitute'): 419 | """ 420 | 421 | :param files: single file or list of text files containing tokens and punctuations separated by tab in lines 422 | :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models 423 | :param sequence_len: length of each sequence 424 | :param token_style: For getting index of special tokens in config.TOKEN_IDX 425 | 426 | """ 427 | if isinstance(files, list): 428 | self.data = [] 429 | for file in files: 430 | self.data += parse_data_cant(file, tokenizer, sequence_len, token_style) 431 | else: 432 | self.data = parse_data_cant(files, tokenizer, sequence_len, token_style) 433 | self.sequence_len = sequence_len 434 | self.augment_rate = augment_rate 435 | self.token_style = token_style 436 | self.is_train = is_train 437 | self.augment_type = augment_type 438 | 439 | def __len__(self): 440 | return len(self.data) 441 | 442 | def _augment(self, x, y, y_mask): 443 | x_aug = [] 444 | y_aug = [] 445 | y_mask_aug = [] 446 | for i in range(len(x)): 447 | r = np.random.rand() 448 | if r < self.augment_rate: 449 | AUGMENTATIONS[self.augment_type](x, y, y_mask, x_aug, y_aug, y_mask_aug, i, self.token_style) 450 | else: 451 | x_aug.append(x[i]) 452 | y_aug.append(y[i]) 453 | y_mask_aug.append(y_mask[i]) 454 | 455 | if len(x_aug) > self.sequence_len: 456 | # len increased due to insert 457 | x_aug = x_aug[0:self.sequence_len] 458 | y_aug = y_aug[0:self.sequence_len] 459 | y_mask_aug = y_mask_aug[0:self.sequence_len] 460 | elif len(x_aug) < self.sequence_len: 461 | # len decreased due to delete 462 | x_aug = x_aug + [TOKEN_IDX[self.token_style]['PAD'] for _ in range(self.sequence_len - len(x_aug))] 463 | y_aug = y_aug + [0 for _ in range(self.sequence_len - len(y_aug))] 464 | y_mask_aug = y_mask_aug + [0 for _ in range(self.sequence_len - len(y_mask_aug))] 465 | 466 | attn_mask = [1 if token != TOKEN_IDX[self.token_style]['PAD'] else 0 for token in x] 467 | return x_aug, y_aug, attn_mask, y_mask_aug 468 | 469 | def __getitem__(self, index): 470 | x = self.data[index][0] 471 | y = self.data[index][1] 472 | attn_mask = self.data[index][2] 473 | y_mask = self.data[index][3] 474 | 475 | if self.is_train and self.augment_rate > 0: 476 | x, y, attn_mask, y_mask = self._augment(x, y, y_mask) 477 | 478 | x = torch.tensor(x) 479 | y = torch.tensor(y) 480 | attn_mask = torch.tensor(attn_mask) 481 | y_mask = torch.tensor(y_mask) 482 | 483 | return x, y, attn_mask, y_mask 484 | 485 | 486 | class Dataset_cant_multitask(torch.utils.data.Dataset): 487 | def __init__(self, files, tokenizer, sequence_len, token_style, is_train=False, augment_rate=0.1, 488 | augment_type='substitute'): 489 | """ 490 | 491 | :param files: single file or list of text files containing tokens and punctuations separated by tab in lines 492 | :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models 493 | :param sequence_len: length of each sequence 494 | :param token_style: For getting index of special tokens in config.TOKEN_IDX 495 | 496 | """ 497 | if isinstance(files, list): 498 | self.data = [] 499 | for file in files: 500 | self.data += parse_data_cant_multitask(file, tokenizer, sequence_len, token_style) 501 | else: 502 | self.data = parse_data_cant_multitask(files, tokenizer, sequence_len, token_style) 503 | self.sequence_len = sequence_len 504 | self.augment_rate = augment_rate 505 | self.token_style = token_style 506 | self.is_train = is_train 507 | self.augment_type = augment_type 508 | 509 | def __len__(self): 510 | return len(self.data) 511 | 512 | def _augment(self, x, y, y_mask): 513 | x_aug = [] 514 | y_aug = [] 515 | y_mask_aug = [] 516 | for i in range(len(x)): 517 | r = np.random.rand() 518 | if r < self.augment_rate: 519 | AUGMENTATIONS[self.augment_type](x, y, y_mask, x_aug, y_aug, y_mask_aug, i, self.token_style) 520 | else: 521 | x_aug.append(x[i]) 522 | y_aug.append(y[i]) 523 | y_mask_aug.append(y_mask[i]) 524 | 525 | if len(x_aug) > self.sequence_len: 526 | # len increased due to insert 527 | x_aug = x_aug[0:self.sequence_len] 528 | y_aug = y_aug[0:self.sequence_len] 529 | y_mask_aug = y_mask_aug[0:self.sequence_len] 530 | elif len(x_aug) < self.sequence_len: 531 | # len decreased due to delete 532 | x_aug = x_aug + [TOKEN_IDX[self.token_style]['PAD'] for _ in range(self.sequence_len - len(x_aug))] 533 | y_aug = y_aug + [0 for _ in range(self.sequence_len - len(y_aug))] 534 | y_mask_aug = y_mask_aug + [0 for _ in range(self.sequence_len - len(y_mask_aug))] 535 | 536 | attn_mask = [1 if token != TOKEN_IDX[self.token_style]['PAD'] else 0 for token in x] 537 | return x_aug, y_aug, attn_mask, y_mask_aug 538 | 539 | def __getitem__(self, index): 540 | x = self.data[index][0] 541 | y = self.data[index][1] 542 | attn_mask = self.data[index][2] 543 | y_mask = self.data[index][3] 544 | y_multitask = self.data[index][4] 545 | 546 | if self.is_train and self.augment_rate > 0: 547 | x, y, attn_mask, y_mask = self._augment(x, y, y_mask) 548 | 549 | x = torch.tensor(x) 550 | y = torch.tensor(y) 551 | y_multitask = torch.tensor(y_multitask) 552 | attn_mask = torch.tensor(attn_mask) 553 | y_mask = torch.tensor(y_mask) 554 | 555 | return x, y, attn_mask, y_mask,y_multitask 556 | 557 | class Dataset_cant_multitask_sequence(torch.utils.data.Dataset): 558 | def __init__(self, files, tokenizer, sequence_len, token_style, is_train=False, augment_rate=0.1, 559 | augment_type='substitute'): 560 | """ 561 | 562 | :param files: single file or list of text files containing tokens and punctuations separated by tab in lines 563 | :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models 564 | :param sequence_len: length of each sequence 565 | :param token_style: For getting index of special tokens in config.TOKEN_IDX 566 | 567 | """ 568 | if isinstance(files, list): 569 | self.data = [] 570 | for file in files: 571 | self.data += parse_data_cant_multitask_sequence(file, tokenizer, sequence_len, token_style) 572 | else: 573 | self.data = parse_data_cant_multitask_sequence(files, tokenizer, sequence_len, token_style) 574 | self.sequence_len = sequence_len 575 | self.augment_rate = augment_rate 576 | self.token_style = token_style 577 | self.is_train = is_train 578 | self.augment_type = augment_type 579 | 580 | def __len__(self): 581 | return len(self.data) 582 | 583 | def _augment(self, x, y, y_mask): 584 | x_aug = [] 585 | y_aug = [] 586 | y_mask_aug = [] 587 | for i in range(len(x)): 588 | r = np.random.rand() 589 | if r < self.augment_rate: 590 | AUGMENTATIONS[self.augment_type](x, y, y_mask, x_aug, y_aug, y_mask_aug, i, self.token_style) 591 | else: 592 | x_aug.append(x[i]) 593 | y_aug.append(y[i]) 594 | y_mask_aug.append(y_mask[i]) 595 | 596 | if len(x_aug) > self.sequence_len: 597 | # len increased due to insert 598 | x_aug = x_aug[0:self.sequence_len] 599 | y_aug = y_aug[0:self.sequence_len] 600 | y_mask_aug = y_mask_aug[0:self.sequence_len] 601 | elif len(x_aug) < self.sequence_len: 602 | # len decreased due to delete 603 | x_aug = x_aug + [TOKEN_IDX[self.token_style]['PAD'] for _ in range(self.sequence_len - len(x_aug))] 604 | y_aug = y_aug + [0 for _ in range(self.sequence_len - len(y_aug))] 605 | y_mask_aug = y_mask_aug + [0 for _ in range(self.sequence_len - len(y_mask_aug))] 606 | 607 | attn_mask = [1 if token != TOKEN_IDX[self.token_style]['PAD'] else 0 for token in x] 608 | return x_aug, y_aug, attn_mask, y_mask_aug 609 | 610 | def __getitem__(self, index): 611 | x = self.data[index][0] 612 | y = self.data[index][1] 613 | attn_mask = self.data[index][2] 614 | y_mask = self.data[index][3] 615 | 616 | if self.is_train and self.augment_rate > 0: 617 | x, y, attn_mask, y_mask = self._augment(x, y, y_mask) 618 | 619 | x = torch.tensor(x) 620 | y = torch.tensor(y) 621 | attn_mask = torch.tensor(attn_mask) 622 | y_mask = torch.tensor(y_mask) 623 | 624 | return x, y, attn_mask, y_mask 625 | 626 | 627 | class Dataset_cant_jyupin(torch.utils.data.Dataset): 628 | def __init__(self, files, tokenizer, sequence_len, token_style, is_train=False, augment_rate=0.1, 629 | augment_type='substitute',multitask=False): 630 | """ 631 | 632 | :param files: single file or list of text files containing tokens and punctuations separated by tab in lines 633 | :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models 634 | :param sequence_len: length of each sequence 635 | :param token_style: For getting index of special tokens in config.TOKEN_IDX 636 | 637 | """ 638 | if isinstance(files, list): 639 | self.data = [] 640 | for file in files: 641 | self.data += parse_data_cant_jyupin(file, tokenizer, sequence_len, token_style,multitask) 642 | else: 643 | self.data = parse_data_cant_jyupin(files, tokenizer, sequence_len, token_style,multitask) 644 | self.sequence_len = sequence_len 645 | self.augment_rate = augment_rate 646 | self.token_style = token_style 647 | self.is_train = is_train 648 | self.augment_type = augment_type 649 | 650 | def __len__(self): 651 | return len(self.data) 652 | 653 | def _augment(self, x, y, y_mask): 654 | x_aug = [] 655 | y_aug = [] 656 | y_mask_aug = [] 657 | for i in range(len(x)): 658 | r = np.random.rand() 659 | if r < self.augment_rate: 660 | AUGMENTATIONS[self.augment_type](x, y, y_mask, x_aug, y_aug, y_mask_aug, i, self.token_style) 661 | else: 662 | x_aug.append(x[i]) 663 | y_aug.append(y[i]) 664 | y_mask_aug.append(y_mask[i]) 665 | 666 | if len(x_aug) > self.sequence_len: 667 | # len increased due to insert 668 | x_aug = x_aug[0:self.sequence_len] 669 | y_aug = y_aug[0:self.sequence_len] 670 | y_mask_aug = y_mask_aug[0:self.sequence_len] 671 | elif len(x_aug) < self.sequence_len: 672 | # len decreased due to delete 673 | x_aug = x_aug + [TOKEN_IDX[self.token_style]['PAD'] for _ in range(self.sequence_len - len(x_aug))] 674 | y_aug = y_aug + [0 for _ in range(self.sequence_len - len(y_aug))] 675 | y_mask_aug = y_mask_aug + [0 for _ in range(self.sequence_len - len(y_mask_aug))] 676 | 677 | attn_mask = [1 if token != TOKEN_IDX[self.token_style]['PAD'] else 0 for token in x] 678 | return x_aug, y_aug, attn_mask, y_mask_aug 679 | 680 | def __getitem__(self, index): 681 | x = self.data[index][0] 682 | y = self.data[index][1] 683 | attn_mask = self.data[index][2] 684 | y_mask = self.data[index][3] 685 | jyupin = self.data[index][4] 686 | 687 | if self.is_train and self.augment_rate > 0: 688 | x, y, attn_mask, y_mask = self._augment(x, y, y_mask) 689 | 690 | x = torch.tensor(x) 691 | y = torch.tensor(y) 692 | attn_mask = torch.tensor(attn_mask) 693 | y_mask = torch.tensor(y_mask) 694 | jyupin = torch.tensor(jyupin) 695 | 696 | return x, y, attn_mask, y_mask, jyupin 697 | 698 | --------------------------------------------------------------------------------