├── src ├── models │ ├── __init__.py │ ├── .DS_Store │ ├── stats.py │ ├── rnn.py │ ├── model_builder_LAI.py │ ├── encoder.py │ ├── reporter.py │ ├── data_loader.py │ ├── neural.py │ ├── optimizers.py │ └── trainer.py ├── others │ ├── __init__.py │ ├── __pycache__ │ │ ├── utils.cpython-37.pyc │ │ ├── logging.cpython-37.pyc │ │ ├── pyrouge.cpython-37.pyc │ │ └── __init__.cpython-37.pyc │ ├── logging.py │ ├── utils.py │ └── pyrouge.py ├── prepro │ ├── __init__.py │ ├── .DS_Store │ ├── __pycache__ │ │ ├── utils.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── data_builder.cpython-37.pyc │ │ └── data_builder_LAI.cpython-37.pyc │ ├── utils.py │ ├── .ipynb_checkpoints │ │ └── data_builder_LAI-checkpoint.ipynb │ ├── smart_common_words.txt │ └── data_builder_LAI.py ├── .DS_Store ├── .ipynb_checkpoints │ └── preprocess_LAI-checkpoint.ipynb ├── preprocess_LAI.py ├── distributed.py └── train_LAI.py ├── logs └── .gitignore ├── bert_data └── .gitignore ├── models └── .gitignore ├── raw_data └── .gitignore ├── results └── .gitignore ├── urls └── .DS_Store ├── json_data └── .DS_Store ├── bert_config.json ├── .gitignore ├── README.md └── LICENSE /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/others/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/prepro/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /bert_data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /raw_data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/.DS_Store -------------------------------------------------------------------------------- /urls/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/urls/.DS_Store -------------------------------------------------------------------------------- /json_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/json_data/.DS_Store -------------------------------------------------------------------------------- /src/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/models/.DS_Store -------------------------------------------------------------------------------- /src/prepro/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/prepro/.DS_Store -------------------------------------------------------------------------------- /src/.ipynb_checkpoints/preprocess_LAI-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /src/others/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/others/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/prepro/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/prepro/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/others/__pycache__/logging.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/others/__pycache__/logging.cpython-37.pyc -------------------------------------------------------------------------------- /src/others/__pycache__/pyrouge.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/others/__pycache__/pyrouge.cpython-37.pyc -------------------------------------------------------------------------------- /src/others/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/others/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/prepro/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/prepro/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/prepro/__pycache__/data_builder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/prepro/__pycache__/data_builder.cpython-37.pyc -------------------------------------------------------------------------------- /src/prepro/__pycache__/data_builder_LAI.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Tom/bertsum-chinese-LAI/HEAD/src/prepro/__pycache__/data_builder_LAI.cpython-37.pyc -------------------------------------------------------------------------------- /bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128 19 | } 20 | -------------------------------------------------------------------------------- /src/others/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file) 20 | file_handler.setLevel(log_file_level) 21 | file_handler.setFormatter(log_format) 22 | logger.addHandler(file_handler) 23 | 24 | return logger 25 | -------------------------------------------------------------------------------- /src/prepro/utils.py: -------------------------------------------------------------------------------- 1 | # stopwords = pkgutil.get_data(__package__, 'smart_common_words.txt') 2 | # stopwords = stopwords.decode('ascii').split('\n') 3 | # stopwords = {key.strip(): 1 for key in stopwords} 4 | 5 | 6 | def _get_ngrams(n, text): 7 | """Calcualtes n-grams. 8 | 9 | Args: 10 | n: which n-grams to calculate 11 | text: An array of tokens 12 | 13 | Returns: 14 | A set of n-grams 15 | """ 16 | ngram_set = set() 17 | text_length = len(text) 18 | max_index_ngram_start = text_length - n 19 | for i in range(max_index_ngram_start + 1): 20 | ngram_set.add(tuple(text[i:i + n])) 21 | return ngram_set 22 | 23 | 24 | def _get_word_ngrams(n, sentences): 25 | """Calculates word n-grams for multiple sentences. 26 | """ 27 | assert len(sentences) > 0 28 | assert n > 0 29 | 30 | # words = _split_into_words(sentences) 31 | 32 | words = sum(sentences, []) 33 | # words = [w for w in words if w not in stopwords] 34 | return _get_ngrams(n, words) 35 | -------------------------------------------------------------------------------- /src/preprocess_LAI.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | 3 | 4 | import argparse 5 | import time 6 | 7 | from others.logging import init_logger 8 | from prepro import data_builder_LAI 9 | 10 | 11 | def do_format_to_lines(args): 12 | print(time.clock()) 13 | data_builder_LAI.format_to_lines(args) 14 | print(time.clock()) 15 | 16 | def do_format_to_bert(args): 17 | print(time.clock()) 18 | data_builder_LAI.format_to_bert(args) 19 | print(time.clock()) 20 | 21 | 22 | def str2bool(v): 23 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 24 | return True 25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 26 | return False 27 | else: 28 | raise argparse.ArgumentTypeError('Boolean value expected.') 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("-mode", default='', type=str, help='format_raw, format_to_lines or format_to_bert') 34 | parser.add_argument("-oracle_mode", default='greedy', type=str, help='how to generate oracle summaries, greedy or combination, combination will generate more accurate oracles but take much longer time.') 35 | 36 | parser.add_argument("-raw_path") 37 | parser.add_argument("-save_path") 38 | 39 | parser.add_argument("-shard_size", default=16000, type=int) ###change from 2000 to 16000 40 | parser.add_argument('-min_nsents', default=3, type=int) 41 | parser.add_argument('-max_nsents', default=100, type=int) 42 | parser.add_argument('-min_src_ntokens', default=5, type=int) 43 | parser.add_argument('-max_src_ntokens', default=200, type=int) 44 | 45 | parser.add_argument('-log_file', default='../logs/LCSTS.log') 46 | 47 | parser.add_argument('-dataset', default='', help='train, valid or test, defaul will process all datasets') 48 | 49 | parser.add_argument('-n_cpus', default=2, type=int) 50 | 51 | 52 | args = parser.parse_args() 53 | init_logger(args.log_file) 54 | eval('data_builder_LAI.'+args.mode + '(args)') 55 | -------------------------------------------------------------------------------- /src/prepro/.ipynb_checkpoints/data_builder_LAI-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 14, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gc\n", 10 | "import glob\n", 11 | "import hashlib\n", 12 | "import itertools\n", 13 | "import json\n", 14 | "import os\n", 15 | "import re\n", 16 | "import subprocess\n", 17 | "import time\n", 18 | "from os.path import join as pjoin\n", 19 | "\n", 20 | "import torch\n", 21 | "from multiprocess import Pool\n", 22 | "from pytorch_pretrained_bert import BertTokenizer\n", 23 | "\n", 24 | "import sys\n", 25 | "o_path = \"../\"\n", 26 | "sys.path.append(o_path)\n", 27 | "\n", 28 | "from others.logging import logger\n", 29 | "from others.utils import clean\n", 30 | "from prepro.utils import _get_word_ngrams" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 8, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "data": { 40 | "text/plain": [ 41 | "'/Users/admin/Desktop/Mission/Experiment_Fine-tune_BERT_for_Extractive_Summarization/BertSum-master_Chinese/src/prepro'" 42 | ] 43 | }, 44 | "execution_count": 8, 45 | "metadata": {}, 46 | "output_type": "execute_result" 47 | } 48 | ], 49 | "source": [ 50 | "\"/Users/admin/Desktop/Mission/Experiment_Fine-tune_BERT_for_Extractive_Summarization/BertSum-master_Chinese/src\"" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [] 59 | } 60 | ], 61 | "metadata": { 62 | "kernelspec": { 63 | "display_name": "Python 3", 64 | "language": "python", 65 | "name": "python3" 66 | }, 67 | "language_info": { 68 | "codemirror_mode": { 69 | "name": "ipython", 70 | "version": 3 71 | }, 72 | "file_extension": ".py", 73 | "mimetype": "text/x-python", 74 | "name": "python", 75 | "nbconvert_exporter": "python", 76 | "pygments_lexer": "ipython3", 77 | "version": "3.7.3" 78 | } 79 | }, 80 | "nbformat": 4, 81 | "nbformat_minor": 2 82 | } 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | led / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints/ 78 | *.ipynb 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | # Mac 128 | .DS_Store 129 | -------------------------------------------------------------------------------- /src/others/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | import time 5 | 6 | from others import pyrouge 7 | 8 | REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", 9 | "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} 10 | 11 | 12 | def clean(x): 13 | return re.sub( 14 | r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", 15 | lambda m: REMAP.get(m.group()), x) 16 | 17 | 18 | def process(params): 19 | temp_dir, data = params 20 | candidates, references, pool_id = data 21 | cnt = len(candidates) 22 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 23 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}-{}".format(current_time, pool_id)) 24 | if not os.path.isdir(tmp_dir): 25 | os.mkdir(tmp_dir) 26 | os.mkdir(tmp_dir + "/candidate") 27 | os.mkdir(tmp_dir + "/reference") 28 | try: 29 | 30 | for i in range(cnt): 31 | if len(references[i]) < 1: 32 | continue 33 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 34 | encoding="utf-8") as f: 35 | f.write(candidates[i]) 36 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 37 | encoding="utf-8") as f: 38 | f.write(references[i]) 39 | r = pyrouge.Rouge155(temp_dir=temp_dir) 40 | r.model_dir = tmp_dir + "/reference/" 41 | r.system_dir = tmp_dir + "/candidate/" 42 | r.model_filename_pattern = 'ref.#ID#.txt' 43 | r.system_filename_pattern = r'cand.(\d+).txt' 44 | rouge_results = r.convert_and_evaluate() 45 | print(rouge_results) 46 | results_dict = r.output_to_dict(rouge_results) 47 | finally: 48 | pass 49 | if os.path.isdir(tmp_dir): 50 | shutil.rmtree(tmp_dir) 51 | return results_dict 52 | 53 | 54 | def test_rouge(temp_dir, cand, ref): 55 | candidates = [line.strip() for line in open(cand, encoding='utf-8')] 56 | references = [line.strip() for line in open(ref, encoding='utf-8')] 57 | print(len(candidates)) 58 | print(len(references)) 59 | assert len(candidates) == len(references) 60 | 61 | cnt = len(candidates) 62 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 63 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) 64 | if not os.path.isdir(tmp_dir): 65 | os.mkdir(tmp_dir) 66 | os.mkdir(tmp_dir + "/candidate") 67 | os.mkdir(tmp_dir + "/reference") 68 | try: 69 | 70 | for i in range(cnt): 71 | if len(references[i]) < 1: 72 | continue 73 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 74 | encoding="utf-8") as f: 75 | f.write(candidates[i]) 76 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 77 | encoding="utf-8") as f: 78 | f.write(references[i]) 79 | r = pyrouge.Rouge155(temp_dir=temp_dir) 80 | r.model_dir = tmp_dir + "/reference/" 81 | r.system_dir = tmp_dir + "/candidate/" 82 | r.model_filename_pattern = 'ref.#ID#.txt' 83 | r.system_filename_pattern = r'cand.(\d+).txt' 84 | rouge_results = r.convert_and_evaluate() 85 | print(rouge_results) 86 | results_dict = r.output_to_dict(rouge_results) 87 | finally: 88 | pass 89 | if os.path.isdir(tmp_dir): 90 | shutil.rmtree(tmp_dir) 91 | return results_dict 92 | 93 | 94 | def rouge_results_to_str(results_dict): 95 | return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( 96 | results_dict["rouge_1_f_score"] * 100, 97 | results_dict["rouge_2_f_score"] * 100, 98 | results_dict["rouge_l_f_score"] * 100, 99 | results_dict["rouge_1_recall"] * 100, 100 | results_dict["rouge_2_recall"] * 100, 101 | results_dict["rouge_l_recall"] * 100 102 | ) 103 | -------------------------------------------------------------------------------- /src/models/stats.py: -------------------------------------------------------------------------------- 1 | """ Statistics calculation utility """ 2 | from __future__ import division 3 | 4 | import sys 5 | import time 6 | 7 | from others.logging import logger 8 | 9 | 10 | class Statistics(object): 11 | """ 12 | Accumulator for loss statistics. 13 | Currently calculates: 14 | 15 | * accuracy 16 | * perplexity 17 | * elapsed time 18 | """ 19 | 20 | def __init__(self, loss=0, n_docs=0, n_correct=0): 21 | self.loss = loss 22 | self.n_docs = n_docs 23 | self.start_time = time.time() 24 | 25 | @staticmethod 26 | def all_gather_stats(stat, max_size=4096): 27 | """ 28 | Gather a `Statistics` object accross multiple process/nodes 29 | 30 | Args: 31 | stat(:obj:Statistics): the statistics object to gather 32 | accross all processes/nodes 33 | max_size(int): max buffer size to use 34 | 35 | Returns: 36 | `Statistics`, the update stats object 37 | """ 38 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 39 | return stats[0] 40 | 41 | @staticmethod 42 | def all_gather_stats_list(stat_list, max_size=4096): 43 | """ 44 | Gather a `Statistics` list accross all processes/nodes 45 | 46 | Args: 47 | stat_list(list([`Statistics`])): list of statistics objects to 48 | gather accross all processes/nodes 49 | max_size(int): max buffer size to use 50 | 51 | Returns: 52 | our_stats(list([`Statistics`])): list of updated stats 53 | """ 54 | from torch.distributed import get_rank 55 | from distributed import all_gather_list 56 | 57 | # Get a list of world_size lists with len(stat_list) Statistics objects 58 | all_stats = all_gather_list(stat_list, max_size=max_size) 59 | 60 | our_rank = get_rank() 61 | our_stats = all_stats[our_rank] 62 | for other_rank, stats in enumerate(all_stats): 63 | if other_rank == our_rank: 64 | continue 65 | for i, stat in enumerate(stats): 66 | our_stats[i].update(stat, update_n_src_words=True) 67 | return our_stats 68 | 69 | def update(self, stat, update_n_src_words=False): 70 | """ 71 | Update statistics by suming values with another `Statistics` object 72 | 73 | Args: 74 | stat: another statistic object 75 | update_n_src_words(bool): whether to update (sum) `n_src_words` 76 | or not 77 | 78 | """ 79 | self.loss += stat.loss 80 | 81 | self.n_docs += stat.n_docs 82 | 83 | def xent(self): 84 | """ compute cross entropy """ 85 | if(self.n_docs==0): 86 | return 0 87 | return self.loss/self.n_docs 88 | 89 | 90 | def elapsed_time(self): 91 | """ compute elapsed time """ 92 | return time.time() - self.start_time 93 | 94 | def output(self, step, num_steps, learning_rate, start): 95 | """Write out statistics to stdout. 96 | 97 | Args: 98 | step (int): current step 99 | n_batch (int): total batches 100 | start (int): start time of step. 101 | """ 102 | t = self.elapsed_time() 103 | step_fmt = "%2d" % step 104 | if num_steps > 0: 105 | step_fmt = "%s/%5d" % (step_fmt, num_steps) 106 | logger.info( 107 | ("Step %s; xent: %4.2f; " + 108 | "lr: %7.7f; %3.0f docs/s; %6.0f sec") 109 | % (step_fmt, 110 | self.xent(), 111 | learning_rate, 112 | self.n_docs / (t + 1e-5), 113 | time.time() - start)) 114 | sys.stdout.flush() 115 | 116 | def log_tensorboard(self, prefix, writer, learning_rate, step): 117 | """ display statistics to tensorboard """ 118 | t = self.elapsed_time() 119 | writer.add_scalar(prefix + "/xent", self.xent(), step) 120 | writer.add_scalar(prefix + "/lr", learning_rate, step) 121 | -------------------------------------------------------------------------------- /src/models/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class LayerNormLSTMCell(nn.LSTMCell): 7 | 8 | def __init__(self, input_size, hidden_size, bias=True): 9 | super().__init__(input_size, hidden_size, bias) 10 | 11 | self.ln_ih = nn.LayerNorm(4 * hidden_size) 12 | self.ln_hh = nn.LayerNorm(4 * hidden_size) 13 | self.ln_ho = nn.LayerNorm(hidden_size) 14 | 15 | def forward(self, input, hidden=None): 16 | self.check_forward_input(input) 17 | if hidden is None: 18 | hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) 19 | cx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) 20 | else: 21 | hx, cx = hidden 22 | self.check_forward_hidden(input, hx, '[0]') 23 | self.check_forward_hidden(input, cx, '[1]') 24 | 25 | gates = self.ln_ih(F.linear(input, self.weight_ih, self.bias_ih)) \ 26 | + self.ln_hh(F.linear(hx, self.weight_hh, self.bias_hh)) 27 | i, f, o = gates[:, :(3 * self.hidden_size)].sigmoid().chunk(3, 1) 28 | g = gates[:, (3 * self.hidden_size):].tanh() 29 | 30 | cy = (f * cx) + (i * g) 31 | hy = o * self.ln_ho(cy).tanh() 32 | return hy, cy 33 | 34 | 35 | class LayerNormLSTM(nn.Module): 36 | 37 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, bidirectional=False): 38 | super().__init__() 39 | self.input_size = input_size 40 | self.hidden_size = hidden_size 41 | self.num_layers = num_layers 42 | self.bidirectional = bidirectional 43 | 44 | num_directions = 2 if bidirectional else 1 45 | self.hidden0 = nn.ModuleList([ 46 | LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions), 47 | hidden_size=hidden_size, bias=bias) 48 | for layer in range(num_layers) 49 | ]) 50 | 51 | if self.bidirectional: 52 | self.hidden1 = nn.ModuleList([ 53 | LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions), 54 | hidden_size=hidden_size, bias=bias) 55 | for layer in range(num_layers) 56 | ]) 57 | 58 | def forward(self, input, hidden=None): 59 | seq_len, batch_size, hidden_size = input.size() # supports TxNxH only 60 | num_directions = 2 if self.bidirectional else 1 61 | if hidden is None: 62 | hx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False) 63 | cx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False) 64 | else: 65 | hx, cx = hidden 66 | 67 | ht = [[None, ] * (self.num_layers * num_directions)] * seq_len 68 | ct = [[None, ] * (self.num_layers * num_directions)] * seq_len 69 | 70 | if self.bidirectional: 71 | xs = input 72 | for l, (layer0, layer1) in enumerate(zip(self.hidden0, self.hidden1)): 73 | l0, l1 = 2 * l, 2 * l + 1 74 | h0, c0, h1, c1 = hx[l0], cx[l0], hx[l1], cx[l1] 75 | for t, (x0, x1) in enumerate(zip(xs, reversed(xs))): 76 | ht[t][l0], ct[t][l0] = layer0(x0, (h0, c0)) 77 | h0, c0 = ht[t][l0], ct[t][l0] 78 | t = seq_len - 1 - t 79 | ht[t][l1], ct[t][l1] = layer1(x1, (h1, c1)) 80 | h1, c1 = ht[t][l1], ct[t][l1] 81 | xs = [torch.cat((h[l0], h[l1]), dim=1) for h in ht] 82 | y = torch.stack(xs) 83 | hy = torch.stack(ht[-1]) 84 | cy = torch.stack(ct[-1]) 85 | else: 86 | h, c = hx, cx 87 | for t, x in enumerate(input): 88 | for l, layer in enumerate(self.hidden0): 89 | ht[t][l], ct[t][l] = layer(x, (h[l], c[l])) 90 | x = ht[t][l] 91 | h, c = ht[t], ct[t] 92 | y = torch.stack([h[-1] for h in ht]) 93 | hy = torch.stack(ht[-1]) 94 | cy = torch.stack(ct[-1]) 95 | 96 | return y, (hy, cy) 97 | -------------------------------------------------------------------------------- /src/models/model_builder_LAI.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from pytorch_pretrained_bert import BertModel, BertConfig 5 | from torch.nn.init import xavier_uniform_ 6 | 7 | from models.encoder import TransformerInterEncoder, Classifier, RNNEncoder 8 | from models.optimizers import Optimizer 9 | 10 | 11 | def build_optim(args, model, checkpoint): 12 | """ Build optimizer """ 13 | saved_optimizer_state_dict = None 14 | 15 | if args.train_from != '': 16 | optim = checkpoint['optim'] 17 | saved_optimizer_state_dict = optim.optimizer.state_dict() 18 | else: 19 | optim = Optimizer( 20 | args.optim, args.lr, args.max_grad_norm, 21 | beta1=args.beta1, beta2=args.beta2, 22 | decay_method=args.decay_method, 23 | warmup_steps=args.warmup_steps) 24 | 25 | optim.set_parameters(list(model.named_parameters())) 26 | 27 | if args.train_from != '': 28 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 29 | if args.visible_gpus != '-1': 30 | for state in optim.optimizer.state.values(): 31 | for k, v in state.items(): 32 | if torch.is_tensor(v): 33 | state[k] = v.cuda() 34 | 35 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 36 | raise RuntimeError( 37 | "Error: loaded Adam optimizer from existing model" + 38 | " but optimizer state is empty") 39 | 40 | return optim 41 | 42 | 43 | class Bert(nn.Module): 44 | def __init__(self, temp_dir, load_pretrained_bert, bert_config): 45 | super(Bert, self).__init__() 46 | if(load_pretrained_bert): 47 | self.model = BertModel.from_pretrained('bert-base-chinese', cache_dir=temp_dir) 48 | else: 49 | self.model = BertModel(bert_config) 50 | 51 | def forward(self, x, segs, mask): 52 | encoded_layers, _ = self.model(x, segs, attention_mask =mask) 53 | top_vec = encoded_layers[-1] 54 | return top_vec 55 | 56 | 57 | 58 | class Summarizer(nn.Module): 59 | def __init__(self, args, device, load_pretrained_bert = False, bert_config = None): 60 | super(Summarizer, self).__init__() 61 | self.args = args 62 | self.device = device 63 | self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config) 64 | if (args.encoder == 'classifier'): 65 | self.encoder = Classifier(self.bert.model.config.hidden_size) 66 | elif(args.encoder=='transformer'): 67 | self.encoder = TransformerInterEncoder(self.bert.model.config.hidden_size, args.ff_size, args.heads, 68 | args.dropout, args.inter_layers) 69 | elif(args.encoder=='rnn'): 70 | self.encoder = RNNEncoder(bidirectional=True, num_layers=1, 71 | input_size=self.bert.model.config.hidden_size, hidden_size=args.rnn_size, 72 | dropout=args.dropout) 73 | elif (args.encoder == 'baseline'): 74 | bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.hidden_size, 75 | num_hidden_layers=6, num_attention_heads=8, intermediate_size=args.ff_size) 76 | self.bert.model = BertModel(bert_config) 77 | self.encoder = Classifier(self.bert.model.config.hidden_size) 78 | 79 | if args.param_init != 0.0: 80 | for p in self.encoder.parameters(): 81 | p.data.uniform_(-args.param_init, args.param_init) 82 | if args.param_init_glorot: 83 | for p in self.encoder.parameters(): 84 | if p.dim() > 1: 85 | xavier_uniform_(p) 86 | 87 | self.to(device) 88 | def load_cp(self, pt): 89 | self.load_state_dict(pt['model'], strict=True) 90 | 91 | def forward(self, x, segs, clss, mask, mask_cls, sentence_range=None): 92 | 93 | top_vec = self.bert(x, segs, mask) 94 | sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss] 95 | sents_vec = sents_vec * mask_cls[:, :, None].float() 96 | sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1) 97 | return sent_scores, mask_cls 98 | -------------------------------------------------------------------------------- /src/distributed.py: -------------------------------------------------------------------------------- 1 | """ Pytorch Distributed utils 2 | This piece of code was heavily inspired by the equivalent of Fairseq-py 3 | https://github.com/pytorch/fairseq 4 | """ 5 | 6 | 7 | from __future__ import print_function 8 | 9 | import math 10 | import pickle 11 | 12 | import torch.distributed 13 | 14 | from others.logging import logger 15 | 16 | 17 | def is_master(gpu_ranks, device_id): 18 | return gpu_ranks[device_id] == 0 19 | 20 | 21 | def multi_init(device_id, world_size,gpu_ranks): 22 | print(gpu_ranks) 23 | dist_init_method = 'tcp://localhost:10000' 24 | dist_world_size = world_size 25 | torch.distributed.init_process_group( 26 | backend='nccl', init_method=dist_init_method, 27 | world_size=dist_world_size, rank=gpu_ranks[device_id]) 28 | gpu_rank = torch.distributed.get_rank() 29 | if not is_master(gpu_ranks, device_id): 30 | # print('not master') 31 | logger.disabled = True 32 | 33 | return gpu_rank 34 | 35 | 36 | 37 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 38 | buffer_size=10485760): 39 | """All-reduce and rescale tensors in chunks of the specified size. 40 | 41 | Args: 42 | tensors: list of Tensors to all-reduce 43 | rescale_denom: denominator for rescaling summed Tensors 44 | buffer_size: all-reduce chunk size in bytes 45 | """ 46 | # buffer size in bytes, determine equiv. # of elements based on data type 47 | buffer_t = tensors[0].new( 48 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 49 | buffer = [] 50 | 51 | def all_reduce_buffer(): 52 | # copy tensors into buffer_t 53 | offset = 0 54 | for t in buffer: 55 | numel = t.numel() 56 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 57 | offset += numel 58 | 59 | # all-reduce and rescale 60 | torch.distributed.all_reduce(buffer_t[:offset]) 61 | buffer_t.div_(rescale_denom) 62 | 63 | # copy all-reduced buffer back into tensors 64 | offset = 0 65 | for t in buffer: 66 | numel = t.numel() 67 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 68 | offset += numel 69 | 70 | filled = 0 71 | for t in tensors: 72 | sz = t.numel() * t.element_size() 73 | if sz > buffer_size: 74 | # tensor is bigger than buffer, all-reduce and rescale directly 75 | torch.distributed.all_reduce(t) 76 | t.div_(rescale_denom) 77 | elif filled + sz > buffer_size: 78 | # buffer is full, all-reduce and replace buffer with grad 79 | all_reduce_buffer() 80 | buffer = [t] 81 | filled = sz 82 | else: 83 | # add tensor to buffer 84 | buffer.append(t) 85 | filled += sz 86 | 87 | if len(buffer) > 0: 88 | all_reduce_buffer() 89 | 90 | 91 | def all_gather_list(data, max_size=4096): 92 | """Gathers arbitrary data from all nodes into a list.""" 93 | world_size = torch.distributed.get_world_size() 94 | if not hasattr(all_gather_list, '_in_buffer') or \ 95 | max_size != all_gather_list._in_buffer.size(): 96 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 97 | all_gather_list._out_buffers = [ 98 | torch.cuda.ByteTensor(max_size) 99 | for i in range(world_size) 100 | ] 101 | in_buffer = all_gather_list._in_buffer 102 | out_buffers = all_gather_list._out_buffers 103 | 104 | enc = pickle.dumps(data) 105 | enc_size = len(enc) 106 | if enc_size + 2 > max_size: 107 | raise ValueError( 108 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 109 | assert max_size < 255*256 110 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 111 | in_buffer[1] = enc_size % 255 112 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 113 | 114 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 115 | 116 | results = [] 117 | for i in range(world_size): 118 | out_buffer = out_buffers[i] 119 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 120 | 121 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 122 | result = pickle.loads(bytes_list) 123 | results.append(result) 124 | return results 125 | -------------------------------------------------------------------------------- /src/models/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.neural import MultiHeadedAttention, PositionwiseFeedForward 7 | from models.rnn import LayerNormLSTM 8 | 9 | 10 | class Classifier(nn.Module): 11 | def __init__(self, hidden_size): 12 | super(Classifier, self).__init__() 13 | self.linear1 = nn.Linear(hidden_size, 1) 14 | self.sigmoid = nn.Sigmoid() 15 | 16 | def forward(self, x, mask_cls): 17 | h = self.linear1(x).squeeze(-1) 18 | sent_scores = self.sigmoid(h) * mask_cls.float() 19 | return sent_scores 20 | 21 | 22 | class PositionalEncoding(nn.Module): 23 | 24 | def __init__(self, dropout, dim, max_len=5000): 25 | pe = torch.zeros(max_len, dim) 26 | position = torch.arange(0, max_len).unsqueeze(1) 27 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 28 | -(math.log(10000.0) / dim))) 29 | pe[:, 0::2] = torch.sin(position.float() * div_term) 30 | pe[:, 1::2] = torch.cos(position.float() * div_term) 31 | pe = pe.unsqueeze(0) 32 | super(PositionalEncoding, self).__init__() 33 | self.register_buffer('pe', pe) 34 | self.dropout = nn.Dropout(p=dropout) 35 | self.dim = dim 36 | 37 | def forward(self, emb, step=None): 38 | emb = emb * math.sqrt(self.dim) 39 | if (step): 40 | emb = emb + self.pe[:, step][:, None, :] 41 | 42 | else: 43 | emb = emb + self.pe[:, :emb.size(1)] 44 | emb = self.dropout(emb) 45 | return emb 46 | 47 | def get_emb(self, emb): 48 | return self.pe[:, :emb.size(1)] 49 | 50 | 51 | class TransformerEncoderLayer(nn.Module): 52 | def __init__(self, d_model, heads, d_ff, dropout): 53 | super(TransformerEncoderLayer, self).__init__() 54 | 55 | self.self_attn = MultiHeadedAttention( 56 | heads, d_model, dropout=dropout) 57 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 58 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 59 | self.dropout = nn.Dropout(dropout) 60 | 61 | def forward(self, iter, query, inputs, mask): 62 | if (iter != 0): 63 | input_norm = self.layer_norm(inputs) 64 | else: 65 | input_norm = inputs 66 | 67 | mask = mask.unsqueeze(1) 68 | context = self.self_attn(input_norm, input_norm, input_norm, 69 | mask=mask) 70 | out = self.dropout(context) + inputs 71 | return self.feed_forward(out) 72 | 73 | 74 | class TransformerInterEncoder(nn.Module): 75 | def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers=0): 76 | super(TransformerInterEncoder, self).__init__() 77 | self.d_model = d_model 78 | self.num_inter_layers = num_inter_layers 79 | self.pos_emb = PositionalEncoding(dropout, d_model) 80 | self.transformer_inter = nn.ModuleList( 81 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 82 | for _ in range(num_inter_layers)]) 83 | self.dropout = nn.Dropout(dropout) 84 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 85 | self.wo = nn.Linear(d_model, 1, bias=True) 86 | self.sigmoid = nn.Sigmoid() 87 | 88 | def forward(self, top_vecs, mask): 89 | """ See :obj:`EncoderBase.forward()`""" 90 | 91 | batch_size, n_sents = top_vecs.size(0), top_vecs.size(1) 92 | pos_emb = self.pos_emb.pe[:, :n_sents] 93 | x = top_vecs * mask[:, :, None].float() 94 | x = x + pos_emb 95 | 96 | for i in range(self.num_inter_layers): 97 | x = self.transformer_inter[i](i, x, x, 1 - mask) # all_sents * max_tokens * dim 98 | 99 | x = self.layer_norm(x) 100 | sent_scores = self.sigmoid(self.wo(x)) 101 | sent_scores = sent_scores.squeeze(-1) * mask.float() 102 | 103 | return sent_scores 104 | 105 | 106 | class RNNEncoder(nn.Module): 107 | 108 | def __init__(self, bidirectional, num_layers, input_size, 109 | hidden_size, dropout=0.0): 110 | super(RNNEncoder, self).__init__() 111 | num_directions = 2 if bidirectional else 1 112 | assert hidden_size % num_directions == 0 113 | hidden_size = hidden_size // num_directions 114 | 115 | self.rnn = LayerNormLSTM( 116 | input_size=input_size, 117 | hidden_size=hidden_size, 118 | num_layers=num_layers, 119 | bidirectional=bidirectional) 120 | 121 | self.wo = nn.Linear(num_directions * hidden_size, 1, bias=True) 122 | self.dropout = nn.Dropout(dropout) 123 | self.sigmoid = nn.Sigmoid() 124 | 125 | def forward(self, x, mask): 126 | """See :func:`EncoderBase.forward()`""" 127 | x = torch.transpose(x, 1, 0) 128 | memory_bank, _ = self.rnn(x) 129 | memory_bank = self.dropout(memory_bank) + x 130 | memory_bank = torch.transpose(memory_bank, 1, 0) 131 | 132 | sent_scores = self.sigmoid(self.wo(memory_bank)) 133 | sent_scores = sent_scores.squeeze(-1) * mask.float() 134 | return sent_scores 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERTSUM中文数据实验说明 2 | 3 | 基于论文`Fine-tune BERT for Extractive Summarization`的方法论&源代码,进行调整,在中文数据集中进行实验。 4 | 5 | 参考论文作者主页(含论文pdf & 源代码链接):[http://nlp-yang.github.io/](http://nlp-yang.github.io/) 6 | 7 | 8 | 9 | ## 数据集 10 | 11 | * 中文数据集:**LCSTS2.0** (A Large Scale Chinese Short Text Summarization Dataset) 12 | 13 | * 来源:Intelligent Computing Research Center, Harbin Institute of Technology Shenzhen Graduate School(`哈尔滨工业大学深圳研究生院·智能计算研究中心`) 14 | 15 | * 申请途径:[http://icrc.hitsz.edu.cn/Article/show/139.html](http://icrc.hitsz.edu.cn/Article/show/139.html) 16 | 17 | 18 | 19 | ## 预处理 20 | 21 | ###Step 1 下载原始数据 22 | 23 | 下载LCSTS2.0原始数据,下载途径。将`LCSTS2.0/DATA`目录下所有**PART_*.txt**文件放入`BertSum-master_Chinese/raw_data` 24 | 25 | ###Step 2 将原始文件转换成json文件存储 26 | 27 | `BertSum-master_Chinese/src`目录下,运行: 28 | 29 | ``` 30 | python preprocess_LAI.py -mode format_raw -raw_path ../raw_data -save_path ../raw_data -log_file ../logs/preprocess.log 31 | ``` 32 | 33 | ###Step 3 分句分词 & 分割文件 & 进一步简化格式 34 | 35 | * 分句分词:首先按照符号['。', '!', '?']分句,若得到的句数少于2句,则用[',', ';']进一步分句 36 | 37 | * 分割文件:训练集文件太大,分割成小文件便于后期训练。**分割后,每个文件包含不多于16000条记录** 38 | 39 | `BertSum-master_Chinese/src`目录下,运行: 40 | 41 | ``` 42 | python preprocess_LAI.py -mode format_to_lines -raw_path ../raw_data -save_path ../json_data/LCSTS -log_file ../logs/preprocess.log 43 | ``` 44 | 45 | ###Step 4 句子标注 & 训练前预处理 46 | 47 | * 句子预处理:找出与参考摘要最接近的n句话(相似程度以ROUGE衡量),标注为1(属于摘要) 48 | 49 | ``` 50 | python preprocess_LAI.py -mode format_to_bert -raw_path ../json_data -save_path ../bert_data -oracle_mode greedy -n_cpus 2 -log_file ../logs/preprocess.log 51 | ``` 52 | 53 | 54 | 55 | ## 模型训练 56 | 57 | **提醒**:**First run**: For the first time, you should use single-GPU, so the code can download the BERT model. Change ``-visible_gpus 0,1,2 -gpu_ranks 0,1,2 -world_size 3`` to ``-visible_gpus 0 -gpu_ranks 0 -world_size 1``, after downloading, you could kill the process and rerun the code with multi-GPUs. 58 | 59 | 60 | 61 | `BertSum-master_Chinese/src`目录下,运行下列三行代码其中之一: 62 | 63 | **三行代码区别是参数 -encoder设置了不同值(classifier & transformer & rnn)分别代表三种不同的摘要层** 64 | 65 | BERT+Classifier model: 66 | 67 | ``` 68 | python train_LAI.py -mode train -encoder classifier -dropout 0.1 -bert_data_path ../bert_data/LCSTS -model_path ../models/bert_classifier -lr 2e-3 -visible_gpus 1 -gpu_ranks 0 -world_size 1 -report_every 50 -save_checkpoint_steps 1000 -batch_size 3000 -decay_method noam -train_steps 30000 -accum_count 2 -log_file ../logs/bert_classifier -use_interval true -warmup_steps 10000 69 | ``` 70 | 71 | BERT+Transformer model: 72 | ``` 73 | python train_LAI.py -mode train -encoder transformer -dropout 0.1 -bert_data_path ../bert_data/LCSTS -model_path ../models/bert_transformer -lr 2e-3 -visible_gpus 1 -gpu_ranks 0 -world_size 1 -report_every 50 -save_checkpoint_steps 1000 -batch_size 3000 -decay_method noam -train_steps 30000 -accum_count 2 -log_file ../logs/bert_transformer -use_interval true -warmup_steps 10000 -ff_size 2048 -inter_layers 2 -heads 8 74 | ``` 75 | 76 | BERT+RNN model: 77 | ``` 78 | python train_LAI.py -mode train -encoder rnn -dropout 0.1 -bert_data_path ../bert_data/LCSTS -model_path ../models/bert_rnn -lr 2e-3 -visible_gpus 1 -gpu_ranks 0 -world_size 1 -report_every 50 -save_checkpoint_steps 1000 -batch_size 3000 -decay_method noam -train_steps 30000 -accum_count 2 -log_file ../logs/bert_rnn -use_interval true -warmup_steps 10000 -rnn_size 768 -dropout 0.1 79 | ``` 80 | 81 | 82 | 83 | **提醒**:如果训练过程被意外中断,可以通过以下代码从某个节点继续训练(-save_checkpoint_steps设置了定期储存模型信息) 84 | 85 | 以下代码将从第20,000步储存的模型继续训练(示例-encoder 设置为transformer,classifier & rnn同理): 86 | 87 | ``` 88 | python train_LAI.py -mode train -encoder transformer -dropout 0.1 -bert_data_path ../bert_data/LCSTS -model_path ../models/bert_transformer -lr 2e-3 -visible_gpus 1 -gpu_ranks 0 -world_size 1 -report_every 50 -save_checkpoint_steps 1000 -batch_size 3000 -decay_method noam -train_steps 30000 -accum_count 2 -log_file ../logs/bert_transformer -use_interval true -warmup_steps 10000 -ff_size 2048 -inter_layers 2 -heads 8 -train_from ../models/bert_transformer/model_step_20000.pt 89 | ``` 90 | 91 | 92 | 93 | ## 模型评估 94 | 95 | 模型训练完毕后,`BertSum-master_Chinese/src`目录下,运行: 96 | 97 | ``` 98 | python train_LAI.py -mode test -bert_data_path ../bert_data/LCSTS -model_path MODEL_PATH -visible_gpus 1 -gpu_ranks 0 -batch_size 30000 -log_file LOG_FILE -result_path ../results/LCSTS -test_all -block_trigram False -test_from ../models/bert_transformer/model_step_30000.pt 99 | ``` 100 | 101 | - `MODEL_PATH` 是储存checkpoints的目录 102 | - `RESULT_PATH` is where you want to put decoded summaries (default `../results/LCSTS`) 103 | 104 | 105 | 106 | ## 生成Oracle摘要 107 | 108 | Oracle摘要:使用贪婪算法,在原文中找到与参考摘要最相近n句话(原代码设置n=3,可自行调整) 109 | 110 | 111 | 112 | 摘要大小调整方法: 113 | 114 | 目录`BertSum-master_Chinese/src/prepro/`: 115 | 116 | data_builder_LAI.py: line204 - oracle_ids = greedy_selection(source, tgt, **3**) 117 | 118 | 119 | 120 | `BertSum-master_Chinese/src`目录下,运行: 121 | 122 | ``` 123 | python train_LAI.py -mode oracle -bert_data_path ../bert_data/LCSTS -visible_gpus -1 -batch_size 30000 -log_file LCSTS_oracle -result_path ../results/LCSTS_oracle -block_trigram false 124 | ``` 125 | 126 | 127 | 128 | ## 新数据集训练 129 | 130 | 如果要在新数据集上使用BERTSUM,只需: 131 | 132 | * 原始数据格式整理成`BertSum-master_Chinese/raw_data/LCSTS_test.json`文件中数据对应格式 133 | * 相应文件名/路径名也要做调整如:`-bert_data_path ../bert_data/LCSTS` `-log_file LCSTS_oracle` (LCSTS改成对应名称) 134 | * 调整完后,预处理部分从**Step 3**开始即可 -------------------------------------------------------------------------------- /src/models/reporter.py: -------------------------------------------------------------------------------- 1 | """ Report manager utility """ 2 | from __future__ import print_function 3 | 4 | import time 5 | from datetime import datetime 6 | 7 | from models.stats import Statistics 8 | from others.logging import logger 9 | 10 | 11 | def build_report_manager(opt): 12 | if opt.tensorboard: 13 | from tensorboardX import SummaryWriter 14 | tensorboard_log_dir = opt.tensorboard_log_dir 15 | 16 | if not opt.train_from: 17 | tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S") 18 | 19 | writer = SummaryWriter(tensorboard_log_dir, 20 | comment="Unmt") 21 | else: 22 | writer = None 23 | 24 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 25 | tensorboard_writer=writer) 26 | return report_mgr 27 | 28 | 29 | class ReportMgrBase(object): 30 | """ 31 | Report Manager Base class 32 | Inherited classes should override: 33 | * `_report_training` 34 | * `_report_step` 35 | """ 36 | 37 | def __init__(self, report_every, start_time=-1.): 38 | """ 39 | Args: 40 | report_every(int): Report status every this many sentences 41 | start_time(float): manually set report start time. Negative values 42 | means that you will need to set it later or use `start()` 43 | """ 44 | self.report_every = report_every 45 | self.progress_step = 0 46 | self.start_time = start_time 47 | 48 | def start(self): 49 | self.start_time = time.time() 50 | 51 | def log(self, *args, **kwargs): 52 | logger.info(*args, **kwargs) 53 | 54 | def report_training(self, step, num_steps, learning_rate, 55 | report_stats, multigpu=False): 56 | """ 57 | This is the user-defined batch-level traing progress 58 | report function. 59 | 60 | Args: 61 | step(int): current step count. 62 | num_steps(int): total number of batches. 63 | learning_rate(float): current learning rate. 64 | report_stats(Statistics): old Statistics instance. 65 | Returns: 66 | report_stats(Statistics): updated Statistics instance. 67 | """ 68 | if self.start_time < 0: 69 | raise ValueError("""ReportMgr needs to be started 70 | (set 'start_time' or use 'start()'""") 71 | 72 | if step % self.report_every == 0: 73 | if multigpu: 74 | report_stats = \ 75 | Statistics.all_gather_stats(report_stats) 76 | self._report_training( 77 | step, num_steps, learning_rate, report_stats) 78 | self.progress_step += 1 79 | return Statistics() 80 | else: 81 | return report_stats 82 | 83 | def _report_training(self, *args, **kwargs): 84 | """ To be overridden """ 85 | raise NotImplementedError() 86 | 87 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 88 | """ 89 | Report stats of a step 90 | 91 | Args: 92 | train_stats(Statistics): training stats 93 | valid_stats(Statistics): validation stats 94 | lr(float): current learning rate 95 | """ 96 | self._report_step( 97 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 98 | 99 | def _report_step(self, *args, **kwargs): 100 | raise NotImplementedError() 101 | 102 | 103 | class ReportMgr(ReportMgrBase): 104 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 105 | """ 106 | A report manager that writes statistics on standard output as well as 107 | (optionally) TensorBoard 108 | 109 | Args: 110 | report_every(int): Report status every this many sentences 111 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 112 | The TensorBoard Summary writer to use or None 113 | """ 114 | super(ReportMgr, self).__init__(report_every, start_time) 115 | self.tensorboard_writer = tensorboard_writer 116 | 117 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 118 | if self.tensorboard_writer is not None: 119 | stats.log_tensorboard( 120 | prefix, self.tensorboard_writer, learning_rate, step) 121 | 122 | def _report_training(self, step, num_steps, learning_rate, 123 | report_stats): 124 | """ 125 | See base class method `ReportMgrBase.report_training`. 126 | """ 127 | report_stats.output(step, num_steps, 128 | learning_rate, self.start_time) 129 | 130 | # Log the progress using the number of batches on the x-axis. 131 | self.maybe_log_tensorboard(report_stats, 132 | "progress", 133 | learning_rate, 134 | self.progress_step) 135 | report_stats = Statistics() 136 | 137 | return report_stats 138 | 139 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 140 | """ 141 | See base class method `ReportMgrBase.report_step`. 142 | """ 143 | if train_stats is not None: 144 | self.log('Train xent: %g' % train_stats.xent()) 145 | 146 | self.maybe_log_tensorboard(train_stats, 147 | "train", 148 | lr, 149 | step) 150 | 151 | if valid_stats is not None: 152 | self.log('Validation xent: %g at step %d' % (valid_stats.xent(), step)) 153 | 154 | self.maybe_log_tensorboard(valid_stats, 155 | "valid", 156 | lr, 157 | step) 158 | -------------------------------------------------------------------------------- /src/prepro/smart_common_words.txt: -------------------------------------------------------------------------------- 1 | rrb 2 | llb 3 | lsb 4 | rsb 5 | reuters 6 | ap 7 | jan 8 | feb 9 | mar 10 | apr 11 | may 12 | jun 13 | jul 14 | aug 15 | sep 16 | oct 17 | nov 18 | dec 19 | tech 20 | news 21 | index 22 | mon 23 | tue 24 | wed 25 | thu 26 | fri 27 | sat 28 | 's 29 | a 30 | a's 31 | able 32 | about 33 | above 34 | according 35 | accordingly 36 | across 37 | actually 38 | after 39 | afterwards 40 | again 41 | against 42 | ain't 43 | all 44 | allow 45 | allows 46 | almost 47 | alone 48 | along 49 | already 50 | also 51 | although 52 | always 53 | am 54 | amid 55 | among 56 | amongst 57 | an 58 | and 59 | another 60 | any 61 | anybody 62 | anyhow 63 | anyone 64 | anything 65 | anyway 66 | anyways 67 | anywhere 68 | apart 69 | appear 70 | appreciate 71 | appropriate 72 | are 73 | aren't 74 | around 75 | as 76 | aside 77 | ask 78 | asking 79 | associated 80 | at 81 | available 82 | away 83 | awfully 84 | b 85 | be 86 | became 87 | because 88 | become 89 | becomes 90 | becoming 91 | been 92 | before 93 | beforehand 94 | behind 95 | being 96 | believe 97 | below 98 | beside 99 | besides 100 | best 101 | better 102 | between 103 | beyond 104 | both 105 | brief 106 | but 107 | by 108 | c 109 | c'mon 110 | c's 111 | came 112 | can 113 | can't 114 | cannot 115 | cant 116 | cause 117 | causes 118 | certain 119 | certainly 120 | changes 121 | clearly 122 | co 123 | com 124 | come 125 | comes 126 | concerning 127 | consequently 128 | consider 129 | considering 130 | contain 131 | containing 132 | contains 133 | corresponding 134 | could 135 | couldn't 136 | course 137 | currently 138 | d 139 | definitely 140 | described 141 | despite 142 | did 143 | didn't 144 | different 145 | do 146 | does 147 | doesn't 148 | doing 149 | don't 150 | done 151 | down 152 | downwards 153 | during 154 | e 155 | each 156 | edu 157 | eg 158 | e.g. 159 | eight 160 | either 161 | else 162 | elsewhere 163 | enough 164 | entirely 165 | especially 166 | et 167 | etc 168 | etc. 169 | even 170 | ever 171 | every 172 | everybody 173 | everyone 174 | everything 175 | everywhere 176 | ex 177 | exactly 178 | example 179 | except 180 | f 181 | far 182 | few 183 | fifth 184 | five 185 | followed 186 | following 187 | follows 188 | for 189 | former 190 | formerly 191 | forth 192 | four 193 | from 194 | further 195 | furthermore 196 | g 197 | get 198 | gets 199 | getting 200 | given 201 | gives 202 | go 203 | goes 204 | going 205 | gone 206 | got 207 | gotten 208 | greetings 209 | h 210 | had 211 | hadn't 212 | happens 213 | hardly 214 | has 215 | hasn't 216 | have 217 | haven't 218 | having 219 | he 220 | he's 221 | hello 222 | help 223 | hence 224 | her 225 | here 226 | here's 227 | hereafter 228 | hereby 229 | herein 230 | hereupon 231 | hers 232 | herself 233 | hi 234 | him 235 | himself 236 | his 237 | hither 238 | hopefully 239 | how 240 | howbeit 241 | however 242 | i 243 | i'd 244 | i'll 245 | i'm 246 | i've 247 | ie 248 | i.e. 249 | if 250 | ignored 251 | immediate 252 | in 253 | inasmuch 254 | inc 255 | indeed 256 | indicate 257 | indicated 258 | indicates 259 | inner 260 | insofar 261 | instead 262 | into 263 | inward 264 | is 265 | isn't 266 | it 267 | it'd 268 | it'll 269 | it's 270 | its 271 | itself 272 | j 273 | just 274 | k 275 | keep 276 | keeps 277 | kept 278 | know 279 | knows 280 | known 281 | l 282 | lately 283 | later 284 | latter 285 | latterly 286 | least 287 | less 288 | lest 289 | let 290 | let's 291 | like 292 | liked 293 | likely 294 | little 295 | look 296 | looking 297 | looks 298 | ltd 299 | m 300 | mainly 301 | many 302 | may 303 | maybe 304 | me 305 | mean 306 | meanwhile 307 | merely 308 | might 309 | more 310 | moreover 311 | most 312 | mostly 313 | mr. 314 | ms. 315 | much 316 | must 317 | my 318 | myself 319 | n 320 | namely 321 | nd 322 | near 323 | nearly 324 | necessary 325 | need 326 | needs 327 | neither 328 | never 329 | nevertheless 330 | new 331 | next 332 | nine 333 | no 334 | nobody 335 | non 336 | none 337 | noone 338 | nor 339 | normally 340 | not 341 | nothing 342 | novel 343 | now 344 | nowhere 345 | o 346 | obviously 347 | of 348 | off 349 | often 350 | oh 351 | ok 352 | okay 353 | old 354 | on 355 | once 356 | one 357 | ones 358 | only 359 | onto 360 | or 361 | other 362 | others 363 | otherwise 364 | ought 365 | our 366 | ours 367 | ourselves 368 | out 369 | outside 370 | over 371 | overall 372 | own 373 | p 374 | particular 375 | particularly 376 | per 377 | perhaps 378 | placed 379 | please 380 | plus 381 | possible 382 | presumably 383 | probably 384 | provides 385 | q 386 | que 387 | quite 388 | qv 389 | r 390 | rather 391 | rd 392 | re 393 | really 394 | reasonably 395 | regarding 396 | regardless 397 | regards 398 | relatively 399 | respectively 400 | right 401 | s 402 | said 403 | same 404 | saw 405 | say 406 | saying 407 | says 408 | second 409 | secondly 410 | see 411 | seeing 412 | seem 413 | seemed 414 | seeming 415 | seems 416 | seen 417 | self 418 | selves 419 | sensible 420 | sent 421 | serious 422 | seriously 423 | seven 424 | several 425 | shall 426 | she 427 | should 428 | shouldn't 429 | since 430 | six 431 | so 432 | some 433 | somebody 434 | somehow 435 | someone 436 | something 437 | sometime 438 | sometimes 439 | somewhat 440 | somewhere 441 | soon 442 | sorry 443 | specified 444 | specify 445 | specifying 446 | still 447 | sub 448 | such 449 | sup 450 | sure 451 | t 452 | t's 453 | take 454 | taken 455 | tell 456 | tends 457 | th 458 | than 459 | thank 460 | thanks 461 | thanx 462 | that 463 | that's 464 | thats 465 | the 466 | their 467 | theirs 468 | them 469 | themselves 470 | then 471 | thence 472 | there 473 | there's 474 | thereafter 475 | thereby 476 | therefore 477 | therein 478 | theres 479 | thereupon 480 | these 481 | they 482 | they'd 483 | they'll 484 | they're 485 | they've 486 | think 487 | third 488 | this 489 | thorough 490 | thoroughly 491 | those 492 | though 493 | three 494 | through 495 | throughout 496 | thru 497 | thus 498 | to 499 | together 500 | too 501 | took 502 | toward 503 | towards 504 | tried 505 | tries 506 | truly 507 | try 508 | trying 509 | twice 510 | two 511 | u 512 | un 513 | under 514 | unfortunately 515 | unless 516 | unlikely 517 | until 518 | unto 519 | up 520 | upon 521 | us 522 | use 523 | used 524 | useful 525 | uses 526 | using 527 | usually 528 | uucp 529 | v 530 | value 531 | various 532 | very 533 | via 534 | viz 535 | vs 536 | w 537 | want 538 | wants 539 | was 540 | wasn't 541 | way 542 | we 543 | we'd 544 | we'll 545 | we're 546 | we've 547 | welcome 548 | well 549 | went 550 | were 551 | weren't 552 | what 553 | what's 554 | whatever 555 | when 556 | whence 557 | whenever 558 | where 559 | where's 560 | whereafter 561 | whereas 562 | whereby 563 | wherein 564 | whereupon 565 | wherever 566 | whether 567 | which 568 | while 569 | whither 570 | who 571 | who's 572 | whoever 573 | whole 574 | whom 575 | whose 576 | why 577 | will 578 | willing 579 | wish 580 | with 581 | within 582 | without 583 | won't 584 | wonder 585 | would 586 | would 587 | wouldn't 588 | x 589 | y 590 | yes 591 | yet 592 | you 593 | you'd 594 | you'll 595 | you're 596 | you've 597 | your 598 | yours 599 | yourself 600 | yourselves 601 | z 602 | zero 603 | -------------------------------------------------------------------------------- /src/models/data_loader.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import glob 3 | import random 4 | 5 | import torch 6 | 7 | from others.logging import logger 8 | 9 | 10 | 11 | class Batch(object): 12 | def _pad(self, data, pad_id, width=-1): 13 | if (width == -1): 14 | width = max(len(d) for d in data) 15 | rtn_data = [d + [pad_id] * (width - len(d)) for d in data] 16 | return rtn_data 17 | 18 | def __init__(self, data=None, device=None, is_test=False): 19 | """Create a Batch from a list of examples.""" 20 | if data is not None: 21 | self.batch_size = len(data) 22 | pre_src = [x[0] for x in data] 23 | pre_labels = [x[1] for x in data] 24 | pre_segs = [x[2] for x in data] 25 | pre_clss = [x[3] for x in data] 26 | 27 | src = torch.tensor(self._pad(pre_src, 0)) 28 | 29 | labels = torch.tensor(self._pad(pre_labels, 0)) 30 | segs = torch.tensor(self._pad(pre_segs, 0)) 31 | mask = 1 - (src == 0) 32 | 33 | clss = torch.tensor(self._pad(pre_clss, -1)) 34 | mask_cls = 1 - (clss == -1) 35 | clss[clss == -1] = 0 36 | 37 | setattr(self, 'clss', clss.to(device)) 38 | setattr(self, 'mask_cls', mask_cls.to(device)) 39 | setattr(self, 'src', src.to(device)) 40 | setattr(self, 'labels', labels.to(device)) 41 | setattr(self, 'segs', segs.to(device)) 42 | setattr(self, 'mask', mask.to(device)) 43 | 44 | if (is_test): 45 | src_str = [x[-2] for x in data] 46 | setattr(self, 'src_str', src_str) 47 | tgt_str = [x[-1] for x in data] 48 | setattr(self, 'tgt_str', tgt_str) 49 | 50 | def __len__(self): 51 | return self.batch_size 52 | 53 | 54 | def batch(data, batch_size): 55 | """Yield elements from data in chunks of batch_size.""" 56 | minibatch, size_so_far = [], 0 57 | for ex in data: 58 | minibatch.append(ex) 59 | size_so_far = simple_batch_size_fn(ex, len(minibatch)) 60 | if size_so_far == batch_size: 61 | yield minibatch 62 | minibatch, size_so_far = [], 0 63 | elif size_so_far > batch_size: 64 | yield minibatch[:-1] 65 | minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1) 66 | if minibatch: 67 | yield minibatch 68 | 69 | 70 | def load_dataset(args, corpus_type, shuffle): 71 | """ 72 | Dataset generator. Don't do extra stuff here, like printing, 73 | because they will be postponed to the first loading time. 74 | 75 | Args: 76 | corpus_type: 'train' or 'valid' 77 | Returns: 78 | A list of dataset, the dataset(s) are lazily loaded. 79 | """ 80 | assert corpus_type in ["train", "valid", "test"] 81 | 82 | def _lazy_dataset_loader(pt_file, corpus_type): 83 | dataset = torch.load(pt_file) 84 | logger.info('Loading %s dataset from %s, number of examples: %d' % 85 | (corpus_type, pt_file, len(dataset))) 86 | return dataset 87 | 88 | # Sort the glob output by file name (by increasing indexes). 89 | pts = sorted(glob.glob(args.bert_data_path + '.' + corpus_type + '.[0-9]*.pt')) 90 | if pts: 91 | if (shuffle): 92 | random.shuffle(pts) 93 | 94 | for pt in pts: 95 | yield _lazy_dataset_loader(pt, corpus_type) 96 | else: 97 | # Only one inputters.*Dataset, simple! 98 | pt = args.bert_data_path + '.' + corpus_type + '.pt' 99 | yield _lazy_dataset_loader(pt, corpus_type) 100 | 101 | 102 | def simple_batch_size_fn(new, count): 103 | src, labels = new[0], new[1] 104 | global max_n_sents, max_n_tokens, max_size 105 | if count == 1: 106 | max_size = 0 107 | max_n_sents=0 108 | max_n_tokens=0 109 | max_n_sents = max(max_n_sents, len(src)) 110 | max_size = max(max_size, max_n_sents) 111 | src_elements = count * max_size 112 | return src_elements 113 | 114 | 115 | class Dataloader(object): 116 | def __init__(self, args, datasets, batch_size, 117 | device, shuffle, is_test): 118 | self.args = args 119 | self.datasets = datasets 120 | self.batch_size = batch_size 121 | self.device = device 122 | self.shuffle = shuffle 123 | self.is_test = is_test 124 | self.cur_iter = self._next_dataset_iterator(datasets) 125 | 126 | assert self.cur_iter is not None 127 | 128 | def __iter__(self): 129 | dataset_iter = (d for d in self.datasets) 130 | while self.cur_iter is not None: 131 | for batch in self.cur_iter: 132 | yield batch 133 | self.cur_iter = self._next_dataset_iterator(dataset_iter) 134 | 135 | 136 | def _next_dataset_iterator(self, dataset_iter): 137 | try: 138 | # Drop the current dataset for decreasing memory 139 | if hasattr(self, "cur_dataset"): 140 | self.cur_dataset = None 141 | gc.collect() 142 | del self.cur_dataset 143 | gc.collect() 144 | 145 | self.cur_dataset = next(dataset_iter) 146 | except StopIteration: 147 | return None 148 | 149 | return DataIterator(args = self.args, 150 | dataset=self.cur_dataset, batch_size=self.batch_size, 151 | device=self.device, shuffle=self.shuffle, is_test=self.is_test) 152 | 153 | 154 | class DataIterator(object): 155 | def __init__(self, args, dataset, batch_size, device=None, is_test=False, 156 | shuffle=True): 157 | self.args = args 158 | self.batch_size, self.is_test, self.dataset = batch_size, is_test, dataset 159 | self.iterations = 0 160 | self.device = device 161 | self.shuffle = shuffle 162 | 163 | self.sort_key = lambda x: len(x[1]) 164 | 165 | self._iterations_this_epoch = 0 166 | 167 | def data(self): 168 | if self.shuffle: 169 | random.shuffle(self.dataset) 170 | xs = self.dataset 171 | return xs 172 | 173 | 174 | def preprocess(self, ex, is_test): 175 | src = ex['src'] 176 | if('labels' in ex): 177 | labels = ex['labels'] 178 | else: 179 | labels = ex['src_sent_labels'] 180 | 181 | segs = ex['segs'] 182 | if(not self.args.use_interval): 183 | segs=[0]*len(segs) 184 | clss = ex['clss'] 185 | src_txt = ex['src_txt'] 186 | tgt_txt = ex['tgt_txt'] 187 | 188 | if(is_test): 189 | return src,labels,segs, clss, src_txt, tgt_txt 190 | else: 191 | return src,labels,segs, clss 192 | 193 | def batch_buffer(self, data, batch_size): 194 | minibatch, size_so_far = [], 0 195 | for ex in data: 196 | if(len(ex['src'])==0): 197 | continue 198 | ex = self.preprocess(ex, self.is_test) 199 | if(ex is None): 200 | continue 201 | minibatch.append(ex) 202 | size_so_far = simple_batch_size_fn(ex, len(minibatch)) 203 | if size_so_far == batch_size: 204 | yield minibatch 205 | minibatch, size_so_far = [], 0 206 | elif size_so_far > batch_size: 207 | yield minibatch[:-1] 208 | minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1) 209 | if minibatch: 210 | yield minibatch 211 | 212 | def create_batches(self): 213 | """ Create batches """ 214 | data = self.data() 215 | for buffer in self.batch_buffer(data, self.batch_size * 50): 216 | 217 | p_batch = sorted(buffer, key=lambda x: len(x[3])) 218 | p_batch = batch(p_batch, self.batch_size) 219 | 220 | p_batch = list(p_batch) 221 | if (self.shuffle): 222 | random.shuffle(p_batch) 223 | for b in p_batch: 224 | yield b 225 | 226 | def __iter__(self): 227 | while True: 228 | self.batches = self.create_batches() 229 | for idx, minibatch in enumerate(self.batches): 230 | # fast-forward if loaded from state 231 | if self._iterations_this_epoch > idx: 232 | continue 233 | self.iterations += 1 234 | self._iterations_this_epoch += 1 235 | batch = Batch(minibatch, self.device, self.is_test) 236 | 237 | yield batch 238 | return 239 | -------------------------------------------------------------------------------- /src/models/neural.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def gelu(x): 8 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 9 | 10 | 11 | class PositionwiseFeedForward(nn.Module): 12 | """ A two-layer Feed-Forward-Network with residual layer norm. 13 | 14 | Args: 15 | d_model (int): the size of input for the first-layer of the FFN. 16 | d_ff (int): the hidden layer size of the second-layer 17 | of the FNN. 18 | dropout (float): dropout probability in :math:`[0, 1)`. 19 | """ 20 | 21 | def __init__(self, d_model, d_ff, dropout=0.1): 22 | super(PositionwiseFeedForward, self).__init__() 23 | self.w_1 = nn.Linear(d_model, d_ff) 24 | self.w_2 = nn.Linear(d_ff, d_model) 25 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 26 | self.actv = gelu 27 | self.dropout_1 = nn.Dropout(dropout) 28 | self.dropout_2 = nn.Dropout(dropout) 29 | 30 | def forward(self, x): 31 | inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x)))) 32 | output = self.dropout_2(self.w_2(inter)) 33 | return output + x 34 | 35 | 36 | class MultiHeadedAttention(nn.Module): 37 | """ 38 | Multi-Head Attention module from 39 | "Attention is All You Need" 40 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. 41 | 42 | Similar to standard `dot` attention but uses 43 | multiple attention distributions simulataneously 44 | to select relevant items. 45 | 46 | .. mermaid:: 47 | 48 | graph BT 49 | A[key] 50 | B[value] 51 | C[query] 52 | O[output] 53 | subgraph Attn 54 | D[Attn 1] 55 | E[Attn 2] 56 | F[Attn N] 57 | end 58 | A --> D 59 | C --> D 60 | A --> E 61 | C --> E 62 | A --> F 63 | C --> F 64 | D --> O 65 | E --> O 66 | F --> O 67 | B --> O 68 | 69 | Also includes several additional tricks. 70 | 71 | Args: 72 | head_count (int): number of parallel heads 73 | model_dim (int): the dimension of keys/values/queries, 74 | must be divisible by head_count 75 | dropout (float): dropout parameter 76 | """ 77 | 78 | def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): 79 | assert model_dim % head_count == 0 80 | self.dim_per_head = model_dim // head_count 81 | self.model_dim = model_dim 82 | 83 | super(MultiHeadedAttention, self).__init__() 84 | self.head_count = head_count 85 | 86 | self.linear_keys = nn.Linear(model_dim, 87 | head_count * self.dim_per_head) 88 | self.linear_values = nn.Linear(model_dim, 89 | head_count * self.dim_per_head) 90 | self.linear_query = nn.Linear(model_dim, 91 | head_count * self.dim_per_head) 92 | self.softmax = nn.Softmax(dim=-1) 93 | self.dropout = nn.Dropout(dropout) 94 | self.use_final_linear = use_final_linear 95 | if (self.use_final_linear): 96 | self.final_linear = nn.Linear(model_dim, model_dim) 97 | 98 | def forward(self, key, value, query, mask=None, 99 | layer_cache=None, type=None, predefined_graph_1=None): 100 | """ 101 | Compute the context vector and the attention vectors. 102 | 103 | Args: 104 | key (`FloatTensor`): set of `key_len` 105 | key vectors `[batch, key_len, dim]` 106 | value (`FloatTensor`): set of `key_len` 107 | value vectors `[batch, key_len, dim]` 108 | query (`FloatTensor`): set of `query_len` 109 | query vectors `[batch, query_len, dim]` 110 | mask: binary mask indicating which keys have 111 | non-zero attention `[batch, query_len, key_len]` 112 | Returns: 113 | (`FloatTensor`, `FloatTensor`) : 114 | 115 | * output context vectors `[batch, query_len, dim]` 116 | * one of the attention vectors `[batch, query_len, key_len]` 117 | """ 118 | 119 | # CHECKS 120 | # batch, k_len, d = key.size() 121 | # batch_, k_len_, d_ = value.size() 122 | # aeq(batch, batch_) 123 | # aeq(k_len, k_len_) 124 | # aeq(d, d_) 125 | # batch_, q_len, d_ = query.size() 126 | # aeq(batch, batch_) 127 | # aeq(d, d_) 128 | # aeq(self.model_dim % 8, 0) 129 | # if mask is not None: 130 | # batch_, q_len_, k_len_ = mask.size() 131 | # aeq(batch_, batch) 132 | # aeq(k_len_, k_len) 133 | # aeq(q_len_ == q_len) 134 | # END CHECKS 135 | 136 | batch_size = key.size(0) 137 | dim_per_head = self.dim_per_head 138 | head_count = self.head_count 139 | key_len = key.size(1) 140 | query_len = query.size(1) 141 | 142 | def shape(x): 143 | """ projection """ 144 | return x.view(batch_size, -1, head_count, dim_per_head) \ 145 | .transpose(1, 2) 146 | 147 | def unshape(x): 148 | """ compute context """ 149 | return x.transpose(1, 2).contiguous() \ 150 | .view(batch_size, -1, head_count * dim_per_head) 151 | 152 | # 1) Project key, value, and query. 153 | if layer_cache is not None: 154 | if type == "self": 155 | query, key, value = self.linear_query(query), \ 156 | self.linear_keys(query), \ 157 | self.linear_values(query) 158 | 159 | key = shape(key) 160 | value = shape(value) 161 | 162 | if layer_cache is not None: 163 | device = key.device 164 | if layer_cache["self_keys"] is not None: 165 | key = torch.cat( 166 | (layer_cache["self_keys"].to(device), key), 167 | dim=2) 168 | if layer_cache["self_values"] is not None: 169 | value = torch.cat( 170 | (layer_cache["self_values"].to(device), value), 171 | dim=2) 172 | layer_cache["self_keys"] = key 173 | layer_cache["self_values"] = value 174 | elif type == "context": 175 | query = self.linear_query(query) 176 | if layer_cache is not None: 177 | if layer_cache["memory_keys"] is None: 178 | key, value = self.linear_keys(key), \ 179 | self.linear_values(value) 180 | key = shape(key) 181 | value = shape(value) 182 | else: 183 | key, value = layer_cache["memory_keys"], \ 184 | layer_cache["memory_values"] 185 | layer_cache["memory_keys"] = key 186 | layer_cache["memory_values"] = value 187 | else: 188 | key, value = self.linear_keys(key), \ 189 | self.linear_values(value) 190 | key = shape(key) 191 | value = shape(value) 192 | else: 193 | key = self.linear_keys(key) 194 | value = self.linear_values(value) 195 | query = self.linear_query(query) 196 | key = shape(key) 197 | value = shape(value) 198 | 199 | query = shape(query) 200 | 201 | key_len = key.size(2) 202 | query_len = query.size(2) 203 | 204 | # 2) Calculate and scale scores. 205 | query = query / math.sqrt(dim_per_head) 206 | scores = torch.matmul(query, key.transpose(2, 3)) 207 | 208 | if mask is not None: 209 | mask = mask.unsqueeze(1).expand_as(scores) 210 | scores = scores.masked_fill(mask, -1e18) 211 | 212 | # 3) Apply attention dropout and compute context vectors. 213 | 214 | attn = self.softmax(scores) 215 | 216 | if (not predefined_graph_1 is None): 217 | attn_masked = attn[:, -1] * predefined_graph_1 218 | attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) 219 | 220 | attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) 221 | 222 | drop_attn = self.dropout(attn) 223 | if (self.use_final_linear): 224 | context = unshape(torch.matmul(drop_attn, value)) 225 | output = self.final_linear(context) 226 | return output 227 | else: 228 | context = torch.matmul(drop_attn, value) 229 | return context 230 | 231 | # CHECK 232 | # batch_, q_len_, d_ = output.size() 233 | # aeq(q_len, q_len_) 234 | # aeq(batch, batch_) 235 | # aeq(d, d_) 236 | 237 | # Return one attn 238 | 239 | -------------------------------------------------------------------------------- /src/models/optimizers.py: -------------------------------------------------------------------------------- 1 | """ Optimizers class """ 2 | import torch 3 | import torch.optim as optim 4 | from torch.nn.utils import clip_grad_norm_ 5 | 6 | 7 | # from onmt.utils import use_gpu 8 | 9 | 10 | def use_gpu(opt): 11 | """ 12 | Creates a boolean if gpu used 13 | """ 14 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 15 | (hasattr(opt, 'gpu') and opt.gpu > -1) 16 | 17 | def build_optim(model, opt, checkpoint): 18 | """ Build optimizer """ 19 | saved_optimizer_state_dict = None 20 | 21 | if opt.train_from: 22 | optim = checkpoint['optim'] 23 | # We need to save a copy of optim.optimizer.state_dict() for setting 24 | # the, optimizer state later on in Stage 2 in this method, since 25 | # the method optim.set_parameters(model.parameters()) will overwrite 26 | # optim.optimizer, and with ith the values stored in 27 | # optim.optimizer.state_dict() 28 | saved_optimizer_state_dict = optim.optimizer.state_dict() 29 | else: 30 | optim = Optimizer( 31 | opt.optim, opt.learning_rate, opt.max_grad_norm, 32 | lr_decay=opt.learning_rate_decay, 33 | start_decay_steps=opt.start_decay_steps, 34 | decay_steps=opt.decay_steps, 35 | beta1=opt.adam_beta1, 36 | beta2=opt.adam_beta2, 37 | adagrad_accum=opt.adagrad_accumulator_init, 38 | decay_method=opt.decay_method, 39 | warmup_steps=opt.warmup_steps) 40 | 41 | # Stage 1: 42 | # Essentially optim.set_parameters (re-)creates and optimizer using 43 | # model.paramters() as parameters that will be stored in the 44 | # optim.optimizer.param_groups field of the torch optimizer class. 45 | # Importantly, this method does not yet load the optimizer state, as 46 | # essentially it builds a new optimizer with empty optimizer state and 47 | # parameters from the model. 48 | optim.set_parameters(model.named_parameters()) 49 | 50 | if opt.train_from: 51 | # Stage 2: In this stage, which is only performed when loading an 52 | # optimizer from a checkpoint, we load the saved_optimizer_state_dict 53 | # into the re-created optimizer, to set the optim.optimizer.state 54 | # field, which was previously empty. For this, we use the optimizer 55 | # state saved in the "saved_optimizer_state_dict" variable for 56 | # this purpose. 57 | # See also: https://github.com/pytorch/pytorch/issues/2830 58 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 59 | # Convert back the state values to cuda type if applicable 60 | if use_gpu(opt): 61 | for state in optim.optimizer.state.values(): 62 | for k, v in state.items(): 63 | if torch.is_tensor(v): 64 | state[k] = v.cuda() 65 | 66 | # We want to make sure that indeed we have a non-empty optimizer state 67 | # when we loaded an existing model. This should be at least the case 68 | # for Adam, which saves "exp_avg" and "exp_avg_sq" state 69 | # (Exponential moving average of gradient and squared gradient values) 70 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 71 | raise RuntimeError( 72 | "Error: loaded Adam optimizer from existing model" + 73 | " but optimizer state is empty") 74 | 75 | return optim 76 | 77 | 78 | class MultipleOptimizer(object): 79 | """ Implement multiple optimizers needed for sparse adam """ 80 | 81 | def __init__(self, op): 82 | """ ? """ 83 | self.optimizers = op 84 | 85 | def zero_grad(self): 86 | """ ? """ 87 | for op in self.optimizers: 88 | op.zero_grad() 89 | 90 | def step(self): 91 | """ ? """ 92 | for op in self.optimizers: 93 | op.step() 94 | 95 | @property 96 | def state(self): 97 | """ ? """ 98 | return {k: v for op in self.optimizers for k, v in op.state.items()} 99 | 100 | def state_dict(self): 101 | """ ? """ 102 | return [op.state_dict() for op in self.optimizers] 103 | 104 | def load_state_dict(self, state_dicts): 105 | """ ? """ 106 | assert len(state_dicts) == len(self.optimizers) 107 | for i in range(len(state_dicts)): 108 | self.optimizers[i].load_state_dict(state_dicts[i]) 109 | 110 | 111 | class Optimizer(object): 112 | """ 113 | Controller class for optimization. Mostly a thin 114 | wrapper for `optim`, but also useful for implementing 115 | rate scheduling beyond what is currently available. 116 | Also implements necessary methods for training RNNs such 117 | as grad manipulations. 118 | 119 | Args: 120 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam] 121 | lr (float): learning rate 122 | lr_decay (float, optional): learning rate decay multiplier 123 | start_decay_steps (int, optional): step to start learning rate decay 124 | beta1, beta2 (float, optional): parameters for adam 125 | adagrad_accum (float, optional): initialization parameter for adagrad 126 | decay_method (str, option): custom decay options 127 | warmup_steps (int, option): parameter for `noam` decay 128 | 129 | We use the default parameters for Adam that are suggested by 130 | the original paper https://arxiv.org/pdf/1412.6980.pdf 131 | These values are also used by other established implementations, 132 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 133 | https://keras.io/optimizers/ 134 | Recently there are slightly different values used in the paper 135 | "Attention is all you need" 136 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98 137 | was used there however, beta2=0.999 is still arguably the more 138 | established value, so we use that here as well 139 | """ 140 | 141 | def __init__(self, method, learning_rate, max_grad_norm, 142 | lr_decay=1, start_decay_steps=None, decay_steps=None, 143 | beta1=0.9, beta2=0.999, 144 | adagrad_accum=0.0, 145 | decay_method=None, 146 | warmup_steps=4000 147 | ): 148 | self.last_ppl = None 149 | self.learning_rate = learning_rate 150 | self.original_lr = learning_rate 151 | self.max_grad_norm = max_grad_norm 152 | self.method = method 153 | self.lr_decay = lr_decay 154 | self.start_decay_steps = start_decay_steps 155 | self.decay_steps = decay_steps 156 | self.start_decay = False 157 | self._step = 0 158 | self.betas = [beta1, beta2] 159 | self.adagrad_accum = adagrad_accum 160 | self.decay_method = decay_method 161 | self.warmup_steps = warmup_steps 162 | 163 | def set_parameters(self, params): 164 | """ ? """ 165 | self.params = [] 166 | self.sparse_params = [] 167 | for k, p in params: 168 | if p.requires_grad: 169 | if self.method != 'sparseadam' or "embed" not in k: 170 | self.params.append(p) 171 | else: 172 | self.sparse_params.append(p) 173 | if self.method == 'sgd': 174 | self.optimizer = optim.SGD(self.params, lr=self.learning_rate) 175 | elif self.method == 'adagrad': 176 | self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate) 177 | for group in self.optimizer.param_groups: 178 | for p in group['params']: 179 | self.optimizer.state[p]['sum'] = self.optimizer\ 180 | .state[p]['sum'].fill_(self.adagrad_accum) 181 | elif self.method == 'adadelta': 182 | self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate) 183 | elif self.method == 'adam': 184 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate, 185 | betas=self.betas, eps=1e-9) 186 | elif self.method == 'sparseadam': 187 | self.optimizer = MultipleOptimizer( 188 | [optim.Adam(self.params, lr=self.learning_rate, 189 | betas=self.betas, eps=1e-8), 190 | optim.SparseAdam(self.sparse_params, lr=self.learning_rate, 191 | betas=self.betas, eps=1e-8)]) 192 | else: 193 | raise RuntimeError("Invalid optim method: " + self.method) 194 | 195 | def _set_rate(self, learning_rate): 196 | self.learning_rate = learning_rate 197 | if self.method != 'sparseadam': 198 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 199 | else: 200 | for op in self.optimizer.optimizers: 201 | op.param_groups[0]['lr'] = self.learning_rate 202 | 203 | def step(self): 204 | """Update the model parameters based on current gradients. 205 | 206 | Optionally, will employ gradient modification or update learning 207 | rate. 208 | """ 209 | self._step += 1 210 | 211 | # Decay method used in tensor2tensor. 212 | if self.decay_method == "noam": 213 | self._set_rate( 214 | self.original_lr * 215 | 216 | min(self._step ** (-0.5), 217 | self._step * self.warmup_steps**(-1.5))) 218 | 219 | # self._set_rate(self.original_lr *self.model_size ** (-0.5) *min(1.0, self._step / self.warmup_steps)*max(self._step, self.warmup_steps)**(-0.5)) 220 | # Decay based on start_decay_steps every decay_steps 221 | else: 222 | if ((self.start_decay_steps is not None) and ( 223 | self._step >= self.start_decay_steps)): 224 | self.start_decay = True 225 | if self.start_decay: 226 | if ((self._step - self.start_decay_steps) 227 | % self.decay_steps == 0): 228 | self.learning_rate = self.learning_rate * self.lr_decay 229 | 230 | if self.method != 'sparseadam': 231 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 232 | 233 | if self.max_grad_norm: 234 | clip_grad_norm_(self.params, self.max_grad_norm) 235 | self.optimizer.step() 236 | 237 | 238 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/prepro/data_builder_LAI.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import glob 3 | import hashlib 4 | import itertools 5 | import json 6 | import os 7 | import re 8 | import subprocess 9 | import time 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | from multiprocess import Pool 14 | from pytorch_pretrained_bert import BertTokenizer 15 | 16 | from others.logging import logger 17 | from prepro.utils import _get_word_ngrams 18 | import emoji 19 | 20 | 21 | def cal_rouge(evaluated_ngrams, reference_ngrams): 22 | reference_count = len(reference_ngrams) 23 | evaluated_count = len(evaluated_ngrams) 24 | 25 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 26 | overlapping_count = len(overlapping_ngrams) 27 | 28 | if evaluated_count == 0: 29 | precision = 0.0 30 | else: 31 | precision = overlapping_count / evaluated_count 32 | 33 | if reference_count == 0: 34 | recall = 0.0 35 | else: 36 | recall = overlapping_count / reference_count 37 | 38 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 39 | return {"f": f1_score, "p": precision, "r": recall} 40 | 41 | 42 | def combination_selection(doc_sent_list, abstract_sent_list, summary_size): 43 | def _rouge_clean(s): 44 | return re.sub(r'[^a-zA-Z0-9 ]', '', s) 45 | 46 | max_rouge = 0.0 47 | max_idx = (0, 0) 48 | abstract = sum(abstract_sent_list, []) 49 | abstract = _rouge_clean(' '.join(abstract)).split() 50 | sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list] 51 | evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents] 52 | reference_1grams = _get_word_ngrams(1, [abstract]) 53 | evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents] 54 | reference_2grams = _get_word_ngrams(2, [abstract]) 55 | 56 | impossible_sents = [] 57 | for s in range(summary_size + 1): 58 | combinations = itertools.combinations([i for i in range(len(sents)) if i not in impossible_sents], s + 1) 59 | for c in combinations: 60 | candidates_1 = [evaluated_1grams[idx] for idx in c] 61 | candidates_1 = set.union(*map(set, candidates_1)) 62 | candidates_2 = [evaluated_2grams[idx] for idx in c] 63 | candidates_2 = set.union(*map(set, candidates_2)) 64 | rouge_1 = cal_rouge(candidates_1, reference_1grams)['f'] 65 | rouge_2 = cal_rouge(candidates_2, reference_2grams)['f'] 66 | 67 | rouge_score = rouge_1 + rouge_2 68 | if (s == 0 and rouge_score == 0): 69 | impossible_sents.append(c[0]) 70 | if rouge_score > max_rouge: 71 | max_idx = c 72 | max_rouge = rouge_score 73 | return sorted(list(max_idx)) 74 | 75 | 76 | def greedy_selection(doc_sent_list, abstract_sent_list, summary_size): 77 | def _rouge_clean(s): 78 | return re.sub(r'[^a-zA-Z0-9 ]', '', s) 79 | 80 | max_rouge = 0.0 81 | abstract = sum(abstract_sent_list, []) 82 | abstract = _rouge_clean(' '.join(abstract)).split() 83 | sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list] 84 | evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents] 85 | reference_1grams = _get_word_ngrams(1, [abstract]) 86 | evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents] 87 | reference_2grams = _get_word_ngrams(2, [abstract]) 88 | 89 | selected = [] 90 | for s in range(summary_size): 91 | cur_max_rouge = max_rouge 92 | cur_id = -1 93 | for i in range(len(sents)): 94 | if (i in selected): 95 | continue 96 | c = selected + [i] 97 | candidates_1 = [evaluated_1grams[idx] for idx in c] 98 | candidates_1 = set.union(*map(set, candidates_1)) 99 | candidates_2 = [evaluated_2grams[idx] for idx in c] 100 | candidates_2 = set.union(*map(set, candidates_2)) 101 | rouge_1 = cal_rouge(candidates_1, reference_1grams)['f'] 102 | rouge_2 = cal_rouge(candidates_2, reference_2grams)['f'] 103 | rouge_score = rouge_1 + rouge_2 104 | if rouge_score > cur_max_rouge: 105 | cur_max_rouge = rouge_score 106 | cur_id = i 107 | if (cur_id == -1): 108 | return selected 109 | selected.append(cur_id) 110 | max_rouge = cur_max_rouge 111 | 112 | return sorted(selected) 113 | 114 | 115 | class BertData(): 116 | def __init__(self, args): 117 | self.args = args 118 | self.tokenizer = BertTokenizer.from_pretrained('../../bert-base-chinese-vocab.txt', do_lower_case=True) ### change from 'bert-base-uncased' to 'bert-base-chinese' 119 | self.sep_vid = self.tokenizer.vocab['[SEP]'] 120 | self.cls_vid = self.tokenizer.vocab['[CLS]'] 121 | self.pad_vid = self.tokenizer.vocab['[PAD]'] 122 | 123 | def preprocess(self, src, tgt, oracle_ids): 124 | 125 | if (len(src) == 0): 126 | return None 127 | 128 | original_src_txt = [' '.join(s) for s in src] 129 | 130 | labels = [0] * len(src) 131 | for l in oracle_ids: 132 | labels[l] = 1 133 | 134 | idxs = [i for i, s in enumerate(src) if (len(s) > self.args.min_src_ntokens)] 135 | 136 | src = [src[i][:self.args.max_src_ntokens] for i in idxs] 137 | labels = [labels[i] for i in idxs] 138 | src = src[:self.args.max_nsents] 139 | labels = labels[:self.args.max_nsents] 140 | 141 | if (len(src) < self.args.min_nsents): 142 | return None 143 | if (len(labels) == 0): 144 | return None 145 | 146 | src_txt = [' '.join(sent) for sent in src] 147 | # text = [' '.join(ex['src_txt'][i].split()[:self.args.max_src_ntokens]) for i in idxs] 148 | # text = [_clean(t) for t in text] 149 | text = ' [SEP] [CLS] '.join(src_txt) 150 | src_subtokens = self.tokenizer.tokenize(text) 151 | src_subtokens = src_subtokens[:510] 152 | src_subtokens = ['[CLS]'] + src_subtokens + ['[SEP]'] 153 | 154 | src_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(src_subtokens) 155 | _segs = [-1] + [i for i, t in enumerate(src_subtoken_idxs) if t == self.sep_vid] 156 | segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))] 157 | segments_ids = [] 158 | for i, s in enumerate(segs): 159 | if (i % 2 == 0): 160 | segments_ids += s * [0] 161 | else: 162 | segments_ids += s * [1] 163 | cls_ids = [i for i, t in enumerate(src_subtoken_idxs) if t == self.cls_vid] 164 | labels = labels[:len(cls_ids)] 165 | 166 | tgt_txt = ''.join([' '.join(tt) for tt in tgt]) 167 | src_txt = [original_src_txt[i] for i in idxs] 168 | return src_subtoken_idxs, labels, segments_ids, cls_ids, src_txt, tgt_txt 169 | 170 | 171 | def format_to_bert(args): 172 | if (args.dataset != ''): 173 | datasets = [args.dataset] 174 | else: 175 | datasets = ['train', 'valid', 'test'] 176 | for corpus_type in datasets: 177 | a_lst = [] 178 | for json_f in glob.glob(pjoin(args.raw_path, '*' + corpus_type + '.*.json')): 179 | real_name = json_f.split('/')[-1] 180 | a_lst.append((json_f, args, pjoin(args.save_path, real_name.replace('json', 'bert.pt')))) 181 | print(a_lst) 182 | pool = Pool(args.n_cpus) 183 | for d in pool.imap(_format_to_bert, a_lst): 184 | pass 185 | 186 | pool.close() 187 | pool.join() 188 | 189 | 190 | def _format_to_bert(params): 191 | json_file, args, save_file = params 192 | if (os.path.exists(save_file)): 193 | logger.info('Ignore %s' % save_file) 194 | return 195 | 196 | bert = BertData(args) 197 | 198 | logger.info('Processing %s' % json_file) 199 | jobs = json.load(open(json_file)) 200 | datasets = [] 201 | for d in jobs: 202 | source, tgt = d['src'], d['tgt'] 203 | if (args.oracle_mode == 'greedy'): 204 | oracle_ids = greedy_selection(source, tgt, 3) 205 | elif (args.oracle_mode == 'combination'): 206 | oracle_ids = combination_selection(source, tgt, 3) 207 | b_data = bert.preprocess(source, tgt, oracle_ids) 208 | if (b_data is None): 209 | continue 210 | indexed_tokens, labels, segments_ids, cls_ids, src_txt, tgt_txt = b_data 211 | b_data_dict = {"src": indexed_tokens, "labels": labels, "segs": segments_ids, 'clss': cls_ids, 212 | 'src_txt': src_txt, "tgt_txt": tgt_txt} 213 | datasets.append(b_data_dict) 214 | logger.info('Saving to %s' % save_file) 215 | torch.save(datasets, save_file) 216 | datasets = [] 217 | gc.collect() 218 | 219 | 220 | def format_to_lines(args): 221 | train_files, valid_files, test_files = [], [], [] 222 | for f in glob.glob(pjoin(args.raw_path, '*.json')): 223 | real_name = f.split('/')[-1].split('.')[0] 224 | with open(f, "r") as read_json: 225 | data_file = json.load(read_json) 226 | 227 | if ('valid' in real_name): 228 | valid_files = data_file 229 | elif ('test' in real_name): 230 | test_files = data_file 231 | elif ('train' in real_name): 232 | train_files = data_file 233 | 234 | corpora = {'train': train_files, 'valid': valid_files, 'test': test_files} 235 | 236 | for corpus_type in ['train', 'valid', 'test']: 237 | dataset = [] 238 | p_ct = 0 239 | 240 | for d in corpora[corpus_type]: 241 | d_formated = _format_to_lines(d) 242 | dataset.append(d_formated) 243 | if (len(dataset) > args.shard_size-1): 244 | pt_file = "{:s}.{:s}.{:d}.json".format(args.save_path, corpus_type, p_ct) 245 | with open(pt_file, 'w') as save: 246 | # save.write('\n'.join(dataset)) 247 | save.write(json.dumps(dataset)) 248 | p_ct += 1 249 | dataset = [] 250 | 251 | if (len(dataset) > 0): 252 | pt_file = "{:s}.{:s}.{:d}.json".format(args.save_path, corpus_type, p_ct) 253 | with open(pt_file, 'w') as save: 254 | # save.write('\n'.join(dataset)) 255 | save.write(json.dumps(dataset)) 256 | p_ct += 1 257 | dataset = [] 258 | 259 | def _format_to_lines(json_element): 260 | json_element_split = {'src':sent_token_split(json_element['src']), 'tgt':sent_token_split(json_element['tgt'],True)} 261 | return json_element_split 262 | 263 | 264 | def format_raw(args): 265 | for i in glob.glob(pjoin(args.raw_path, 'PART_*.txt')): 266 | is_train = True if "PART_I." in i else False 267 | is_valid = True if "PART_II." in i else False 268 | 269 | raw_formated = _format_raw(i, is_train=is_train) 270 | 271 | if is_train: 272 | file_name = "LCSTS_train.json" 273 | elif is_valid: 274 | file_name = "LCSTS_valid.json" 275 | else: 276 | file_name = "LCSTS_test.json" 277 | 278 | json.dump(raw_formated, open(pjoin(args.raw_path,file_name),"w")) 279 | 280 | 281 | def _format_raw(raw_LCSTS_path, is_train=True): 282 | raw_LCSTS_file = open(raw_LCSTS_path, "r") 283 | raw_LCSTS_str = raw_LCSTS_file.read() 284 | raw_LCSTS_str_list = raw_LCSTS_str.split("\n") 285 | 286 | num_line_el = 8 if is_train else 9 287 | extract_line = [0, 2, 5] if is_train else [0, 3, 6] 288 | num_el = len(raw_LCSTS_str_list)//num_line_el 289 | 290 | json_list = [] 291 | for i in range(num_el): 292 | doc = {"id": raw_LCSTS_str_list[i*num_line_el+extract_line[0]].strip(), 293 | "tgt": raw_LCSTS_str_list[i*num_line_el+extract_line[1]].strip(), 294 | "src": raw_LCSTS_str_list[i*num_line_el+extract_line[2]].strip()} 295 | 296 | json_list.append(doc) 297 | 298 | for i in json_list: 299 | num = re.findall(r'\d+', i['id']) 300 | doc_id = int(num[0]) 301 | i['id'] = doc_id 302 | 303 | return json_list 304 | 305 | def sent_token_split(doc, is_short_summary = False): 306 | doc_modified = re.sub(r' ', "", doc) 307 | doc_modified = re.sub(r':\w+:', "", emoji.demojize(doc_modified)) 308 | 309 | ### if the doc is a very short summary, just don't split sentence 310 | if is_short_summary: 311 | doc_split = [list(doc_modified)] 312 | return doc_split 313 | 314 | doc_modified = re.sub(r'。', "。 ", doc_modified) 315 | doc_modified = re.sub(r'!', "! ", doc_modified) 316 | doc_modified = re.sub(r'?', "? ", doc_modified) 317 | 318 | doc_split = re.split(r' ', doc_modified) 319 | doc_split = [i for i in doc_split if len(i)>=2] 320 | 321 | if len(doc_split)<2: 322 | doc_modified = re.sub(r',', ", ", doc_modified) 323 | doc_modified = re.sub(r';', "; ", doc_modified) 324 | doc_split = re.split(r' ', doc_modified) 325 | doc_split = [i for i in doc_split if len(i)>=2] 326 | 327 | doc_split = [list(i) for i in doc_split] 328 | 329 | return doc_split 330 | -------------------------------------------------------------------------------- /src/train_LAI.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main training workflow 4 | """ 5 | from __future__ import division 6 | 7 | import argparse 8 | import glob 9 | import os 10 | import random 11 | import signal 12 | import time 13 | 14 | import torch 15 | from pytorch_pretrained_bert import BertConfig 16 | 17 | import distributed 18 | from models import data_loader, model_builder 19 | from models.data_loader import load_dataset 20 | from models.model_builder import Summarizer 21 | from models.trainer import build_trainer 22 | from others.logging import logger, init_logger 23 | 24 | model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers','encoder','ff_actv', 'use_interval','rnn_size'] 25 | 26 | 27 | def str2bool(v): 28 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 29 | return True 30 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 31 | return False 32 | else: 33 | raise argparse.ArgumentTypeError('Boolean value expected.') 34 | 35 | 36 | 37 | def multi_main(args): 38 | """ Spawns 1 process per GPU """ 39 | init_logger() 40 | 41 | nb_gpu = args.world_size 42 | mp = torch.multiprocessing.get_context('spawn') 43 | 44 | # Create a thread to listen for errors in the child processes. 45 | error_queue = mp.SimpleQueue() 46 | error_handler = ErrorHandler(error_queue) 47 | 48 | # Train with multiprocessing. 49 | procs = [] 50 | for i in range(nb_gpu): 51 | device_id = i 52 | procs.append(mp.Process(target=run, args=(args, 53 | device_id, error_queue,), daemon=True)) 54 | procs[i].start() 55 | logger.info(" Starting process pid: %d " % procs[i].pid) 56 | error_handler.add_child(procs[i].pid) 57 | for p in procs: 58 | p.join() 59 | 60 | 61 | 62 | def run(args, device_id, error_queue): 63 | 64 | """ run process """ 65 | setattr(args, 'gpu_ranks', [int(i) for i in args.gpu_ranks]) 66 | 67 | try: 68 | gpu_rank = distributed.multi_init(device_id, args.world_size, args.gpu_ranks) 69 | print('gpu_rank %d' %gpu_rank) 70 | if gpu_rank != args.gpu_ranks[device_id]: 71 | raise AssertionError("An error occurred in \ 72 | Distributed initialization") 73 | 74 | train(args,device_id) 75 | except KeyboardInterrupt: 76 | pass # killed by parent, do nothing 77 | except Exception: 78 | # propagate exception to parent process, keeping original traceback 79 | import traceback 80 | error_queue.put((args.gpu_ranks[device_id], traceback.format_exc())) 81 | 82 | 83 | class ErrorHandler(object): 84 | """A class that listens for exceptions in children processes and propagates 85 | the tracebacks to the parent process.""" 86 | 87 | def __init__(self, error_queue): 88 | """ init error handler """ 89 | import signal 90 | import threading 91 | self.error_queue = error_queue 92 | self.children_pids = [] 93 | self.error_thread = threading.Thread( 94 | target=self.error_listener, daemon=True) 95 | self.error_thread.start() 96 | signal.signal(signal.SIGUSR1, self.signal_handler) 97 | 98 | def add_child(self, pid): 99 | """ error handler """ 100 | self.children_pids.append(pid) 101 | 102 | def error_listener(self): 103 | """ error listener """ 104 | (rank, original_trace) = self.error_queue.get() 105 | self.error_queue.put((rank, original_trace)) 106 | os.kill(os.getpid(), signal.SIGUSR1) 107 | 108 | def signal_handler(self, signalnum, stackframe): 109 | """ signal handler """ 110 | for pid in self.children_pids: 111 | os.kill(pid, signal.SIGINT) # kill children processes 112 | (rank, original_trace) = self.error_queue.get() 113 | msg = """\n\n-- Tracebacks above this line can probably 114 | be ignored --\n\n""" 115 | msg += original_trace 116 | raise Exception(msg) 117 | 118 | 119 | 120 | def wait_and_validate(args, device_id): 121 | 122 | timestep = 0 123 | if (args.test_all): 124 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 125 | cp_files.sort(key=os.path.getmtime) 126 | xent_lst = [] 127 | for i, cp in enumerate(cp_files): 128 | step = int(cp.split('.')[-2].split('_')[-1]) 129 | xent = validate(args, device_id, cp, step) 130 | xent_lst.append((xent, cp)) 131 | max_step = xent_lst.index(min(xent_lst)) 132 | if (i - max_step > 10): 133 | break 134 | xent_lst = sorted(xent_lst, key=lambda x: x[0])[:3] 135 | logger.info('PPL %s' % str(xent_lst)) 136 | for xent, cp in xent_lst: 137 | step = int(cp.split('.')[-2].split('_')[-1]) 138 | test(args, device_id, cp, step) 139 | else: 140 | while (True): 141 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 142 | cp_files.sort(key=os.path.getmtime) 143 | if (cp_files): 144 | cp = cp_files[-1] 145 | time_of_cp = os.path.getmtime(cp) 146 | if (not os.path.getsize(cp) > 0): 147 | time.sleep(60) 148 | continue 149 | if (time_of_cp > timestep): 150 | timestep = time_of_cp 151 | step = int(cp.split('.')[-2].split('_')[-1]) 152 | validate(args, device_id, cp, step) 153 | test(args, device_id, cp, step) 154 | 155 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 156 | cp_files.sort(key=os.path.getmtime) 157 | if (cp_files): 158 | cp = cp_files[-1] 159 | time_of_cp = os.path.getmtime(cp) 160 | if (time_of_cp > timestep): 161 | continue 162 | else: 163 | time.sleep(300) 164 | 165 | 166 | def validate(args, device_id, pt, step): 167 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 168 | if (pt != ''): 169 | test_from = pt 170 | else: 171 | test_from = args.test_from 172 | logger.info('Loading checkpoint from %s' % test_from) 173 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) 174 | opt = vars(checkpoint['opt']) 175 | for k in opt.keys(): 176 | if (k in model_flags): 177 | setattr(args, k, opt[k]) 178 | print(args) 179 | 180 | config = BertConfig.from_json_file(args.bert_config_path) 181 | model = Summarizer(args, device, load_pretrained_bert=False, bert_config = config) 182 | model.load_cp(checkpoint) 183 | model.eval() 184 | 185 | valid_iter =data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), 186 | args.batch_size, device, 187 | shuffle=False, is_test=False) 188 | trainer = build_trainer(args, device_id, model, None) 189 | stats = trainer.validate(valid_iter, step) 190 | return stats.xent() 191 | 192 | def test(args, device_id, pt, step): 193 | 194 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 195 | if (pt != ''): 196 | test_from = pt 197 | else: 198 | test_from = args.test_from 199 | logger.info('Loading checkpoint from %s' % test_from) 200 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) 201 | opt = vars(checkpoint['opt']) 202 | for k in opt.keys(): 203 | if (k in model_flags): 204 | setattr(args, k, opt[k]) 205 | print(args) 206 | 207 | config = BertConfig.from_json_file(args.bert_config_path) 208 | model = Summarizer(args, device, load_pretrained_bert=False, bert_config = config) 209 | model.load_cp(checkpoint) 210 | model.eval() 211 | 212 | test_iter =data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), 213 | args.batch_size, device, 214 | shuffle=False, is_test=True) 215 | trainer = build_trainer(args, device_id, model, None) 216 | trainer.test(test_iter,step) 217 | 218 | 219 | def baseline(args, cal_lead=False, cal_oracle=False): 220 | 221 | test_iter =data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), 222 | args.batch_size, device, 223 | shuffle=False, is_test=True) 224 | 225 | trainer = build_trainer(args, device_id, None, None) 226 | # 227 | if (cal_lead): 228 | trainer.test(test_iter, 0, cal_lead=True) 229 | elif (cal_oracle): 230 | trainer.test(test_iter, 0, cal_oracle=True) 231 | 232 | 233 | def train(args, device_id): 234 | init_logger(args.log_file) 235 | 236 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 237 | logger.info('Device ID %d' % device_id) 238 | logger.info('Device %s' % device) 239 | torch.manual_seed(args.seed) 240 | random.seed(args.seed) 241 | torch.backends.cudnn.deterministic = True 242 | 243 | if device_id >= 0: 244 | torch.cuda.set_device(device_id) 245 | torch.cuda.manual_seed(args.seed) 246 | 247 | 248 | torch.manual_seed(args.seed) 249 | random.seed(args.seed) 250 | torch.backends.cudnn.deterministic = True 251 | 252 | def train_iter_fct(): 253 | return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device, 254 | shuffle=True, is_test=False) 255 | 256 | model = Summarizer(args, device, load_pretrained_bert=True) 257 | if args.train_from != '': 258 | logger.info('Loading checkpoint from %s' % args.train_from) 259 | checkpoint = torch.load(args.train_from, 260 | map_location=lambda storage, loc: storage) 261 | opt = vars(checkpoint['opt']) 262 | for k in opt.keys(): 263 | if (k in model_flags): 264 | setattr(args, k, opt[k]) 265 | model.load_cp(checkpoint) 266 | optim = model_builder.build_optim(args, model, checkpoint) 267 | else: 268 | optim = model_builder.build_optim(args, model, None) 269 | 270 | logger.info(model) 271 | trainer = build_trainer(args, device_id, model, optim) 272 | trainer.train(train_iter_fct, args.train_steps) 273 | 274 | 275 | 276 | if __name__ == '__main__': 277 | parser = argparse.ArgumentParser() 278 | 279 | 280 | 281 | parser.add_argument("-encoder", default='classifier', type=str, choices=['classifier','transformer','rnn','baseline']) 282 | parser.add_argument("-mode", default='train', type=str, choices=['train','validate','test']) 283 | parser.add_argument("-bert_data_path", default='../bert_data/LCSTS') 284 | parser.add_argument("-model_path", default='../models/') 285 | parser.add_argument("-result_path", default='../results/LCSTS') 286 | parser.add_argument("-temp_dir", default='../temp') 287 | parser.add_argument("-bert_config_path", default='../bert_config.json') 288 | 289 | parser.add_argument("-batch_size", default=1000, type=int) 290 | 291 | parser.add_argument("-use_interval", type=str2bool, nargs='?',const=True,default=True) 292 | parser.add_argument("-hidden_size", default=128, type=int) 293 | parser.add_argument("-ff_size", default=512, type=int) 294 | parser.add_argument("-heads", default=4, type=int) 295 | parser.add_argument("-inter_layers", default=2, type=int) 296 | parser.add_argument("-rnn_size", default=512, type=int) 297 | 298 | parser.add_argument("-param_init", default=0, type=float) 299 | parser.add_argument("-param_init_glorot", type=str2bool, nargs='?',const=True,default=True) 300 | parser.add_argument("-dropout", default=0.1, type=float) 301 | parser.add_argument("-optim", default='adam', type=str) 302 | parser.add_argument("-lr", default=1, type=float) 303 | parser.add_argument("-beta1", default= 0.9, type=float) 304 | parser.add_argument("-beta2", default=0.999, type=float) 305 | parser.add_argument("-decay_method", default='', type=str) 306 | parser.add_argument("-warmup_steps", default=8000, type=int) 307 | parser.add_argument("-max_grad_norm", default=0, type=float) 308 | 309 | parser.add_argument("-save_checkpoint_steps", default=5, type=int) 310 | parser.add_argument("-accum_count", default=1, type=int) 311 | parser.add_argument("-world_size", default=1, type=int) 312 | parser.add_argument("-report_every", default=1, type=int) 313 | parser.add_argument("-train_steps", default=1000, type=int) 314 | parser.add_argument("-recall_eval", type=str2bool, nargs='?',const=True,default=False) 315 | 316 | 317 | parser.add_argument('-visible_gpus', default='-1', type=str) 318 | parser.add_argument('-gpu_ranks', default='0', type=str) 319 | parser.add_argument('-log_file', default='../logs/LCSTS.log') 320 | parser.add_argument('-dataset', default='') 321 | parser.add_argument('-seed', default=666, type=int) 322 | 323 | parser.add_argument("-test_all", type=str2bool, nargs='?',const=True,default=False) 324 | parser.add_argument("-test_from", default='') 325 | parser.add_argument("-train_from", default='') 326 | parser.add_argument("-report_rouge", type=str2bool, nargs='?',const=True,default=True) 327 | parser.add_argument("-block_trigram", type=str2bool, nargs='?', const=True, default=True) 328 | 329 | args = parser.parse_args() 330 | args.gpu_ranks = [int(i) for i in args.gpu_ranks.split(',')] 331 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus 332 | 333 | init_logger(args.log_file) 334 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 335 | device_id = 0 if device == "cuda" else -1 336 | 337 | if(args.world_size>1): 338 | multi_main(args) 339 | elif (args.mode == 'train'): 340 | train(args, device_id) 341 | elif (args.mode == 'validate'): 342 | wait_and_validate(args, device_id) 343 | elif (args.mode == 'lead'): 344 | baseline(args, cal_lead=True) 345 | elif (args.mode == 'oracle'): 346 | baseline(args, cal_oracle=True) 347 | elif (args.mode == 'test'): 348 | cp = args.test_from 349 | try: 350 | step = int(cp.split('.')[-2].split('_')[-1]) 351 | except: 352 | step = 0 353 | test(args, device_id, cp, step) 354 | -------------------------------------------------------------------------------- /src/models/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from tensorboardX import SummaryWriter 6 | 7 | import distributed 8 | # import onmt 9 | from models.reporter import ReportMgr 10 | from models.stats import Statistics 11 | from others.logging import logger 12 | from others.utils import test_rouge, rouge_results_to_str 13 | 14 | 15 | def _tally_parameters(model): 16 | n_params = sum([p.nelement() for p in model.parameters()]) 17 | return n_params 18 | 19 | 20 | def build_trainer(args, device_id, model, 21 | optim): 22 | """ 23 | Simplify `Trainer` creation based on user `opt`s* 24 | Args: 25 | opt (:obj:`Namespace`): user options (usually from argument parsing) 26 | model (:obj:`onmt.models.NMTModel`): the model to train 27 | fields (dict): dict of fields 28 | optim (:obj:`onmt.utils.Optimizer`): optimizer used during training 29 | data_type (str): string describing the type of data 30 | e.g. "text", "img", "audio" 31 | model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object 32 | used to save the model 33 | """ 34 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 35 | 36 | 37 | grad_accum_count = args.accum_count 38 | n_gpu = args.world_size 39 | 40 | if device_id >= 0: 41 | gpu_rank = int(args.gpu_ranks[device_id]) 42 | else: 43 | gpu_rank = 0 44 | n_gpu = 0 45 | 46 | print('gpu_rank %d' % gpu_rank) 47 | 48 | tensorboard_log_dir = args.model_path 49 | 50 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") 51 | 52 | report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) 53 | 54 | trainer = Trainer(args, model, optim, grad_accum_count, n_gpu, gpu_rank, report_manager) 55 | 56 | # print(tr) 57 | if (model): 58 | n_params = _tally_parameters(model) 59 | logger.info('* number of parameters: %d' % n_params) 60 | 61 | return trainer 62 | 63 | 64 | class Trainer(object): 65 | """ 66 | Class that controls the training process. 67 | 68 | Args: 69 | model(:py:class:`onmt.models.model.NMTModel`): translation model 70 | to train 71 | train_loss(:obj:`onmt.utils.loss.LossComputeBase`): 72 | training loss computation 73 | valid_loss(:obj:`onmt.utils.loss.LossComputeBase`): 74 | training loss computation 75 | optim(:obj:`onmt.utils.optimizers.Optimizer`): 76 | the optimizer responsible for update 77 | trunc_size(int): length of truncated back propagation through time 78 | shard_size(int): compute loss in shards of this size for efficiency 79 | data_type(string): type of the source input: [text|img|audio] 80 | norm_method(string): normalization methods: [sents|tokens] 81 | grad_accum_count(int): accumulate gradients this many times. 82 | report_manager(:obj:`onmt.utils.ReportMgrBase`): 83 | the object that creates reports, or None 84 | model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is 85 | used to save a checkpoint. 86 | Thus nothing will be saved if this parameter is None 87 | """ 88 | 89 | def __init__(self, args, model, optim, 90 | grad_accum_count=1, n_gpu=1, gpu_rank=1, 91 | report_manager=None): 92 | # Basic attributes. 93 | self.args = args 94 | self.save_checkpoint_steps = args.save_checkpoint_steps 95 | self.model = model 96 | self.optim = optim 97 | self.grad_accum_count = grad_accum_count 98 | self.n_gpu = n_gpu 99 | self.gpu_rank = gpu_rank 100 | self.report_manager = report_manager 101 | 102 | self.loss = torch.nn.BCELoss(reduction='none') 103 | assert grad_accum_count > 0 104 | # Set model in training mode. 105 | if (model): 106 | self.model.train() 107 | 108 | def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): 109 | """ 110 | The main training loops. 111 | by iterating over training data (i.e. `train_iter_fct`) 112 | and running validation (i.e. iterating over `valid_iter_fct` 113 | 114 | Args: 115 | train_iter_fct(function): a function that returns the train 116 | iterator. e.g. something like 117 | train_iter_fct = lambda: generator(*args, **kwargs) 118 | valid_iter_fct(function): same as train_iter_fct, for valid data 119 | train_steps(int): 120 | valid_steps(int): 121 | save_checkpoint_steps(int): 122 | 123 | Return: 124 | None 125 | """ 126 | logger.info('Start training...') 127 | 128 | # step = self.optim._step + 1 129 | step = self.optim._step + 1 130 | true_batchs = [] 131 | accum = 0 132 | normalization = 0 133 | train_iter = train_iter_fct() 134 | 135 | total_stats = Statistics() 136 | report_stats = Statistics() 137 | self._start_report_manager(start_time=total_stats.start_time) 138 | 139 | while step <= train_steps: 140 | 141 | reduce_counter = 0 142 | for i, batch in enumerate(train_iter): 143 | if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): 144 | 145 | true_batchs.append(batch) 146 | normalization += batch.batch_size 147 | accum += 1 148 | if accum == self.grad_accum_count: 149 | reduce_counter += 1 150 | if self.n_gpu > 1: 151 | normalization = sum(distributed 152 | .all_gather_list 153 | (normalization)) 154 | 155 | self._gradient_accumulation( 156 | true_batchs, normalization, total_stats, 157 | report_stats) 158 | 159 | report_stats = self._maybe_report_training( 160 | step, train_steps, 161 | self.optim.learning_rate, 162 | report_stats) 163 | 164 | true_batchs = [] 165 | accum = 0 166 | normalization = 0 167 | if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): 168 | self._save(step) 169 | 170 | step += 1 171 | if step > train_steps: 172 | break 173 | train_iter = train_iter_fct() 174 | 175 | return total_stats 176 | 177 | def validate(self, valid_iter, step=0): 178 | """ Validate model. 179 | valid_iter: validate data iterator 180 | Returns: 181 | :obj:`nmt.Statistics`: validation loss statistics 182 | """ 183 | # Set model in validating mode. 184 | self.model.eval() 185 | stats = Statistics() 186 | 187 | with torch.no_grad(): 188 | for batch in valid_iter: 189 | 190 | src = batch.src 191 | labels = batch.labels 192 | segs = batch.segs 193 | clss = batch.clss 194 | mask = batch.mask 195 | mask_cls = batch.mask_cls 196 | 197 | sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) 198 | 199 | 200 | loss = self.loss(sent_scores, labels.float()) 201 | loss = (loss * mask.float()).sum() 202 | batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels)) 203 | stats.update(batch_stats) 204 | self._report_step(0, step, valid_stats=stats) 205 | return stats 206 | 207 | def test(self, test_iter, step, cal_lead=False, cal_oracle=False): 208 | """ Validate model. 209 | valid_iter: validate data iterator 210 | Returns: 211 | :obj:`nmt.Statistics`: validation loss statistics 212 | """ 213 | # Set model in validating mode. 214 | def _get_ngrams(n, text): 215 | ngram_set = set() 216 | text_length = len(text) 217 | max_index_ngram_start = text_length - n 218 | for i in range(max_index_ngram_start + 1): 219 | ngram_set.add(tuple(text[i:i + n])) 220 | return ngram_set 221 | 222 | def _block_tri(c, p): 223 | tri_c = _get_ngrams(3, c.split()) 224 | for s in p: 225 | tri_s = _get_ngrams(3, s.split()) 226 | if len(tri_c.intersection(tri_s))>0: 227 | return True 228 | return False 229 | 230 | if (not cal_lead and not cal_oracle): 231 | self.model.eval() 232 | stats = Statistics() 233 | 234 | can_path = '%s_step%d.candidate'%(self.args.result_path,step) 235 | gold_path = '%s_step%d.gold' % (self.args.result_path, step) 236 | with open(can_path, 'w') as save_pred: 237 | with open(gold_path, 'w') as save_gold: 238 | with torch.no_grad(): 239 | for batch in test_iter: 240 | src = batch.src 241 | labels = batch.labels 242 | segs = batch.segs 243 | clss = batch.clss 244 | mask = batch.mask 245 | mask_cls = batch.mask_cls 246 | 247 | 248 | gold = [] 249 | pred = [] 250 | 251 | if (cal_lead): 252 | selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size 253 | elif (cal_oracle): 254 | selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in 255 | range(batch.batch_size)] 256 | else: 257 | sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) 258 | 259 | loss = self.loss(sent_scores, labels.float()) 260 | loss = (loss * mask.float()).sum() 261 | batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels)) 262 | stats.update(batch_stats) 263 | 264 | sent_scores = sent_scores + mask.float() 265 | sent_scores = sent_scores.cpu().data.numpy() 266 | selected_ids = np.argsort(-sent_scores, 1) 267 | # selected_ids = np.sort(selected_ids,1) 268 | for i, idx in enumerate(selected_ids): 269 | _pred = [] 270 | if(len(batch.src_str[i])==0): 271 | continue 272 | for j in selected_ids[i][:len(batch.src_str[i])]: 273 | if(j>=len( batch.src_str[i])): 274 | continue 275 | candidate = batch.src_str[i][j].strip() 276 | if(self.args.block_trigram): 277 | if(not _block_tri(candidate,_pred)): 278 | _pred.append(candidate) 279 | else: 280 | _pred.append(candidate) 281 | 282 | if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3): 283 | break 284 | 285 | _pred = ''.join(_pred) 286 | if(self.args.recall_eval): 287 | _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())]) 288 | 289 | pred.append(_pred) 290 | gold.append(batch.tgt_str[i]) 291 | 292 | for i in range(len(gold)): 293 | save_gold.write(gold[i].strip()+'\n') 294 | for i in range(len(pred)): 295 | save_pred.write(pred[i].strip()+'\n') 296 | if(step!=-1 and self.args.report_rouge): 297 | rouges = test_rouge(self.args.temp_dir, can_path, gold_path) 298 | logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) 299 | self._report_step(0, step, valid_stats=stats) 300 | 301 | return stats 302 | 303 | 304 | 305 | def _gradient_accumulation(self, true_batchs, normalization, total_stats, 306 | report_stats): 307 | if self.grad_accum_count > 1: 308 | self.model.zero_grad() 309 | 310 | for batch in true_batchs: 311 | if self.grad_accum_count == 1: 312 | self.model.zero_grad() 313 | 314 | src = batch.src 315 | labels = batch.labels 316 | segs = batch.segs 317 | clss = batch.clss 318 | mask = batch.mask 319 | mask_cls = batch.mask_cls 320 | 321 | sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) 322 | 323 | loss = self.loss(sent_scores, labels.float()) 324 | loss = (loss*mask.float()).sum() 325 | (loss/loss.numel()).backward() 326 | # loss.div(float(normalization)).backward() 327 | 328 | batch_stats = Statistics(float(loss.cpu().data.numpy()), normalization) 329 | 330 | 331 | total_stats.update(batch_stats) 332 | report_stats.update(batch_stats) 333 | 334 | # 4. Update the parameters and statistics. 335 | if self.grad_accum_count == 1: 336 | # Multi GPU gradient gather 337 | if self.n_gpu > 1: 338 | grads = [p.grad.data for p in self.model.parameters() 339 | if p.requires_grad 340 | and p.grad is not None] 341 | distributed.all_reduce_and_rescale_tensors( 342 | grads, float(1)) 343 | self.optim.step() 344 | 345 | # in case of multi step gradient accumulation, 346 | # update only after accum batches 347 | if self.grad_accum_count > 1: 348 | if self.n_gpu > 1: 349 | grads = [p.grad.data for p in self.model.parameters() 350 | if p.requires_grad 351 | and p.grad is not None] 352 | distributed.all_reduce_and_rescale_tensors( 353 | grads, float(1)) 354 | self.optim.step() 355 | 356 | def _save(self, step): 357 | real_model = self.model 358 | # real_generator = (self.generator.module 359 | # if isinstance(self.generator, torch.nn.DataParallel) 360 | # else self.generator) 361 | 362 | model_state_dict = real_model.state_dict() 363 | # generator_state_dict = real_generator.state_dict() 364 | checkpoint = { 365 | 'model': model_state_dict, 366 | # 'generator': generator_state_dict, 367 | 'opt': self.args, 368 | 'optim': self.optim, 369 | } 370 | checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step) 371 | logger.info("Saving checkpoint %s" % checkpoint_path) 372 | # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step) 373 | if (not os.path.exists(checkpoint_path)): 374 | torch.save(checkpoint, checkpoint_path) 375 | return checkpoint, checkpoint_path 376 | 377 | def _start_report_manager(self, start_time=None): 378 | """ 379 | Simple function to start report manager (if any) 380 | """ 381 | if self.report_manager is not None: 382 | if start_time is None: 383 | self.report_manager.start() 384 | else: 385 | self.report_manager.start_time = start_time 386 | 387 | def _maybe_gather_stats(self, stat): 388 | """ 389 | Gather statistics in multi-processes cases 390 | 391 | Args: 392 | stat(:obj:onmt.utils.Statistics): a Statistics object to gather 393 | or None (it returns None in this case) 394 | 395 | Returns: 396 | stat: the updated (or unchanged) stat object 397 | """ 398 | if stat is not None and self.n_gpu > 1: 399 | return Statistics.all_gather_stats(stat) 400 | return stat 401 | 402 | def _maybe_report_training(self, step, num_steps, learning_rate, 403 | report_stats): 404 | """ 405 | Simple function to report training stats (if report_manager is set) 406 | see `onmt.utils.ReportManagerBase.report_training` for doc 407 | """ 408 | if self.report_manager is not None: 409 | return self.report_manager.report_training( 410 | step, num_steps, learning_rate, report_stats, 411 | multigpu=self.n_gpu > 1) 412 | 413 | def _report_step(self, learning_rate, step, train_stats=None, 414 | valid_stats=None): 415 | """ 416 | Simple function to report stats (if report_manager is set) 417 | see `onmt.utils.ReportManagerBase.report_step` for doc 418 | """ 419 | if self.report_manager is not None: 420 | return self.report_manager.report_step( 421 | learning_rate, step, train_stats=train_stats, 422 | valid_stats=valid_stats) 423 | 424 | def _maybe_save(self, step): 425 | """ 426 | Save the model if a model saver is set 427 | """ 428 | if self.model_saver is not None: 429 | self.model_saver.maybe_save(step) 430 | -------------------------------------------------------------------------------- /src/others/pyrouge.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals, division 2 | 3 | import os 4 | import re 5 | import codecs 6 | import platform 7 | 8 | from subprocess import check_output 9 | from tempfile import mkdtemp 10 | from functools import partial 11 | 12 | try: 13 | from configparser import ConfigParser 14 | except ImportError: 15 | from ConfigParser import ConfigParser 16 | 17 | from pyrouge.utils import log 18 | from pyrouge.utils.file_utils import verify_dir 19 | 20 | 21 | REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", 22 | "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} 23 | 24 | 25 | def clean(x): 26 | return re.sub( 27 | r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", 28 | lambda m: REMAP.get(m.group()), x) 29 | 30 | 31 | class DirectoryProcessor: 32 | 33 | @staticmethod 34 | def process(input_dir, output_dir, function): 35 | """ 36 | Apply function to all files in input_dir and save the resulting ouput 37 | files in output_dir. 38 | 39 | """ 40 | if not os.path.exists(output_dir): 41 | os.makedirs(output_dir) 42 | logger = log.get_global_console_logger() 43 | logger.info("Processing files in {}.".format(input_dir)) 44 | input_file_names = os.listdir(input_dir) 45 | for input_file_name in input_file_names: 46 | input_file = os.path.join(input_dir, input_file_name) 47 | with codecs.open(input_file, "r", encoding="UTF-8") as f: 48 | input_string = f.read() 49 | output_string = function(input_string) 50 | output_file = os.path.join(output_dir, input_file_name) 51 | with codecs.open(output_file, "w", encoding="UTF-8") as f: 52 | f.write(clean(output_string.lower())) 53 | logger.info("Saved processed files to {}.".format(output_dir)) 54 | 55 | 56 | class Rouge155(object): 57 | """ 58 | This is a wrapper for the ROUGE 1.5.5 summary evaluation package. 59 | This class is designed to simplify the evaluation process by: 60 | 61 | 1) Converting summaries into a format ROUGE understands. 62 | 2) Generating the ROUGE configuration file automatically based 63 | on filename patterns. 64 | 65 | This class can be used within Python like this: 66 | 67 | rouge = Rouge155() 68 | rouge.system_dir = 'test/systems' 69 | rouge.model_dir = 'test/models' 70 | 71 | # The system filename pattern should contain one group that 72 | # matches the document ID. 73 | rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html' 74 | 75 | # The model filename pattern has '#ID#' as a placeholder for the 76 | # document ID. If there are multiple model summaries, pyrouge 77 | # will use the provided regex to automatically match them with 78 | # the corresponding system summary. Here, [A-Z] matches 79 | # multiple model summaries for a given #ID#. 80 | rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 81 | 82 | rouge_output = rouge.evaluate() 83 | print(rouge_output) 84 | output_dict = rouge.output_to_dict(rouge_ouput) 85 | print(output_dict) 86 | -> {'rouge_1_f_score': 0.95652, 87 | 'rouge_1_f_score_cb': 0.95652, 88 | 'rouge_1_f_score_ce': 0.95652, 89 | 'rouge_1_precision': 0.95652, 90 | [...] 91 | 92 | 93 | To evaluate multiple systems: 94 | 95 | rouge = Rouge155() 96 | rouge.system_dir = '/PATH/TO/systems' 97 | rouge.model_dir = 'PATH/TO/models' 98 | for system_id in ['id1', 'id2', 'id3']: 99 | rouge.system_filename_pattern = \ 100 | 'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id) 101 | rouge.model_filename_pattern = \ 102 | 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 103 | rouge_output = rouge.evaluate(system_id) 104 | print(rouge_output) 105 | 106 | """ 107 | 108 | def __init__(self, rouge_dir=None, rouge_args=None, temp_dir = None): 109 | """ 110 | Create a Rouge155 object. 111 | 112 | rouge_dir: Directory containing Rouge-1.5.5.pl 113 | rouge_args: Arguments to pass through to ROUGE if you 114 | don't want to use the default pyrouge 115 | arguments. 116 | 117 | """ 118 | self.temp_dir=temp_dir 119 | self.log = log.get_global_console_logger() 120 | self.__set_dir_properties() 121 | self._config_file = None 122 | self._settings_file = self.__get_config_path() 123 | self.__set_rouge_dir(rouge_dir) 124 | self.args = self.__clean_rouge_args(rouge_args) 125 | self._system_filename_pattern = None 126 | self._model_filename_pattern = None 127 | 128 | def save_home_dir(self): 129 | config = ConfigParser() 130 | section = 'pyrouge settings' 131 | config.add_section(section) 132 | config.set(section, 'home_dir', self._home_dir) 133 | with open(self._settings_file, 'w') as f: 134 | config.write(f) 135 | self.log.info("Set ROUGE home directory to {}.".format(self._home_dir)) 136 | 137 | @property 138 | def settings_file(self): 139 | """ 140 | Path of the setttings file, which stores the ROUGE home dir. 141 | 142 | """ 143 | return self._settings_file 144 | 145 | @property 146 | def bin_path(self): 147 | """ 148 | The full path of the ROUGE binary (although it's technically 149 | a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl 150 | 151 | """ 152 | if self._bin_path is None: 153 | raise Exception( 154 | "ROUGE path not set. Please set the ROUGE home directory " 155 | "and ensure that ROUGE-1.5.5.pl exists in it.") 156 | return self._bin_path 157 | 158 | @property 159 | def system_filename_pattern(self): 160 | """ 161 | The regular expression pattern for matching system summary 162 | filenames. The regex string. 163 | 164 | E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system 165 | filenames in the SPL2003/system folder of the ROUGE SPL example 166 | in the "sample-test" folder. 167 | 168 | Currently, there is no support for multiple systems. 169 | 170 | """ 171 | return self._system_filename_pattern 172 | 173 | @system_filename_pattern.setter 174 | def system_filename_pattern(self, pattern): 175 | self._system_filename_pattern = pattern 176 | 177 | @property 178 | def model_filename_pattern(self): 179 | """ 180 | The regular expression pattern for matching model summary 181 | filenames. The pattern needs to contain the string "#ID#", 182 | which is a placeholder for the document ID. 183 | 184 | E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model 185 | filenames in the SPL2003/system folder of the ROUGE SPL 186 | example in the "sample-test" folder. 187 | 188 | "#ID#" is a placeholder for the document ID which has been 189 | matched by the "(\d+)" part of the system filename pattern. 190 | The different model summaries for a given document ID are 191 | matched by the "[A-Z]" part. 192 | 193 | """ 194 | return self._model_filename_pattern 195 | 196 | @model_filename_pattern.setter 197 | def model_filename_pattern(self, pattern): 198 | self._model_filename_pattern = pattern 199 | 200 | @property 201 | def config_file(self): 202 | return self._config_file 203 | 204 | @config_file.setter 205 | def config_file(self, path): 206 | config_dir, _ = os.path.split(path) 207 | verify_dir(config_dir, "configuration file") 208 | self._config_file = path 209 | 210 | def split_sentences(self): 211 | """ 212 | ROUGE requires texts split into sentences. In case the texts 213 | are not already split, this method can be used. 214 | 215 | """ 216 | from pyrouge.utils.sentence_splitter import PunktSentenceSplitter 217 | self.log.info("Splitting sentences.") 218 | ss = PunktSentenceSplitter() 219 | sent_split_to_string = lambda s: "\n".join(ss.split(s)) 220 | process_func = partial( 221 | DirectoryProcessor.process, function=sent_split_to_string) 222 | self.__process_summaries(process_func) 223 | 224 | @staticmethod 225 | def convert_summaries_to_rouge_format(input_dir, output_dir): 226 | """ 227 | Convert all files in input_dir into a format ROUGE understands 228 | and saves the files to output_dir. The input files are assumed 229 | to be plain text with one sentence per line. 230 | 231 | input_dir: Path of directory containing the input files. 232 | output_dir: Path of directory in which the converted files 233 | will be saved. 234 | 235 | """ 236 | DirectoryProcessor.process( 237 | input_dir, output_dir, Rouge155.convert_text_to_rouge_format) 238 | 239 | @staticmethod 240 | def convert_text_to_rouge_format(text, title="dummy title"): 241 | """ 242 | Convert a text to a format ROUGE understands. The text is 243 | assumed to contain one sentence per line. 244 | 245 | text: The text to convert, containg one sentence per line. 246 | title: Optional title for the text. The title will appear 247 | in the converted file, but doesn't seem to have 248 | any other relevance. 249 | 250 | Returns: The converted text as string. 251 | 252 | """ 253 | # sentences = text.split("\n") 254 | sentences = text.split("") 255 | sent_elems = [ 256 | "[{i}] " 257 | "{text}".format(i=i, text=sent) 258 | for i, sent in enumerate(sentences, start=1)] 259 | html = """ 260 | 261 | {title} 262 | 263 | 264 | {elems} 265 | 266 | """.format(title=title, elems="\n".join(sent_elems)) 267 | 268 | return html 269 | 270 | @staticmethod 271 | def write_config_static(system_dir, system_filename_pattern, 272 | model_dir, model_filename_pattern, 273 | config_file_path, system_id=None): 274 | """ 275 | Write the ROUGE configuration file, which is basically a list 276 | of system summary files and their corresponding model summary 277 | files. 278 | 279 | pyrouge uses regular expressions to automatically find the 280 | matching model summary files for a given system summary file 281 | (cf. docstrings for system_filename_pattern and 282 | model_filename_pattern). 283 | 284 | system_dir: Path of directory containing 285 | system summaries. 286 | system_filename_pattern: Regex string for matching 287 | system summary filenames. 288 | model_dir: Path of directory containing 289 | model summaries. 290 | model_filename_pattern: Regex string for matching model 291 | summary filenames. 292 | config_file_path: Path of the configuration file. 293 | system_id: Optional system ID string which 294 | will appear in the ROUGE output. 295 | 296 | """ 297 | system_filenames = [f for f in os.listdir(system_dir)] 298 | system_models_tuples = [] 299 | 300 | system_filename_pattern = re.compile(system_filename_pattern) 301 | for system_filename in sorted(system_filenames): 302 | match = system_filename_pattern.match(system_filename) 303 | if match: 304 | id = match.groups(0)[0] 305 | model_filenames = [model_filename_pattern.replace('#ID#',id)] 306 | # model_filenames = Rouge155.__get_model_filenames_for_id( 307 | # id, model_dir, model_filename_pattern) 308 | system_models_tuples.append( 309 | (system_filename, sorted(model_filenames))) 310 | if not system_models_tuples: 311 | raise Exception( 312 | "Did not find any files matching the pattern {} " 313 | "in the system summaries directory {}.".format( 314 | system_filename_pattern.pattern, system_dir)) 315 | 316 | with codecs.open(config_file_path, 'w', encoding='utf-8') as f: 317 | f.write('') 318 | for task_id, (system_filename, model_filenames) in enumerate( 319 | system_models_tuples, start=1): 320 | 321 | eval_string = Rouge155.__get_eval_string( 322 | task_id, system_id, 323 | system_dir, system_filename, 324 | model_dir, model_filenames) 325 | f.write(eval_string) 326 | f.write("") 327 | 328 | def write_config(self, config_file_path=None, system_id=None): 329 | """ 330 | Write the ROUGE configuration file, which is basically a list 331 | of system summary files and their matching model summary files. 332 | 333 | This is a non-static version of write_config_file_static(). 334 | 335 | config_file_path: Path of the configuration file. 336 | system_id: Optional system ID string which will 337 | appear in the ROUGE output. 338 | 339 | """ 340 | if not system_id: 341 | system_id = 1 342 | if (not config_file_path) or (not self._config_dir): 343 | self._config_dir = mkdtemp(dir=self.temp_dir) 344 | config_filename = "rouge_conf.xml" 345 | else: 346 | config_dir, config_filename = os.path.split(config_file_path) 347 | verify_dir(config_dir, "configuration file") 348 | self._config_file = os.path.join(self._config_dir, config_filename) 349 | Rouge155.write_config_static( 350 | self._system_dir, self._system_filename_pattern, 351 | self._model_dir, self._model_filename_pattern, 352 | self._config_file, system_id) 353 | self.log.info( 354 | "Written ROUGE configuration to {}".format(self._config_file)) 355 | 356 | def evaluate(self, system_id=1, rouge_args=None): 357 | """ 358 | Run ROUGE to evaluate the system summaries in system_dir against 359 | the model summaries in model_dir. The summaries are assumed to 360 | be in the one-sentence-per-line HTML format ROUGE understands. 361 | 362 | system_id: Optional system ID which will be printed in 363 | ROUGE's output. 364 | 365 | Returns: Rouge output as string. 366 | 367 | """ 368 | self.write_config(system_id=system_id) 369 | options = self.__get_options(rouge_args) 370 | command = [self._bin_path] + options 371 | self.log.info( 372 | "Running ROUGE with command {}".format(" ".join(command))) 373 | rouge_output = check_output(command).decode("UTF-8") 374 | return rouge_output 375 | 376 | def convert_and_evaluate(self, system_id=1, 377 | split_sentences=False, rouge_args=None): 378 | """ 379 | Convert plain text summaries to ROUGE format and run ROUGE to 380 | evaluate the system summaries in system_dir against the model 381 | summaries in model_dir. Optionally split texts into sentences 382 | in case they aren't already. 383 | 384 | This is just a convenience method combining 385 | convert_summaries_to_rouge_format() and evaluate(). 386 | 387 | split_sentences: Optional argument specifying if 388 | sentences should be split. 389 | system_id: Optional system ID which will be printed 390 | in ROUGE's output. 391 | 392 | Returns: ROUGE output as string. 393 | 394 | """ 395 | if split_sentences: 396 | self.split_sentences() 397 | self.__write_summaries() 398 | rouge_output = self.evaluate(system_id, rouge_args) 399 | return rouge_output 400 | 401 | def output_to_dict(self, output): 402 | """ 403 | Convert the ROUGE output into python dictionary for further 404 | processing. 405 | 406 | """ 407 | #0 ROUGE-1 Average_R: 0.02632 (95%-conf.int. 0.02632 - 0.02632) 408 | pattern = re.compile( 409 | r"(\d+) (ROUGE-\S+) (Average_\w): (\d.\d+) " 410 | r"\(95%-conf.int. (\d.\d+) - (\d.\d+)\)") 411 | results = {} 412 | for line in output.split("\n"): 413 | match = pattern.match(line) 414 | if match: 415 | sys_id, rouge_type, measure, result, conf_begin, conf_end = \ 416 | match.groups() 417 | measure = { 418 | 'Average_R': 'recall', 419 | 'Average_P': 'precision', 420 | 'Average_F': 'f_score' 421 | }[measure] 422 | rouge_type = rouge_type.lower().replace("-", '_') 423 | key = "{}_{}".format(rouge_type, measure) 424 | results[key] = float(result) 425 | results["{}_cb".format(key)] = float(conf_begin) 426 | results["{}_ce".format(key)] = float(conf_end) 427 | return results 428 | 429 | ################################################################### 430 | # Private methods 431 | 432 | def __set_rouge_dir(self, home_dir=None): 433 | """ 434 | Verfify presence of ROUGE-1.5.5.pl and data folder, and set 435 | those paths. 436 | 437 | """ 438 | if not home_dir: 439 | self._home_dir = self.__get_rouge_home_dir_from_settings() 440 | else: 441 | self._home_dir = home_dir 442 | self.save_home_dir() 443 | self._bin_path = os.path.join(self._home_dir, 'ROUGE-1.5.5.pl') 444 | self.data_dir = os.path.join(self._home_dir, 'data') 445 | if not os.path.exists(self._bin_path): 446 | raise Exception( 447 | "ROUGE binary not found at {}. Please set the " 448 | "correct path by running pyrouge_set_rouge_path " 449 | "/path/to/rouge/home.".format(self._bin_path)) 450 | 451 | def __get_rouge_home_dir_from_settings(self): 452 | config = ConfigParser() 453 | with open(self._settings_file) as f: 454 | if hasattr(config, "read_file"): 455 | config.read_file(f) 456 | else: 457 | # use deprecated python 2.x method 458 | config.readfp(f) 459 | rouge_home_dir = config.get('pyrouge settings', 'home_dir') 460 | return rouge_home_dir 461 | 462 | @staticmethod 463 | def __get_eval_string( 464 | task_id, system_id, 465 | system_dir, system_filename, 466 | model_dir, model_filenames): 467 | """ 468 | ROUGE can evaluate several system summaries for a given text 469 | against several model summaries, i.e. there is an m-to-n 470 | relation between system and model summaries. The system 471 | summaries are listed in the tag and the model summaries 472 | in the tag. pyrouge currently only supports one system 473 | summary per text, i.e. it assumes a 1-to-n relation between 474 | system and model summaries. 475 | 476 | """ 477 | peer_elems = "

{name}

".format( 478 | id=system_id, name=system_filename) 479 | 480 | model_elems = ["{name}".format( 481 | id=chr(65 + i), name=name) 482 | for i, name in enumerate(model_filenames)] 483 | 484 | model_elems = "\n\t\t\t".join(model_elems) 485 | eval_string = """ 486 | 487 | {model_root} 488 | {peer_root} 489 | 490 | 491 | 492 | {peer_elems} 493 | 494 | 495 | {model_elems} 496 | 497 | 498 | """.format( 499 | task_id=task_id, 500 | model_root=model_dir, model_elems=model_elems, 501 | peer_root=system_dir, peer_elems=peer_elems) 502 | return eval_string 503 | 504 | def __process_summaries(self, process_func): 505 | """ 506 | Helper method that applies process_func to the files in the 507 | system and model folders and saves the resulting files to new 508 | system and model folders. 509 | 510 | """ 511 | temp_dir = mkdtemp(dir=self.temp_dir) 512 | new_system_dir = os.path.join(temp_dir, "system") 513 | os.mkdir(new_system_dir) 514 | new_model_dir = os.path.join(temp_dir, "model") 515 | os.mkdir(new_model_dir) 516 | self.log.info( 517 | "Processing summaries. Saving system files to {} and " 518 | "model files to {}.".format(new_system_dir, new_model_dir)) 519 | process_func(self._system_dir, new_system_dir) 520 | process_func(self._model_dir, new_model_dir) 521 | self._system_dir = new_system_dir 522 | self._model_dir = new_model_dir 523 | 524 | def __write_summaries(self): 525 | self.log.info("Writing summaries.") 526 | self.__process_summaries(self.convert_summaries_to_rouge_format) 527 | 528 | @staticmethod 529 | def __get_model_filenames_for_id(id, model_dir, model_filenames_pattern): 530 | pattern = re.compile(model_filenames_pattern.replace('#ID#', id)) 531 | model_filenames = [ 532 | f for f in os.listdir(model_dir) if pattern.match(f)] 533 | if not model_filenames: 534 | raise Exception( 535 | "Could not find any model summaries for the system" 536 | " summary with ID {}. Specified model filename pattern was: " 537 | "{}".format(id, model_filenames_pattern)) 538 | return model_filenames 539 | 540 | def __get_options(self, rouge_args=None): 541 | """ 542 | Get supplied command line arguments for ROUGE or use default 543 | ones. 544 | 545 | """ 546 | if self.args: 547 | options = self.args.split() 548 | elif rouge_args: 549 | options = rouge_args.split() 550 | else: 551 | options = [ 552 | '-e', self._data_dir, 553 | '-c', 95, 554 | # '-2', 555 | # '-1', 556 | # '-U', 557 | '-m', 558 | # '-v', 559 | '-r', 1000, 560 | '-n', 2, 561 | # '-w', 1.2, 562 | '-a', 563 | ] 564 | options = list(map(str, options)) 565 | 566 | 567 | 568 | 569 | options = self.__add_config_option(options) 570 | return options 571 | 572 | def __create_dir_property(self, dir_name, docstring): 573 | """ 574 | Generate getter and setter for a directory property. 575 | 576 | """ 577 | property_name = "{}_dir".format(dir_name) 578 | private_name = "_" + property_name 579 | setattr(self, private_name, None) 580 | 581 | def fget(self): 582 | return getattr(self, private_name) 583 | 584 | def fset(self, path): 585 | verify_dir(path, dir_name) 586 | setattr(self, private_name, path) 587 | 588 | p = property(fget=fget, fset=fset, doc=docstring) 589 | setattr(self.__class__, property_name, p) 590 | 591 | def __set_dir_properties(self): 592 | """ 593 | Automatically generate the properties for directories. 594 | 595 | """ 596 | directories = [ 597 | ("home", "The ROUGE home directory."), 598 | ("data", "The path of the ROUGE 'data' directory."), 599 | ("system", "Path of the directory containing system summaries."), 600 | ("model", "Path of the directory containing model summaries."), 601 | ] 602 | for (dirname, docstring) in directories: 603 | self.__create_dir_property(dirname, docstring) 604 | 605 | def __clean_rouge_args(self, rouge_args): 606 | """ 607 | Remove enclosing quotation marks, if any. 608 | 609 | """ 610 | if not rouge_args: 611 | return 612 | quot_mark_pattern = re.compile('"(.+)"') 613 | match = quot_mark_pattern.match(rouge_args) 614 | if match: 615 | cleaned_args = match.group(1) 616 | return cleaned_args 617 | else: 618 | return rouge_args 619 | 620 | def __add_config_option(self, options): 621 | return options + [self._config_file] 622 | 623 | def __get_config_path(self): 624 | if platform.system() == "Windows": 625 | parent_dir = os.getenv("APPDATA") 626 | config_dir_name = "pyrouge" 627 | elif os.name == "posix": 628 | parent_dir = os.path.expanduser("~") 629 | config_dir_name = ".pyrouge" 630 | else: 631 | parent_dir = os.path.dirname(__file__) 632 | config_dir_name = "" 633 | config_dir = os.path.join(parent_dir, config_dir_name) 634 | if not os.path.exists(config_dir): 635 | os.makedirs(config_dir) 636 | return os.path.join(config_dir, 'settings.ini') 637 | 638 | 639 | if __name__ == "__main__": 640 | import argparse 641 | from utils.argparsers import rouge_path_parser 642 | 643 | parser = argparse.ArgumentParser(parents=[rouge_path_parser]) 644 | args = parser.parse_args() 645 | 646 | rouge = Rouge155(args.rouge_home) 647 | rouge.save_home_dir() 648 | --------------------------------------------------------------------------------