├── .gitignore ├── BERTScore ├── README.md ├── filterBanglaBert.py ├── logBERT.py ├── testBertScore.py └── utils │ ├── score.py │ └── utils.py ├── N-gram Repitition Filter ├── README.md └── n_gram_repeatition_filter.py ├── PINCScore ├── PINCscore.py ├── README.md └── filterPINC.py ├── Punctuation Filter ├── README.md └── punctuation_filter.py ├── README.md ├── filter.sh ├── images └── filter_sequence.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store -------------------------------------------------------------------------------- /BERTScore/README.md: -------------------------------------------------------------------------------- 1 | ### Setting up for BERTScore for using BanglaBERT encoding 2 | Install requirements. 3 | ``` 4 | pip install git+https://github.com/csebuetnlp/normalizer 5 | pip install jsonlines 6 | ``` 7 | Install BERTScore 8 | ``` 9 | git clone https://github.com/Tiiiger/bert_score 10 | cd bert_score 11 | pip install . 12 | ``` 13 | To use BanglaBERT encoding, replace the `bert_score/bert_score/score.py` and `bert_score/bert_score/utils.py` files with our corresponding provided files. 14 | ### Generate the log file 15 | Place the `logBERT.py` inside the `bert_score/bert_score/` folder. Then generate the log file by running the following command from `bert_score/bert_score/` folder. 16 | ``` 17 | python logBERT.py --l --t 18 | ``` 19 | where `source` is the path to the jsonl file containing sentences and their corresponding paraphrases as key value pairs and `target` is the generated log file. 20 | ### Calculate BERTScore 21 | To calculate bert score place the `testBertScore.py` inside the `bert_score/bert_score/` folder. Then generate the log file by running the following command from `bert_score/bert_score/` folder. 22 | ``` 23 | python testBertScore.py --s --p 24 | ``` 25 | where `source` and `prediction` are respectively the paths to the files containing sources and corresponding predictions. 26 | 27 | ### Filter with BERTScore 28 | To filer with BERTScore using the log file run the following command. 29 | ``` 30 | python filterBanglaBert.py --j --s --t --l --u 31 | ``` 32 | where `log_file` is the path to the log file, `source` and `target` are the path to the generated files containing the sources and their corresponding paraphrases. `lower_limit` and `upper_limit` are the limits for BERTScore in scale of 0 to 1. 33 | -------------------------------------------------------------------------------- /BERTScore/filterBanglaBert.py: -------------------------------------------------------------------------------- 1 | from ast import arg 2 | import jsonlines 3 | import json 4 | import argparse 5 | 6 | # filters from log files of PINC filtered json files. 7 | 8 | if __name__ == '__main__': 9 | 10 | # Create the parser 11 | parser = argparse.ArgumentParser( 12 | description='path to jsonl log file, output source and output target and the lower and upper limit of BertScore') 13 | 14 | # Add the arguments 15 | parser.add_argument('--j', 16 | metavar='j', 17 | type=str, 18 | help='the path to the jsonl log file') 19 | 20 | parser.add_argument('--s', 21 | metavar='t', 22 | type=str, 23 | help='the path to the generated source file') 24 | 25 | parser.add_argument('--t', 26 | metavar='t', 27 | type=str, 28 | help='the path to the generated target file') 29 | 30 | parser.add_argument('--l', 31 | metavar='l', 32 | type=float, 33 | help='the lower limit of bbertscore') 34 | 35 | parser.add_argument('--u', 36 | metavar='u', 37 | type=float, 38 | help='the upper limit of bbertscore') 39 | 40 | # Execute the parse_args() method 41 | args = parser.parse_args() 42 | 43 | banglabert_path = args.j 44 | source_path = args.s 45 | target_path = args.t 46 | threshold1 = args.l 47 | threshold2 = args.u 48 | 49 | linecount = 0 50 | sourcebuffer = "" 51 | targetbuffer = "" 52 | banglabertfilteredcount = 0 53 | 54 | banglabertfile = jsonlines.open(banglabert_path) 55 | sourcefile = open( 56 | source_path, 'w', encoding='utf-8') 57 | targetfile = open( 58 | target_path, 'w', encoding='utf-8') 59 | 60 | for line in banglabertfile.iter(): 61 | srcbangla = "" 62 | trgtbangla = "" 63 | maxbangla = -1 64 | for key, values in line.items(): 65 | srcbangla = key 66 | for value in values: 67 | if value[1] > maxbangla: 68 | trgtbangla = value[0] 69 | maxbangla = value[1] 70 | 71 | if(maxbangla >= threshold1 and maxbangla < threshold2): 72 | sourcebuffer += srcbangla + "\n" 73 | targetbuffer += trgtbangla + "\n" 74 | banglabertfilteredcount += 1 75 | if(banglabertfilteredcount == 50000): 76 | sourcefile.write("%s" % sourcebuffer) 77 | targetfile.write("%s" % targetbuffer) 78 | sourcebuffer = "" 79 | targetbuffer = "" 80 | banglabertfilteredcount = 0 81 | 82 | linecount += 1 83 | if linecount % 50000 == 0: 84 | print(linecount) 85 | sourcefile.write("%s" % sourcebuffer) 86 | targetfile.write("%s" % targetbuffer) 87 | # closing all the files 88 | banglabertfile.close() 89 | sourcefile.close() 90 | targetfile.close 91 | -------------------------------------------------------------------------------- /BERTScore/logBERT.py: -------------------------------------------------------------------------------- 1 | from score import score 2 | import torch 3 | from nltk import ngrams 4 | import argparse 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import jsonlines 8 | import json 9 | 10 | 11 | def generate_bertscore_logs(jsonl_file, target_file): 12 | """ 13 | generates the logs by processing jsonl input file and writes the parallel sentences with bertscore to target file 14 | 15 | Args: 16 | jsonl_file (file object) : file to be read 17 | target_file (file object) : file to be written the logs to 18 | Returns: 19 | None 20 | """ 21 | 22 | tracks = [] 23 | linecount = 0 24 | refs = [] 25 | cands = [] 26 | originalrefs = [] 27 | originalcands = [] 28 | lines = 0 29 | for line in jsonl_file.iter(): 30 | lines += 1 31 | tracks.append(0) 32 | for key, values in line.items(): 33 | tracks[linecount] += len(values) 34 | originalkey = key 35 | key = key.strip() 36 | for value in enumerate(values): 37 | originalrefs.append(originalkey) 38 | originalcands.append(value[-1]) 39 | refs.append(key) 40 | cands.append(value[-1].strip()) 41 | linecount += 1 42 | if linecount == 4000: 43 | _, _, F1 = score(cands, refs, lang='bn', verbose=False) 44 | F1_list = F1.tolist() 45 | bertindex = 0 46 | for track in tracks: 47 | objtowrite = {} 48 | objtowrite[originalrefs[bertindex]] = [] 49 | for i in range(track): 50 | objtowrite[originalrefs[bertindex]].append( 51 | (originalcands[bertindex], F1_list[bertindex])) 52 | bertindex += 1 53 | json.dump(objtowrite, target_file, ensure_ascii=False) 54 | target_file.write("%s" % '\n') 55 | tracks = [] 56 | refs = [] 57 | cands = [] 58 | originalrefs = [] 59 | originalcands = [] 60 | linecount = 0 61 | print(lines) 62 | 63 | _, _, F1 = score(cands, refs, lang='bn', verbose=False) 64 | F1_list = F1.tolist() 65 | bertindex = 0 66 | for track in tracks: 67 | objtowrite = {} 68 | objtowrite[originalrefs[bertindex]] = [] 69 | for i in range(track): 70 | objtowrite[originalrefs[bertindex]].append( 71 | (originalcands[bertindex], F1_list[bertindex])) 72 | bertindex += 1 73 | json.dump(objtowrite, target_file, ensure_ascii=False) 74 | target_file.write("%s" % '\n') 75 | tracks = [] 76 | refs = [] 77 | cands = [] 78 | originalrefs = [] 79 | originalcands = [] 80 | linecount = 0 81 | print(lines) 82 | 83 | 84 | if __name__ == '__main__': 85 | 86 | # Create the parser 87 | parser = argparse.ArgumentParser( 88 | description='path to jsonl input file and generated log file') 89 | 90 | # Add the arguments 91 | parser.add_argument('--l', 92 | metavar='l', 93 | type=str, 94 | help='the path to the jsonl file with sources and corresponding paraphrases') 95 | 96 | parser.add_argument('--t', 97 | metavar='t', 98 | type=str, 99 | help='the path to the generated log file') 100 | 101 | # Execute the parse_args() method 102 | args = parser.parse_args() 103 | 104 | jsonl_path = args.l 105 | target_path = args.t 106 | 107 | jsonl_file = jsonlines.open(jsonl_path) 108 | target_file = open(target_path, 'w', encoding='utf-8') 109 | 110 | generate_bertscore_logs(jsonl_file, target_file) 111 | 112 | # closing all the files 113 | jsonl_file.close() 114 | target_file.close() 115 | -------------------------------------------------------------------------------- /BERTScore/testBertScore.py: -------------------------------------------------------------------------------- 1 | # Must Be Kept at /bert_score/bert_score folder 2 | 3 | from score import score 4 | import numpy as np 5 | from nltk import ngrams 6 | import argparse 7 | 8 | # Create the parser 9 | parser = argparse.ArgumentParser(description='path to source and prediction') 10 | 11 | # Add the arguments 12 | parser.add_argument('--s', 13 | metavar='s', 14 | type=str, 15 | help='the path to the source') 16 | 17 | 18 | parser.add_argument('--p', 19 | metavar='p', 20 | type=str, 21 | help='the path to the generated prediction file') 22 | 23 | args = parser.parse_args() 24 | 25 | pred_path = args.p 26 | source_path = args.s 27 | 28 | with open(pred_path) as f: 29 | cands = [line.strip() for line in f] 30 | 31 | with open(source_path) as f: 32 | refs = [line.strip() for line in f] 33 | 34 | 35 | P, R, F1 = score(cands, refs, lang='bn', verbose=True) 36 | 37 | 38 | P_mean = P.mean() 39 | R_mean = R.mean() 40 | F1_mean= F1.mean() 41 | 42 | print(f"System level precision: {P_mean :.3f}") 43 | print(f"System level recall: {R_mean:.3f}") 44 | print(f"System level F1 score: {F1_mean:.3f}") -------------------------------------------------------------------------------- /BERTScore/utils/score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import pathlib 5 | import torch 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.axes_grid1 import make_axes_locatable 8 | import numpy as np 9 | import pandas as pd 10 | 11 | 12 | from collections import defaultdict 13 | from transformers import AutoTokenizer 14 | from utils import get_model, get_tokenizer, get_idf_dict, bert_cos_score_idf, get_bert_embedding, lang2model, model2layers, get_hash, cache_scibert, sent_encode 15 | 16 | 17 | # from .utils import ( 18 | # get_model, 19 | # get_tokenizer, 20 | # get_idf_dict, 21 | # bert_cos_score_idf, 22 | # get_bert_embedding, 23 | # lang2model, 24 | # model2layers, 25 | # get_hash, 26 | # cache_scibert, 27 | # sent_encode, 28 | # ) 29 | 30 | 31 | __all__ = ["score", "plot_example"] 32 | 33 | 34 | def score( 35 | cands, 36 | refs, 37 | model_type=None, 38 | num_layers=None, 39 | verbose=False, 40 | idf=False, 41 | device=None, 42 | batch_size=64, 43 | nthreads=4, 44 | all_layers=False, 45 | lang=None, 46 | return_hash=False, 47 | rescale_with_baseline=False, 48 | baseline_path=None, 49 | use_fast_tokenizer=False 50 | ): 51 | """ 52 | BERTScore metric. 53 | 54 | Args: 55 | - :param: `cands` (list of str): candidate sentences 56 | - :param: `refs` (list of str or list of list of str): reference sentences 57 | - :param: `model_type` (str): bert specification, default using the suggested 58 | model for the target langauge; has to specify at least one of 59 | `model_type` or `lang` 60 | - :param: `num_layers` (int): the layer of representation to use. 61 | default using the number of layer tuned on WMT16 correlation data 62 | - :param: `verbose` (bool): turn on intermediate status update 63 | - :param: `idf` (bool or dict): use idf weighting, can also be a precomputed idf_dict 64 | - :param: `device` (str): on which the contextual embedding model will be allocated on. 65 | If this argument is None, the model lives on cuda:0 if cuda is available. 66 | - :param: `nthreads` (int): number of threads 67 | - :param: `batch_size` (int): bert score processing batch size 68 | - :param: `lang` (str): language of the sentences; has to specify 69 | at least one of `model_type` or `lang`. `lang` needs to be 70 | specified when `rescale_with_baseline` is True. 71 | - :param: `return_hash` (bool): return hash code of the setting 72 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline 73 | - :param: `baseline_path` (str): customized baseline file 74 | - :param: `use_fast_tokenizer` (bool): `use_fast` parameter passed to HF tokenizer 75 | 76 | Return: 77 | - :param: `(P, R, F)`: each is of shape (N); N = number of input 78 | candidate reference pairs. if returning hashcode, the 79 | output will be ((P, R, F), hashcode). If a candidate have 80 | multiple references, the returned score of this candidate is 81 | the *best* score among all references. 82 | """ 83 | assert len(cands) == len(refs), "Different number of candidates and references" 84 | 85 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified" 86 | 87 | ref_group_boundaries = None 88 | if not isinstance(refs[0], str): 89 | ref_group_boundaries = [] 90 | ori_cands, ori_refs = cands, refs 91 | cands, refs = [], [] 92 | count = 0 93 | for cand, ref_group in zip(ori_cands, ori_refs): 94 | cands += [cand] * len(ref_group) 95 | refs += ref_group 96 | ref_group_boundaries.append((count, count + len(ref_group))) 97 | count += len(ref_group) 98 | 99 | if rescale_with_baseline: 100 | assert lang is not None, "Need to specify Language when rescaling with baseline" 101 | 102 | if model_type is None: 103 | lang = lang.lower() 104 | model_type = lang2model[lang] 105 | if num_layers is None: 106 | num_layers = model2layers[model_type] 107 | 108 | tokenizer = get_tokenizer(model_type, use_fast_tokenizer) 109 | model = get_model(model_type, num_layers, all_layers) 110 | if device is None: 111 | device = "cuda" if torch.cuda.is_available() else "cpu" 112 | model.to(device) 113 | 114 | if not idf: 115 | idf_dict = defaultdict(lambda: 1.0) 116 | # set idf for [SEP] and [CLS] to 0 117 | idf_dict[tokenizer.sep_token_id] = 0 118 | idf_dict[tokenizer.cls_token_id] = 0 119 | elif isinstance(idf, dict): 120 | if verbose: 121 | print("using predefined IDF dict...") 122 | idf_dict = idf 123 | else: 124 | if verbose: 125 | print("preparing IDF dict...") 126 | start = time.perf_counter() 127 | idf_dict = get_idf_dict(refs, tokenizer, nthreads=nthreads) 128 | if verbose: 129 | print("done in {:.2f} seconds".format(time.perf_counter() - start)) 130 | 131 | if verbose: 132 | print("calculating scores...") 133 | start = time.perf_counter() 134 | all_preds = bert_cos_score_idf( 135 | model, 136 | refs, 137 | cands, 138 | tokenizer, 139 | idf_dict, 140 | verbose=verbose, 141 | device=device, 142 | batch_size=batch_size, 143 | all_layers=all_layers, 144 | ).cpu() 145 | 146 | if ref_group_boundaries is not None: 147 | max_preds = [] 148 | for beg, end in ref_group_boundaries: 149 | max_preds.append(all_preds[beg:end].max(dim=0)[0]) 150 | all_preds = torch.stack(max_preds, dim=0) 151 | 152 | use_custom_baseline = baseline_path is not None 153 | if rescale_with_baseline: 154 | if baseline_path is None: 155 | baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv") 156 | if os.path.isfile(baseline_path): 157 | if not all_layers: 158 | baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float() 159 | else: 160 | baselines = torch.from_numpy(pd.read_csv(baseline_path).to_numpy())[:, 1:].unsqueeze(1).float() 161 | 162 | all_preds = (all_preds - baselines) / (1 - baselines) 163 | else: 164 | print( 165 | f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr, 166 | ) 167 | 168 | out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F 169 | 170 | if verbose: 171 | time_diff = time.perf_counter() - start 172 | print(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec") 173 | 174 | if return_hash: 175 | return tuple( 176 | [ 177 | out, 178 | get_hash(model_type, num_layers, idf, rescale_with_baseline, 179 | use_custom_baseline=use_custom_baseline, 180 | use_fast_tokenizer=use_fast_tokenizer), 181 | ] 182 | ) 183 | 184 | return out 185 | 186 | 187 | def plot_example( 188 | candidate, 189 | reference, 190 | model_type=None, 191 | num_layers=None, 192 | lang=None, 193 | rescale_with_baseline=False, 194 | baseline_path=None, 195 | use_fast_tokenizer=False, 196 | fname="", 197 | ): 198 | """ 199 | BERTScore metric. 200 | 201 | Args: 202 | - :param: `candidate` (str): a candidate sentence 203 | - :param: `reference` (str): a reference sentence 204 | - :param: `verbose` (bool): turn on intermediate status update 205 | - :param: `model_type` (str): bert specification, default using the suggested 206 | model for the target langauge; has to specify at least one of 207 | `model_type` or `lang` 208 | - :param: `num_layers` (int): the layer of representation to use 209 | - :param: `lang` (str): language of the sentences; has to specify 210 | at least one of `model_type` or `lang`. `lang` needs to be 211 | specified when `rescale_with_baseline` is True. 212 | - :param: `return_hash` (bool): return hash code of the setting 213 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline 214 | - :param: `use_fast_tokenizer` (bool): `use_fast` parameter passed to HF tokenizer 215 | - :param: `fname` (str): path to save the output plot 216 | """ 217 | assert isinstance(candidate, str) 218 | assert isinstance(reference, str) 219 | 220 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified" 221 | 222 | if rescale_with_baseline: 223 | assert lang is not None, "Need to specify Language when rescaling with baseline" 224 | 225 | if model_type is None: 226 | lang = lang.lower() 227 | model_type = lang2model[lang] 228 | if num_layers is None: 229 | num_layers = model2layers[model_type] 230 | 231 | tokenizer = get_tokenizer(model_type, use_fast_tokenizer) 232 | model = get_model(model_type, num_layers) 233 | device = "cuda" if torch.cuda.is_available() else "cpu" 234 | model.to(device) 235 | 236 | idf_dict = defaultdict(lambda: 1.0) 237 | # set idf for [SEP] and [CLS] to 0 238 | idf_dict[tokenizer.sep_token_id] = 0 239 | idf_dict[tokenizer.cls_token_id] = 0 240 | 241 | hyp_embedding, masks, padded_idf = get_bert_embedding( 242 | [candidate], model, tokenizer, idf_dict, device=device, all_layers=False 243 | ) 244 | ref_embedding, masks, padded_idf = get_bert_embedding( 245 | [reference], model, tokenizer, idf_dict, device=device, all_layers=False 246 | ) 247 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1)) 248 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1)) 249 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) 250 | sim = sim.squeeze(0).cpu() 251 | 252 | # remove [CLS] and [SEP] tokens 253 | r_tokens = [tokenizer.decode([i]) for i in sent_encode(tokenizer, reference)][1:-1] 254 | h_tokens = [tokenizer.decode([i]) for i in sent_encode(tokenizer, candidate)][1:-1] 255 | sim = sim[1:-1, 1:-1] 256 | 257 | if rescale_with_baseline: 258 | if baseline_path is None: 259 | baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv") 260 | if os.path.isfile(baseline_path): 261 | baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float() 262 | sim = (sim - baselines[2].item()) / (1 - baselines[2].item()) 263 | else: 264 | print( 265 | f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr, 266 | ) 267 | 268 | fig, ax = plt.subplots(figsize=(len(r_tokens), len(h_tokens))) 269 | im = ax.imshow(sim, cmap="Blues", vmin=0, vmax=1) 270 | 271 | # We want to show all ticks... 272 | ax.set_xticks(np.arange(len(r_tokens))) 273 | ax.set_yticks(np.arange(len(h_tokens))) 274 | # ... and label them with the respective list entries 275 | ax.set_xticklabels(r_tokens, fontsize=10) 276 | ax.set_yticklabels(h_tokens, fontsize=10) 277 | ax.grid(False) 278 | plt.xlabel("Reference (tokenized)", fontsize=14) 279 | plt.ylabel("Candidate (tokenized)", fontsize=14) 280 | title = "Similarity Matrix" 281 | if rescale_with_baseline: 282 | title += " (after Rescaling)" 283 | plt.title(title, fontsize=14) 284 | 285 | divider = make_axes_locatable(ax) 286 | cax = divider.append_axes("right", size="2%", pad=0.2) 287 | fig.colorbar(im, cax=cax) 288 | 289 | # Rotate the tick labels and set their alignment. 290 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 291 | 292 | # Loop over data dimensions and create text annotations. 293 | for i in range(len(h_tokens)): 294 | for j in range(len(r_tokens)): 295 | text = ax.text( 296 | j, 297 | i, 298 | "{:.3f}".format(sim[i, j].item()), 299 | ha="center", 300 | va="center", 301 | color="k" if sim[i, j].item() < 0.5 else "w", 302 | ) 303 | 304 | fig.tight_layout() 305 | if fname != "": 306 | plt.savefig(fname, dpi=100) 307 | print("Saved figure to file: ", fname) 308 | plt.show() 309 | -------------------------------------------------------------------------------- /BERTScore/utils/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | from math import log 5 | from itertools import chain 6 | from collections import defaultdict, Counter 7 | from multiprocessing import Pool 8 | from functools import partial 9 | from tqdm.auto import tqdm 10 | from torch.nn.utils.rnn import pad_sequence 11 | from distutils.version import LooseVersion 12 | from normalizer import normalize 13 | 14 | from transformers import BertConfig, XLNetConfig, XLMConfig, RobertaConfig 15 | from transformers import AutoModel, GPT2Tokenizer, AutoTokenizer 16 | 17 | __version__ = "0.3.10" 18 | 19 | from transformers import __version__ as trans_version 20 | 21 | __all__ = [] 22 | 23 | SCIBERT_URL_DICT = { 24 | "scibert-scivocab-uncased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_scivocab_uncased.tar", # recommend by the SciBERT authors 25 | "scibert-scivocab-cased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_scivocab_cased.tar", 26 | "scibert-basevocab-uncased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_basevocab_uncased.tar", 27 | "scibert-basevocab-cased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_basevocab_cased.tar", 28 | } 29 | 30 | 31 | lang2model = defaultdict(lambda: "bert-base-multilingual-cased") 32 | lang2model.update( 33 | { 34 | "en": "roberta-large", 35 | "zh": "bert-base-chinese", 36 | "tr": "dbmdz/bert-base-turkish-cased", 37 | "en-sci": "allenai/scibert_scivocab_uncased", 38 | 'bn': 'csebuetnlp/banglabert' 39 | } 40 | ) 41 | 42 | 43 | model2layers = { 44 | 'csebuetnlp/banglabert': 12, 45 | "bert-base-uncased": 9, # 0.6925188074454226 46 | "bert-large-uncased": 18, # 0.7210358126642836 47 | "bert-base-cased-finetuned-mrpc": 9, # 0.6721947475618048 48 | "bert-base-multilingual-cased": 9, # 0.6680687802637132 49 | "bert-base-chinese": 8, 50 | "roberta-base": 10, # 0.706288719158983 51 | "roberta-large": 17, # 0.7385974720781534 52 | "roberta-large-mnli": 19, # 0.7535618640417984 53 | "roberta-base-openai-detector": 7, # 0.7048158349432633 54 | "roberta-large-openai-detector": 15, # 0.7462770207355116 55 | "xlnet-base-cased": 5, # 0.6630103662114238 56 | "xlnet-large-cased": 7, # 0.6598800720297179 57 | "xlm-mlm-en-2048": 6, # 0.651262570131464 58 | "xlm-mlm-100-1280": 10, # 0.6475166424401905 59 | # "scibert-scivocab-uncased": 8, # 0.6590354319927313 60 | # "scibert-scivocab-cased": 9, # 0.6536375053937445 61 | # "scibert-basevocab-uncased": 9, # 0.6748944832703548 62 | # "scibert-basevocab-cased": 9, # 0.6524624150542374 63 | 'allenai/scibert_scivocab_uncased': 8, # 0.6590354393124127 64 | 'allenai/scibert_scivocab_cased': 9, # 0.6536374902465466 65 | 'nfliu/scibert_basevocab_uncased': 9, # 0.6748945076082333 66 | "distilroberta-base": 5, # 0.6797558139322964 67 | "distilbert-base-uncased": 5, # 0.6756659152782033 68 | "distilbert-base-uncased-distilled-squad": 4, # 0.6718318036382493 69 | "distilbert-base-multilingual-cased": 5, # 0.6178131050889238 70 | "albert-base-v1": 10, # 0.654237567249745 71 | "albert-large-v1": 17, # 0.6755890754323239 72 | "albert-xlarge-v1": 16, # 0.7031844211905911 73 | "albert-xxlarge-v1": 8, # 0.7508642218461096 74 | "albert-base-v2": 9, # 0.6682455591837927 75 | "albert-large-v2": 14, # 0.7008537594374035 76 | "albert-xlarge-v2": 13, # 0.7317228357869254 77 | "albert-xxlarge-v2": 8, # 0.7505160257184014 78 | "xlm-roberta-base": 9, # 0.6506799445871697 79 | "xlm-roberta-large": 17, # 0.6941551437476826 80 | "google/electra-small-generator": 9, # 0.6659421842117754 81 | "google/electra-small-discriminator": 11, # 0.6534639151385759 82 | "google/electra-base-generator": 10, # 0.6730033453857188 83 | "google/electra-base-discriminator": 9, # 0.7032089590812965 84 | "google/electra-large-generator": 18, # 0.6813370013104459 85 | "google/electra-large-discriminator": 14, # 0.6896675824733477 86 | "google/bert_uncased_L-2_H-128_A-2": 1, # 0.5887998733228855 87 | "google/bert_uncased_L-2_H-256_A-4": 1, # 0.6114863547661203 88 | "google/bert_uncased_L-2_H-512_A-8": 1, # 0.6177345529192847 89 | "google/bert_uncased_L-2_H-768_A-12": 2, # 0.6191261237956839 90 | "google/bert_uncased_L-4_H-128_A-2": 3, # 0.6076202863798991 91 | "google/bert_uncased_L-4_H-256_A-4": 3, # 0.6205239036810148 92 | "google/bert_uncased_L-4_H-512_A-8": 3, # 0.6375351621856903 93 | "google/bert_uncased_L-4_H-768_A-12": 3, # 0.6561849979644787 94 | "google/bert_uncased_L-6_H-128_A-2": 5, # 0.6200458425360283 95 | "google/bert_uncased_L-6_H-256_A-4": 5, # 0.6277501629539081 96 | "google/bert_uncased_L-6_H-512_A-8": 5, # 0.641952305130849 97 | "google/bert_uncased_L-6_H-768_A-12": 5, # 0.6762186226247106 98 | "google/bert_uncased_L-8_H-128_A-2": 7, # 0.6186876506711779 99 | "google/bert_uncased_L-8_H-256_A-4": 7, # 0.6447993208267708 100 | "google/bert_uncased_L-8_H-512_A-8": 6, # 0.6489729408169956 101 | "google/bert_uncased_L-8_H-768_A-12": 7, # 0.6705203359541737 102 | "google/bert_uncased_L-10_H-128_A-2": 8, # 0.6126762064125278 103 | "google/bert_uncased_L-10_H-256_A-4": 8, # 0.6376350032576573 104 | "google/bert_uncased_L-10_H-512_A-8": 9, # 0.6579006292799915 105 | "google/bert_uncased_L-10_H-768_A-12": 8, # 0.6861146692220176 106 | "google/bert_uncased_L-12_H-128_A-2": 10, # 0.6184105693383591 107 | "google/bert_uncased_L-12_H-256_A-4": 11, # 0.6374004994430261 108 | "google/bert_uncased_L-12_H-512_A-8": 10, # 0.65880012149526 109 | "google/bert_uncased_L-12_H-768_A-12": 9, # 0.675911357700092 110 | "amazon/bort": 0, # 0.41927911053036643 111 | "facebook/bart-base": 6, # 0.7122259132414092 112 | "facebook/bart-large": 10, # 0.7448671872459683 113 | "facebook/bart-large-cnn": 10, # 0.7393148105835096 114 | "facebook/bart-large-mnli": 11, # 0.7531665445691358 115 | "facebook/bart-large-xsum": 9, # 0.7496408866539556 116 | "t5-small": 6, # 0.6813843919496912 117 | "t5-base": 11, # 0.7096044814981418 118 | "t5-large": 23, # 0.7244153820191929 119 | "vinai/bertweet-base": 9, # 0.6529471006118857 120 | "microsoft/deberta-base": 9, # 0.7088459455930344 121 | "microsoft/deberta-base-mnli": 9, # 0.7395257063907247 122 | "microsoft/deberta-large": 16, # 0.7511806792052013 123 | "microsoft/deberta-large-mnli": 18, # 0.7736263649679905 124 | "microsoft/deberta-xlarge": 18, # 0.7568670944373346 125 | "microsoft/deberta-xlarge-mnli": 40, # 0.7780600929333213 126 | "YituTech/conv-bert-base": 10, # 0.7058253551080789 127 | "YituTech/conv-bert-small": 10, # 0.6544473011107349 128 | "YituTech/conv-bert-medium-small": 9, # 0.6590097075123257 129 | "microsoft/mpnet-base": 8, # 0.724976539498804 130 | "squeezebert/squeezebert-uncased": 9, # 0.6543868703018726 131 | "squeezebert/squeezebert-mnli": 9, # 0.6654799051284791 132 | "squeezebert/squeezebert-mnli-headless": 9, # 0.6654799051284791 133 | "tuner007/pegasus_paraphrase": 15, # 0.7188349436772694 134 | "google/pegasus-large": 8, # 0.63960462272448 135 | "google/pegasus-xsum": 11, # 0.6836878575233349 136 | "sshleifer/tiny-mbart": 2, # 0.028246072231946733 137 | "facebook/mbart-large-cc25": 12, # 0.6582922975802958 138 | "facebook/mbart-large-50": 12, # 0.6464972230103133 139 | "facebook/mbart-large-en-ro": 12, # 0.6791285137459857 140 | "facebook/mbart-large-50-many-to-many-mmt": 12, # 0.6904136529270892 141 | "facebook/mbart-large-50-one-to-many-mmt": 12, # 0.6847906439540236 142 | "allenai/led-base-16384": 6, # 0.7122259170564179 143 | "facebook/blenderbot_small-90M": 7, # 0.6489176335400088 144 | "facebook/blenderbot-400M-distill": 2, # 0.5874774070540008 145 | "microsoft/prophetnet-large-uncased": 4, # 0.586496184234925 146 | "microsoft/prophetnet-large-uncased-cnndm": 7, # 0.6478379437729287 147 | "SpanBERT/spanbert-base-cased": 8, # 0.6824006863686848 148 | "SpanBERT/spanbert-large-cased": 17, # 0.705352690855603 149 | "microsoft/xprophetnet-large-wiki100-cased": 7, # 0.5852499775879524 150 | "ProsusAI/finbert": 10, # 0.6923213940752796 151 | "Vamsi/T5_Paraphrase_Paws": 12, # 0.6941611753807352 152 | "ramsrigouthamg/t5_paraphraser": 11, # 0.7200917597031539 153 | "microsoft/deberta-v2-xlarge": 10, # 0.7393675784473045 154 | "microsoft/deberta-v2-xlarge-mnli": 17, # 0.7620620803716714 155 | "microsoft/deberta-v2-xxlarge": 21, # 0.7520547670281869 156 | "microsoft/deberta-v2-xxlarge-mnli": 22, # 0.7742603457742682 157 | "allenai/longformer-base-4096": 7, # 0.7089559593129316 158 | "allenai/longformer-large-4096": 14, # 0.732408493548181 159 | "allenai/longformer-large-4096-finetuned-triviaqa": 14, # 0.7365882744744722 160 | "zhiheng-huang/bert-base-uncased-embedding-relative-key": 4, # 0.5995636595368777 161 | "zhiheng-huang/bert-base-uncased-embedding-relative-key-query": 7, # 0.6303599452145718 162 | "zhiheng-huang/bert-large-uncased-whole-word-masking-embedding-relative-key-query": 19, # 0.6896878492850327 163 | 'google/mt5-small': 8, # 0.6401166527273479 164 | 'google/mt5-base': 11, # 0.5663956536597241 165 | 'google/mt5-large': 19, # 0.6430931371732798 166 | 'google/mt5-xl': 24, # 0.6707200963021145 167 | 'google/bigbird-roberta-base': 10, # 0.6695606423502717 168 | 'google/bigbird-roberta-large': 14, # 0.6755874042374509 169 | 'google/bigbird-base-trivia-itc': 8, # 0.6930725491629892 170 | 'princeton-nlp/unsup-simcse-bert-base-uncased': 10, # 0.6703066531921142 171 | 'princeton-nlp/unsup-simcse-bert-large-uncased': 18, # 0.6958302800755326 172 | 'princeton-nlp/unsup-simcse-roberta-base': 8, # 0.6436615893535319 173 | 'princeton-nlp/unsup-simcse-roberta-large': 13, # 0.6812864385585965 174 | 'princeton-nlp/sup-simcse-bert-base-uncased': 10, # 0.7068074935240984 175 | 'princeton-nlp/sup-simcse-bert-large-uncased': 18, # 0.7111049471332378 176 | 'princeton-nlp/sup-simcse-roberta-base': 10, # 0.7253123806661946 177 | 'princeton-nlp/sup-simcse-roberta-large': 16, # 0.7497820277237173 178 | 'dbmdz/bert-base-turkish-cased': 10, # WMT18 seg en-tr 0.5522827687776142 179 | 'dbmdz/distilbert-base-turkish-cased': 4, # WMT18 seg en-tr 0.4742268041237113 180 | 'google/byt5-small': 1, # 0.5100025975052146 181 | 'google/byt5-base': 17, # 0.5810347173565313 182 | 'google/byt5-large': 30, # 0.6151895697554877 183 | } 184 | 185 | 186 | def sent_encode(tokenizer, sent): 187 | 188 | "Encoding as sentence based on the tokenizer" 189 | sent = sent.strip() 190 | 191 | sent = normalize(sent) 192 | 193 | if sent == "": 194 | return tokenizer.build_inputs_with_special_tokens([]) 195 | elif isinstance(tokenizer, GPT2Tokenizer): 196 | # for RoBERTa and GPT-2 197 | if LooseVersion(trans_version) >= LooseVersion("4.0.0"): 198 | return tokenizer.encode( 199 | sent, 200 | add_special_tokens=True, 201 | add_prefix_space=True, 202 | max_length=tokenizer.model_max_length, 203 | truncation=True, 204 | ) 205 | elif LooseVersion(trans_version) >= LooseVersion("3.0.0"): 206 | return tokenizer.encode( 207 | sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len, truncation=True, 208 | ) 209 | elif LooseVersion(trans_version) >= LooseVersion("2.0.0"): 210 | return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len) 211 | else: 212 | raise NotImplementedError(f"transformers version {trans_version} is not supported") 213 | else: 214 | if LooseVersion(trans_version) >= LooseVersion("4.0.0"): 215 | return tokenizer.encode( 216 | sent, add_special_tokens=True, max_length=tokenizer.model_max_length, truncation=True, 217 | ) 218 | elif LooseVersion(trans_version) >= LooseVersion("3.0.0"): 219 | return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len, truncation=True) 220 | elif LooseVersion(trans_version) >= LooseVersion("2.0.0"): 221 | return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len) 222 | else: 223 | raise NotImplementedError(f"transformers version {trans_version} is not supported") 224 | 225 | 226 | def get_model(model_type, num_layers, all_layers=None): 227 | if model_type.startswith("scibert"): 228 | model = AutoModel.from_pretrained(cache_scibert(model_type)) 229 | elif "t5" in model_type: 230 | from transformers import T5EncoderModel 231 | 232 | model = T5EncoderModel.from_pretrained(model_type) 233 | else: 234 | model = AutoModel.from_pretrained(model_type) 235 | model.eval() 236 | 237 | if hasattr(model, "decoder") and hasattr(model, "encoder"): 238 | model = model.encoder 239 | 240 | # drop unused layers 241 | if not all_layers: 242 | if hasattr(model, "n_layers"): # xlm 243 | assert ( 244 | 0 <= num_layers <= model.n_layers 245 | ), f"Invalid num_layers: num_layers should be between 0 and {model.n_layers} for {model_type}" 246 | model.n_layers = num_layers 247 | elif hasattr(model, "layer"): # xlnet 248 | assert ( 249 | 0 <= num_layers <= len(model.layer) 250 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.layer)} for {model_type}" 251 | model.layer = torch.nn.ModuleList([layer for layer in model.layer[:num_layers]]) 252 | elif hasattr(model, "encoder"): # albert 253 | if hasattr(model.encoder, "albert_layer_groups"): 254 | assert ( 255 | 0 <= num_layers <= model.encoder.config.num_hidden_layers 256 | ), f"Invalid num_layers: num_layers should be between 0 and {model.encoder.config.num_hidden_layers} for {model_type}" 257 | model.encoder.config.num_hidden_layers = num_layers 258 | elif hasattr(model.encoder, "block"): # t5 259 | assert ( 260 | 0 <= num_layers <= len(model.encoder.block) 261 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.encoder.block)} for {model_type}" 262 | model.encoder.block = torch.nn.ModuleList([layer for layer in model.encoder.block[:num_layers]]) 263 | else: # bert, roberta 264 | assert ( 265 | 0 <= num_layers <= len(model.encoder.layer) 266 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.encoder.layer)} for {model_type}" 267 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]]) 268 | elif hasattr(model, "transformer"): # bert, roberta 269 | assert ( 270 | 0 <= num_layers <= len(model.transformer.layer) 271 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.transformer.layer)} for {model_type}" 272 | model.transformer.layer = torch.nn.ModuleList([layer for layer in model.transformer.layer[:num_layers]]) 273 | elif hasattr(model, "layers"): # bart 274 | assert ( 275 | 0 <= num_layers <= len(model.layers) 276 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.layers)} for {model_type}" 277 | model.layers = torch.nn.ModuleList([layer for layer in model.layers[:num_layers]]) 278 | else: 279 | raise ValueError("Not supported") 280 | else: 281 | if hasattr(model, "output_hidden_states"): 282 | model.output_hidden_states = True 283 | elif hasattr(model, "encoder"): 284 | model.encoder.output_hidden_states = True 285 | elif hasattr(model, "transformer"): 286 | model.transformer.output_hidden_states = True 287 | # else: 288 | # raise ValueError(f"Not supported model architecture: {model_type}") 289 | 290 | return model 291 | 292 | 293 | def get_tokenizer(model_type, use_fast=False): 294 | if model_type.startswith("scibert"): 295 | model_type = cache_scibert(model_type) 296 | 297 | if LooseVersion(trans_version) >= LooseVersion("4.0.0"): 298 | tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=use_fast) 299 | else: 300 | assert not use_fast, "Fast tokenizer is not available for version < 4.0.0" 301 | tokenizer = AutoTokenizer.from_pretrained(model_type) 302 | 303 | return tokenizer 304 | 305 | 306 | def padding(arr, pad_token, dtype=torch.long): 307 | lens = torch.LongTensor([len(a) for a in arr]) 308 | max_len = lens.max().item() 309 | padded = torch.ones(len(arr), max_len, dtype=dtype) * pad_token 310 | mask = torch.zeros(len(arr), max_len, dtype=torch.long) 311 | for i, a in enumerate(arr): 312 | padded[i, : lens[i]] = torch.tensor(a, dtype=dtype) 313 | mask[i, : lens[i]] = 1 314 | return padded, lens, mask 315 | 316 | 317 | def bert_encode(model, x, attention_mask, all_layers=False): 318 | model.eval() 319 | with torch.no_grad(): 320 | out = model(x, attention_mask=attention_mask, output_hidden_states=all_layers) 321 | if all_layers: 322 | emb = torch.stack(out[-1], dim=2) 323 | else: 324 | emb = out[0] 325 | return emb 326 | 327 | 328 | def process(a, tokenizer=None): 329 | if tokenizer is not None: 330 | a = sent_encode(tokenizer, a) 331 | return set(a) 332 | 333 | 334 | def get_idf_dict(arr, tokenizer, nthreads=4): 335 | """ 336 | Returns mapping from word piece index to its inverse document frequency. 337 | 338 | 339 | Args: 340 | - :param: `arr` (list of str) : sentences to process. 341 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`. 342 | - :param: `nthreads` (int) : number of CPU threads to use 343 | """ 344 | idf_count = Counter() 345 | num_docs = len(arr) 346 | 347 | process_partial = partial(process, tokenizer=tokenizer) 348 | 349 | with Pool(nthreads) as p: 350 | idf_count.update(chain.from_iterable(p.map(process_partial, arr))) 351 | 352 | idf_dict = defaultdict(lambda: log((num_docs + 1) / (1))) 353 | idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()}) 354 | return idf_dict 355 | 356 | 357 | def collate_idf(arr, tokenizer, idf_dict, device="cuda:0"): 358 | """ 359 | Helper function that pads a list of sentences to hvae the same length and 360 | loads idf score for words in the sentences. 361 | 362 | Args: 363 | - :param: `arr` (list of str): sentences to process. 364 | - :param: `tokenize` : a function that takes a string and return list 365 | of tokens. 366 | - :param: `numericalize` : a function that takes a list of tokens and 367 | return list of token indexes. 368 | - :param: `idf_dict` (dict): mapping a word piece index to its 369 | inverse document frequency 370 | - :param: `pad` (str): the padding token. 371 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' 372 | """ 373 | arr = [sent_encode(tokenizer, a) for a in arr] 374 | 375 | idf_weights = [[idf_dict[i] for i in a] for a in arr] 376 | 377 | pad_token = tokenizer.pad_token_id 378 | 379 | padded, lens, mask = padding(arr, pad_token, dtype=torch.long) 380 | padded_idf, _, _ = padding(idf_weights, 0, dtype=torch.float) 381 | 382 | padded = padded.to(device=device) 383 | mask = mask.to(device=device) 384 | lens = lens.to(device=device) 385 | return padded, padded_idf, lens, mask 386 | 387 | 388 | def get_bert_embedding(all_sens, model, tokenizer, idf_dict, batch_size=-1, device="cuda:0", all_layers=False): 389 | """ 390 | Compute BERT embedding in batches. 391 | 392 | Args: 393 | - :param: `all_sens` (list of str) : sentences to encode. 394 | - :param: `model` : a BERT model from `pytorch_pretrained_bert`. 395 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`. 396 | - :param: `idf_dict` (dict) : mapping a word piece index to its 397 | inverse document frequency 398 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' 399 | """ 400 | 401 | padded_sens, padded_idf, lens, mask = collate_idf(all_sens, tokenizer, idf_dict, device=device) 402 | 403 | if batch_size == -1: 404 | batch_size = len(all_sens) 405 | 406 | embeddings = [] 407 | with torch.no_grad(): 408 | for i in range(0, len(all_sens), batch_size): 409 | batch_embedding = bert_encode( 410 | model, padded_sens[i : i + batch_size], attention_mask=mask[i : i + batch_size], all_layers=all_layers, 411 | ) 412 | embeddings.append(batch_embedding) 413 | del batch_embedding 414 | 415 | total_embedding = torch.cat(embeddings, dim=0) 416 | 417 | return total_embedding, mask, padded_idf 418 | 419 | 420 | def greedy_cos_idf(ref_embedding, ref_masks, ref_idf, hyp_embedding, hyp_masks, hyp_idf, all_layers=False): 421 | """ 422 | Compute greedy matching based on cosine similarity. 423 | 424 | Args: 425 | - :param: `ref_embedding` (torch.Tensor): 426 | embeddings of reference sentences, BxKxd, 427 | B: batch size, K: longest length, d: bert dimenison 428 | - :param: `ref_lens` (list of int): list of reference sentence length. 429 | - :param: `ref_masks` (torch.LongTensor): BxKxK, BERT attention mask for 430 | reference sentences. 431 | - :param: `ref_idf` (torch.Tensor): BxK, idf score of each word 432 | piece in the reference setence 433 | - :param: `hyp_embedding` (torch.Tensor): 434 | embeddings of candidate sentences, BxKxd, 435 | B: batch size, K: longest length, d: bert dimenison 436 | - :param: `hyp_lens` (list of int): list of candidate sentence length. 437 | - :param: `hyp_masks` (torch.LongTensor): BxKxK, BERT attention mask for 438 | candidate sentences. 439 | - :param: `hyp_idf` (torch.Tensor): BxK, idf score of each word 440 | piece in the candidate setence 441 | """ 442 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1)) 443 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1)) 444 | 445 | if all_layers: 446 | B, _, L, D = hyp_embedding.size() 447 | hyp_embedding = hyp_embedding.transpose(1, 2).transpose(0, 1).contiguous().view(L * B, hyp_embedding.size(1), D) 448 | ref_embedding = ref_embedding.transpose(1, 2).transpose(0, 1).contiguous().view(L * B, ref_embedding.size(1), D) 449 | batch_size = ref_embedding.size(0) 450 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) 451 | masks = torch.bmm(hyp_masks.unsqueeze(2).float(), ref_masks.unsqueeze(1).float()) 452 | if all_layers: 453 | masks = masks.unsqueeze(0).expand(L, -1, -1, -1).contiguous().view_as(sim) 454 | else: 455 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim) 456 | 457 | masks = masks.float().to(sim.device) 458 | sim = sim * masks 459 | 460 | word_precision = sim.max(dim=2)[0] 461 | word_recall = sim.max(dim=1)[0] 462 | 463 | hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True)) 464 | ref_idf.div_(ref_idf.sum(dim=1, keepdim=True)) 465 | precision_scale = hyp_idf.to(word_precision.device) 466 | recall_scale = ref_idf.to(word_recall.device) 467 | if all_layers: 468 | precision_scale = precision_scale.unsqueeze(0).expand(L, B, -1).contiguous().view_as(word_precision) 469 | recall_scale = recall_scale.unsqueeze(0).expand(L, B, -1).contiguous().view_as(word_recall) 470 | P = (word_precision * precision_scale).sum(dim=1) 471 | R = (word_recall * recall_scale).sum(dim=1) 472 | F = 2 * P * R / (P + R) 473 | 474 | hyp_zero_mask = hyp_masks.sum(dim=1).eq(2) 475 | ref_zero_mask = ref_masks.sum(dim=1).eq(2) 476 | 477 | if all_layers: 478 | P = P.view(L, B) 479 | R = R.view(L, B) 480 | F = F.view(L, B) 481 | 482 | if torch.any(hyp_zero_mask): 483 | print( 484 | "Warning: Empty candidate sentence detected; setting raw BERTscores to 0.", file=sys.stderr, 485 | ) 486 | P = P.masked_fill(hyp_zero_mask, 0.0) 487 | R = R.masked_fill(hyp_zero_mask, 0.0) 488 | 489 | if torch.any(ref_zero_mask): 490 | print("Warning: Empty reference sentence detected; setting raw BERTScores to 0.", file=sys.stderr) 491 | P = P.masked_fill(ref_zero_mask, 0.0) 492 | R = R.masked_fill(ref_zero_mask, 0.0) 493 | 494 | F = F.masked_fill(torch.isnan(F), 0.0) 495 | 496 | return P, R, F 497 | 498 | 499 | def bert_cos_score_idf( 500 | model, refs, hyps, tokenizer, idf_dict, verbose=False, batch_size=64, device="cuda:0", all_layers=False, 501 | ): 502 | """ 503 | Compute BERTScore. 504 | 505 | Args: 506 | - :param: `model` : a BERT model in `pytorch_pretrained_bert` 507 | - :param: `refs` (list of str): reference sentences 508 | - :param: `hyps` (list of str): candidate sentences 509 | - :param: `tokenzier` : a BERT tokenizer corresponds to `model` 510 | - :param: `idf_dict` : a dictionary mapping a word piece index to its 511 | inverse document frequency 512 | - :param: `verbose` (bool): turn on intermediate status update 513 | - :param: `batch_size` (int): bert score processing batch size 514 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' 515 | """ 516 | preds = [] 517 | 518 | def dedup_and_sort(l): 519 | return sorted(list(set(l)), key=lambda x: len(x.split(" ")), reverse=True) 520 | 521 | sentences = dedup_and_sort(refs + hyps) 522 | embs = [] 523 | iter_range = range(0, len(sentences), batch_size) 524 | if verbose: 525 | print("computing bert embedding.") 526 | iter_range = tqdm(iter_range) 527 | stats_dict = dict() 528 | for batch_start in iter_range: 529 | sen_batch = sentences[batch_start : batch_start + batch_size] 530 | embs, masks, padded_idf = get_bert_embedding( 531 | sen_batch, model, tokenizer, idf_dict, device=device, all_layers=all_layers 532 | ) 533 | embs = embs.cpu() 534 | masks = masks.cpu() 535 | padded_idf = padded_idf.cpu() 536 | for i, sen in enumerate(sen_batch): 537 | sequence_len = masks[i].sum().item() 538 | emb = embs[i, :sequence_len] 539 | idf = padded_idf[i, :sequence_len] 540 | stats_dict[sen] = (emb, idf) 541 | 542 | def pad_batch_stats(sen_batch, stats_dict, device): 543 | stats = [stats_dict[s] for s in sen_batch] 544 | emb, idf = zip(*stats) 545 | emb = [e.to(device) for e in emb] 546 | idf = [i.to(device) for i in idf] 547 | lens = [e.size(0) for e in emb] 548 | emb_pad = pad_sequence(emb, batch_first=True, padding_value=2.0) 549 | idf_pad = pad_sequence(idf, batch_first=True) 550 | 551 | def length_to_mask(lens): 552 | lens = torch.tensor(lens, dtype=torch.long) 553 | max_len = max(lens) 554 | base = torch.arange(max_len, dtype=torch.long).expand(len(lens), max_len) 555 | return base < lens.unsqueeze(1) 556 | 557 | pad_mask = length_to_mask(lens).to(device) 558 | return emb_pad, pad_mask, idf_pad 559 | 560 | device = next(model.parameters()).device 561 | iter_range = range(0, len(refs), batch_size) 562 | if verbose: 563 | print("computing greedy matching.") 564 | iter_range = tqdm(iter_range) 565 | 566 | with torch.no_grad(): 567 | for batch_start in iter_range: 568 | batch_refs = refs[batch_start : batch_start + batch_size] 569 | batch_hyps = hyps[batch_start : batch_start + batch_size] 570 | ref_stats = pad_batch_stats(batch_refs, stats_dict, device) 571 | hyp_stats = pad_batch_stats(batch_hyps, stats_dict, device) 572 | 573 | P, R, F1 = greedy_cos_idf(*ref_stats, *hyp_stats, all_layers) 574 | preds.append(torch.stack((P, R, F1), dim=-1).cpu()) 575 | preds = torch.cat(preds, dim=1 if all_layers else 0) 576 | return preds 577 | 578 | 579 | def get_hash(model, num_layers, idf, rescale_with_baseline, use_custom_baseline, use_fast_tokenizer): 580 | msg = "{}_L{}{}_version={}(hug_trans={})".format( 581 | model, num_layers, "_idf" if idf else "_no-idf", __version__, trans_version 582 | ) 583 | if rescale_with_baseline: 584 | if use_custom_baseline: 585 | msg += "-custom-rescaled" 586 | else: 587 | msg += "-rescaled" 588 | if use_fast_tokenizer: 589 | msg += "_fast-tokenizer" 590 | return msg 591 | 592 | 593 | def cache_scibert(model_type, cache_folder="~/.cache/torch/transformers"): 594 | if not model_type.startswith("scibert"): 595 | return model_type 596 | 597 | underscore_model_type = model_type.replace("-", "_") 598 | cache_folder = os.path.abspath(os.path.expanduser(cache_folder)) 599 | filename = os.path.join(cache_folder, underscore_model_type) 600 | 601 | # download SciBERT models 602 | if not os.path.exists(filename): 603 | cmd = f"mkdir -p {cache_folder}; cd {cache_folder};" 604 | cmd += f"wget {SCIBERT_URL_DICT[model_type]}; tar -xvf {underscore_model_type}.tar;" 605 | cmd += ( 606 | f"rm -f {underscore_model_type}.tar ; cd {underscore_model_type}; tar -zxvf weights.tar.gz; mv weights/* .;" 607 | ) 608 | cmd += f"rm -f weights.tar.gz; rmdir weights; mv bert_config.json config.json;" 609 | print(cmd) 610 | print(f"downloading {model_type} model") 611 | os.system(cmd) 612 | 613 | # fix the missing files in scibert 614 | json_file = os.path.join(filename, "special_tokens_map.json") 615 | if not os.path.exists(json_file): 616 | with open(json_file, "w") as f: 617 | print( 618 | '{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}', 619 | file=f, 620 | ) 621 | 622 | json_file = os.path.join(filename, "added_tokens.json") 623 | if not os.path.exists(json_file): 624 | with open(json_file, "w") as f: 625 | print("{}", file=f) 626 | 627 | if "uncased" in model_type: 628 | json_file = os.path.join(filename, "tokenizer_config.json") 629 | if not os.path.exists(json_file): 630 | with open(json_file, "w") as f: 631 | print('{"do_lower_case": true, "max_len": 512, "init_inputs": []}', file=f) 632 | 633 | return filename 634 | -------------------------------------------------------------------------------- /N-gram Repitition Filter/README.md: -------------------------------------------------------------------------------- 1 | ### Filter sentences with N-gram repeats 2 | 3 | To remove sentences from the source and target file having n-gram repeats in the target file. The value of n is set to 2 by default. After removal, both the filtered source and target sentences are obtained. 4 | 5 | ``` 6 | python n_gram_repeatition_filter.py --s --t 7 | ``` 8 | 9 | where `source` and `target` are respectively the paths to the files containing sources and corresponding target paraphrases. 10 | -------------------------------------------------------------------------------- /N-gram Repitition Filter/n_gram_repeatition_filter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from bengali_stemmer.rafikamal2014 import RafiStemmer 3 | """Must be kept inside the rafikamal stemmer""" 4 | 5 | 6 | def stem_string(string): 7 | """ 8 | returns a stemmed string without punctuations 9 | 10 | Args: 11 | string (str) : string to be stemmed 12 | Returns: 13 | stemmed_string (str) : stemmed version of the string 14 | """ 15 | stemmer = RafiStemmer() 16 | punc = '''।,;:?!'."-[]{}()–—―~''' 17 | 18 | for ele in string: 19 | if ele in punc: 20 | string = string.replace(ele, "") 21 | words = string.split() 22 | return ' '.join([stemmer.stem_word(word) for word in words]) 23 | 24 | 25 | def calculate_ngram_repeat(text, n_gram): 26 | """ 27 | returns the filtered sentences from the text which has n gram-repeats 28 | 29 | Args: 30 | text (str) : string analyzed 31 | n_gram (str) : n_gram to consider during calculation of the repeat 32 | Returns: 33 | string with repeats (str) : string containing n-gram repeats 34 | """ 35 | stemmed = stem_string(text) 36 | 37 | splitted = stemmed.split() 38 | 39 | for i, baseword in enumerate(splitted): 40 | for j, cmpword in enumerate(splitted[i+1:]): 41 | if baseword == cmpword: 42 | if len(splitted) - i-j-1 > j: 43 | trackflag = True 44 | for k in range(1, j+1): 45 | if splitted[i+k] != splitted[i+1+j+k]: 46 | trackflag = False 47 | break 48 | if trackflag: 49 | if j+1 >= n_gram: 50 | return ' '.join([s for s in splitted[i:i+j+1]]) 51 | 52 | return '' 53 | 54 | 55 | if __name__ == '__main__': 56 | # Create the parser 57 | parser = argparse.ArgumentParser( 58 | description='path to the source and target file') 59 | 60 | # Add the arguments 61 | parser.add_argument('--s', 62 | metavar='s', 63 | type=str, 64 | help='the path to the source file') 65 | 66 | parser.add_argument('--t', 67 | metavar='t', 68 | type=str, 69 | help='the path to the target file') 70 | 71 | args = parser.parse_args() 72 | 73 | source_path = args.s 74 | target_path = args.t 75 | 76 | 77 | source_file = open(source_path, encoding='utf-8') 78 | target_file = open(target_path, encoding='utf-8') 79 | 80 | final_target = open('ngram_filtered_source.bn', 'w', encoding='utf-8') 81 | final_source = open('ngram_filtered_target.bn', 'w', encoding='utf-8') 82 | 83 | target_lines = target_file.readlines() 84 | source = source_file.readlines() 85 | counter = 0 86 | 87 | for line_index, line in enumerate(target_lines): 88 | output = calculate_ngram_repeat(line, n_gram=2) 89 | 90 | if (line_index+1) % 20000 == 0: 91 | print(line_index + 1) 92 | 93 | if output != '': 94 | counter += 1 95 | else: 96 | final_target.write(line) 97 | final_source.write(source[line_index]) 98 | -------------------------------------------------------------------------------- /PINCScore/PINCscore.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import json 3 | import pandas as pd 4 | import torch 5 | from nltk import ngrams 6 | import argparse 7 | import os 8 | import sys 9 | from bengali_stemmer.rafikamal2014 import RafiStemmer 10 | 11 | 12 | def stem_string(string): 13 | """ 14 | returns a stemmed string without punctuations 15 | 16 | Args: 17 | string (str) : string to be stemmed 18 | Returns: 19 | stemmed_string (str) : stemmed version of the string 20 | """ 21 | stemmer = RafiStemmer() 22 | punc = '''।,;:?!'."-[]{}()–—―~''' 23 | 24 | for ele in string: 25 | if ele in punc: 26 | string = string.replace(ele, "") 27 | words = string.split() 28 | return ' '.join([stemmer.stem_word(word) for word in words]) 29 | 30 | 31 | def calculateScore(sourcefile, predictionfile): 32 | """ 33 | returns a stemmed string without punctuations 34 | 35 | Args: 36 | sourcefile (file object) : the source file 37 | predictionfile (file object) : the prediction file 38 | Returns: 39 | None 40 | """ 41 | N = 4 42 | datacount = 0 43 | totalPINC = 0 44 | for source in sourcefile: 45 | prediction = predictionfile.readline() 46 | datacount += 1 47 | source = stem_string(source) 48 | prediction = stem_string(prediction) 49 | pinc_sum = 0 50 | for i in range(N): 51 | overlap_count = 0 52 | key_n_grams = list(ngrams(source.split(), i + 1)) 53 | value_n_gram = list(ngrams(prediction.split(), i + 1)) 54 | value_ngram_size = len(value_n_gram) 55 | for key_i in range(len(key_n_grams)): 56 | for value_i in range(len(value_n_gram)): 57 | if key_n_grams[key_i] == value_n_gram[value_i]: 58 | overlap_count += 1 # increasing overlap count 59 | # removing the exact n gram after calculating overlap 60 | value_n_gram.pop(value_i) 61 | break 62 | 63 | # calculating the pinc sum 64 | if value_ngram_size > 0: 65 | pinc_sum += (1 - (overlap_count / value_ngram_size)) 66 | pinc_sum = pinc_sum / N 67 | totalPINC += pinc_sum 68 | PINCscore = totalPINC/datacount 69 | print("Average PINC Score : ", PINCscore) 70 | 71 | 72 | if __name__ == '__main__': 73 | 74 | # Create the parser 75 | parser = argparse.ArgumentParser( 76 | description='path to source and target files') 77 | 78 | # Add the arguments 79 | parser.add_argument('--s', 80 | metavar='s', 81 | type=str, 82 | help='the path to the source file') 83 | 84 | parser.add_argument('--p', 85 | metavar='p', 86 | type=str, 87 | help='the path to the prediction file') 88 | 89 | # Execute the parse_args() method 90 | args = parser.parse_args() 91 | 92 | source_path = args.s 93 | prediction_path = args.p 94 | 95 | sourcefile = open(source_path, encoding='utf-8') 96 | predictionfile = open(prediction_path, encoding='utf-8') 97 | 98 | calculateScore(sourcefile, predictionfile) 99 | 100 | # closing all the files 101 | sourcefile.close() 102 | predictionfile.close() 103 | -------------------------------------------------------------------------------- /PINCScore/README.md: -------------------------------------------------------------------------------- 1 | ### Requirements 2 | To run the scripts, they should be placed inside this project [bengali-stemmer](https://github.com/banglakit/bengali-stemmer). 3 | ### Calculate PINCScore 4 | 5 | To calculate PINCScore between source and generated paraphrases run the following command from the mentioned project directory. 6 | 7 | ``` 8 | python PINCscore.py --s --p 9 | ``` 10 | 11 | where `source` and `prediction` are respectively the paths to the files containing sources and corresponding predictions. 12 | 13 | ### Filter with PINCScore 14 | 15 | To filer with PINCScore run the following command from the mentioned project directory. 16 | 17 | ``` 18 | python filterPINC.py --l --t --p 19 | ``` 20 | 21 | where `source` is the path to the jsonl file containing sentences and their corresponding paraphrases as key value pairs. `target` is the path to the generated jsonl file containing the pairs after filtering in the same format and `pinc_score` is the PINC score threshold to use to filter the source file. 22 | -------------------------------------------------------------------------------- /PINCScore/filterPINC.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import json 3 | import pandas as pd 4 | import torch 5 | from nltk import ngrams 6 | import argparse 7 | import os 8 | import sys 9 | from bengali_stemmer.rafikamal2014 import RafiStemmer 10 | 11 | 12 | def stem_string(string): 13 | """ 14 | returns a stemmed string without punctuations 15 | 16 | Args: 17 | string (str) : string to be stemmed 18 | Returns: 19 | stemmed_string (str) : stemmed version of the string 20 | """ 21 | stemmer = RafiStemmer() 22 | punc = '''।,;:?!'."-[]{}()–—―~''' 23 | 24 | for ele in string: 25 | if ele in punc: 26 | string = string.replace(ele, "") 27 | words = string.split() 28 | return ' '.join([stemmer.stem_word(word) for word in words]) 29 | 30 | 31 | def filter_dataset(jsonl_file, target_file, pinc_threshold): 32 | """ 33 | filters the jsonl file with the given pinc threshold and writes the parallel 34 | sentences to the target file 35 | 36 | Args: 37 | jsonl_file (file object) : file to be read and filtered 38 | target_file (file object) : file to be written to 39 | pinc_threshold (float): pinc threshold value 40 | Returns: 41 | None 42 | """ 43 | 44 | N = 4 45 | linecount = 0 46 | 47 | for line in jsonl_file.iter(): 48 | linecount += 1 49 | hasfound = False 50 | trgts = {} 51 | 52 | for key, values in line.items(): 53 | original_key = key 54 | key = stem_string(key) 55 | stemmed_values = [stem_string(value) for value in values] 56 | trgts[original_key] = [] 57 | 58 | for value_index, value in enumerate(stemmed_values): 59 | pinc_sum = 0 60 | 61 | for i in range(N): 62 | overlap_count = 0 63 | key_n_grams = list(ngrams(key.split(), i + 1)) 64 | value_n_gram = list(ngrams(value.split(), i + 1)) 65 | value_ngram_size = len(value_n_gram) 66 | 67 | for key_i in range(len(key_n_grams)): 68 | for value_i in range(len(value_n_gram)): 69 | if key_n_grams[key_i] == value_n_gram[value_i]: 70 | overlap_count += 1 # increasing overlap count 71 | # removing the exact n gram after calculating overlap 72 | value_n_gram.pop(value_i) 73 | break 74 | 75 | # calculating the pinc sum 76 | if value_ngram_size > 0: 77 | pinc_sum += (1 - (overlap_count / value_ngram_size)) 78 | pinc_sum = pinc_sum / N 79 | 80 | if pinc_sum >= pinc_threshold: 81 | hasfound = True 82 | trgts[original_key].append(values[value_index]) 83 | if hasfound: 84 | json.dump(trgts, target_file, ensure_ascii=False) 85 | target_file.write("%s" % '\n') 86 | if linecount % 10000 == 0: 87 | print(linecount) 88 | 89 | 90 | if __name__ == '__main__': 91 | 92 | # Create the parser 93 | parser = argparse.ArgumentParser( 94 | description='path to the input and output file and the pinc score threshold') 95 | 96 | # Add the arguments 97 | parser.add_argument('--l', 98 | metavar='l', 99 | type=str, 100 | help='the path to the jsonl file with sources and corresponding paraphrases') 101 | 102 | parser.add_argument('--t', 103 | metavar='t', 104 | type=str, 105 | help='the path to the generated target jsonl file') 106 | 107 | parser.add_argument('--p', 108 | metavar='p', 109 | type=float, 110 | help='the desired pinc score threshold (0 - 1)') 111 | 112 | # Execute the parse_args() method 113 | args = parser.parse_args() 114 | 115 | jsonl_path = args.l 116 | target_path = args.t 117 | pinc_threshold = args.p 118 | 119 | target_file = open(target_path, 'w', encoding='utf-8') 120 | jsonl_file = jsonlines.open(jsonl_path) 121 | 122 | filter_dataset(jsonl_file, target_file, pinc_threshold) 123 | 124 | # closing all the files 125 | target_file.close() 126 | jsonl_file.close() 127 | -------------------------------------------------------------------------------- /Punctuation Filter/README.md: -------------------------------------------------------------------------------- 1 | ### Filters sentences with misplaced punctuation at the end 2 | 3 | To remove sentences from the source and the target file having misplaced punctuation or no trailing punctuation at the end. In some cases, the sentences are simply modified to remove the unwanted punctuation at the end. 4 | 5 | ``` 6 | python punctuation_filter.py --s --t 7 | ``` 8 | 9 | where `source` and `target` are respectively the paths to the files containing sources and corresponding target paraphrases. 10 | -------------------------------------------------------------------------------- /Punctuation Filter/punctuation_filter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | symbols = ''',;:?!'."-[]{}()–—―~''' 3 | 4 | 5 | # Create the parser 6 | parser = argparse.ArgumentParser( 7 | description='path to the source and target file') 8 | 9 | # Add the arguments 10 | parser.add_argument('--s', 11 | metavar='s', 12 | type=str, 13 | help='the path to the source file') 14 | 15 | parser.add_argument('--t', 16 | metavar='t', 17 | type=str, 18 | help='the path to the target file') 19 | 20 | 21 | args = parser.parse_args() 22 | 23 | source_path = args.s 24 | target_path = args.t 25 | 26 | source_file = open( 27 | source_path, encoding='utf-8') 28 | target_file = open( 29 | target_path, encoding='utf-8') 30 | 31 | 32 | final_source = open('./source.bn', 'w', encoding='utf-8') 33 | final_target = open('./target.bn', 'w', encoding='utf-8') 34 | 35 | 36 | counter = 0 37 | 38 | for index, line in enumerate(zip(source_file.readlines(), target_file.readlines())): 39 | 40 | src = line[0].strip() 41 | trg = line[1].strip() 42 | 43 | if src[-1] == '"' and trg[-1] == '"': 44 | # check if -2 pos has !?| or not 45 | if src[-2] == '?' or src[-2] == "!" or src[-2] == "।": 46 | if trg[-2] == '?' or trg[-2] == "!" or trg[-2] == "।": 47 | 48 | src_to_write = '' 49 | trg_to_write = '' 50 | 51 | if (src.count('"') % 2 == 0 and src.count('"') >= 2) and (trg.count('"') % 2 == 0 and trg.count('"') >= 2): 52 | src_to_write = src 53 | trg_to_write = trg 54 | 55 | if (src.count('"') == 1) and (trg.count('"') == 1): 56 | src_to_write = src[:-1] 57 | trg_to_write = trg[:-1] 58 | 59 | # write without the last quotation 60 | if src_to_write != '' and trg_to_write != '': 61 | final_source.write(src_to_write+'\n') 62 | final_target.write(trg_to_write+'\n') 63 | 64 | elif src[-1] == '\'' and trg[-1] == '\'': 65 | # check if -2 pos has !?| or not 66 | if src[-2] == '?' or src[-2] == "!" or src[-2] == "।": 67 | if trg[-2] == '?' or trg[-2] == "!" or trg[-2] == "।": 68 | 69 | src_to_write = '' 70 | trg_to_write = '' 71 | 72 | if (src.count('\'') % 2 == 0 and src.count('\'') >= 2) and (trg.count('\'') % 2 == 0 and trg.count('\'') >= 2): 73 | src_to_write = src 74 | trg_to_write = trg 75 | 76 | if (src.count('\'') == 1) and (trg.count('\'') == 1): 77 | src_to_write = src[:-1] 78 | trg_to_write = trg[:-1] 79 | 80 | # write without the last quotation 81 | if src_to_write != '' and trg_to_write != '': 82 | final_source.write(src_to_write+'\n') 83 | final_target.write(trg_to_write+'\n') 84 | 85 | elif src[-1] == '"' and (trg[-1] == '।' or trg[-1] == '?' or trg[-1] == '!') and src.count('"') == 1 and trg.count('"') == 0: 86 | 87 | # check if -2 pos has !?| or not 88 | if src[-2] == '?' or src[-2] == "!" or src[-2] == "।": 89 | 90 | final_source.write(src[:-1]+'\n') 91 | final_target.write(trg+'\n') 92 | 93 | elif src[-1] == '\'' and (trg[-1] == '।' or trg[-1] == '?' or trg[-1] == '!') and src.count('\'') == 1 and trg.count('\'') == 0: 94 | 95 | if src[-2] == '?' or src[-2] == "!" or src[-2] == "।": 96 | 97 | final_source.write(src[:-1]+'\n') 98 | final_target.write(trg+'\n') 99 | 100 | elif trg[-1] == '"' and (src[-1] == '।' or src[-1] == '?' or src[-1] == '!') and trg.count('"') == 1 and src.count('"') == 0: 101 | 102 | # check if -2 pos has !?| or not 103 | if trg[-2] == '?' or trg[-2] == "!" or trg[-2] == "।": 104 | 105 | final_source.write(src+'\n') 106 | final_target.write(trg[:-1]+'\n') 107 | 108 | elif trg[-1] == '\'' and (src[-1] == '।' or src[-1] == '?' or src[-1] == '!') and trg.count('\'') == 1 and src.count('\'') == 0: 109 | 110 | # check if -2 pos has !?| or not 111 | if trg[-2] == '?' or trg[-2] == "!" or trg[-2] == "।": 112 | 113 | final_source.write(src+'\n') 114 | final_target.write(trg[:-1]+'\n') 115 | 116 | elif (src[-1] == "।" or src[-1] == "!" or src[-1] == "?") and (trg[-1] == "।" or trg[-1] == "!" or trg[-1] == "?"): 117 | final_source.write(src+'\n') 118 | final_target.write(trg+'\n') 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BanglaParaphrase 2 | 3 | This repository contains the code, data, and associated models of the paper titled [**"BanglaParaphrase: A High-Quality Bangla Paraphrase Dataset"**](https://arxiv.org/abs/2210.05109), accepted in *Proceedings of the Asia-Pacific Chapter of the Association for Computational Linguistics: AACL 2022*. 4 | 5 | ## Table of Contents 6 | 7 | - [BanglaParaphrase](#banglaParaphrase) 8 | - [Table of Contents](#table-of-contents) 9 | - [Datasets](#datasets) 10 | - [Filtering Pipeline](#filtering-pipeline) 11 | - [Training & Evaluation](#training--evaluation) 12 | - [Models](#models) 13 | - [License](#license) 14 | - [Citation](#citation) 15 | 16 | ## Datasets 17 | 18 | ***Disclaimer: You must agree to the [license](#license) and terms of use before using the dataset.*** 19 | 20 | The dataset files are organized in `.jsonl` format i.e. one JSON per line. **Download the dataset from [here](https://huggingface.co/datasets/csebuetnlp/BanglaParaphrase/tree/main).** 21 | 22 | One example from the `test` part of the dataset is given below in JSON format. 23 | ``` 24 | { 25 | "source": "খোঁজ খবর রাখতেন বিজ্ঞানের অগ্রগতি নিয়ে।", 26 | "target": "বিজ্ঞানের অগ্রগতির দিকে তিনি নজর রেখেছিলেন।" 27 | } 28 | ``` 29 | 30 | ### Data Splits 31 | Dataset with train-dev-test example counts are given below: 32 | Language | ISO 639-1 Code | Train | Validation | Test | 33 | -------------- | ---------------- | ------- | ----- | ------ | 34 | Bengali | bn | 419, 967 | 233, 31 | 233, 32 | 35 | 36 | ## Filtering Pipeline 37 | The following filtering pipeline was used to preprocess the raw dataset to ensure high quality. 38 | ![filter_pipeline](images/filter_sequence.png) 39 | 40 | | Filter Name | Significance | Filtering Parameters | 41 | | ----------- | ----------- |----------------------------| 42 | | PINC | Ensure diversity in generated paraphrase | 0.65, 0.76, 0.80| 43 | | BERTScore | Preserve semantic coherence with the source |lower 0.91 - 0.93, upper 0.98| 44 | |N-gram repetition|Reduce n-gram repetition during inference|2 - 4 grams| 45 | | Punctuation | Prevent generating non-terminating sentences during inference | N/A | 46 | 47 | In the respective folders, instructions on how to run certain filtering and scoring scripts are provided. 48 | 49 | ### Run the full pipeline 50 | Install requirements from [requirements](https://github.com/csebuetnlp/banglaparaphrase/blob/master/requirements.txt) and then run the following command. 51 | ``` 52 | bash filter.sh -i -p -l -h 53 | ``` 54 | Where `input` is the path to the jsonl file containing sentences and their corresponding paraphrases as key value pairs, `pinc_threshold` is the threshold for PINCScore, `lower_bert_score_threshold` and `higher_bert_score_threshold` are the limits for BERTScore in scale of 0 to 1. 55 | 56 | This will generate two files named `source.bn` and `target.bn` in the working directory containing the filtered pairs after passing through all the filtering steps. 57 | 58 | ## Training & Evaluation 59 | For training and evaluation, please refer to the repository of [BanglaNLG](https://github.com/csebuetnlp/BanglaNLG). 60 | 61 | ## Models 62 | 63 | The model checkpoint from the paper is available at [huggingface model hub](https://huggingface.co/csebuetnlp/banglat5_banglaparaphrase). 64 | 65 | 66 | ## License 67 | Contents of this repository are restricted to only non-commercial research purposes under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/). Copyright of the dataset contents belongs to the original copyright holders. 68 | 69 | 70 | ## Citation 71 | ``` 72 | @article{akil2022banglaparaphrase, 73 | title={BanglaParaphrase: A High-Quality Bangla Paraphrase Dataset}, 74 | author={Akil, Ajwad and Sultana, Najrin and Bhattacharjee, Abhik and Shahriyar, Rifat}, 75 | journal={arXiv preprint arXiv:2210.05109}, 76 | year={2022} 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /filter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | while getopts i:p:l:h: flag 3 | do 4 | case "$flag" in 5 | i) input=${OPTARG};; 6 | p) pinc_threshold=${OPTARG};; 7 | l) bert_low=${OPTARG};; 8 | h) bert_high=${OPTARG};; 9 | esac 10 | done 11 | git clone https://github.com/banglakit/bengali-stemmer.git 12 | cp PINCScore/filterPINC.py bengali-stemmer/ 13 | python bengali-stemmer/filterPINC.py --l $input --t "PINC-filtered.jsonl" --p $pinc_threshold 14 | pip install git+https://github.com/csebuetnlp/normalizer 15 | git clone https://github.com/Tiiiger/bert_score 16 | cd bert_score 17 | pip install . 18 | rm bert_score/score.py bert_score/utils.py 19 | cp "../BERTScore/utils/score.py" bert_score/ 20 | cp "../BERTScore/utils/utils.py" bert_score/ 21 | cp "../BERTScore/logBERT.py" bert_score/ 22 | python "bert_score/logBERT.py" --l "../PINC-filtered.jsonl" --t "BERTlog.jsonl" 23 | rm "../PINC-filtered.jsonl" 24 | python "../BERTScore/filterBanglaBert.py" --j "BERTlog.jsonl" --s "../source.bn" --t "../target.bn" --l $bert_low --u $bert_high 25 | rm "BERTlog.jsonl" 26 | cd .. 27 | rm -r bert_score 28 | cp "N-gram Repitition Filter/n_gram_repeatition_filter.py" bengali-stemmer/ 29 | python bengali-stemmer/n_gram_repeatition_filter.py --s source.bn --t target.bn 30 | python "Punctuation Filter/punctuation_filter.py" --s ngram_filtered_source.bn --t ngram_filtered_target.bn 31 | rm -r bengali-stemmer 32 | rm ngram_filtered_source.bn 33 | rm ngram_filtered_target.bn -------------------------------------------------------------------------------- /images/filter_sequence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csebuetnlp/banglaparaphrase/65cbac46e20b26332da39ae9a228202a26574abb/images/filter_sequence.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonlines --------------------------------------------------------------------------------