├── .DS_Store ├── MDERank ├── .DS_Store └── mderank_main.py ├── run.sh ├── utils ├── beam_search.py ├── length_figure.py ├── avg_f1.py ├── data_process.py ├── quick_rank.py ├── random_selection.py ├── statistic.py ├── cos_can_doc.py ├── attention.py └── cos_mask_doc.py ├── README.md └── .gitignore /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinhanZ/mderank/HEAD/.DS_Store -------------------------------------------------------------------------------- /MDERank/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinhanZ/mderank/HEAD/MDERank/.DS_Store -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #This is script for MDERank testing 2 | # test MDERank 3 | # Dataset name: Inspec, SemEval2010, SemEval2017, DUC2001, nus, krapivin 4 | # Please download data first and save in 'data' folder. 5 | dataset_name=Inspec 6 | CUDA_VISIBLE_DEVICES=0 python MDERank/mderank_kpebert.py --dataset_dir data/$dataset_name --batch_size 1 --distance cos --doc_embed_mode max \ 7 | --model_name_or_path model_name_or_path --log_dir log_path --dataset_name $dataset_name --layer_num -1 8 | 9 | -------------------------------------------------------------------------------- /utils/beam_search.py: -------------------------------------------------------------------------------- 1 | from math import log 2 | from numpy import array 3 | from numpy import nanargmax, nanargmin 4 | # beam search 5 | def beam_search_decoder(data, k): 6 | sequences = [[list(), 1.0]] #[[[], 1.0]] 7 | # walk over each step in sequence 8 | for row in data: 9 | all_candidates = list() 10 | # expand each current candidate 11 | for i in range(len(sequences)): 12 | seq, score = sequences[i] 13 | for j in range(len(row)): 14 | candidate = [seq + [j], score * -log(row[j])] 15 | all_candidates.append(candidate) 16 | # order all candidates by score 17 | ordered = sorted(all_candidates, key=lambda tup :tup[1]) 18 | # select k best 19 | sequences = ordered[:k] 20 | return sequences 21 | def greedy_decoder(data): 22 | # index for minist probability each row 23 | return [nanargmin(s) for s in data] 24 | # define a sequence of 10 words over a vocab of 5 words 25 | # data = [[0.1, 0.2, 0.3, 0.4, 0.5], 26 | # [0.5, 0.4, 0.3, 0.2, 0.1], 27 | # [0.1, 0.2, 0.3, 0.4, 0.5], 28 | # [0.5, 0.4, 0.3, 0.2, 0.1], 29 | # [0.1, 0.2, 0.3, 0.4, 0.5], 30 | # [0.5, 0.4, 0.3, 0.2, 0.1], 31 | # [0.1, 0.2, 0.3, 0.4, 0.5], 32 | # [0.5, 0.4, 0.3, 0.2, 0.1], 33 | # [0.1, 0.2, 0.3, 0.4, 0.5], 34 | # [0.5, 0.4, 0.3, 0.2, 0.1]] 35 | # data = array(data) 36 | # # decode sequence 37 | # result = beam_search_decoder(data, 3) 38 | # # print result 39 | # for seq in result: 40 | # print(seq) -------------------------------------------------------------------------------- /utils/length_figure.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | mde_5 = [12.86, 14.45, 15.24] 5 | mde_10 = [16.06, 16.01, 18.33] 6 | mde_15 = [16.67, 16.64, 17.95] 7 | pd_5 = [8.76, 5.86, 3.75] 8 | pd_10 = [14.75, 10.19, 6.34] 9 | pd_15 = [16.28, 12.90, 8.11] 10 | 11 | x = np.arange(3) 12 | x_l = [128, 256, 512] 13 | mde = [mde_5, mde_10, mde_15] 14 | pd = [pd_5, pd_10, pd_15] 15 | i = 0 16 | fig, axes = plt.subplots(1,3) 17 | plt.xticks(x, x_l) 18 | K = [5, 10, 15] 19 | 20 | 21 | axes[0].plot(x, mde_5, label="MDERank",color="royalblue", marker="o", markersize = 2) 22 | axes[0].plot(x, pd_5, label="Phrase-Document",color="orange", marker="p", markersize = 2) 23 | axes[1].plot(x, mde_10, label="MDERank",color="royalblue", marker="o", markersize = 2) 24 | axes[1].plot(x, pd_10, label="Phrase-Document",color="orange", marker="p", markersize = 2) 25 | axes[2].plot(x, mde_15, label="MDERank",color="royalblue", marker="o", markersize = 2) 26 | axes[2].plot(x, pd_15, label="Phrase-Document",color="orange", marker="p", markersize = 2) 27 | 28 | # axes[0].set_xticks(x_l) 29 | # axes[1].set_xticks(x_l) 30 | # axes[2].set_xticks(x_l) 31 | 32 | axes[0].set_title("F1@5") 33 | axes[1].set_title("F1@10") 34 | axes[2].set_title("F1@15") 35 | 36 | plt.subplots_adjust(left=0.125, 37 | bottom=0.1, 38 | right=0.9, 39 | top=0.9, 40 | wspace=0.2, 41 | hspace=0.35) 42 | plt.legend(["MDERank", "Phrase-Document"], loc='upper left') 43 | plt.show() 44 | 45 | plt.savefig("sequence_length.png") -------------------------------------------------------------------------------- /utils/avg_f1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | sifrank = [0.2938,0.3912, 0.3982,0.2238,0.3260,0.3725,0.1116,0.1603,0.1842,0.2430,0.2760,0.2796,0.0162,0.0252,0.0300,0.0301,0.0534,0.0586] 4 | sifrank_f5 = [sifrank[i]*100 for i in range(0,len(sifrank),3)] 5 | sifrank_f10 = [sifrank[i]*100 for i in range(1,len(sifrank),3)] 6 | sifrank_f15 = [sifrank[i]*100 for i in range(2,len(sifrank), 3)] 7 | 8 | print (sifrank_f15) 9 | print ("sifrank_f5: ",sum(sifrank_f5)/6) 10 | print ("sifrank_f10: ",sum(sifrank_f10)/6) 11 | print ("sifrank_f15: ",sum(sifrank_f15)/6) 12 | 13 | 14 | mderank = [0.2617,0.3381,0.3617,0.2281,0.3251,0.3718,0.1295,0.1707,0.2009,0.1305,0.1731,0.1913,0.1178,0.1293,0.1258,0.1524,0.1833,0.1795] 15 | mderank_f5 = [mderank[i]*100 for i in range(0,len(mderank),3)] 16 | mderank_f10 = [mderank[i]*100 for i in range(1,len(mderank),3)] 17 | mderank_f15 = [mderank[i]*100 for i in range(2,len(mderank), 3)] 18 | print ("-------------------\n") 19 | print ("mderank_f5: ",sum(mderank_f5)/6) 20 | print ("mderank_f10: ",sum(mderank_f10)/6) 21 | print ("mderank_f15: ",sum(mderank_f15)/6) 22 | 23 | bert_kp = [0.2806,0.3580,0.3743,0.2163,0.3223,0.3752,0.1295,0.1795,0.2069,0.2251,0.2697,0.2628,0.1291,0.1436,0.1358,0.1411,0.1772,0.1795] 24 | bert_kp_f5 = [bert_kp[i]*100 for i in range(0,len(bert_kp),3)] 25 | bert_kp_f10 = [bert_kp[i]*100 for i in range(1,len(bert_kp),3)] 26 | bert_kp_f15 = [bert_kp[i]*100 for i in range(2,len(bert_kp), 3)] 27 | 28 | print ("-------------------\n") 29 | print ("bert_kp_f5: ",sum(bert_kp_f5)/6) 30 | print ("bert_kp_f10: ",sum(bert_kp_f10)/6) 31 | print ("bert_kp_f15: ",sum(bert_kp_f15)/6) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mderank 2 | 3 | This is code for paper: MDERank: A Masked Document Embedding Rank Approach for Unsupervised Keyphrase Extraction 4 | Data is from [OpenNMT-kpg-release](https://github.com/memray/OpenNMT-kpg-release) and [SIFRank](https://github.com/sunyilgdx/SIFRank). 5 | (Inspec, DUC2001, SemEval2017 are from SIFRank). 6 | 7 | ## Table of Contents 8 | 9 | * [Environment](#installation) 10 | * [Usage](#usage) 11 | * [Cite](#citing-mderank) 12 | 13 | ## Environment 14 | ``` 15 | Python 3.7 16 | nltk 3.4.3 17 | StanfordCoreNLP 3.9.1.1 18 | torch 1.1.0 19 | allennlp 0.8.4 20 | pke 1.8.1 21 | transformer 4.14.1 22 | CUDA version 10.2 23 | ``` 24 | 25 | ## Usage 26 | We use run.sh script to run MDERank. 27 | ``` 28 | sh run.sh 29 | ``` 30 | --checkpoint is the model used for predictions. Initial MDERank use bert-base-uncased. 31 | 32 | ## Cite 33 | If you use this code, please cite this paper 34 | ``` 35 | @article{DBLP:journals/corr/abs-2110-06651, 36 | author = {Linhan Zhang and 37 | Qian Chen and 38 | Wen Wang and 39 | Chong Deng and 40 | Shiliang Zhang and 41 | Bing Li and 42 | Wei Wang and 43 | Xin Cao}, 44 | title = {MDERank: {A} Masked Document Embedding Rank Approach for Unsupervised 45 | Keyphrase Extraction}, 46 | journal = {CoRR}, 47 | volume = {abs/2110.06651}, 48 | year = {2021}, 49 | url = {https://arxiv.org/abs/2110.06651}, 50 | eprinttype = {arXiv}, 51 | eprint = {2110.06651}, 52 | timestamp = {Fri, 22 Oct 2021 13:33:09 +0200}, 53 | biburl = {https://dblp.org/rec/journals/corr/abs-2110-06651.bib}, 54 | bibsource = {dblp computer science bibliography, https://dblp.org} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /utils/data_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import codecs 4 | from tqdm import tqdm 5 | import json 6 | import re 7 | 8 | def load_dataset(file_path): 9 | """ Load file.jsonl .""" 10 | data_list = [] 11 | with codecs.open(file_path, 'r', 'utf-8') as f: 12 | json_text = f.readlines() 13 | for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."): 14 | try: 15 | jsonl = json.loads(line) 16 | data_list.append(jsonl) 17 | except: 18 | raise ValueError 19 | 20 | return data_list 21 | 22 | def generate_doc(dataset_dir, save_txt_path, save_label_path): 23 | 24 | doc_list = [] 25 | keyphrases = [] 26 | doc_tok_num = 0 27 | dataset = load_dataset(dataset_dir) 28 | txt_file = open(save_txt_path,"a") 29 | label_file = open(save_label_path,"a") 30 | for idx, example in enumerate(dataset): 31 | keywords = example['keywords'] 32 | if type(keywords) == str: 33 | label_file.write(keywords + "\n") 34 | else: 35 | keywords = ';'.join(keywords) 36 | label_file.write(keywords + "\n") 37 | abstract = example['abstract'] 38 | doc = abstract 39 | doc = re.sub('\. ', ' . ', doc) 40 | doc = re.sub(', ', ' , ', doc) 41 | doc_tok_num +=len(doc.split(' ')) 42 | txt_file.write(doc + "\n") 43 | 44 | doc_list.append(doc) 45 | keyphrases.append(keywords) 46 | txt_file.close() 47 | label_file.close() 48 | 49 | 50 | 51 | dataset = "/home/zhanglinhan.zlh/unsupervised_bert_kpe/Dataset/json/kp20k/kp20k_train.json" 52 | save_train_txt = "/home/zhanglinhan.zlh/kpe_test_experiment/data/kp20k_train_text.txt" 53 | save_label_txt = "/home/zhanglinhan.zlh/kpe_test_experiment/data/kp20k_train_label.txt" 54 | 55 | 56 | generate_doc(dataset, save_train_txt, save_label_txt) 57 | 58 | 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /utils/quick_rank.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | from collections import defaultdict 5 | textrank = {"inspec":[21.58,27.53,27.62], "sm2017":[16.43,25.83,30.50],"sm2010":[7.42,11.27,13.47], "duc":[11.02,17.45,18.84],"krapivin":[6.04,9.43,9.95],"nus":[1.80,3.02,3.53]} 6 | singlerank = {"inspec":[14.88,21.50,24.13], "sm2017":[18.23,27.73,31.73],"sm2010":[8.69,12.94,14.40], "duc":[19.14,23.86,23.43],"krapivin":[8.12,10.53,10.42],"nus":[2.98,4.51,4.92]} 7 | topicrank = {"inspec":[12.20,17.24,19.33], "sm2017":[17.10,22.62,24.87],"sm2010":[9.93,12.52,12.26], "duc":[19.97,21.73,20.97],"krapivin":[8.94,9.01,8.30],"nus":[4.54,7.93,9.37]} 8 | multipartite = {"inspec":[13.41,18.18,20.52],"sm2017":[17.39,23.73,26.87],"sm2010":[10.13,12.91,13.24], "duc":[21.70,24.10,23.62], "krapivin":[9.29,9.35,9.16], "nus":[6.17,8.57,10.82]} 9 | YAKE = {"inspec":[8.02,11.47,13.65], "sm2017":[11.84,18.14,20.55],"sm2010":[6.82,11.01,12.55], "duc":[11.99,14.18,14.28],"krapivin":[8.09,9.35,9.12],"nus":[7.85,11.05,13.09]} 10 | EmbedRank = {"inspec":[14.51,21.02,23.79], "sm2017":[20.21,29.59,33.94],"sm2010":[9.63,13.90,14.79], "duc":[21.75,25.09,24.68],"krapivin":[8.44,10.47,10.17],"nus":[2.13,2.94,3.56]} 11 | SIFRank = {"inspec":[29.38,39.12,39.82], "sm2017":[22.38,32.60,37.25],"sm2010":[11.16,16.03,18.42], "duc":[24.30,27.60,27.96],"krapivin":[1.62,2.52,3.00],"nus":[3.01,5.34,5.86]} 12 | PD = {"inspec":[28.92,38.55,39.77], "sm2017":[20.03,31.01,36.72],"sm2010":[10.46,16.35,19.35], "duc":[8.12,11.62,13.58],"krapivin":[4.05,6.60,7.84],"nus":[3.75,6.34,8.11]} 13 | # once = {"inspec":[27.93,37.38,39.11], "sm2017":[20.56,30.95,36.07],"sm2010":[10.16,15.40,17.69], "duc":[9.11,13.49,16.47],"krapivin":[4.61,7.21,8.15],"nus":[3.92,6.52,8.85]} 14 | # subset = {"inspec":[29.25,36.55,38.08], "sm2017":[21.50,31.30,36.67],"sm2010":[10.26,15.88,17.83], "duc":[12.05,16.73,19.19],"krapivin":[8.50,9.99,10.48],"nus":[9.61,13.43,14.65]} 15 | all = {"inspec":[26.17,33.81,36.17], "sm2017":[22.81,32.51,37.18],"sm2010":[12.95,17.07,20.09], "duc":[13.05,17.31,19.13],"krapivin":[11.78,12.93,12.58],"nus":[15.24,18.33,17.95]} 16 | ab = {"inspec":[28.06,35.80,37.43], "sm2017":[21.63,32.23,37.52],"sm2010":[12.95,17.95,20.69], "duc":[22.51,26.97,26.28],"krapivin":[12.91,14.36,13.58],"nus":[14.11,17.72,17.95]} 17 | re = {"inspec":[27.85,34.36,36.40], "sm2017":[20.37,31.21,36.63],"sm2010":[13.05,18.27,20.35], "duc":[23.31,26.65,26.42],"krapivin":[12.35,14.31,13.31],"nus":[14.39,18.46,19.41]} 18 | methods = [textrank, singlerank, topicrank, multipartite, YAKE, EmbedRank, SIFRank, PD, all, ab, re] 19 | name = ["textrank", "singlerank", "topicrank", "multipartite", "YAKE", "EmbedRank", "SIFRank", "PD", "all", "ab", "re"] 20 | #cal avg 21 | data_avg = dict() 22 | for i, m in enumerate(methods): 23 | l5, l10, l15 = [],[],[] 24 | for ds, f1 in m.items(): 25 | l5.append(f1[0]) 26 | l10.append(f1[1]) 27 | l15.append(f1[2]) 28 | avg5 = round(sum(l5)/len(l5),2) 29 | avg10 = round(sum(l10)/len(l10),2) 30 | avg15 = round(sum(l15)/len(l15),2) 31 | data_avg[i] = [avg5, avg10, avg15] 32 | 33 | #cal rank 34 | ds_name = textrank.keys() 35 | np.set_printoptions(precision=2) 36 | m_array = [] 37 | for i, m in enumerate(methods): 38 | array = [] 39 | for ds, f1 in m.items(): 40 | array.append(f1) 41 | m_array.append(array) 42 | m_array = np.array(m_array) 43 | total_rank = [] 44 | for i, n in enumerate([5,10,15]): 45 | rank = np.zeros((11,6)) 46 | df = pd.DataFrame(m_array[:, :, i], columns=ds_name, index=name) 47 | for n, ds in enumerate(ds_name): 48 | rank[:,n]=list(df[ds].rank(method="first", ascending=False).values) 49 | avg = np.nanmean(rank, axis=1) 50 | print ("avg: ", avg) 51 | std = np.std(rank, axis=1) 52 | print ("std: ", std) -------------------------------------------------------------------------------- /utils/random_selection.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import re, os 4 | import nltk 5 | from stanfordcorenlp import StanfordCoreNLP 6 | import time 7 | from tqdm import tqdm 8 | from nltk.corpus import stopwords 9 | import codecs 10 | import json 11 | en_model = StanfordCoreNLP(r'/home/zhanglinhan.zlh/SIFRank/stanford-corenlp-full-2018-02-27',quiet=True) 12 | stopword_dict = set(stopwords.words('english')) 13 | MAX_LEN =512 14 | 15 | GRAMMAR1 = """ NP: 16 | {*} # Adjective(s)(optional) + Noun(s)""" 17 | 18 | GRAMMAR2 = """ NP: 19 | {*{0,3}} # Adjective(s)(optional) + Noun(s)""" 20 | 21 | GRAMMAR3 = """ NP: 22 | {*} # Adjective(s)(optional) + Noun(s)""" 23 | 24 | def read_file(input_path): 25 | with open(input_path, 'r', errors='replace_with_space') as input_file: 26 | return input_file.read() 27 | 28 | def clean_text(text="",database="Inspec"): 29 | 30 | #Specially for Duc2001 Database 31 | if(database=="Duc2001" or database=="Semeval2017"): 32 | pattern2 = re.compile(r'[\s,]' + '[\n]{1}') 33 | while (True): 34 | if (pattern2.search(text) is not None): 35 | position = pattern2.search(text) 36 | start = position.start() 37 | end = position.end() 38 | # start = int(position[0]) 39 | text_new = text[:start] + "\n" + text[start + 2:] 40 | text = text_new 41 | else: 42 | break 43 | 44 | pattern2 = re.compile(r'[a-zA-Z0-9,\s]' + '[\n]{1}') 45 | while (True): 46 | if (pattern2.search(text) is not None): 47 | position = pattern2.search(text) 48 | start = position.start() 49 | end = position.end() 50 | # start = int(position[0]) 51 | text_new = text[:start + 1] + " " + text[start + 2:] 52 | text = text_new 53 | else: 54 | break 55 | 56 | pattern3 = re.compile(r'\s{2,}') 57 | while (True): 58 | if (pattern3.search(text) is not None): 59 | position = pattern3.search(text) 60 | start = position.start() 61 | end = position.end() 62 | # start = int(position[0]) 63 | text_new = text[:start + 1] + "" + text[start + 2:] 64 | text = text_new 65 | else: 66 | break 67 | 68 | pattern1 = re.compile(r'[<>[\]{}]') 69 | text = pattern1.sub(' ', text) 70 | text = text.replace("\t", " ") 71 | text = text.replace(' p ','\n') 72 | text = text.replace(' /p \n','\n') 73 | lines = text.splitlines() 74 | # delete blank line 75 | text_new="" 76 | for line in lines: 77 | if(line!='\n'): 78 | text_new+=line+'\n' 79 | 80 | return text_new 81 | 82 | 83 | def get_inspec_data(file_path="data/Inspec"): 84 | 85 | data={} 86 | labels={} 87 | for dirname, dirnames, filenames in os.walk(file_path): 88 | for fname in filenames: 89 | left, right = fname.split('.') 90 | if (right == "abstr"): 91 | infile = os.path.join(dirname, fname) 92 | f=open(infile) 93 | text=f.read() 94 | text = text.replace("%", '') 95 | text=clean_text(text) 96 | data[left]=text 97 | if (right == "uncontr"): 98 | infile = os.path.join(dirname, fname) 99 | f=open(infile) 100 | text=f.read() 101 | text=text.replace("\n",' ') 102 | text=clean_text(text,database="Inspec") 103 | text=text.lower() 104 | label=text.split("; ") 105 | labels[left]=label 106 | return data,labels 107 | 108 | def extract_candidates(tokens_tagged, no_subset=False): 109 | """ 110 | Based on part of speech return a list of candidate phrases 111 | :param text_obj: Input text Representation see @InputTextObj 112 | :param no_subset: if true won't put a candidate which is the subset of an other candidate 113 | :return keyphrase_candidate: list of list of candidate phrases: [tuple(string,tuple(start_index,end_index))] 114 | """ 115 | np_parser = nltk.RegexpParser(GRAMMAR1) # Noun phrase parser 116 | keyphrase_candidate = [] 117 | np_pos_tag_tokens = np_parser.parse(tokens_tagged) 118 | count = 0 119 | for token in np_pos_tag_tokens: 120 | if (isinstance(token, nltk.tree.Tree) and token._label == "NP"): 121 | np = ' '.join(word for word, tag in token.leaves()) 122 | length = len(token.leaves()) 123 | start_end = (count, count + length) 124 | count += length 125 | keyphrase_candidate.append((np, start_end)) 126 | 127 | else: 128 | count += 1 129 | 130 | return keyphrase_candidate 131 | 132 | 133 | def get_long_data(file_path="data/nus/nus_test.json"): 134 | """ Load file.jsonl .""" 135 | data = {} 136 | labels = {} 137 | with codecs.open(file_path, 'r', 'utf-8') as f: 138 | json_text = f.readlines() 139 | for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."): 140 | try: 141 | jsonl = json.loads(line) 142 | keywords = jsonl['keywords'].lower().split(";") 143 | abstract = jsonl['abstract'] 144 | fulltxt = jsonl['fulltext'] 145 | doc = ' '.join([abstract, fulltxt]) 146 | doc = re.sub('\. ', ' . ', doc) 147 | doc = re.sub(', ', ' , ', doc) 148 | 149 | doc = clean_text(doc, database="nus") 150 | doc = doc.replace('\n', ' ') 151 | data[jsonl['name']] = doc 152 | labels[jsonl['name']] = keywords 153 | except: 154 | raise ValueError 155 | return data,labels 156 | 157 | def get_short_data(file_path="data/kp20k/kp20k_valid2k_test.json"): 158 | """ Load file.jsonl .""" 159 | data = {} 160 | labels = {} 161 | with codecs.open(file_path, 'r', 'utf-8') as f: 162 | json_text = f.readlines() 163 | for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."): 164 | try: 165 | jsonl = json.loads(line) 166 | keywords = jsonl['keywords'].lower().split(";") 167 | abstract = jsonl['abstract'] 168 | doc =abstract 169 | doc = re.sub('\. ', ' . ', doc) 170 | doc = re.sub(', ', ' , ', doc) 171 | 172 | doc = clean_text(doc, database="kp20k") 173 | doc = doc.replace('\n', ' ') 174 | data[i] = doc 175 | labels[i] = keywords 176 | except: 177 | raise ValueError 178 | return data,labels 179 | 180 | 181 | def get_duc2001_data(file_path="data/DUC2001"): 182 | pattern = re.compile(r'(.*?)', re.S) 183 | data = {} 184 | labels = {} 185 | for dirname, dirnames, filenames in os.walk(file_path): 186 | for fname in filenames: 187 | if (fname == "annotations.txt"): 188 | # left, right = fname.split('.') 189 | infile = os.path.join(dirname, fname) 190 | f = open(infile,'rb') 191 | text = f.read().decode('utf8') 192 | lines = text.splitlines() 193 | for line in lines: 194 | left, right = line.split("@") 195 | d = right.split(";")[:-1] 196 | l = left 197 | labels[l] = d 198 | f.close() 199 | else: 200 | infile = os.path.join(dirname, fname) 201 | f = open(infile,'rb') 202 | text = f.read().decode('utf8') 203 | text = re.findall(pattern, text)[0] 204 | 205 | text = text.lower() 206 | text = clean_text(text,database="Duc2001") 207 | data[fname]=text.strip("\n") 208 | # data[fname] = text 209 | return data,labels 210 | 211 | def get_inspec_data(file_path="data/Inspec"): 212 | 213 | data={} 214 | labels={} 215 | for dirname, dirnames, filenames in os.walk(file_path): 216 | for fname in filenames: 217 | left, right = fname.split('.') 218 | if (right == "abstr"): 219 | infile = os.path.join(dirname, fname) 220 | f=open(infile) 221 | text=f.read() 222 | text = text.replace("%", '') 223 | text=clean_text(text) 224 | data[left]=text 225 | if (right == "uncontr"): 226 | infile = os.path.join(dirname, fname) 227 | f=open(infile) 228 | text=f.read() 229 | text=text.replace("\n",' ') 230 | text=clean_text(text,database="Inspec") 231 | text=text.lower() 232 | label=text.split("; ") 233 | labels[left]=label 234 | return data,labels 235 | 236 | def get_semeval2017_data(data_path="data/SemEval2017/docsutf8",labels_path="data/SemEval2017/keys"): 237 | 238 | data={} 239 | labels={} 240 | for dirname, dirnames, filenames in os.walk(data_path): 241 | for fname in filenames: 242 | left, right = fname.split('.') 243 | infile = os.path.join(dirname, fname) 244 | # f = open(infile, 'rb') 245 | # text = f.read().decode('utf8') 246 | with codecs.open(infile, "r", "utf-8") as fi: 247 | text = fi.read() 248 | text = text.replace("%", '') 249 | text = clean_text(text,database="Semeval2017") 250 | data[left] = text.lower() 251 | # f.close() 252 | for dirname, dirnames, filenames in os.walk(labels_path): 253 | for fname in filenames: 254 | left, right = fname.split('.') 255 | infile = os.path.join(dirname, fname) 256 | f = open(infile, 'rb') 257 | text = f.read().decode('utf8') 258 | text = text.strip() 259 | ls=text.splitlines() 260 | labels[left] = ls 261 | f.close() 262 | return data,labels 263 | 264 | class InputTextObj: 265 | """Represent the input text in which we want to extract keyphrases""" 266 | 267 | def __init__(self, en_model, text=""): 268 | """ 269 | :param is_sectioned: If we want to section the text. 270 | :param en_model: the pipeline of tokenization and POS-tagger 271 | :param considered_tags: The POSs we want to keep 272 | """ 273 | self.considered_tags = {'NN', 'NNS', 'NNP', 'NNPS', 'JJ'} 274 | 275 | self.tokens = [] 276 | self.tokens_tagged = [] 277 | self.tokens = en_model.word_tokenize(text) 278 | self.tokens_tagged = en_model.pos_tag(text) 279 | assert len(self.tokens) == len(self.tokens_tagged) 280 | for i, token in enumerate(self.tokens): 281 | if token.lower() in stopword_dict: 282 | self.tokens_tagged[i] = (token, "IN") 283 | self.keyphrase_candidate = extract_candidates(self.tokens_tagged, en_model) 284 | 285 | if __name__ == '__main__': 286 | 287 | dataset_name = "krapivin" 288 | dataset_dir = "../SIFRank/data/" 289 | if dataset_name == "SemEval2017": 290 | data, referneces = get_semeval2017_data(dataset_dir +dataset_name + "/docsutf8", dataset_dir + dataset_name + "/keys") 291 | elif dataset_name == "DUC2001": 292 | data, referneces = get_duc2001_data(dataset_dir + dataset_name) 293 | elif dataset_name == "nus": 294 | data, referneces = get_long_data(dataset_dir + dataset_name+ "/nus_test.json") 295 | elif dataset_name == "krapivin": 296 | data, referneces = get_long_data(dataset_dir + dataset_name +"/krapivin_test.json") 297 | elif dataset_name == "kp20k": 298 | data, referneces = get_short_data(dataset_dir + dataset_name +"/kp20k_valid2k_test.json") 299 | elif dataset_name == "SemEval2010": 300 | data, referneces = get_short_data(dataset_dir + dataset_name +"/semeval_test.json") 301 | else: 302 | data, referneces = get_inspec_data(dataset_dir + dataset_name) 303 | 304 | 305 | docs_pairs = [] 306 | doc_list = [] 307 | key_list = [] 308 | labels = [] 309 | labels_stemed = [] 310 | porter = nltk.PorterStemmer() 311 | set_cans_num = [] 312 | set_cans = [] 313 | max_can_num = 0 314 | max_reference_num = 0 315 | for idx, (key, doc) in enumerate(data.items()): 316 | set_can = set() 317 | labels.append([ref.replace(" \n", "") for ref in referneces[key]]) 318 | labels_s = [] 319 | set_total_cans = set() 320 | for l in referneces[key]: 321 | tokens = l.split() 322 | labels_s.append(' '.join(porter.stem(t) for t in tokens)) 323 | if len(labels_s) > max_reference_num: 324 | max_reference_num = len(labels_s) 325 | labels_stemed.append(labels_s) 326 | try: 327 | text_obj = InputTextObj(en_model, doc) 328 | doc_list.append(doc) 329 | except: 330 | continue 331 | cans = text_obj.keyphrase_candidate 332 | candidates = [] 333 | for can, pos in cans: 334 | set_can.add(can) 335 | set_cans.append(set_can) 336 | set_cans_num.append(len(set_can)) 337 | if len(set_can) > max_can_num: 338 | max_can_num = len(set_can) 339 | 340 | for t in range(100): 341 | f1_doc =[] 342 | for idx, doc in enumerate(doc_list): 343 | print("Calculating doc {} in process {}".format(idx, t)) 344 | doc_p = 0 345 | doc_r = 0 346 | doc_f = 0 347 | doc_cans = list(set_cans[idx]) 348 | doc_ref = labels[idx] 349 | doc_shuffled_cans = random.sample(doc_cans, len(doc_cans)) 350 | 351 | f1_x = [] 352 | for x in range(1,max_can_num): 353 | candidates = doc_shuffled_cans[:x] 354 | f1_y = [] 355 | for y in range(1, max_reference_num): 356 | selection = random.choices(candidates, k = y) 357 | m = 0 358 | for s in selection: 359 | if s in doc_ref: 360 | m +=1 361 | p = m/len(candidates) 362 | r = m/len(doc_ref) 363 | if (p + r == 0.0): 364 | f1 = 0 365 | else: 366 | f1 = 2 * p * r / (p + r) 367 | f1_y.append(f1) 368 | exp_f1_y = sum(f1_y)/len(f1_y) 369 | f1_x.append(exp_f1_y) 370 | exp_f1_x = sum(f1_x)/len(f1_x) 371 | f1_doc.append(exp_f1_x) 372 | exp_f1_doc = sum(f1_doc)/len(f1_doc) 373 | print("Expect f1: ", exp_f1_doc) 374 | -------------------------------------------------------------------------------- /utils/statistic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import codecs 7 | from nltk.stem import PorterStemmer 8 | from stanfordcorenlp import StanfordCoreNLP 9 | from nltk.corpus import stopwords 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | porter = PorterStemmer() 14 | en_model = StanfordCoreNLP(r'/home/zhanglinhan.zlh/SIFRank/stanford-corenlp-full-2018-02-27',quiet=True) 15 | 16 | stopword_dict = set(stopwords.words('english')) 17 | 18 | GRAMMAR1 = """ NP: 19 | {*} # Adjective(s)(optional) + Noun(s)""" 20 | 21 | GRAMMAR2 = """ NP: 22 | {*{0,3}} # Adjective(s)(optional) + Noun(s)""" 23 | 24 | GRAMMAR3 = """ NP: 25 | {*} # Adjective(s)(optional) + Noun(s)""" 26 | 27 | 28 | def extract_candidates(tokens_tagged, no_subset=False): 29 | """ 30 | Based on part of speech return a list of candidate phrases 31 | :param text_obj: Input text Representation see @InputTextObj 32 | :param no_subset: if true won't put a candidate which is the subset of an other candidate 33 | :return keyphrase_candidate: list of list of candidate phrases: [tuple(string,tuple(start_index,end_index))] 34 | """ 35 | np_parser = nltk.RegexpParser(GRAMMAR1) # Noun phrase parser 36 | keyphrase_candidate = [] 37 | np_pos_tag_tokens = np_parser.parse(tokens_tagged) 38 | count = 0 39 | for token in np_pos_tag_tokens: 40 | if (isinstance(token, nltk.tree.Tree) and token._label == "NP"): 41 | np = ' '.join(word for word, tag in token.leaves()) 42 | length = len(token.leaves()) 43 | start_end = (count, count + length) 44 | count += length 45 | keyphrase_candidate.append((np, start_end)) 46 | 47 | else: 48 | count += 1 49 | 50 | return keyphrase_candidate 51 | 52 | class InputTextObj: 53 | """Represent the input text in which we want to extract keyphrases""" 54 | 55 | def __init__(self, en_model, text=""): 56 | """ 57 | :param is_sectioned: If we want to section the text. 58 | :param en_model: the pipeline of tokenization and POS-tagger 59 | :param considered_tags: The POSs we want to keep 60 | """ 61 | self.considered_tags = {'NN', 'NNS', 'NNP', 'NNPS', 'JJ'} 62 | 63 | self.tokens = [] 64 | self.tokens_tagged = [] 65 | self.tokens = en_model.word_tokenize(text) 66 | self.tokens_tagged = en_model.pos_tag(text) 67 | assert len(self.tokens) == len(self.tokens_tagged) 68 | for i, token in enumerate(self.tokens): 69 | if token.lower() in stopword_dict: 70 | self.tokens_tagged[i] = (token, "IN") 71 | self.keyphrase_candidate = extract_candidates(self.tokens_tagged, en_model) 72 | 73 | 74 | def clean_text(text="",database="Inspec"): 75 | 76 | #Specially for Duc2001 Database 77 | if(database=="Duc2001" or database=="Semeval2017"): 78 | pattern2 = re.compile(r'[\s,]' + '[\n]{1}') 79 | while (True): 80 | if (pattern2.search(text) is not None): 81 | position = pattern2.search(text) 82 | start = position.start() 83 | end = position.end() 84 | # start = int(position[0]) 85 | text_new = text[:start] + "\n" + text[start + 2:] 86 | text = text_new 87 | else: 88 | break 89 | 90 | pattern2 = re.compile(r'[a-zA-Z0-9,\s]' + '[\n]{1}') 91 | while (True): 92 | if (pattern2.search(text) is not None): 93 | position = pattern2.search(text) 94 | start = position.start() 95 | end = position.end() 96 | # start = int(position[0]) 97 | text_new = text[:start + 1] + " " + text[start + 2:] 98 | text = text_new 99 | else: 100 | break 101 | 102 | pattern3 = re.compile(r'\s{2,}') 103 | while (True): 104 | if (pattern3.search(text) is not None): 105 | position = pattern3.search(text) 106 | start = position.start() 107 | end = position.end() 108 | # start = int(position[0]) 109 | text_new = text[:start + 1] + "" + text[start + 2:] 110 | text = text_new 111 | else: 112 | break 113 | 114 | pattern1 = re.compile(r'[<>[\]{}]') 115 | text = pattern1.sub(' ', text) 116 | text = text.replace("\t", " ") 117 | text = text.replace(' p ','\n') 118 | text = text.replace(' /p \n','\n') 119 | lines = text.splitlines() 120 | # delete blank line 121 | text_new="" 122 | for line in lines: 123 | if(line!='\n'): 124 | text_new+=line+'\n' 125 | 126 | return text_new 127 | 128 | 129 | 130 | def get_long_data(file_path="data/nus/nus_test.json"): 131 | """ Load file.jsonl .""" 132 | data = {} 133 | labels = {} 134 | with codecs.open(file_path, 'r', 'utf-8') as f: 135 | json_text = f.readlines() 136 | for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."): 137 | try: 138 | jsonl = json.loads(line) 139 | keywords = jsonl['keywords'].lower().split(";") 140 | abstract = jsonl['abstract'] 141 | fulltxt = jsonl['fulltext'] 142 | doc = ' '.join([abstract, fulltxt]) 143 | doc = re.sub('\. ', ' . ', doc) 144 | doc = re.sub(', ', ' , ', doc) 145 | 146 | doc = clean_text(doc, database="nus") 147 | doc = doc.replace('\n', ' ') 148 | data[jsonl['name']] = doc 149 | labels[jsonl['name']] = keywords 150 | except: 151 | raise ValueError 152 | return data,labels 153 | 154 | def get_duc2001_data(file_path="data/DUC2001"): 155 | pattern = re.compile(r'(.*?)', re.S) 156 | data = {} 157 | labels = {} 158 | for dirname, dirnames, filenames in os.walk(file_path): 159 | for fname in filenames: 160 | if (fname == "annotations.txt"): 161 | # left, right = fname.split('.') 162 | infile = os.path.join(dirname, fname) 163 | f = open(infile,'rb') 164 | text = f.read().decode('utf8') 165 | lines = text.splitlines() 166 | for line in lines: 167 | left, right = line.split("@") 168 | d = right.split(";")[:-1] 169 | l = left 170 | labels[l] = d 171 | f.close() 172 | else: 173 | infile = os.path.join(dirname, fname) 174 | f = open(infile,'rb') 175 | text = f.read().decode('utf8') 176 | text = re.findall(pattern, text)[0] 177 | 178 | text = text.lower() 179 | text = clean_text(text,database="Duc2001") 180 | data[fname]=text.strip("\n") 181 | # data[fname] = text 182 | return data,labels 183 | 184 | def get_inspec_data(file_path="data/Inspec"): 185 | 186 | data={} 187 | labels={} 188 | for dirname, dirnames, filenames in os.walk(file_path): 189 | for fname in filenames: 190 | left, right = fname.split('.') 191 | if (right == "abstr"): 192 | infile = os.path.join(dirname, fname) 193 | f=open(infile) 194 | text=f.read() 195 | text = text.replace("%", '') 196 | text=clean_text(text) 197 | data[left]=text 198 | if (right == "uncontr"): 199 | infile = os.path.join(dirname, fname) 200 | f=open(infile) 201 | text=f.read() 202 | text=text.replace("\n",' ') 203 | text=clean_text(text,database="Inspec") 204 | text=text.lower() 205 | label=text.split("; ") 206 | labels[left]=label 207 | return data,labels 208 | 209 | def get_semeval2017_data(data_path="data/SemEval2017/docsutf8",labels_path="data/SemEval2017/keys"): 210 | 211 | data={} 212 | labels={} 213 | for dirname, dirnames, filenames in os.walk(data_path): 214 | for fname in filenames: 215 | left, right = fname.split('.') 216 | infile = os.path.join(dirname, fname) 217 | # f = open(infile, 'rb') 218 | # text = f.read().decode('utf8') 219 | with codecs.open(infile, "r", "utf-8") as fi: 220 | text = fi.read() 221 | text = text.replace("%", '') 222 | text = clean_text(text,database="Semeval2017") 223 | data[left] = text.lower() 224 | # f.close() 225 | for dirname, dirnames, filenames in os.walk(labels_path): 226 | for fname in filenames: 227 | left, right = fname.split('.') 228 | infile = os.path.join(dirname, fname) 229 | f = open(infile, 'rb') 230 | text = f.read().decode('utf8') 231 | text = text.strip() 232 | ls=text.splitlines() 233 | labels[left] = ls 234 | f.close() 235 | return data,labels 236 | 237 | def get_short_data(file_path="data/kp20k/kp20k_valid2k_test.json"): 238 | """ Load file.jsonl .""" 239 | data = {} 240 | labels = {} 241 | with codecs.open(file_path, 'r', 'utf-8') as f: 242 | json_text = f.readlines() 243 | for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."): 244 | try: 245 | jsonl = json.loads(line) 246 | keywords = jsonl['keywords'].lower().split(";") 247 | abstract = jsonl['abstract'] 248 | doc = abstract 249 | doc = re.sub('\. ', ' . ', doc) 250 | doc = re.sub(', ', ' , ', doc) 251 | 252 | doc = clean_text(doc, database="kp20k") 253 | doc = doc.replace('\n', ' ') 254 | data[i] = doc 255 | labels[i] = keywords 256 | except: 257 | raise ValueError 258 | return data,labels 259 | 260 | 261 | 262 | def matched_label(doc, golds): 263 | matched_gold = [] 264 | for gold in golds: 265 | try: 266 | gold_pattern = re.compile(r"\b" + gold + r"\b") 267 | matched = gold_pattern.findall(doc) 268 | except: 269 | continue 270 | if len(matched)>=1: 271 | matched_gold.append(gold) 272 | matched_count = len(matched_gold) 273 | 274 | return matched_count 275 | 276 | 277 | 278 | dataset_dir = "/home/zhanglinhan.zlh/SIFRank-master/data" 279 | # dataset_name = "nus" 280 | 281 | figure_data = ["DUC2001","nus","krapivin"] 282 | x_axis = {} 283 | y_axis = {} 284 | max_length_list = {} 285 | all_max_length = 0 286 | for dataset_name in figure_data: 287 | if dataset_name == "SemEval2017": 288 | data, referneces = get_semeval2017_data(dataset_dir + "/SemEval2017/docsutf8", dataset_dir + "/SemEval2017/keys") 289 | elif dataset_name == "DUC2001": 290 | data, referneces = get_duc2001_data(dataset_dir+"/DUC2001") 291 | elif dataset_name == "nus": 292 | data, referneces = get_long_data(dataset_dir + "/nus/nus_test.json") 293 | elif dataset_name == "krapivin": 294 | data, referneces = get_long_data(dataset_dir + "/krapivin/krapivin_test.json") 295 | elif dataset_name == "kp20k": 296 | data, referneces = get_short_data(dataset_dir + "/kp20k/kp20k_valid2k_test.json") 297 | elif dataset_name == "SemEval2010": 298 | data, referneces = get_short_data(dataset_dir + "/SemEval2010/semeval_test.json") 299 | else: 300 | data, referneces = get_inspec_data(dataset_dir+'/Inspec') 301 | print(len(data)) 302 | print(len(referneces)) 303 | labels = [] 304 | labels_stemed = [] 305 | doc_list = [] 306 | total_words_num = 0 307 | max_doc_length = 0 308 | token_num_total = 0 309 | total_label_num = 0 310 | total_words_num = 0 311 | 312 | for idx, (key, doc) in tqdm(enumerate(data.items()), desc="Importing document ..."): 313 | 314 | labels_doc = [ref.replace(" \n", "") for ref in referneces[key]] 315 | labels.append(labels_doc) 316 | labels_s = [] 317 | set_total_cans = set() 318 | token_num = 0 319 | for l in labels_doc: 320 | tokens = l.split() 321 | token_num += len(tokens) 322 | labels_s.append(' '.join(porter.stem(t) for t in tokens)) 323 | total_label_num +=len(labels_doc) 324 | token_num_total += token_num/len(labels_doc) 325 | labels_stemed.append(labels_s) 326 | # if len(doc.split()) > 510: 327 | # doc = ' '.join(doc.split()[:510]) 328 | total_words_num += len(doc.split()) 329 | if len(doc.split()) > max_doc_length: 330 | max_doc_length = len(doc.split()) 331 | doc_list.append(doc) 332 | 333 | max_length_list[dataset_name] = max_doc_length 334 | if max_doc_length > all_max_length: 335 | all_max_length = max_doc_length 336 | print("Avg token num: ", token_num_total/len(doc_list)) 337 | print("Avg labels: ", total_label_num/len(doc_list)) 338 | print("Words num: ", total_words_num/len(doc_list)) 339 | 340 | total_matched_count = 0 341 | total_matched_count = [] 342 | for id, doc in tqdm(enumerate(doc_list[:2]),desc="Making statistic ..."): 343 | golds = labels[id] 344 | label_num = len(golds) 345 | doc_tokens = doc.split() 346 | doc_length = len(doc.split()) 347 | matched_count = [0] 348 | for n in range(max_doc_length): 349 | doc_n = ' '.join(doc_tokens[:n]) 350 | if doc_length < n: 351 | matched_count.append(matched_count[-1]) 352 | else: 353 | matched_count_n = matched_label(doc_n, golds) 354 | matched_count.append(matched_count_n/label_num) 355 | 356 | total_matched_count.append(matched_count) 357 | total_label_num +=label_num 358 | total_matched_count = np.array(total_matched_count) 359 | avg_matched_count = np.nanmean(total_matched_count, axis=0) 360 | y_axis[dataset_name]= avg_matched_count 361 | 362 | 363 | # total_matched_count = total_matched_count/len(doc_list) 364 | # total_label_num = total_label_num/len(doc_list) 365 | # print("Avg Matched labels: ", total_matched_count) 366 | 367 | 368 | colors = {"DUC2001":'goldenrod', "nus":'darkseagreen', "krapivin":'olivedrab'} 369 | x = np.arange(0, all_max_length, 500) 370 | fig,plt = plt.subplots() 371 | for dataset_name, y in y_axis.items(): 372 | x_axis = np.arange(0, max_length_list[dataset_name],500) 373 | plt.plot(x[0:len(x_axis)], [y[i] for i in x_axis], label= dataset_name, color = colors[dataset_name], marker='o') 374 | # plt.bar(methods, [19.00,12.33,20.36,17.54,15.02], align = 'center', color=['gold','olivedrab','darkkhaki','darkseagreen','cadetblue'], alpha = .7 ) 375 | # plt.xlabel('Doc Words Num', fontsize=18) 376 | # plt.ylabel('Matched Words', fontsize=18) 377 | # plt.xticks(x_index, x) 378 | plt.legend(loc="upper left") 379 | plt.show() 380 | plt.savefig("length_match.png") 381 | # print("{} total_words_num: {}".format(dataset_name, total_words_num/len(doc_list))) 382 | 383 | -------------------------------------------------------------------------------- /utils/cos_can_doc.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch.utils.data import Dataset 4 | from tqdm import tqdm 5 | from transformers import BertForMaskedLM, BertTokenizer, BertModel 6 | from torch.utils.data import DataLoader 7 | import pandas as pd 8 | from pke.unsupervised import TextRank 9 | import numpy as np 10 | import logging 11 | import argparse 12 | import codecs 13 | import json 14 | import os 15 | import string 16 | import nltk 17 | from accelerate import Accelerator 18 | # nltk.download('averaged_perceptron_tagger') 19 | from nltk.stem import PorterStemmer 20 | import itertools 21 | 22 | MAX_LEN =512 23 | 24 | class Logger(object): 25 | level_relations = { 26 | 'debug': logging.DEBUG, 27 | 'info': logging.INFO, 28 | 'warning': logging.WARNING, 29 | 'error': logging.ERROR, 30 | 'crit': logging.CRITICAL 31 | } # 日志级别关系映射 32 | 33 | def __init__(self, filename, level='info'): 34 | 35 | self.logger = logging.getLogger(filename) 36 | # # format_str = logging.Formatter(fmt) # 设置日志格式 37 | # if args.local_rank == 0 : 38 | # level = level 39 | # else: 40 | # level = 'warning' 41 | self.logger.setLevel(self.level_relations.get(level)) # 设置日志级别 42 | sh = logging.StreamHandler() # 往屏幕上输出 43 | # sh.setFormatter(format_str) # 设置屏幕上显示的格式 44 | 45 | th = logging.FileHandler(filename,'w') 46 | formatter = logging.Formatter('%(asctime)s => %(name)s * %(levelname)s : %(message)s') 47 | th.setFormatter(formatter) 48 | 49 | self.logger.addHandler(sh) # 代表在屏幕上输出,如果注释掉,屏幕将不输出 50 | self.logger.addHandler(th) # 代表在log文件中输出,如果注释掉,将不再向文件中写入数据 51 | 52 | 53 | class PhraseKPE_Dataset(Dataset): 54 | 55 | def __init__(self, docs_pairs): 56 | 57 | self.docs_pairs = docs_pairs 58 | self.total_examples = len(self.docs_pairs) 59 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) 60 | 61 | def __len__(self): 62 | return self.total_examples 63 | 64 | def __getitem__(self, idx): 65 | 66 | doc_pair = self.docs_pairs[idx] 67 | ori_doc = doc_pair[0] 68 | candidate = doc_pair[1] 69 | doc_id = doc_pair[2] 70 | 71 | tokenized_ori_doc = self.tokenized_doc(ori_doc, self.tokenizer, candidate) 72 | tokenized_candidate = self.tokenized_doc(candidate, self.tokenizer, candidate) 73 | 74 | return [tokenized_ori_doc, tokenized_candidate, doc_id] 75 | 76 | def tokenized_doc(self, text, tokenizer, candidate): 77 | 78 | max_len = MAX_LEN 79 | 80 | encoded_dict = tokenizer.encode_plus( 81 | text, # Sentence to encode. 82 | add_special_tokens=True, # Add '[CLS]' and '[SEP]' 83 | max_length=max_len, # Pad & truncate all sentences. 84 | padding='max_length', 85 | return_attention_mask=True, # Construct attn. masks. 86 | return_tensors='pt', # Return pytorch tensors. 87 | truncation=True 88 | ) 89 | input_ids = encoded_dict["input_ids"] 90 | attention_mask = encoded_dict["attention_mask"] 91 | token_type_ids = encoded_dict["token_type_ids"] 92 | 93 | example = { 94 | "input_ids": input_ids, 95 | "token_type_ids": token_type_ids, 96 | "attention_mask": attention_mask, 97 | "candidate": candidate 98 | } 99 | 100 | return example 101 | 102 | def load_dataset(file_path): 103 | """ Load file.jsonl .""" 104 | data_list = [] 105 | with codecs.open(file_path, 'r', 'utf-8') as f: 106 | json_text = f.readlines() 107 | for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."): 108 | try: 109 | jsonl = json.loads(line) 110 | data_list.append(jsonl) 111 | except: 112 | raise ValueError 113 | 114 | return data_list 115 | 116 | def generate_doc(dataset_dir, dataset_name): 117 | 118 | doc_list = [] 119 | keyphrases = [] 120 | doc_tok_num = 0 121 | dataset = load_dataset(dataset_dir) 122 | for idx, example in enumerate(dataset): 123 | keywords = example['keywords'].lower() 124 | abstract = example['abstract'] 125 | doc = abstract 126 | doc = re.sub('\. ', ' . ', doc) 127 | doc = re.sub(', ', ' , ', doc) 128 | doc_tok_num +=len(doc.split(' ')) 129 | doc_list.append(doc) 130 | keyphrases.append(keywords) 131 | return doc_list, keyphrases, doc_tok_num/len(dataset) 132 | 133 | 134 | def extract_candidate_words(text, good_tags=set(['JJ','JJR','JJS','NN','NNP','NNS','NNPS'])): 135 | 136 | punct = set(string.punctuation) 137 | 138 | stop_words = set(nltk.corpus.stopwords.words('english')) 139 | tagged_words = itertools.chain.from_iterable(nltk.pos_tag_sents(nltk.word_tokenize(sent) for sent in nltk.sent_tokenize(text))) 140 | candidate_phrase = [] 141 | candidates = [] 142 | for word, tag in tagged_words: 143 | if tag in good_tags and word.lower() not in stop_words and not all(char in punct for char in word): 144 | candidate_phrase.append(word) 145 | continue 146 | else: 147 | if candidate_phrase: 148 | candidates.append(candidate_phrase) 149 | candidate_phrase = [] 150 | else: 151 | continue 152 | 153 | candiates_num = len(candidates) 154 | 155 | return candidates, candiates_num 156 | 157 | def dedup(candidates): 158 | new_can = {} 159 | for can in candidates: 160 | can_set = can.split() 161 | candidate_len = len(can_set) 162 | # can = ' '.join(can) 163 | new_can[can] = candidate_len 164 | 165 | return new_can 166 | 167 | 168 | def eval_metric(cans, refs): 169 | precision_scores, recall_scores, f1_scores = {5: [], 10: [], 15:[]},{5: [], 10: [], 15:[]},{5: [], 10: [], 15:[]} 170 | 171 | stemmer = PorterStemmer() 172 | references = refs.split(";") 173 | ref_num = len(references) 174 | 175 | for i, reference in enumerate(references): 176 | reference = stemmer.stem(reference.lower()) 177 | references[i] = reference.lower() 178 | candidates_clean = set() 179 | candidates = [] 180 | for i, can in enumerate(cans): 181 | can = stemmer.stem(can[0].lower()) 182 | if can in candidates_clean: 183 | continue 184 | else: 185 | candidates_clean.add(can) 186 | candidates.append(can) 187 | 188 | 189 | for topk in [5, 10, 15]: 190 | m_can = 0 191 | for i,candidate in enumerate(candidates[:topk]): 192 | if candidate in references: 193 | m_can += 1 194 | micropk = m_can / float(topk) 195 | micrork = m_can / float(ref_num) 196 | 197 | if micropk + micrork > 0: 198 | microf1 = float(2 * (micropk * micrork)) / (micropk + micrork) 199 | else: 200 | microf1 = 0.0 201 | 202 | precision_scores[topk].append(micropk) 203 | recall_scores[topk].append(micrork) 204 | f1_scores[topk].append(microf1) 205 | 206 | return f1_scores, precision_scores, recall_scores, candidates, references, ref_num 207 | 208 | def mean_pooling(model_output, attention_mask): 209 | hidden_states = model_output.hidden_states 210 | token_embeddings = hidden_states[-2] #First element of model_output contains all token embeddings 211 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 212 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 213 | 214 | def keyphrases_selection(doc_list, references, model, dataloader, log, doc_avg_tok_num): 215 | 216 | model.eval() 217 | 218 | cos_similarity_list = {} 219 | candidate_list = [] 220 | cos_score_list = [] 221 | doc_id_list = [] 222 | 223 | for id, [ori_doc, candidate, doc_id] in enumerate(tqdm(dataloader,desc="Evaluating:")): 224 | 225 | ori_input_ids = torch.squeeze(ori_doc["input_ids"].to('cuda'),1) 226 | ori_token_type_ids = torch.squeeze(ori_doc["token_type_ids"].to('cuda'),1) 227 | ori_attention_mask = torch.squeeze(ori_doc["attention_mask"].to('cuda'),1) 228 | 229 | can_input_ids = torch.squeeze(candidate["input_ids"].to('cuda'),1) 230 | can_token_type_ids = torch.squeeze(candidate["token_type_ids"].to('cuda'),1) 231 | can_attention_mask = torch.squeeze(candidate["attention_mask"].to('cuda'),1) 232 | can = candidate["candidate"] 233 | 234 | 235 | with torch.no_grad(): 236 | # See the models docstrings for the detail of the inputs 237 | ori_outputs = model(input_ids=ori_input_ids, attention_mask=ori_attention_mask, token_type_ids=ori_token_type_ids, output_hidden_states=True) 238 | masked_outputs = model(input_ids=can_input_ids, attention_mask=can_attention_mask, token_type_ids=can_token_type_ids, output_hidden_states=True) 239 | # Transformers models always output tuples. 240 | # See the models docstrings for the detail of all the outputs 241 | # In our case, the first element is the hidden state of the last layer of the Bert model 242 | ori_doc_embed = mean_pooling(ori_outputs, ori_attention_mask) 243 | can_embed = mean_pooling(masked_outputs, can_attention_mask) 244 | 245 | cosine_similarity = torch.cosine_similarity(ori_doc_embed, can_embed, dim=1).cpu() 246 | 247 | doc_id_list.extend(doc_id.numpy().tolist()) 248 | candidate_list.extend(can) 249 | cos_score_list.extend(cosine_similarity.numpy()) 250 | 251 | cos_similarity_list["doc_id"] = doc_id_list 252 | cos_similarity_list["candidate"] = candidate_list 253 | cos_similarity_list["cos"] = cos_score_list 254 | 255 | cosine_similarity_rank = pd.DataFrame(cos_similarity_list) 256 | total_f1_socres, total_precision_scores, total_recall_scores = np.zeros([len(doc_list),3]),\ 257 | np.zeros([len(doc_list),3]),\ 258 | np.zeros([len(doc_list),3]) 259 | 260 | doc_num = len(doc_list) 261 | ref_total_len = 0 262 | for i in range(len(doc_list)): 263 | doc_results = cosine_similarity_rank.loc[cosine_similarity_rank['doc_id']==i] 264 | ranked_keyphrases = doc_results.sort_values(by='cos', ascending=False) 265 | top_k = ranked_keyphrases.reset_index(drop = True) 266 | print(top_k) 267 | top_k = top_k.loc[:, ['candidate']].values.tolist() 268 | doc_references = references[i] 269 | 270 | f1_scores, precision_scores, recall_scores, candidates_clean, references_clean, ref_num = eval_metric(top_k, doc_references) 271 | ref_total_len +=ref_num 272 | for idx, key in enumerate([5,10,15]): 273 | total_f1_socres[i][idx] = f1_scores[key][0] 274 | total_precision_scores[i][idx] = precision_scores[key][0] 275 | total_recall_scores[i][idx] = recall_scores[key][0] 276 | # if args.local_rank == 0: 277 | log.logger.info("Doc {} results:\n {}".format(i, candidates_clean)) 278 | log.logger.info("Reference:\n {}".format(references_clean)) 279 | log.logger.info("###########################") 280 | log.logger.info("F1: {} ".format(f1_scores)) 281 | log.logger.info("P: {} ".format(precision_scores)) 282 | log.logger.info("R: {} ".format(recall_scores)) 283 | log.logger.info("###########################\n") 284 | 285 | 286 | log.logger.info("############ Total Result ############") 287 | for i , key in enumerate([5,10,15]): 288 | log.logger.info("ref_avg_len: {}".format(ref_total_len/doc_num)) 289 | log.logger.info("doc_avg_len: {}".format(doc_avg_tok_num)) 290 | log.logger.info("@{}".format(key)) 291 | log.logger.info("F1:{}".format(np.mean(total_f1_socres[:,i], axis=0))) 292 | log.logger.info("P:{}".format(np.mean(total_precision_scores[:,i], axis=0))) 293 | log.logger.info("R:{}".format(np.mean(total_recall_scores[:,i], axis=0))) 294 | log.logger.info("#########################\n") 295 | # doc_sentences = list(filter(None, ex.split('.'))) 296 | # 297 | # 298 | 299 | 300 | if __name__ == '__main__': 301 | 302 | parser = argparse.ArgumentParser() 303 | parser.add_argument("--dataset_dir", 304 | default=None, 305 | type=str, 306 | required=True, 307 | help="The input dataset.") 308 | parser.add_argument("--dataset_name", 309 | default=None, 310 | type=str, 311 | required=True, 312 | help="The input dataset name.") 313 | parser.add_argument("--batch_size", 314 | default=None, 315 | type=int, 316 | required=True, 317 | help="Batch size for testing.") 318 | parser.add_argument("--checkpoints", 319 | default=None, 320 | type=str, 321 | required=False, 322 | help="Checkpoint for pre-trained Bert model") 323 | parser.add_argument("--log_dir", 324 | default=None, 325 | type=str, 326 | required=True, 327 | help="Path for Logging file") 328 | parser.add_argument("--local_rank", 329 | default=-1, 330 | type=int, 331 | help="local_rank for distributed training on gpus") 332 | parser.add_argument("--no_cuda", 333 | action="store_true", 334 | help="Whether not to use CUDA when available") 335 | args = parser.parse_args() 336 | 337 | 338 | log = Logger(args.log_dir + args.dataset_name + '.kpe.log') 339 | if args.local_rank == -1 or args.no_cuda: 340 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 341 | n_gpu = torch.cuda.device_count() 342 | 343 | doc_list, references, doc_avg_tok_num = generate_doc(args.dataset_dir, args.dataset_name) 344 | 345 | docs_pairs = [] 346 | for idx, doc in tqdm(enumerate(doc_list),desc="generating pairs..."): 347 | # candidates, candidates_num = extract_candidate_words(doc) 348 | extractor = TextRank() 349 | extractor.load_document(input=doc, 350 | language="en", 351 | normalization="none") 352 | extractor.candidate_selection(pos={'NOUN', 'PROPN', 'ADJ'}) 353 | candidates = list(extractor.candidates.keys()) 354 | 355 | candidates_num = len(candidates) 356 | for can in candidates: 357 | doc_pair = [doc, can, idx] 358 | docs_pairs.append(doc_pair) 359 | 360 | dataset = PhraseKPE_Dataset(docs_pairs) 361 | dataloader = DataLoader(dataset, batch_size=args.batch_size) 362 | 363 | model = BertForMaskedLM.from_pretrained('bert-base-uncased') 364 | if os.path.exists(args.checkpoints): 365 | if args.local_rank == 0: 366 | log.logger.info("Loading Checkpoint ...") 367 | accelerator = Accelerator() 368 | unwrapped_model = accelerator.unwrap_model(model) 369 | unwrapped_model.load_state_dict(torch.load(args.checkpoints)) 370 | log.logger.info("Start Testing ...") 371 | model.to(device) 372 | keyphrases_selection(doc_list, references, model, dataloader, log, doc_avg_tok_num) 373 | 374 | 375 | 376 | 377 | -------------------------------------------------------------------------------- /utils/attention.py: -------------------------------------------------------------------------------- 1 | from transformers import BertForMaskedLM, BertTokenizer 2 | from bertviz.neuron_view import show 3 | import re 4 | import torch 5 | from numpy import * 6 | from tqdm import tqdm 7 | import pandas as pd 8 | import numpy as np 9 | import logging 10 | import argparse 11 | import codecs 12 | import json 13 | import os 14 | import string 15 | import nltk 16 | from stanfordcorenlp import StanfordCoreNLP 17 | from accelerate import Accelerator 18 | import time 19 | from nltk.corpus import stopwords 20 | # nltk.download('averaged_perceptron_tagger') 21 | from nltk.stem import PorterStemmer 22 | import itertools 23 | import matplotlib.pyplot as plt 24 | import pandas as pd 25 | from itertools import groupby 26 | 27 | MAX_LEN =512 28 | en_model = StanfordCoreNLP(r'/home/zhanglinhan.zlh/SIFRank/stanford-corenlp-full-2018-02-27',quiet=True) 29 | 30 | 31 | 32 | stopword_dict = set(stopwords.words('english')) 33 | 34 | GRAMMAR1 = """ NP: 35 | {*} # Adjective(s)(optional) + Noun(s)""" 36 | 37 | GRAMMAR2 = """ NP: 38 | {*{0,3}} # Adjective(s)(optional) + Noun(s)""" 39 | 40 | GRAMMAR3 = """ NP: 41 | {*} # Adjective(s)(optional) + Noun(s)""" 42 | 43 | class Logger(object): 44 | level_relations = { 45 | 'debug': logging.DEBUG, 46 | 'info': logging.INFO, 47 | 'warning': logging.WARNING, 48 | 'error': logging.ERROR, 49 | 'crit': logging.CRITICAL 50 | } # 日志级别关系映射 51 | 52 | def __init__(self, filename, level='info'): 53 | 54 | self.logger = logging.getLogger(filename) 55 | # # format_str = logging.Formatter(fmt) # 设置日志格式 56 | # if args.local_rank == 0 : 57 | # level = level 58 | # else: 59 | # level = 'warning' 60 | self.logger.setLevel(self.level_relations.get(level)) # 设置日志级别 61 | sh = logging.StreamHandler() # 往屏幕上输出 62 | # sh.setFormatter(format_str) # 设置屏幕上显示的格式 63 | 64 | th = logging.FileHandler(filename,'w') 65 | # formatter = logging.Formatter('%(asctime)s => %(name)s * %(levelname)s : %(message)s') 66 | # th.setFormatter(formatter) 67 | 68 | self.logger.addHandler(sh) # 代表在屏幕上输出,如果注释掉,屏幕将不输出 69 | self.logger.addHandler(th) # 代表在log文件中输出,如果注释掉,将不再向文件中写入数据 70 | 71 | def extract_candidates(tokens_tagged, no_subset=False): 72 | """ 73 | Based on part of speech return a list of candidate phrases 74 | :param text_obj: Input text Representation see @InputTextObj 75 | :param no_subset: if true won't put a candidate which is the subset of an other candidate 76 | :return keyphrase_candidate: list of list of candidate phrases: [tuple(string,tuple(start_index,end_index))] 77 | """ 78 | np_parser = nltk.RegexpParser(GRAMMAR1) # Noun phrase parser 79 | keyphrase_candidate = [] 80 | np_pos_tag_tokens = np_parser.parse(tokens_tagged) 81 | count = 0 82 | for token in np_pos_tag_tokens: 83 | if (isinstance(token, nltk.tree.Tree) and token._label == "NP"): 84 | np = ' '.join(word for word, tag in token.leaves()) 85 | length = len(token.leaves()) 86 | start_end = (count, count + length) 87 | count += length 88 | keyphrase_candidate.append((np, start_end)) 89 | 90 | else: 91 | count += 1 92 | 93 | return keyphrase_candidate 94 | 95 | class InputTextObj: 96 | """Represent the input text in which we want to extract keyphrases""" 97 | 98 | def __init__(self, en_model, text=""): 99 | """ 100 | :param is_sectioned: If we want to section the text. 101 | :param en_model: the pipeline of tokenization and POS-tagger 102 | :param considered_tags: The POSs we want to keep 103 | """ 104 | self.considered_tags = {'NN', 'NNS', 'NNP', 'NNPS', 'JJ'} 105 | 106 | self.tokens = [] 107 | self.tokens_tagged = [] 108 | self.tokens = en_model.word_tokenize(text) 109 | self.tokens_tagged = en_model.pos_tag(text) 110 | assert len(self.tokens) == len(self.tokens_tagged) 111 | for i, token in enumerate(self.tokens): 112 | if token.lower() in stopword_dict: 113 | self.tokens_tagged[i] = (token, "IN") 114 | self.keyphrase_candidate = extract_candidates(self.tokens_tagged, en_model) 115 | 116 | 117 | def clean_text(text="",database="Inspec"): 118 | 119 | #Specially for Duc2001 Database 120 | if(database=="Duc2001" or database=="Semeval2017"): 121 | pattern2 = re.compile(r'[\s,]' + '[\n]{1}') 122 | while (True): 123 | if (pattern2.search(text) is not None): 124 | position = pattern2.search(text) 125 | start = position.start() 126 | end = position.end() 127 | # start = int(position[0]) 128 | text_new = text[:start] + "\n" + text[start + 2:] 129 | text = text_new 130 | else: 131 | break 132 | 133 | pattern2 = re.compile(r'[a-zA-Z0-9,\s]' + '[\n]{1}') 134 | while (True): 135 | if (pattern2.search(text) is not None): 136 | position = pattern2.search(text) 137 | start = position.start() 138 | end = position.end() 139 | # start = int(position[0]) 140 | text_new = text[:start + 1] + " " + text[start + 2:] 141 | text = text_new 142 | else: 143 | break 144 | 145 | pattern3 = re.compile(r'\s{2,}') 146 | while (True): 147 | if (pattern3.search(text) is not None): 148 | position = pattern3.search(text) 149 | start = position.start() 150 | end = position.end() 151 | # start = int(position[0]) 152 | text_new = text[:start + 1] + "" + text[start + 2:] 153 | text = text_new 154 | else: 155 | break 156 | 157 | pattern1 = re.compile(r'[<>[\]{}]') 158 | text = pattern1.sub(' ', text) 159 | text = text.replace("\t", " ") 160 | text = text.replace(' p ','\n') 161 | text = text.replace(' /p \n','\n') 162 | lines = text.splitlines() 163 | # delete blank line 164 | text_new="" 165 | for line in lines: 166 | if(line!='\n'): 167 | text_new+=line+'\n' 168 | 169 | return text_new 170 | 171 | 172 | def get_inspec_data(file_path="data/Inspec"): 173 | 174 | data={} 175 | labels={} 176 | for dirname, dirnames, filenames in os.walk(file_path): 177 | for fname in filenames: 178 | left, right = fname.split('.') 179 | if (right == "abstr"): 180 | infile = os.path.join(dirname, fname) 181 | f=open(infile) 182 | text=f.read() 183 | text = text.replace("%", '') 184 | text=clean_text(text) 185 | data[left]=text 186 | if (right == "uncontr"): 187 | infile = os.path.join(dirname, fname) 188 | f=open(infile) 189 | text=f.read() 190 | text=text.replace("\n",' ') 191 | text=clean_text(text,database="Inspec") 192 | text=text.lower() 193 | label=text.split("; ") 194 | labels[left]=label 195 | return data,labels 196 | 197 | def extract_candidate_words(text, good_tags=set(['JJ','JJR','JJS','NN','NNP','NNS','NNPS'])): 198 | 199 | punct = set(string.punctuation) 200 | 201 | stop_words = set(nltk.corpus.stopwords.words('english')) 202 | tagged_words = itertools.chain.from_iterable(nltk.pos_tag_sents(nltk.word_tokenize(sent) for sent in nltk.sent_tokenize(text))) 203 | candidate_phrase = [] 204 | candidates = [] 205 | for word, tag in tagged_words: 206 | if tag in good_tags and word.lower() not in stop_words and not all(char in punct for char in word): 207 | candidate_phrase.append(word) 208 | continue 209 | else: 210 | if candidate_phrase: 211 | candidates.append(candidate_phrase) 212 | candidate_phrase = [] 213 | else: 214 | continue 215 | 216 | candiates_num = len(candidates) 217 | 218 | return candidates, candiates_num 219 | 220 | def dedup(candidates): 221 | new_can = {} 222 | for can in candidates: 223 | can_set = can.split() 224 | candidate_len = len(can_set) 225 | # can = ' '.join(can) 226 | new_can[can] = candidate_len 227 | 228 | return new_can 229 | 230 | def generate_absent_doc(doc, candidates, idx): 231 | 232 | doc_pairs = [] 233 | #每个文章的candidate, 可能有多个 234 | doc_candidate = dedup(candidates) 235 | for id, candidate in enumerate(doc_candidate.keys()): 236 | candidate_len = doc_candidate[candidate] 237 | mask = ' '.join(['[MASK]']*candidate_len) 238 | try: 239 | candidate_re = re.compile(r"\b" + candidate + r"\b") 240 | masked_doc = re.sub(candidate_re, mask, doc.lower()) 241 | except: 242 | continue 243 | 244 | doc_pairs.append([doc.lower(), masked_doc, candidate, idx]) 245 | # print("Candidate: ", candidate) 246 | # print("Masked Doc {} : {}".format(idx, masked_doc)) 247 | # print("Ori_doc {}: {}".format(idx, doc.lower())) 248 | 249 | return doc_pairs 250 | 251 | porter = nltk.PorterStemmer() 252 | data, referneces = get_inspec_data("../SIFRank/data/Inspec") 253 | docs_pairs = [] 254 | doc_list = [] 255 | labels = [] 256 | labels_stemed = [] 257 | candidates_num = 0 258 | 259 | 260 | def largest_indices(ary, n): 261 | flat = ary.flatten() 262 | indices = np.argpartition(flat, -n)[-n:] 263 | indices = indices[np.argsort(-flat[indices])] 264 | return np.unravel_index(indices, ary.shape) 265 | 266 | def get_PRF(num_c, num_e, num_s): 267 | F1 = 0.0 268 | P = float(num_c) / float(num_e) if num_e!=0 else 0.0 269 | R = float(num_c) / float(num_s) if num_s!=0 else 0.0 270 | if (P + R == 0.0): 271 | F1 = 0 272 | else: 273 | F1 = 2 * P * R / (P + R) 274 | return P, R, F1 275 | 276 | 277 | def print_PRF(P, R, F1, N): 278 | 279 | log.logger.info("\nN=" + str(N)) 280 | log.logger.info("P=" + str(P)) 281 | log.logger.info("R=" + str(R)) 282 | log.logger.info("F1=" + str(F1)) 283 | return 0 284 | 285 | P_5 = R_5 = F1_5 = 0.0 286 | P_10 = R_10 = F1_10 = 0.0 287 | P_15 = R_15 = F1_15 = 0.0 288 | num_c_5 = num_c_10 = num_c_15 = 0 289 | num_e_5 = num_e_10 = num_e_15 = 0 290 | num_s = 0 291 | lamda = 0.0 292 | 293 | log = Logger('result/Inspec.kpe.cls.attention.5.log') 294 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 295 | model = BertForMaskedLM.from_pretrained('bert-base-uncased') 296 | 297 | for idx, (key, doc) in enumerate(data.items()): 298 | set_can = set() 299 | labels_o = [] 300 | for ref in referneces[key]: 301 | labels_o.append(ref.replace(" \n","")) 302 | labels.append(labels_o) 303 | labels_s = [] 304 | for l in labels_o: 305 | tokens = l.split() 306 | labels_s.append(' '.join(porter.stem(t) for t in tokens)) 307 | labels_stemed.append(labels_s) 308 | doc_list.append(doc) 309 | encoded_dict = tokenizer.encode_plus( 310 | doc, # Sentence to encode. 311 | add_special_tokens=True, # Add '[CLS]' and '[SEP]' 312 | return_attention_mask=True, # Construct attn. masks. 313 | return_tensors='pt', # Return pytorch tensors. 314 | truncation=True 315 | ) 316 | outputs = model(**encoded_dict, output_attentions=True) 317 | attention = outputs.attentions[-1] 318 | 319 | 320 | text_obj = InputTextObj(en_model, doc) 321 | cans = text_obj.keyphrase_candidate 322 | candidates = [] 323 | for can, pos in cans: 324 | candidates.append(can) 325 | set_can.add(can.lower()) 326 | candidates_num += len(set_can) 327 | candidate_att = [] 328 | for can in set_can: 329 | doc_pair = [doc, can, idx] 330 | docs_pairs.append(doc_pair) 331 | can_tokens = tokenizer.tokenize(can) 332 | can_token_ids = tokenizer.convert_tokens_to_ids(can_tokens) 333 | print("can token ids: ", can_token_ids) 334 | can_token_num = len(can_token_ids) 335 | head_att = [] 336 | for i in range(attention.size(1)): 337 | att_matrix = attention[0][i][0].detach().numpy() 338 | token_ids = encoded_dict["input_ids"].squeeze() 339 | cans_pos = [] 340 | for idx in can_token_ids: 341 | cans_pos.append(torch.nonzero(token_ids == idx)) 342 | print("can positions: ", cans_pos) 343 | can_att_avg = 0 344 | min_app = 0 345 | positions = [] 346 | pos_set = set() 347 | for c in range(len(cans_pos)): 348 | for p in cans_pos[c]: 349 | positions.append(int(p[0].numpy())) 350 | pos_set.add(int(p[0].numpy())) 351 | 352 | positions = sorted(positions) 353 | pos_set = sorted([p for p in pos_set]) 354 | print("pos set: ", pos_set) 355 | fun = lambda x: x[1] - x[0] 356 | can_att = [] 357 | for k, g in groupby(enumerate(pos_set), fun): 358 | l1 = [j for i, j in g] # 连续数字的列表 359 | print("l1: ", l1) 360 | print("can token num: ", can_token_num) 361 | if len(l1) >= can_token_num: 362 | for l in l1: 363 | can_att_avg +=att_matrix[l] 364 | can_att.append(can_att_avg/len(l1)) 365 | if len(can_att) == 0: 366 | print("can tokens: ", can_token_ids) 367 | print("tokens: ", token_ids) 368 | # head_att.append(sum(can_att)/len(can_att)) #不同head的attention 369 | if len(can_att) == 0: 370 | break 371 | else: 372 | head_att.append(sum(can_att)/len(can_att)) 373 | candidate_att.append(head_att) 374 | heads_num=np.arange(attention.size(1)) 375 | candidate_att = pd.DataFrame(candidate_att,index=set_can, columns=heads_num) 376 | candidate_att["max"] = candidate_att.max(axis=1) #取出该最大值 377 | candidate_att["mean"] = candidate_att.mean(axis=1) 378 | 379 | # 380 | # for i in heads_num: 381 | results = candidate_att.sort_values(by=[5],ascending=False) #降序排列 382 | top_k = results.index.values.tolist() 383 | doc_labels = [ref.replace(" \n", "") for ref in referneces[key]] 384 | log.logger.info("Doc {} Head {} \nTop15 Candidates: {} \nReference: {}".format(idx, 5, top_k[:15], doc_labels)) 385 | 386 | j = 0 387 | matched_candidate = [] 388 | for temp in top_k[0:15]: 389 | tokens = temp.lower().split() 390 | tt = ' '.join(porter.stem(t) for t in tokens) 391 | if (tt in labels_s or temp in labels_o): 392 | if (j < 5): 393 | num_c_5 += 1 394 | num_c_10 += 1 395 | num_c_15 += 1 396 | 397 | elif (j < 10 and j >= 5): 398 | num_c_10 += 1 399 | num_c_15 += 1 400 | 401 | elif (j < 15 and j >= 10): 402 | num_c_15 += 1 403 | matched_candidate.append(temp) 404 | j += 1 405 | 406 | if (len(top_k[0:5]) == 5): 407 | num_e_5 += 5 408 | else: 409 | num_e_5 += len(top_k[0:5]) 410 | 411 | if (len(top_k[0:10]) == 10): 412 | num_e_10 += 10 413 | else: 414 | num_e_10 += len(top_k[0:10]) 415 | 416 | if (len(top_k[0:15]) == 15): 417 | num_e_15 += 15 418 | else: 419 | num_e_15 += len(top_k[0:15]) 420 | 421 | num_s += len(labels_s) 422 | log.logger.info("Matched Candidate: {} \n".format(matched_candidate)) 423 | 424 | 425 | p, r, f = get_PRF(num_c_5, num_e_5, num_s) 426 | print_PRF(p, r, f, 5) 427 | p, r, f = get_PRF(num_c_10, num_e_10, num_s) 428 | print_PRF(p, r, f, 10) 429 | p, r, f = get_PRF(num_c_15, num_e_15, num_s) 430 | print_PRF(p, r, f, 15) 431 | 432 | 433 | en_model.close() 434 | -------------------------------------------------------------------------------- /utils/cos_mask_doc.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch.utils.data import Dataset 4 | from tqdm import tqdm 5 | from transformers import BertForMaskedLM, BertTokenizer 6 | from torch.utils.data import DataLoader 7 | import pandas as pd 8 | from pke.unsupervised import TextRank 9 | import numpy as np 10 | import logging 11 | import argparse 12 | from accelerate import Accelerator 13 | import codecs 14 | import json 15 | import os 16 | import string 17 | import nltk 18 | import time 19 | # nltk.download('averaged_perceptron_tagger') 20 | from nltk.stem import PorterStemmer 21 | import itertools 22 | 23 | MAX_LEN =512 24 | 25 | class Logger(object): 26 | level_relations = { 27 | 'debug': logging.DEBUG, 28 | 'info': logging.INFO, 29 | 'warning': logging.WARNING, 30 | 'error': logging.ERROR, 31 | 'crit': logging.CRITICAL 32 | } # 日志级别关系映射 33 | 34 | def __init__(self, filename, level='info'): 35 | 36 | self.logger = logging.getLogger(filename) 37 | # # format_str = logging.Formatter(fmt) # 设置日志格式 38 | # if args.local_rank == 0 : 39 | # level = level 40 | # else: 41 | # level = 'warning' 42 | self.logger.setLevel(self.level_relations.get(level)) # 设置日志级别 43 | sh = logging.StreamHandler() # 往屏幕上输出 44 | # sh.setFormatter(format_str) # 设置屏幕上显示的格式 45 | 46 | th = logging.FileHandler(filename,'w') 47 | formatter = logging.Formatter('%(asctime)s => %(name)s * %(levelname)s : %(message)s') 48 | th.setFormatter(formatter) 49 | 50 | self.logger.addHandler(sh) # 代表在屏幕上输出,如果注释掉,屏幕将不输出 51 | self.logger.addHandler(th) # 代表在log文件中输出,如果注释掉,将不再向文件中写入数据 52 | 53 | 54 | class KPE_Dataset(Dataset): 55 | def __init__(self, docs_pairs): 56 | 57 | # print("generated candidates: ", doc_candidate_list) 58 | # total_pairs = [] 59 | # for doc_id in docs_pairs.keys(): 60 | # doc_pairs = docs_pairs[doc_id] 61 | # for pair in doc_pairs: 62 | # pair.append(doc_id) 63 | # total_pairs.append(pair) 64 | self.docs_pairs = docs_pairs 65 | self.total_examples = len(self.docs_pairs) 66 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) 67 | def __len__(self): 68 | return self.total_examples 69 | 70 | def __getitem__(self, idx): 71 | 72 | doc_pair = self.docs_pairs[idx] 73 | ori_doc = doc_pair[0] 74 | masked_doc = doc_pair[1] 75 | candidate = doc_pair[2] 76 | doc_id = doc_pair[3] 77 | 78 | tokenized_ori_doc = self.tokenized_doc(ori_doc, self.tokenizer, candidate, mode='ori') 79 | tokenized_masked_doc = self.tokenized_doc(masked_doc, self.tokenizer, candidate, mode='mask') 80 | 81 | return [tokenized_ori_doc, tokenized_masked_doc, doc_id] 82 | 83 | def tokenized_doc(self, doc, tokenizer, candidate, mode): 84 | 85 | max_len = MAX_LEN 86 | 87 | encoded_dict = tokenizer.encode_plus( 88 | doc, # Sentence to encode. 89 | add_special_tokens=True, # Add '[CLS]' and '[SEP]' 90 | max_length=max_len, # Pad & truncate all sentences. 91 | padding='max_length', 92 | return_attention_mask=True, # Construct attn. masks. 93 | return_tensors='pt', # Return pytorch tensors. 94 | truncation=True 95 | ) 96 | input_ids = encoded_dict["input_ids"] 97 | attention_mask = encoded_dict["attention_mask"] 98 | token_type_ids = encoded_dict["token_type_ids"] 99 | 100 | # 保存在字典里 101 | if mode == 'mask': 102 | example = { 103 | "input_ids": input_ids, 104 | "token_type_ids": token_type_ids, 105 | "attention_mask": attention_mask, 106 | "candidate": candidate 107 | } 108 | 109 | else: 110 | example = { 111 | "input_ids": input_ids, 112 | "token_type_ids": token_type_ids, 113 | "attention_mask": attention_mask, 114 | "candidate": candidate 115 | } 116 | 117 | return example 118 | 119 | def load_dataset(file_path): 120 | """ Load file.jsonl .""" 121 | data_list = [] 122 | with codecs.open(file_path, 'r', 'utf-8') as f: 123 | json_text = f.readlines() 124 | for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."): 125 | try: 126 | jsonl = json.loads(line) 127 | data_list.append(jsonl) 128 | except: 129 | raise ValueError 130 | 131 | return data_list 132 | 133 | def generate_doc(dataset_dir, dataset_name): 134 | 135 | doc_list = [] 136 | keyphrases = [] 137 | doc_tok_num = 0 138 | dataset = load_dataset(dataset_dir) 139 | for idx, example in enumerate(dataset): 140 | keywords = example['keywords'].lower() 141 | abstract = example['abstract'] 142 | # if dataset_name =="semeval" or dataset_name =="nus": 143 | # fulltxt = example['fulltext'] 144 | # doc = ' '.join([abstract,fulltxt]) 145 | # else: 146 | # doc = abstract 147 | doc = abstract 148 | doc = re.sub('\. ', ' . ', doc) 149 | doc = re.sub(', ', ' , ', doc) 150 | doc_tok = doc.split(' ') 151 | if len(doc_tok) > 510: 152 | doc_tok = doc_tok [: 510] 153 | doc_tok_num +=len(doc_tok) 154 | doc_list.append(' '.join(doc_tok)) 155 | keyphrases.append(keywords) 156 | return doc_list, keyphrases, doc_tok_num/len(dataset) 157 | 158 | 159 | def extract_candidate_words(text, good_tags=set(['JJ','JJR','JJS','NN','NNP','NNS','NNPS'])): 160 | 161 | punct = set(string.punctuation) 162 | 163 | stop_words = set(nltk.corpus.stopwords.words('english')) 164 | tagged_words = itertools.chain.from_iterable(nltk.pos_tag_sents(nltk.word_tokenize(sent) for sent in nltk.sent_tokenize(text))) 165 | candidate_phrase = [] 166 | candidates = [] 167 | for word, tag in tagged_words: 168 | if tag in good_tags and word.lower() not in stop_words and not all(char in punct for char in word): 169 | candidate_phrase.append(word) 170 | continue 171 | else: 172 | if candidate_phrase: 173 | candidates.append(candidate_phrase) 174 | candidate_phrase = [] 175 | else: 176 | continue 177 | 178 | candiates_num = len(candidates) 179 | 180 | return candidates, candiates_num 181 | 182 | def dedup(candidates): 183 | new_can = {} 184 | for can in candidates: 185 | can_set = can.split() 186 | candidate_len = len(can_set) 187 | # can = ' '.join(can) 188 | new_can[can] = candidate_len 189 | 190 | return new_can 191 | 192 | def generate_absent_doc(doc, candidates, idx): 193 | 194 | doc_pairs = [] 195 | #每个文章的candidate, 可能有多个 196 | doc_candidate = dedup(candidates) 197 | for id, candidate in enumerate(doc_candidate.keys()): 198 | candidate_len = doc_candidate[candidate] 199 | mask = ' '.join(['[MASK]']*candidate_len) 200 | try: 201 | candidate_re = re.compile(r"\b" + candidate + r"\b") 202 | masked_doc = re.sub(candidate_re, mask, doc.lower()) 203 | except: 204 | continue 205 | 206 | doc_pairs.append([doc.lower(), masked_doc, candidate, idx]) 207 | # print("Candidate: ", candidate) 208 | # print("Masked Doc {} : {}".format(idx, masked_doc)) 209 | # print("Ori_doc {}: {}".format(idx, doc.lower())) 210 | 211 | return doc_pairs 212 | 213 | def eval_metric(cans, refs): 214 | precision_scores, recall_scores, f1_scores = {5: [], 10: [], 15:[]},{5: [], 10: [], 15:[]},{5: [], 10: [], 15:[]} 215 | 216 | stemmer = PorterStemmer() 217 | references = refs.split(";") 218 | ref_num = len(references) 219 | 220 | for i, reference in enumerate(references): 221 | reference = stemmer.stem(reference.lower()) 222 | references[i] = reference.lower() 223 | candidates_clean = set() 224 | candidates = [] 225 | for i, can in enumerate(cans): 226 | can = stemmer.stem(can[0].lower()) 227 | if can in candidates_clean: 228 | continue 229 | else: 230 | candidates_clean.add(can) 231 | candidates.append(can) 232 | 233 | 234 | for topk in [5, 10, 15]: 235 | m_can = 0 236 | for i,candidate in enumerate(candidates[:topk]): 237 | if candidate in references: 238 | m_can += 1 239 | micropk = m_can / float(topk) 240 | micrork = m_can / float(ref_num) 241 | 242 | if micropk + micrork > 0: 243 | microf1 = float(2 * (micropk * micrork)) / (micropk + micrork) 244 | else: 245 | microf1 = 0.0 246 | 247 | precision_scores[topk].append(micropk) 248 | recall_scores[topk].append(micrork) 249 | f1_scores[topk].append(microf1) 250 | 251 | return f1_scores, precision_scores, recall_scores, candidates, references, ref_num 252 | 253 | def mean_pooling(model_output, attention_mask): 254 | hidden_states = model_output.hidden_states 255 | token_embeddings = hidden_states[-2] #First element of model_output contains all token embeddings 256 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 257 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 258 | 259 | 260 | def keyphrases_selection(doc_list, references, model, dataloader, log, doc_avg_tok_num): 261 | 262 | model.eval() 263 | 264 | cos_similarity_list = {} 265 | candidate_list = [] 266 | cos_score_list = [] 267 | doc_id_list = [] 268 | 269 | for id, [ori_doc, masked_doc, doc_id] in enumerate(tqdm(dataloader,desc="Evaluating:")): 270 | 271 | ori_input_ids = torch.squeeze(ori_doc["input_ids"].to('cuda'),1) 272 | ori_token_type_ids = torch.squeeze(ori_doc["token_type_ids"].to('cuda'), 1) 273 | ori_attention_mask = torch.squeeze(ori_doc["attention_mask"].to('cuda'), 1) 274 | 275 | masked_input_ids = torch.squeeze(masked_doc["input_ids"].to('cuda'), 1) 276 | masked_token_type_ids = torch.squeeze(masked_doc["token_type_ids"].to('cuda'), 1) 277 | masked_attention_mask = torch.squeeze(masked_doc["attention_mask"].to('cuda'), 1) 278 | candidate = masked_doc["candidate"] 279 | # print("masked doc candiate: ", candidate) 280 | # log.logger.info("candiate: {}".format(candidate)) 281 | # Predict hidden states features for each layer 282 | with torch.no_grad(): 283 | # See the models docstrings for the detail of the inputs 284 | ori_outputs = model(input_ids=ori_input_ids, attention_mask=ori_attention_mask, token_type_ids=ori_token_type_ids, output_hidden_states=True) 285 | masked_outputs = model(input_ids=masked_input_ids, attention_mask=masked_attention_mask, token_type_ids=masked_token_type_ids, output_hidden_states=True) 286 | # Transformers models always output tuples. 287 | # See the models docstrings for the detail of all the outputs 288 | # In our case, the first element is the hidden state of the last layer of the Bert model 289 | ori_doc_embed = mean_pooling(ori_outputs, ori_attention_mask) 290 | masked_doc_embed = mean_pooling(masked_outputs, masked_attention_mask) 291 | cosine_similarity = torch.cosine_similarity(ori_doc_embed, masked_doc_embed, dim=1).cpu() 292 | 293 | doc_id_list.extend(doc_id.numpy().tolist()) 294 | candidate_list.extend(candidate) 295 | cos_score_list.extend(cosine_similarity.numpy()) 296 | 297 | cos_similarity_list["doc_id"] = doc_id_list 298 | cos_similarity_list["candidate"] = candidate_list 299 | cos_similarity_list["cos"] = cos_score_list 300 | 301 | cosine_similarity_rank = pd.DataFrame(cos_similarity_list) 302 | total_f1_socres, total_precision_scores, total_recall_scores = np.zeros([len(doc_list),3]),\ 303 | np.zeros([len(doc_list),3]),\ 304 | np.zeros([len(doc_list),3]) 305 | doc_num = len(doc_list) 306 | ref_total_len = 0 307 | for i in range(len(doc_list)): 308 | doc_results = cosine_similarity_rank.loc[cosine_similarity_rank['doc_id']==i] 309 | ranked_keyphrases = doc_results.sort_values(by='cos') 310 | top_k = ranked_keyphrases.reset_index(drop = True) 311 | print(top_k) 312 | top_k = top_k.loc[:, ['candidate']].values.tolist() 313 | doc_references = references[i] 314 | 315 | f1_scores, precision_scores, recall_scores, candidates_clean, references_clean, ref_num = eval_metric(top_k, doc_references) 316 | ref_total_len +=ref_num 317 | for idx, key in enumerate([5,10,15]): 318 | total_f1_socres[i][idx] = f1_scores[key][0] 319 | total_precision_scores[i][idx] = precision_scores[key][0] 320 | total_recall_scores[i][idx] = recall_scores[key][0] 321 | 322 | log.logger.info("Doc {} results:\n {}".format(i, candidates_clean)) 323 | log.logger.info("Reference:\n {}".format(references_clean)) 324 | log.logger.info("###########################") 325 | log.logger.info("F1: {} ".format(f1_scores)) 326 | log.logger.info("P: {} ".format(precision_scores)) 327 | log.logger.info("R: {} ".format(recall_scores)) 328 | log.logger.info("###########################\n") 329 | 330 | 331 | log.logger.info("############ Total Result ############") 332 | for i , key in enumerate([5,10,15]): 333 | log.logger.info("ref_avg_len: {}".format(ref_total_len/doc_num)) 334 | log.logger.info("doc_avg_len: {}".format(doc_avg_tok_num)) 335 | log.logger.info("@{}".format(key)) 336 | log.logger.info("F1:{}".format(np.mean(total_f1_socres[:,i], axis=0))) 337 | log.logger.info("P:{}".format(np.mean(total_precision_scores[:,i], axis=0))) 338 | log.logger.info("R:{}".format(np.mean(total_recall_scores[:,i], axis=0))) 339 | log.logger.info("#########################\n") 340 | 341 | 342 | 343 | if __name__ == '__main__': 344 | 345 | parser = argparse.ArgumentParser() 346 | parser.add_argument("--dataset_dir", 347 | default=None, 348 | type=str, 349 | required=True, 350 | help="The input dataset.") 351 | parser.add_argument("--dataset_name", 352 | default=None, 353 | type=str, 354 | required=True, 355 | help="The input dataset name.") 356 | parser.add_argument("--batch_size", 357 | default=None, 358 | type=int, 359 | required=True, 360 | help="Batch size for testing.") 361 | parser.add_argument("--checkpoints", 362 | default=None, 363 | type=str, 364 | required=False, 365 | help="Checkpoint for pre-trained Bert model") 366 | parser.add_argument("--log_dir", 367 | default=None, 368 | type=str, 369 | required=True, 370 | help="Path for Logging file") 371 | parser.add_argument("--local_rank", 372 | default=-1, 373 | type=int, 374 | help="local_rank for distributed training on gpus") 375 | parser.add_argument("--no_cuda", 376 | action="store_true", 377 | help="Whether not to use CUDA when available") 378 | args = parser.parse_args() 379 | 380 | 381 | log = Logger(args.log_dir + args.dataset_name + '.kpe.log') 382 | if args.local_rank == -1 or args.no_cuda: 383 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 384 | n_gpu = torch.cuda.device_count() 385 | 386 | doc_list, references, doc_avg_tok_num = generate_doc(args.dataset_dir, args.dataset_name) 387 | 388 | docs_pairs = [] 389 | for idx, doc in tqdm(enumerate(doc_list),desc="generating pairs..."): 390 | # candidates, candidates_num = extract_candidate_words(doc) 391 | extractor = TextRank() 392 | extractor.load_document(input=doc, 393 | language="en", 394 | normalization=None) 395 | extractor.candidate_selection(pos={'NOUN', 'PROPN', 'ADJ'}) 396 | candidates = list(extractor.candidates.keys()) 397 | 398 | candidates_num = len(candidates) 399 | doc_pairs = generate_absent_doc(doc, candidates, idx) 400 | docs_pairs.extend(doc_pairs) 401 | 402 | dataset = KPE_Dataset(docs_pairs) 403 | dataloader = DataLoader(dataset, batch_size=args.batch_size) 404 | 405 | model = BertForMaskedLM.from_pretrained('bert-base-uncased') 406 | if os.path.exists(args.checkpoints): 407 | if args.local_rank == 0: 408 | log.logger.info("Loading Checkpoint ...") 409 | accelerator = Accelerator() 410 | unwrapped_model = accelerator.unwrap_model(model) 411 | unwrapped_model.load_state_dict(torch.load(args.checkpoints)) 412 | 413 | log.logger.info("Start Testing ...") 414 | model.to(device) 415 | keyphrases_selection(doc_list, references, model, dataloader, log, doc_avg_tok_num) 416 | 417 | 418 | 419 | 420 | -------------------------------------------------------------------------------- /MDERank/mderank_main.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | import torch 4 | from torch.utils.data import Dataset 5 | from tqdm import tqdm 6 | from transformers import BertForMaskedLM, BertTokenizer, RobertaTokenizer, RobertaForMaskedLM 7 | from torch.utils.data import DataLoader 8 | import pandas as pd 9 | import numpy as np 10 | import logging 11 | import argparse 12 | import codecs 13 | import random 14 | import json 15 | import os 16 | import string 17 | import nltk 18 | from stanfordcorenlp import StanfordCoreNLP 19 | from accelerate import Accelerator 20 | 21 | from nltk.corpus import stopwords 22 | from itertools import groupby 23 | # nltk.download('averaged_perceptron_tagger') 24 | from nltk.stem import PorterStemmer 25 | import itertools 26 | 27 | MAX_LEN =512 28 | en_model = StanfordCoreNLP(r'stanford-corenlp-full-2018-02-27',quiet=True) 29 | 30 | 31 | 32 | stopword_dict = set(stopwords.words('english')) 33 | 34 | GRAMMAR1 = """ NP: 35 | {*} # Adjective(s)(optional) + Noun(s)""" 36 | 37 | GRAMMAR2 = """ NP: 38 | {*{0,3}} # Adjective(s)(optional) + Noun(s)""" 39 | 40 | GRAMMAR3 = """ NP: 41 | {*} # Adjective(s)(optional) + Noun(s)""" 42 | 43 | 44 | def extract_candidates(tokens_tagged, no_subset=False): 45 | """ 46 | Based on part of speech return a list of candidate phrases 47 | :param text_obj: Input text Representation see @InputTextObj 48 | :param no_subset: if true won't put a candidate which is the subset of an other candidate 49 | :return keyphrase_candidate: list of list of candidate phrases: [tuple(string,tuple(start_index,end_index))] 50 | """ 51 | np_parser = nltk.RegexpParser(GRAMMAR1) # Noun phrase parser 52 | keyphrase_candidate = [] 53 | np_pos_tag_tokens = np_parser.parse(tokens_tagged) 54 | count = 0 55 | for token in np_pos_tag_tokens: 56 | if (isinstance(token, nltk.tree.Tree) and token._label == "NP"): 57 | np = ' '.join(word for word, tag in token.leaves()) 58 | length = len(token.leaves()) 59 | start_end = (count, count + length) 60 | count += length 61 | keyphrase_candidate.append((np, start_end)) 62 | 63 | else: 64 | count += 1 65 | 66 | return keyphrase_candidate 67 | 68 | class InputTextObj: 69 | """Represent the input text in which we want to extract keyphrases""" 70 | 71 | def __init__(self, en_model, text=""): 72 | """ 73 | :param is_sectioned: If we want to section the text. 74 | :param en_model: the pipeline of tokenization and POS-tagger 75 | :param considered_tags: The POSs we want to keep 76 | """ 77 | self.considered_tags = {'NN', 'NNS', 'NNP', 'NNPS', 'JJ'} 78 | 79 | self.tokens = [] 80 | self.tokens_tagged = [] 81 | self.tokens = en_model.word_tokenize(text) 82 | self.tokens_tagged = en_model.pos_tag(text) 83 | assert len(self.tokens) == len(self.tokens_tagged) 84 | for i, token in enumerate(self.tokens): 85 | if token.lower() in stopword_dict: 86 | self.tokens_tagged[i] = (token, "IN") 87 | self.keyphrase_candidate = extract_candidates(self.tokens_tagged, en_model) 88 | 89 | 90 | 91 | class Logger(object): 92 | level_relations = { 93 | 'debug': logging.DEBUG, 94 | 'info': logging.INFO, 95 | 'warning': logging.WARNING, 96 | 'error': logging.ERROR, 97 | 'crit': logging.CRITICAL 98 | } # 日志级别关系映射 99 | 100 | def __init__(self, filename, level='info'): 101 | 102 | self.logger = logging.getLogger(filename) 103 | # # format_str = logging.Formatter(fmt) # 设置日志格式 104 | # if args.local_rank == 0 : 105 | # level = level 106 | # else: 107 | # level = 'warning' 108 | self.logger.setLevel(self.level_relations.get(level)) # 设置日志级别 109 | sh = logging.StreamHandler() # 往屏幕上输出 110 | # sh.setFormatter(format_str) # 设置屏幕上显示的格式 111 | 112 | th = logging.FileHandler(filename,'w') 113 | # formatter = logging.Formatter('%(asctime)s => %(name)s * %(levelname)s : %(message)s') 114 | # th.setFormatter(formatter) 115 | 116 | self.logger.addHandler(sh) # 代表在屏幕上输出,如果注释掉,屏幕将不输出 117 | self.logger.addHandler(th) # 代表在log文件中输出,如果注释掉,将不再向文件中写入数据 118 | 119 | 120 | class KPE_Dataset(Dataset): 121 | def __init__(self, docs_pairs): 122 | 123 | self.docs_pairs = docs_pairs 124 | self.total_examples = len(self.docs_pairs) 125 | 126 | def __len__(self): 127 | return self.total_examples 128 | 129 | def __getitem__(self, idx): 130 | 131 | doc_pair = self.docs_pairs[idx] 132 | ori_example = doc_pair[0] 133 | masked_example = doc_pair[1] 134 | doc_id = doc_pair[2] 135 | 136 | return [ori_example, masked_example, doc_id] 137 | 138 | 139 | 140 | """pre process code from SIFRank""" 141 | class Result: 142 | 143 | def __init__(self,N=15): 144 | self.database="" 145 | self.predict_keyphrases = [] 146 | self.true_keyphrases = [] 147 | self.file_names = [] 148 | self.lamda=0.0 149 | self.beta=0.0 150 | 151 | def update_result(self, file_name, pre_kp, true_kp): 152 | self.file_names.append(file_name) 153 | self.predict_keyphrases.append(pre_kp) 154 | self.true_keyphrases.append(true_kp) 155 | 156 | def get_parameters(self,database="",lamda=0.6,beta=0.0): 157 | self.database = database 158 | self.lamda = lamda 159 | self.beta = beta 160 | 161 | def write_results(self): 162 | return 0 163 | 164 | def write_string(s, output_path): 165 | with open(output_path, 'w') as output_file: 166 | output_file.write(s) 167 | 168 | 169 | def read_file(input_path): 170 | with open(input_path, 'r', errors='replace_with_space') as input_file: 171 | return input_file.read() 172 | 173 | def clean_text(text="",database="Inspec"): 174 | 175 | #Specially for Duc2001 Database 176 | if(database=="Duc2001" or database=="Semeval2017"): 177 | pattern2 = re.compile(r'[\s,]' + '[\n]{1}') 178 | while (True): 179 | if (pattern2.search(text) is not None): 180 | position = pattern2.search(text) 181 | start = position.start() 182 | end = position.end() 183 | # start = int(position[0]) 184 | text_new = text[:start] + "\n" + text[start + 2:] 185 | text = text_new 186 | else: 187 | break 188 | 189 | pattern2 = re.compile(r'[a-zA-Z0-9,\s]' + '[\n]{1}') 190 | while (True): 191 | if (pattern2.search(text) is not None): 192 | position = pattern2.search(text) 193 | start = position.start() 194 | end = position.end() 195 | # start = int(position[0]) 196 | text_new = text[:start + 1] + " " + text[start + 2:] 197 | text = text_new 198 | else: 199 | break 200 | 201 | pattern3 = re.compile(r'\s{2,}') 202 | while (True): 203 | if (pattern3.search(text) is not None): 204 | position = pattern3.search(text) 205 | start = position.start() 206 | end = position.end() 207 | # start = int(position[0]) 208 | text_new = text[:start + 1] + "" + text[start + 2:] 209 | text = text_new 210 | else: 211 | break 212 | 213 | pattern1 = re.compile(r'[<>[\]{}]') 214 | text = pattern1.sub(' ', text) 215 | text = text.replace("\t", " ") 216 | text = text.replace(' p ','\n') 217 | text = text.replace(' /p \n','\n') 218 | lines = text.splitlines() 219 | # delete blank line 220 | text_new="" 221 | for line in lines: 222 | if(line!='\n'): 223 | text_new+=line+'\n' 224 | 225 | return text_new 226 | 227 | def get_long_data(file_path="data/nus/nus_test.json"): 228 | """ Load file.jsonl .""" 229 | data = {} 230 | labels = {} 231 | with codecs.open(file_path, 'r', 'utf-8') as f: 232 | json_text = f.readlines() 233 | for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."): 234 | try: 235 | jsonl = json.loads(line) 236 | keywords = jsonl['keywords'].lower().split(";") 237 | abstract = jsonl['abstract'] 238 | fulltxt = jsonl['fulltext'] 239 | doc = ' '.join([abstract, fulltxt]) 240 | doc = re.sub('\. ', ' . ', doc) 241 | doc = re.sub(', ', ' , ', doc) 242 | 243 | doc = clean_text(doc, database="nus") 244 | doc = doc.replace('\n', ' ') 245 | data[jsonl['name']] = doc 246 | labels[jsonl['name']] = keywords 247 | except: 248 | raise ValueError 249 | return data,labels 250 | 251 | def get_short_data(file_path="data/kp20k/kp20k_valid2k_test.json"): 252 | """ Load file.jsonl .""" 253 | data = {} 254 | labels = {} 255 | with codecs.open(file_path, 'r', 'utf-8') as f: 256 | json_text = f.readlines() 257 | for i, line in tqdm(enumerate(json_text), desc="Loading Doc ..."): 258 | try: 259 | jsonl = json.loads(line) 260 | keywords = jsonl['keywords'].lower().split(";") 261 | abstract = jsonl['abstract'] 262 | doc =abstract 263 | doc = re.sub('\. ', ' . ', doc) 264 | doc = re.sub(', ', ' , ', doc) 265 | 266 | doc = clean_text(doc, database="kp20k") 267 | doc = doc.replace('\n', ' ') 268 | data[i] = doc 269 | labels[i] = keywords 270 | except: 271 | raise ValueError 272 | return data,labels 273 | 274 | 275 | def get_duc2001_data(file_path="data/DUC2001"): 276 | pattern = re.compile(r'(.*?)', re.S) 277 | data = {} 278 | labels = {} 279 | for dirname, dirnames, filenames in os.walk(file_path): 280 | for fname in filenames: 281 | if (fname == "annotations.txt"): 282 | # left, right = fname.split('.') 283 | infile = os.path.join(dirname, fname) 284 | f = open(infile,'rb') 285 | text = f.read().decode('utf8') 286 | lines = text.splitlines() 287 | for line in lines: 288 | left, right = line.split("@") 289 | d = right.split(";")[:-1] 290 | l = left 291 | labels[l] = d 292 | f.close() 293 | else: 294 | infile = os.path.join(dirname, fname) 295 | f = open(infile,'rb') 296 | text = f.read().decode('utf8') 297 | text = re.findall(pattern, text)[0] 298 | 299 | text = text.lower() 300 | text = clean_text(text,database="Duc2001") 301 | data[fname]=text.strip("\n") 302 | # data[fname] = text 303 | return data,labels 304 | 305 | def get_inspec_data(file_path="data/Inspec"): 306 | 307 | data={} 308 | labels={} 309 | for dirname, dirnames, filenames in os.walk(file_path): 310 | for fname in filenames: 311 | left, right = fname.split('.') 312 | if (right == "abstr"): 313 | infile = os.path.join(dirname, fname) 314 | f=open(infile) 315 | text=f.read() 316 | text = text.replace("%", '') 317 | text=clean_text(text) 318 | data[left]=text 319 | if (right == "uncontr"): 320 | infile = os.path.join(dirname, fname) 321 | f=open(infile) 322 | text=f.read() 323 | text=text.replace("\n",' ') 324 | text=clean_text(text,database="Inspec") 325 | text=text.lower() 326 | label=text.split("; ") 327 | labels[left]=label 328 | return data,labels 329 | 330 | def get_semeval2017_data(data_path="data/SemEval2017/docsutf8",labels_path="data/SemEval2017/keys"): 331 | 332 | data={} 333 | labels={} 334 | for dirname, dirnames, filenames in os.walk(data_path): 335 | for fname in filenames: 336 | left, right = fname.split('.') 337 | infile = os.path.join(dirname, fname) 338 | # f = open(infile, 'rb') 339 | # text = f.read().decode('utf8') 340 | with codecs.open(infile, "r", "utf-8") as fi: 341 | text = fi.read() 342 | text = text.replace("%", '') 343 | text = clean_text(text,database="Semeval2017") 344 | data[left] = text.lower() 345 | # f.close() 346 | for dirname, dirnames, filenames in os.walk(labels_path): 347 | for fname in filenames: 348 | left, right = fname.split('.') 349 | infile = os.path.join(dirname, fname) 350 | f = open(infile, 'rb') 351 | text = f.read().decode('utf8') 352 | text = text.strip() 353 | ls=text.splitlines() 354 | labels[left] = ls 355 | f.close() 356 | return data,labels 357 | 358 | 359 | def remove (text): 360 | text_len = len(text.split()) 361 | remove_chars = '[’!"#$%&\'()*+,./:;<=>?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+' 362 | text = re.sub(remove_chars, '', text) 363 | re_text_len = len(text.split()) 364 | if text_len != re_text_len: 365 | return True 366 | else: 367 | return False 368 | 369 | 370 | def dedup_stem(candidates): 371 | new_can = {} 372 | can_dedup_stemmed = {} 373 | for can in candidates: 374 | can_dedup_stemmed[' '.join(porter.stem(t) for t in can.split())] = can 375 | 376 | for stemmed_can, can in can_dedup_stemmed.items(): 377 | re_flag = remove(can) 378 | if re_flag: 379 | candidate_tokens = tokenizer.tokenize(can) 380 | candidate_len = len(candidate_tokens) 381 | new_can[can.lower()] = candidate_len 382 | return new_can 383 | 384 | 385 | 386 | def generate_absent_doc(ori_encode_dict, candidates, idx): 387 | 388 | count = 0 389 | doc_pairs = [] 390 | ori_input_ids = ori_encode_dict["input_ids"].squeeze() 391 | ori_tokens = tokenizer.convert_ids_to_tokens(ori_input_ids) 392 | 393 | # There are multi candidates for a document 394 | for id, candidate in enumerate(candidates): 395 | 396 | # Remove stopwords in a candidate 397 | if remove(candidate): 398 | count +=1 399 | continue 400 | 401 | tok_candidate = tokenizer.tokenize(candidate) 402 | candidate_len = len(tok_candidate) 403 | mask = ' '.join(['[MASK]'] * candidate_len) 404 | ori_doc = ' '.join(ori_tokens) 405 | can_token = ' '.join(tok_candidate) 406 | 407 | try: 408 | candidate_re = re.compile(r"\b" + can_token + r"\b") 409 | masked_doc = re.sub(candidate_re, mask, ori_doc) 410 | match = candidate_re.findall(ori_doc) 411 | except: 412 | print("cannot replace: ", candidate) 413 | count +=1 414 | continue 415 | if len(match) == 0: 416 | count +=1 417 | print("candidate:", can_token) 418 | print("ori_docs: ", ori_tokens) 419 | print("do not find: ", candidate) 420 | continue 421 | 422 | masked_tokens = masked_doc.split() 423 | masked_input_ids = tokenizer.convert_tokens_to_ids(masked_tokens) 424 | len_masked_tokens = len(masked_tokens) - masked_tokens.count('[PAD]') 425 | 426 | try: 427 | assert len(masked_input_ids) == 512 428 | except: 429 | count +=1 430 | print("unmcatched: ", candidate) 431 | continue 432 | 433 | masked_attention_mask = np.zeros(MAX_LEN) 434 | masked_attention_mask[:len_masked_tokens] = 1 435 | masked_token_type_ids = np.zeros(MAX_LEN) 436 | masked_encode_dict = { 437 | "input_ids": torch.Tensor(masked_input_ids).to(torch.long), 438 | "token_type_ids": torch.Tensor(masked_token_type_ids).to(torch.long), 439 | "attention_mask": torch.Tensor(masked_attention_mask).to(torch.long), 440 | "candidate": candidate, 441 | "freq": len(match) 442 | } 443 | doc_pairs.append([ori_encode_dict, masked_encode_dict, idx]) 444 | 445 | return doc_pairs, count 446 | 447 | 448 | def get_PRF(num_c, num_e, num_s): 449 | F1 = 0.0 450 | P = float(num_c) / float(num_e) if num_e!=0 else 0.0 451 | R = float(num_c) / float(num_s) if num_s!=0 else 0.0 452 | if (P + R == 0.0): 453 | F1 = 0 454 | else: 455 | F1 = 2 * P * R / (P + R) 456 | return P, R, F1 457 | 458 | 459 | def print_PRF(P, R, F1, N): 460 | 461 | log.logger.info("\nN=" + str(N)) 462 | log.logger.info("P=" + str(P)) 463 | log.logger.info("R=" + str(R)) 464 | log.logger.info("F1=" + str(F1)) 465 | return 0 466 | 467 | def mean_pooling(model_output, attention_mask): 468 | hidden_states = model_output.hidden_states 469 | token_embeddings = hidden_states[args.layer_num] #First element of model_output contains all token embeddings 470 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 471 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 472 | 473 | def max_pooling(model_output, attention_mask): 474 | hidden_states = model_output.hidden_states 475 | token_embeddings = hidden_states[args.layer_num] # First element of model_output contains all token embeddings 476 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 477 | token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value 478 | return torch.max(token_embeddings, 1)[0] 479 | 480 | def cls_emebddings(model_output): 481 | hidden_states = model_output.hidden_states 482 | token_embeddings = hidden_states[args.layer_num] #First element of model_output contains all token embeddings 483 | doc_embeddings = token_embeddings[:,0,:] 484 | return doc_embeddings 485 | 486 | 487 | def keyphrases_selection(doc_list, labels_stemed, labels, model, dataloader, log): 488 | 489 | model.eval() 490 | 491 | cos_similarity_list = {} 492 | candidate_list = [] 493 | cos_score_list = [] 494 | doc_id_list = [] 495 | 496 | P = R = F1 = 0.0 497 | num_c_5 = num_c_10 = num_c_15 = 0 498 | num_e_5 = num_e_10 = num_e_15 = 0 499 | num_s = 0 500 | lamda = 0.0 501 | 502 | for id, [ori_doc, masked_doc, doc_id] in enumerate(tqdm(dataloader,desc="Evaluating:")): 503 | 504 | ori_input_ids = torch.squeeze(ori_doc["input_ids"].to('cuda'),1) 505 | ori_token_type_ids = torch.squeeze(ori_doc["token_type_ids"].to('cuda'), 1) 506 | ori_attention_mask = torch.squeeze(ori_doc["attention_mask"].to('cuda'), 1) 507 | 508 | masked_input_ids = torch.squeeze(masked_doc["input_ids"].to('cuda'), 1) 509 | masked_token_type_ids = torch.squeeze(masked_doc["token_type_ids"].to('cuda'), 1) 510 | masked_attention_mask = torch.squeeze(masked_doc["attention_mask"].to('cuda'), 1) 511 | candidate = masked_doc["candidate"] 512 | 513 | # Predict hidden states features for each layer 514 | with torch.no_grad(): 515 | # See the models docstrings for the detail of the inputs 516 | ori_outputs = model(input_ids=ori_input_ids, attention_mask=ori_attention_mask, token_type_ids=ori_token_type_ids, output_hidden_states=True) 517 | masked_outputs = model(input_ids=masked_input_ids, attention_mask=masked_attention_mask, token_type_ids=masked_token_type_ids, output_hidden_states=True) 518 | # Transformers models always output tuples. 519 | # See the models docstrings for the detail of all the outputs 520 | # In our case, the first element is the hidden state of the last layer of the Bert model 521 | if args.doc_embed_mode == "mean": 522 | ori_doc_embed =mean_pooling(ori_outputs, ori_attention_mask) 523 | masked_doc_embed = mean_pooling(masked_outputs, masked_attention_mask) 524 | elif args.doc_embed_mode == "cls": 525 | ori_doc_embed = cls_emebddings(ori_outputs) 526 | masked_doc_embed = cls_emebddings(masked_outputs) 527 | elif args.doc_embed_mode == "max": 528 | ori_doc_embed = max_pooling(ori_outputs, ori_attention_mask) 529 | masked_doc_embed = max_pooling(masked_outputs, masked_attention_mask) 530 | 531 | cosine_similarity = torch.cosine_similarity(ori_doc_embed, masked_doc_embed, dim=1).cpu() 532 | score = cosine_similarity 533 | doc_id_list.extend(doc_id.numpy().tolist()) 534 | candidate_list.extend(candidate) 535 | cos_score_list.extend(score.numpy()) 536 | 537 | cos_similarity_list["doc_id"] = doc_id_list 538 | cos_similarity_list["candidate"] = candidate_list 539 | cos_similarity_list["score"] = cos_score_list 540 | cosine_similarity_rank = pd.DataFrame(cos_similarity_list) 541 | 542 | for i in range(len(doc_list)): 543 | doc_results = cosine_similarity_rank.loc[cosine_similarity_rank['doc_id']==i] 544 | ranked_keyphrases = doc_results.sort_values(by='score') 545 | top_k = ranked_keyphrases.reset_index(drop = True) 546 | top_k_can = top_k.loc[:, ['candidate']].values.tolist() 547 | print(top_k) 548 | 549 | candidates_set = set() 550 | candidates_dedup = [] 551 | for temp in top_k_can: 552 | temp = temp[0].lower() 553 | if temp in candidates_set: 554 | continue 555 | else: 556 | candidates_set.add(temp) 557 | candidates_dedup.append(temp) 558 | 559 | #piror 长度2-5且有'-' 560 | # re_cleaned_candidates= [] 561 | # short_cans = [] 562 | # for can in candidates_dedup: 563 | # can_words = can.split() 564 | # if len(can_words) >1: 565 | # re_cleaned_candidates.append(can) 566 | # elif len(can_words) == 1: 567 | # punc = can.split('-') 568 | # if len(punc) >1: 569 | # re_cleaned_candidates.append(can) 570 | # else: 571 | # short_cans.append(can) 572 | # else: 573 | # continue 574 | # re_cleaned_candidates = re_cleaned_candidates + short_cans 575 | log.logger.info("Sorted_Candidate: {} \n".format(top_k_can)) 576 | log.logger.info("Candidates_Dedup: {} \n".format(candidates_dedup)) 577 | 578 | j = 0 579 | Matched = candidates_dedup[:15] 580 | for id, temp in enumerate(candidates_dedup[0:15]): 581 | tokens = temp.split() 582 | tt = ' '.join(porter.stem(t) for t in tokens) 583 | if (tt in labels_stemed[i] or temp in labels[i]): 584 | Matched[id] = [temp] 585 | if (j < 5): 586 | num_c_5 += 1 587 | num_c_10 += 1 588 | num_c_15 += 1 589 | 590 | elif (j < 10 and j >= 5): 591 | num_c_10 += 1 592 | num_c_15 += 1 593 | 594 | elif (j < 15 and j >= 10): 595 | num_c_15 += 1 596 | j += 1 597 | 598 | log.logger.info("TOP-K {}: {} \n".format(i, Matched)) 599 | log.logger.info("Reference {}: {} \n".format(i,labels[i])) 600 | 601 | if (len(top_k[0:5]) == 5): 602 | num_e_5 += 5 603 | else: 604 | num_e_5 += len(top_k[0:5]) 605 | 606 | if (len(top_k[0:10]) == 10): 607 | num_e_10 += 10 608 | else: 609 | num_e_10 += len(top_k[0:10]) 610 | 611 | if (len(top_k[0:15]) == 15): 612 | num_e_15 += 15 613 | else: 614 | num_e_15 += len(top_k[0:15]) 615 | 616 | num_s += len(labels[i]) 617 | 618 | 619 | en_model.close() 620 | p, r, f = get_PRF(num_c_5, num_e_5, num_s) 621 | print_PRF(p, r, f, 5) 622 | p, r, f = get_PRF(num_c_10, num_e_10, num_s) 623 | print_PRF(p, r, f, 10) 624 | p, r, f = get_PRF(num_c_15, num_e_15, num_s) 625 | print_PRF(p, r, f, 15) 626 | 627 | 628 | 629 | 630 | if __name__ == '__main__': 631 | 632 | parser = argparse.ArgumentParser() 633 | parser.add_argument("--dataset_dir", 634 | default=None, 635 | type=str, 636 | required=True, 637 | help="The input dataset.") 638 | parser.add_argument("--dataset_name", 639 | default=None, 640 | type=str, 641 | required=True, 642 | help="The input dataset name.") 643 | parser.add_argument("--doc_embed_mode", 644 | default="mean", 645 | type=str, 646 | required=True, 647 | help="The method for doc embedding.") 648 | parser.add_argument("--layer_num", 649 | default=-1, 650 | type=int, 651 | help="The hidden state layer of BERT.") 652 | parser.add_argument("--batch_size", 653 | default=None, 654 | type=int, 655 | required=True, 656 | help="Batch size for testing.") 657 | parser.add_argument("--checkpoints", 658 | default=None, 659 | type=str, 660 | required=False, 661 | help="Checkpoint for pre-trained Bert model") 662 | parser.add_argument("--log_dir", 663 | default=None, 664 | type=str, 665 | required=True, 666 | help="Path for Logging file") 667 | parser.add_argument("--local_rank", 668 | default=-1, 669 | type=int, 670 | help="local_rank for distributed training on gpus") 671 | parser.add_argument("--no_cuda", 672 | action="store_true", 673 | help="Whether not to use CUDA when available") 674 | args = parser.parse_args() 675 | 676 | start = time.time() 677 | log = Logger(args.log_dir + args.dataset_name + '.kpe.' + args.doc_embed_mode + '.log') 678 | if args.local_rank == -1 or args.no_cuda: 679 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 680 | n_gpu = torch.cuda.device_count() 681 | porter = nltk.PorterStemmer() 682 | 683 | if args.dataset_name =="SemEval2017": 684 | data, referneces = get_semeval2017_data(args.dataset_dir + "/docsutf8", args.dataset_dir + "/keys") 685 | elif args.dataset_name == "DUC2001": 686 | data, referneces = get_duc2001_data(args.dataset_dir) 687 | elif args.dataset_name == "nus" : 688 | data, referneces = get_long_data(args.dataset_dir + "/nus_test.json") 689 | elif args.dataset_name == "krapivin": 690 | data, referneces = get_long_data(args.dataset_dir + "/krapivin_test.json") 691 | elif args.dataset_name == "kp20k": 692 | data, referneces = get_short_data(args.dataset_dir + "/kp20k_valid200_test.json") 693 | elif args.dataset_name == "SemEval2010": 694 | data, referneces = get_short_data(args.dataset_dir + "/semeval_test.json") 695 | else: 696 | data, referneces = get_inspec_data(args.dataset_dir) 697 | 698 | docs_pairs = [] 699 | doc_list = [] 700 | labels = [] 701 | labels_stemed = [] 702 | t_n = 0 703 | candidate_num = 0 704 | 705 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 706 | model = BertForMaskedLM.from_pretrained('bert-base-uncased') 707 | 708 | if os.path.exists(args.checkpoints): 709 | log.logger.info("Loading Checkpoint ...") 710 | accelerator = Accelerator() 711 | unwrapped_model = accelerator.unwrap_model(model) 712 | unwrapped_model.load_state_dict(torch.load(args.checkpoints)) 713 | 714 | log.logger.info("Start Testing ...") 715 | model.to(device) 716 | 717 | for idx, (key, doc) in enumerate(data.items()): 718 | 719 | # Get stemmed labels and document segments 720 | labels.append([ref.replace(" \n", "") for ref in referneces[key]]) 721 | labels_s = [] 722 | set_total_cans = set() 723 | for l in referneces[key]: 724 | tokens = l.split() 725 | labels_s.append(' '.join(porter.stem(t) for t in tokens)) 726 | 727 | doc = ' '.join(doc.split()[:512]) 728 | labels_stemed.append(labels_s) 729 | doc_list.append(doc) 730 | 731 | # Statistic on empty docs 732 | empty_doc = 0 733 | try: 734 | text_obj = InputTextObj(en_model, doc) 735 | except: 736 | empty_doc += 1 737 | print("doc: ", doc) 738 | 739 | # Generate candidates (lower) 740 | cans = text_obj.keyphrase_candidate 741 | candidates = [] 742 | for can, pos in cans: 743 | candidates.append(can.lower()) 744 | candidate_num += len(candidates) 745 | 746 | ori_encode_dict = tokenizer.encode_plus( 747 | doc, # Sentence to encode. 748 | add_special_tokens=True, # Add '[CLS]' and '[SEP]' 749 | max_length=MAX_LEN, # Pad & truncate all sentences. 750 | padding='max_length', 751 | return_attention_mask=True, # Construct attn. masks. 752 | return_tensors='pt', # Return pytorch tensors. 753 | truncation=True 754 | ) 755 | 756 | doc_pairs, count = generate_absent_doc(ori_encode_dict, candidates, idx) 757 | docs_pairs.extend(doc_pairs) 758 | t_n +=count 759 | 760 | print("candidate_num: ", candidate_num) 761 | print("unmatched: ", t_n) 762 | dataset = KPE_Dataset(docs_pairs) 763 | print("examples: ", dataset.total_examples) 764 | dataloader = DataLoader(dataset, batch_size=args.batch_size) 765 | 766 | keyphrases_selection(doc_list, labels_stemed, labels, model, dataloader, log) 767 | end = time.time() 768 | log.logger.info("Processing time: {}".format(end-start)) 769 | 770 | 771 | 772 | --------------------------------------------------------------------------------