├── README.md ├── data.py ├── train.py ├── .gitignore ├── eval.py ├── model.py ├── computation_units.py ├── requirements.txt ├── raw_data_preprocessor.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Paraformer 2 | This repository contains the code for the paper: Attentive deep neural networks for legal document retrieval 3 | 4 | # Resource 5 | 6 | 7 | * Data provided by COLIEE Competition 8 | * Pretrained checkpoint can be downloaded [here](https://github.com/nguyenthanhasia/paraformer/releases/download/0.2/Paraformer.ckpt) 9 | * Project page can be accessed [here](https://nguyenthanhasia.github.io/AttentiveDNN4LegalDocRetrieval/). 10 | 11 | 12 | ## Installation 13 | 14 | 15 | ``` 16 | conda create -n paraformer python=3.6 17 | conda activate paraformer 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ## Training 22 | 23 | ``` 24 | python train.py \ 25 | --data-dir DATA_DIR \ 26 | --test-file TEST_FILE \ 27 | --max-epochs NUM_EPOCHES 28 | ``` 29 | 30 | ## Evaluating 31 | 32 | ``` 33 | python eval.py \ 34 | --data-dir DATA_DIR \ 35 | --test-file TEST_FILE \ 36 | --checkpoint CHECKPOINT_PATH \ 37 | --bm25-top-n TOP_N \ 38 | --alpha ALPHA 39 | ``` 40 | 41 | ## Citation 42 | 43 | - Nguyen, H., Phi, M., Ngo, X., Tran, V., Nguyen, L., & Tu, M. (2022). Attentive deep neural networks for legal document retrieval. Artificial Intelligence and Law, 1-30. Springer. 44 | 45 | BibTeX: 46 | ```bibtex 47 | @article{nguyen2022attentive, 48 | title={Attentive deep neural networks for legal document retrieval}, 49 | author={Nguyen, Ha-Thanh and Phi, Manh-Kien and Ngo, Xuan-Bach and Tran, Vu and Nguyen, Le-Minh and Tu, Minh-Phuong}, 50 | journal={Artificial Intelligence and Law}, 51 | pages={1--30}, 52 | year={2022}, 53 | publisher={Springer} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler, Dataset 2 | from pytorch_lightning import LightningDataModule 3 | BATCH_SIZE=1 4 | 5 | class Paraformer_Dataset(Dataset): 6 | def __init__(self, df): 7 | self.content = df["content"] 8 | self.article_content = df["article_content"] 9 | self.article_id = df["article_id"] 10 | self.label = df["label"] 11 | 12 | def __getitem__(self, idx): 13 | return self.content[idx], self.article_content[idx], self.article_id[idx],self.label[idx] 14 | 15 | def __len__(self): 16 | return len(self.content) 17 | 18 | 19 | class Paraformer_DataModule(LightningDataModule): 20 | 21 | def __init__(self,df_train=None, df_val=None, df_test=None): 22 | super().__init__() 23 | self.df_train, self.df_val, self.df_test=df_train, df_val, df_test 24 | 25 | def setup(self): 26 | if self.df_train is not None: 27 | self.train_dataset = Paraformer_Dataset(self.df_train) 28 | if self.df_val is not None: 29 | self.val_dataset= Paraformer_Dataset(self.df_val) 30 | if self.df_test is not None: 31 | self.test_dataset=Paraformer_Dataset(self.df_test) 32 | 33 | 34 | def train_dataloader(self): 35 | return DataLoader(self.train_dataset,batch_size=BATCH_SIZE, shuffle = True, num_workers=4) 36 | 37 | def val_dataloader(self): 38 | return DataLoader(self.val_dataset,batch_size= BATCH_SIZE, shuffle = False, num_workers=4) 39 | 40 | def test_dataloader(self): 41 | return DataLoader(self.test_dataset,batch_size= BATCH_SIZE, shuffle = False, num_workers=4) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import evaluate, load_data_coliee, create_df 3 | import pandas as pd 4 | from data import Paraformer_DataModule 5 | from pytorch_lightning import Trainer 6 | from model import Paraformer_Model 7 | 8 | 9 | def main(*args, **kargs): 10 | data_dir=hyperparams.data_dir 11 | test_file=hyperparams.test_file 12 | max_epochs=hyperparams.max_epochs 13 | save_top_k=hyperparams.save_top_k 14 | patience=hyperparams.patience 15 | gpus=hyperparams.gpus 16 | 17 | 18 | c_docs, c_keys, val_q, test_q, train_q, _ = load_data_coliee(data_dir,test_file=test_file) 19 | 20 | civil_dict={} 21 | for key,value in zip(c_keys,c_docs): 22 | civil_dict[key] = value 23 | 24 | df_train=create_df(train_q,civil_dict) 25 | df_val=create_df(val_q,civil_dict) 26 | df_test=create_df(test_q,civil_dict,neg_sampling=False) 27 | 28 | data_module=Paraformer_DataModule(df_train,df_val,df_test) 29 | data_module.setup() 30 | 31 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 32 | checkpoint_callback = ModelCheckpoint( 33 | monitor='avg_val_loss', 34 | filename='Paraformer', 35 | save_top_k=save_top_k, # save the top 3 models 36 | mode='min', # mode of the monitored quantity for optimization 37 | ) 38 | from pytorch_lightning.callbacks import EarlyStopping 39 | early_stop_callback = EarlyStopping( 40 | monitor='avg_val_loss', 41 | min_delta=0.00, 42 | patience=patience, 43 | verbose=False, 44 | mode='min' 45 | ) 46 | 47 | trainer = Trainer(max_epochs = max_epochs , gpus =gpus, callbacks=[checkpoint_callback,early_stop_callback]) 48 | 49 | model=Paraformer_Model() 50 | trainer.fit(model, data_module) 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--data-dir") 55 | parser.add_argument("--test-file") 56 | parser.add_argument("--max-epochs",type=int,default=200) 57 | parser.add_argument("--save-top-k",type=int,default=3) 58 | parser.add_argument("--patience",type=int,default=20) 59 | parser.add_argument("--gpus",type=int,default=1) 60 | 61 | hyperparams = parser.parse_args() 62 | 63 | main(hyperparams) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import evaluate, load_data_coliee, create_df 3 | from model import Paraformer_Model 4 | from rank_bm25 import BM25Okapi 5 | 6 | def weighted_sorted_relevance(query,model,c_docs,c_keys,bm25_top_n=10,alpha=0.1,top_articles=1): 7 | 8 | #Get BM25 Scores 9 | bm25_score=[] 10 | corpus=list(c_docs) 11 | tokenized_corpus = [doc.split(" ") for doc in corpus] 12 | bm25 = BM25Okapi(tokenized_corpus) 13 | bm25_scores = bm25.get_scores(query.split(" ")) 14 | 15 | bm25_filter_list=sorted(list(zip(c_keys,c_docs,bm25_scores)), 16 | key=lambda tup: tup[2],reverse=True)[:bm25_top_n] 17 | 18 | #Get Deep Scores 19 | c_keys,c_docs,bm25_scores = zip(*bm25_filter_list) 20 | deep_scores=[] 21 | final_scores=[] 22 | 23 | #Get Weighted Scores 24 | for article in bm25_filter_list: 25 | article_content=[sent.strip() for sent in article[1].split("\n") if sent.strip()!=""] 26 | deep_score=model.get_score(query, article_content) 27 | deep_scores.append(deep_score) 28 | final_scores.append(alpha*deep_score+(1-alpha)*article[2]) 29 | 30 | final_list=sorted(list(zip(c_keys,c_docs,bm25_scores,deep_scores,final_scores)), 31 | key=lambda tup: tup[4],reverse=True)[:top_articles] 32 | 33 | 34 | return [article[0] for article in (final_list)] 35 | 36 | def main(*args, **kargs): 37 | print(hyperparams) 38 | data_dir=hyperparams.data_dir 39 | test_file=hyperparams.test_file 40 | ckpt_path=hyperparams.checkpoint 41 | alpha=hyperparams.alpha 42 | bm25_top_n=hyperparams.bm25_top_n 43 | 44 | c_docs, c_keys, val_q, test_q, train_q, _ = load_data_coliee(data_dir,test_file=test_file) 45 | 46 | civil_dict={} 47 | for key,value in zip(c_keys,c_docs): 48 | civil_dict[key] = value 49 | 50 | df=create_df(test_q,civil_dict,neg_sampling=False) 51 | 52 | model = Paraformer_Model.load_from_checkpoint(checkpoint_path=ckpt_path) 53 | model.eval() 54 | # model.cuda() 55 | 56 | df.drop_duplicates(subset =["content","article_id"], keep = "first", inplace = True) 57 | df = df.groupby('content').article_id.apply(list).reset_index() #group the id 58 | 59 | df["preds"]=df["content"].apply(weighted_sorted_relevance,c_docs=c_docs, c_keys=c_keys, 60 | model=model,bm25_top_n=bm25_top_n,alpha=alpha) 61 | 62 | df["true_positive"] = [list(set(df.loc[r, "article_id"]) & set(df.loc[r, "preds"])) for r in range(len(df))] 63 | df["precision"]=df["true_positive"].apply(len)/df["preds"].apply(len) 64 | df["recall"]=df["true_positive"].apply(len)/df["article_id"].apply(len) 65 | df["f2"]=5*df["precision"]*df["recall"]/(4*df["precision"]+df["recall"]) 66 | df["f2"]=df["f2"].fillna(0) 67 | print(f'Precision: {df["precision"].mean():.3f} | Recall: {df["recall"].mean():.3f} | F2: {df["f2"].mean():.3f}.') 68 | 69 | 70 | if __name__ == "__main__": 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument("--data-dir") 73 | parser.add_argument("--test-file") 74 | parser.add_argument("--checkpoint") 75 | parser.add_argument("--bm25-top-n",type=int,default=10) 76 | parser.add_argument("--alpha",type=float,default=0.1) 77 | 78 | hyperparams = parser.parse_args() 79 | 80 | main(hyperparams) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningModule 2 | from sentence_transformers import SentenceTransformer 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from computation_units import Attention, Sparsemax 8 | 9 | def flat_accuracy(preds, labels): 10 | pred_flat = torch.argmax(preds.detach(), axis=1).flatten() 11 | labels_flat = labels.flatten() 12 | return torch.sum(pred_flat == labels_flat) / len(labels_flat) 13 | 14 | class Paraformer_Model(LightningModule): 15 | # Set up the classifier 16 | def __init__(self,base_model="paraphrase-mpnet-base-v2"): 17 | super().__init__() 18 | self.plm=SentenceTransformer(base_model) 19 | self.general_attn=Attention(768) 20 | self.classifier=nn.Linear(768,2) 21 | 22 | self.criterion = torch.nn.CrossEntropyLoss() 23 | 24 | def forward(self, query, article): 25 | q_vec = self.plm.encode(query,convert_to_tensor=True) 26 | q_vec = torch.unsqueeze(q_vec,1) 27 | # print(q_vec.size()) 28 | a_vecs=torch.stack([self.plm.encode(sent,convert_to_tensor=True) for sent in article]) 29 | # print(a_vecs.size()) 30 | a_vecs=a_vecs.permute(1,0,2) 31 | # print(a_vecs.size()) 32 | attn_output=self.general_attn(q_vec,a_vecs)[0] 33 | # print(attn_output.size()) 34 | out=self.classifier(attn_output) 35 | out = torch.squeeze(out,1) 36 | # print(out.size()) 37 | return out 38 | 39 | def training_step(self, batch, batch_idx): 40 | b_content, b_article_content, b_article_id, b_labels = batch 41 | logits = self.forward(b_content, b_article_content) 42 | 43 | loss = self.criterion(logits,b_labels) 44 | acc = flat_accuracy(logits, b_labels) 45 | 46 | return {'loss' : loss, 'acc':acc} 47 | 48 | def training_epoch_end(self, outputs): 49 | avg_loss = torch.stack([x['loss'] for x in outputs]).mean() 50 | avg_acc = torch.stack([x['acc'] for x in outputs]).mean() 51 | self.log('avg_train_loss', avg_loss) 52 | self.log('avg_train_acc', avg_acc) 53 | 54 | def validation_step(self, batch, batch_idx): 55 | b_content, b_article_content, b_article_id, b_labels = batch 56 | with torch.no_grad(): 57 | logits = self.forward(b_content, b_article_content) 58 | 59 | loss = self.criterion(logits,b_labels) 60 | acc = flat_accuracy(logits, b_labels) 61 | 62 | return {'loss' : loss, 'acc':acc} 63 | 64 | def validation_epoch_end(self, outputs): 65 | avg_loss = torch.stack([x['loss'] for x in outputs]).mean() 66 | avg_acc = torch.stack([x['acc'] for x in outputs]).mean() 67 | self.log('avg_val_loss', avg_loss) 68 | self.log('avg_val_acc', avg_acc) 69 | 70 | def test_step(self, batch, batch_idx): 71 | b_content, b_article_content, b_article_id, b_labels = batch 72 | with torch.no_grad(): 73 | logits = self.forward(b_content, b_article_content) 74 | 75 | loss = self.criterion(logits,b_labels) 76 | acc = flat_accuracy(logits, b_labels) 77 | 78 | return {'loss' : loss, 'acc':acc} 79 | 80 | def test_epoch_end(self, outputs): 81 | avg_loss = torch.stack([x['loss'] for x in outputs]).mean() 82 | avg_acc = torch.stack([x['acc'] for x in outputs]).mean() 83 | self.test_results = {"loss":avg_loss, "acc":avg_acc} 84 | print('avg_test_loss', avg_loss) 85 | print('avg_test_acc', avg_acc) 86 | self.log('avg_test_loss', avg_loss) 87 | self.log('avg_test_acc', avg_acc) 88 | 89 | 90 | 91 | def get_score(self, query, article): 92 | with torch.no_grad(): 93 | q_vec = self.plm.encode(query,convert_to_tensor=True) 94 | q_vec = torch.unsqueeze(q_vec,0)#for attention 95 | q_vec = torch.unsqueeze(q_vec,0)#batch 96 | a_vecs=torch.stack([self.plm.encode(sent,convert_to_tensor=True) for sent in article]) 97 | a_vecs=torch.unsqueeze(a_vecs,1)#batch 98 | a_vecs=a_vecs.permute(1,0,2) 99 | # print(q_vec.size(),a_vecs.size()) 100 | attn_output=self.general_attn(q_vec.cpu().detach(),a_vecs.cpu().detach())[0] 101 | out=self.classifier(attn_output) 102 | out = torch.squeeze(out,1) 103 | # print(out,out.size()) 104 | return out.cpu().detach().numpy()[0][1] 105 | 106 | 107 | def predict(self, query, article): 108 | with torch.no_grad(): 109 | q_vec = self.plm.encode(query,convert_to_tensor=True) 110 | q_vec = torch.unsqueeze(q_vec,0)#for attention 111 | q_vec = torch.unsqueeze(q_vec,0)#batch 112 | a_vecs=torch.stack([self.plm.encode(sent,convert_to_tensor=True) for sent in article]) 113 | a_vecs=torch.unsqueeze(a_vecs,1)#batch 114 | a_vecs=a_vecs.permute(1,0,2) 115 | # print(q_vec.size(),a_vecs.size()) 116 | attn_output=self.general_attn(q_vec.detach(),a_vecs.detach())[0] 117 | out=self.classifier(attn_output) 118 | out = torch.squeeze(out,1) 119 | return torch.argmax(out).cpu().detach().numpy() 120 | 121 | # def get_backbone(self): 122 | # return self.model 123 | 124 | def configure_optimizers(self): 125 | FULL_FINETUNING = True 126 | if FULL_FINETUNING: 127 | param_optimizer = list(self.plm.named_parameters())+list(self.general_attn.named_parameters())+list(self.classifier.named_parameters()) 128 | no_decay = ['bias', 'gamma', 'beta'] 129 | optimizer_grouped_parameters = [ 130 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 131 | 'weight_decay_rate': 0.01}, 132 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 133 | 'weight_decay_rate': 0.0} 134 | ] 135 | else: 136 | param_optimizer = list(self.general_attn.named_parameters()) 137 | optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}] 138 | 139 | return torch.optim.Adam(optimizer_grouped_parameters, lr=3e-5,eps=1e-8) -------------------------------------------------------------------------------- /computation_units.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Sparsemax(nn.Module): 7 | """Sparsemax function.""" 8 | 9 | def __init__(self, dim=None): 10 | """Initialize sparsemax activation 11 | 12 | Args: 13 | dim (int, optional): The dimension over which to apply the sparsemax function. 14 | """ 15 | super(Sparsemax, self).__init__() 16 | 17 | self.dim = -1 if dim is None else dim 18 | 19 | def forward(self, input): 20 | """Forward function. 21 | 22 | Args: 23 | input (torch.Tensor): Input tensor. First dimension should be the batch size 24 | 25 | Returns: 26 | torch.Tensor: [batch_size x number_of_logits] Output tensor 27 | 28 | """ 29 | # Sparsemax currently only handles 2-dim tensors, 30 | # so we reshape to a convenient shape and reshape back after sparsemax 31 | input = input.transpose(0, self.dim) 32 | original_size = input.size() 33 | input = input.reshape(input.size(0), -1) 34 | input = input.transpose(0, 1) 35 | dim = 1 36 | 37 | number_of_logits = input.size(dim) 38 | 39 | # Translate input by max for numerical stability 40 | input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input) 41 | 42 | # Sort input in descending order. 43 | # (NOTE: Can be replaced with linear time selection method described here: 44 | # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html) 45 | zs = torch.sort(input=input, dim=dim, descending=True)[0] 46 | range = torch.arange(start=1, end=number_of_logits + 1, step=1, dtype=input.dtype, device="cpu" if zs.get_device()==-1 else zs.get_device()).view(1, -1) 47 | range = range.expand_as(zs) 48 | 49 | # Determine sparsity of projection 50 | # print(range.get_device(),zs.get_device()) 51 | bound = 1 + range * zs 52 | cumulative_sum_zs = torch.cumsum(zs, dim) 53 | is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type()) 54 | k = torch.max(is_gt * range, dim, keepdim=True)[0] 55 | 56 | # Compute threshold function 57 | zs_sparse = is_gt * zs 58 | 59 | # Compute taus 60 | taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k 61 | taus = taus.expand_as(input) 62 | 63 | # Sparsemax 64 | self.output = torch.max(torch.zeros_like(input), input - taus) 65 | 66 | # Reshape back to original shape 67 | output = self.output 68 | output = output.transpose(0, 1) 69 | output = output.reshape(original_size) 70 | output = output.transpose(0, self.dim) 71 | 72 | return output 73 | 74 | class Attention(nn.Module): 75 | """ Applies attention mechanism on the `context` using the `query`. 76 | 77 | **Thank you** to IBM for their initial implementation of :class:`Attention`. Here is 78 | their `License 79 | `__. 80 | 81 | Args: 82 | dimensions (int): Dimensionality of the query and context. 83 | attention_type (str, optional): How to compute the attention score: 84 | 85 | * dot: :math:`score(H_j,q) = H_j^T q` 86 | * general: :math:`score(H_j, q) = H_j^T W_a q` 87 | 88 | Example: 89 | 90 | >>> attention = Attention(256) 91 | >>> query = torch.randn(5, 1, 256) 92 | >>> context = torch.randn(5, 5, 256) 93 | >>> output, weights = attention(query, context) 94 | >>> output.size() 95 | torch.Size([5, 1, 256]) 96 | >>> weights.size() 97 | torch.Size([5, 1, 5]) 98 | """ 99 | 100 | def __init__(self, dimensions, attention_type='general'): 101 | super(Attention, self).__init__() 102 | 103 | if attention_type not in ['dot', 'general']: 104 | raise ValueError('Invalid attention type selected.') 105 | 106 | self.attention_type = attention_type 107 | if self.attention_type == 'general': 108 | self.linear_in = nn.Linear(dimensions, dimensions, bias=False) 109 | 110 | self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False) 111 | self.sparsemax = Sparsemax(dim=-1) 112 | self.tanh = nn.Tanh() 113 | 114 | def forward(self, query, context): 115 | """ 116 | Args: 117 | query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of 118 | queries to query the context. 119 | context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data 120 | overwhich to apply the attention mechanism. 121 | 122 | Returns: 123 | :class:`tuple` with `output` and `weights`: 124 | * **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]): 125 | Tensor containing the attended features. 126 | * **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]): 127 | Tensor containing attention weights. 128 | """ 129 | batch_size, output_len, dimensions = query.size() 130 | query_len = context.size(1) 131 | 132 | if self.attention_type == "general": 133 | query = query.reshape(batch_size * output_len, dimensions) 134 | query = self.linear_in(query) 135 | query = query.reshape(batch_size, output_len, dimensions) 136 | 137 | # TODO: Include mask on PADDING_INDEX? 138 | 139 | # (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) -> 140 | # (batch_size, output_len, query_len) 141 | attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous()) 142 | 143 | # Compute weights across every context sequence 144 | attention_scores = attention_scores.view(batch_size * output_len, query_len) 145 | attention_weights = self.sparsemax(attention_scores) 146 | attention_weights = attention_weights.view(batch_size, output_len, query_len) 147 | 148 | # (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) -> 149 | # (batch_size, output_len, dimensions) 150 | mix = torch.bmm(attention_weights, context) 151 | 152 | # concat -> (batch_size * output_len, 2*dimensions) 153 | combined = torch.cat((mix, query), dim=2) 154 | combined = combined.view(batch_size * output_len, 2 * dimensions) 155 | 156 | # Apply linear_out on every 2nd dimension of concat 157 | # output -> (batch_size, output_len, dimensions) 158 | output = self.linear_out(combined).view(batch_size, output_len, dimensions) 159 | output = self.tanh(output) 160 | 161 | return output, attention_weights -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | aiohttp==3.7.4.post0 3 | alabaster==0.7.12 4 | albumentations==0.1.12 5 | altair==4.1.0 6 | appdirs==1.4.4 7 | argon2-cffi==20.1.0 8 | arviz==0.11.2 9 | astor==0.8.1 10 | astropy==4.2.1 11 | astunparse==1.6.3 12 | async-generator==1.10 13 | async-timeout==3.0.1 14 | atari-py==0.2.9 15 | atomicwrites==1.4.0 16 | attrs==21.2.0 17 | audioread==2.1.9 18 | autograd==1.3 19 | Babel==2.9.1 20 | backcall==0.2.0 21 | beautifulsoup4==4.6.3 22 | bleach==3.3.0 23 | blis==0.4.1 24 | bokeh==2.3.2 25 | Bottleneck==1.3.2 26 | branca==0.4.2 27 | bs4==0.0.1 28 | CacheControl==0.12.6 29 | cached-property==1.5.2 30 | cachetools==4.2.2 31 | catalogue==1.0.0 32 | certifi==2021.5.30 33 | cffi==1.14.5 34 | cftime==1.5.0 35 | chardet==3.0.4 36 | click==7.1.2 37 | cloudpickle==1.3.0 38 | cmake==3.12.0 39 | cmdstanpy==0.9.5 40 | colorcet==2.0.6 41 | colorlover==0.3.0 42 | community==1.0.0b1 43 | configparser==5.0.2 44 | contextlib2==0.5.5 45 | convertdate==2.3.2 46 | coverage==3.7.1 47 | coveralls==0.5 48 | crcmod==1.7 49 | cufflinks==0.17.3 50 | cupy-cuda101==9.1.0 51 | cvxopt==1.2.6 52 | cvxpy==1.0.31 53 | cycler==0.10.0 54 | cymem==2.0.5 55 | Cython==0.29.23 56 | daft==0.0.4 57 | dask==2.12.0 58 | datascience==0.10.6 59 | debugpy==1.0.0 60 | decorator==4.4.2 61 | defusedxml==0.7.1 62 | descartes==1.1.0 63 | dill==0.3.4 64 | distributed==1.25.3 65 | dlib==19.18.0 66 | dm-tree==0.1.6 67 | docker-pycreds==0.4.0 68 | docopt==0.6.2 69 | docutils==0.17.1 70 | dopamine-rl==1.0.5 71 | earthengine-api==0.1.269 72 | easydict==1.9 73 | ecos==2.0.7.post1 74 | editdistance==0.5.3 75 | en-core-web-sm==2.2.5 76 | entrypoints==0.3 77 | ephem==4.0.0.2 78 | et-xmlfile==1.1.0 79 | fa2==0.3.5 80 | fastai==1.0.61 81 | fastdtw==0.3.4 82 | fastprogress==1.0.0 83 | fastrlock==0.6 84 | fbprophet==0.7.1 85 | feather-format==0.4.1 86 | filelock==3.0.12 87 | firebase-admin==4.4.0 88 | fix-yahoo-finance==0.0.22 89 | Flask==1.1.4 90 | flatbuffers==1.12 91 | folium==0.8.3 92 | fsspec==2021.6.1 93 | future==0.18.2 94 | gast==0.4.0 95 | GDAL==2.2.2 96 | gdown==3.6.4 97 | gensim==3.6.0 98 | geographiclib==1.50 99 | geopy==1.17.0 100 | gin-config==0.4.0 101 | gitdb==4.0.7 102 | GitPython==3.1.18 103 | glob2==0.7 104 | google==2.0.3 105 | google-api-core==1.26.3 106 | google-api-python-client==1.12.8 107 | google-auth==1.31.0 108 | google-auth-httplib2==0.0.4 109 | google-auth-oauthlib==0.4.4 110 | google-cloud-bigquery==1.21.0 111 | google-cloud-bigquery-storage==1.1.0 112 | google-cloud-core==1.0.3 113 | google-cloud-datastore==1.8.0 114 | google-cloud-firestore==1.7.0 115 | google-cloud-language==1.2.0 116 | google-cloud-storage==1.18.1 117 | google-cloud-translate==1.5.0 118 | google-colab==1.0.0 119 | google-pasta==0.2.0 120 | google-resumable-media==0.4.1 121 | googleapis-common-protos==1.53.0 122 | googledrivedownloader==0.4 123 | graphviz==0.10.1 124 | greenlet==1.1.0 125 | grpcio==1.34.1 126 | gspread==3.0.1 127 | gspread-dataframe==3.0.8 128 | gym==0.17.3 129 | h5py==3.1.0 130 | HeapDict==1.0.1 131 | hijri-converter==2.1.2 132 | holidays==0.10.5.2 133 | holoviews==1.14.4 134 | html5lib==1.0.1 135 | httpimport==0.5.18 136 | httplib2==0.17.4 137 | httplib2shim==0.0.3 138 | huggingface-hub==0.0.13 139 | humanize==0.5.1 140 | hyperopt==0.1.2 141 | ideep4py==2.0.0.post3 142 | idna==2.10 143 | imageio==2.4.1 144 | imagesize==1.2.0 145 | imbalanced-learn==0.4.3 146 | imblearn==0.0 147 | imgaug==0.2.9 148 | importlib-metadata==4.5.0 149 | importlib-resources==5.1.4 150 | imutils==0.5.4 151 | inflect==2.1.0 152 | iniconfig==1.1.1 153 | install==1.3.4 154 | intel-openmp==2021.2.0 155 | intervaltree==2.1.0 156 | ipykernel==4.10.1 157 | ipython==5.5.0 158 | ipython-genutils==0.2.0 159 | ipython-sql==0.3.9 160 | ipywidgets==7.6.3 161 | itsdangerous==1.1.0 162 | jax==0.2.13 163 | jaxlib==0.1.66+cuda110 164 | jdcal==1.4.1 165 | jedi==0.18.0 166 | jieba==0.42.1 167 | Jinja2==2.11.3 168 | joblib==1.0.1 169 | jpeg4py==0.1.4 170 | jsonschema==2.6.0 171 | jupyter==1.0.0 172 | jupyter-client==5.3.5 173 | jupyter-console==5.2.0 174 | jupyter-core==4.7.1 175 | jupyterlab-pygments==0.1.2 176 | jupyterlab-widgets==1.0.0 177 | kaggle==1.5.12 178 | kapre==0.3.5 179 | Keras==2.4.3 180 | keras-nightly==2.5.0.dev2021032900 181 | Keras-Preprocessing==1.1.2 182 | keras-vis==0.4.1 183 | kiwisolver==1.3.1 184 | korean-lunar-calendar==0.2.1 185 | librosa==0.8.1 186 | lightgbm==2.2.3 187 | llvmlite==0.34.0 188 | lmdb==0.99 189 | LunarCalendar==0.0.9 190 | lxml==4.2.6 191 | Markdown==3.3.4 192 | MarkupSafe==2.0.1 193 | matplotlib==3.2.2 194 | matplotlib-inline==0.1.2 195 | matplotlib-venn==0.11.6 196 | missingno==0.4.2 197 | mistune==0.8.4 198 | mizani==0.6.0 199 | mkl==2019.0 200 | mlxtend==0.14.0 201 | more-itertools==8.8.0 202 | moviepy==0.2.3.5 203 | mpmath==1.2.1 204 | msgpack==1.0.2 205 | multidict==5.1.0 206 | multiprocess==0.70.12.2 207 | multitasking==0.0.9 208 | murmurhash==1.0.5 209 | music21==5.5.0 210 | natsort==5.5.0 211 | nbclient==0.5.3 212 | nbconvert==5.6.1 213 | nbformat==5.1.3 214 | nest-asyncio==1.5.1 215 | netCDF4==1.5.6 216 | networkx==2.5.1 217 | nibabel==3.0.2 218 | nltk==3.2.5 219 | notebook==5.3.1 220 | numba==0.51.2 221 | numexpr==2.7.3 222 | numpy==1.19.5 223 | nvidia-ml-py3==7.352.0 224 | oauth2client==4.1.3 225 | oauthlib==3.1.1 226 | okgrade==0.4.3 227 | opencv-contrib-python==4.1.2.30 228 | opencv-python==4.1.2.30 229 | openpyxl==2.5.9 230 | opt-einsum==3.3.0 231 | osqp==0.6.2.post0 232 | packaging==20.9 233 | palettable==3.3.0 234 | pandas==1.1.5 235 | pandas-datareader==0.9.0 236 | pandas-gbq==0.13.3 237 | pandas-profiling==1.4.1 238 | pandocfilters==1.4.3 239 | panel==0.11.3 240 | param==1.10.1 241 | parso==0.8.2 242 | pathlib==1.0.1 243 | pathtools==0.1.2 244 | patsy==0.5.1 245 | pexpect==4.8.0 246 | pickleshare==0.7.5 247 | Pillow==7.1.2 248 | pip-tools==4.5.1 249 | plac==1.1.3 250 | plotly==4.4.1 251 | plotnine==0.6.0 252 | pluggy==0.7.1 253 | pooch==1.4.0 254 | portpicker==1.3.9 255 | prefetch-generator==1.0.1 256 | preshed==3.0.5 257 | prettytable==2.1.0 258 | progressbar2==3.38.0 259 | prometheus-client==0.11.0 260 | promise==2.3 261 | prompt-toolkit==1.0.18 262 | protobuf==3.12.4 263 | psutil==5.4.8 264 | psycopg2==2.7.6.1 265 | ptyprocess==0.7.0 266 | py==1.10.0 267 | pyarrow==3.0.0 268 | pyasn1==0.4.8 269 | pyasn1-modules==0.2.8 270 | pycocotools==2.0.2 271 | pycparser==2.20 272 | pyct==0.4.8 273 | pydata-google-auth==1.2.0 274 | pyDeprecate==0.3.0 275 | pydot==1.3.0 276 | pydot-ng==2.0.0 277 | pydotplus==2.0.2 278 | PyDrive==1.3.1 279 | pyemd==0.5.1 280 | pyerfa==2.0.0 281 | pyglet==1.5.0 282 | Pygments==2.6.1 283 | pygobject==3.26.1 284 | pymc3==3.11.2 285 | PyMeeus==0.5.11 286 | pymongo==3.11.4 287 | pymystem3==0.2.0 288 | PyOpenGL==3.1.5 289 | pyparsing==2.4.7 290 | pyrsistent==0.17.3 291 | pysndfile==1.3.8 292 | PySocks==1.7.1 293 | pystan==2.19.1.1 294 | pytest==3.6.4 295 | python-apt==0.0.0 296 | python-chess==0.23.11 297 | python-dateutil==2.8.1 298 | python-louvain==0.15 299 | python-slugify==5.0.2 300 | python-utils==2.5.6 301 | pytorch-lightning==1.3.7.post0 302 | pytz==2018.9 303 | pyviz-comms==2.0.2 304 | PyWavelets==1.1.1 305 | PyYAML==5.4.1 306 | pyzmq==22.1.0 307 | qdldl==0.1.5.post0 308 | qtconsole==5.1.0 309 | QtPy==1.9.0 310 | rank-bm25==0.2.1 311 | regex==2019.12.20 312 | requests==2.23.0 313 | requests-oauthlib==1.3.0 314 | resampy==0.2.2 315 | retrying==1.3.3 316 | rpy2==3.4.5 317 | rsa==4.7.2 318 | sacremoses==0.0.45 319 | scikit-image==0.16.2 320 | scikit-learn==0.22.2.post1 321 | scipy==1.4.1 322 | screen-resolution-extra==0.0.0 323 | scs==2.1.4 324 | seaborn==0.11.1 325 | semver==2.13.0 326 | Send2Trash==1.5.0 327 | sentence-transformers==2.0.0 328 | sentencepiece==0.1.96 329 | sentry-sdk==1.1.0 330 | setuptools-git==1.2 331 | Shapely==1.7.1 332 | shortuuid==1.0.1 333 | simplegeneric==0.8.1 334 | six==1.15.0 335 | sklearn==0.0 336 | sklearn-pandas==1.8.0 337 | smart-open==5.1.0 338 | smmap==4.0.0 339 | snowballstemmer==2.1.0 340 | sortedcontainers==2.4.0 341 | SoundFile==0.10.3.post1 342 | spacy==2.2.4 343 | Sphinx==1.8.5 344 | sphinxcontrib-serializinghtml==1.1.5 345 | sphinxcontrib-websupport==1.2.4 346 | SQLAlchemy==1.4.18 347 | sqlparse==0.4.1 348 | srsly==1.0.5 349 | statsmodels==0.10.2 350 | subprocess32==3.5.4 351 | sympy==1.7.1 352 | tables==3.4.4 353 | tabulate==0.8.9 354 | tblib==1.7.0 355 | tensorboard==2.4.1 356 | tensorboard-data-server==0.6.1 357 | tensorboard-plugin-wit==1.8.0 358 | tensorflow==2.5.0 359 | tensorflow-datasets==4.0.1 360 | tensorflow-estimator==2.5.0 361 | tensorflow-gcs-config==2.5.0 362 | tensorflow-hub==0.12.0 363 | tensorflow-metadata==1.0.0 364 | tensorflow-probability==0.12.1 365 | termcolor==1.1.0 366 | terminado==0.10.1 367 | testpath==0.5.0 368 | text-unidecode==1.3 369 | textblob==0.15.3 370 | Theano-PyMC==1.1.2 371 | thinc==7.4.0 372 | tifffile==2021.6.14 373 | tokenizers==0.10.3 374 | toml==0.10.2 375 | toolz==0.11.1 376 | torch==1.9.0+cu102 377 | torchmetrics==0.4.0 378 | torchsummary==1.5.1 379 | torchtext==0.10.0 380 | torchvision==0.10.0+cu102 381 | tornado==5.1.1 382 | tqdm==4.41.1 383 | traitlets==5.0.5 384 | transformers==4.8.2 385 | tweepy==3.10.0 386 | typeguard==2.7.1 387 | typing-extensions==3.7.4.3 388 | tzlocal==1.5.1 389 | uritemplate==3.0.1 390 | urllib3==1.24.3 391 | vega-datasets==0.9.0 392 | wandb==0.10.33 393 | wasabi==0.8.2 394 | wcwidth==0.2.5 395 | webencodings==0.5.1 396 | Werkzeug==1.0.1 397 | widgetsnbextension==3.5.1 398 | wordcloud==1.5.0 399 | wrapt==1.12.1 400 | xarray==0.18.2 401 | xgboost==0.90 402 | xkit==0.0.0 403 | xlrd==1.1.0 404 | xlwt==1.3.0 405 | yarl==1.6.3 406 | yellowbrick==0.9.1 407 | zict==2.0.0 408 | zipp==3.4.1 409 | -------------------------------------------------------------------------------- /raw_data_preprocessor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import xml.etree.ElementTree as Et 3 | 4 | 5 | def _parse_article_text(article_text): 6 | article_element = {} 7 | article_id = None 8 | for _l in article_text.split("\n"): 9 | if _l.startswith("Article"): 10 | article_id = _l[len("Article "):] 11 | else: 12 | if article_id is not None: 13 | if article_id not in article_element: 14 | article_element[article_id] = "" 15 | article_element[article_id] = article_element[article_id] + " \n " + _l 16 | else: 17 | print("[W] error id = {} with text = {}".format(article_id, _l)) 18 | 19 | return article_element 20 | 21 | 22 | def load_samples(filexml, file_alignment=None): 23 | try: 24 | if file_alignment is not None: 25 | tree_alignment = Et.parse(file_alignment) 26 | root_alignment = tree_alignment.getroot() 27 | tree = Et.parse(filexml) 28 | root = tree.getroot() 29 | samples = [] 30 | for i in range(0, len(root)): 31 | sample = {'result': []} 32 | for j, e in enumerate(root[i]): 33 | if e.tag == "t1": 34 | if file_alignment is not None: 35 | article_elements = _parse_article_text( 36 | root_alignment[i][j].text.strip()) 37 | else: 38 | article_elements = _parse_article_text(e.text.strip()) 39 | 40 | sample['result'] = list(article_elements.keys()) 41 | elif e.tag == "t2": 42 | question = e.text.strip() 43 | sample['content'] = question if len(question) > 0 else None 44 | sample.update( 45 | {'index': root[i].attrib['id'], 'label': root[i].attrib.get('label', "N")}) 46 | 47 | # filter the noise samples 48 | if sample['content'] is not None: 49 | samples.append(sample) 50 | else: 51 | print("[Important warning] samples {} is ignored".format(sample)) 52 | 53 | return samples 54 | except Exception as e: 55 | print(e) 56 | print("[Err] parse tree error {}".format(filexml)) 57 | 58 | 59 | def load_civil_codes(file_path, path_data_alignment=None): 60 | article_elements = {} 61 | article_id = None 62 | civil_name = "" 63 | chapter_name = "" 64 | section_name = "" 65 | subsection_name = "" 66 | division_name = "" 67 | part_name = "" 68 | annotated_line = "" 69 | 70 | # load data alignment in english language 71 | if path_data_alignment is not None: 72 | with open(path_data_alignment, "rt") as file_align: 73 | data_alignment = [_l.strip() for _l in file_align.readlines()] 74 | 75 | # load data 76 | with open(file_path, "rt") as file_civil: 77 | for i, data_l in enumerate(file_civil.readlines()): 78 | data_l = data_l.strip() 79 | _l = data_alignment[i] 80 | 81 | if _l.startswith("Civil Code "): 82 | civil_name = data_l 83 | part_name, chapter_name, section_name, subsection_name, division_name, annotated_line = \ 84 | "", "", "", "", "", "" 85 | elif _l.startswith('Part '): 86 | part_name = data_l 87 | chapter_name, section_name, subsection_name, division_name, annotated_line = "", "", "", "", "" 88 | elif _l.startswith('Chapter '): 89 | chapter_name = data_l 90 | section_name, subsection_name, division_name, annotated_line = "", "", "", "" 91 | elif _l.startswith('Section '): 92 | section_name = data_l 93 | subsection_name, division_name, annotated_line = "", "", "" 94 | elif _l.startswith('Subsection '): 95 | subsection_name = data_l 96 | division_name, annotated_line = "", "" 97 | elif _l.startswith('Division '): 98 | division_name = data_l 99 | annotated_line = "" 100 | 101 | elif re.fullmatch(r'\([^)]*\)', _l.strip()) is not None: 102 | annotated_line = data_l 103 | # print("[W] Skip line {}".format(_l)) 104 | elif _l.startswith("Article") and "deleted" not in _l.lower(): 105 | article_id = re.search(r"Article ([^ ]*) ", _l).group(1) 106 | 107 | # get article content with out id part 108 | article_info = data_l.split('\u3000') 109 | if len(article_info) > 1 and len(article_info[1].strip()) > 0: 110 | article_content = article_info[1].strip() 111 | else: 112 | article_content = data_l 113 | 114 | # save article 115 | if article_id not in article_elements: 116 | article_elements[article_id] = { 117 | "civil_name": civil_name, 118 | "chapter_name": chapter_name, 119 | "section_name": section_name, 120 | "subsection_name": subsection_name, 121 | "division_name": division_name, 122 | "part_name": part_name, 123 | "annotated_line": annotated_line, 124 | "content": article_content, 125 | } 126 | # print(article_id) 127 | else: 128 | if article_id is not None: 129 | if article_id not in article_elements: 130 | article_elements[article_id] = { 131 | "civil_name": civil_name, 132 | "chapter_name": chapter_name, 133 | "section_name": section_name, 134 | "subsection_name": subsection_name, 135 | "division_name": division_name, 136 | "part_name": part_name, 137 | "annotated_line": annotated_line, 138 | "content": article_content, 139 | } 140 | article_elements[article_id]["content"] = article_elements[article_id]["content"] + " \n " + data_l 141 | else: 142 | print("[W] error id = {} with text = {}".format( 143 | article_id, _l)) 144 | 145 | return article_elements 146 | 147 | 148 | def _article_content(article_info, chunk_content_info=None, tokenizer=None): 149 | if chunk_content_info is not None and len(article_info["content"]) > 0: 150 | chunk_content_size, chunk_content_stride = chunk_content_info[0], chunk_content_info[1] 151 | sub_contents = [] 152 | full_content = article_info["content"] 153 | words = tokenizer( 154 | full_content) if tokenizer is not None else full_content.split(" ") 155 | separate_w = '' if tokenizer is not None else " " 156 | 157 | if len(words) > chunk_content_size: 158 | for i_start in range(0, len(words), chunk_content_size-chunk_content_stride): 159 | sub_cont = separate_w.join( 160 | words[i_start:i_start + chunk_content_size]) 161 | sub_contents.append(sub_cont) 162 | if len(words[i_start:i_start + chunk_content_size]) < chunk_content_size: 163 | break 164 | 165 | articles = ["{} {} {} {} {} {} {}".format(article_info["part_name"], 166 | article_info["chapter_name"], 167 | article_info["section_name"], 168 | article_info["subsection_name"], 169 | article_info["division_name"], 170 | article_info["annotated_line"], 171 | full_content, )] + ["{} {} {} {} {} {} {}".format(article_info["part_name"], 172 | article_info["chapter_name"], 173 | article_info["section_name"], 174 | article_info["subsection_name"], 175 | article_info["division_name"], 176 | article_info["annotated_line"], 177 | sub_content) for sub_content in sub_contents] 178 | return articles 179 | 180 | else: 181 | return ["{} {} {} {} {} {} {}".format(article_info["part_name"], 182 | article_info["chapter_name"], 183 | article_info["section_name"], 184 | article_info["subsection_name"], 185 | article_info["division_name"], 186 | article_info["annotated_line"], 187 | article_info["content"], )] 188 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | from typing import Any, Dict, List 3 | from rank_bm25 import BM25Okapi 4 | 5 | import nltk 6 | import glob 7 | import pickle 8 | import torch 9 | import pandas as pd 10 | import re 11 | 12 | from raw_data_preprocessor import load_civil_codes, load_samples, _article_content 13 | 14 | 15 | def f_score(p, r, beta=1): 16 | y = (beta * beta * p + r) 17 | return (1 + beta * beta) * p * r / y if y != 0 else 0.0 18 | 19 | 20 | def micro_result(count_real_lb, count_predicted, count_true, count_gold_lb=138): 21 | p = count_true/count_predicted if count_predicted != 0 else 0.0 22 | r = count_true/count_real_lb if count_real_lb != 0 else 0.0 23 | result = {"count_real_lb": count_real_lb, 24 | "count_predicted": count_predicted, 25 | "count_gold_lb": count_gold_lb, 26 | "count_true": count_true, 27 | "P": p, 28 | "R": r, 29 | "f1": f_score(p, r, 1), 30 | "f2": f_score(p, r, 2), 31 | "f2_": f_score(p, count_true/count_gold_lb, 2)} 32 | print(result) 33 | return result 34 | 35 | 36 | def evaluate_by_similarity(similarities_, gold_data, c_keys, topk=150): 37 | count_true = 0 38 | count_all_prediction = 0 39 | count_all_gold_lb = 0 40 | 41 | idx_result = similarities_.argsort()[:, -topk:] 42 | for i in range(idx_result.shape[0]): 43 | gold_lb = gold_data[i]['result'] 44 | count_all_gold_lb += len(gold_lb) 45 | 46 | pred = [c_keys[idx] for idx in idx_result[i]] 47 | count_all_prediction += len(pred) 48 | 49 | for i, pred_lb in enumerate(pred): 50 | if pred_lb in gold_lb: 51 | count_true += 1 52 | 53 | print(count_true, count_all_prediction, count_all_gold_lb, 54 | 'P: ', count_true/count_all_prediction, 55 | 'R: ', count_true/count_all_gold_lb, 56 | 'F1: ', f_score(count_true*1.0/count_all_prediction, 57 | count_true*1.0/count_all_gold_lb), 58 | 'F2: ', f_score(count_true*1.0/count_all_prediction, 59 | count_true*1.0/count_all_gold_lb, beta=2), 60 | ) 61 | return idx_result 62 | 63 | 64 | def evaluate_by_label(prediction_file, test_dat_file, ensemble_files=None): 65 | test_dat = pd.read_csv(test_dat_file, sep=',') 66 | predictions = [] 67 | 68 | count_real_lb = 0 69 | count_gold_lb = 138 70 | count_true = 0 71 | count_predicted = 0 72 | ensemble_files = ensemble_files or [] 73 | 74 | if prediction_file not in ensemble_files: 75 | ensemble_files.append(prediction_file) 76 | 77 | for pred_file in ensemble_files: 78 | prediction_ = pd.read_csv(pred_file, sep='\t') 79 | predictions.append(prediction_) 80 | 81 | for i in range(len(test_dat)): 82 | if test_dat['label'][i] == 1: 83 | count_real_lb += 1 84 | for prediction in predictions: 85 | if prediction['prediction'][i] == 1: 86 | count_true += 1 87 | break 88 | 89 | for prediction in predictions: 90 | if prediction['prediction'][i] == 1: 91 | count_predicted += 1 92 | break 93 | 94 | return micro_result(count_real_lb, count_predicted, count_true, count_gold_lb) 95 | 96 | 97 | def evaluate(similarities_, gold_data, topk=150, c_keys=None): 98 | try: 99 | count_true = 0 100 | count_all_prediction = 0 101 | count_all_gold_lb = 0 102 | 103 | idx_result = similarities_.argsort()[:, -topk:] 104 | for i_gold in range(idx_result.shape[0]): 105 | gold_lb = gold_data[i_gold]['result'] 106 | count_all_gold_lb += len(gold_lb) 107 | 108 | pred = [c_keys[idx] for idx in idx_result[i_gold]] 109 | count_all_prediction += len(pred) 110 | 111 | for _i, pred_lb in enumerate(pred): 112 | if pred_lb in gold_lb: 113 | count_true += 1 114 | 115 | print(count_true, count_all_prediction, count_all_gold_lb, 116 | 'P: ', count_true / count_all_prediction, 117 | 'R: ', count_true / count_all_gold_lb, 118 | 'F1: ', f_score(count_true * 1.0 / count_all_prediction, 119 | count_true * 1.0 / count_all_gold_lb), 120 | 'F2: ', f_score(count_true * 1.0 / count_all_prediction, 121 | count_true * 1.0 / count_all_gold_lb, beta=2), 122 | ) 123 | return idx_result 124 | except Exception as e: 125 | print(e) 126 | return idx_result 127 | 128 | 129 | def load_data_coliee(path_folder_base="../coliee3_2020/", postags_select=None, ids_test=None, ids_dev=None, 130 | lang='en', path_data_alignment=None, chunk_content_info=None, tokenizer=None, test_file=None): 131 | if ids_test is None: 132 | ids_test = ['R01'] 133 | if ids_dev is None or len(ids_dev) == 0: 134 | ids_dev = ids_test 135 | #print('Test ids = {}, Dev ids = {}', ids_test, ids_dev) 136 | 137 | if lang != 'en' and path_data_alignment is None: 138 | print("[Warn] Miss input meta-data for alignment and parsing structure for language {}".format(lang)) 139 | if lang == 'en' and path_data_alignment is None: 140 | path_data_alignment = path_folder_base 141 | articles = load_civil_codes("{}/text/civil_code_{}-1to724-2.txt".format(path_folder_base, lang), 142 | path_data_alignment="{}/text/civil_code_en-1to724-2.txt".format(path_data_alignment)) 143 | #print(len(articles)) 144 | #print(articles["2"]) 145 | 146 | # load annotated data 147 | data = [] 148 | for file_path in glob.glob("{}/train/*.xml".format(path_folder_base)): 149 | data = data + load_samples(file_path, file_path.replace('_{}.'.format(lang), '_en.') 150 | .replace(path_folder_base, path_data_alignment)) 151 | data_test = [] 152 | if test_file is not None: 153 | data_test = load_samples(test_file) 154 | #print('len (data) = {}, len(data_test) = {}'.format(len(data), len(data_test))) 155 | 156 | #pprint(_article_content(articles["5"], 157 | # chunk_content_info, tokenizer=tokenizer)) 158 | 159 | if len(data_test) > 0: 160 | test_q = [q for q in data_test] 161 | else: 162 | test_q = [q for q in data if q['index'].split("-")[0] in ids_test] 163 | #print('Len test_q', len(test_q)) 164 | dev_q = [q for q in data if q['index'].split("-")[0] in ids_dev] 165 | #print('Len dev_q', len(dev_q)) 166 | train_q = [q for q in data if q['index'].split( 167 | "-")[0] not in ids_test + ids_dev] 168 | #print('Len train_q', len(train_q)) 169 | 170 | c_docs = [] 171 | c_sub_docs = [] 172 | c_keys = [] 173 | c_sub_keys = [] 174 | sub_key_mapping = {} 175 | for k, c in articles.items(): 176 | if postags_select is not None and isinstance(postags_select, list): 177 | c = postag_filter(_article_content( 178 | c, chunk_content_info, tokenizer=tokenizer), tags_filter=postags_select) 179 | else: 180 | c = _article_content(c, chunk_content_info, tokenizer=tokenizer) 181 | c_docs.append(c[0]) 182 | c_keys.append(k) 183 | for i_civil, c_sub in enumerate(c[1:]): 184 | c_sub_docs.append(c_sub) 185 | c_sub_keys.append("{}-sub{}".format(k, i_civil)) 186 | if k not in sub_key_mapping: 187 | sub_key_mapping[k] = [] 188 | sub_key_mapping[k].append("{}-sub{}".format(k, i_civil)) 189 | 190 | if chunk_content_info is not None: 191 | return c_docs, c_keys, dev_q, test_q, train_q, (c_sub_docs, c_sub_keys, sub_key_mapping) 192 | else: 193 | return c_docs, c_keys, dev_q, test_q, train_q, None 194 | 195 | 196 | def postag_filter(input_str, tags_filter=["V", "N", "P", "."]): 197 | words = nltk.word_tokenize(input_str) 198 | pos = nltk.pos_tag(words) 199 | new_words = [] 200 | 201 | for p in pos: 202 | if p[1][0] in tags_filter: 203 | new_words.append(p[0]) 204 | return " ".join(new_words) 205 | 206 | 207 | def aggregate_results(base_folder, aggregate_predictions=None, keys=None): 208 | prediction_mt = pickle.load( 209 | open("{}/predictions.pkl".format(base_folder), "rb")) 210 | test_dat = pd.read_csv("{}/test.tsv".format(base_folder), sep="\t") 211 | prediction = pd.read_csv( 212 | "{}/test_results_mrpc.txt".format(base_folder), sep="\t") 213 | probs = torch.softmax(torch.from_numpy(prediction_mt), dim=1) 214 | 215 | # aggregate gold values 216 | if aggregate_predictions is None and keys is None: 217 | aggregate_predictions = {} 218 | keys = [] 219 | 220 | predicted_pairs = set() 221 | for k, v_s in aggregate_predictions.items(): 222 | for v in v_s: 223 | predicted_pairs.add((v[0], v[1])) 224 | 225 | for i in range(len(test_dat)): 226 | if prediction['prediction'][i] == 1: 227 | # H30-1-A Q0 886 1 0.193 JNLP 228 | query_id = test_dat["#1 ID"][i] 229 | c_id = test_dat["#2 ID"][i] 230 | score = probs[i][1] 231 | if query_id not in aggregate_predictions: 232 | keys.append(query_id) 233 | aggregate_predictions[query_id] = [] 234 | 235 | if (query_id, c_id) not in predicted_pairs: 236 | aggregate_predictions[query_id].append((query_id, c_id, score)) 237 | predicted_pairs.add((query_id, c_id)) 238 | 239 | return aggregate_predictions, keys 240 | 241 | 242 | def aggregate_all_results(prediction_files, gold_test_file, topk=1, append_unpredicted_q=True): 243 | prediction_mt_s = [pickle.load(open(f_, 'rb')) for f_ in prediction_files] 244 | 245 | # load test file - gold data for question id and article idß 246 | test_dat = pd.read_csv(gold_test_file, sep=",") 247 | 248 | predicted_pairs = {} 249 | unpredicted_pairs = {} 250 | individual_model_stats = [[] for i in range(len(prediction_files))] 251 | 252 | for i_mod, prediction_mt in enumerate(prediction_mt_s): 253 | probs = torch.softmax(torch.from_numpy(prediction_mt), dim=1) 254 | 255 | for i in range(len(test_dat)): 256 | # H30-1-A Q0 886 1 0.193 JNLP 257 | query_id = test_dat["#1 ID"][i] 258 | c_id = re.sub(r'-sub.*', '', test_dat["#2 ID"][i]) 259 | score = probs[i][1] 260 | 261 | if probs[i][1] > probs[i][0]: 262 | if (query_id, c_id) not in predicted_pairs: 263 | predicted_pairs[(query_id, c_id)] = [] 264 | 265 | predicted_pairs[(query_id, c_id)].append(score) 266 | else: 267 | if (query_id, c_id) not in unpredicted_pairs: 268 | unpredicted_pairs[(query_id, c_id)] = [] 269 | 270 | unpredicted_pairs[(query_id, c_id)].append(score) 271 | 272 | # stats each model 273 | individual_model_stats[i_mod].append((query_id, c_id, score)) 274 | 275 | # sort stats each model 276 | new_stats = [{} for i in range(len(prediction_files))] 277 | for i_mod, result in enumerate(individual_model_stats): 278 | for stat_e in result: 279 | if stat_e[0] not in new_stats[i_mod]: 280 | new_stats[i_mod][stat_e[0]] = [] 281 | new_stats[i_mod][stat_e[0]].append((stat_e[1], stat_e[2].item())) 282 | for q_id, v in new_stats[i_mod].items(): 283 | new_stats[i_mod][q_id].sort(key=lambda x: x[1], reverse=True) 284 | new_stats[i_mod][q_id] = new_stats[i_mod][q_id][:topk] 285 | individual_model_stats = new_stats 286 | 287 | # 288 | # aggregrate result from many models 289 | def aggregrate_result_(pairs_): 290 | aggregate_results = {} 291 | for k, v in pairs_.items(): 292 | if k[0] not in aggregate_results: 293 | aggregate_results[k[0]] = [] 294 | # aggregate_results[k[0]].append((k[0], k[1], max(v))) 295 | aggregate_results[k[0]].append((k[0], k[1], sum(v) / len(v))) 296 | return aggregate_results 297 | 298 | predicted_results = aggregrate_result_(predicted_pairs) 299 | unpredicted_results = aggregrate_result_(unpredicted_pairs) 300 | 301 | # append unpredicted question by top 1 302 | miss_prediction_keys = set() 303 | if append_unpredicted_q: 304 | miss_prediction_keys = set(unpredicted_results.keys()).difference( 305 | set(predicted_results.keys())) 306 | print('Miss question ids: {}'.format(miss_prediction_keys)) 307 | for q_id in miss_prediction_keys: 308 | unpredicted_results[q_id].sort(key=lambda x: x[2], reverse=True) 309 | predicted_results[q_id] = unpredicted_results[q_id][:1] 310 | 311 | # 312 | # aggregrate gold label 313 | gold_results = {} 314 | gold_all_q_ids = set() 315 | for i in range(len(test_dat)): 316 | query_id = test_dat["#1 ID"][i] 317 | # test_dat["#2 ID"][i] 318 | c_id = re.sub(r'-sub.*', '', test_dat["#2 ID"][i]) 319 | gold_all_q_ids.add(query_id) 320 | 321 | if test_dat['label'][i] == 1: 322 | if query_id not in gold_results: 323 | gold_results[query_id] = [] 324 | gold_results[query_id].append((query_id, c_id, 1)) 325 | # 326 | # compute performance by accuracy task 4 327 | stats_task4 = {'pred': [], 'gold': []} 328 | for q_id in gold_all_q_ids: 329 | if q_id in gold_results: 330 | stats_task4['gold'].append((q_id, True)) 331 | else: 332 | stats_task4['gold'].append((q_id, False)) 333 | 334 | if q_id in predicted_results: 335 | stats_task4['pred'].append((q_id, True)) 336 | else: 337 | stats_task4['pred'].append((q_id, False)) 338 | right_count = len(set(stats_task4['pred']).intersection( 339 | set(stats_task4['gold']))) 340 | stats_task4['acc'] = right_count / len(gold_all_q_ids) 341 | stats_task4['correct_count'] = right_count 342 | stats_task4['total'] = len(gold_all_q_ids) 343 | 344 | # 345 | # compute performance by some metrics 346 | stats_result = {} 347 | for q_id in gold_all_q_ids: 348 | stats_result[q_id] = {} 349 | if q_id not in gold_results or q_id not in predicted_results: 350 | stats_result[q_id]['pred'] = [x[1] 351 | for x in predicted_results.get(q_id, [])] 352 | stats_result[q_id]['enssemble_score'] = [x[2].item() 353 | for x in predicted_results.get(q_id, [])] 354 | stats_result[q_id]['gold'] = [] 355 | stats_result[q_id]["P"] = 0 356 | stats_result[q_id]["R"] = 0 357 | stats_result[q_id]["F2"] = 0 358 | else: 359 | articles_prediction = [x[1]for x in predicted_results[q_id]] 360 | articles_gold = [x[1]for x in gold_results[q_id]] 361 | stats_result[q_id]['pred'] = articles_prediction 362 | stats_result[q_id]['enssemble_score'] = [x[2].item() 363 | for x in predicted_results[q_id]] 364 | stats_result[q_id]['gold'] = articles_gold 365 | count_true = len( 366 | set(articles_prediction).intersection(set(articles_gold))) 367 | stats_result[q_id]["P"] = count_true / \ 368 | len(set(articles_prediction)) 369 | stats_result[q_id]["R"] = count_true / len(set(articles_gold)) 370 | stats_result[q_id]["F2"] = f_score( stats_result[q_id]["P"], stats_result[q_id]["R"], beta=2) 371 | 372 | stats_result[q_id]['found_by_model'] = q_id not in miss_prediction_keys 373 | stats_result[q_id]['detail_scores'] = [individual_model_stats[i][q_id] 374 | for i in range(len(prediction_files))] 375 | 376 | all_p = [stats_result[q_id]['P'] for q_id in stats_result] 377 | p = sum(all_p) / len(all_p) 378 | 379 | all_r = [stats_result[q_id]['R'] for q_id in stats_result] 380 | r = sum(all_r) / len(all_r) 381 | 382 | all_f2 = [stats_result[q_id]['F2'] for q_id in stats_result] 383 | macro_f2 = sum(all_f2) / len(all_f2) 384 | 385 | f2 = f_score(p, r, beta=2) 386 | 387 | overall_result = {'p': p, 'r': r, 'f2': f2, 'macro_f2': macro_f2, 'acc_task4': stats_task4} 388 | stats_result.update(overall_result) 389 | # pprint(stats_result) 390 | print('task 4:', "{:2.2f}".format(stats_task4['acc']*100), stats_task4['correct_count'], stats_task4['total']) 391 | 392 | return stats_result 393 | 394 | 395 | def generate_file_submission(stats_result: Dict[str, Any], file_name: str, topk: int = None): 396 | predictions = {} 397 | for q_id, a_info in stats_result.items(): 398 | if '-' not in q_id: 399 | continue 400 | if q_id not in predictions: 401 | predictions[q_id] = [] 402 | if topk is None: 403 | for i_pred, pred in enumerate(zip(a_info['pred'], a_info['enssemble_score'])): 404 | predictions[q_id].append((q_id, pred[0], pred[1])) 405 | else: 406 | enssemble_scores = {} 407 | # aggregate all score 408 | for scores_model_i in a_info['detail_scores']: 409 | for score in scores_model_i: 410 | a_id = score[0] 411 | score_raw = score[1] 412 | if a_id not in enssemble_scores: 413 | enssemble_scores[a_id] = [] 414 | enssemble_scores[a_id].append(score_raw) 415 | # get mean all score 416 | for a_id in enssemble_scores: 417 | enssemble_scores[a_id] = sum(enssemble_scores[a_id]) / len(enssemble_scores[a_id]) # max(enssemble_scores[a_id]) # 418 | 419 | for a_id, score_enss in enssemble_scores.items(): 420 | predictions[q_id].append((q_id, a_id, score_enss)) 421 | 422 | keys_ = predictions.keys() 423 | for query_id in keys_: 424 | predictions[query_id].sort(key=lambda x: x[2], reverse=True) 425 | if topk is not None: 426 | # if len(predictions[query_id]) < topk: 427 | # print("exception in {}, countpred = {}, topk={}".format(query_id, len(predictions[query_id]), topk)) 428 | predictions[query_id] = predictions[query_id][:topk] 429 | prediction_str = [] 430 | for query_id in keys_: 431 | for i, prediction_info in enumerate(predictions[query_id]): 432 | template = "{} {} {} {} {:.9f} {}" 433 | 434 | # H30-1-A Q0 886 1 0.193 JNLP 435 | prediction_str.append( 436 | template.format(query_id, "Q0", prediction_info[1], i + 1, prediction_info[2], "JNLP")) 437 | 438 | with open(file_name, "wt", encoding="utf8") as f: 439 | f.write("\n".join(prediction_str)) 440 | 441 | 442 | def generate_file_submission_task4(stats_results_in: List[Any], file_name: str, topk: int = None): 443 | stats_results = [] 444 | gold = {} 445 | enss_res = {} 446 | 447 | for stats_info in stats_results_in: 448 | stats_results.append(dict(stats_info['acc_task4']['pred'])) 449 | if len(gold) == 0: 450 | gold = dict(stats_info['acc_task4']['gold']) 451 | 452 | for k, v in gold.items(): 453 | count_true = sum([e[k] for e in stats_results]) 454 | if count_true >= len(stats_results) / 2.0: 455 | enss_res[k] = True 456 | else: 457 | enss_res[k] = False 458 | 459 | count_correct = 0 460 | for k, v in gold.items(): 461 | if enss_res[k] == v: 462 | count_correct += 1 463 | print(count_correct, len(gold), count_correct / len(gold)) 464 | 465 | prediction_str =[] 466 | for k, v in enss_res.items(): 467 | # H30-1-A Y JNLP 468 | template = "{} {} {}" 469 | prediction_str.append( template.format(k, "Y" if v else "N", "JNLP")) 470 | 471 | with open(file_name, "wt", encoding="ascii") as f: 472 | print("writing file: {} ...".format(file_name)) 473 | f.write("\n".join(prediction_str)) 474 | 475 | prediction_str =[] 476 | 477 | for i, sub_re in enumerate(stats_results): 478 | prediction_str = [] 479 | for k, v in sub_re.items(): 480 | # H30-1-A Y JNLP 481 | template = "{} {} {}" 482 | prediction_str.append( template.format(k, "Y" if v else "N", "JNLP")) 483 | 484 | sub_file_name = file_name.replace(".txt", ".{}.txt".format(i)) 485 | with open(sub_file_name, "wt", encoding="ascii") as f: 486 | print("writing file: {} ...".format(sub_file_name)) 487 | f.write("\n".join(prediction_str)) 488 | 489 | return enss_res 490 | 491 | 492 | def create_df(query_set,civil_dict, neg_sampling=True,top_bm25=10): 493 | res=[] 494 | corpus=list(civil_dict.values()) 495 | tokenized_corpus = [doc.split(" ") for doc in corpus] 496 | bm25 = BM25Okapi(tokenized_corpus) 497 | for query in query_set: 498 | #negative sampling (using bm25) 499 | neg_list=bm25.get_top_n(query["content"].split(" "), corpus, n=top_bm25) 500 | for idx in query["result"]: 501 | res.append({ 502 | "content":query["content"], 503 | "article_content":[sent.strip() for sent in civil_dict[idx].split("\n") if sent.strip()!=""], 504 | "article_id":idx, 505 | "label":1 506 | }) 507 | #remove positive sample from neg=list 508 | while civil_dict[idx] in neg_list: 509 | neg_list.remove(civil_dict[idx]) 510 | if neg_sampling: #if it is a train set or dev set 511 | for i in range(len(neg_list)): 512 | res.append({ 513 | "content":query["content"], 514 | "article_content":[sent.strip() for sent in neg_list[i].split("\n") if sent.strip()!=""], 515 | "article_id":"neg", 516 | "label":0 517 | }) 518 | return pd.DataFrame(res) --------------------------------------------------------------------------------