├── src ├── __init__.py ├── frame │ ├── __init__.py │ ├── ir │ │ ├── __init__.py │ │ ├── ir_config.py │ │ ├── ir_passage_tf_tdqfs.py │ │ ├── ir_passage_tf.py │ │ ├── ir_tf.py │ │ ├── ir_tools.py │ │ └── ir_tf_tdqfs.py │ ├── centrality │ │ ├── __init__.py │ │ ├── centrality_ensemble_config.py │ │ ├── centrality_config.py │ │ ├── centrality_tfidf.py │ │ └── centrality_tfidf_records.py │ ├── bert_passage │ │ ├── __init__.py │ │ ├── config_name.py │ │ ├── passage_obj.py │ │ ├── build_passage.py │ │ ├── build_passage_tdqfs.py │ │ ├── bert_input.py │ │ └── data_pipe_cluster.py │ ├── bert_qa │ │ ├── config_name.py │ │ ├── bert_input_cosine.py │ │ ├── qa_config.py │ │ ├── bert_input.py │ │ ├── data_pipe_cluster.py │ │ └── data_pipe_cluster_cosine.py │ └── bert_ensemble │ │ ├── ensemble_config.py │ │ └── ensemble.py ├── scripts │ ├── __init__.py │ └── proc_tdqfs.py ├── config │ ├── config_model_bert_base.yml │ ├── config_model_bert_qa.yml │ ├── config_model_bert_passage.yml │ └── config_meta.yml ├── data │ ├── data_pipe.py │ ├── data_tools.py │ ├── clip_and_mask_sl.py │ ├── bert_input_sl.py │ ├── bert_input.py │ ├── bert_input_sep.py │ └── data_pipe_cluster.py ├── baselines │ ├── lead_tqdfs.py │ ├── lead.py │ ├── lexrank │ │ ├── lexrank_tfidf.py │ │ ├── grsum.py │ │ └── lexrank_tfidf_tdqfs.py │ └── human.py ├── summ │ ├── rank_sent.py │ ├── build_summary_targets.py │ └── compute_rouge.py ├── tools │ ├── general_tools.py │ └── vec_tools.py └── utils │ ├── graph_io.py │ ├── config_loader.py │ ├── graph_tools.py │ └── tools.py ├── .gitignore ├── LICENSE ├── requirements.txt └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/frame/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/frame/ir/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/frame/centrality/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/frame/bert_passage/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | lib/ 3 | include/ 4 | /data/ 5 | man/ 6 | /res/ 7 | /model/ 8 | /performances/ 9 | /pred/ 10 | /summary/ 11 | /graph/ 12 | /mturk/ 13 | etc/ 14 | share/ 15 | src/.ipynb_checkpoints/ 16 | src/temp 17 | /dataset/ 18 | *.tar.gz 19 | *.ipynb 20 | pip-selfcheck.json 21 | .Python 22 | *.log 23 | __pycache__ 24 | archived_env/ 25 | .vscode/ 26 | .idea/ 27 | pyvenv.cfg 28 | -------------------------------------------------------------------------------- /src/config/config_model_bert_base.yml: -------------------------------------------------------------------------------- 1 | --- 2 | variation: 'BertBase' 3 | 4 | model_name: 'BertBase' 5 | fine_tune: 'qa' 6 | 7 | n_epochs: 40 8 | n_batches: 10000 9 | 10 | # input size 11 | # input: d_batch * max_ns_doc * max_n_tokens 12 | d_batch: 32 # 32, 64, 128, 200 13 | 14 | max_ns_doc: 100 15 | 16 | max_nw_query: 100 17 | max_nw_sent: 70 18 | max_n_tokens: 173 # [CLS], [SEP] * 2 19 | 20 | d_embed: 768 -------------------------------------------------------------------------------- /src/config/config_model_bert_qa.yml: -------------------------------------------------------------------------------- 1 | --- 2 | variation: 'BertQA' 3 | 4 | model_name: 'BertQA' 5 | fine_tune: 'qa' 6 | 7 | n_epochs: 40 8 | n_batches: 10000 9 | 10 | # input size 11 | # input: d_batch * max_ns_doc * max_n_tokens 12 | d_batch: 32 # 32, 64, 128, 200 13 | 14 | max_ns_doc: 100 15 | 16 | max_nw_query: 100 17 | max_nw_sent: 70 18 | max_n_tokens: 173 # [CLS], [SEP] * 2 19 | 20 | d_embed: 768 21 | -------------------------------------------------------------------------------- /src/config/config_model_bert_passage.yml: -------------------------------------------------------------------------------- 1 | --- 2 | variation: 'BertPassage' 3 | 4 | model_name: 'BertPassage' 5 | fine_tune: 'passage' 6 | 7 | n_epochs: 40 8 | n_batches: 10000 9 | 10 | # input size 11 | # input: d_batch * max_ns_doc * max_n_tokens 12 | d_batch: 24 # 16, 32, 64, 128, 200 13 | 14 | max_ns_doc: 100 15 | 16 | max_nw_query: 100 # 100 17 | max_nw_sent: 50 18 | ns_passage: 8 19 | stride: 4 20 | max_n_tokens: 503 # [CLS], [SEP] * 2, 100 + 50 * 8 + 3 21 | 22 | d_embed: 768 -------------------------------------------------------------------------------- /src/config/config_meta.yml: -------------------------------------------------------------------------------- 1 | --- 2 | model_name: 'bert_qa' # bert_qa, bert_passage, bert_base 3 | test_year: '2007' # 2005, 2006, 2007, tdqfs 4 | grain: 'sent' # sent, passage 5 | 6 | mode: 'train' # train, rank_sent, select_sent 7 | para_org: null # True, null 8 | 9 | auto_parallel: True 10 | 11 | word_tokenizer: 'bert' # bert, nltk 12 | preload_model_tokenizer: True # True when using pre-trained Bert model: None, mb, qa, bing 13 | vocab: 'bert' # bert 14 | texttiling: True 15 | 16 | remove_dialog: False 17 | -------------------------------------------------------------------------------- /src/frame/bert_qa/config_name.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from io import open 3 | import os 4 | from os.path import join, dirname, abspath 5 | import sys 6 | 7 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 8 | 9 | config_root = join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'config') 10 | 11 | # meta 12 | config_meta_fp = os.path.join(config_root, 'config_meta.yml') 13 | config_meta = yaml.load(open(config_meta_fp, 'r', encoding='utf-8')) 14 | 15 | # model 16 | meta_model_name = config_meta['model_name'] 17 | config_model_fn = 'config_model_{0}.yml'.format(meta_model_name) 18 | config_model_fp = os.path.join(config_root, config_model_fn) 19 | config_model = yaml.load(open(config_model_fp, 'r')) 20 | 21 | if config_model['model_name']: 22 | model_name = config_model['model_name'] 23 | -------------------------------------------------------------------------------- /src/frame/bert_passage/config_name.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | import yaml 4 | from io import open 5 | import os 6 | from os.path import join, dirname, abspath 7 | import sys 8 | 9 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 10 | 11 | config_root = join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'config') 12 | 13 | # meta 14 | config_meta_fp = os.path.join(config_root, 'config_meta.yml') 15 | config_meta = yaml.load(open(config_meta_fp, 'r', encoding='utf-8')) 16 | 17 | # model 18 | meta_model_name = config_meta['model_name'] 19 | config_model_fn = 'config_model_{0}.yml'.format(meta_model_name) 20 | config_model_fp = os.path.join(config_root, config_model_fn) 21 | config_model = yaml.load(open(config_model_fp, 'r')) 22 | model_name = config_model['model_name'] 23 | -------------------------------------------------------------------------------- /src/data/data_pipe.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from os import listdir 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | from os.path import join, isdir, dirname, abspath 7 | import sys 8 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 9 | 10 | import utils.config_loader as config 11 | from utils.config_loader import path_parser, logger, config_model, config_meta 12 | 13 | from data.dataset_parser import dataset_parser 14 | import data.bert_input as bert_in 15 | import data.bert_input_sep as bert_input_sep 16 | import data.bert_input_sl as bert_input_sl 17 | import data.data_tools as data_tools 18 | import utils.tools as tools 19 | 20 | 21 | class To2DMat(object): 22 | def __call__(self, numpy_dict): 23 | for (k, v) in numpy_dict.items(): 24 | # logger.info('[BEFORE TO TENSOR] type of {0}: {1}'.format(k, v.dtype)) 25 | if k in ('token_ids', 'seg_ids', 'token_masks'): 26 | numpy_dict[k] = v.reshape(-1, config_model['max_n_tokens']) 27 | 28 | return numpy_dict 29 | 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yumo Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/frame/centrality/centrality_ensemble_config.py: -------------------------------------------------------------------------------- 1 | import utils.config_loader as config 2 | import frame.bert_ensemble.ensemble_config as ensemble_config 3 | 4 | # set following macro vars 5 | QUERY_TYPE = config.NARR 6 | LINK_TYPE = 'uniform' # inter, intra, uniform 7 | 8 | QA_RECORD_DIR_NAME = ensemble_config.QA_RECORD_DIR_NAME # sentence lookup for using hard filtering 9 | 10 | DAMP = 0.85 # 1.0, 0.85 11 | 12 | DIVERSITY_ALGORITHM = 'wan' # wan 13 | OMEGA = 0 # 4 (sentence), 2 (passage) 14 | COS_THRESHOLD = 0.6 # 0.5, 0.6, 1.0 15 | 16 | if DIVERSITY_ALGORITHM == 'wan': 17 | DIVERSITY_PARAM_TUPLE = (OMEGA, DIVERSITY_ALGORITHM) 18 | else: 19 | raise ValueError('Invalid DIVERSITY_ALGORITHM: {}'.format(DIVERSITY_ALGORITHM)) 20 | 21 | QA_SCORE_DIR_NAME = QA_RECORD_DIR_NAME 22 | BIAS_TYPE = 'hard' # hard, soft 23 | 24 | CENTRALITY_MODEL_NAME_BASIC = 'centrality-{}_bias-{}_damp-{}'.format(BIAS_TYPE, DAMP, QA_SCORE_DIR_NAME) 25 | CENTRALITY_TUNE_DIR_NAME_BASIC = 'centrality_tune-{0}-{1}_cos-{2}'.format(CENTRALITY_MODEL_NAME_BASIC, 26 | COS_THRESHOLD, DIVERSITY_ALGORITHM) 27 | 28 | # LENGTH_BUDGET_TUPLE = ('ns', 7) 29 | LENGTH_BUDGET_TUPLE = ('nw', 250) 30 | -------------------------------------------------------------------------------- /src/frame/bert_qa/bert_input_cosine.py: -------------------------------------------------------------------------------- 1 | import utils.config_loader as config 2 | from utils.config_loader import config_model 3 | from data.dataset_parser import dataset_parser 4 | import numpy as np 5 | 6 | 7 | def _build_bert_in(tokens): 8 | in_size = [config_model['max_n_tokens'], ] 9 | 10 | token_ids = np.zeros(in_size, dtype=np.int32) 11 | seg_ids = np.zeros(in_size, dtype=np.int32) 12 | token_masks = np.zeros(in_size) 13 | 14 | tokens = ['[CLS]'] + tokens + ['[SEP]'] 15 | token_id_list = config.bert_tokenizer.convert_tokens_to_ids(tokens) 16 | n_tokens = len(token_id_list) 17 | 18 | token_ids[:n_tokens] = token_id_list 19 | token_masks[:n_tokens] = [1] * n_tokens 20 | 21 | bert_in = { 22 | 'token_ids': token_ids, 23 | 'seg_ids': seg_ids, 24 | 'token_masks': token_masks, 25 | } 26 | 27 | return bert_in 28 | 29 | 30 | def build_query(query): 31 | query_tokens = dataset_parser.parse_query(query) 32 | return _build_bert_in(query_tokens) 33 | 34 | 35 | def build_sentence(sentence): 36 | sentence_tokens = dataset_parser.sent2words(sentence)[:config_model['max_nw_sent']] 37 | return _build_bert_in(sentence_tokens) 38 | -------------------------------------------------------------------------------- /src/frame/ir/ir_config.py: -------------------------------------------------------------------------------- 1 | import utils.config_loader as config 2 | from utils.config_loader import config_meta 3 | 4 | if config.grain == 'sent': 5 | IR_META_NAME = 'ir' 6 | else: 7 | IR_META_NAME = 'ir-{}'.format(config.grain) # e.g., ir-passage 8 | 9 | QUERY_TYPE = None # config.NARR, config.TITLE, None (concat narr and title), REF (oracle) 10 | if QUERY_TYPE: 11 | CONCAT_TITLE_NARR = False 12 | IR_META_NAME = '{}-{}'.format(IR_META_NAME, QUERY_TYPE) 13 | else: 14 | CONCAT_TITLE_NARR = True 15 | 16 | test_year = config_meta['test_year'] 17 | IR_MODEL_NAME_TF = '{}-tf-{}'.format(IR_META_NAME, test_year) 18 | 19 | DEDUPLICATE = False 20 | 21 | CONF_THRESHOLD_IR = 0.75 22 | TOP_NUM_IR = 90 23 | 24 | FILTER = 'conf' # conf, comp, topK (only used in ablation study) 25 | if FILTER == 'conf': 26 | FILTER_VAR = CONF_THRESHOLD_IR 27 | elif FILTER == 'topK': 28 | FILTER_VAR = TOP_NUM_IR 29 | else: 30 | raise ValueError('Invalid FILTER: {}'.format(FILTER)) 31 | 32 | IR_RECORDS_DIR_NAME_PATTERN = 'ir_records-{}-{}_ir_{}' 33 | 34 | if DEDUPLICATE: 35 | IR_RECORDS_DIR_NAME_PATTERN += '-dedup' 36 | 37 | IR_RECORDS_DIR_NAME_TF = IR_RECORDS_DIR_NAME_PATTERN.format(IR_MODEL_NAME_TF, FILTER_VAR, FILTER) 38 | IR_TUNE_DIR_NAME_TF = 'ir_tune-{}'.format(IR_MODEL_NAME_TF) 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | apex 3 | appdirs==1.4.3 4 | astor==0.8.0 5 | bert-serving-client==1.9.8 6 | bert-serving-server==1.9.8 7 | boto==2.49.0 8 | boto3==1.9.188 9 | botocore==1.12.188 10 | certifi==2019.6.16 11 | chardet==3.0.4 12 | cycler==0.10.0 13 | dill==0.3.1.1 14 | docutils==0.14 15 | gast==0.2.2 16 | gensim==3.7.3 17 | GPUtil==1.4.0 18 | grpcio==1.20.1 19 | h5py==2.9.0 20 | idna==2.8 21 | importlib-metadata==0.8 22 | jmespath==0.9.4 23 | Keras-Applications==1.0.7 24 | Keras-Preprocessing==1.0.9 25 | kiwisolver==1.0.1 26 | lexrank==0.1.0 27 | Markdown==3.1.1 28 | matplotlib==3.0.2 29 | mock==3.0.5 30 | nltk==3.4 31 | numpy==1.16.4 32 | path.py==11.5.0 33 | Pillow==5.4.1 34 | protobuf==3.7.1 35 | pyparsing==2.3.1 36 | pyrouge==0.1.3 37 | pyrsistent==0.14.11 38 | python-dateutil==2.8.0 39 | pytorch-pretrained-bert==0.4.0 40 | pytorch-transformers==1.0.0 41 | PyYAML==3.13 42 | pyzmq==18.1.1 43 | regex==2019.6.8 44 | requests==2.22.0 45 | s3transfer==0.2.1 46 | scikit-learn==0.20.2 47 | scipy==1.3.2 48 | sentencepiece==0.1.82 49 | singledispatch==3.4.0.3 50 | six==1.12.0 51 | sklearn==0.0 52 | smart-open==1.8.3 53 | tensorboard==1.12.2 54 | tensorboardX==1.8 55 | tensorflow-estimator==1.13.0 56 | tensorflow-gpu==1.12.0 57 | tensorflow-hub==0.4.0 58 | termcolor==1.1.0 59 | torch==1.1.0 60 | torchvision==0.2.1 61 | tqdm==4.32.2 62 | uritools==2.2.0 63 | urlextract==0.9 64 | urllib3==1.25.3 65 | Werkzeug==0.15.4 66 | xlwt==1.3.0 67 | zipp==0.3.3 68 | -------------------------------------------------------------------------------- /src/frame/bert_qa/qa_config.py: -------------------------------------------------------------------------------- 1 | import utils.config_loader as config 2 | import frame.ir.ir_config as ir_config 3 | 4 | # set following macro vars 5 | QUERY_TYPE = config.NARR 6 | 7 | # IR configs: the method should be consistent 8 | IR_MODEL_NAME = ir_config.IR_MODEL_NAME_TF # for building sid2sent for contextual QA models 9 | # IR_RECORDS_DIR_NAME: sentence lookup, IR_MODEL_NAME_TF (full), IR_RECORDS_DIR_NAME_TF (retrieved) 10 | IR_RECORDS_DIR_NAME = ir_config.IR_MODEL_NAME_TF 11 | 12 | if config.meta_model_name == 'bert_qa': 13 | BERT_TYPE = 'bert' 14 | elif config.meta_model_name == 'bert_base': 15 | BERT_TYPE = 'bert_base' 16 | elif config.meta_model_name == 'bert_passage': 17 | BERT_TYPE = 'bert_passage-{}'.format(config.bert_passage_iter) 18 | else: 19 | raise ValueError('Invalid mode_name: {}'.format(config.meta_model_name)) 20 | 21 | QA_MODEL_NAME_BERT = 'qa-{}-{}-{}'.format(BERT_TYPE, QUERY_TYPE, IR_RECORDS_DIR_NAME) 22 | 23 | RELEVANCE_SCORE_DIR_NAME = QA_MODEL_NAME_BERT 24 | 25 | # filter config 26 | FILTER = 'topK' # topK, conf 27 | 28 | CONF_THRESHOLD_QA = 0.95 29 | TOP_NUM_QA =90 # 90: sentence, 110: passage 30 | 31 | if FILTER == 'conf': 32 | FILTER_VAR = CONF_THRESHOLD_QA 33 | elif FILTER == 'topK': 34 | FILTER_VAR = TOP_NUM_QA 35 | else: 36 | raise ValueError('Invalid FILTER: {}'.format(FILTER)) 37 | 38 | QA_RECORD_DIR_NAME_PATTERN = 'qa_records-{}-{}_qa_{}' # model name, conf 39 | QA_RECORD_DIR_NAME_BERT = QA_RECORD_DIR_NAME_PATTERN.format(QA_MODEL_NAME_BERT, FILTER_VAR, FILTER) 40 | 41 | QA_TUNE_DIR_NAME_PATTERN = 'qa_tune-{}' 42 | QA_TUNE_DIR_NAME_BERT = QA_TUNE_DIR_NAME_PATTERN.format(QA_MODEL_NAME_BERT) 43 | -------------------------------------------------------------------------------- /src/baselines/lead_tqdfs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import io 3 | import os 4 | from os import listdir 5 | from os.path import join, dirname, abspath, isfile, exists, isdir 6 | sys_path = dirname(dirname(abspath(__file__))) 7 | parent_sys_path = dirname(sys_path) 8 | 9 | if sys_path not in sys.path: 10 | sys.path.insert(0, sys_path) 11 | if parent_sys_path not in sys.path: 12 | sys.path.insert(0, parent_sys_path) 13 | 14 | from utils.config_loader import logger, path_parser 15 | import summ.compute_rouge as rouge 16 | 17 | MODEL_NAME = 'lead-tqdfs' 18 | cids = [cid for cid in listdir(path_parser.data_tdqfs_sentences) if isdir(join(path_parser.data_tdqfs_sentences, cid))] 19 | 20 | text_dp = join(path_parser.summary_text, MODEL_NAME) 21 | assert not exists(text_dp), f'{text_dp} exists!' 22 | os.mkdir(text_dp) 23 | 24 | 25 | def _get_lines(cid): 26 | sents = [] 27 | for doc_idx in range(10): 28 | if len(sents) >= 50: 29 | return sents 30 | fp = join(path_parser.data_tdqfs_sentences, cid, str(doc_idx)) 31 | lines = [line.strip('\n') for line in io.open(fp).readlines()] 32 | sents.extend([line for line in lines if line]) 33 | return sents 34 | 35 | def select(): 36 | for cid in cids: 37 | lines = _get_lines(cid) 38 | logger.info(f'{cid}: {len(lines)}') 39 | io.open(join(text_dp, cid), mode='a').write('\n'.join(lines)) 40 | 41 | 42 | def compute_rouge(): 43 | rouge_parmas = { 44 | 'text_dp': text_dp, 45 | 'ref_dp': path_parser.data_tdqfs_summary_targets, 46 | 'length': 250, 47 | } 48 | output = rouge.compute_rouge_for_tdqfs(**rouge_parmas) 49 | return output 50 | 51 | 52 | if __name__ == "__main__": 53 | select() 54 | compute_rouge() 55 | -------------------------------------------------------------------------------- /src/baselines/lead.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os import listdir 4 | from os.path import join, dirname, abspath, isfile, exists 5 | sys_path = dirname(dirname(abspath(__file__))) 6 | parent_sys_path = dirname(sys_path) 7 | 8 | if sys_path not in sys.path: 9 | sys.path.insert(0, sys_path) 10 | if parent_sys_path not in sys.path: 11 | sys.path.insert(0, parent_sys_path) 12 | 13 | import shutil 14 | import re 15 | import utils.config_loader as config 16 | from utils.config_loader import logger, path_parser 17 | import summ.compute_rouge as rouge 18 | 19 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 20 | 21 | MODEL_NAME = 'lead-{}'.format(config.test_year) 22 | 23 | def extract_lead_summaries(): 24 | if config.test_year == '2006': 25 | dn = '2006/NIST/NISTeval/ROUGE/peers' 26 | elif config.test_year == '2007': 27 | dn = '2007/mainEval/ROUGE/peers' 28 | else: 29 | raise ValueError('Invalid test_year: {}'.format(config.test_year)) 30 | 31 | lead_dp = join(path_parser.data_summary_results, dn) 32 | fns = [fn for fn in listdir(lead_dp) if isfile(join(lead_dp, fn))] 33 | ref_pat = re.compile('[\S]+.M.250.\D.1$') 34 | ref_fns = [fn for fn in fns if re.search(ref_pat, fn)] 35 | print(ref_fns) 36 | 37 | out_dp = join(path_parser.summary_text, MODEL_NAME) 38 | if exists(out_dp): 39 | raise ValueError('out_dp exists: {}'.format(out_dp)) 40 | os.mkdir(out_dp) 41 | 42 | for fn in ref_fns: 43 | out_fn = '_'.join((config.test_year, fn.split('.')[0] + fn.split('.')[-2])) 44 | shutil.copyfile(join(lead_dp, fn), join(out_dp, out_fn)) 45 | 46 | def compute_rouge(): 47 | rouge.compute_rouge(model_name=MODEL_NAME) 48 | 49 | 50 | if __name__ == '__main__': 51 | extract_lead_summaries() 52 | compute_rouge() 53 | -------------------------------------------------------------------------------- /src/frame/bert_passage/passage_obj.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | from os.path import join, dirname, abspath, exists 4 | 5 | sys_path = dirname(dirname(abspath(__file__))) 6 | parent_sys_path = dirname(sys_path) 7 | 8 | if sys_path not in sys.path: 9 | sys.path.insert(0, sys_path) 10 | if parent_sys_path not in sys.path: 11 | sys.path.insert(0, parent_sys_path) 12 | 13 | import utils.config_loader as config 14 | from utils.config_loader import path_parser 15 | import dill 16 | 17 | class SentObj: 18 | def __init__(self, sid, original_sent, proc_sent): 19 | self.sid = sid # config.SEP.join([cid, str(sent_idx)]); for score gathering in QA module 20 | self.original_sent = original_sent 21 | self.proc_sent = proc_sent 22 | 23 | 24 | class PassageObj: 25 | def __init__(self, pid, query, narr, sent_objs): 26 | self.pid = pid # config.SEP.join([cid, str(passage_idx)]) 27 | self.query = query 28 | self.narr = narr 29 | 30 | self.sent_objs = sent_objs 31 | self.size = len(self.sent_objs) 32 | 33 | self.ir_score = None 34 | 35 | def get_original_sents(self): 36 | return [so.original_sent for so in self.sent_objs] 37 | 38 | def get_proc_sents(self): 39 | return [so.proc_sent for so in self.sent_objs] 40 | 41 | def get_proc_passage(self): 42 | return ' '.join(self.get_proc_sents()) 43 | 44 | def get_original_passage(self): 45 | return ' '.join(self.get_original_sents()) 46 | 47 | 48 | def pid2obj(cid, pid, use_tdqfs): 49 | if use_tdqfs: 50 | fp = join(path_parser.data_tdqfs_passages, cid, pid) 51 | else: 52 | year, _ = cid.split(config.SEP) 53 | fp = join(path_parser.data_passages, year, cid, pid) 54 | 55 | with open(fp, 'rb') as f: 56 | po = dill.load(f) 57 | return po 58 | -------------------------------------------------------------------------------- /src/frame/bert_ensemble/ensemble_config.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os.path import join, dirname, abspath 3 | 4 | sys_path = dirname(dirname(abspath(__file__))) 5 | parent_sys_path = dirname(sys_path) 6 | 7 | if sys_path not in sys.path: 8 | sys.path.insert(0, sys_path) 9 | if parent_sys_path not in sys.path: 10 | sys.path.insert(0, parent_sys_path) 11 | 12 | import utils.config_loader as config 13 | from utils.config_loader import path_parser 14 | 15 | assert config.grain == 'sent' # rank and its records are saved under sent directory 16 | 17 | SENT_QA_REC_DIR_NAME = f'qa_records-qa-bert-narr-ir_records-ir-tf-{config.test_year}-0.75_ir_conf-90_qa_topK' 18 | PASSAGE_QA_REC_DIR_NAME = f'qa_records-qa-bert_passage-12000-narr-ir_records-ir-passage-tf-{config.test_year}-0.75_ir_conf-90_qa_topK' 19 | SENT_QA_RECORD_DP = join(path_parser.proj_root, 'rank', SENT_QA_REC_DIR_NAME) 20 | PASSAGE_QA_RECORD_DP = join(path_parser.proj_root, 'rank_passage', PASSAGE_QA_REC_DIR_NAME) 21 | 22 | # avg: avg two scores; if there is only one score, keep it. 23 | # sqrt; sqrt two scores; if there is only one score, keep it. 24 | # avg_global: avg two scores; if there is only one score, halve it. 25 | # sqrt_global: sqrt two scores; if there is only one score, sqrt it. 26 | # weight_avg: (1- \mu) * sent_score + \mu * span_score 27 | # weight_avg_sent_only: (1- \mu) * sent_score + \mu * span_score; use only records from sent model. 28 | ENSEMBLE_MODE = 'weight_avg_sent_only' 29 | IS_ENSEMBLE_GLOBAL = ENSEMBLE_MODE.endswith('global') 30 | IS_SENT_REC_ONLY = ENSEMBLE_MODE.endswith('sent_only') 31 | 32 | MODEL_NAME = 'bert_ensemble-{}-{}_mode'.format(config.test_year, ENSEMBLE_MODE) 33 | 34 | FILTER = 'topK' 35 | FILTER_VAR = 90 36 | QA_RECORD_DIR_NAME_PATTERN = 'qa_records-{}-{}_qa_{}' 37 | QA_RECORD_DIR_NAME = QA_RECORD_DIR_NAME_PATTERN.format(MODEL_NAME, FILTER_VAR, FILTER) 38 | 39 | SPAN_REC_WEIGHT = 0.0 # 0.05 40 | SPAN_AFFIX = '-{}_span_weight'.format(SPAN_REC_WEIGHT) 41 | if ENSEMBLE_MODE.startswith('weight_avg'): 42 | MODEL_NAME += SPAN_AFFIX 43 | QA_RECORD_DIR_NAME += SPAN_AFFIX 44 | -------------------------------------------------------------------------------- /src/frame/centrality/centrality_config.py: -------------------------------------------------------------------------------- 1 | import utils.config_loader as config 2 | import frame.bert_qa.qa_config as qa_config 3 | 4 | # set following macro vars 5 | QUERY_TYPE = config.NARR 6 | LINK_TYPE = 'uniform' # inter, intra, uniform 7 | 8 | IR_RECORD_DIR_NAME = qa_config.IR_RECORDS_DIR_NAME # sentence lookup for using rel_vec 9 | 10 | # sentence lookup for using hard filtering 11 | # change to QA_DIR_NAME_BERT when using all sentences (TBC) 12 | QA_RECORD_DIR_NAME = qa_config.QA_RECORD_DIR_NAME_BERT 13 | 14 | QA_RELEVANCE_SCORE_DIR_NAME = qa_config.RELEVANCE_SCORE_DIR_NAME 15 | 16 | DAMP = 0.85 # 1.0 (w/o query bias), 0.85 (w/ query bias) 17 | 18 | DIVERSITY_ALGORITHM = 'wan' 19 | 20 | # for DUC, set it to 4 for sentence, and 2 for passage 21 | # for TD-QFS, we do not use wan's diversity algorithm and set it to 0 22 | OMEGA = 4 23 | 24 | COS_THRESHOLD = 0.6 25 | 26 | BIAS_TYPE = 'hard' # hard, soft 27 | 28 | if DIVERSITY_ALGORITHM == 'wan': 29 | DIVERSITY_PARAM_TUPLE = (OMEGA, DIVERSITY_ALGORITHM) 30 | else: 31 | raise ValueError('Invalid DIVERSITY_ALGORITHM: {}'.format(DIVERSITY_ALGORITHM)) 32 | 33 | if BIAS_TYPE == 'soft': 34 | QA_SCORE_DIR_NAME = QA_RELEVANCE_SCORE_DIR_NAME 35 | elif BIAS_TYPE == 'hard': 36 | QA_SCORE_DIR_NAME = QA_RECORD_DIR_NAME 37 | else: 38 | raise ValueError('Invalid BIAS_TYPE: {}'.format(BIAS_TYPE)) 39 | 40 | CENTRALITY_MODEL_NAME_BASIC = 'centrality-{}_bias-{}_damp-{}'.format(BIAS_TYPE, DAMP, QA_SCORE_DIR_NAME) 41 | REL_VEC_NORM = None 42 | if config.grain == 'passage': 43 | REL_VEC_NORM = 'sqrt_tanh' # None, sqrt_tanh 44 | CENTRALITY_MODEL_NAME_BASIC += '-{}_vec_norm'.format(REL_VEC_NORM) 45 | 46 | 47 | # following is for ablation study of QA module; only IR is used for centrality. 48 | CENTRALITY_MODEL_NAME_wo_QA = 'centrality-{}_bias-{}_damp-{}'.format(BIAS_TYPE, DAMP, IR_RECORD_DIR_NAME) 49 | CENTRALITY_TUNE_DIR_NAME_BASIC = 'centrality_tune-{0}-{1}_cos-{2}'.format(CENTRALITY_MODEL_NAME_BASIC, COS_THRESHOLD, DIVERSITY_ALGORITHM) 50 | 51 | # LENGTH_BUDGET_TUPLE = ('ns', 7) 52 | LENGTH_BUDGET_TUPLE = ('nw', 250) 53 | -------------------------------------------------------------------------------- /src/summ/rank_sent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from os.path import dirname, abspath 3 | import sys 4 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 5 | 6 | import utils.config_loader as config 7 | import utils.config_loader as config 8 | import utils.tools as tools 9 | 10 | import torch 11 | import shutil 12 | 13 | 14 | versions = ['sl', 'alpha'] 15 | para_org = True 16 | for vv in versions: 17 | if config.meta_model_name.endswith(vv): 18 | para_org = False 19 | 20 | 21 | def sort_sid2score(sid2score): 22 | sid_score_list = sorted(sid2score.items(), key=lambda item: item[1], reverse=True) 23 | return sid_score_list 24 | 25 | 26 | def get_rank_records(sid_score_list, sents=None, flat_sents=False): 27 | """ 28 | optional: display sentence in record 29 | :param sid_score_list: 30 | :param sents: 31 | :param flat_sents: if True, iterate sent directly; if False, need use sid to get doc_idx and sent_idx. 32 | :return: 33 | """ 34 | rank_records = [] 35 | for sid, score in sid_score_list: 36 | items = [sid, str(score)] 37 | if sents: 38 | if flat_sents: 39 | sent = sents[len(rank_records)] # the current point 40 | else: 41 | doc_idx, sent_idx = tools.get_sent_info(sid) 42 | sent = sents[doc_idx][sent_idx] 43 | items.append(sent) 44 | record = '\t'.join(items) 45 | rank_records.append(record) 46 | return rank_records 47 | 48 | 49 | def dump_rank_records(rank_records, out_fp, with_rank_idx): 50 | """ 51 | each line is 52 | ranking sid score 53 | 54 | sid: config.SEP.join((doc_idx, para_idx, sent_idx)) 55 | :param sid_score_list: 56 | :param out_fp: 57 | :return: 58 | """ 59 | lines = rank_records 60 | if with_rank_idx: 61 | lines = ['\t'.join((str(rank), record)) for rank, record in enumerate(rank_records)] 62 | 63 | with open(out_fp, mode='a', encoding='utf-8') as f: 64 | f.write('\n'.join(lines)) 65 | 66 | return len(lines) 67 | -------------------------------------------------------------------------------- /src/data/data_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from os.path import join, isdir, dirname, abspath 4 | import sys 5 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 6 | sys.path.insert(0, dirname(dirname(dirname(abspath(__file__))))) 7 | 8 | import utils.config_loader as config 9 | from utils.config_loader import path_parser, logger, config_model, config_meta 10 | 11 | def get_bert_in_func(): 12 | if config.meta_model_name == 'bert_qa': 13 | from frame.bert_qa import bert_input 14 | bert_in_func = bert_input.build_bert_x 15 | else: 16 | from data import bert_input_sep 17 | bert_in_func = bert_input_sep.build_bert_x_sep 18 | 19 | logger.info('Using bert_in_func: {}'.format(config.meta_model_name)) 20 | return bert_in_func 21 | 22 | 23 | class ToTensor(object): 24 | """ 25 | Convert ndarrays in sample to Tensors. 26 | """ 27 | 28 | def __call__(self, numpy_dict): 29 | for (k, v) in numpy_dict.items(): 30 | # logger.info('[BEFORE TO TENSOR] type of {0}: {1}'.format(k, v.dtype)) 31 | 32 | if k.endswith('_ids'): 33 | v = v.type(torch.LongTensor) # for embedding look up 34 | # logger.info('[TO LONG TENSOR] convert {0} => {1}'.format(k, v.dtype)) 35 | 36 | if 'placement' in config_meta and config_meta['placement'] != 'cpu': 37 | # origin_type = v.type() 38 | v = v.cuda() 39 | # logger.info('[TO CUDA TENSOR] {0}: {1} => {2}'.format(k, origin_type, v.type())) 40 | 41 | numpy_dict[k] = v 42 | 43 | # logger.info('is cuda available: {}'.format(torch.cuda.is_available())) 44 | # for (k, v) in numpy_dict.items(): 45 | # # logger.info('[BEFORE TO TENSOR] type of {0}: {1}'.format(k, v.dtype)) 46 | # if type(v) == np.ndarray: 47 | # v = torch.from_numpy(v) 48 | # 49 | # if k.endswith('_ids'): 50 | # v = v.type(torch.LongTensor) # for embedding look up 51 | # logger.info('[TO TENSOR] convert {0} => {1}'.format(k, v.dtype)) 52 | # 53 | # if config.placement in ('auto', 'single'): 54 | # v = v.cuda() 55 | # logger.info('[TO CUDA] type of {0}: {1}'.format(k, v.dtype)) 56 | # 57 | # numpy_dict[k] = v 58 | 59 | return numpy_dict 60 | -------------------------------------------------------------------------------- /src/baselines/lexrank/lexrank_tfidf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | from os.path import join, dirname, abspath, exists 5 | 6 | sys_path = dirname(dirname(abspath(__file__))) 7 | parent_sys_path = dirname(sys_path) 8 | 9 | if sys_path not in sys.path: 10 | sys.path.insert(0, sys_path) 11 | if parent_sys_path not in sys.path: 12 | sys.path.insert(0, parent_sys_path) 13 | 14 | import utils.config_loader as config 15 | from utils.config_loader import logger 16 | from data.dataset_parser import dataset_parser 17 | import utils.tools as tools 18 | import summ.rank_sent as rank_sent 19 | import summ.select_sent as select_sent 20 | 21 | from lexrank import STOPWORDS, LexRank 22 | import itertools 23 | from tqdm import tqdm 24 | 25 | assert config.grain == 'sent' 26 | MODEL_NAME = 'lexrank-{}'.format(config.test_year) 27 | COS_THRESHOLD = 1.0 28 | 29 | def _lexrank(cid): 30 | """ 31 | Run LexRank on all sentences from all documents in a cluster. 32 | 33 | :param cid: 34 | :return: rank_records 35 | """ 36 | _, processed_sents = dataset_parser.cid2sents(cid) # 2d lists, docs => sents 37 | flat_processed_sents = list(itertools.chain(*processed_sents)) # 1d sent list 38 | 39 | lxr = LexRank(processed_sents, stopwords=STOPWORDS['en']) 40 | scores = lxr.rank_sentences(flat_processed_sents, threshold=None, fast_power_method=True) 41 | 42 | sid2score = dict() 43 | abs_idx = 0 44 | for doc_idx, doc in enumerate(processed_sents): 45 | for sent_idx, sent in enumerate(doc): 46 | sid = config.SEP.join((str(doc_idx), str(sent_idx))) 47 | score = scores[abs_idx] 48 | sid2score[sid] = score 49 | 50 | abs_idx += 1 51 | 52 | sid_score_list = rank_sent.sort_sid2score(sid2score) 53 | rank_records = rank_sent.get_rank_records(sid_score_list, sents=processed_sents, flat_sents=False) 54 | return rank_records 55 | 56 | 57 | def rank_e2e(): 58 | rank_dp = tools.get_rank_dp(model_name=MODEL_NAME) 59 | 60 | if exists(rank_dp): 61 | raise ValueError('rank_dp exists: {}'.format(rank_dp)) 62 | os.mkdir(rank_dp) 63 | 64 | cc_ids = tools.get_test_cc_ids() 65 | for cid in tqdm(cc_ids): 66 | rank_records = _lexrank(cid) 67 | rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, cid), with_rank_idx = False) 68 | 69 | logger.info('Successfully dumped rankings to: {}'.format(rank_dp)) 70 | 71 | 72 | def select_e2e(): 73 | """ 74 | No redundancy removal is applied here. 75 | """ 76 | params = { 77 | 'model_name': MODEL_NAME, 78 | 'cos_threshold': COS_THRESHOLD, 79 | } 80 | select_sent.select_end2end(**params) 81 | 82 | if __name__ == '__main__': 83 | rank_e2e() 84 | select_e2e() 85 | -------------------------------------------------------------------------------- /src/data/clip_and_mask_sl.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, abspath 2 | import sys 3 | from utils.config_loader import config_model 4 | import numpy as np 5 | 6 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 7 | 8 | """ 9 | for doc: clip_doc_sents, mask_doc_sents, clip_and_mask_doc_sents 10 | """ 11 | 12 | 13 | def _len2mask(lens, mask_shape, offset): 14 | """ 15 | could be applied to: 16 | [1] paragraph masks: pooling paragraph instance scores to document bag score 17 | [2] sentence masks: pooling word representations to sentence representation 18 | :param lens: n_sents or n_paras 19 | :param mask_shape: [max_n_sents, max_words] or [max_n_docs, max_n_paras] 20 | :return: 21 | """ 22 | mask = np.zeros(mask_shape, dtype=np.float32) 23 | if type(lens) != list: 24 | raise ValueError('Invalid lens type: {}'.format(type(lens))) 25 | 26 | if len(mask_shape) == 1: # mask a document with its paras 27 | mask[offset:offset + lens[0]] = [1] * lens[0] 28 | return mask 29 | 30 | elif len(mask_shape) == 2: # mask sentences of a para/query with their words 31 | for idx, ll in enumerate(lens): 32 | end = offset + ll 33 | mask[idx, offset:end] = [1] * ll 34 | offset = end 35 | return mask 36 | 37 | else: 38 | raise ValueError('Invalid mask dim: {}'.format(len(mask_shape))) 39 | 40 | 41 | def mask_para(n_sents, max_n_sents): 42 | """ 43 | 44 | :param n_sents: an int. 45 | :param max_n_sents: 46 | :return: 47 | """ 48 | mask_shape = [max_n_sents, ] 49 | return _len2mask([n_sents], mask_shape=mask_shape, offset=0) 50 | 51 | 52 | def clip_doc_sents(sents): 53 | """ 54 | For QueryNetSL 55 | :param sents: 56 | :return: 57 | """ 58 | 59 | words = [ss[:config_model['max_nw_sent']] for ss in sents[:config_model['max_ns_doc']]] 60 | n_words = [len(ss) for ss in words] 61 | 62 | res = { 63 | 'words': words, 64 | 'n_words_by_sents': n_words, 65 | } 66 | return res 67 | 68 | 69 | def mask_doc_sents(n_words): 70 | """ 71 | mask sentences of a doc with their words. 72 | 73 | :param n_words: an int list of sentence sizes in words. 74 | """ 75 | mask_shape = (config_model['max_ns_doc'], config_model['max_nw_sent']) 76 | # logger.info('mask shape: {}'.format(mask_shape)) 77 | return _len2mask(n_words, mask_shape=mask_shape, offset=0) 78 | 79 | 80 | def clip_and_mask_doc_sents(sents): 81 | """ 82 | For QueryNetSL. 83 | 84 | :param sents: 85 | :param offset: 86 | :return: 87 | """ 88 | clipped_res = clip_doc_sents(sents) 89 | doc_masks = mask_para(len(clipped_res['n_words_by_sents']), 90 | max_n_sents=config_model['max_ns_doc']) 91 | 92 | res = { 93 | 'sents': clipped_res['words'], 94 | 'doc_masks': doc_masks, # max_ns_doc, 95 | } 96 | return res 97 | -------------------------------------------------------------------------------- /src/frame/bert_passage/build_passage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | 4 | import os 5 | from os.path import join, dirname, abspath, exists 6 | 7 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 8 | sys.path.insert(0, dirname(dirname(dirname(abspath(__file__))))) 9 | 10 | import utils.config_loader as config 11 | import utils.tools as tools 12 | from utils.config_loader import logger, path_parser, config_model 13 | from data.dataset_parser import dataset_parser 14 | from tqdm import tqdm 15 | import dill 16 | from frame.bert_passage.passage_obj import SentObj, PassageObj 17 | 18 | 19 | def _passage_core(cid, query, narr, passage_size, stride): 20 | original_sents, processed_sents = dataset_parser.cid2sents(cid, max_ns_doc=None) # 2d lists, docs => sents 21 | logger.info('#doc: {}'.format(len(original_sents))) 22 | 23 | # build sent_objs 24 | sent_objs = [] # organized by doc 25 | sent_idx = 0 26 | for doc_idx in range(len(original_sents)): 27 | sent_objs_doc = [] 28 | for original_s, proc_s in zip(original_sents[doc_idx], processed_sents[doc_idx]): 29 | sid = config.SEP.join([cid, str(sent_idx)]) 30 | so = SentObj(sid=sid, original_sent=original_s, proc_sent=proc_s) 31 | sent_objs_doc.append(so) 32 | sent_idx += 1 33 | 34 | sent_objs.append(sent_objs_doc) 35 | 36 | # build passage objs 37 | passage_objs = [] 38 | for sent_objs_doc in sent_objs: 39 | start = 0 40 | # make sure the last sentence whose length < stride will be discarded 41 | while start + stride < len(sent_objs_doc): 42 | pid = config.SEP.join([cid, str(len(passage_objs))]) 43 | 44 | target_sent_objs = sent_objs_doc[start:start+passage_size] 45 | po = PassageObj(pid=pid, query=query, narr=narr, sent_objs=target_sent_objs) 46 | passage_objs.append(po) 47 | 48 | start += stride 49 | 50 | return passage_objs 51 | 52 | 53 | def _dump_passages(year, cid, passage_objs): 54 | cc_dp = join(path_parser.data_passages, year, cid) 55 | if not exists(cc_dp): # remove previous output 56 | os.mkdir(cc_dp) 57 | 58 | for po in passage_objs: 59 | with open(join(cc_dp, po.pid), 'wb') as f: 60 | dill.dump(po, f) 61 | 62 | logger.info('[_dump_passages] Dump {} passage objects to {}'.format(len(passage_objs), cc_dp)) 63 | 64 | 65 | def build_passages(passage_size, stride, year=None): 66 | """ 67 | 68 | :param passage_size: max number of sentences in a passage 69 | :param use_stride: 70 | :return: 71 | """ 72 | query_info = dataset_parser.get_cid2query(tokenize_narr=False) 73 | narr_info = dataset_parser.get_cid2narr() 74 | 75 | if year: 76 | years = [year] 77 | else: 78 | years = config.years 79 | 80 | for year in years: 81 | cc_ids = tools.get_cc_ids(year, model_mode='test') 82 | 83 | for cid in tqdm(cc_ids): 84 | core_params = { 85 | 'cid': cid, 86 | 'query': query_info[cid], 87 | 'narr': narr_info[cid], 88 | 'passage_size': passage_size, 89 | 'stride': stride, 90 | } 91 | passage_objs = _passage_core(**core_params) 92 | 93 | _dump_passages(year, cid, passage_objs) 94 | 95 | 96 | if __name__ == '__main__': 97 | build_passages(passage_size=config_model['ns_passage'], stride=config_model['stride'], year='2007') 98 | -------------------------------------------------------------------------------- /src/data/bert_input_sl.py: -------------------------------------------------------------------------------- 1 | import utils.config_loader as config 2 | from utils.config_loader import logger, config_model 3 | import utils.tools as tools 4 | from data.dataset_parser import dataset_parser 5 | import numpy as np 6 | 7 | 8 | def get_max_n_tokens(is_para): 9 | if is_para: 10 | return config_model['max_n_para_tokens'] 11 | 12 | return config_model['max_n_query_tokens'] 13 | 14 | 15 | def _build_token_ids(words, max_n_tokens): 16 | token_ids = np.zeros([max_n_tokens, ], dtype=np.int32) 17 | # tokens = tools.flatten(sent_tokens) 18 | 19 | tokens = ['[CLS]'] + words 20 | token_id_list = config.bert_tokenizer.convert_tokens_to_ids(tokens) 21 | n_tokens = len(token_id_list) 22 | # logger.info('n_tokens: {}'.format(n_tokens)) 23 | token_ids[:n_tokens] = token_id_list 24 | 25 | return token_ids, n_tokens 26 | 27 | 28 | def _build_token_masks(n_tokens, max_n_tokens): 29 | token_masks = np.zeros([max_n_tokens, ]) 30 | token_masks[:n_tokens] = [1] * n_tokens 31 | return token_masks 32 | 33 | 34 | def _build_seg_ids(max_n_tokens): 35 | seg_ids = np.zeros([max_n_tokens, ], dtype=np.int32) 36 | return seg_ids 37 | 38 | 39 | def _build_bert_tokens(words, max_n_tokens): 40 | token_ids, n_tokens = _build_token_ids(words, max_n_tokens) 41 | token_masks = _build_token_masks(n_tokens, max_n_tokens) 42 | seg_ids = _build_seg_ids(max_n_tokens) 43 | 44 | res = { 45 | 'token_ids': token_ids, 46 | 'seg_ids': seg_ids, 47 | 'token_masks': token_masks, 48 | } 49 | 50 | return res 51 | 52 | 53 | def build_bert_x_doc_sl(doc_fp): 54 | # build para x 55 | doc_res = dataset_parser.parse_doc2sents(doc_fp) 56 | # init paras arrays 57 | max_ns_doc = config_model['max_ns_doc'] 58 | max_nt = config_model['max_nt_sent'] 59 | basic_doc_size = [max_ns_doc, max_nt] 60 | 61 | doc_token_ids = np.zeros(basic_doc_size, dtype=np.int32) 62 | doc_seg_ids = np.zeros(basic_doc_size, dtype=np.int32) 63 | doc_token_masks = np.zeros(basic_doc_size) 64 | 65 | # build para 66 | word_list = doc_res['words'] 67 | # logger.info('word list from doc: {}'.format(word_list)) 68 | for s_idx, words_s in enumerate(word_list): 69 | # logger.info('{}: {}'.format(s_idx, words_s)) 70 | sent_bert_in = _build_bert_tokens(words=words_s, max_n_tokens=max_nt) 71 | doc_token_ids[s_idx] = sent_bert_in['token_ids'] 72 | doc_seg_ids[s_idx] = sent_bert_in['seg_ids'] 73 | doc_token_masks[s_idx] = sent_bert_in['token_masks'] 74 | 75 | xx = { 76 | 'doc_token_ids': doc_token_ids, 77 | 'doc_seg_ids': doc_seg_ids, 78 | 'doc_token_masks': doc_token_masks, 79 | 'doc_masks': doc_res['doc_mask'], # max_ns_doc, 80 | } 81 | 82 | return xx 83 | 84 | 85 | def build_bert_x_trigger_sl(trigger): 86 | max_nt = config_model['max_nt_trigger'] 87 | trigger_res = dataset_parser.parse_trigger2words(trigger) 88 | trigger_bert_in = _build_bert_tokens(words=trigger_res['words'], max_n_tokens=max_nt) 89 | 90 | xx = { 91 | 'trigger_token_ids': trigger_bert_in['token_ids'], 92 | 'trigger_seg_ids': trigger_bert_in['seg_ids'], 93 | 'trigger_token_masks': trigger_bert_in['token_masks'], 94 | # 'trigger_masks': trigger_res['trigger_mask'], 95 | } 96 | 97 | return xx 98 | 99 | 100 | def build_bert_x_sl(trigger, doc_fp): 101 | xx_doc = build_bert_x_doc_sl(doc_fp) 102 | xx_trigger = build_bert_x_trigger_sl(trigger) 103 | 104 | return { 105 | **xx_doc, 106 | **xx_trigger, 107 | } 108 | -------------------------------------------------------------------------------- /src/data/bert_input.py: -------------------------------------------------------------------------------- 1 | import utils.config_loader as config 2 | from utils.config_loader import logger, config_model 3 | from data.dataset_parser import dataset_parser 4 | import numpy as np 5 | 6 | 7 | def _build_bert_tokens_for_para(query_words, para_words): 8 | token_ids = np.zeros([config_model['max_n_tokens'], ], dtype=np.int32) 9 | seg_ids = np.zeros([config_model['max_n_tokens'], ], dtype=np.int32) 10 | token_masks = np.zeros([config_model['max_n_tokens'], ]) 11 | 12 | # logger.info('shape of token_ids: {}'.format(token_ids.shape)) 13 | 14 | tokens = ['[CLS]'] + query_words + ['[SEP]'] + para_words 15 | token_id_list = config.bert_tokenizer.convert_tokens_to_ids(tokens) 16 | n_tokens = len(token_id_list) 17 | 18 | # logger.info('tokens: {}'.format(tokens)) 19 | # logger.info('token_id_list: {}'.format(token_id_list)) 20 | 21 | token_ids[:n_tokens] = token_id_list 22 | seg_ids[len(query_words) + 2:n_tokens] = [1] * len(para_words) 23 | token_masks[:n_tokens] = [1] * n_tokens 24 | 25 | para_in = { 26 | 'token_ids': token_ids, 27 | 'seg_ids': seg_ids, 28 | 'token_masks': token_masks, 29 | } 30 | 31 | return para_in 32 | 33 | 34 | def build_bert_x(query, doc_fp): 35 | # prep resources: query and document 36 | query_res = dataset_parser.parse_query(query) 37 | 38 | para_offset = len(query_res['words']) + 2 # 2 additional tokens for CLS and SEP 39 | doc_res = dataset_parser.parse_doc(doc_fp, concat_paras=False, offset=para_offset) 40 | 41 | # init arrays 42 | # token_ids = np.zeros([config_model['max_n_article_paras'], config_model['max_n_tokens']], dtype=np.float32) 43 | # seg_ids = np.zeros([config_model['max_n_article_paras'], config_model['max_n_tokens']], dtype=np.float32) 44 | # token_masks = np.zeros([config_model['max_n_article_paras'], config_model['max_n_tokens']], dtype=np.float32) 45 | 46 | token_ids = np.zeros([config_model['max_n_article_paras'], config_model['max_n_tokens']], dtype=np.int32) 47 | seg_ids = np.zeros([config_model['max_n_article_paras'], config_model['max_n_tokens']], dtype=np.int32) 48 | token_masks = np.zeros([config_model['max_n_article_paras'], config_model['max_n_tokens']], dtype=np.float32) 49 | 50 | query_sent_masks = np.zeros( 51 | [config_model['max_n_article_paras'], config_model['max_n_query_sents'], config_model['max_n_tokens']], 52 | dtype=np.float32) 53 | 54 | para_sent_masks = np.zeros( 55 | [config_model['max_n_article_paras'], config_model['max_n_para_sents'], config_model['max_n_tokens']], 56 | dtype=np.float32) 57 | 58 | para_masks = np.zeros([config_model['max_n_article_paras'], config_model['max_n_para_sents']], dtype=np.float32) 59 | 60 | # concat paras with query 61 | for para_idx, para_res in enumerate(doc_res['paras']): 62 | # input tokens 63 | para_in = _build_bert_tokens_for_para(query_words=query_res['words'], para_words=para_res['words']) 64 | token_ids[para_idx] = para_in['token_ids'] 65 | seg_ids[para_idx] = para_in['seg_ids'] 66 | token_masks[para_idx] = para_in['token_masks'] 67 | 68 | # masks 69 | query_sent_masks[para_idx] = query_res['sent_mask'] 70 | para_sent_masks[para_idx] = para_res['sent_mask'] 71 | para_masks[para_idx] = para_res['para_mask'] 72 | 73 | xx = { 74 | 'token_ids': token_ids, 75 | 'seg_ids': seg_ids, 76 | 'token_masks': token_masks, 77 | 'query_sent_masks': query_sent_masks, 78 | 'query_masks': query_res['para_mask'], 79 | 'para_sent_masks': para_sent_masks, 80 | 'para_masks': para_masks, 81 | 'doc_masks': doc_res['doc_masks'], 82 | } 83 | 84 | return xx 85 | -------------------------------------------------------------------------------- /src/frame/bert_qa/bert_input.py: -------------------------------------------------------------------------------- 1 | import utils.config_loader as config 2 | from utils.config_loader import config_model 3 | from data.dataset_parser import dataset_parser 4 | import numpy as np 5 | 6 | 7 | def _build_bert_tokens_for_sent(query_tokens, instance_tokens): 8 | in_size = [config_model['max_n_tokens'], ] 9 | 10 | token_ids = np.zeros(in_size, dtype=np.int32) 11 | seg_ids = np.zeros(in_size, dtype=np.int32) 12 | token_masks = np.zeros(in_size) 13 | 14 | # logger.info('shape of token_ids: {}'.format(token_ids.shape)) 15 | tokens = ['[CLS]'] + query_tokens + ['[SEP]'] + instance_tokens + ['[SEP]'] 16 | # logger.info('tokens: {}'.format(tokens)) 17 | token_id_list = config.bert_tokenizer.convert_tokens_to_ids(tokens) 18 | n_tokens = len(token_id_list) 19 | 20 | token_ids[:n_tokens] = token_id_list 21 | seg_ids[len(query_tokens) + 2:n_tokens] = [1] * (len(instance_tokens) + 1) 22 | token_masks[:n_tokens] = [1] * n_tokens 23 | 24 | sent_in = { 25 | 'token_ids': token_ids, 26 | 'seg_ids': seg_ids, 27 | 'token_masks': token_masks, 28 | } 29 | 30 | return sent_in 31 | 32 | 33 | def build_instance_tokens_with_context(sent_idx, doc_sents, window): 34 | if window <= 0: 35 | raise ValueError('Invalid window: {}'.format(window)) 36 | 37 | n_sent = len(doc_sents) 38 | context = [] 39 | 40 | context_idx = 0 41 | context_token_pat = '[unused{}] ' 42 | 43 | for i in range(window): 44 | # preceding 45 | idx_a = sent_idx - i - 1 46 | context_idx += 1 47 | context_token = context_token_pat.format(context_idx) 48 | 49 | if idx_a >= 0: 50 | context.append(context_token + doc_sents[idx_a]) 51 | else: 52 | context.append(context_token) 53 | 54 | # subsequent 55 | idx_b = sent_idx + i + 1 56 | context_idx += 1 57 | context_token = context_token_pat.format(context_idx) 58 | if idx_b < n_sent: 59 | context.append(context_token + doc_sents[idx_b]) 60 | else: 61 | context.append(context_token) 62 | 63 | sent = doc_sents[sent_idx] 64 | context.insert(0, sent) 65 | sent = ' '.join(context) 66 | 67 | return sent 68 | 69 | 70 | def build_bert_x(query, doc_fp, window=None): 71 | # prep resources: query and document 72 | query_tokens = dataset_parser.parse_query(query) 73 | 74 | doc_res = dataset_parser.parse_doc2sents(doc_fp) 75 | in_size = [config_model['max_ns_doc'], config_model['max_n_tokens']] 76 | token_ids = np.zeros(in_size, dtype=np.int32) 77 | seg_ids = np.zeros(in_size, dtype=np.int32) 78 | token_masks = np.zeros(in_size, dtype=np.float32) 79 | 80 | # concat sentence with query 81 | for sent_idx in range(doc_res['sents']): 82 | instance_tokens = build_instance_tokens_with_context(sent_idx, 83 | doc_sents=doc_res['sents'], 84 | window=window) 85 | sent_in = _build_bert_tokens_for_sent(query_tokens=query_tokens, 86 | instance_tokens=instance_tokens) 87 | token_ids[sent_idx] = sent_in['token_ids'] 88 | seg_ids[sent_idx] = sent_in['seg_ids'] 89 | token_masks[sent_idx] = sent_in['token_masks'] 90 | 91 | xx = { 92 | 'token_ids': token_ids, 93 | 'seg_ids': seg_ids, 94 | 'token_masks': token_masks, 95 | 'doc_masks': doc_res['doc_masks'], 96 | } 97 | 98 | return xx 99 | 100 | 101 | def build_bert_sentence_x(query, sentence): 102 | query_tokens = dataset_parser.parse_query(query) 103 | instance_tokens = dataset_parser.sent2words(sentence)[:config_model['max_nw_sent']] 104 | return _build_bert_tokens_for_sent(query_tokens, instance_tokens) 105 | -------------------------------------------------------------------------------- /src/data/bert_input_sep.py: -------------------------------------------------------------------------------- 1 | import utils.config_loader as config 2 | from utils.config_loader import logger, config_model 3 | import utils.tools as tools 4 | from data.dataset_parser import dataset_parser 5 | import numpy as np 6 | 7 | 8 | def get_max_n_tokens(is_para): 9 | if is_para: 10 | return config_model['max_n_para_tokens'] 11 | 12 | return config_model['max_n_query_tokens'] 13 | 14 | 15 | def _build_token_ids(words, max_n_tokens): 16 | token_ids = np.zeros([max_n_tokens, ], dtype=np.int32) 17 | # tokens = tools.flatten(sent_tokens) 18 | 19 | tokens = ['[CLS]'] + words 20 | token_id_list = config.bert_tokenizer.convert_tokens_to_ids(tokens) 21 | n_tokens = len(token_id_list) 22 | # logger.info('n_tokens: {}'.format(n_tokens)) 23 | token_ids[:n_tokens] = token_id_list 24 | 25 | return token_ids, n_tokens 26 | 27 | 28 | def _build_token_masks(n_tokens, max_n_tokens): 29 | token_masks = np.zeros([max_n_tokens, ]) 30 | token_masks[:n_tokens] = [1] * n_tokens 31 | return token_masks 32 | 33 | 34 | def _build_seg_ids(max_n_tokens): 35 | seg_ids = np.zeros([max_n_tokens, ], dtype=np.int32) 36 | return seg_ids 37 | 38 | 39 | def _build_bert_tokens(words, max_n_tokens): 40 | token_ids, n_tokens = _build_token_ids(words, max_n_tokens) 41 | token_masks = _build_token_masks(n_tokens, max_n_tokens) 42 | seg_ids = _build_seg_ids(max_n_tokens) 43 | 44 | res = { 45 | 'token_ids': token_ids, 46 | 'seg_ids': seg_ids, 47 | 'token_masks': token_masks, 48 | } 49 | 50 | return res 51 | 52 | 53 | def build_bert_x_sep(query, doc_fp): 54 | # todo: move initial sentence masks here for query and paras 55 | # build query x 56 | max_n_query_tokens = config_model['max_n_query_tokens'] 57 | query_res = dataset_parser.parse_query(query) 58 | query_bert_in = _build_bert_tokens(words=query_res['words'], max_n_tokens=max_n_query_tokens) 59 | 60 | # build para x 61 | doc_res = dataset_parser.parse_doc(doc_fp, concat_paras=False, offset=1) 62 | # init paras arrays 63 | max_n_article_paras = config_model['max_n_article_paras'] 64 | max_n_para_sents = config_model['max_n_para_sents'] 65 | max_n_para_tokens = config_model['max_n_para_tokens'] 66 | basic_para_size = [max_n_article_paras, max_n_para_tokens] 67 | 68 | para_token_ids = np.zeros(basic_para_size, dtype=np.int32) 69 | para_seg_ids = np.zeros(basic_para_size, dtype=np.int32) 70 | para_token_masks = np.zeros(basic_para_size) 71 | 72 | # init sentence and para masks 73 | para_sent_masks = np.zeros([max_n_article_paras, max_n_para_sents, max_n_para_tokens], dtype=np.float32) 74 | para_masks = np.zeros([max_n_article_paras, max_n_para_sents], dtype=np.float32) 75 | 76 | # build para 77 | for para_idx, para_res in enumerate(doc_res['paras']): 78 | # bert inputs 79 | para_bert_in = _build_bert_tokens(words=para_res['words'], max_n_tokens=max_n_para_tokens) 80 | para_token_ids[para_idx] = para_bert_in['token_ids'] 81 | para_seg_ids[para_idx] = para_bert_in['seg_ids'] 82 | para_token_masks[para_idx] = para_bert_in['token_masks'] 83 | # masks 84 | para_sent_masks[para_idx] = para_res['sent_mask'] 85 | para_masks[para_idx] = para_res['para_mask'] 86 | 87 | xx = { 88 | 'query_token_ids': query_bert_in['token_ids'], 89 | 'query_seg_ids': query_bert_in['seg_ids'], 90 | 'query_token_masks': query_bert_in['token_masks'], 91 | 'query_sent_masks': query_res['sent_mask'], 92 | 'query_masks': query_res['para_mask'], 93 | 94 | 'para_token_ids': para_token_ids, 95 | 'para_seg_ids': para_seg_ids, 96 | 'para_token_masks': para_token_masks, 97 | 'para_sent_masks': para_sent_masks, 98 | 'para_masks': para_masks, 99 | 'doc_masks': doc_res['doc_masks'], 100 | } 101 | 102 | return xx 103 | -------------------------------------------------------------------------------- /src/summ/build_summary_targets.py: -------------------------------------------------------------------------------- 1 | import io 2 | import sys 3 | from os import listdir 4 | from os.path import join, dirname, abspath, isfile 5 | import shutil 6 | from shutil import copyfile 7 | 8 | import utils.config_loader as config 9 | from utils.config_loader import logger, path_parser 10 | import utils.tools as tools 11 | 12 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 13 | 14 | 15 | def match_summary_fn_with_cid(summary_fn, cid): 16 | year, cc = cid.split(config.SEP) 17 | 18 | if year == '2005': # handle cluster naming differences in 2005 data 19 | summary_fn = summary_fn.lower() 20 | 21 | is_a_match = cc.startswith(summary_fn.split('.')[0]) 22 | 23 | return is_a_match 24 | 25 | 26 | def retrieve_refs_fps_with_cid(cc_ids, ref_dp): 27 | cid2ref_fps = dict() 28 | for cid in cc_ids: 29 | ref_fns = [ref_fn for ref_fn in listdir(ref_dp) if isfile(join(ref_dp, ref_fn))] 30 | ref_fps = [join(ref_dp, fn) for fn in ref_fns if match_summary_fn_with_cid(fn, cid)] 31 | cid2ref_fps[cid] = ref_fps 32 | 33 | return cid2ref_fps 34 | 35 | 36 | def build_summary_targets_single_file(cid2ref_fps, out_dp): 37 | for cid, ref_fps in cid2ref_fps.items(): 38 | content = [] 39 | for ref_fp in ref_fps: # handle multiple refs 40 | summary_words = [] 41 | with io.open(ref_fp, encoding='latin1') as ref_f: 42 | for line in ref_f: 43 | words = config.bert_tokenizer.tokenize(line.rstrip('\n')) 44 | summary_words.extend(words) 45 | content.append(' '.join(summary_words)) 46 | 47 | out_fp = join(out_dp, cid) 48 | with io.open(out_fp, mode='a', encoding='utf-8') as out_f: 49 | out_f.write('\n'.join(content)) 50 | logger.info('[BUILD SUMMARY TARGETS] successfully dumped {0} refs for {1}'.format(len(content), cid)) 51 | 52 | 53 | def build_summary_targets(cid2ref_fps, out_dp, tokenize=True): 54 | for cid, ref_fps in cid2ref_fps.items(): 55 | for ref_idx, ref_fp in enumerate(ref_fps, start=1): # handle multiple refs 56 | out_fp = join(out_dp, config.SEP.join((cid, str(ref_idx)))) 57 | if not tokenize: 58 | shutil.copy(ref_fp, out_fp) 59 | else: 60 | summary_sents = [] 61 | with io.open(ref_fp, encoding='latin1') as ref_f: 62 | for line in ref_f: 63 | words = config.bert_tokenizer.tokenize(line.rstrip('\n')) 64 | summary_sents.append(' '.join(words)) 65 | 66 | ref = '\n'.join(summary_sents) 67 | with io.open(out_fp, mode='a', encoding='utf-8') as out_f: 68 | out_f.write(ref) 69 | 70 | 71 | def build_summary_targets_annually(year, single_file, tokenize, manual=False): 72 | cc_ids = tools.get_cc_ids(year, model_mode='test') 73 | if tokenize: 74 | out_dir = '{}_tokenized'.format(year) 75 | elif manual: 76 | out_dir = '{}_manual'.format(year) 77 | else: 78 | out_dir = year 79 | out_dp = join(path_parser.data_summary_targets, out_dir) 80 | 81 | if manual: 82 | in_dir = '{}_manual'.format(year) 83 | else: 84 | in_dir = year 85 | ref_dp = join(path_parser.data_summary_refs, in_dir) 86 | cid2ref_fps = retrieve_refs_fps_with_cid(cc_ids, ref_dp) 87 | 88 | if single_file: 89 | build_summary_targets_single_file(cid2ref_fps, out_dp) 90 | else: 91 | build_summary_targets(cid2ref_fps, out_dp, tokenize=tokenize) 92 | 93 | 94 | def build_summary_targets_end2end(single_file, tokenize, manual=False): 95 | for year in config.years: 96 | build_summary_targets_annually(year, single_file, tokenize, manual) 97 | 98 | 99 | if __name__ == '__main__': 100 | # build_summary_targets(year='2005') 101 | build_summary_targets_annually(year='2007', single_file=False, tokenize=False, manual=True) 102 | -------------------------------------------------------------------------------- /src/tools/general_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | from os.path import join, dirname, abspath 4 | 5 | sys_path = dirname(dirname(abspath(__file__))) 6 | parent_sys_path = dirname(sys_path) 7 | 8 | if sys_path not in sys.path: 9 | sys.path.insert(0, sys_path) 10 | if parent_sys_path not in sys.path: 11 | sys.path.insert(0, parent_sys_path) 12 | 13 | import utils.config_loader as config 14 | from utils.config_loader import logger, path_parser, config_meta 15 | from data.dataset_parser import dataset_parser 16 | import utils.tools as tools 17 | import io 18 | 19 | 20 | def build_test_cid_query_dicts(tokenize_narr, concat_title_narr, query_type=None): 21 | """ 22 | 23 | :param tokenize_narr: bool 24 | :param concat_title_narr: bool 25 | :return: 26 | """ 27 | query_info = dict() 28 | for year in config.years: 29 | query_params = { 30 | 'year': year, 31 | 'tokenize_narr': tokenize_narr, 32 | 'concat_title_narr': concat_title_narr, 33 | } 34 | 35 | annual_query_info = dataset_parser.build_query_info(**query_params) 36 | query_info = { 37 | **annual_query_info, 38 | **query_info, 39 | } 40 | 41 | cids = tools.get_test_cc_ids() 42 | test_cid_query_dicts = [] 43 | 44 | for cid in cids: 45 | query = tools.get_query_w_cid(query_info, cid=cid) 46 | 47 | if query_type: 48 | query = query[query_type] 49 | 50 | print('query: {}'.format(query)) 51 | test_cid_query_dicts.append({ 52 | 'cid': cid, 53 | 'query': query, 54 | }) 55 | 56 | return test_cid_query_dicts 57 | 58 | 59 | def build_tdqfs_cid_query_dicts(query_fp, proc=True): 60 | """ 61 | :return: 62 | """ 63 | assert config_meta['test_year'] == 'tdqfs' 64 | lines = io.open(query_fp).readlines() 65 | cid_query_dicts = [] 66 | 67 | items = config_meta['test_year'].split('-') 68 | for line in lines: 69 | cid, dom, query = line.rstrip('\n').split('\t') 70 | if proc: 71 | query = dataset_parser._proc_sent(query, rm_dialog=False, rm_stop=False, stem=True, rm_short=None) 72 | cid_query_dicts.append({ 73 | 'cid': cid, 74 | 'query': query, 75 | }) 76 | return cid_query_dicts 77 | 78 | 79 | def build_tdqfs_oracle_test_cid_query_dicts(query_fp): 80 | def _get_ref(cid): 81 | REF_DP = path_parser.data_tdqfs_summary_targets 82 | fp = join(REF_DP, '{}_{}'.format(cid, 0)) 83 | ref = '' 84 | # for fn in fns: 85 | lines = io.open(fp, encoding='utf-8').readlines() 86 | for line in lines: 87 | ref += line.rstrip('\n') 88 | 89 | return ref 90 | 91 | assert config_meta['test_year'] == 'tdqfs' 92 | lines = io.open(query_fp).readlines() 93 | cids = [line.rstrip('\n').split('\t')[0] for line in lines] 94 | test_cid_query_dicts = [] 95 | for cid in cids: 96 | ref = _get_ref(cid) 97 | logger.info('cid {}: {}'.format(cid, ref)) 98 | 99 | test_cid_query_dicts.append({ 100 | 'cid': cid, 101 | 'query': ref, 102 | }) 103 | return test_cid_query_dicts 104 | 105 | 106 | def build_oracle_test_cid_query_dicts(): 107 | def _get_ref(cid): 108 | REF_DP = join(path_parser.data_summary_targets, config.test_year) 109 | fp = join(REF_DP, '{}_{}'.format(cid, 1)) 110 | ref = '' 111 | lines = io.open(fp, encoding='utf-8').readlines() 112 | for line in lines: 113 | ref += line.rstrip('\n') 114 | 115 | return ref 116 | 117 | test_cid_query_dicts = [] 118 | cids = tools.get_test_cc_ids() 119 | 120 | for cid in cids: 121 | ref = _get_ref(cid) 122 | logger.info('cid {}: {}'.format(cid, ref)) 123 | 124 | test_cid_query_dicts.append({ 125 | 'cid': cid, 126 | 'query': ref, 127 | }) 128 | 129 | return test_cid_query_dicts 130 | 131 | 132 | if __name__ == '__main__': 133 | build_test_cid_query_dicts(tokenize_narr=None, concat_title_narr=True) -------------------------------------------------------------------------------- /src/frame/bert_passage/build_passage_tdqfs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import io 4 | import os 5 | from os import listdir 6 | from os.path import join, dirname, abspath, exists 7 | 8 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 9 | sys.path.insert(0, dirname(dirname(dirname(abspath(__file__))))) 10 | 11 | import tools.general_tools as general_tools 12 | import utils.config_loader as config 13 | import utils.tools as tools 14 | from utils.config_loader import logger, path_parser, config_model 15 | 16 | from data.dataset_parser import dataset_parser 17 | from tqdm import tqdm 18 | import dill 19 | import itertools 20 | from frame.bert_passage.passage_obj import SentObj, PassageObj 21 | 22 | 23 | sentence_dp = path_parser.data_tdqfs_sentences 24 | passages_dp = path_parser.data_tdqfs_passages 25 | query_fp = path_parser.data_tdqfs_queries 26 | test_cid_query_dicts = general_tools.build_tdqfs_cid_query_dicts(query_fp=query_fp, proc=True) 27 | 28 | 29 | def get_sentences(cid): 30 | cc_dp = join(sentence_dp, cid) 31 | fns = [fn for fn in listdir(cc_dp)] 32 | lines = itertools.chain(*[io.open(join(cc_dp, fn)).readlines() for fn in fns]) 33 | sentences = [line.strip('\n') for line in lines] 34 | 35 | original_sents = [] 36 | processed_sents = [] 37 | for ss in sentences: 38 | ss_origin = dataset_parser._proc_sent(ss, rm_dialog=False, rm_stop=False, stem=False) 39 | ss_proc = dataset_parser._proc_sent(ss, rm_dialog=False, rm_stop=True, stem=True) 40 | 41 | if ss_proc: # make sure the sent is not removed, i.e., is not empty and is not in a dialog 42 | original_sents.append(ss_origin) 43 | processed_sents.append(ss_proc) 44 | 45 | return [original_sents], [processed_sents] 46 | 47 | 48 | def _passage_core(cid, query, passage_size, stride): 49 | original_sents, processed_sents = get_sentences(cid) # 2d lists, docs => sents 50 | # logger.info('#doc: {}'.format(len(original_sents))) 51 | 52 | # build sent_objs 53 | sent_objs = [] # organized by doc 54 | sent_idx = 0 55 | for doc_idx in range(len(original_sents)): 56 | sent_objs_doc = [] 57 | for original_s, proc_s in zip(original_sents[doc_idx], processed_sents[doc_idx]): 58 | sid = config.SEP.join([cid, str(sent_idx)]) 59 | so = SentObj(sid=sid, original_sent=original_s, proc_sent=proc_s) 60 | sent_objs_doc.append(so) 61 | sent_idx += 1 62 | 63 | sent_objs.append(sent_objs_doc) 64 | 65 | # build passage objs 66 | passage_objs = [] 67 | for sent_objs_doc in sent_objs: 68 | start = 0 69 | # make sure the last sentence whose length < stride will be discarded 70 | while start + stride < len(sent_objs_doc): 71 | pid = config.SEP.join([cid, str(len(passage_objs))]) 72 | 73 | target_sent_objs = sent_objs_doc[start:start+passage_size] 74 | po = PassageObj(pid=pid, query=query, narr=query, sent_objs=target_sent_objs) 75 | passage_objs.append(po) 76 | 77 | start += stride 78 | 79 | return passage_objs 80 | 81 | 82 | def _dump_passages(cid, passage_objs): 83 | cc_dp = join(passages_dp, cid) 84 | if not exists(cc_dp): # remove previous output 85 | os.mkdir(cc_dp) 86 | 87 | for po in passage_objs: 88 | with open(join(cc_dp, po.pid), 'wb') as f: 89 | dill.dump(po, f) 90 | 91 | logger.info('[_dump_passages] Dump {} passage objects to {}'.format(len(passage_objs), cc_dp)) 92 | 93 | 94 | def build_passages(passage_size, stride): 95 | """ 96 | 97 | :param passage_size: max number of sentences in a passage 98 | :param use_stride: 99 | :return: 100 | """ 101 | for cid_query_dict in tqdm(test_cid_query_dicts): 102 | core_params = { 103 | **cid_query_dict, 104 | 'passage_size': passage_size, 105 | 'stride': stride, 106 | } 107 | passage_objs = _passage_core(**core_params) 108 | _dump_passages(cid=cid_query_dict['cid'], passage_objs=passage_objs) 109 | 110 | 111 | if __name__ == '__main__': 112 | build_passages(passage_size=config_model['ns_passage'], 113 | stride=config_model['stride']) 114 | -------------------------------------------------------------------------------- /src/baselines/human.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os import listdir 4 | from os.path import join, dirname, abspath, isfile, exists 5 | sys_path = dirname(dirname(abspath(__file__))) 6 | parent_sys_path = dirname(sys_path) 7 | 8 | if sys_path not in sys.path: 9 | sys.path.insert(0, sys_path) 10 | if parent_sys_path not in sys.path: 11 | sys.path.insert(0, parent_sys_path) 12 | 13 | import shutil 14 | import utils.config_loader as config 15 | from utils.config_loader import logger, path_parser 16 | from pyrouge import Rouge155 17 | 18 | """ 19 | This module computes Gold scores which represents human intra-agreement. 20 | """ 21 | 22 | MODEL_DP = join(path_parser.data_summary_targets, config.test_year) 23 | MODEL_NAME_TEMP = 'human_model' 24 | SYSTEM_NAME_TEMP = 'human_system' 25 | 26 | MODEL_DP_TEMP = join(path_parser.summary_text, MODEL_NAME_TEMP) 27 | SYSTEM_DP_TEMP = join(path_parser.summary_text, SYSTEM_NAME_TEMP) 28 | fns = [fn for fn in listdir(MODEL_DP) if isfile(join(MODEL_DP, fn))] 29 | ROUGE_METRICS = ['1', '2', 'SU4'] 30 | N_REFS = 4 31 | 32 | def build_eval_dirs(summary_index): 33 | system_fns = [] 34 | model_fns = [] 35 | for fn in fns: 36 | if fn.endswith(str(summary_index)): 37 | system_fns.append(fn) 38 | else: 39 | model_fns.append(fn) 40 | 41 | assert len(model_fns)/len(system_fns) == N_REFS-1 42 | 43 | # remove previous output 44 | for temp_dp in (MODEL_DP_TEMP, SYSTEM_DP_TEMP): 45 | if exists(temp_dp): 46 | shutil.rmtree(temp_dp) 47 | os.mkdir(temp_dp) 48 | 49 | for fn in system_fns: 50 | shutil.copyfile(join(MODEL_DP, fn), join(SYSTEM_DP_TEMP, fn[:-2])) 51 | 52 | for fn in model_fns: 53 | shutil.copyfile(join(MODEL_DP, fn), join(MODEL_DP_TEMP, fn)) 54 | 55 | def proc_output(output): 56 | start_pat = '1 ROUGE-{} Average' 57 | 58 | output = '\n'.join(output.split('\n')[1:]) 59 | inter_breaker = '\n---------------------------------------------\n' 60 | intra_breaker = '\n.............................................\n' 61 | 62 | tg2ck = {} 63 | for ck in output.split(inter_breaker): 64 | ck = ck.strip('\n') 65 | if ck: 66 | ck = ck.split(intra_breaker)[0] 67 | for tg in ROUGE_METRICS: 68 | if ck.startswith(start_pat.format(tg)): 69 | tg2ck[tg] = ck 70 | break 71 | 72 | num_idx = 3 73 | tg2recall = {} 74 | tg2f1 = {} 75 | 76 | for tg, ck in tg2ck.items(): 77 | lines = ck.split('\n') 78 | recall = float(lines[0].split(' ')[num_idx]) * 100 79 | f1 = float(lines[2].split(' ')[num_idx]) * 100 80 | 81 | tg2recall[tg] = recall 82 | tg2f1[tg] = f1 83 | 84 | return tg2recall, tg2f1 85 | 86 | def compute_rouge_for_human(system_dp, model_dp): 87 | rouge_args = '-a -l 250 -n 2 -m -2 4 -u -c 95 -r 1000 -f A -p 0.5 -t 0 -d -e {} -x'.format( 88 | path_parser.rouge_dir) 89 | 90 | r = Rouge155(rouge_args=rouge_args) 91 | r.system_dir = system_dp 92 | r.model_dir = model_dp 93 | 94 | gen_sys_file_pat = '(\w*)' 95 | gen_model_file_pat = '#ID#_[\d]' 96 | r.system_filename_pattern = gen_sys_file_pat 97 | r.model_filename_pattern = gen_model_file_pat 98 | 99 | output = r.convert_and_evaluate() 100 | tg2recall, tg2f1 = proc_output(output) 101 | return tg2recall, tg2f1 102 | 103 | def compute_rouge(): 104 | tg2recall_all = {} 105 | tg2f1_all = {} 106 | for idx in range(N_REFS): 107 | build_eval_dirs(summary_index=idx+1) 108 | tg2recall, tg2f1 = compute_rouge_for_human(system_dp=SYSTEM_DP_TEMP, model_dp=MODEL_DP_TEMP) 109 | logger.info('tg2f1: {}'.format(tg2f1)) 110 | for metric in ROUGE_METRICS: 111 | if metric in tg2recall_all: 112 | tg2recall_all[metric] += tg2recall[metric] 113 | else: 114 | tg2recall_all[metric] = tg2recall[metric] 115 | 116 | if metric in tg2f1_all: 117 | tg2f1_all[metric] += tg2f1[metric] 118 | else: 119 | tg2f1_all[metric] = tg2f1[metric] 120 | 121 | for metric in ROUGE_METRICS: 122 | tg2recall_all[metric] /= N_REFS 123 | tg2f1_all[metric] /= N_REFS 124 | 125 | recall_str = 'Recall:\t{}'.format('\t'.join(['{0:.2f}'.format(tg2recall_all[metric]) for metric in ROUGE_METRICS])) 126 | f1_str = 'F1:\t{}'.format('\t'.join(['{0:.2f}'.format(tg2f1_all[metric]) for metric in ROUGE_METRICS])) 127 | 128 | output = '\n' + '\n'.join((f1_str, recall_str)) 129 | logger.info(output) 130 | 131 | if __name__ == '__main__': 132 | compute_rouge() 133 | -------------------------------------------------------------------------------- /src/frame/bert_passage/bert_input.py: -------------------------------------------------------------------------------- 1 | import utils.config_loader as config 2 | from utils.config_loader import config_model 3 | from data.dataset_parser import dataset_parser 4 | import numpy as np 5 | 6 | 7 | def _build_bert_tokens_for_passage(query_tokens, instance_tokens): 8 | in_size = [config_model['max_n_tokens'], ] 9 | 10 | token_ids = np.zeros(in_size, dtype=np.int32) 11 | seg_ids = np.zeros(in_size, dtype=np.int32) 12 | token_masks = np.zeros(in_size) 13 | 14 | # logger.info('shape of token_ids: {}'.format(token_ids.shape)) 15 | tokens = ['[CLS]'] + query_tokens + ['[SEP]'] + instance_tokens + ['[SEP]'] 16 | # logger.info('tokens: {}'.format(tokens)) 17 | token_id_list = config.bert_tokenizer.convert_tokens_to_ids(tokens) 18 | n_tokens = len(token_id_list) 19 | 20 | token_ids[:n_tokens] = token_id_list 21 | seg_ids[len(query_tokens) + 2:n_tokens] = [1] * (len(instance_tokens) + 1) 22 | token_masks[:n_tokens] = [1] * n_tokens 23 | 24 | passage_in = { 25 | 'token_ids': token_ids, 26 | 'seg_ids': seg_ids, 27 | 'token_masks': token_masks, 28 | } 29 | 30 | return passage_in 31 | 32 | 33 | def build_instance_tokens_with_context(sent_idx, doc_sents, window): 34 | if window <= 0: 35 | raise ValueError('Invalid window: {}'.format(window)) 36 | 37 | n_sent = len(doc_sents) 38 | context = [] 39 | 40 | context_idx = 0 41 | context_token_pat = '[unused{}] ' 42 | 43 | for i in range(window): 44 | # preceding 45 | idx_a = sent_idx - i - 1 46 | context_idx += 1 47 | context_token = context_token_pat.format(context_idx) 48 | 49 | if idx_a >= 0: 50 | context.append(context_token + doc_sents[idx_a]) 51 | else: 52 | context.append(context_token) 53 | 54 | # subsequent 55 | idx_b = sent_idx + i + 1 56 | context_idx += 1 57 | context_token = context_token_pat.format(context_idx) 58 | if idx_b < n_sent: 59 | context.append(context_token + doc_sents[idx_b]) 60 | else: 61 | context.append(context_token) 62 | 63 | sent = doc_sents[sent_idx] 64 | context.insert(0, sent) 65 | sent = ' '.join(context) 66 | 67 | return sent 68 | 69 | 70 | def build_bert_x(query, doc_fp, window=None): 71 | # prep resources: query and document 72 | query_tokens = dataset_parser.parse_query(query) 73 | 74 | doc_res = dataset_parser.parse_doc2sents(doc_fp) 75 | in_size = [config_model['max_ns_doc'], config_model['max_n_tokens']] 76 | token_ids = np.zeros(in_size, dtype=np.int32) 77 | seg_ids = np.zeros(in_size, dtype=np.int32) 78 | token_masks = np.zeros(in_size, dtype=np.float32) 79 | 80 | # concat sentence with query 81 | for sent_idx in range(doc_res['sents']): 82 | instance_tokens = build_instance_tokens_with_context(sent_idx, 83 | doc_sents=doc_res['sents'], 84 | window=window) 85 | sent_in = _build_bert_tokens_for_passage(query_tokens=query_tokens, instance_tokens=instance_tokens) 86 | token_ids[sent_idx] = sent_in['token_ids'] 87 | seg_ids[sent_idx] = sent_in['seg_ids'] 88 | token_masks[sent_idx] = sent_in['token_masks'] 89 | 90 | xx = { 91 | 'token_ids': token_ids, 92 | 'seg_ids': seg_ids, 93 | 'token_masks': token_masks, 94 | 'doc_masks': doc_res['doc_masks'], 95 | } 96 | 97 | return xx 98 | 99 | 100 | def build_bert_passage_x(query, passage): 101 | """ 102 | 103 | :param query: 104 | :param passage: a list of sentences 105 | :return: 106 | """ 107 | query_tokens = dataset_parser.parse_query(query) 108 | 109 | passage_tokens = [] 110 | 111 | seg_indices = - np.ones((config_model['ns_passage'], 2), dtype=np.int32) 112 | start = len(query_tokens) + 2 113 | 114 | if len(passage) > config_model['ns_passage']: 115 | raise ValueError('Invalid #sents: {}'.format(len(passage))) 116 | 117 | for sent_idx, sent in enumerate(passage): 118 | sent_tokens = dataset_parser.sent2words(sent)[:config_model['max_nw_sent']] 119 | passage_tokens.extend(sent_tokens) 120 | 121 | end = start + len(sent_tokens) 122 | seg_indices[sent_idx, 0] = start 123 | seg_indices[sent_idx, 1] = end 124 | start = end 125 | 126 | seg_indices = np.array(seg_indices, dtype=np.int32) 127 | passage_in = _build_bert_tokens_for_passage(query_tokens, passage_tokens) 128 | 129 | res = { 130 | **passage_in, 131 | 'seg_indices': seg_indices, 132 | } 133 | 134 | return res 135 | -------------------------------------------------------------------------------- /src/baselines/lexrank/grsum.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | from os.path import join, dirname, abspath, exists 4 | 5 | sys_path = dirname(dirname(abspath(__file__))) 6 | parent_sys_path = dirname(sys_path) 7 | 8 | if sys_path not in sys.path: 9 | sys.path.insert(0, sys_path) 10 | if parent_sys_path not in sys.path: 11 | sys.path.insert(0, parent_sys_path) 12 | 13 | from tqdm import tqdm 14 | import utils.config_loader as config 15 | from utils.config_loader import logger, config_meta 16 | import utils.graph_io as graph_io 17 | import utils.graph_tools as graph_tools 18 | 19 | import summ.select_sent as select_sent 20 | import tools.tfidf_tools as tfidf_tools 21 | import tools.general_tools as general_tools 22 | import tools.vec_tools as vec_tools 23 | 24 | # MODEL_NAME = 'grsum-{}'.format(config.test_year) 25 | MODEL_NAME = 'grsum-{}'.format(config_meta['test_year']) 26 | DIVERSITY_PARAM_TUPLE = (10, 'wan') 27 | COS_THRESHOLD = 1.0 28 | DAMP = 0.85 29 | RM_DIALOG = False 30 | 31 | QUERY_TYPE = None # config.NARR, config.TITLE, None 32 | CONCAT_TITLE_NARR = False if QUERY_TYPE else True 33 | 34 | def _build_components(cid, query): 35 | sim_items = tfidf_tools.build_sim_items_e2e_tfidf_with_lexrank(cid, query, rm_dialog=RM_DIALOG) 36 | 37 | sim_mat = vec_tools.norm_sim_mat(sim_mat=sim_items['doc_sim_mat'], max_min_scale=False) 38 | rel_vec = vec_tools.norm_rel_scores(rel_scores=sim_items['rel_scores'], max_min_scale=False) 39 | logger.info('rel_vec: {}'.format(rel_vec)) 40 | 41 | if len(rel_vec) != len(sim_mat): 42 | raise ValueError('Incompatible sim_mat size: {} and rel_vec size: {} for cid: {}'.format( 43 | sim_mat.shape, rel_vec.shape, cid)) 44 | 45 | processed_sents = sim_items['processed_sents'] 46 | sid2abs = {} 47 | sid_abs = 0 48 | for doc_idx, doc in enumerate(processed_sents): 49 | for sent_idx, sent in enumerate(doc): 50 | sid = config.SEP.join((str(doc_idx), str(sent_idx))) 51 | sid2abs[sid] = sid_abs 52 | sid_abs += 1 53 | 54 | components = { 55 | 'sim_mat': sim_mat, 56 | 'rel_vec': rel_vec, 57 | 'sid2abs': sid2abs, 58 | } 59 | 60 | return components 61 | 62 | 63 | def build_components_e2e(): 64 | dp_params = { 65 | 'model_name': MODEL_NAME, 66 | 'n_iter': None, 67 | 'mode': 'w', 68 | } 69 | 70 | summ_comp_root = graph_io.get_summ_comp_root(**dp_params) 71 | sim_mat_dp = graph_io.get_sim_mat_dp(summ_comp_root, mode='w') 72 | rel_vec_dp = graph_io.get_rel_vec_dp(summ_comp_root, mode='w') 73 | sid2abs_dp = graph_io.get_sid2abs_dp(summ_comp_root, mode='w') 74 | 75 | logger.info('sim_mat_dp: {}'.format(sim_mat_dp)) 76 | logger.info('rel_vec_dp: {}'.format(rel_vec_dp)) 77 | logger.info('sid2abs_dp: {}'.format(sid2abs_dp)) 78 | 79 | test_cid_query_dicts = general_tools.build_test_cid_query_dicts(tokenize_narr=False, 80 | concat_title_narr=CONCAT_TITLE_NARR, 81 | query_type=QUERY_TYPE) 82 | 83 | for params in tqdm(test_cid_query_dicts): 84 | logger.info('cid: {}'.format(params['cid'])) 85 | 86 | components = _build_components(**params) 87 | graph_io.dump_sim_mat(sim_mat=components['sim_mat'], sim_mat_dp=sim_mat_dp, cid=params['cid']) 88 | graph_io.dump_rel_vec(rel_vec=components['rel_vec'], rel_vec_dp=rel_vec_dp, cid=params['cid']) 89 | graph_io.dump_sid2abs(sid2abs=components['sid2abs'], sid2abs_dp=sid2abs_dp, cid=params['cid']) 90 | 91 | 92 | def score_e2e(): 93 | if DAMP == 1.0: 94 | damp = 0.85 95 | use_rel_vec = False 96 | else: 97 | damp = DAMP 98 | use_rel_vec = True 99 | 100 | graph_tools.score_end2end(model_name=MODEL_NAME, 101 | damp=damp, 102 | use_rel_vec=use_rel_vec, 103 | rm_dialog=RM_DIALOG) 104 | 105 | 106 | def rank_e2e(): 107 | graph_tools.rank_end2end(model_name=MODEL_NAME, 108 | diversity_param_tuple=DIVERSITY_PARAM_TUPLE, 109 | retrieved_dp=None, 110 | rm_dialog=RM_DIALOG) 111 | 112 | 113 | def select_e2e(): 114 | params = { 115 | 'model_name': MODEL_NAME, 116 | 'diversity_param_tuple': DIVERSITY_PARAM_TUPLE, 117 | 'cos_threshold': COS_THRESHOLD, 118 | 'rm_dialog': RM_DIALOG, 119 | } 120 | select_sent.select_end2end(**params) 121 | 122 | 123 | if __name__ == '__main__': 124 | build_components_e2e() 125 | score_e2e() 126 | rank_e2e() 127 | select_e2e() 128 | -------------------------------------------------------------------------------- /src/utils/graph_io.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | from os.path import join, dirname, abspath, exists 4 | import sys 5 | import json 6 | import numpy as np 7 | from utils.config_loader import path_parser 8 | import utils.tools as tools 9 | 10 | 11 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 12 | 13 | 14 | def get_summ_comp_root(model_name, n_iter, mode): 15 | """ 16 | 17 | :param model_name: 18 | :param attn_weigh: 19 | :param doc_weigh: 20 | :param n_iter: 21 | :param mode: w or r 22 | :return: 23 | """ 24 | dn_items = tools.get_dir_name_items(model_name, n_iter) 25 | root_dp = join(path_parser.graph, '-'.join(dn_items)) 26 | 27 | if mode == 'r': 28 | if not exists(root_dp): 29 | raise ValueError('root_dp does not exists: {}'.format(root_dp)) 30 | elif mode == 'w': 31 | if exists(root_dp): 32 | raise ValueError('root_dp already exists: {}'.format(root_dp)) 33 | os.mkdir(root_dp) 34 | else: 35 | raise ValueError('Invalid mode: {}'.format(mode)) 36 | 37 | return root_dp 38 | 39 | 40 | def get_sim_mat_dp(summ_comp_root, mode): 41 | sim_mat_dp = join(summ_comp_root, 'sim_mat') 42 | 43 | if mode == 'r': 44 | if not exists(sim_mat_dp): 45 | raise ValueError('sim_mat_dp does not exists: {}'.format(sim_mat_dp)) 46 | elif mode == 'w': 47 | if exists(sim_mat_dp): 48 | raise ValueError('sim_mat_dp already exists: {}'.format(sim_mat_dp)) 49 | os.mkdir(sim_mat_dp) 50 | else: 51 | raise ValueError('Invalid mode: {}'.format(mode)) 52 | 53 | return sim_mat_dp 54 | 55 | 56 | def get_rel_vec_dp(summ_comp_root, mode): 57 | rel_vec_dp = join(summ_comp_root, 'rel_vec') 58 | 59 | if mode == 'r': 60 | if not exists(rel_vec_dp): 61 | raise ValueError('rel_vec_dp does not exists: {}'.format(rel_vec_dp)) 62 | elif mode == 'w': 63 | if exists(rel_vec_dp): 64 | raise ValueError('rel_vec_dp already exists: {}'.format(rel_vec_dp)) 65 | os.mkdir(rel_vec_dp) 66 | else: 67 | raise ValueError('Invalid mode: {}'.format(mode)) 68 | 69 | return rel_vec_dp 70 | 71 | 72 | def get_sid2abs_dp(summ_comp_root, mode): 73 | sid2abs_dp = join(summ_comp_root, 'sid2abs') 74 | 75 | if mode == 'r': 76 | if not exists(sid2abs_dp): 77 | raise ValueError('sid2abs_dp does not exists: {}'.format(sid2abs_dp)) 78 | elif mode == 'w': 79 | if exists(sid2abs_dp): 80 | raise ValueError('sid2abs_dp already exists: {}'.format(sid2abs_dp)) 81 | os.mkdir(sid2abs_dp) 82 | else: 83 | raise ValueError('Invalid mode: {}'.format(mode)) 84 | 85 | return sid2abs_dp 86 | 87 | 88 | def get_sid2score_dp(summ_comp_root, mode): 89 | sid2score_dp = join(summ_comp_root, 'sid2score') 90 | 91 | if mode == 'r': 92 | if not exists(sid2score_dp): 93 | raise ValueError('sid2score_dp does not exists: {}'.format(sid2score_dp)) 94 | elif mode == 'w': 95 | if exists(sid2score_dp): 96 | raise ValueError('sid2score_dp already exists: {}'.format(sid2score_dp)) 97 | os.mkdir(sid2score_dp) 98 | else: 99 | raise ValueError('Invalid mode: {}'.format(mode)) 100 | 101 | return sid2score_dp 102 | 103 | 104 | def dump_sim_mat(sim_mat, sim_mat_dp, cid): 105 | np.save(join(sim_mat_dp, cid), sim_mat) 106 | 107 | 108 | def load_sim_mat(sim_mat_dp, cid): 109 | return np.load(join(sim_mat_dp, '{}.npy'.format(cid))) 110 | 111 | 112 | def dump_rel_vec(rel_vec, rel_vec_dp, cid): 113 | np.save(join(rel_vec_dp, cid), rel_vec) 114 | 115 | 116 | def load_rel_vec(rel_vec_dp, cid): 117 | return np.load(join(rel_vec_dp, '{}.npy'.format(cid))) 118 | 119 | 120 | def dump_sid2abs(sid2abs, sid2abs_dp, cid): 121 | fp = '{}.json'.format(join(sid2abs_dp, cid)) 122 | with io.open(fp, 'a') as f: 123 | json.dump(sid2abs, f) 124 | 125 | 126 | def load_sid2abs(sid2abs_dp, cid): 127 | fp = '{}.json'.format(join(sid2abs_dp, cid)) 128 | with io.open(fp, 'r') as f: 129 | return json.load(f) 130 | 131 | 132 | def dump_sid2score(sid2score, sid2score_dp, cid): 133 | fp = '{}.json'.format(join(sid2score_dp, cid)) 134 | with io.open(fp, 'a') as f: 135 | json.dump(sid2score, f) 136 | 137 | 138 | def load_sid2score(sid2score_dp, cid): 139 | fp = '{}.json'.format(join(sid2score_dp, cid)) 140 | with io.open(fp, 'r') as f: 141 | return json.load(f) 142 | 143 | 144 | def load_components(sim_mat_dp, rel_vec_dp, sid2abs_dp, cid): 145 | sim_mat = load_sim_mat(sim_mat_dp, cid) 146 | rel_vec = load_rel_vec(rel_vec_dp, cid) 147 | sid2abs = load_sid2abs(sid2abs_dp, cid) 148 | 149 | components = { 150 | 'sim_mat': sim_mat, 151 | 'rel_vec': rel_vec, 152 | 'sid2abs': sid2abs, 153 | } 154 | return components -------------------------------------------------------------------------------- /src/frame/bert_qa/data_pipe_cluster.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | from os.path import dirname, abspath 5 | import sys 6 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 7 | 8 | import utils.config_loader as config 9 | from utils.config_loader import config_model 10 | from data.dataset_parser import dataset_parser 11 | import data.data_tools as data_tools 12 | 13 | import utils.tools as tools 14 | from frame.ir.ir_tools import load_retrieved_sentences 15 | from frame.bert_qa.bert_input import build_bert_sentence_x 16 | 17 | 18 | class ClusterDataset(Dataset): 19 | def __init__(self, cid, query, retrieve_dp, transform=None): 20 | super(ClusterDataset, self).__init__() 21 | original_sents, _ = load_retrieved_sentences(retrieved_dp=retrieve_dp, cid=cid) 22 | self.sentences = original_sents[0] 23 | 24 | self.query = query 25 | self.yy = 0.0 # 0.0 26 | 27 | self.transform = transform 28 | 29 | def __len__(self): 30 | return len(self.sentences) 31 | 32 | @staticmethod 33 | def _vec_label(yy): 34 | if yy == '-1.0': 35 | yy = 0.0 36 | return np.array([yy], dtype=np.float32) 37 | 38 | def __getitem__(self, index): 39 | """ 40 | get an item from self.doc_ids. 41 | 42 | return a sample: (xx, yy) 43 | """ 44 | # build xx 45 | xx = build_bert_sentence_x(self.query, sentence=self.sentences[index]) 46 | 47 | # build yy 48 | yy = self._vec_label(self.yy) 49 | 50 | sample = { 51 | **xx, 52 | 'yy': yy, 53 | } 54 | 55 | if self.transform: 56 | sample = self.transform(sample) 57 | 58 | return sample 59 | 60 | 61 | class ClusterDataLoader(DataLoader): 62 | def __init__(self, cid, query, retrieve_dp, transform=data_tools.ToTensor()): 63 | dataset = ClusterDataset(cid, query, retrieve_dp=retrieve_dp) 64 | self.transform = transform 65 | self.cid = cid 66 | 67 | super(ClusterDataLoader, self).__init__(dataset=dataset, 68 | batch_size=config_model['d_batch'], 69 | shuffle=False, 70 | num_workers=3, # 3 71 | drop_last=False) 72 | 73 | def _generator(self, super_iter): 74 | while True: 75 | batch = next(super_iter) 76 | batch = self.transform(batch) 77 | yield batch 78 | 79 | def __iter__(self): 80 | super_iter = super(ClusterDataLoader, self).__iter__() 81 | return self._generator(super_iter) 82 | 83 | 84 | class QSDataLoader: 85 | """ 86 | iter over all clusters. 87 | each cluster is handled with a separate data loader. 88 | 89 | tokenize_narr: whether tokenize query into sentences. 90 | """ 91 | 92 | def __init__(self, tokenize_narr, query_type, retrieve_dp): 93 | if query_type == config.TITLE: 94 | query_dict = dataset_parser.get_cid2title() 95 | elif query_type == config.NARR: 96 | query_dict = dataset_parser.get_cid2narr() 97 | elif query_type == config.QUERY: 98 | query_dict = dataset_parser.get_cid2query(tokenize_narr) 99 | else: 100 | raise ValueError('Invalid query_type: {}'.format(query_type)) 101 | 102 | cids = tools.get_test_cc_ids() 103 | 104 | self.loader_init_params = [] 105 | for cid in cids: 106 | query = tools.get_query_w_cid(query_dict, cid=cid) 107 | # query = query_dict[cid] 108 | self.loader_init_params.append({ 109 | 'cid': cid, 110 | 'query': query, 111 | 'retrieve_dp': retrieve_dp, 112 | }) 113 | 114 | def _loader_generator(self): 115 | for params in self.loader_init_params: 116 | c_loader = ClusterDataLoader(**params) 117 | yield c_loader 118 | 119 | def __iter__(self): 120 | return self._loader_generator() 121 | 122 | 123 | class TdqfsQSDataLoader: 124 | """ 125 | iter over all clusters. 126 | each cluster is handled with a separate data loader. 127 | """ 128 | 129 | def __init__(self, test_cid_query_dicts, retrieve_dp): 130 | self.loader_init_params = [] 131 | for cq_dict in test_cid_query_dicts: 132 | self.loader_init_params.append({ 133 | 'cid': cq_dict['cid'], 134 | 'query': cq_dict['query'], 135 | 'retrieve_dp': retrieve_dp, 136 | }) 137 | 138 | def _loader_generator(self): 139 | for params in self.loader_init_params: 140 | c_loader = ClusterDataLoader(**params) 141 | yield c_loader 142 | 143 | def __iter__(self): 144 | return self._loader_generator() 145 | -------------------------------------------------------------------------------- /src/frame/centrality/centrality_tfidf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os.path import join, dirname, abspath 3 | 4 | sys_path = dirname(dirname(abspath(__file__))) 5 | parent_sys_path = dirname(sys_path) 6 | 7 | if sys_path not in sys.path: 8 | sys.path.insert(0, sys_path) 9 | if parent_sys_path not in sys.path: 10 | sys.path.insert(0, parent_sys_path) 11 | 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | import utils.config_loader as config 16 | from utils.config_loader import logger, path_parser 17 | import utils.graph_io as graph_io 18 | import utils.graph_tools as graph_tools 19 | import tools.tfidf_tools as tfidf_tools 20 | import tools.general_tools as general_tools 21 | import summ.compute_rouge as rouge 22 | import frame.centrality.centrality_config as centrality_config 23 | 24 | MODEL_NAME = 'centrality-tfidf-2007-{}_damp-{}_link'.format(centrality_config.DAMP, 25 | centrality_config.LINK_TYPE) 26 | 27 | if centrality_config.LINK_TYPE == 'inter': 28 | mask_intra = True 29 | else: 30 | mask_intra = False 31 | 32 | 33 | def _build_components(cid, query): 34 | sim_items = tfidf_tools.build_sim_items_e2e(cid, query, mask_intra=mask_intra) 35 | rel_scores = sim_items['rel_scores'] 36 | sim_mat = sim_items['doc_sim_mat'] 37 | processed_sents = sim_items['processed_sents'] 38 | 39 | rel_vec = rel_scores / np.sum(rel_scores) # l1 norm to make a distribution 40 | 41 | np.fill_diagonal(sim_mat, 0.0) # avoid self-transition 42 | logger.info('sim_mat: {}'.format(sim_mat)) 43 | 44 | sid2abs = {} 45 | sid_abs = 0 46 | for doc_idx, doc in enumerate(processed_sents): 47 | for sent_idx, sent in enumerate(doc): 48 | sid = config.SEP.join((str(doc_idx), str(sent_idx))) 49 | sid2abs[sid] = sid_abs 50 | sid_abs += 1 51 | 52 | components = { 53 | 'sim_mat': sim_mat, 54 | 'rel_vec': rel_vec, 55 | 'sid2abs': sid2abs, 56 | } 57 | 58 | return components 59 | 60 | 61 | def build_components_e2e(): 62 | dp_params = { 63 | 'model_name': MODEL_NAME, 64 | 'n_iter': None, 65 | 'mode': 'w', 66 | } 67 | 68 | summ_comp_root = graph_io.get_summ_comp_root(**dp_params) 69 | sim_mat_dp = graph_io.get_sim_mat_dp(summ_comp_root, mode='w') 70 | rel_vec_dp = graph_io.get_rel_vec_dp(summ_comp_root, mode='w') 71 | sid2abs_dp = graph_io.get_sid2abs_dp(summ_comp_root, mode='w') 72 | 73 | logger.info('sim_mat_dp: {}'.format(sim_mat_dp)) 74 | logger.info('rel_vec_dp: {}'.format(rel_vec_dp)) 75 | logger.info('sid2abs_dp: {}'.format(sid2abs_dp)) 76 | 77 | test_cid_query_dicts = general_tools.build_test_cid_query_dicts(tokenize_narr=False, 78 | concat_title_narr=False, 79 | query_type=centrality_config.QUERY_TYPE) 80 | 81 | for params in tqdm(test_cid_query_dicts): 82 | components = _build_components(**params) 83 | 84 | graph_io.dump_sim_mat(sim_mat=components['sim_mat'], sim_mat_dp=sim_mat_dp, cid=params['cid']) 85 | graph_io.dump_rel_vec(rel_vec=components['rel_vec'], rel_vec_dp=rel_vec_dp, cid=params['cid']) 86 | graph_io.dump_sid2abs(sid2abs=components['sid2abs'], sid2abs_dp=sid2abs_dp, cid=params['cid']) 87 | 88 | logger.info('[BUILD GRAPH COMPONENT] dumping sim mat file to: {0}'.format(sim_mat_dp)) 89 | logger.info('[BUILD GRAPH COMPONENT] dumping rel vec file to: {0}'.format(rel_vec_dp)) 90 | logger.info('[BUILD GRAPH COMPONENT] dumping sid2abs file to: {0}'.format(sid2abs_dp)) 91 | 92 | 93 | def score_e2e(): 94 | damp = centrality_config.DAMP 95 | use_rel_vec = True 96 | if damp == 1.0: 97 | damp = 0.85 98 | use_rel_vec = False 99 | 100 | graph_tools.score_end2end(model_name=MODEL_NAME, 101 | damp=damp, 102 | use_rel_vec=use_rel_vec) 103 | 104 | 105 | def rank_e2e(omega=10, max_n_iter=100): 106 | graph_tools.rank_end2end(model_name=MODEL_NAME, omega=omega, max_n_iter=max_n_iter) 107 | 108 | 109 | def select_e2e(omega=10): 110 | graph_tools.select_end2end(model_name=MODEL_NAME, omega=omega) 111 | 112 | 113 | def compute_rouge(omega=10): 114 | params = { 115 | 'model_name': MODEL_NAME, 116 | 'n_iter': None, 117 | 'cos_threshold': 1.0, # do not pos cosine similarity criterion 118 | 'omega': omega, 119 | 'manual': True, 120 | } 121 | rouge.compute_rouge_end2end(**params) 122 | 123 | 124 | def search_optimum_omega(max_n_iter=100): 125 | out_fp = join(path_parser.rouge, 'omega_search-{}'.format(MODEL_NAME)) 126 | 127 | for omega in range(11, 21): 128 | graph_tools.rank_end2end(model_name=MODEL_NAME, omega=omega, max_n_iter=max_n_iter) 129 | graph_tools.select_end2end(model_name=MODEL_NAME, omega=omega, save_out_fp=out_fp) 130 | 131 | 132 | if __name__ == '__main__': 133 | build_components_e2e() 134 | score_e2e() 135 | rank_e2e() 136 | # select_e2e() 137 | # compute_rouge() 138 | -------------------------------------------------------------------------------- /src/baselines/lexrank/lexrank_tfidf_tdqfs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import io 4 | import os 5 | from os import listdir 6 | from os.path import join, dirname, abspath, exists 7 | 8 | sys_path = dirname(dirname(abspath(__file__))) 9 | parent_sys_path = dirname(sys_path) 10 | 11 | if sys_path not in sys.path: 12 | sys.path.insert(0, sys_path) 13 | if parent_sys_path not in sys.path: 14 | sys.path.insert(0, parent_sys_path) 15 | 16 | import utils.config_loader as config 17 | from utils.config_loader import logger, path_parser 18 | from data.dataset_parser import dataset_parser 19 | import utils.tools as tools 20 | import summ.rank_sent as rank_sent 21 | import summ.select_sent as select_sent 22 | 23 | from lexrank import STOPWORDS, LexRank 24 | import itertools 25 | from tqdm import tqdm 26 | 27 | import tools.general_tools as general_tools 28 | from utils.tools import get_text_dp_for_tdqfs 29 | import summ.compute_rouge as rouge 30 | 31 | assert config.grain == 'sent' 32 | MODEL_NAME = 'lexrank-{}'.format(config.test_year) 33 | COS_THRESHOLD = 1.0 34 | 35 | assert 'tdqfs' in config.test_year 36 | sentence_dp = path_parser.data_tdqfs_sentences 37 | query_fp = path_parser.data_tdqfs_queries 38 | tdqfs_summary_target_dp = path_parser.data_tdqfs_summary_targets 39 | 40 | test_cid_query_dicts = general_tools.build_tdqfs_cid_query_dicts(query_fp=query_fp, proc=True) 41 | cc_ids = [cq_dict['cid'] for cq_dict in test_cid_query_dicts] 42 | LENGTH_BUDGET_TUPLE = ('nw', 250) 43 | 44 | def _get_sentences(cid): 45 | cc_dp = join(sentence_dp, cid) 46 | fns = [fn for fn in listdir(cc_dp)] 47 | lines = itertools.chain(*[io.open(join(cc_dp, fn)).readlines() for fn in fns]) 48 | sentences = [line.strip('\n') for line in lines] 49 | 50 | original_sents = [] 51 | processed_sents = [] 52 | for ss in sentences: 53 | ss_origin = dataset_parser._proc_sent(ss, rm_dialog=False, rm_stop=False, stem=False) 54 | ss_proc = dataset_parser._proc_sent(ss, rm_dialog=False, rm_stop=True, stem=True) 55 | 56 | if ss_proc: # make sure the sent is not removed, i.e., is not empty and is not in a dialog 57 | original_sents.append(ss_origin) 58 | processed_sents.append(ss_proc) 59 | 60 | return [original_sents], [processed_sents] 61 | 62 | 63 | def _lexrank(cid): 64 | """ 65 | Run LexRank on all sentences from all documents in a cluster. 66 | 67 | :param cid: 68 | :return: rank_records 69 | """ 70 | _, processed_sents = dataset_parser.cid2sents_tdqfs(cid) # 2d lists, docs => sents 71 | flat_processed_sents = list(itertools.chain(*processed_sents)) # 1d sent list 72 | 73 | lxr = LexRank(processed_sents, stopwords=STOPWORDS['en']) 74 | scores = lxr.rank_sentences(flat_processed_sents, threshold=None, fast_power_method=True) 75 | 76 | sid2score = dict() 77 | abs_idx = 0 78 | for doc_idx, doc in enumerate(processed_sents): 79 | for sent_idx, sent in enumerate(doc): 80 | sid = config.SEP.join((str(doc_idx), str(sent_idx))) 81 | score = scores[abs_idx] 82 | sid2score[sid] = score 83 | 84 | abs_idx += 1 85 | 86 | sid_score_list = rank_sent.sort_sid2score(sid2score) 87 | rank_records = rank_sent.get_rank_records(sid_score_list, sents=processed_sents, flat_sents=False) 88 | return rank_records 89 | 90 | 91 | def rank_e2e(): 92 | rank_dp = tools.get_rank_dp(model_name=MODEL_NAME) 93 | if exists(rank_dp): 94 | raise ValueError('rank_dp exists: {}'.format(rank_dp)) 95 | os.mkdir(rank_dp) 96 | 97 | for cid in tqdm(cc_ids): 98 | rank_records = _lexrank(cid) 99 | rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, cid), with_rank_idx = False) 100 | 101 | logger.info('Successfully dumped rankings to: {}'.format(rank_dp)) 102 | 103 | 104 | def select_e2e(): 105 | """ 106 | No redundancy removal is applied here. 107 | """ 108 | params = { 109 | 'model_name': MODEL_NAME, 110 | 'cos_threshold': COS_THRESHOLD, 111 | } 112 | select_sent.select_end2end(**params) 113 | 114 | 115 | def select_e2e_tdqfs(): 116 | params = { 117 | 'model_name': MODEL_NAME, 118 | 'length_budget_tuple': LENGTH_BUDGET_TUPLE, 119 | 'cos_threshold': COS_THRESHOLD, 120 | 'cc_ids': cc_ids, 121 | } 122 | select_sent.select_end2end_for_tdqfs(**params) 123 | 124 | 125 | def compute_rouge_tdqfs(length): 126 | text_params = { 127 | 'model_name': MODEL_NAME, 128 | 'length_budget_tuple': LENGTH_BUDGET_TUPLE, 129 | 'cos_threshold': COS_THRESHOLD, 130 | } 131 | 132 | text_dp = get_text_dp_for_tdqfs(**text_params) 133 | 134 | rouge_parmas = { 135 | 'text_dp': text_dp, 136 | 'ref_dp': tdqfs_summary_target_dp, 137 | } 138 | if LENGTH_BUDGET_TUPLE[0] == 'nw': 139 | rouge_parmas['length'] = LENGTH_BUDGET_TUPLE[1] 140 | 141 | output = rouge.compute_rouge_for_tdqfs(**rouge_parmas) 142 | return output 143 | 144 | 145 | if __name__ == '__main__': 146 | rank_e2e() 147 | select_e2e_tdqfs() 148 | compute_rouge_tdqfs(length=None) 149 | -------------------------------------------------------------------------------- /src/frame/bert_qa/data_pipe_cluster_cosine.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from os.path import dirname, abspath 6 | import sys 7 | 8 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 9 | 10 | import utils.config_loader as config 11 | from utils.config_loader import config_model 12 | from data.dataset_parser import dataset_parser 13 | import data.data_tools as data_tools 14 | 15 | import utils.tools as tools 16 | from frame.ir.ir_tools import load_retrieved_sentences 17 | from frame.bert_qa.bert_input_cosine import build_query, build_sentence 18 | 19 | 20 | class ClusterDataset(Dataset): 21 | def __init__(self, cid, retrieve_dp, transform=None): 22 | super(ClusterDataset, self).__init__() 23 | original_sents, _ = load_retrieved_sentences(retrieved_dp=retrieve_dp, cid=cid) 24 | self.sentences = original_sents[0] 25 | self.yy = 0.0 # 0.0 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return len(self.sentences) 30 | 31 | @staticmethod 32 | def _vec_label(yy): 33 | if yy == '-1.0': 34 | yy = 0.0 35 | return np.array([yy], dtype=np.float32) 36 | 37 | def __getitem__(self, index): 38 | """ 39 | get an item from self.doc_ids. 40 | 41 | return a sample: (xx, yy) 42 | """ 43 | # build xx 44 | sent = build_sentence(sentence=self.sentences[index]) 45 | 46 | # build yy 47 | yy = self._vec_label(self.yy) 48 | 49 | sample = { 50 | **sent, 51 | 'yy': yy, 52 | } 53 | 54 | if self.transform: 55 | sample = self.transform(sample) 56 | 57 | return sample 58 | 59 | 60 | class ClusterDataLoader(DataLoader): 61 | def __init__(self, cid, query, retrieve_dp, transform=data_tools.ToTensor()): 62 | self.cid = cid 63 | dataset = ClusterDataset(cid, retrieve_dp=retrieve_dp) 64 | self.transform = transform 65 | 66 | query_dict = build_query(query) 67 | for (k, v) in query_dict.items(): 68 | v = v.reshape(1, -1) 69 | query_dict[k] = torch.from_numpy(v) 70 | self.query_in = self.transform(query_dict) 71 | 72 | super(ClusterDataLoader, self).__init__(dataset=dataset, 73 | batch_size=config_model['d_batch'], 74 | shuffle=False, 75 | num_workers=3, # 3 76 | drop_last=False) 77 | 78 | def _generator(self, super_iter): 79 | while True: 80 | batch = next(super_iter) 81 | batch = self.transform(batch) 82 | yield batch 83 | 84 | def __iter__(self): 85 | super_iter = super(ClusterDataLoader, self).__iter__() 86 | return self._generator(super_iter) 87 | 88 | 89 | class QSDataLoader: 90 | """ 91 | iter over all clusters. 92 | each cluster is handled with a separate data loader. 93 | 94 | tokenize_narr: whether tokenize query into sentences. 95 | """ 96 | 97 | def __init__(self, tokenize_narr, query_type, retrieve_dp): 98 | if query_type == config.TITLE: 99 | query_dict = dataset_parser.get_cid2title() 100 | elif query_type == config.NARR: 101 | query_dict = dataset_parser.get_cid2narr() 102 | elif query_type == config.QUERY: 103 | query_dict = dataset_parser.get_cid2query(tokenize_narr) 104 | else: 105 | raise ValueError('Invalid query_type: {}'.format(query_type)) 106 | 107 | cids = tools.get_test_cc_ids() 108 | 109 | self.loader_init_params = [] 110 | for cid in cids: 111 | query = tools.get_query_w_cid(query_dict, cid=cid) 112 | # query = query_dict[cid] 113 | self.loader_init_params.append({ 114 | 'cid': cid, 115 | 'query': query, 116 | 'retrieve_dp': retrieve_dp, 117 | }) 118 | 119 | def _loader_generator(self): 120 | for params in self.loader_init_params: 121 | c_loader = ClusterDataLoader(**params) 122 | yield c_loader 123 | 124 | def __iter__(self): 125 | return self._loader_generator() 126 | 127 | 128 | class TdqfsQSDataLoader: 129 | """ 130 | iter over all clusters. 131 | each cluster is handled with a separate data loader. 132 | """ 133 | 134 | def __init__(self, test_cid_query_dicts, retrieve_dp): 135 | self.loader_init_params = [] 136 | for cq_dict in test_cid_query_dicts: 137 | self.loader_init_params.append({ 138 | 'cid': cq_dict['cid'], 139 | 'query': cq_dict['query'], 140 | 'retrieve_dp': retrieve_dp, 141 | }) 142 | 143 | def _loader_generator(self): 144 | for params in self.loader_init_params: 145 | c_loader = ClusterDataLoader(**params) 146 | yield c_loader 147 | 148 | def __iter__(self): 149 | return self._loader_generator() 150 | -------------------------------------------------------------------------------- /src/frame/bert_passage/data_pipe_cluster.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | from os.path import join, dirname, abspath 5 | import sys 6 | 7 | import utils.config_loader as config 8 | from utils.config_loader import config_model 9 | from data.dataset_parser import dataset_parser 10 | import data.data_tools as data_tools 11 | 12 | import utils.tools as tools 13 | from frame.ir.ir_tools import load_retrieved_passages 14 | from frame.bert_passage.bert_input import build_bert_passage_x 15 | 16 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 17 | 18 | 19 | class ClusterDataset(Dataset): 20 | def __init__(self, cid, query, retrieve_dp, tdqfs_data=False, transform=None): 21 | super(ClusterDataset, self).__init__() 22 | original_passages, _, passage_ids = load_retrieved_passages(retrieved_dp=retrieve_dp, 23 | cid=cid, 24 | get_sents=True, 25 | tdqfs_data=tdqfs_data) 26 | self.passages = original_passages 27 | self.passage_ids = passage_ids 28 | 29 | self.query = query 30 | self.yy = 0.0 # 0.0 31 | 32 | self.transform = transform 33 | 34 | def __len__(self): 35 | return len(self.passages) 36 | 37 | @staticmethod 38 | def _vec_label(yy): 39 | if yy == '-1.0': 40 | yy = 0.0 41 | return np.array([yy], dtype=np.float32) 42 | 43 | def __getitem__(self, index): 44 | """ 45 | get an item from self.doc_ids. 46 | 47 | return a sample: (xx, yy) 48 | """ 49 | # build xx 50 | xx = build_bert_passage_x(self.query, passage=self.passages[index]) 51 | 52 | # build yy 53 | yy = self._vec_label(self.yy) 54 | 55 | sample = { 56 | **xx, 57 | 'yy': yy, 58 | } 59 | 60 | if self.transform: 61 | sample = self.transform(sample) 62 | 63 | return sample 64 | 65 | 66 | class ClusterDataLoader(DataLoader): 67 | def __init__(self, cid, query, retrieve_dp, tdqfs_data=False, transform=data_tools.ToTensor()): 68 | dataset = ClusterDataset(cid, query, retrieve_dp=retrieve_dp, tdqfs_data=tdqfs_data) 69 | self.transform = transform 70 | self.cid = cid 71 | self.passage_ids = dataset.passage_ids 72 | 73 | super(ClusterDataLoader, self).__init__(dataset=dataset, 74 | batch_size=8, 75 | shuffle=False, 76 | num_workers=3, # 3 77 | drop_last=False) 78 | 79 | def _generator(self, super_iter): 80 | while True: 81 | batch = next(super_iter) 82 | batch = self.transform(batch) 83 | yield batch 84 | 85 | def __iter__(self): 86 | super_iter = super(ClusterDataLoader, self).__iter__() 87 | return self._generator(super_iter) 88 | 89 | 90 | class QSDataLoader: 91 | """ 92 | iter over all clusters. 93 | each cluster is handled with a separate data loader. 94 | 95 | tokenize_narr: whether tokenize query into sentences. 96 | """ 97 | 98 | def __init__(self, tokenize_narr, query_type, retrieve_dp): 99 | if query_type == config.TITLE: 100 | query_dict = dataset_parser.get_cid2title() 101 | elif query_type == config.NARR: 102 | query_dict = dataset_parser.get_cid2narr() 103 | elif query_type == config.QUERY: 104 | query_dict = dataset_parser.get_cid2query(tokenize_narr) 105 | else: 106 | raise ValueError('Invalid query_type: {}'.format(query_type)) 107 | 108 | cids = tools.get_test_cc_ids() 109 | 110 | self.loader_init_params = [] 111 | for cid in cids: 112 | query = query_dict[cid] 113 | self.loader_init_params.append({ 114 | 'cid': cid, 115 | 'query': query, 116 | 'retrieve_dp': retrieve_dp, 117 | }) 118 | 119 | def _loader_generator(self): 120 | for params in self.loader_init_params: 121 | c_loader = ClusterDataLoader(**params) 122 | yield c_loader 123 | 124 | def __iter__(self): 125 | return self._loader_generator() 126 | 127 | 128 | class TdqfsQSDataLoader: 129 | """ 130 | iter over all clusters. 131 | each cluster is handled with a separate data loader. 132 | """ 133 | 134 | def __init__(self, test_cid_query_dicts, retrieve_dp): 135 | self.loader_init_params = [] 136 | for cq_dict in test_cid_query_dicts: 137 | self.loader_init_params.append({ 138 | 'cid': cq_dict['cid'], 139 | 'query': cq_dict['query'], 140 | 'retrieve_dp': retrieve_dp, 141 | 'tdqfs_data': True, 142 | }) 143 | 144 | def _loader_generator(self): 145 | for params in self.loader_init_params: 146 | c_loader = ClusterDataLoader(**params) 147 | yield c_loader 148 | 149 | def __iter__(self): 150 | return self._loader_generator() 151 | -------------------------------------------------------------------------------- /src/data/data_pipe_cluster.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from os import listdir 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | from os.path import join, isdir, dirname, abspath 7 | import sys 8 | 9 | import utils.config_loader as config 10 | from utils.config_loader import path_parser, logger, config_model, config_meta 11 | from data.dataset_parser import dataset_parser 12 | import data.data_tools as data_tools 13 | 14 | import utils.tools as tools 15 | 16 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 17 | 18 | 19 | class ClusterDataset(Dataset): 20 | def __init__(self, cid, query, transform=None): 21 | super(ClusterDataset, self).__init__() 22 | self.doc_ids = tools.get_doc_ids(cid, remove_illegal=True) # remove empty docs with no preprocessed sents 23 | self.query = query 24 | self.yy = 0.0 # 0.0 25 | 26 | self.transform = transform 27 | self.bert_in_func = data_tools.get_bert_in_func() 28 | 29 | def __len__(self): 30 | return len(self.doc_ids) 31 | 32 | @staticmethod 33 | def _vec_label(yy): 34 | if yy == '-1.0': 35 | yy = 0.0 36 | return np.array([yy], dtype=np.float32) 37 | 38 | def __getitem__(self, index): 39 | """ 40 | get an item from self.doc_ids. 41 | 42 | return a sample: (xx, yy) 43 | """ 44 | year, cc, fn = self.doc_ids[index].split(config.SEP) 45 | 46 | doc_fp = join(path_parser.data_docs, year, cc, fn) 47 | 48 | # build xx 49 | xx = self.bert_in_func(self.query, doc_fp=doc_fp) 50 | 51 | # build yy 52 | yy = self._vec_label(self.yy) 53 | 54 | sample = { 55 | **xx, 56 | 'yy': yy, 57 | } 58 | 59 | if self.transform: 60 | sample = self.transform(sample) 61 | 62 | return sample 63 | 64 | 65 | class ClusterDataLoader(DataLoader): 66 | def __init__(self, cid, query, transform=data_tools.ToTensor()): 67 | dataset = ClusterDataset(cid, query) 68 | self.transform = transform 69 | self.cid = cid 70 | 71 | super(ClusterDataLoader, self).__init__(dataset=dataset, 72 | # batch_size=config_model['d_batch'], 73 | batch_size=1, 74 | shuffle=False, 75 | num_workers=3, # 3 76 | drop_last=False) 77 | 78 | def _generator(self, super_iter): 79 | while True: 80 | # try: 81 | batch = next(super_iter) 82 | # for func in self.transform: 83 | batch = self.transform(batch) 84 | yield batch 85 | 86 | def __iter__(self): 87 | super_iter = super(ClusterDataLoader, self).__iter__() 88 | return self._generator(super_iter) 89 | 90 | 91 | class QSDataLoaderOneClusterABatch: 92 | """ 93 | iter over all clusters. 94 | each cluster is handled with a separate data loader. 95 | """ 96 | 97 | def __init__(self): 98 | pos_narr_dict = dataset_parser.build_query_info(config_meta['test_year'], tokenize=None) 99 | # pos_headline_dict = dataset_parser.build_headline_info(config_meta['test_year'], tokenize=None, silent=True) 100 | 101 | self.loader_init_params = [] 102 | 103 | for cid in pos_narr_dict: 104 | narr = pos_narr_dict[cid][config.NARR] 105 | self.loader_init_params.append({ 106 | 'cid': cid, 107 | 'narr': narr, 108 | }) 109 | 110 | def _loader_generator(self): 111 | for params in self.loader_init_params: 112 | c_loader = ClusterDataLoader(**params) 113 | for batch_idx, batch in enumerate(c_loader): # should be only one batch / $n_docs$ batches in a loader 114 | # logger.info('batch_idx: {}'.format(batch_idx)) 115 | yield { 116 | 'cid': params['cid'], 117 | 'batch': batch, 118 | } 119 | 120 | def __iter__(self): 121 | return self._loader_generator() 122 | 123 | 124 | class QSDataLoader: 125 | """ 126 | iter over all clusters. 127 | each cluster is handled with a separate data loader. 128 | 129 | tokenize_narr: whether tokenize query into sentences. 130 | """ 131 | 132 | def __init__(self, tokenize_narr, query_type=None): 133 | # fixme: this class may not work right; check type(narr). 134 | if query_type == config.TITLE: 135 | query_dict = dataset_parser.get_cid2title() 136 | elif query_type == config.NARR: 137 | query_dict = dataset_parser.get_cid2narr() 138 | else: 139 | query_dict = dataset_parser.get_cid2query(tokenize_narr) 140 | 141 | cids = tools.get_test_cc_ids() 142 | 143 | self.loader_init_params = [] 144 | for cid in cids: 145 | query = query_dict[cid] 146 | self.loader_init_params.append({ 147 | 'cid': cid, 148 | 'query': query, 149 | }) 150 | 151 | def _loader_generator(self): 152 | for params in self.loader_init_params: 153 | c_loader = ClusterDataLoader(**params) 154 | yield c_loader 155 | 156 | def __iter__(self): 157 | return self._loader_generator() 158 | -------------------------------------------------------------------------------- /src/frame/ir/ir_passage_tf_tdqfs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | 4 | import os 5 | from os.path import join, dirname, abspath, exists 6 | 7 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 8 | sys.path.insert(0, dirname(dirname(dirname(abspath(__file__))))) 9 | 10 | import utils.config_loader as config 11 | from utils.config_loader import logger, path_parser 12 | import summ.rank_sent as rank_sent 13 | import utils.tools as tools 14 | import tools.tfidf_tools as tfidf_tools 15 | import tools.general_tools as general_tools 16 | import frame.ir.ir_tools as ir_tools 17 | 18 | from tqdm import tqdm 19 | import shutil 20 | import frame.ir.ir_config as ir_config 21 | import summ.compute_rouge as rouge 22 | from frame.ir.ir_tools import load_retrieved_passages 23 | import numpy as np 24 | 25 | if not config.grain.startswith('passage'): 26 | raise ValueError('Invalid grain: {}'.format(config.grain)) 27 | assert ir_config.test_year.startswith('tdqfs'), f'set ir_config.test_year to tdqfs! now: {ir_config.test_year}' 28 | 29 | query_fp = path_parser.data_tdqfs_queries 30 | test_cid_query_dicts = general_tools.build_tdqfs_cid_query_dicts(query_fp=query_fp, proc=True) 31 | cids = [cq_dict['cid'] for cq_dict in test_cid_query_dicts] 32 | 33 | 34 | def _rank(cid, query): 35 | pid2score = tfidf_tools.build_rel_scores_tf_passage(cid, query, tdqfs_data=True) 36 | # rank scores 37 | sid_score_list = rank_sent.sort_sid2score(pid2score) 38 | # include sentences in records 39 | rank_records = rank_sent.get_rank_records(sid_score_list, sents=None) 40 | # rank_records = rank_sent.get_rank_records(sid_score_list) 41 | 42 | return rank_records 43 | 44 | 45 | def rank_e2e(): 46 | """ 47 | 48 | :param pool_func: avg, max, or None (for integrated query). 49 | :return: 50 | """ 51 | rank_dp = join(path_parser.summary_rank, ir_config.IR_MODEL_NAME_TF) 52 | 53 | if exists(rank_dp): 54 | raise ValueError('rank_dp exists: {}'.format(rank_dp)) 55 | os.mkdir(rank_dp) 56 | 57 | for cid_query_dict in tqdm(test_cid_query_dicts): 58 | rank_records = _rank(**cid_query_dict) 59 | rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, cid_query_dict['cid']), with_rank_idx=False) 60 | logger.info('Successfully dumped rankings to: {}'.format(rank_dp)) 61 | 62 | 63 | def ir_rank2records(): 64 | ir_rec_dp = join(path_parser.summary_rank, ir_config.IR_RECORDS_DIR_NAME_TF) 65 | 66 | if exists(ir_rec_dp): 67 | raise ValueError('qa_rec_dp exists: {}'.format(ir_rec_dp)) 68 | os.mkdir(ir_rec_dp) 69 | 70 | for cid in tqdm(cids): 71 | retrieval_params = { 72 | 'model_name': ir_config.IR_MODEL_NAME_TF, 73 | 'cid': cid, 74 | 'filter_var': ir_config.FILTER_VAR, 75 | 'filter': ir_config.FILTER, 76 | 'deduplicate': ir_config.DEDUPLICATE, 77 | 'prune': True, 78 | } 79 | 80 | retrieved_items = ir_tools.retrieve(**retrieval_params) 81 | ir_tools.dump_retrieval(fp=join(ir_rec_dp, cid), retrieved_items=retrieved_items) 82 | 83 | 84 | def tune(): 85 | """ 86 | Tune IR confidence / compression rate based on Recall Rouge 2. 87 | :return: 88 | """ 89 | if ir_config.FILTER == 'conf': 90 | tune_range = np.arange(0.05, 1.05, 0.05) 91 | else: 92 | interval = 10 93 | tune_range = range(interval, 500 + interval, interval) 94 | 95 | ir_tune_dp = join(path_parser.summary_rank, ir_config.IR_TUNE_DIR_NAME_TF) 96 | ir_tune_result_fp = join(path_parser.tune, ir_config.IR_TUNE_DIR_NAME_TF) 97 | with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f: 98 | headline = 'Filter\tRecall\tF1\n' 99 | out_f.write(headline) 100 | 101 | cids = tools.get_test_cc_ids() 102 | for filter_var in tune_range: 103 | if exists(ir_tune_dp): # remove previous output 104 | shutil.rmtree(ir_tune_dp) 105 | os.mkdir(ir_tune_dp) 106 | 107 | for cid in tqdm(cids): 108 | retrieval_params = { 109 | 'model_name': ir_config.IR_MODEL_NAME_TF, 110 | 'cid': cid, 111 | 'filter_var': filter_var, 112 | 'filter': ir_config.FILTER, 113 | 'deduplicate': ir_config.DEDUPLICATE, 114 | 'prune': True, 115 | } 116 | 117 | retrieved_items = ir_tools.retrieve(**retrieval_params) # pid, score 118 | 119 | passage_ids = [item[0] for item in retrieved_items] 120 | original_passages, _, _ = load_retrieved_passages(cid=cid, 121 | get_sents=True, 122 | passage_ids=passage_ids) 123 | passages = ['\n'.join(sents) for sents in original_passages] 124 | summary = '\n'.join(passages) 125 | print(summary) 126 | # print(summary) 127 | with open(join(ir_tune_dp, cid), mode='a', encoding='utf-8') as out_f: 128 | out_f.write(summary) 129 | 130 | performance = rouge.compute_rouge_for_dev(ir_tune_dp, tune_centrality=False) 131 | with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f: 132 | if ir_config.FILTER == 'conf': 133 | rec = '{0:.2f}\t{1}\n'.format(filter_var, performance) 134 | else: 135 | rec = '{0}\t{1}\n'.format(filter_var, performance) 136 | 137 | out_f.write(rec) 138 | 139 | 140 | if __name__ == '__main__': 141 | rank_e2e() 142 | ir_rank2records() 143 | # tune() 144 | -------------------------------------------------------------------------------- /src/frame/ir/ir_passage_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | 4 | import os 5 | from os.path import join, dirname, abspath, exists 6 | 7 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 8 | sys.path.insert(0, dirname(dirname(dirname(abspath(__file__))))) 9 | 10 | import utils.config_loader as config 11 | from utils.config_loader import logger, path_parser 12 | import summ.rank_sent as rank_sent 13 | import utils.tools as tools 14 | import tools.tfidf_tools as tfidf_tools 15 | import tools.general_tools as general_tools 16 | import frame.ir.ir_tools as ir_tools 17 | 18 | from tqdm import tqdm 19 | import shutil 20 | import frame.ir.ir_config as ir_config 21 | import summ.compute_rouge as rouge 22 | from frame.ir.ir_tools import load_retrieved_passages 23 | import numpy as np 24 | 25 | if config.grain != 'passage': 26 | raise ValueError('Invalid grain: {}'.format(config.grain)) 27 | 28 | 29 | def _rank(cid, query): 30 | pid2score = tfidf_tools.build_rel_scores_tf_passage(cid, query) 31 | # rank scores 32 | sid_score_list = rank_sent.sort_sid2score(pid2score) 33 | # include sentences in records 34 | rank_records = rank_sent.get_rank_records(sid_score_list, sents=None) 35 | # rank_records = rank_sent.get_rank_records(sid_score_list) 36 | 37 | return rank_records 38 | 39 | 40 | def rank_e2e(): 41 | """ 42 | 43 | :param pool_func: avg, max, or None (for integrated query). 44 | :return: 45 | """ 46 | rank_dp = join(path_parser.summary_rank, ir_config.IR_MODEL_NAME_TF) 47 | test_cid_query_dicts = general_tools.build_test_cid_query_dicts(tokenize_narr=False, 48 | concat_title_narr=ir_config.CONCAT_TITLE_NARR, 49 | query_type=ir_config.QUERY_TYPE) 50 | 51 | if exists(rank_dp): 52 | raise ValueError('rank_dp exists: {}'.format(rank_dp)) 53 | os.mkdir(rank_dp) 54 | 55 | for cid_query_dict in tqdm(test_cid_query_dicts): 56 | params = { 57 | **cid_query_dict, 58 | } 59 | rank_records = _rank(**params) 60 | rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, params['cid']), with_rank_idx=False) 61 | 62 | logger.info('Successfully dumped rankings to: {}'.format(rank_dp)) 63 | 64 | 65 | def ir_rank2records(): 66 | ir_rec_dp = join(path_parser.summary_rank, ir_config.IR_RECORDS_DIR_NAME_TF) 67 | 68 | if exists(ir_rec_dp): 69 | raise ValueError('qa_rec_dp exists: {}'.format(ir_rec_dp)) 70 | os.mkdir(ir_rec_dp) 71 | 72 | cids = tools.get_test_cc_ids() 73 | for cid in tqdm(cids): 74 | retrieval_params = { 75 | 'model_name': ir_config.IR_MODEL_NAME_TF, 76 | 'cid': cid, 77 | 'filter_var': ir_config.FILTER_VAR, 78 | 'filter': ir_config.FILTER, 79 | 'deduplicate': ir_config.DEDUPLICATE, 80 | 'prune': True, 81 | } 82 | 83 | retrieved_items = ir_tools.retrieve(**retrieval_params) 84 | ir_tools.dump_retrieval(fp=join(ir_rec_dp, cid), retrieved_items=retrieved_items) 85 | 86 | 87 | def tune(): 88 | """ 89 | Tune IR confidence / compression rate based on Recall Rouge 2. 90 | :return: 91 | """ 92 | if ir_config.FILTER == 'conf': 93 | tune_range = np.arange(0.05, 1.05, 0.05) 94 | else: 95 | interval = 10 96 | tune_range = range(interval, 500 + interval, interval) 97 | 98 | ir_tune_dp = join(path_parser.summary_rank, ir_config.IR_TUNE_DIR_NAME_TF) 99 | ir_tune_result_fp = join(path_parser.tune, ir_config.IR_TUNE_DIR_NAME_TF) 100 | with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f: 101 | headline = 'Filter\tRecall\tF1\n' 102 | out_f.write(headline) 103 | 104 | cids = tools.get_test_cc_ids() 105 | for filter_var in tune_range: 106 | if exists(ir_tune_dp): # remove previous output 107 | shutil.rmtree(ir_tune_dp) 108 | os.mkdir(ir_tune_dp) 109 | 110 | for cid in tqdm(cids): 111 | retrieval_params = { 112 | 'model_name': ir_config.IR_MODEL_NAME_TF, 113 | 'cid': cid, 114 | 'filter_var': filter_var, 115 | 'filter': ir_config.FILTER, 116 | 'deduplicate': ir_config.DEDUPLICATE, 117 | 'prune': True, 118 | } 119 | 120 | retrieved_items = ir_tools.retrieve(**retrieval_params) # pid, score 121 | 122 | passage_ids = [item[0] for item in retrieved_items] 123 | original_passages, _, _ = load_retrieved_passages(cid=cid, 124 | get_sents=True, 125 | passage_ids=passage_ids) 126 | passages = ['\n'.join(sents) for sents in original_passages] 127 | summary = '\n'.join(passages) 128 | print(summary) 129 | # print(summary) 130 | with open(join(ir_tune_dp, cid), mode='a', encoding='utf-8') as out_f: 131 | out_f.write(summary) 132 | 133 | performance = rouge.compute_rouge_for_dev(ir_tune_dp, tune_centrality=False) 134 | with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f: 135 | if ir_config.FILTER == 'conf': 136 | rec = '{0:.2f}\t{1}\n'.format(filter_var, performance) 137 | else: 138 | rec = '{0}\t{1}\n'.format(filter_var, performance) 139 | 140 | out_f.write(rec) 141 | 142 | 143 | if __name__ == '__main__': 144 | rank_e2e() 145 | # ir_rank2records() 146 | tune() 147 | -------------------------------------------------------------------------------- /src/scripts/proc_tdqfs.py: -------------------------------------------------------------------------------- 1 | import io 2 | import sys 3 | from os import mkdir, listdir 4 | from os.path import join, dirname, abspath, exists 5 | import json 6 | from random import choice 7 | from tqdm import tqdm 8 | from nltk.tokenize import sent_tokenize 9 | import itertools 10 | from multiprocessing import Pool 11 | 12 | sys_path = dirname(dirname(abspath(__file__))) 13 | parent_sys_path = dirname(sys_path) 14 | 15 | if sys_path not in sys.path: 16 | sys.path.insert(0, sys_path) 17 | if parent_sys_path not in sys.path: 18 | sys.path.insert(0, parent_sys_path) 19 | 20 | 21 | data_root = '~/querysum/data/tdqfs' 22 | 23 | 24 | def build_summary_targets(): 25 | src_dp = join(data_root, 'manual_summaries') 26 | cids = [dn for dn in listdir(src_dp)] 27 | 28 | dump_dp = join(data_root, 'summary_targets') 29 | for cid in tqdm(cids): 30 | src_cc_dp = join(src_dp, cid) 31 | fns = [fn for fn in listdir(src_cc_dp)] 32 | 33 | dump_cc_dp = join(dump_dp, cid) 34 | mkdir(dump_cc_dp) 35 | for fn in fns: 36 | summ = io.open(join(src_cc_dp, fn)).read().strip('\n') 37 | sentences = sent_tokenize(summ) 38 | 39 | proc_sentences = [] 40 | for ss in sentences: 41 | ss = ss.strip().strip('\n') 42 | if ss: 43 | proc_sentences.append(ss) 44 | assert proc_sentences 45 | io.open(join(dump_cc_dp, fn), mode='a').write('\n'.join(proc_sentences)) 46 | 47 | 48 | def read_data(json_fp): 49 | with open(json_fp) as f: 50 | data = json.load(f) 51 | questions = [] 52 | answers = [] 53 | supports = [] 54 | for d in data: 55 | questions.append(d["question"].strip()) 56 | answers.append(d["answer"].strip()) 57 | supports.append(d["document"].strip()) 58 | assert(len(questions) == len(answers) == len(supports)) 59 | 60 | index = choice(range(len(questions))) 61 | # print(f'loaded {len(questions)} samples!\n======Question======\nq: {questions[index]}\n======Answer======\na: {answers[index]}\n======Document======\nd: {supports[index]}') 62 | return questions, answers, supports 63 | 64 | 65 | def build_question_files(questions, data_type, simplified=False): 66 | lines = [] 67 | for idx, qq in enumerate(tqdm(questions)): 68 | if simplified: 69 | assert qq.split('--T--')[0].strip(), f'empty q after proc: {qq}' 70 | qq = qq.split('--T--')[0].strip() 71 | 72 | cid = f'{data_type}_{idx}' 73 | record = f'{cid}\t{qq}' 74 | lines.append(record) 75 | 76 | if simplified: 77 | fn = f'{data_type}_question_simplified.txt' 78 | else: 79 | fn = f'{data_type}_question_raw.txt' 80 | 81 | io.open(join(data_root, fn), mode='a').write('\n'.join(lines)) 82 | 83 | 84 | def build_answer_files(answers, data_type): 85 | dp = join(data_root, f'{data_type}_answers') 86 | if not exists(dp): 87 | mkdir(dp) 88 | for idx, ans in enumerate(tqdm(answers)): 89 | cid = f'{data_type}_{idx}' 90 | io.open(join(dp, f'{cid}.txt'), mode='a').write(ans) 91 | 92 | 93 | def build_document_files(documents, data_type): 94 | dp = join(data_root, f'{data_type}_documents') 95 | if not exists(dp): 96 | mkdir(dp) 97 | for idx, doc in enumerate(tqdm(documents)): 98 | cid = f'{data_type}_{idx}' 99 | io.open(join(dp, f'{cid}.txt'), mode='a').write(doc) 100 | 101 | 102 | def dump_segments(fp, segs, cid): 103 | lines = [f'{"_".join((cid, str(idx)))}\t{seg}' for idx, seg in enumerate(segs)] 104 | io.open(fp, mode='a').write('\n'.join(lines)) 105 | 106 | 107 | def proc_doc_into_sentences(data_type): 108 | doc_dp = join(data_root, f'{data_type}_documents') 109 | assert exists(doc_dp), f'{doc_dp} does not exist!' 110 | cids = [fn.split('.')[0] for fn in listdir(doc_dp)] 111 | 112 | sentence_dp = join(data_root, f'{data_type}_sentences') 113 | if not exists(sentence_dp): 114 | mkdir(sentence_dp) 115 | 116 | passage_dp = join(data_root, f'{data_type}_passages') 117 | if not exists(passage_dp): 118 | mkdir(passage_dp) 119 | 120 | for idx, cid in enumerate(tqdm(cids)): 121 | fn = f'{cid}.txt' 122 | doc = io.open(join(doc_dp, fn)).read() 123 | 124 | passages = doc.split('

') 125 | passages = [psg.strip() for psg in passages if psg.strip()] 126 | sentences = [sent_tokenize(psg) for psg in passages] 127 | dump_segments(join(passage_dp, fn), segs=passages, cid=cid) 128 | dump_segments(join(sentence_dp, fn), segs=list(itertools.chain(*sentences)), cid=cid) 129 | 130 | 131 | def build_summary_files(answers, data_type): 132 | dp = join(data_root, f'{data_type}_summaries') 133 | if not exists(dp): 134 | mkdir(dp) 135 | for idx, ans in enumerate(tqdm(answers)): 136 | cid = f'{data_type}_{idx}' 137 | io.open(join(dp, f'{cid}.txt'), mode='a').write(ans) 138 | 139 | 140 | def _proc_answer_into_summaries(cid, ans_dp, summary_dp): 141 | """for multiproc""" 142 | doc = io.open(join(ans_dp, f'{cid}.txt')).read().strip('\n') 143 | sentences = sent_tokenize(doc) 144 | io.open(join(summary_dp, cid), mode='a').write('\n'.join(sentences)) 145 | 146 | 147 | def proc_answer_into_summaries(data_type): 148 | ans_dp = join(data_root, f'{data_type}_answers') 149 | assert exists(ans_dp), f'{ans_dp} does not exist!' 150 | cids = [fn.split('.')[0] for fn in listdir(ans_dp)] 151 | 152 | summary_dp = join(data_root, f'{data_type}_summaries') 153 | if not exists(summary_dp): 154 | mkdir(summary_dp) 155 | for cid in tqdm(cids): 156 | doc = io.open(join(ans_dp, f'{cid}.txt')).read().strip('\n') 157 | sentences = sent_tokenize(doc) 158 | io.open(join(summary_dp, cid), mode='a').write('\n'.join(sentences)) 159 | 160 | 161 | if __name__ == "__main__": 162 | build_summary_targets() 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # querysum 2 | 3 | 4 | This repository releases the code for Coarse-to-Fine Query Focused Multi-Document Summarization. 5 | Please cite the following paper [[bib]](https://www.aclweb.org/anthology/2020.emnlp-main.296.bib) if you use this code, 6 | 7 | Xu, Yumo, and Mirella Lapata. "Coarse-to-Fine Query Focused Multi-Document Summarization." In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), pp. 3632-3645. 2020. 8 | 9 | > We consider the problem of better modeling query-cluster interactions to facilitate query focused multi-document summarization. Due to the lack of training data, existing work relies heavily on retrieval-style methods for assembling query relevant summaries. We propose a coarse-to-fine modeling framework which employs progressively more accurate modules for estimating whether text segments are relevant, likely to contain an answer, and central. The modules can be independently developed and leverage training data if available. We present an instantiation of this framework with a trained evidence estimator which relies on distant supervision from question answering (where various resources exist) to identify segments which are likely to answer the query and should be included in the summary. Our framework is robust across domains and query types (i.e., long vs short) and outperforms strong comparison systems on benchmark datasets. 10 | 11 | Should you have any query please contact me at yumo.xu@ed.ac.uk. 12 | 13 | ## Project structure 14 | ```bash 15 | querysum 16 | └───requirements.txt 17 | └───README.md 18 | └───log # logging files 19 | └───src # source files 20 | └───data # test sets 21 | └───graph # graph components for centrality estimation, e.g., sim matrix and relevance vector 22 | └───model # QA models for infernece 23 | └───rank # ranking lists from sentence-level model 24 | └───text # predicted summaries from sentence-level model 25 | └───rank_passage # ranking lists from passage-level model 26 | └───text_passage # predicted summaries from passage-level model 27 | ``` 28 | 29 | After cloning this project, use the following command to initialize the structure: 30 | ```bash 31 | mkdir log data graph model rank text rank_passage text_passage 32 | ``` 33 | 34 | Trained models used in the paper can be downloaded [here](https://drive.google.com/file/d/1lOb9ECZa_fsYCI7Q41xMQjL0fzFvpkkD/view?usp=sharing). 35 | Please put them under `querysum/model`. 36 | 37 | ## Create environment 38 | ```bash 39 | cd .. 40 | virtualenv -p python3.6 querysum 41 | cd querysum 42 | . bin/activate 43 | pip install -r requirements.txt 44 | ``` 45 | You need to install apex: 46 | ```bash 47 | cd .. 48 | git clone https://www.github.com/nvidia/apex 49 | cd apex 50 | python3 setup.py install 51 | ``` 52 | 53 | Also, you need to setup ROUGE evaluation if you have not yet done it. Please refer to [this](https://github.com/bheinzerling/pyrouge) repository. After finishing the setup, specify the ROUGE path in `frame/utils/config_loader.py` as an attribute of `PathParser`: 54 | ```python 55 | self.rouge_dir = '~/ROUGE-1.5.5/data' # specify your ROUGE dir 56 | ``` 57 | 58 | Lastly, to run the centrality module with query injection, replace the `algorithms` package under the standard `lexrank` library with our code. You can download our code [here](https://drive.google.com/file/d/1w1voXYfiKjCb6iHjtdJK1xrkZ3ixURS3/view?usp=sharing), and put it under `~/querysum`. Then: 59 | ```bash 60 | rm -rf ~/querysum/lib/python3.6/site-packages/lexrank/algorithms 61 | unzip algorithms.zip ~/querysum/lib/python3.6/site-packages/lexrank/ 62 | rm algorithms.zip 63 | ``` 64 | 65 | ## Prepare benchmark data 66 | Since we are not allowed to distribute DUC data, you can request DUC 2005-2007 from [NIST](https://www-nlpir.nist.gov/projects/duc/data.html). 67 | After acquiring the data, gather each year's clusters, summaries, and queries under `data/docs`, `data/summary_targets` and `data/topics`, respectively. For instance, DUC 2006's clusters, queries, summaries should be found under `data/docs/2006/`, `data/topics/2006.sgml` and `data/summary_targets/2006/`, respectively. 68 | 69 | TD-QFS data can be downloaded from [here](https://talbaumel.github.io/TD-QFS/files/TD-QFS.zip). 70 | You can also use the processed version [here](https://drive.google.com/file/d/1X1rKKP5SrUoU9-ki0urrlhO_L35Vl2oO/view?usp=sharing). 71 | 72 | After data preparation, you should have the following directory structure with the right files under each folder: 73 | ```bash 74 | querysum 75 | └───data 76 | │ └───docs # DUC clusters 77 | │ └───passages # DUC passage objects for passage-level QA (generated by our code) 78 | │ └───topics # DUC queries 79 | │ └───summary_targets # DUC reference summaries 80 | │ └───tdqfs # documents, queries and reference summaries in TD-QFS 81 | ``` 82 | 83 | 84 | ## Run experiments 85 | We go though the three stages in the QuerySum pipeline: retrieval, answering, and summarization. 86 | 87 | ### Retrieval 88 | In `src/frame/ir/ir_tf.py`, 89 | ```python 90 | rank_e2e() # rank all sentences 91 | ir_rank2records() # filter sentences based on IR scores 92 | ``` 93 | Specfically, `rank_e2e` builds a directory under `rank`, e.g., `rank/ir-tf-2007`, which stores ranking files for each cluster. On the top of it, `ir_rank2records` builds a filtered record for each cluster, e.g., under `rank/ir_records-ir-tf-2007-0.75_ir_conf` where `0.75` is the accumulated confidence. 94 | 95 | You can specifiy IR configs in `src/frame/ir/ir_config.py`. 96 | 97 | ### Answering 98 | For sentence-level QA (i.e., answer sentence selection), in `src/frame/bert_qa/main.py`, 99 | ```python 100 | run(prep=True, mode='rec') 101 | ``` 102 | Note that you only need to set `prep=True` at the first time, which calculates and saves query relevance scores for the retrieved sentences from the last module. 103 | The scores are then converted into a ranking list, from which top K sentences are selected. 104 | 105 | You can specify QA config in `src/frame/bert_qa/qa_config.py`. 106 | 107 | For passage-level QA (i.e., MRC), use `src/frame/bert_passage/infer.py`. 108 | 109 | ### Summarization 110 | In `src/frame/centrality/centrality_qa_tfidf_hard.py`, run the following methods in order (or at once): 111 | 112 | ```python 113 | build_components_e2e() # build graph compnents 114 | score_e2e() # run Markov Chain 115 | rank_e2e() # rank sentences 116 | select_e2e() # compose summary 117 | ``` 118 | Specifically, `build_components_e2e` builds similarity matrix and query relevance vector for the selected sentences from the last step. `score_e2e` runs a query focused Markov Chain till convergence. 119 | `rank_e2e` ranks sentences considering both the saliance (stationary distribution) and redundancy. Finally, `select_e2e` composes summary from the ranking. 120 | 121 | You can specifiy centrality configs in `src/frame/ir/centrality_config.py`. 122 | -------------------------------------------------------------------------------- /src/summ/compute_rouge.py: -------------------------------------------------------------------------------- 1 | from os.path import join, dirname, abspath 2 | import sys 3 | 4 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 5 | from pyrouge import Rouge155 6 | import utils.config_loader as config 7 | from utils.config_loader import path_parser, logger 8 | import utils.tools as tools 9 | import logging 10 | 11 | 12 | def proc_output_for_tune(output): 13 | start_pat = '1 ROUGE-2 Average' 14 | 15 | output = '\n'.join(output.split('\n')[1:]) 16 | inter_breaker = '\n---------------------------------------------\n' 17 | intra_breaker = '\n.............................................\n' 18 | 19 | target_ck = None 20 | for ck in output.split(inter_breaker): 21 | ck = ck.strip('\n') 22 | if ck: 23 | ck = ck.split(intra_breaker)[0] 24 | if ck.startswith(start_pat): 25 | target_ck = ck 26 | break 27 | 28 | if not target_ck: 29 | raise ValueError('Not found record!') 30 | 31 | num_idx = 3 32 | lines = target_ck.split('\n') 33 | recall = '{0:.2f}'.format(float(lines[0].split(' ')[num_idx]) * 100) 34 | f1 = '{0:.2f}'.format(float(lines[2].split(' ')[num_idx]) * 100) 35 | 36 | return '\t'.join((recall, f1)) 37 | 38 | 39 | def proc_output(output, target=['1', '2', 'SU4']): 40 | start_pat = '1 ROUGE-{} Average' 41 | 42 | output = '\n'.join(output.split('\n')[1:]) 43 | inter_breaker = '\n---------------------------------------------\n' 44 | intra_breaker = '\n.............................................\n' 45 | 46 | tg2ck = {} 47 | for ck in output.split(inter_breaker): 48 | ck = ck.strip('\n') 49 | if ck: 50 | ck = ck.split(intra_breaker)[0] 51 | for tg in target: 52 | if ck.startswith(start_pat.format(tg)): 53 | tg2ck[tg] = ck 54 | break 55 | 56 | num_idx = 3 57 | 58 | tg2recall = {} 59 | tg2f1 = {} 60 | 61 | for tg, ck in tg2ck.items(): 62 | lines = ck.split('\n') 63 | recall = '{0:.2f}'.format(float(lines[0].split(' ')[num_idx]) * 100) 64 | f1 = '{0:.2f}'.format(float(lines[2].split(' ')[num_idx]) * 100) 65 | 66 | tg2recall[tg] = recall 67 | tg2f1[tg] = f1 68 | 69 | recall_str = 'Recall:\t{}'.format('\t'.join([tg2recall[tg] for tg in target])) 70 | f1_str = 'F1:\t{}'.format('\t'.join([tg2f1[tg] for tg in target])) 71 | 72 | output = '\n' + '\n'.join((f1_str, recall_str)) 73 | return output 74 | 75 | 76 | def compute_rouge_mix(model_name, n_iter, cos_threshold, extra): 77 | for year in config.years: 78 | compute_rouge(model_name, n_iter=n_iter, 79 | cos_threshold=cos_threshold, 80 | year=year, 81 | extra=extra) 82 | return None 83 | 84 | 85 | def compute_rouge_for_dev(text_dp, tune_centrality): 86 | rouge_args = '-a -n 2 -m -c 95 -r 1000 -f A -p 0.5 -t 0 -d -e {} -x'.format( 87 | path_parser.rouge_dir) 88 | 89 | if tune_centrality: # summary length requirement 90 | rouge_args += ' -l 250' 91 | 92 | r = Rouge155(rouge_args=rouge_args) 93 | r.system_dir = text_dp 94 | r.model_dir = join(path_parser.data_summary_targets, config.test_year) 95 | 96 | gen_sys_file_pat = '(\w*)' 97 | gen_model_file_pat = '#ID#_[\d]' 98 | r.system_filename_pattern = gen_sys_file_pat 99 | r.model_filename_pattern = gen_model_file_pat 100 | 101 | output = r.convert_and_evaluate() 102 | output = proc_output_for_tune(output) 103 | logger.info(output) 104 | return output 105 | 106 | 107 | def compute_rouge_for_ablation_study(text_dp, ref_dp=None): 108 | rouge_args = '-a -l 250 -n 2 -m -2 4 -u -c 95 -r 1000 -f A -p 0.5 -t 0 -d -e {} -x'.format( 109 | path_parser.rouge_dir) 110 | 111 | r = Rouge155(rouge_args=rouge_args) 112 | r.system_dir = text_dp 113 | if not ref_dp: 114 | ref_dp = join(path_parser.data_summary_targets, config.test_year) 115 | r.model_dir = ref_dp 116 | 117 | gen_sys_file_pat = '(\w*)' 118 | gen_model_file_pat = '#ID#_[\d]' 119 | r.system_filename_pattern = gen_sys_file_pat 120 | r.model_filename_pattern = gen_model_file_pat 121 | 122 | output = r.convert_and_evaluate() 123 | output = proc_output(output) 124 | logger.info(output) 125 | return output 126 | 127 | 128 | def compute_rouge(model_name, n_iter=None, diversity_param_tuple=None, cos_threshold=None, extra=None): 129 | rouge_args = '-a -l 250 -n 2 -m -2 4 -u -c 95 -r 1000 -f A -p 0.5 -t 0 -d -e {} -x'.format( 130 | path_parser.rouge_dir) 131 | 132 | r = Rouge155(rouge_args=rouge_args) 133 | 134 | baselines_wo_config = ['lead', 'lead-2006', 'lead-2007', 'lead_2007'] 135 | if model_name in baselines_wo_config or model_name.startswith('duc'): 136 | text_dp = join(path_parser.summary_text, model_name) 137 | else: 138 | text_dp = tools.get_text_dp(model_name, 139 | cos_threshold=cos_threshold, 140 | n_iter=n_iter, 141 | diversity_param_tuple=diversity_param_tuple, 142 | extra=extra) 143 | 144 | r.system_dir = text_dp 145 | r.model_dir = join(path_parser.data_summary_targets, config.test_year) 146 | gen_sys_file_pat = '(\w*)' 147 | gen_model_file_pat = '#ID#_[\d]' 148 | 149 | r.system_filename_pattern = gen_sys_file_pat 150 | r.model_filename_pattern = gen_model_file_pat 151 | 152 | output = r.convert_and_evaluate() 153 | output = proc_output(output) 154 | logger.info(output) 155 | return output 156 | 157 | 158 | def compute_rouge_end2end(model_name, n_iter, cos_threshold=None, diversity_param_tuple=None, extra=None): 159 | rouge_paras = { 160 | 'model_name': model_name, 161 | 'n_iter': n_iter, 162 | 'cos_threshold': cos_threshold, 163 | 'diversity_param_tuple': diversity_param_tuple, 164 | 'extra': extra, 165 | } 166 | return compute_rouge(**rouge_paras) 167 | 168 | 169 | def compute_rouge_for_tdqfs(text_dp, ref_dp, length): 170 | rouge_args = f'-a -n 2 -m -2 4 -u -c 95 -r 1000 -f A -p 0.5 -t 0 -d -e {path_parser.rouge_dir} -x' 171 | if length: 172 | rouge_args += f' -l {length}' 173 | r = Rouge155(rouge_dir=str(path_parser.remote_root / 'ROUGE-1.5.5'), 174 | rouge_args=rouge_args, 175 | log_level=logging.WARNING, 176 | config_parent_dir=str(path_parser.remote_root)) 177 | 178 | 179 | r.system_dir = text_dp 180 | r.model_dir = ref_dp 181 | 182 | gen_sys_file_pat = '(\w*)' 183 | gen_model_file_pat = '#ID#_[\d]' 184 | r.system_filename_pattern = gen_sys_file_pat 185 | r.model_filename_pattern = gen_model_file_pat 186 | 187 | output = r.convert_and_evaluate() 188 | output = proc_output(output) 189 | logger.info(output) 190 | return output 191 | -------------------------------------------------------------------------------- /src/frame/ir/ir_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | 4 | import os 5 | from os.path import join, dirname, abspath, exists 6 | 7 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 8 | sys.path.insert(0, dirname(dirname(dirname(abspath(__file__))))) 9 | 10 | import utils.config_loader as config 11 | from utils.config_loader import logger, path_parser 12 | import summ.rank_sent as rank_sent 13 | import utils.tools as tools 14 | import tools.tfidf_tools as tfidf_tools 15 | import tools.general_tools as general_tools 16 | import frame.ir.ir_tools as ir_tools 17 | 18 | from tqdm import tqdm 19 | import shutil 20 | import frame.ir.ir_config as ir_config 21 | import summ.compute_rouge as rouge 22 | import numpy as np 23 | 24 | 25 | if config.grain != 'sent': 26 | raise ValueError('Invalid grain: {}'.format(config.grain)) 27 | 28 | 29 | def _rank(cid, query): 30 | res = tfidf_tools.build_rel_scores_tf(cid, query) 31 | rel_scores = res['rel_scores'] 32 | processed_sents = res['processed_sents'] 33 | original_sents = res['original_sents'] 34 | 35 | # get sid2score 36 | sid2score = dict() 37 | abs_idx = 0 38 | for doc_idx, doc in enumerate(processed_sents): 39 | for sent_idx, sent in enumerate(doc): 40 | sid = config.SEP.join((str(doc_idx), str(sent_idx))) 41 | score = rel_scores[abs_idx] 42 | sid2score[sid] = score 43 | 44 | abs_idx += 1 45 | 46 | # rank scores 47 | sid_score_list = rank_sent.sort_sid2score(sid2score) 48 | # include sentences in records 49 | rank_records = rank_sent.get_rank_records(sid_score_list, sents=original_sents) 50 | # rank_records = rank_sent.get_rank_records(sid_score_list) 51 | 52 | return rank_records 53 | 54 | 55 | def rank_e2e(): 56 | """ 57 | 58 | :param pool_func: avg, max, or None (for integrated query). 59 | :return: 60 | """ 61 | rank_dp = join(path_parser.summary_rank, ir_config.IR_MODEL_NAME_TF) 62 | 63 | if ir_config.QUERY_TYPE == 'REF': 64 | test_cid_query_dicts = general_tools.build_oracle_test_cid_query_dicts() 65 | else: 66 | test_cid_query_dicts = general_tools.build_test_cid_query_dicts(tokenize_narr=False, 67 | concat_title_narr=ir_config.CONCAT_TITLE_NARR, 68 | query_type=ir_config.QUERY_TYPE) 69 | 70 | if exists(rank_dp): 71 | raise ValueError('rank_dp exists: {}'.format(rank_dp)) 72 | os.mkdir(rank_dp) 73 | 74 | for cid_query_dict in tqdm(test_cid_query_dicts): 75 | params = { 76 | **cid_query_dict, 77 | } 78 | rank_records = _rank(**params) 79 | rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, params['cid']), with_rank_idx=False) 80 | 81 | logger.info('Successfully dumped rankings to: {}'.format(rank_dp)) 82 | 83 | 84 | def ir_rank2records(): 85 | ir_rec_dp = join(path_parser.summary_rank, ir_config.IR_RECORDS_DIR_NAME_TF) 86 | 87 | if exists(ir_rec_dp): 88 | raise ValueError('ir_rec_dp exists: {}'.format(ir_rec_dp)) 89 | os.mkdir(ir_rec_dp) 90 | 91 | cids = tools.get_test_cc_ids() 92 | for cid in tqdm(cids): 93 | retrieval_params = { 94 | 'model_name': ir_config.IR_MODEL_NAME_TF, 95 | 'cid': cid, 96 | 'filter_var': ir_config.FILTER_VAR, 97 | 'filter': ir_config.FILTER, 98 | 'deduplicate': ir_config.DEDUPLICATE, 99 | 'prune': True, 100 | } 101 | 102 | retrieved_items = ir_tools.retrieve(**retrieval_params) 103 | ir_tools.dump_retrieval(fp=join(ir_rec_dp, cid), retrieved_items=retrieved_items) 104 | 105 | 106 | def tune(): 107 | """ 108 | Tune IR confidence / compression rate based on Recall Rouge 2. 109 | :return: 110 | """ 111 | if ir_config.FILTER == 'conf': 112 | tune_range = np.arange(0.05, 1.05, 0.05) 113 | else: 114 | interval = 10 115 | tune_range = range(interval, 500+interval, interval) 116 | 117 | ir_tune_dp = join(path_parser.summary_rank, ir_config.IR_TUNE_DIR_NAME_TF) 118 | ir_tune_result_fp = join(path_parser.tune, ir_config.IR_TUNE_DIR_NAME_TF) 119 | with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f: 120 | headline = 'Filter\tRecall\tF1\n' 121 | out_f.write(headline) 122 | 123 | cids = tools.get_test_cc_ids() 124 | for filter_var in tune_range: 125 | if exists(ir_tune_dp): # remove previous output 126 | shutil.rmtree(ir_tune_dp) 127 | os.mkdir(ir_tune_dp) 128 | 129 | for cid in tqdm(cids): 130 | retrieval_params = { 131 | 'model_name': ir_config.IR_MODEL_NAME_TF, 132 | 'cid': cid, 133 | 'filter_var': filter_var, 134 | 'filter': ir_config.FILTER, 135 | 'deduplicate': ir_config.DEDUPLICATE, 136 | 'prune': True, 137 | } 138 | 139 | retrieved_items = ir_tools.retrieve(**retrieval_params) 140 | 141 | summary = '\n'.join([item[-1] for item in retrieved_items]) 142 | # print(summary) 143 | with open(join(ir_tune_dp, cid), mode='a', encoding='utf-8') as out_f: 144 | out_f.write(summary) 145 | 146 | performance = rouge.compute_rouge_for_dev(ir_tune_dp, tune_centrality=False) 147 | with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f: 148 | if ir_config.FILTER == 'conf': 149 | rec = '{0:.2f}\t{1}\n'.format(filter_var, performance) 150 | else: 151 | rec = '{0}\t{1}\n'.format(filter_var, performance) 152 | 153 | out_f.write(rec) 154 | 155 | 156 | def compute_rouge_for_oracle(): 157 | """ 158 | The rec dp for oracle saves text for comparing against refecence. 159 | 160 | :return: 161 | """ 162 | ir_rec_dp = join(path_parser.summary_rank, ir_config.IR_RECORDS_DIR_NAME_TF) 163 | 164 | if exists(ir_rec_dp): 165 | raise ValueError('ir_rec_dp exists: {}'.format(ir_rec_dp)) 166 | os.mkdir(ir_rec_dp) 167 | 168 | cids = tools.get_test_cc_ids() 169 | for cid in tqdm(cids): 170 | retrieval_params = { 171 | 'model_name': ir_config.IR_MODEL_NAME_TF, 172 | 'cid': cid, 173 | 'filter_var': ir_config.FILTER_VAR, 174 | 'filter': ir_config.FILTER, 175 | 'deduplicate': ir_config.DEDUPLICATE, 176 | 'prune': True, 177 | } 178 | 179 | retrieved_items = ir_tools.retrieve(**retrieval_params) 180 | summary = '\n'.join([item[-1] for item in retrieved_items]) 181 | with open(join(ir_rec_dp, cid), mode='a', encoding='utf-8') as out_f: 182 | out_f.write(summary) 183 | 184 | performance = rouge.compute_rouge_for_ablation_study(ir_rec_dp) 185 | logger.info(performance) 186 | 187 | 188 | if __name__ == '__main__': 189 | # rank_e2e() 190 | ir_rank2records() 191 | # compute_rouge_for_oracle() 192 | # tune() 193 | -------------------------------------------------------------------------------- /src/tools/vec_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from os.path import join, dirname, abspath, exists 4 | import sys 5 | 6 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 7 | import sklearn 8 | import numpy as np 9 | from scipy.stats import logistic 10 | import math 11 | 12 | 13 | def softmax(x): 14 | e_x = np.exp(x - np.max(x)) 15 | return e_x / e_x.sum(axis=0) 16 | 17 | 18 | def _compute_sim_mat(doc_sent_reps, trigger_sent_reps, score_func): 19 | """ 20 | 21 | :param doc_sent_reps: d_batch * max_ns_doc * d_embed 22 | :param trigger_sent_reps: d_batch * max_ns_trigger * d_embed 23 | :param method: cosine, angular, sigmoid 24 | :return: 25 | sim_mat: d_batch * max_ns_doc * max_ns_trigger 26 | """ 27 | if score_func in ('cosine', 'angular'): 28 | d_batch = doc_sent_reps.shape[0] 29 | 30 | sim_mats = [sklearn.metrics.pairwise.cosine_similarity(doc_sent_reps[sample_idx], trigger_sent_reps[sample_idx]) 31 | for sample_idx in range(d_batch)] 32 | 33 | sim_mats = np.stack(sim_mats, axis=0) 34 | if score_func == 'cosine': 35 | return sim_mats 36 | 37 | sim_mats = 1 - np.arccos(sim_mats) / math.pi # angular 38 | return sim_mats 39 | 40 | elif score_func == 'tanh': 41 | d_embed = doc_sent_reps.shape[-1] 42 | query_sent_reps = np.transpose(trigger_sent_reps, [0, 2, 1]) # d_batch * d_embed * max_ns_trigger 43 | score_in = np.matmul(doc_sent_reps, query_sent_reps) # d_batch * max_ns_doc * max_ns_trigger 44 | 45 | score_in /= math.sqrt(d_embed) 46 | sent_scores = np.tanh(score_in) 47 | 48 | return sent_scores 49 | 50 | elif score_func == 'sigmoid': 51 | d_embed = doc_sent_reps.shape[-1] 52 | query_sent_reps = np.transpose(trigger_sent_reps, [0, 2, 1]) # d_batch * d_embed * max_ns_trigger 53 | score_in = np.matmul(doc_sent_reps, query_sent_reps) # d_batch * max_ns_doc * max_ns_trigger 54 | 55 | score_in /= math.sqrt(d_embed) 56 | sent_scores = logistic.cdf(score_in) 57 | return sent_scores 58 | 59 | else: 60 | raise ValueError('Invalid method: {}'.format(score_func)) 61 | 62 | 63 | def _mask_sim_mat(sim_mat, doc_masks, trigger_masks): 64 | """ 65 | 66 | :param sim_mat: d_batch * max_ns_doc * max_ns_trigger 67 | :param doc_masks: d_batch * max_ns_doc 68 | :param trigger_masks: d_batch * max_ns_trigger 69 | :return: 70 | """ 71 | doc_masks = np.expand_dims(doc_masks, axis=-1) # d_batch * max_ns_doc * 1 72 | trigger_masks = np.expand_dims(trigger_masks, axis=1) # d_batch * 1 * max_ns_trigger 73 | 74 | sim_score_masks = np.matmul(doc_masks, trigger_masks) # d_batch * max_ns_doc * max_ns_trigger 75 | 76 | return sim_mat * sim_score_masks 77 | 78 | 79 | def _compute_relv_scores(sim_mat, pool_func, trigger_masks=None): 80 | """ 81 | 82 | :param sim_mat: d_batch * max_ns_doc * max_ns_trigger 83 | :param trigger_mask: d_batch * max_ns_trigger 84 | :return: 85 | relv_scores: d_batch * max_ns_doc 86 | """ 87 | if pool_func == 'avg': 88 | n_query_sents = np.sum(trigger_masks, axis=-1, keepdims=True) # d_batch * 1 89 | nom = np.sum(sim_mat, axis=-1) # d_batch * max_ns_doc 90 | 91 | relv_scores = nom / n_query_sents 92 | return relv_scores 93 | 94 | elif pool_func == 'max': 95 | relv_scores = np.max(sim_mat, axis=-1) 96 | return relv_scores 97 | else: 98 | raise ValueError('Invalid pool_func: {}'.format(pool_func)) 99 | 100 | 101 | def _mask_relv_scores(relv_scores, doc_masks): 102 | """ 103 | fill padded sentences with score of -np.inf for ranking purpose. 104 | 105 | :param relv_scores: d_batch * max_ns_doc 106 | :param doc_masks: d_batch * max_ns_doc 107 | :return: 108 | d_batch * max_ns_doc 109 | """ 110 | 111 | relv_mask = np.full_like(relv_scores, -np.inf) 112 | relv_scores = np.where(doc_masks, relv_scores, relv_mask) # d_batch * max_ns_doc 113 | return relv_scores 114 | 115 | 116 | def get_relv_scores(sent_embeds, trigger_embeds, doc_masks, trigger_masks, score_func, pool_func): 117 | sim_mat = _compute_sim_mat(sent_embeds, trigger_embeds, score_func=score_func) 118 | sim_mat = _mask_sim_mat(sim_mat, doc_masks=doc_masks, trigger_masks=trigger_masks) 119 | 120 | # relv_scores: d_batch * max_ns_doc 121 | relv_scores = _compute_relv_scores(sim_mat, pool_func=pool_func, trigger_masks=trigger_masks) 122 | relv_scores = _mask_relv_scores(relv_scores, doc_masks=doc_masks) 123 | 124 | return relv_scores 125 | 126 | 127 | def max_min_scale(scores): 128 | """ 129 | 130 | :param scores: could be a vector or a matrix. 131 | :return: 132 | """ 133 | min_v = np.min(scores) 134 | max_v = np.max(scores) 135 | denom = max_v - min_v 136 | scores = (scores - min_v) / denom 137 | return scores 138 | 139 | 140 | def norm_rel_scores(rel_scores, max_min_scale, passage_proc=False): 141 | """ 142 | The transition matrix requires all elements to be in [0, 1] and all rows sum to 1. 143 | 144 | We use max-min scale + l1 norm for both sim_mat and rel_vec, to achieve this. 145 | 146 | Note: 147 | max-min scale was not adopted by Wan's paper on query based text summarization; we believe it is a flaw. 148 | 149 | 150 | :param rel_scores: in [0, 1] 151 | :param passage_proc: for bert_passage only 152 | :return: 153 | """ 154 | if max_min_scale: 155 | rel_scores = max_min_scale(rel_scores) 156 | 157 | if passage_proc: 158 | rel_scores = np.sqrt(rel_scores) 159 | rel_scores = np.tanh(rel_scores) 160 | 161 | rel_scores = rel_scores / np.sum(rel_scores) # l1 norm to make a distribution 162 | return rel_scores 163 | 164 | 165 | def norm_sim_mat(sim_mat, max_min_scale): 166 | """ 167 | The transition matrix requires all elements to be in [0, 1] and all rows sum to 1. 168 | 169 | We use max-min scale + l1 norm for both sim_mat and rel_vec, to achieve this. 170 | 171 | Note: 172 | (1) max-min scale was not adopted by Wan's paper on query based text summarization; we believe it is a flaw. 173 | (2) l1 norm for sim_mat is implemented in lexrank package; here we just max-min scale it. 174 | 175 | 176 | :param sim_mat: 177 | :return: 178 | """ 179 | np.fill_diagonal(sim_mat, 0.0) # avoid self-transition 180 | # deal with row_sum == 0.0. Not needed as per Wan's paper. 181 | # set a uniform number 182 | # row_sum = sim_mat.sum(axis=1, keepdims=True) 183 | # cond = row_sum == 0.0 184 | # zeros = np.zeros_like(row_sum) 185 | # fill_number = 1.0/(len(sim_mat)-1) # make non-self elements add to 1.0 186 | # fill = np.full_like(row_sum, fill_number) 187 | # extra = np.where(cond, fill, zeros) 188 | # sim_mat += extra 189 | # np.fill_diagonal(sim_mat, 0.0) # remove self 190 | if max_min_scale: # todo: check if scale should come first than fill_diagonal 191 | sim_mat = max_min_scale(sim_mat) 192 | 193 | return sim_mat 194 | -------------------------------------------------------------------------------- /src/frame/bert_ensemble/ensemble.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os.path import join, dirname, abspath, exists 3 | 4 | sys_path = dirname(dirname(abspath(__file__))) 5 | parent_sys_path = dirname(sys_path) 6 | 7 | if sys_path not in sys.path: 8 | sys.path.insert(0, sys_path) 9 | if parent_sys_path not in sys.path: 10 | sys.path.insert(0, parent_sys_path) 11 | 12 | import io 13 | from tqdm import tqdm 14 | import math 15 | import numpy as np 16 | import os 17 | 18 | import utils.config_loader as config 19 | from utils.config_loader import logger, path_parser 20 | import utils.tools as tools 21 | import tools.general_tools as general_tools 22 | import frame.ir.ir_tools as ir_tools 23 | import frame.bert_ensemble.ensemble_config as ensemble_config 24 | import summ.rank_sent as rank_sent 25 | 26 | use_tdqfs = 'tdqfs' == config.test_year 27 | if use_tdqfs: 28 | query_fp = path_parser.data_tdqfs_queries 29 | test_cid_query_dicts = general_tools.build_tdqfs_cid_query_dicts(query_fp=query_fp, proc=True) 30 | cids = [cq_dict['cid'] for cq_dict in test_cid_query_dicts] 31 | else: 32 | cids = tools.get_test_cc_ids() 33 | 34 | 35 | def _read_records(cid): 36 | def _get_sent2score(qa_record_fp, proc_passage): 37 | qa_record = io.open(qa_record_fp).readlines() 38 | sent2score = {} 39 | for line in qa_record: 40 | _, score, sent = line.strip('\n').split('\t') 41 | score = float(score) 42 | 43 | if proc_passage: 44 | score = math.tanh(math.sqrt(score)) 45 | 46 | if not 0 <= score <= 1: 47 | raise ValueError('Invalid score: {}'.format(score)) 48 | 49 | if sent in sent2score and score < sent2score[sent]: 50 | continue 51 | sent2score[sent] = score 52 | return sent2score 53 | 54 | sent2score_s = _get_sent2score(qa_record_fp=join(ensemble_config.SENT_QA_RECORD_DP, cid), proc_passage=False) 55 | sent2score_p = _get_sent2score(qa_record_fp=join(ensemble_config.PASSAGE_QA_RECORD_DP, cid), proc_passage=True) 56 | 57 | return sent2score_s, sent2score_p 58 | 59 | 60 | def _proc_one_side_score(score): 61 | if ensemble_config.IS_SENT_REC_ONLY: # then this score is from sent records 62 | if ensemble_config.ENSEMBLE_MODE == 'weight_avg_sent_only': 63 | return (1 - ensemble_config.SPAN_REC_WEIGHT) * score 64 | else: 65 | raise ValueError('Corrupted conditions!') 66 | 67 | if not ensemble_config.IS_ENSEMBLE_GLOBAL: 68 | return score 69 | elif ensemble_config.ENSEMBLE_MODE == 'sqrt_global': 70 | return 0.001 * score 71 | elif ensemble_config.ENSEMBLE_MODE == 'avg_global': 72 | return score / 2 73 | else: 74 | raise ValueError('Corrupted conditions!') 75 | 76 | 77 | def _proc_two_side_scores(score_s, score_p): 78 | if ensemble_config.ENSEMBLE_MODE in ('sqrt', 'sqrt_global'): 79 | return math.sqrt(score_s * score_p) 80 | elif ensemble_config.ENSEMBLE_MODE in ('avg', 'avg_global'): 81 | return (score_s + score_p) / 2 82 | elif ensemble_config.ENSEMBLE_MODE == 'weight_avg_sent_only': 83 | return (1 - ensemble_config.SPAN_REC_WEIGHT) * score_s + ensemble_config.SPAN_REC_WEIGHT * score_p 84 | else: 85 | raise ValueError('Invalid ENSEMBLE_MODE: {}'.format(ensemble_config.ENSEMBLE_MODE)) 86 | 87 | 88 | def _ensemble_records(cid): 89 | sent2score_s, sent2score_p = _read_records(cid) 90 | 91 | # n_sents_p = len(sent2score_p) 92 | sent2score_ensemble = {} 93 | for sent, score_s in sent2score_s.items(): 94 | if sent not in sent2score_p: 95 | score_s = _proc_one_side_score(score_s) 96 | if score_s: 97 | sent2score_ensemble[sent] = score_s 98 | continue 99 | 100 | score_p = sent2score_p[sent] 101 | sent2score_ensemble[sent] = _proc_two_side_scores(score_s, score_p) 102 | del sent2score_p[sent] 103 | 104 | # logger.info('n_sents_p: {} -> {}'.format(n_sents_p, len(sent2score_p))) 105 | if not ensemble_config.IS_SENT_REC_ONLY: 106 | for sent, score_p in sent2score_p.items(): 107 | score_p = _proc_one_side_score(score_p) 108 | if score_p: 109 | sent2score_ensemble[sent] = score_p 110 | 111 | return sent2score_ensemble 112 | 113 | 114 | def _rank(cid): 115 | sent2score_ensemble = _ensemble_records(cid) 116 | sent_score_list = sorted(sent2score_ensemble.items(), key=lambda item: item[1], reverse=True) 117 | 118 | records = [] 119 | for sid, sent_score in enumerate(sent_score_list): 120 | rec = ('0_{}'.format(sid), str(sent_score[1]), sent_score[0]) 121 | records.append('\t'.join(rec)) 122 | 123 | return records 124 | 125 | 126 | def rank(): 127 | rank_dp = join(path_parser.summary_rank, ensemble_config.MODEL_NAME) 128 | if exists(rank_dp): 129 | raise ValueError('rank_dp exists: {}'.format(rank_dp)) 130 | os.mkdir(rank_dp) 131 | 132 | for cid in tqdm(cids): 133 | rank_records = _rank(cid) 134 | n_sents = rank_sent.dump_rank_records(rank_records=rank_records, out_fp=join(rank_dp, cid), with_rank_idx=False) 135 | logger.info('Dump {} ranking records'.format(n_sents)) 136 | 137 | 138 | def rank2records(): 139 | rec_dp = join(path_parser.summary_rank, ensemble_config.QA_RECORD_DIR_NAME) 140 | 141 | if exists(rec_dp): 142 | raise ValueError('rec_dp exists: {}'.format(rec_dp)) 143 | os.mkdir(rec_dp) 144 | 145 | for cid in tqdm(cids): 146 | retrieval_params = { 147 | 'model_name': ensemble_config.MODEL_NAME, 148 | 'cid': cid, 149 | 'filter_var': ensemble_config.FILTER_VAR, 150 | 'filter': ensemble_config.FILTER, 151 | 'deduplicate': None, 152 | } 153 | 154 | retrieved_items = ir_tools.retrieve(**retrieval_params) 155 | ir_tools.dump_retrieval(fp=join(rec_dp, cid), retrieved_items=retrieved_items) 156 | 157 | 158 | def rank2records_in_batch(): 159 | interval = 10 160 | start = 20 161 | end = 150 + interval 162 | filter_var_range = range(start, end, interval) 163 | 164 | for filter_var in tqdm(filter_var_range): 165 | qa_rec_dn = ensemble_config.QA_RECORD_DIR_NAME_PATTERN.format(ensemble_config.MODEL_NAME, 166 | filter_var, 167 | ensemble_config.FILTER) 168 | qa_rec_dp = join(path_parser.summary_rank, qa_rec_dn) 169 | 170 | if exists(qa_rec_dp): 171 | raise ValueError('qa_rec_dp exists: {}'.format(qa_rec_dp)) 172 | os.mkdir(qa_rec_dp) 173 | 174 | for cid in cids: 175 | retrieval_params = { 176 | 'model_name': ensemble_config.MODEL_NAME, 177 | 'cid': cid, 178 | 'filter_var': filter_var, 179 | 'filter': ensemble_config.FILTER, 180 | 'deduplicate': None, 181 | } 182 | 183 | retrieved_items = ir_tools.retrieve(**retrieval_params) 184 | ir_tools.dump_retrieval(fp=join(qa_rec_dp, cid), retrieved_items=retrieved_items) 185 | 186 | 187 | if __name__ == '__main__': 188 | rank() 189 | rank2records() 190 | # rank2records_in_batch() 191 | -------------------------------------------------------------------------------- /src/frame/centrality/centrality_tfidf_records.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os.path import join, dirname, abspath, exists 3 | sys_path = dirname(dirname(abspath(__file__))) 4 | parent_sys_path = dirname(sys_path) 5 | 6 | if sys_path not in sys.path: 7 | sys.path.insert(0, sys_path) 8 | if parent_sys_path not in sys.path: 9 | sys.path.insert(0, parent_sys_path) 10 | 11 | import io 12 | import utils.config_loader as config 13 | from utils.config_loader import logger, path_parser 14 | import utils.graph_io as graph_io 15 | import utils.graph_tools as graph_tools 16 | import tools.tfidf_tools as tfidf_tools 17 | import tools.general_tools as general_tools 18 | import tools.vec_tools as vec_tools 19 | import summ.select_sent as select_sent 20 | import summ.compute_rouge as rouge 21 | 22 | import frame.centrality.centrality_config as centrality_config 23 | from tqdm import tqdm 24 | import numpy as np 25 | from utils.tools import get_test_cc_ids, get_text_dp, get_text_dp_for_tdqfs 26 | 27 | """ 28 | This file is for the ablation study of ``w/o Answering'', a combination of the following components: 29 | 1. Retrieval model: TF; score and filter 30 | 3. Summarization model: MRW (sentence rep: TFIDF); with relevance score from IR 31 | 32 | The relevance vector is produced via loading relevance scores and normalization. 33 | """ 34 | 35 | ir_record_dp = join(path_parser.summary_rank, centrality_config.IR_RECORD_DIR_NAME) 36 | model_name = centrality_config.CENTRALITY_MODEL_NAME_wo_QA 37 | 38 | use_tdqfs = 'tdqfs' in centrality_config.QA_RECORD_DIR_NAME 39 | 40 | if use_tdqfs: 41 | sentence_dp = path_parser.data_tdqfs_sentences 42 | query_fp = path_parser.data_tdqfs_queries 43 | tdqfs_summary_target_dp = path_parser.data_tdqfs_summary_targets 44 | 45 | test_cid_query_dicts = general_tools.build_tdqfs_cid_query_dicts(query_fp=query_fp, proc=True) 46 | cc_ids = [cq_dict['cid'] for cq_dict in test_cid_query_dicts] 47 | else: 48 | test_cid_query_dicts = general_tools.build_test_cid_query_dicts(tokenize_narr=False, 49 | concat_title_narr=False, 50 | query_type=centrality_config.QUERY_TYPE) 51 | cc_ids = get_test_cc_ids() 52 | 53 | 54 | def _load_rel_scores(cid, ir_record_dp): 55 | ir_record_fp = join(ir_record_dp, cid) 56 | ir_records = io.open(ir_record_fp, encoding='utf-8').readlines() 57 | ir_scores = [float(line.split('\t')[1]) for line in ir_records] 58 | ir_rel_scores = np.array(ir_scores) 59 | return ir_rel_scores 60 | 61 | 62 | def _build_components(cid, query): 63 | sim_items = tfidf_tools.build_sim_items_e2e(cid, 64 | query, 65 | mask_intra=None, 66 | max_ns_doc=None, 67 | retrieved_dp=ir_record_dp, 68 | sentence_rep='tfidf') 69 | 70 | sim_mat = vec_tools.norm_sim_mat(sim_mat=sim_items['doc_sim_mat'], max_min_scale=False) 71 | # logger.info('sim_mat: {}'.format(sim_mat)) 72 | 73 | rel_scores = _load_rel_scores(cid, ir_record_dp=ir_record_dp) 74 | rel_vec = vec_tools.norm_rel_scores(rel_scores=rel_scores, max_min_scale=False) 75 | # logger.info('rel_vec: {}'.format(rel_vec)) 76 | 77 | if len(rel_vec) != len(sim_mat): 78 | raise ValueError('Incompatible sim_mat size: {} and rel_vec size: {} for cid: {}'.format( 79 | sim_mat.shape, rel_vec.shape, cid)) 80 | 81 | processed_sents = sim_items['processed_sents'] 82 | sid2abs = {} 83 | sid_abs = 0 84 | for doc_idx, doc in enumerate(processed_sents): 85 | for sent_idx, sent in enumerate(doc): 86 | sid = config.SEP.join((str(doc_idx), str(sent_idx))) 87 | sid2abs[sid] = sid_abs 88 | sid_abs += 1 89 | 90 | components = { 91 | 'sim_mat': sim_mat, 92 | 'rel_vec': rel_vec, 93 | 'sid2abs': sid2abs, 94 | } 95 | 96 | return components 97 | 98 | 99 | def build_components_e2e(): 100 | dp_params = { 101 | 'model_name': model_name, 102 | 'n_iter': None, 103 | 'mode': 'w', 104 | } 105 | 106 | summ_comp_root = graph_io.get_summ_comp_root(**dp_params) 107 | sim_mat_dp = graph_io.get_sim_mat_dp(summ_comp_root, mode='w') 108 | rel_vec_dp = graph_io.get_rel_vec_dp(summ_comp_root, mode='w') 109 | sid2abs_dp = graph_io.get_sid2abs_dp(summ_comp_root, mode='w') 110 | 111 | for params in tqdm(test_cid_query_dicts): 112 | components = _build_components(**params) 113 | 114 | graph_io.dump_sim_mat(sim_mat=components['sim_mat'], sim_mat_dp=sim_mat_dp, cid=params['cid']) 115 | graph_io.dump_rel_vec(rel_vec=components['rel_vec'], rel_vec_dp=rel_vec_dp, cid=params['cid']) 116 | graph_io.dump_sid2abs(sid2abs=components['sid2abs'], sid2abs_dp=sid2abs_dp, cid=params['cid']) 117 | 118 | 119 | def score_e2e(): 120 | if centrality_config.DAMP == 1.0: 121 | damp = 0.85 122 | use_rel_vec = False 123 | else: 124 | damp = centrality_config.DAMP 125 | use_rel_vec = True 126 | 127 | graph_tools.score_end2end(model_name=model_name, 128 | damp=damp, 129 | use_rel_vec=use_rel_vec, 130 | cc_ids=cc_ids) 131 | 132 | 133 | def rank_e2e(): 134 | graph_tools.rank_end2end(model_name=model_name, 135 | diversity_param_tuple=centrality_config.DIVERSITY_PARAM_TUPLE, 136 | retrieved_dp=ir_record_dp, 137 | cc_ids=cc_ids) 138 | 139 | 140 | def select_e2e(): 141 | params = { 142 | 'model_name': model_name, 143 | 'diversity_param_tuple': centrality_config.DIVERSITY_PARAM_TUPLE, 144 | 'cos_threshold': centrality_config.COS_THRESHOLD, # do not pos cosine similarity criterion? 145 | 'retrieved_dp': ir_record_dp, 146 | } 147 | select_sent.select_end2end(**params) 148 | 149 | 150 | def select_e2e_tdqfs(): 151 | params = { 152 | 'model_name': model_name, 153 | 'n_iter': None, 154 | 'length_budget_tuple': centrality_config.LENGTH_BUDGET_TUPLE, 155 | 'diversity_param_tuple': centrality_config.DIVERSITY_PARAM_TUPLE, 156 | 'cos_threshold': centrality_config.COS_THRESHOLD, # do not pos cosine similarity criterion? 157 | 'retrieved_dp': ir_record_dp, 158 | 'cc_ids': cc_ids, 159 | } 160 | select_sent.select_end2end_for_tdqfs(**params) 161 | 162 | 163 | def compute_rouge_tdqfs(): 164 | text_params = { 165 | 'model_name': model_name, 166 | 'n_iter': None, 167 | 'length_budget_tuple': centrality_config.LENGTH_BUDGET_TUPLE, 168 | 'diversity_param_tuple': centrality_config.DIVERSITY_PARAM_TUPLE, 169 | 'cos_threshold': centrality_config.COS_THRESHOLD, 170 | 'extra': None, 171 | } 172 | text_dp = get_text_dp_for_tdqfs(**text_params) 173 | 174 | rouge_parmas = { 175 | 'text_dp': text_dp, 176 | 'ref_dp': tdqfs_summary_target_dp, 177 | } 178 | if centrality_config.LENGTH_BUDGET_TUPLE[0] == 'nw': 179 | rouge_parmas['length'] = centrality_config.LENGTH_BUDGET_TUPLE[1] 180 | 181 | output = rouge.compute_rouge_for_tdqfs(**rouge_parmas) 182 | return output 183 | 184 | 185 | if __name__ == '__main__': 186 | build_components_e2e() 187 | score_e2e() 188 | rank_e2e() 189 | # select_e2e() 190 | 191 | select_e2e_tdqfs() 192 | compute_rouge_tdqfs() 193 | -------------------------------------------------------------------------------- /src/utils/config_loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | import yaml 4 | from io import open 5 | import os 6 | from os.path import join, dirname, abspath 7 | import warnings 8 | import sys 9 | from pytorch_transformers import BertModel, BertTokenizer, BertForQuestionAnswering 10 | import torch 11 | from pathlib import Path 12 | 13 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 14 | 15 | 16 | def deprecated(func): 17 | """ 18 | This is a decorator which can be used to mark functions 19 | as deprecated. It will result in a warning being emitted 20 | when the function is used. 21 | """ 22 | 23 | def new_func(*args, **kwargs): 24 | warnings.warn("Call to deprecated function {}.".format(func.__name__), 25 | category=DeprecationWarning) 26 | return func(*args, **kwargs) 27 | 28 | new_func.__name__ = func.__name__ 29 | new_func.__doc__ = func.__doc__ 30 | new_func.__dict__.update(func.__dict__) 31 | return new_func 32 | 33 | 34 | class PathParser: 35 | def __init__(self, proj_root): 36 | self.proj_root = proj_root 37 | self.log = join(self.proj_root, 'log') 38 | 39 | # set data 40 | self.data = join(self.proj_root, 'data') 41 | self.squad = join(self.data, 'squad') 42 | self.squad_raw = join(self.squad, 'raw') 43 | self.squad_proc = join(self.squad, 'proc') 44 | self.data_docs = join(self.data, 'docs') 45 | self.data_passages = join(self.data, 'passages') 46 | self.data_topics = join(self.data, 'topics') 47 | 48 | # tdqfs 49 | self.data_tdqfs = join(self.data, 'tdqfs') 50 | self.data_tdqfs_sentences = join(self.data_tdqfs, 'sentences') 51 | self.data_tdqfs_passages = join(self.data_tdqfs, 'passages') 52 | self.data_tdqfs_queries = join(self.data_tdqfs, 'query_info.txt') 53 | self.data_tdqfs_summary_targets = join(self.data_tdqfs, 'summary_targets') 54 | 55 | self.data_summary_results = join(self.data, 'summary_results') 56 | self.data_summary_refs = join(self.data, 'summary_refs') 57 | self.data_summary_targets = join(self.data, 'summary_targets') 58 | 59 | self.res = join(self.proj_root, 'res') 60 | self.model_save = join(self.proj_root, 'model') 61 | self.bert_qa = join(self.model_save, 'qa_sentence') 62 | 63 | # bert passage 64 | self.bert_passage_root = join(self.model_save, 'squad_passage') 65 | self.bert_passage_tokenizer = join(self.bert_passage_root, 'passage_tokenizer') 66 | self.bert_passage_checkpoint_root = join(self.bert_passage_root, 'squad-epoch_5') 67 | self.bert_passage_checkpoint = join(self.bert_passage_checkpoint_root, 'checkpoint-{}') 68 | self.bert_passage_model = join(self.bert_passage_checkpoint, 'pytorch_model.bin') 69 | self.bert_passage_config = join(self.bert_passage_checkpoint, 'config.json') 70 | 71 | self.pred = join(self.proj_root, 'pred') 72 | 73 | if config_meta['grain'] == 'sent': 74 | self.summary_rank = join(self.proj_root, 'rank') 75 | self.summary_text = join(self.proj_root, 'text') 76 | self.graph = join(self.proj_root, 'graph') 77 | else: 78 | self.summary_rank = join(self.proj_root, 'rank_{}'.format(config_meta['grain'])) 79 | self.summary_text = join(self.proj_root, 'text_{}'.format(config_meta['grain'])) 80 | self.graph = join(self.proj_root, 'graph_{}'.format(config_meta['grain'])) 81 | 82 | self.graph_rel_scores = join(self.graph, 'rel_scores') # for dumping relevance scores 83 | self.graph_token_logits = join(self.graph, 'token_logits') # for dumping relevance scores 84 | 85 | self.rouge = join(self.proj_root, 'rouge') 86 | 87 | self.tune = join(self.proj_root, 'tune') 88 | self.rouge_dir = '~/ROUGE-1.5.5/data' # specify your ROUGE dir 89 | 90 | src_root = os.path.dirname(os.path.dirname(__file__)) 91 | proj_root = os.path.dirname(src_root) 92 | config_root = join(src_root, 'config') 93 | 94 | # meta 95 | config_meta_fp = os.path.join(config_root, 'config_meta.yml') 96 | config_meta = yaml.load(open(config_meta_fp, 'r', encoding='utf-8')) 97 | path_parser = PathParser(proj_root=proj_root) 98 | 99 | # model 100 | meta_model_name = config_meta['model_name'] 101 | config_model_fn = 'config_model_{0}.yml'.format(meta_model_name) 102 | config_model_fp = os.path.join(config_root, config_model_fn) 103 | config_model = yaml.load(open(config_model_fp, 'r')) 104 | 105 | test_year = config_meta['test_year'] 106 | grain = config_meta['grain'] 107 | remove_dialog = config_meta['remove_dialog'] 108 | 109 | if meta_model_name == 'bert_passage': 110 | from frame.bert_passage.config_name import model_name 111 | elif meta_model_name in ('bert_qa', 'bert_base'): 112 | from frame.bert_qa.config_name import model_name 113 | else: 114 | model_name = 'MiscModel' 115 | 116 | logger = logging.getLogger('my_logger') 117 | logger.setLevel(logging.DEBUG) 118 | file_handler = logging.FileHandler('log/{0}.log'.format(model_name)) 119 | console_handler = logging.StreamHandler(sys.stdout) 120 | formatter = logging.Formatter("[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s") 121 | file_handler.setFormatter(formatter) 122 | console_handler.setFormatter(formatter) 123 | logger.addHandler(file_handler) 124 | logger.addHandler(console_handler) 125 | 126 | logger.info(f'model name: {model_name}') 127 | 128 | NARR = 'narr' 129 | TITLE = 'title' 130 | QUERY = 'query' 131 | NONE = 'None' 132 | SEP = '_' 133 | years = ['2005', '2006', '2007'] 134 | 135 | def load_bert_qa(): 136 | print('Load PyTorch model from {}'.format(path_parser.bert_qa)) 137 | if not torch.cuda.is_available(): # cpu 138 | state = torch.load(path_parser.bert_qa, map_location='cpu') 139 | else: 140 | state = torch.load(path_parser.bert_qa) 141 | 142 | return state['epoch'], state['model'], state['tokenizer'], state['scores'] 143 | 144 | bert_passage_iter = 12000 145 | def load_bert_passage(): 146 | if meta_model_name != 'bert_passage': 147 | raise ValueError('Invalid meta_model_name: {}'.format(meta_model_name)) 148 | 149 | tokenizer_dir = path_parser.bert_passage_tokenizer 150 | checkpoint_dir = path_parser.bert_passage_checkpoint.format(bert_passage_iter) 151 | 152 | print('Load PyTorch model from {}, vocab: {}'.format(checkpoint_dir, tokenizer_dir)) 153 | 154 | model_params = { 155 | 'pretrained_model_name_or_path': checkpoint_dir, 156 | } 157 | bert_model = BertForQuestionAnswering.from_pretrained(**model_params) 158 | 159 | tokenizer = BertTokenizer.from_pretrained(tokenizer_dir, 160 | do_lower_case=True, 161 | do_basic_tokenize=True) 162 | 163 | return bert_model, tokenizer 164 | 165 | 166 | preload_model_tokenizer = config_meta['preload_model_tokenizer'] 167 | if preload_model_tokenizer: 168 | if meta_model_name == 'bert_qa' and config_model['fine_tune'] == 'qa': 169 | logger.info('building BERT model and tokenizer: {}'.format(config_model['fine_tune'])) 170 | _, bert_model, bert_tokenizer, _ = load_bert_qa() 171 | 172 | elif meta_model_name == 'bert_passage' and config_model['fine_tune'] == 'passage': 173 | logger.info('building BERT model and tokenizer: {}'.format(config_model['fine_tune'])) 174 | bert_model, bert_tokenizer = load_bert_passage() 175 | 176 | else: 177 | logger.info('building BERT tokenizer') 178 | bert_model = BertModel.from_pretrained('bert-base-uncased') 179 | bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) 180 | 181 | mode = config_meta['mode'] 182 | if mode == 'rank_sent': 183 | config_model['d_batch'] = 50 184 | -------------------------------------------------------------------------------- /src/frame/ir/ir_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | from os.path import join, dirname, abspath, exists 4 | 5 | sys_path = dirname(dirname(abspath(__file__))) 6 | parent_sys_path = dirname(sys_path) 7 | 8 | if sys_path not in sys.path: 9 | sys.path.insert(0, sys_path) 10 | if parent_sys_path not in sys.path: 11 | sys.path.insert(0, parent_sys_path) 12 | 13 | import summ.rank_sent as rank_sent 14 | from utils.config_loader import logger, path_parser 15 | import utils.tools as tools 16 | from data.dataset_parser import dataset_parser 17 | 18 | import io 19 | import numpy as np 20 | import tools.vec_tools as vec_tools 21 | import dill 22 | 23 | 24 | def load_retrieved_sentences(retrieved_dp, cid): 25 | """ 26 | For downstream components, e.g., QA model or centrality model. 27 | 28 | :param retrieved_dp: 29 | :param cid: 30 | :return: 31 | """ 32 | if not exists(retrieved_dp): 33 | raise ValueError('retrieved_dp does not exist: {}'.format(retrieved_dp)) 34 | 35 | fp = join(retrieved_dp, cid) 36 | with io.open(fp, encoding='utf-8') as f: 37 | content = f.readlines() 38 | 39 | original_sents = [ll.rstrip('\n').split('\t')[-1] for ll in content] 40 | 41 | processed_sents = [dataset_parser._proc_sent(ss, rm_dialog=False, rm_stop=True, stem=True) 42 | for ss in original_sents] 43 | 44 | return [original_sents], [processed_sents] # for compatibility of document organization for similarity calculation 45 | 46 | 47 | def load_retrieved_passages(cid, get_sents, retrieved_dp=None, passage_ids=None, tdqfs_data=False): 48 | """ 49 | You can retrieve passages from passage_ids 50 | [1] in retrieved files by setting retrieved_dp OR 51 | [2] from target passage_ids by setting passage_ids OR 52 | [3] from all passages by leaving both retrieved_dp and passage_ids to None 53 | 54 | :param cid: 55 | :param get_sents: 56 | :param retrieved_dp: 57 | :param passage_ids: 58 | :return: 59 | """ 60 | # if not (retrieved_dp or passage_ids): 61 | # raise ValueError('Specify retrieved_dp or passage_ids!') 62 | if tdqfs_data: 63 | GET_PASSAGE_FPS = tools.get_passage_fps_for_tdqfs 64 | else: 65 | GET_PASSAGE_FPS = tools.get_passage_fps 66 | 67 | passage_ids, passage_fps = GET_PASSAGE_FPS(cid, retrieved_dp=retrieved_dp, passage_ids=passage_ids) 68 | 69 | original_passages, proc_passages = [], [] 70 | for fp in passage_fps: 71 | with open(fp, 'rb') as f: 72 | po = dill.load(f) 73 | 74 | if get_sents: 75 | original_p = po.get_original_sents() 76 | proc_p = po.get_proc_sents() 77 | else: 78 | original_p = po.get_original_passage() 79 | proc_p = po.get_proc_passage() 80 | 81 | original_passages.append(original_p) 82 | proc_passages.append(proc_p) 83 | 84 | return original_passages, proc_passages, passage_ids 85 | 86 | 87 | def load_retrieved_paragraphs(retrieved_dp, cid): 88 | """ 89 | For downstream components, e.g., QA model or centrality model. 90 | fixme: this function has been archived. 91 | 92 | :param retrieved_dp: 93 | :param cid: 94 | :return: 95 | """ 96 | 97 | if not exists(retrieved_dp): 98 | raise ValueError('retrieved_dp does not exist: {}'.format(retrieved_dp)) 99 | 100 | fp = join(retrieved_dp, cid) 101 | with io.open(fp, encoding='utf-8') as f: 102 | content = f.readlines() 103 | 104 | original_paras = [ll.rstrip('\n').split('\t')[-1] for ll in content] 105 | 106 | para_tuples = [dataset_parser._proc_para(pp, rm_dialog=False, rm_stop=True, stem=True, to_str=True) 107 | for pp in original_paras] 108 | 109 | original_paras, processed_paras = list(zip(*para_tuples)) # using the new "original" to keep consistency 110 | 111 | return original_paras, processed_paras # for compatibility of document organization for similarity calculation 112 | 113 | 114 | def load_rank_items(model_name, cid): 115 | rank_dp = join(path_parser.summary_rank, model_name) 116 | rank_fp = join(rank_dp, cid) 117 | 118 | with io.open(rank_fp, encoding='utf-8') as f: 119 | content = f.readlines() 120 | 121 | rank_items = [ll.rstrip('\n').split('\t') for ll in content] 122 | return rank_items 123 | 124 | 125 | def _deduplicate(rank_items): 126 | new_rank_items = [] 127 | sents = [] 128 | for items in rank_items: 129 | sent = items[-1] 130 | if sent in sents: 131 | continue 132 | sents.append(sent) 133 | new_rank_items.append(items) 134 | 135 | return new_rank_items 136 | 137 | 138 | def _norm(rank_items): 139 | score_list = [float(items[1]) for items in rank_items] 140 | 141 | scores = vec_tools.max_min_scale(scores=np.array(score_list)) 142 | score_list = scores.tolist() 143 | for i in range(len(rank_items)): 144 | rank_items[i][1] = str(score_list[i]) 145 | 146 | return rank_items, score_list 147 | 148 | 149 | def _retrieve_from_rank_items_via_conf(rank_items, 150 | conf_threshold, 151 | deduplicate, 152 | min_ns, 153 | norm=False): 154 | if deduplicate: 155 | rank_items = _deduplicate(rank_items) 156 | 157 | conf = 0.0 158 | if len(rank_items[0]) not in (2, 3): # 2: w/o sentence; 3: with sentence 159 | raise ValueError('Corrupted item format: {}'.format(rank_items[0])) 160 | 161 | score_list = [float(items[1]) for items in rank_items] 162 | 163 | if norm: 164 | rank_items, score_list = _norm(rank_items) 165 | 166 | total = sum(score_list) 167 | retrieved_items = [] 168 | 169 | if min_ns: 170 | n_threshold = min(min_ns, len(rank_items)) 171 | else: 172 | n_threshold = None 173 | 174 | for items in rank_items: 175 | retrieved_items.append(items) 176 | conf += float(items[1]) / total 177 | 178 | if conf >= conf_threshold and (not n_threshold or len(retrieved_items) >= n_threshold): 179 | break 180 | 181 | return retrieved_items 182 | 183 | 184 | def _retrieve_from_rank_items_via_top_k(rank_items, k, deduplicate): 185 | if deduplicate: 186 | rank_items = _deduplicate(rank_items) 187 | 188 | return rank_items[:min(len(rank_items), k)] 189 | 190 | 191 | def _prune_rank_items(rank_items, threshold=1e-10): 192 | if float(rank_items[-1][1]) > threshold: 193 | logger.info('Prune ratio: 0.00') 194 | return rank_items 195 | 196 | for i in range(len(rank_items)): 197 | if float(rank_items[i][1]) <= threshold: 198 | logger.info('Prune ratio: {0:.2f}'.format(float(i) / len(rank_items))) 199 | return rank_items[:i] 200 | 201 | 202 | def retrieve(model_name, 203 | cid, 204 | filter_var, 205 | filter, 206 | deduplicate, 207 | min_ns=None, 208 | norm=False, 209 | prune=False): 210 | rank_items = load_rank_items(model_name, cid) 211 | 212 | logger.info('cid: {}, #rank_items: {}'.format(cid, len(rank_items))) 213 | 214 | if float(rank_items[0][1]) == 0.0: 215 | logger.info('retrieved {0}/{1} items for {2}'.format(len(rank_items), len(rank_items), cid)) 216 | return rank_items 217 | 218 | 219 | if prune: 220 | rank_items = _prune_rank_items(rank_items) 221 | 222 | if filter == 'conf': 223 | retrieved_items = _retrieve_from_rank_items_via_conf(rank_items, 224 | filter_var, 225 | deduplicate=deduplicate, 226 | min_ns=min_ns, 227 | norm=norm) 228 | 229 | elif filter == 'topK': 230 | retrieved_items = _retrieve_from_rank_items_via_top_k(rank_items, 231 | filter_var, 232 | deduplicate=deduplicate) 233 | 234 | else: 235 | raise ValueError('Invalid FILTER: {}'.format(filter)) 236 | 237 | logger.info('retrieved {0}/{1} items for {2}'.format(len(retrieved_items), len(rank_items), cid)) 238 | 239 | return retrieved_items 240 | 241 | 242 | def dump_retrieval(fp, retrieved_items): 243 | retrieve_records = ['\t'.join(items) for items in retrieved_items] 244 | n_sents = rank_sent.dump_rank_records(rank_records=retrieve_records, out_fp=fp, with_rank_idx=False) 245 | 246 | logger.info('successfully dumped {0} retrieved items to {1}'.format(n_sents, fp)) 247 | -------------------------------------------------------------------------------- /src/utils/graph_tools.py: -------------------------------------------------------------------------------- 1 | from os.path import exists, join, dirname, abspath 2 | import itertools 3 | import sys 4 | import copy 5 | import sklearn 6 | import os 7 | import io 8 | from tqdm import tqdm 9 | from lexrank import STOPWORDS, LexRank 10 | 11 | from data.dataset_parser import dataset_parser 12 | import utils.tools as tools 13 | import utils.graph_io as graph_io 14 | import utils.config_loader as config 15 | from utils.config_loader import logger 16 | import summ.rank_sent as rank_sent 17 | import summ.select_sent as select_sent 18 | from frame.ir.ir_tools import load_retrieved_sentences 19 | 20 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 21 | 22 | 23 | def _score_graph_initially(sim_mat, rel_vec, cid, damp, abs2sid=None): 24 | # todo: check if feeding placeholder documents to init LexRank does no harm 25 | # _, processed_sents = dataset_parser.cid2sents(cid, rm_dialog=rm_dialog) # 2d lists, docs => sents 26 | # lxr = LexRank(processed_sents, stopwords=STOPWORDS['en']) 27 | doc_place_holder = [['test sentence 1', 'test sentence 2'], ['test sentence 3']] 28 | lxr = LexRank(doc_place_holder, stopwords=STOPWORDS['en']) 29 | params = { 30 | 'similarity_matrix': sim_mat, 31 | 'threshold': None, 32 | 'fast_power_method': True, 33 | 'rel_vec': rel_vec, 34 | 'damp': damp, 35 | } 36 | scores = lxr.rank_sentences_with_sim_mat(**params) 37 | 38 | sid2score = dict() 39 | 40 | for abs, sc in enumerate(scores): 41 | sid2score[abs2sid[abs]] = sc 42 | 43 | return sid2score 44 | 45 | 46 | def _rank_with_diversity_penalty_wan(sid2score, sid2abs, sim_mat, omega=10, original_sents=None): 47 | """ 48 | 49 | :param sid2score: 50 | :param sid2abs: 51 | :param sim_mat: 52 | :param omega: 53 | :param original_sents: 54 | 55 | :return: 56 | """ 57 | if omega < 0: 58 | raise ValueError('Invalid omega: {}'.format(omega)) 59 | 60 | # norm sim_mat 61 | sim_mat_normed = sklearn.preprocessing.normalize(sim_mat, axis=1, norm='l1') 62 | sid_score_list_selected = [] 63 | # n_iter = 0 64 | 65 | sid2score_ar = copy.deepcopy(sid2score) 66 | 67 | # while sid2score_ar and n_iter <= max_n_iter: 68 | while sid2score_ar: 69 | sid_score_list = rank_sent.sort_sid2score(sid2score=sid2score_ar) 70 | sid_0, _ = sid_score_list[0] 71 | sid_score_list_selected.append(sid_score_list[0]) 72 | 73 | ii = sid2abs[sid_0] 74 | del sid2score_ar[sid_0] 75 | for sid_j in sid2score_ar: # penalize remaining sentences 76 | jj = sid2abs[sid_j] 77 | info_rich_ii = sid2score[sid_0] 78 | penalty = omega * sim_mat_normed[jj, ii] * info_rich_ii 79 | 80 | # logger.info('penalty: {}'.format(penalty)) 81 | sid2score_ar[sid_j] -= penalty 82 | 83 | # n_iter += 1 84 | rank_records = rank_sent.get_rank_records(sid_score_list=sid_score_list_selected, sents=original_sents) 85 | return rank_records 86 | 87 | 88 | def score_end2end(model_name, n_iter=None, damp=0.85, use_rel_vec=True, cc_ids=None): 89 | dp_mode = 'r' 90 | dp_params = { 91 | 'model_name': model_name, # one model has only one suit of summary components but different ranking sys 92 | 'n_iter': n_iter, 93 | 'mode': dp_mode, 94 | } 95 | 96 | summ_comp_root = graph_io.get_summ_comp_root(**dp_params) 97 | sim_mat_dp = graph_io.get_sim_mat_dp(summ_comp_root, mode=dp_mode) 98 | rel_vec_dp = graph_io.get_rel_vec_dp(summ_comp_root, mode=dp_mode) 99 | sid2abs_dp = graph_io.get_sid2abs_dp(summ_comp_root, mode=dp_mode) 100 | 101 | sid2score_dp = graph_io.get_sid2score_dp(summ_comp_root, mode='w') 102 | 103 | dps = { 104 | 'sim_mat_dp': sim_mat_dp, 105 | 'rel_vec_dp': rel_vec_dp, 106 | 'sid2abs_dp': sid2abs_dp, 107 | } 108 | 109 | if not cc_ids: 110 | cc_ids = tools.get_test_cc_ids() 111 | 112 | for cid in tqdm(cc_ids): 113 | comp_params = { 114 | **dps, 115 | 'cid': cid, 116 | } 117 | components = graph_io.load_components(**comp_params) 118 | # logger.info('[GRAPH RANK 1/2] successfully loaded components') 119 | 120 | abs2sid = {} 121 | for sid, abs in components['sid2abs'].items(): 122 | abs2sid[abs] = sid 123 | 124 | scoring_params = { 125 | 'sim_mat': components['sim_mat'], 126 | 'rel_vec': components['rel_vec'].transpose() if use_rel_vec else None, 127 | # 'rel_vec': components['rel_vec'] if use_rel_vec else None, 128 | 'cid': cid, 129 | 'damp': damp, 130 | 'abs2sid': abs2sid, 131 | # 'rm_dialog': rm_dialog, 132 | } 133 | 134 | sid2score = _score_graph_initially(**scoring_params) 135 | graph_io.dump_sid2score(sid2score=sid2score, sid2score_dp=sid2score_dp, cid=cid) 136 | 137 | # logger.info('[GRAPH RANK 2/2] successfully completed initial scoring') 138 | 139 | logger.info('[GRAPH RANK] Finished. Scores were dumped to: {}'.format(sid2score_dp)) 140 | 141 | 142 | def rank_end2end(model_name, 143 | diversity_param_tuple, 144 | component_name=None, 145 | n_iter=None, 146 | rank_dp=None, 147 | retrieved_dp=None, 148 | rm_dialog=True, 149 | cc_ids=None): 150 | """ 151 | 152 | :param model_name: 153 | :param diversity_param_tuple: 154 | :param component_name: 155 | :param n_iter: 156 | :param rank_dp: 157 | :param retrieved_dp: 158 | :param rm_dialog: only useful when retrieved_dp=None 159 | :return: 160 | """ 161 | dp_mode = 'r' 162 | dp_params = { 163 | 'n_iter': n_iter, 164 | 'mode': dp_mode, 165 | } 166 | 167 | diversity_weight, diversity_algorithm = diversity_param_tuple 168 | 169 | # todo: double check this condition; added later for avoiding bug for centrality-tfidf. 170 | # # one model has only one suit of summary components but different ranking sys 171 | if component_name: 172 | dp_params['model_name'] = component_name 173 | else: 174 | dp_params['model_name'] = model_name 175 | 176 | summ_comp_root = graph_io.get_summ_comp_root(**dp_params) 177 | sim_mat_dp = graph_io.get_sim_mat_dp(summ_comp_root, mode=dp_mode) 178 | rel_vec_dp = graph_io.get_rel_vec_dp(summ_comp_root, mode=dp_mode) 179 | sid2abs_dp = graph_io.get_sid2abs_dp(summ_comp_root, mode=dp_mode) 180 | sid2score_dp = graph_io.get_sid2score_dp(summ_comp_root, mode=dp_mode) 181 | 182 | if not rank_dp: 183 | rank_dp_params = { 184 | 'model_name': model_name, 185 | 'n_iter': n_iter, 186 | 'diversity_param_tuple': diversity_param_tuple, 187 | } 188 | 189 | rank_dp = tools.get_rank_dp(**rank_dp_params) 190 | 191 | if exists(rank_dp): 192 | raise ValueError('rank_dp exists: {}'.format(rank_dp)) 193 | os.mkdir(rank_dp) 194 | 195 | dps = { 196 | 'sim_mat_dp': sim_mat_dp, 197 | 'rel_vec_dp': rel_vec_dp, 198 | 'sid2abs_dp': sid2abs_dp, 199 | } 200 | 201 | if not cc_ids: 202 | cc_ids = tools.get_test_cc_ids() 203 | 204 | for cid in tqdm(cc_ids): 205 | # logger.info('cid: {}'.format(cid)) 206 | comp_params = { 207 | **dps, 208 | 'cid': cid, 209 | } 210 | components = graph_io.load_components(**comp_params) 211 | # logger.info('[GRAPH RANK 1/2] successfully loaded components') 212 | sid2score = graph_io.load_sid2score(sid2score_dp, cid) 213 | 214 | if retrieved_dp: 215 | original_sents, _ = load_retrieved_sentences(retrieved_dp=retrieved_dp, cid=cid) 216 | else: 217 | if 'tdqfs' in config.test_year: 218 | original_sents, _ = dataset_parser.cid2sents_tdqfs(cid) 219 | else: 220 | original_sents, _ = dataset_parser.cid2sents(cid, rm_dialog=rm_dialog) # 2d lists, docs => sents 221 | 222 | diversity_params = { 223 | 'sid2score': sid2score, 224 | 'sid2abs': components['sid2abs'], 225 | 'sim_mat': components['sim_mat'], 226 | 'original_sents': original_sents, 227 | } 228 | 229 | if diversity_algorithm == 'wan': 230 | diversity_params['omega'] = diversity_weight 231 | rank_records = _rank_with_diversity_penalty_wan(**diversity_params) 232 | else: 233 | raise ValueError('Invalid diversity_algorithm: {}'.format(diversity_algorithm)) 234 | 235 | logger.info('cid: {}, #rank_records: {}'.format(cid, len(rank_records))) 236 | rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, cid), with_rank_idx=False) 237 | 238 | logger.info('[GRAPH RANK] Finished. Rankings were dumped to: {}'.format(rank_dp)) 239 | 240 | 241 | def select_end2end(model_name, n_iter=None, omega=10, save_out_fp=None): 242 | params = { 243 | 'model_name': model_name, 244 | 'n_iter': n_iter, 245 | 'cos_threshold': 1.0, # do not pos cosine similarity criterion 246 | 'omega': omega, 247 | } 248 | output = select_sent.select_end2end(**params) 249 | 250 | # output = rouge.compute_rouge_end2end(**params) 251 | if save_out_fp: 252 | content = '\t'.join((str(omega), output)) 253 | with io.open(save_out_fp, encoding='utf-8', mode='a') as f: 254 | f.write(content + '\n') 255 | -------------------------------------------------------------------------------- /src/frame/ir/ir_tf_tdqfs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | 4 | import os 5 | from os import listdir 6 | from os.path import join, dirname, abspath, exists 7 | 8 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 9 | sys.path.insert(0, dirname(dirname(dirname(abspath(__file__))))) 10 | 11 | import utils.config_loader as config 12 | from utils.config_loader import logger, path_parser 13 | import summ.rank_sent as rank_sent 14 | import utils.tools as tools 15 | import tools.tfidf_tools as tfidf_tools 16 | import tools.general_tools as general_tools 17 | import frame.ir.ir_tools as ir_tools 18 | 19 | from tqdm import tqdm 20 | import shutil 21 | import frame.ir.ir_config as ir_config 22 | import summ.compute_rouge as rouge 23 | import numpy as np 24 | import io 25 | from data.dataset_parser import dataset_parser 26 | 27 | from multiprocessing import Pool 28 | import frame.centrality.centrality_config as centrality_config 29 | import summ.select_sent as select_sent 30 | import itertools 31 | 32 | assert config.grain == 'sent', f'Invalid grain: {config.grain}' 33 | assert ir_config.test_year.startswith('tdqfs'), f'set ir_config.test_year to tdqfs! now: {ir_config.test_year}' 34 | 35 | sentence_dp = path_parser.data_tdqfs_sentences 36 | query_fp = path_parser.data_tdqfs_queries 37 | summary_target_dp = path_parser.data_tdqfs_summary_targets 38 | 39 | 40 | if ir_config.QUERY_TYPE == 'REF': 41 | test_cid_query_dicts = general_tools.build_tdqfs_oracle_test_cid_query_dicts(query_fp=query_fp) 42 | else: 43 | test_cid_query_dicts = general_tools.build_tdqfs_cid_query_dicts(query_fp=query_fp, proc=True) 44 | 45 | cids = [cq_dict['cid'] for cq_dict in test_cid_query_dicts] 46 | 47 | def get_sentences(cid): 48 | cc_dp = join(sentence_dp, cid) 49 | fns = [fn for fn in listdir(cc_dp)] 50 | lines = itertools.chain(*[io.open(join(cc_dp, fn)).readlines() for fn in fns]) 51 | sentences = [line.strip('\n') for line in lines] 52 | 53 | original_sents = [] 54 | processed_sents = [] 55 | for ss in sentences: 56 | ss_origin = dataset_parser._proc_sent(ss, rm_dialog=False, rm_stop=False, stem=False) 57 | ss_proc = dataset_parser._proc_sent(ss, rm_dialog=False, rm_stop=True, stem=True) 58 | 59 | if ss_proc: # make sure the sent is not removed, i.e., is not empty and is not in a dialog 60 | original_sents.append(ss_origin) 61 | processed_sents.append(ss_proc) 62 | 63 | return [original_sents], [processed_sents] 64 | 65 | 66 | def _rank(cid, query): 67 | original_sents, processed_sents = get_sentences(cid) 68 | rel_scores = tfidf_tools._compute_rel_scores_tf_dot(processed_sents, query) 69 | 70 | # get sid2score 71 | sid2score = dict() 72 | abs_idx = 0 73 | for doc_idx, doc in enumerate(processed_sents): 74 | for sent_idx, sent in enumerate(doc): 75 | sid = config.SEP.join((str(doc_idx), str(sent_idx))) 76 | score = rel_scores[abs_idx] 77 | sid2score[sid] = score 78 | 79 | abs_idx += 1 80 | 81 | # rank scores 82 | sid_score_list = rank_sent.sort_sid2score(sid2score) 83 | # include sentences in records 84 | rank_records = rank_sent.get_rank_records(sid_score_list, sents=original_sents) 85 | # rank_records = rank_sent.get_rank_records(sid_score_list) 86 | 87 | return rank_records 88 | 89 | 90 | def rank_e2e(): 91 | rank_dp = join(path_parser.summary_rank, ir_config.IR_MODEL_NAME_TF) 92 | assert not exists(rank_dp), f'rank_dp exists: {rank_dp}' 93 | os.mkdir(rank_dp) 94 | 95 | for cid_query_dict in tqdm(test_cid_query_dicts): 96 | rank_records = _rank(**cid_query_dict) 97 | rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, cid_query_dict['cid']), with_rank_idx=False) 98 | 99 | logger.info('Successfully dumped rankings to: {}'.format(rank_dp)) 100 | 101 | 102 | def _rank_core(cq_dict): 103 | cid = cq_dict['cid'] 104 | query = cq_dict['query'] 105 | rank_dp = join(path_parser.summary_rank, ir_config.IR_MODEL_NAME_TF) 106 | original_sents, processed_sents = get_sentences(cid) 107 | rel_scores = tfidf_tools._compute_rel_scores_tf_dot(processed_sents, query) 108 | 109 | # get sid2score 110 | sid2score = dict() 111 | abs_idx = 0 112 | for doc_idx, doc in enumerate(processed_sents): 113 | for sent_idx, sent in enumerate(doc): 114 | sid = config.SEP.join((str(doc_idx), str(sent_idx))) 115 | score = rel_scores[abs_idx] 116 | sid2score[sid] = score 117 | abs_idx += 1 118 | 119 | sid_score_list = rank_sent.sort_sid2score(sid2score) 120 | rank_records = rank_sent.get_rank_records(sid_score_list, sents=original_sents) 121 | rank_sent.dump_rank_records(rank_records, out_fp=join(rank_dp, cid), with_rank_idx=False) 122 | 123 | 124 | def rank_e2e_multiproc(): 125 | p = Pool(20) 126 | rank_dp = join(path_parser.summary_rank, ir_config.IR_MODEL_NAME_TF) 127 | 128 | assert not exists(rank_dp), f'rank_dp exists: {rank_dp}' 129 | os.mkdir(rank_dp) 130 | 131 | p.map(_rank_core, test_cid_query_dicts) 132 | 133 | 134 | def ir_rank2records(): 135 | ir_rec_dp = join(path_parser.summary_rank, ir_config.IR_RECORDS_DIR_NAME_TF) 136 | assert not exists(ir_rec_dp), f'ir_rec_dp exists: {ir_rec_dp}' 137 | os.mkdir(ir_rec_dp) 138 | 139 | # cids = tools.get_test_cc_ids() 140 | cids = [c_q_dict['cid'] for c_q_dict in test_cid_query_dicts] 141 | for cid in tqdm(cids): 142 | retrieval_params = { 143 | 'model_name': ir_config.IR_MODEL_NAME_TF, 144 | 'cid': cid, 145 | 'filter_var': ir_config.FILTER_VAR, 146 | 'filter': ir_config.FILTER, 147 | 'deduplicate': ir_config.DEDUPLICATE, 148 | # 'prune': True, 149 | 'prune': False, 150 | } 151 | 152 | retrieved_items = ir_tools.retrieve(**retrieval_params) 153 | ir_tools.dump_retrieval(fp=join(ir_rec_dp, cid), retrieved_items=retrieved_items) 154 | 155 | 156 | def tune(): 157 | """ 158 | Tune IR confidence / compression rate based on Recall Rouge 2. 159 | :return: 160 | """ 161 | if ir_config.FILTER == 'conf': 162 | tune_range = np.arange(0.05, 1.05, 0.05) 163 | else: 164 | interval = 10 165 | tune_range = range(interval, 500+interval, interval) 166 | 167 | ir_tune_dp = join(path_parser.summary_rank, ir_config.IR_TUNE_DIR_NAME_TF) 168 | ir_tune_result_fp = join(path_parser.tune, ir_config.IR_TUNE_DIR_NAME_TF) 169 | with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f: 170 | headline = 'Filter\tRecall\tF1\n' 171 | out_f.write(headline) 172 | 173 | cids = tools.get_test_cc_ids() 174 | for filter_var in tune_range: 175 | if exists(ir_tune_dp): # remove previous output 176 | shutil.rmtree(ir_tune_dp) 177 | os.mkdir(ir_tune_dp) 178 | 179 | for cid in tqdm(cids): 180 | retrieval_params = { 181 | 'model_name': ir_config.IR_MODEL_NAME_TF, 182 | 'cid': cid, 183 | 'filter_var': filter_var, 184 | 'filter': ir_config.FILTER, 185 | 'deduplicate': ir_config.DEDUPLICATE, 186 | 'prune': True, 187 | } 188 | 189 | retrieved_items = ir_tools.retrieve(**retrieval_params) 190 | 191 | summary = '\n'.join([item[-1] for item in retrieved_items]) 192 | # print(summary) 193 | with open(join(ir_tune_dp, cid), mode='a', encoding='utf-8') as out_f: 194 | out_f.write(summary) 195 | 196 | performance = rouge.compute_rouge_for_dev(ir_tune_dp, tune_centrality=False) 197 | with open(ir_tune_result_fp, mode='a', encoding='utf-8') as out_f: 198 | if ir_config.FILTER == 'conf': 199 | rec = '{0:.2f}\t{1}\n'.format(filter_var, performance) 200 | else: 201 | rec = '{0}\t{1}\n'.format(filter_var, performance) 202 | 203 | out_f.write(rec) 204 | 205 | 206 | def compute_rouge_for_oracle(): 207 | """ 208 | The rec dp for oracle saves text for comparing against refecence. 209 | 210 | :return: 211 | """ 212 | ir_rec_dp = join(path_parser.summary_rank, ir_config.IR_RECORDS_DIR_NAME_TF) 213 | rouge_parmas = { 214 | 'text_dp': ir_rec_dp, 215 | 'ref_dp': summary_target_dp, 216 | } 217 | if centrality_config.LENGTH_BUDGET_TUPLE[0] == 'nw': 218 | rouge_parmas['length'] = centrality_config.LENGTH_BUDGET_TUPLE[1] 219 | 220 | output = rouge.compute_rouge_for_tdqfs(**rouge_parmas) 221 | return output 222 | 223 | 224 | def select_e2e(): 225 | params = { 226 | 'model_name': ir_config.IR_MODEL_NAME_TF, 227 | 'length_budget_tuple': ('nw', 250), 228 | 'cos_threshold': 1.0, 229 | # 'retrieved_dp': ir_config.IR_RECORDS_DIR_NAME_TF, 230 | 'retrieved_dp': join(path_parser.summary_rank, ir_config.IR_MODEL_NAME_TF), 231 | 'cc_ids': cids, 232 | } 233 | select_sent.select_end2end_for_tdqfs(**params) 234 | 235 | 236 | def compute_rouge(): 237 | text_params = { 238 | 'model_name': ir_config.IR_MODEL_NAME_TF, 239 | 'length_budget_tuple': ('nw', 250), 240 | 'cos_threshold': 1.0, # do not pos cosine similarity criterion? 241 | } 242 | 243 | text_dp = tools.get_text_dp_for_tdqfs(**text_params) 244 | 245 | rouge_parmas = { 246 | 'text_dp': text_dp, 247 | 'ref_dp': summary_target_dp, 248 | } 249 | if centrality_config.LENGTH_BUDGET_TUPLE[0] == 'nw': 250 | rouge_parmas['length'] = centrality_config.LENGTH_BUDGET_TUPLE[1] 251 | 252 | output = rouge.compute_rouge_for_tdqfs(**rouge_parmas) 253 | return output 254 | 255 | 256 | if __name__ == '__main__': 257 | # rank_e2e_multiproc() 258 | # ir_rank2records() 259 | 260 | compute_rouge_for_oracle() 261 | # tune() 262 | 263 | # select_e2e() 264 | # compute_rouge() 265 | -------------------------------------------------------------------------------- /src/utils/tools.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pickle 3 | import os 4 | from os import listdir 5 | import itertools 6 | from os.path import isfile, isdir, join, dirname, abspath, exists 7 | import sys 8 | 9 | import math 10 | from collections import Counter 11 | 12 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 13 | import utils.config_loader as config 14 | from utils.config_loader import config_meta, path_parser 15 | 16 | 17 | def flatten(list2d): 18 | return list(itertools.chain(*list2d)) 19 | 20 | 21 | def save_obj(obj, fp): 22 | with open(fp, 'wb') as f: 23 | pickle.dump(obj, f) 24 | 25 | 26 | def load_obj(fp): 27 | with open(fp, 'rb') as f: 28 | return pickle.load(f) 29 | 30 | 31 | def get_cc_ids(year, model_mode): 32 | root = join(path_parser.data_docs, year) 33 | 34 | all_cc_ids = [config.SEP.join((year, fn)) for fn in listdir(root) if isdir(join(root, fn))] 35 | assert model_mode == 'test' 36 | return all_cc_ids 37 | 38 | 39 | def get_test_cc_ids(): 40 | return get_cc_ids(config_meta['test_year'], model_mode='test') 41 | 42 | 43 | def get_cc_dp(cid): 44 | year, cc = cid.split(config.SEP) 45 | cc_dp = join(path_parser.data_docs, year, cc) 46 | return cc_dp 47 | 48 | 49 | def get_doc_ids(cid, remove_illegal): 50 | """ 51 | 52 | :param cid: 53 | :param remove_illegal: remove empty docs with no preprocessed sents, e.g., filled with dialogs or quotes. 54 | :return: 55 | """ 56 | cc_dp = get_cc_dp(cid) 57 | doc_ids = [config.SEP.join((cid, fn)) for fn in listdir(cc_dp) 58 | if isfile(join(cc_dp, fn)) and not join(cc_dp, fn).endswith('.swp')] 59 | 60 | illegal_doc_ids = ['2007_D0709B_APW19990124.0079', '2007_D0736H_APW19990311.0174'] 61 | if remove_illegal: 62 | for ill in illegal_doc_ids: 63 | if ill in doc_ids: 64 | doc_ids.remove(ill) 65 | 66 | return doc_ids 67 | 68 | 69 | def get_cid(did): 70 | return config.SEP.join(did.split(config.SEP)[:2]) 71 | 72 | 73 | def get_doc_fps(cid): 74 | cc_dp = get_cc_dp(cid) 75 | doc_fps = [join(cc_dp, fn) for fn in listdir(cc_dp) 76 | if isfile(join(cc_dp, fn)) and not join(cc_dp, fn).endswith('.swp')] 77 | return doc_fps 78 | 79 | 80 | def get_doc_info(did): 81 | return did.split(config.SEP) 82 | 83 | 84 | def get_cc_info(cid): 85 | return cid.split(config.SEP) 86 | 87 | 88 | def get_doc_fp(did): 89 | year, cc, fn = get_doc_info(did) 90 | return join(path_parser.data_docs, year, cc, fn) 91 | 92 | 93 | def get_doc_idx2id(cid): 94 | doc_ids = get_doc_ids(cid) 95 | doc_idx2id = dict() 96 | for doc_idx, doc_id in enumerate(doc_ids): 97 | doc_idx2id[doc_idx] = doc_id 98 | 99 | return doc_idx2id 100 | 101 | 102 | def get_doc_fps_yearly(year): 103 | cc_ids = get_cc_ids(year, model_mode='test') 104 | doc_fps = list(itertools.chain(*[get_doc_fps(cid) for cid in cc_ids])) 105 | return doc_fps 106 | 107 | 108 | def text_to_vec(sent_words): 109 | return Counter(sent_words) 110 | 111 | 112 | def get_n_refs(cid): 113 | year, cc = cid.split(config.SEP) 114 | if year != '2005': 115 | return 4 116 | 117 | ref_fp = join(path_parser.data_summary_targets, year, cid) 118 | with io.open(ref_fp, encoding='utf-8') as ref_f: 119 | n_refs = len(ref_f.readlines()) 120 | 121 | return n_refs 122 | 123 | 124 | def get_sent_info(sid): 125 | doc_idx, sent_idx = [int(idx) for idx in sid.split(config.SEP)] 126 | return doc_idx, sent_idx 127 | 128 | 129 | def get_sent(sents, sid): 130 | if type(sents[0]) != list: 131 | sents = [sents] 132 | 133 | doc_idx, sent_idx = get_sent_info(sid) 134 | 135 | if doc_idx >= len(sents): 136 | raise ValueError('Invalid doc_idx: {} for #doc: {}'.format(doc_idx, len(sents))) 137 | 138 | doc = sents[doc_idx] 139 | if sent_idx >= len(doc): 140 | raise ValueError('Invalid sent_idx: {} for #sents: {}'.format(sent_idx, len(doc))) 141 | 142 | sent = doc[sent_idx] 143 | return sent 144 | 145 | 146 | def compute_sent_cosine(sent_words_1, sent_words_2): 147 | vec_1 = text_to_vec(sent_words_1) 148 | vec_2 = text_to_vec(sent_words_2) 149 | 150 | intersection = set(vec_1.keys()) & set(vec_2.keys()) 151 | numerator = sum([vec_1[x] * vec_2[x] for x in intersection]) 152 | 153 | sum_1 = sum([vec_1[x] ** 2 for x in vec_1.keys()]) 154 | sum_2 = sum([vec_2[x] ** 2 for x in vec_2.keys()]) 155 | denom = math.sqrt(sum_1) * math.sqrt(sum_2) 156 | 157 | if not denom: 158 | return 0.0 159 | else: 160 | return float(numerator) / denom 161 | 162 | 163 | def add_extra(dn_items, extra): 164 | if extra is None: 165 | return dn_items 166 | 167 | if type(extra) is list: 168 | dn_items.extend([str(ex) for ex in extra]) 169 | else: 170 | dn_items.append(str(extra)) 171 | 172 | return dn_items 173 | 174 | 175 | def get_dir_name_items(model_name, n_iter=None, diversity_param_tuple=None, extra=None): 176 | dn_items = [model_name] 177 | if n_iter: 178 | dn_items.append('{}_iter'.format(n_iter)) 179 | 180 | if diversity_param_tuple: 181 | dn_items.append('_'.join([str(item) for item in diversity_param_tuple])) 182 | 183 | dn_items = add_extra(dn_items, extra=extra) 184 | return dn_items 185 | 186 | 187 | def get_rank_dp(model_name, n_iter=None, diversity_param_tuple=None, extra=None): 188 | dn_items = get_dir_name_items(model_name, n_iter, diversity_param_tuple=diversity_param_tuple, extra=extra) 189 | rank_dp = join(path_parser.summary_rank, '-'.join(dn_items)) 190 | return rank_dp 191 | 192 | 193 | def get_text_dp(model_name, 194 | cos_threshold, 195 | diversity_param_tuple=None, 196 | n_iter=None, 197 | extra=None): 198 | dn_items = [model_name, '{}_cos'.format(cos_threshold)] 199 | if n_iter: 200 | dn_items.append('{}_iter'.format(n_iter)) 201 | 202 | if diversity_param_tuple: 203 | dn_items.append('_'.join([str(item) for item in diversity_param_tuple])) 204 | 205 | dn_items = add_extra(dn_items, extra=extra) 206 | text_dp = join(path_parser.summary_text, '-'.join(dn_items)) 207 | 208 | return text_dp 209 | 210 | 211 | def get_text_dp_for_tdqfs(model_name, 212 | cos_threshold, 213 | diversity_param_tuple=None, 214 | length_budget_tuple=None, 215 | n_iter=None, 216 | extra=None): 217 | dn_items = [model_name, f'{cos_threshold}_cos'] 218 | if n_iter: 219 | dn_items.append(f'{n_iter}_iter') 220 | 221 | if diversity_param_tuple: 222 | dn_items.append('_'.join([str(item) for item in diversity_param_tuple])) 223 | 224 | if length_budget_tuple: 225 | dn_items.append('_'.join([str(item) for item in length_budget_tuple])) 226 | 227 | dn_items = add_extra(dn_items, extra=extra) 228 | text_dp = join(path_parser.summary_text, '-'.join(dn_items)) 229 | return text_dp 230 | 231 | 232 | def init_text_dp_for_tdqfs(model_name, cos_threshold, n_iter, diversity_param_tuple, length_budget_tuple, extra): 233 | text_dp = get_text_dp_for_tdqfs(model_name=model_name, 234 | cos_threshold=cos_threshold, 235 | diversity_param_tuple=diversity_param_tuple, 236 | length_budget_tuple=length_budget_tuple, 237 | n_iter=n_iter, 238 | extra=extra) 239 | 240 | if exists(text_dp): 241 | raise ValueError('text_dp exists: {}'.format(text_dp)) 242 | os.mkdir(text_dp) 243 | return text_dp 244 | 245 | 246 | def init_text_dp(model_name, cos_threshold, n_iter, diversity_param_tuple, extra): 247 | text_dp = get_text_dp(model_name=model_name, 248 | cos_threshold=cos_threshold, 249 | diversity_param_tuple=diversity_param_tuple, 250 | n_iter=n_iter, 251 | extra=extra) 252 | 253 | if exists(text_dp): 254 | raise ValueError('text_dp exists: {}'.format(text_dp)) 255 | os.mkdir(text_dp) 256 | 257 | return text_dp 258 | 259 | 260 | def get_passage_fps(cid, retrieved_dp, passage_ids=None): 261 | """ 262 | 263 | :param cid: 264 | :param retrieved_dp: 265 | :param passage_ids: 266 | :return: 267 | """ 268 | year, _ = cid.split(config.SEP) 269 | cc_dp = join(path_parser.data_passages, year, cid) 270 | 271 | if not passage_ids: 272 | if retrieved_dp: # get passage_ids from retrieval file 273 | if not exists(retrieved_dp): 274 | raise ValueError('retrieved_dp does not exist: {}'.format(retrieved_dp)) 275 | 276 | fp = join(retrieved_dp, cid) 277 | with io.open(fp, encoding='utf-8') as f: 278 | content = f.readlines() 279 | 280 | passage_ids = [ll.rstrip('\n').split('\t')[0] for ll in content] 281 | 282 | else: # get passage_ids from data dir 283 | passage_ids = [pid for pid in listdir(cc_dp) if isfile(join(cc_dp, pid))] 284 | 285 | passage_fps = [join(cc_dp, pid) for pid in passage_ids] 286 | 287 | return passage_ids, passage_fps 288 | 289 | 290 | def get_passage_fps_for_tdqfs(cid, retrieved_dp, passage_ids=None): 291 | """ 292 | 293 | :param cid: 294 | :param retrieved_dp: 295 | :param passage_ids: 296 | :return: 297 | """ 298 | cc_dp = join(path_parser.data_tdqfs_passages, cid) 299 | 300 | if not passage_ids: 301 | if retrieved_dp: # get passage_ids from retrieval file 302 | if not exists(retrieved_dp): 303 | raise ValueError('retrieved_dp does not exist: {}'.format(retrieved_dp)) 304 | 305 | fp = join(retrieved_dp, cid) 306 | with io.open(fp, encoding='utf-8') as f: 307 | content = f.readlines() 308 | 309 | passage_ids = [ll.rstrip('\n').split('\t')[0] for ll in content] 310 | 311 | else: # get passage_ids from data dir 312 | passage_ids = [pid for pid in listdir(cc_dp) if isfile(join(cc_dp, pid))] 313 | 314 | passage_fps = [join(cc_dp, pid) for pid in passage_ids] 315 | 316 | return passage_ids, passage_fps 317 | 318 | 319 | def get_query_w_cid(query_info, cid): 320 | return query_info[cid] 321 | --------------------------------------------------------------------------------