├── .gitignore ├── requirements.txt ├── utils.py ├── README.md ├── models.py ├── main.py └── helpers.py /.gitignore: -------------------------------------------------------------------------------- 1 | trash 2 | model 3 | __pycache__ 4 | .idea -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | clean-text==0.3.0 2 | hazm==0.7.0 3 | numpy<1.20 4 | torch==1.7.1 5 | tqdm==4.46.1 6 | transformers==4.3.3 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def get_sentences_splitters(txt): 4 | splitters = ['? ', '! ', '. ', '.\n', '?\n', '!\n', ' ؟', '\n؟'] 5 | all_sents = [] 6 | last_sent_index= 0 7 | for i, (ch1, ch2) in enumerate(zip(txt, txt[1:])): 8 | if ch1 + ch2 in splitters: 9 | all_sents.append((txt[last_sent_index:i+len(ch1)],ch2)) 10 | last_sent_index = i+ len(ch1+ch2) 11 | all_sents.append((txt[last_sent_index:], None)) 12 | return [item[0] for item in all_sents], [item[1] for item in all_sents[:-1]] 13 | 14 | def space_special_chars(txt): 15 | return re.sub('([.:،<>,!?()])', r' \1 ', txt) 16 | 17 | def de_space_special_chars(txt): 18 | txt = re.sub('( ([.:،<>,!?()]) )', r'\2', txt) 19 | txt = re.sub('( ([.:،<>,!?()]))', r'\2', txt) 20 | return re.sub('(([.:،<>,!?()]) )', r'\2', txt) 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nevise: A Bert-Based Spell-Checker for Persian 2 | 3 | Nevise is a Persian spelling-checker developed by Dadmatech Co based on deep learning. Nevise is available in two versions. The second version has greater accuracy, the ability to correct errors based on spaces, and a better understanding of special characters like half space. These versions can be accessed via web services and as demos. We provide public access to the code and model checkpoint of the first version here. 4 | 5 | ## packages Installation 6 | 7 | Use the package manager [pip](https://pip.pypa.io/en/stable/) to install packages. 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | ``` 12 | ## Download model checkpoint and vocab and put them on "model" directory 13 | 14 | 15 | ```bash 16 | pip install gdown 17 | mkdir model 18 | cd model 19 | gdown https://drive.google.com/uc?id=1Ki5WGR4yxftDEjROQLf9Br8KHef95k1F 20 | gdown https://drive.google.com/uc?id=1nKeMdDnxIJpOv-OeFj00UnhoChuaY5Ns 21 | ``` 22 | ## run 23 | 24 | 25 | ```bash 26 | python main.py 27 | ``` 28 | # Demo 29 | 30 | [Nevise(both versions)](https://dadmatech.ir/#/products/SpellChecker) 31 | 32 | # Results on [Nevise Dataset](https://github.com/Dadmatech/Nevise-Dataset/tree/main/nevise-news-title-539) 33 | 34 |
35 | 36 | | Algorithm | Wrong detection rate | Wrong correction rate | Correct to wrong rate | Precision | 37 | | -- | -- | -- | -- | -- | 38 | | Nevise 2 | **0.8314** | **0.7216** | 0.003 | 0.968 | 39 | | Paknevis | 0.7843 | 0.6706 | 0.228 | 0.7921 | 40 | | Nevise 1 | 0.7647 | 0.6824 | **0.0019** | **0.9774** | 41 | | Google | 0.7392 | 0.702 | 0.0045 | 0.9449 | 42 | | Virastman | 0.6 | 0.5 | 0.0032 | 0.9533 | 43 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils.rnn import pad_sequence 4 | import torch.nn.functional as F 5 | import transformers 6 | 7 | ################################################# 8 | # SubwordBert 9 | ################################################# 10 | 11 | 12 | class SubwordBert(nn.Module): 13 | def __init__(self, screp_dim, padding_idx, output_dim): 14 | super(SubwordBert,self).__init__() 15 | 16 | self.bert_dropout = torch.nn.Dropout(0.2) 17 | self.bert_model = transformers.BertModel.from_pretrained("HooshvareLab/bert-fa-base-uncased") 18 | self.bertmodule_outdim = self.bert_model.config.hidden_size 19 | # Uncomment to freeze BERT layers 20 | # for param in self.bert_model.parameters(): 21 | # param.requires_grad = False 22 | 23 | # output module 24 | assert output_dim>0 25 | # self.dropout = nn.Dropout(p=0.4) 26 | self.dense = nn.Linear(self.bertmodule_outdim,output_dim) 27 | 28 | # loss 29 | # See https://pytorch.org/docs/stable/nn.html#crossentropyloss 30 | self.criterion = nn.CrossEntropyLoss(reduction='mean',ignore_index=padding_idx) 31 | 32 | @property 33 | def device(self) -> torch.device: 34 | return next(self.parameters()).device 35 | 36 | def get_merged_encodings(self, bert_seq_encodings, seq_splits, mode='avg'): 37 | bert_seq_encodings = bert_seq_encodings[:sum(seq_splits)+2,:] # 2 for [CLS] and [SEP] 38 | bert_seq_encodings = bert_seq_encodings[1:-1,:] 39 | # a tuple of tensors 40 | split_encoding = torch.split(bert_seq_encodings,seq_splits,dim=0) 41 | batched_encodings = pad_sequence(split_encoding,batch_first=True,padding_value=0) 42 | if mode=='avg': 43 | seq_splits = torch.tensor(seq_splits).reshape(-1,1).to(self.device) 44 | out = torch.div( torch.sum(batched_encodings,dim=1), seq_splits) 45 | elif mode=="add": 46 | out = torch.sum(batched_encodings,dim=1) 47 | else: 48 | raise Exception("Not Implemented") 49 | return out 50 | 51 | def forward(self, 52 | batch_bert_dict: "{'input_ids':tensor, 'attention_mask':tensor, 'token_type_ids':tensor}", 53 | batch_splits: "list[list[int]]", 54 | aux_word_embs: "tensor" = None, 55 | targets: "tensor" = None, 56 | topk = 1): 57 | 58 | # cnn 59 | batch_size = len(batch_splits) 60 | 61 | # bert 62 | # BS X max_nsubwords x self.bertmodule_outdim 63 | bert_encodings, cls_encoding = self.bert_model( 64 | input_ids=batch_bert_dict["input_ids"], 65 | attention_mask=batch_bert_dict["attention_mask"], 66 | token_type_ids=batch_bert_dict["token_type_ids"], 67 | return_dict=False 68 | ) 69 | bert_encodings = self.bert_dropout(bert_encodings) 70 | # BS X max_nwords x self.bertmodule_outdim 71 | bert_merged_encodings = pad_sequence( 72 | [self.get_merged_encodings(bert_seq_encodings, seq_splits, mode='avg') \ 73 | for bert_seq_encodings, seq_splits in zip(bert_encodings,batch_splits)], 74 | batch_first=True, 75 | padding_value=0 76 | ) 77 | 78 | # concat aux_embs 79 | # if not None, the expected dim for aux_word_embs: [BS,max_nwords,*] 80 | intermediate_encodings = bert_merged_encodings 81 | if aux_word_embs is not None: 82 | intermediate_encodings = torch.cat((intermediate_encodings,aux_word_embs),dim=2) 83 | 84 | # dense 85 | # [BS,max_nwords,*] or [BS,max_nwords,self.bertmodule_outdim]->[BS,max_nwords,output_dim] 86 | # logits = self.dense(self.dropout(intermediate_encodings)) 87 | logits = self.dense(intermediate_encodings) 88 | 89 | # loss 90 | if targets is not None: 91 | assert len(targets)==batch_size # targets:[[BS,max_nwords] 92 | logits_permuted = logits.permute(0, 2, 1) # logits: [BS,output_dim,max_nwords] 93 | loss = self.criterion(logits_permuted,targets) 94 | 95 | # eval preds 96 | if not self.training: 97 | probs = F.softmax(logits,dim=-1) # [BS,max_nwords,output_dim] 98 | if topk>1: 99 | topk_values, topk_inds = \ 100 | torch.topk(probs, topk, dim=-1, largest=True, sorted=True) # -> (Tensor, LongTensor) of [BS,max_nwords,topk] 101 | elif topk==1: 102 | topk_inds = torch.argmax(probs,dim=-1) # [BS,max_nwords] 103 | 104 | # Note that for those positions with padded_idx, 105 | # the arg_max_prob above computes a index because 106 | # the bias term leads to non-uniform values in those positions 107 | 108 | return loss.cpu().detach().numpy(), topk_inds.cpu().detach().numpy() 109 | return loss 110 | 111 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import re 4 | import time 5 | import torch 6 | import utils 7 | from helpers import load_vocab_dict 8 | from helpers import batch_iter, labelize, bert_tokenize_for_valid_examples 9 | from helpers import untokenize_without_unks, untokenize_without_unks2, get_model_nparams 10 | from hazm import Normalizer 11 | from models import SubwordBert 12 | from utils import get_sentences_splitters 13 | 14 | 15 | def model_inference(model, data, topk, DEVICE, BATCH_SIZE=16, vocab_=None): 16 | """ 17 | model: an instance of SubwordBert 18 | data: list of tuples, with each tuple consisting of correct and incorrect 19 | sentence string (would be split at whitespaces) 20 | topk: how many of the topk softmax predictions are considered for metrics calculations 21 | """ 22 | if vocab_ is not None: 23 | vocab = vocab_ 24 | print("###############################################") 25 | inference_st_time = time.time() 26 | _corr2corr, _corr2incorr, _incorr2corr, _incorr2incorr = 0, 0, 0, 0 27 | _mistakes = [] 28 | VALID_BATCH_SIZE = BATCH_SIZE 29 | valid_loss = 0. 30 | print("data size: {}".format(len(data))) 31 | data_iter = batch_iter(data, batch_size=VALID_BATCH_SIZE, shuffle=False) 32 | model.eval() 33 | model.to(DEVICE) 34 | results = [] 35 | line_index = 0 36 | for batch_id, (batch_labels, batch_sentences) in tqdm(enumerate(data_iter)): 37 | torch.cuda.empty_cache() 38 | st_time = time.time() 39 | # set batch data for bert 40 | batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples( 41 | batch_labels, batch_sentences) 42 | if len(batch_labels_) == 0: 43 | print("################") 44 | print("Not predicting the following lines due to pre-processing mismatch: \n") 45 | print([(a, b) for a, b in zip(batch_labels, batch_sentences)]) 46 | print("################") 47 | continue 48 | else: 49 | batch_labels, batch_sentences = batch_labels_, batch_sentences_ 50 | batch_bert_inp = {k: v.to(DEVICE) for k, v in batch_bert_inp.items()} 51 | # set batch data for others 52 | batch_labels_ids, batch_lengths = labelize(batch_labels, vocab) 53 | batch_lengths = batch_lengths.to(DEVICE) 54 | batch_labels_ids = batch_labels_ids.to(DEVICE) 55 | 56 | try: 57 | with torch.no_grad(): 58 | """ 59 | NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len) 60 | """ 61 | batch_loss, batch_predictions = model(batch_bert_inp, batch_bert_splits, targets=batch_labels_ids, 62 | topk=topk) 63 | except RuntimeError: 64 | print(f"batch_bert_inp:{len(batch_bert_inp.keys())},batch_labels_ids:{batch_labels_ids.shape}") 65 | raise Exception("") 66 | valid_loss += batch_loss 67 | batch_lengths = batch_lengths.cpu().detach().numpy() 68 | if topk == 1: 69 | batch_predictions = untokenize_without_unks(batch_predictions, batch_lengths, vocab, batch_sentences) 70 | else: 71 | batch_predictions = untokenize_without_unks2(batch_predictions, batch_lengths, vocab, batch_sentences, 72 | topk=None) 73 | batch_clean_sentences = [line for line in batch_labels] 74 | batch_corrupt_sentences = [line for line in batch_sentences] 75 | batch_predictions = [line for line in batch_predictions] 76 | 77 | for i, (a, b, c) in enumerate(zip(batch_clean_sentences, batch_corrupt_sentences, batch_predictions)): 78 | results.append({"id": line_index + i, "original": a, "noised": b, "predicted": c, "topk": [], 79 | "topk_prediction_probs": [], "topk_reranker_losses": []}) 80 | line_index += len(batch_clean_sentences) 81 | 82 | ''' 83 | # update progress 84 | progressBar(batch_id+1, 85 | int(np.ceil(len(data) / VALID_BATCH_SIZE)), 86 | ["batch_time","batch_loss","avg_batch_loss","batch_acc","avg_batch_acc"], 87 | [time.time()-st_time,batch_loss,valid_loss/(batch_id+1),None,None]) 88 | ''' 89 | print(f"\nEpoch {None} valid_loss: {valid_loss / (batch_id + 1)}") 90 | print("total inference time for this data is: {:4f} secs".format(time.time() - inference_st_time)) 91 | print("###############################################") 92 | print("###############################################") 93 | return results 94 | 95 | def load_model(vocab): 96 | model = SubwordBert(3*len(vocab["chartoken2idx"]),vocab["token2idx"][ vocab["pad_token"] ],len(vocab["token_freq"])) 97 | print(model) 98 | print( get_model_nparams(model) ) 99 | return model 100 | 101 | 102 | def load_pretrained(model, checkpoint_path, optimizer=None, device='cuda'): 103 | if torch.cuda.is_available() and device != "cpu": 104 | map_location = lambda storage, loc: storage.cuda() 105 | else: 106 | map_location = 'cpu' 107 | print(f"Loading model params from checkpoint dir: {checkpoint_path}") 108 | checkpoint_data = torch.load(checkpoint_path, map_location=map_location) 109 | model.load_state_dict(checkpoint_data['model_state_dict']) 110 | if optimizer is not None: 111 | optimizer.load_state_dict(checkpoint_data['optimizer_state_dict']) 112 | max_dev_acc, argmax_dev_acc = checkpoint_data["max_dev_acc"], checkpoint_data["argmax_dev_acc"] 113 | 114 | if optimizer is not None: 115 | return model, optimizer, max_dev_acc, argmax_dev_acc 116 | return model 117 | 118 | def load_pre_model(vocab_path, model_checkpoint_path): 119 | DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" 120 | print(f"loading vocab from {vocab_path}") 121 | vocab = load_vocab_dict(vocab_path) 122 | model = load_model(vocab) 123 | model = load_pretrained(model, model_checkpoint_path) 124 | return model, vocab, DEVICE 125 | 126 | 127 | def spell_checking_on_sents(model, vocab, device, normalizer, txt): 128 | sents, splitters = get_sentences_splitters(txt) 129 | sents = [utils.space_special_chars(s) for s in sents] 130 | sents = list(filter(lambda txt: (txt != '' and txt != ' '), sents)) 131 | test_data = [(normalizer.normalize(t), normalizer.normalize(t)) for t in sents] 132 | print('inputs:') 133 | for t in test_data: 134 | print(t) 135 | greedy_results = model_inference(model, test_data, topk=1, DEVICE=device, BATCH_SIZE=1, 136 | vocab_=vocab) 137 | out = [] 138 | for i, line in enumerate(greedy_results): 139 | ls = [(n, p) if n == p else ("**" + n + "**", "**" + p + "**") for n, p in 140 | zip(line["noised"].split(), line["predicted"].split())] 141 | y, z = map(list, zip(*ls)) 142 | try: 143 | z = ' '.join(z) 144 | z = re.sub(r'\*\*(\w+)\*\*', r'** \1 **', z) 145 | z = re.sub(r'\*\* (\w+) \*\*', r'**\1**', z) 146 | except: 147 | z = ' '.join(z) 148 | out.append((" ".join(y), z)) 149 | new_out = [] 150 | for i, sent in enumerate(out): 151 | new_out.append( (utils.de_space_special_chars(out[i][0]), utils.de_space_special_chars(out[i][1]))) 152 | return new_out, splitters 153 | 154 | 155 | if __name__ == '__main__': 156 | normalizer = Normalizer(punctuation_spacing=False, remove_extra_spaces=False) 157 | vocab_path = os.path.join('model', 'vocab.pkl') 158 | model_checkpoint_path = os.path.join('model', 'model.pth.tar') 159 | model, vocab, device = load_pre_model(vocab_path=vocab_path, model_checkpoint_path=model_checkpoint_path) 160 | #test 161 | sample_input = 'این یک مثالل صاده برالی ازرابی این سامانح اسصت' 162 | output = spell_checking_on_sents(model, vocab, device, normalizer, sample_input) 163 | print(output) -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os, sys 3 | import numpy as np 4 | import pickle 5 | import numpy as np 6 | import transformers 7 | import torch 8 | from torch.nn.utils.rnn import pad_sequence 9 | 10 | def progressBar(value, endvalue, names, values, bar_length=30): 11 | assert(len(names)==len(values)); 12 | percent = float(value) / endvalue 13 | arrow = '-' * int(round(percent * bar_length)-1) + '>' 14 | spaces = ' ' * (bar_length - len(arrow)); 15 | string = ''; 16 | for name, val in zip(names,values): 17 | temp = '|| {0}: {1:.4f} '.format(name, val) if val!=None else '|| {0}: {1} '.format(name, None) 18 | string+=temp; 19 | sys.stdout.write("\rPercent: [{0}] {1}% {2}".format(arrow + spaces, int(round(percent * 100)), string)) 20 | sys.stdout.flush() 21 | return 22 | 23 | def load_data(base_path, corr_file, incorr_file): 24 | 25 | # load files 26 | if base_path: 27 | assert os.path.exists(base_path)==True 28 | incorr_data = [] 29 | opfile1 = open(os.path.join(base_path, incorr_file),"r") 30 | for line in opfile1: 31 | if line.strip()!="": incorr_data.append(line.strip()) 32 | opfile1.close() 33 | corr_data = [] 34 | opfile2 = open(os.path.join(base_path, corr_file),"r") 35 | for line in opfile2: 36 | if line.strip()!="": corr_data.append(line.strip()) 37 | opfile2.close() 38 | assert len(incorr_data)==len(corr_data) 39 | 40 | # verify if token split is same 41 | for i,(x,y) in tqdm(enumerate(zip(corr_data,incorr_data))): 42 | x_split, y_split = x.split(), y.split() 43 | try: 44 | assert len(x_split)==len(y_split) 45 | except AssertionError: 46 | print("# tokens in corr and incorr mismatch. retaining and trimming to min len.") 47 | # print(x_split, y_split) 48 | # mn = min([len(x_split),len(y_split)]) 49 | # corr_data[i] = " ".join(x_split[:mn]) 50 | # incorr_data[i] = " ".join(y_split[:mn]) 51 | # print(corr_data[i],incorr_data[i]) 52 | 53 | # return as pairs 54 | data = [] 55 | for x,y in tqdm(zip(corr_data,incorr_data)): 56 | data.append((x,y)) 57 | 58 | print(f"loaded tuples of (corr,incorr) examples from {base_path}") 59 | return data 60 | 61 | 62 | 63 | def batch_iter(data, batch_size, shuffle): 64 | """ 65 | each data item is a tuple of lables and text 66 | """ 67 | n_batches = int(np.ceil(len(data) / batch_size)) 68 | indices = list(range(len(data))) 69 | if shuffle: np.random.shuffle(indices) 70 | 71 | for i in range(n_batches): 72 | batch_indices = indices[i * batch_size: (i + 1) * batch_size] 73 | batch_labels = [data[idx][0] for idx in batch_indices] 74 | batch_sentences = [data[idx][1] for idx in batch_indices] 75 | 76 | yield (batch_labels,batch_sentences) 77 | 78 | def labelize(batch_labels, vocab): 79 | token2idx, pad_token, unk_token = vocab["token2idx"], vocab["pad_token"], vocab["unk_token"] 80 | list_list = [[token2idx[token] if token in token2idx else token2idx[unk_token] for token in line.split()] for line in batch_labels] 81 | list_tensors = [torch.tensor(x) for x in list_list] 82 | tensor_ = pad_sequence(list_tensors,batch_first=True,padding_value=token2idx[pad_token]) 83 | return tensor_, torch.tensor([len(x) for x in list_list]).long() 84 | 85 | def tokenize(batch_sentences, vocab): 86 | token2idx, pad_token, unk_token = vocab["token2idx"], vocab["pad_token"], vocab["unk_token"] 87 | list_list = [[token2idx[token] if token in token2idx else token2idx[unk_token] for token in line.split()] for line in batch_sentences] 88 | list_tensors = [torch.tensor(x) for x in list_list] 89 | tensor_ = pad_sequence(list_tensors,batch_first=True,padding_value=token2idx[pad_token]) 90 | return tensor_, torch.tensor([len(x) for x in list_list]).long() 91 | 92 | 93 | def untokenize(batch_predictions, batch_lengths, vocab): 94 | idx2token = vocab["idx2token"] 95 | unktoken = vocab["unk_token"] 96 | assert len(batch_predictions)==len(batch_lengths) 97 | batch_predictions = \ 98 | [ " ".join( [idx2token[idx] for idx in pred_[:len_]] ) \ 99 | for pred_,len_ in zip(batch_predictions,batch_lengths) ] 100 | return batch_predictions 101 | 102 | def untokenize_without_unks(batch_predictions, batch_lengths, vocab, batch_clean_sentences, backoff="pass-through"): 103 | assert backoff in ["neutral","pass-through"], print(f"selected backoff strategy not implemented: {backoff}") 104 | idx2token = vocab["idx2token"] 105 | unktoken = vocab["token2idx"][vocab["unk_token"]] 106 | assert len(batch_predictions)==len(batch_lengths)==len(batch_clean_sentences) 107 | batch_clean_sentences = [sent.split() for sent in batch_clean_sentences] 108 | if backoff=="pass-through": 109 | batch_predictions = \ 110 | [ " ".join( [ idx2token[idx] if idx!=unktoken else clean_[i] for i, idx in enumerate(pred_[:len_]) ] ) \ 111 | for pred_,len_,clean_ in zip(batch_predictions,batch_lengths,batch_clean_sentences) ] 112 | elif backoff=="neutral": 113 | batch_predictions = \ 114 | [ " ".join( [ idx2token[idx] if idx!=unktoken else "a" for i, idx in enumerate(pred_[:len_]) ] ) \ 115 | for pred_,len_,clean_ in zip(batch_predictions,batch_lengths,batch_clean_sentences) ] 116 | return batch_predictions 117 | 118 | def untokenize_without_unks2(batch_predictions, batch_lengths, vocab, batch_clean_sentences, topk=None): 119 | """ 120 | batch_predictions are softmax probabilities and should have shape (batch_size,max_seq_len,vocab_size) 121 | batch_lengths should have shape (batch_size) 122 | batch_clean_sentences should be strings of shape (batch_size) 123 | """ 124 | #print(batch_predictions.shape) 125 | idx2token = vocab["idx2token"] 126 | unktoken = vocab["token2idx"][vocab["unk_token"]] 127 | assert len(batch_predictions)==len(batch_lengths)==len(batch_clean_sentences) 128 | batch_clean_sentences = [sent.split() for sent in batch_clean_sentences] 129 | 130 | if topk is not None: 131 | # get topk items from dim=2 i.e top 5 prob inds 132 | batch_predictions = np.argpartition(-batch_predictions,topk,axis=-1)[:,:,:topk] # (batch_size,max_seq_len,5) 133 | #else: 134 | # batch_predictions = batch_predictions # already have the topk indices 135 | 136 | # get topk words 137 | idx_to_token = lambda idx,idx2token,corresponding_clean_token,unktoken: idx2token[idx] if idx!=unktoken else corresponding_clean_token 138 | batch_predictions = \ 139 | [[[idx_to_token(wordidx,idx2token,batch_clean_sentences[i][j],unktoken) \ 140 | for wordidx in topk_wordidxs] \ 141 | for j,topk_wordidxs in enumerate(predictions[:batch_lengths[i]])] \ 142 | for i,predictions in enumerate(batch_predictions)] 143 | 144 | return batch_predictions 145 | 146 | 147 | 148 | def get_model_nparams(model): 149 | ntotal = 0 150 | for param in list(model.parameters()): 151 | temp = 1 152 | for sz in list(param.size()): temp*=sz 153 | ntotal += temp 154 | return ntotal 155 | 156 | 157 | def load_vocab_dict(path_: str): 158 | """ 159 | path_: path where the vocab pickle file is saved 160 | """ 161 | with open(path_, 'rb') as fp: 162 | vocab = pickle.load(fp) 163 | return vocab 164 | 165 | 166 | 167 | BERT_TOKENIZER = transformers.BertTokenizer.from_pretrained("HooshvareLab/bert-fa-base-uncased", do_lower_case=False) 168 | BERT_TOKENIZER.do_basic_tokenize = False 169 | BERT_TOKENIZER.tokenize_chinese_chars = False 170 | BERT_MAX_SEQ_LEN = 512 171 | 172 | def merge_subtokens(tokens: "list"): 173 | merged_tokens = [] 174 | for token in tokens: 175 | if token.startswith("##"): merged_tokens[-1] = merged_tokens[-1]+token[2:] 176 | else: merged_tokens.append(token) 177 | text = " ".join(merged_tokens) 178 | return text 179 | 180 | 181 | def _custom_bert_tokenize_sentence(text): 182 | # from hazm import WordTokenizer 183 | new_tokens = [] 184 | tokens = BERT_TOKENIZER.tokenize(text) 185 | j = 0 186 | for i, t in enumerate(tokens): 187 | if t == '[UNK]': 188 | new_tokens.append(text.split()[j]) 189 | else: 190 | new_tokens.append(t) 191 | if t[0] != '#': 192 | j += 1 193 | tokens = new_tokens 194 | tokens = tokens[:BERT_MAX_SEQ_LEN-2] # 2 allowed for [CLS] and [SEP] 195 | idxs = np.array([idx for idx,token in enumerate(tokens) if not token.startswith("##")]+[len(tokens)]) 196 | split_sizes = (idxs[1:]-idxs[0:-1]).tolist() 197 | # NOTE: BERT tokenizer does more than just splitting at whitespace and tokenizing. So be careful. 198 | # -----> assert len(split_sizes)==len(text.split()), print(len(tokens), len(split_sizes), len(text.split()), split_sizes, text) 199 | # -----> hence do the following: 200 | text = merge_subtokens(tokens) 201 | assert len(split_sizes)==len(text.split()), print(len(tokens), len(split_sizes), len(text.split()), split_sizes, text) 202 | return text, tokens, split_sizes 203 | 204 | 205 | def _custom_bert_tokenize_sentences(list_of_texts): 206 | out = [_custom_bert_tokenize_sentence(text) for text in list_of_texts] 207 | texts, tokens, split_sizes = list(zip(*out)) 208 | return [*texts], [*tokens], [*split_sizes] 209 | 210 | _simple_bert_tokenize_sentences = \ 211 | lambda list_of_texts: [merge_subtokens( BERT_TOKENIZER.tokenize(text)[:BERT_MAX_SEQ_LEN-2] ) for text in list_of_texts] 212 | 213 | 214 | def bert_tokenize(batch_sentences): 215 | """ 216 | inputs: 217 | batch_sentences: List[str] 218 | a list of textual sentences to tokenized 219 | outputs: 220 | batch_attention_masks, batch_input_ids, batch_token_type_ids 221 | 2d tensors of shape (bs,max_len) 222 | batch_splits: List[List[Int]] 223 | specifies #sub-tokens for each word in each textual string after sub-word tokenization 224 | """ 225 | batch_sentences, batch_tokens, batch_splits = _custom_bert_tokenize_sentences(batch_sentences) 226 | 227 | # max_seq_len = max([len(tokens) for tokens in batch_tokens]) 228 | # batch_encoded_dicts = [BERT_TOKENIZER.encode_plus(tokens,max_length=max_seq_len,pad_to_max_length=True) for tokens in batch_tokens] 229 | batch_encoded_dicts = [BERT_TOKENIZER.encode_plus(tokens) for tokens in batch_tokens] 230 | 231 | batch_attention_masks = pad_sequence([torch.tensor(encoded_dict["attention_mask"]) for encoded_dict in batch_encoded_dicts],batch_first=True,padding_value=0) 232 | batch_input_ids = pad_sequence([torch.tensor(encoded_dict["input_ids"]) for encoded_dict in batch_encoded_dicts],batch_first=True,padding_value=0) 233 | batch_token_type_ids = pad_sequence([torch.tensor(encoded_dict["token_type_ids"]) for encoded_dict in batch_encoded_dicts],batch_first=True,padding_value=0) 234 | 235 | batch_bert_dict = {"attention_mask":batch_attention_masks, 236 | "input_ids":batch_input_ids, 237 | "token_type_ids":batch_token_type_ids} 238 | 239 | return batch_sentences, batch_bert_dict, batch_splits 240 | 241 | 242 | def bert_tokenize_for_valid_examples(batch_orginal_sentences, batch_noisy_sentences): 243 | """ 244 | inputs: 245 | batch_noisy_sentences: List[str] 246 | a list of textual sentences to tokenized 247 | batch_orginal_sentences: List[str] 248 | a list of texts to make sure lengths of input and output are same in the seq-modeling task 249 | outputs (only of batch_noisy_sentences): 250 | batch_attention_masks, batch_input_ids, batch_token_type_ids 251 | 2d tensors of shape (bs,max_len) 252 | batch_splits: List[List[Int]] 253 | specifies #sub-tokens for each word in each textual string after sub-word tokenization 254 | """ 255 | _batch_orginal_sentences = _simple_bert_tokenize_sentences(batch_orginal_sentences) 256 | _batch_noisy_sentences, _batch_tokens, _batch_splits = _custom_bert_tokenize_sentences(batch_noisy_sentences) 257 | valid_idxs = [idx for idx,(a,b) in enumerate(zip(_batch_orginal_sentences, _batch_noisy_sentences)) if len(a.split())==len(b.split())] 258 | batch_orginal_sentences = [line for idx,line in enumerate(_batch_orginal_sentences) if idx in valid_idxs] 259 | batch_noisy_sentences = [line for idx,line in enumerate(_batch_noisy_sentences) if idx in valid_idxs] 260 | batch_tokens = [line for idx,line in enumerate(_batch_tokens) if idx in valid_idxs] 261 | batch_splits = [line for idx,line in enumerate(_batch_splits) if idx in valid_idxs] 262 | 263 | batch_bert_dict = {"attention_mask":[],"input_ids":[],"token_type_ids":[]} 264 | if len(valid_idxs)>0: 265 | batch_encoded_dicts = [BERT_TOKENIZER.encode_plus(tokens) for tokens in batch_tokens] 266 | batch_attention_masks = pad_sequence([torch.tensor(encoded_dict["attention_mask"]) for encoded_dict in batch_encoded_dicts],batch_first=True,padding_value=0) 267 | batch_input_ids = pad_sequence([torch.tensor(encoded_dict["input_ids"]) for encoded_dict in batch_encoded_dicts],batch_first=True,padding_value=0) 268 | batch_token_type_ids = pad_sequence([torch.tensor(encoded_dict["token_type_ids"]) for encoded_dict in batch_encoded_dicts],batch_first=True,padding_value=0) 269 | batch_bert_dict = {"attention_mask":batch_attention_masks, 270 | "input_ids":batch_input_ids, 271 | "token_type_ids":batch_token_type_ids} 272 | 273 | return batch_orginal_sentences, batch_noisy_sentences, batch_bert_dict, batch_splits --------------------------------------------------------------------------------