├── blanc ├── __version__.py ├── __init__.py ├── shannon.py ├── __main__.py ├── utils.py ├── estime.py └── blanc.py ├── .gitignore ├── requirements.txt ├── data ├── single.json ├── pairs.json ├── doc-summaries.json └── README.md ├── shannon ├── README.md └── summeval_score.py ├── LICENSE ├── setup.py ├── SECURITY.md ├── .github └── workflows │ └── codeql-analysis.yml ├── estime └── README.md └── README.md /blanc/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.4" 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.sw* 3 | ._* 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk>=3.1,<4.0 2 | numpy>=1.0,<2.0 3 | scipy>=1.6.0 4 | torch>=1.0,<2.0 5 | tqdm>=4.27.0,<5.0 6 | transformers>=2.4.0 7 | -------------------------------------------------------------------------------- /blanc/__init__.py: -------------------------------------------------------------------------------- 1 | from .__version__ import __version__ 2 | from .blanc import BlancHelp, BlancTune 3 | from .estime import Estime 4 | from .shannon import Shannon 5 | -------------------------------------------------------------------------------- /data/single.json: -------------------------------------------------------------------------------- 1 | { 2 | "doc": "Jack drove his minivan to the bazaar to purchase milk and honey for his large family.", 3 | "summary": "Jack bought milk and honey." 4 | } 5 | -------------------------------------------------------------------------------- /data/pairs.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "doc": "Jack drove his minivan to the bazaar to purchase milk and honey for his large family.", 4 | "summary": "Jack bought milk and honey." 5 | }, 6 | { 7 | "doc": "As Jill started taking a walk in the park, she certainly noticed that the trees were extra green this year.", 8 | "summary": "Jill saw green trees in the park." 9 | } 10 | ] 11 | -------------------------------------------------------------------------------- /shannon/README.md: -------------------------------------------------------------------------------- 1 | # Shannon Game 2 | 3 | Shannon Score and Information Difference metrics of summary quality are defined in [Play the Shannon Game With Language Models: A Human-Free Approach to Summary Evaluation](https://arxiv.org/abs/2103.10918), in [Proceedings AAAI 2022](https://ojs.aaai.org/index.php/AAAI/article/view/21304). Our implementation of these metrics is [here](https://github.com/PrimerAI/blanc/blob/master/blanc/shannon.py). 4 | -------------------------------------------------------------------------------- /data/doc-summaries.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "doc": "Jack drove his minivan to the bazaar to purchase milk and honey for his large family.", 4 | "summaries": [ 5 | "Jack bought milk and honey.", 6 | "Jack drove to the bazaar in a minivan." 7 | ] 8 | }, 9 | { 10 | "doc": "As Jill started taking a walk in the park, she certainly noticed that the trees were extra green this year.", 11 | "summaries": [ 12 | "Jill saw green trees in the park.", 13 | "The trees were green." 14 | ] 15 | } 16 | ] 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Primer AI 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import json 2 | import setuptools 3 | 4 | version = "0.3.4" 5 | 6 | with open("README.md", encoding="utf-8") as reader: 7 | long_description = reader.read() 8 | 9 | with open("requirements.txt") as reader: 10 | requirements = [line.strip() for line in reader] 11 | 12 | setuptools.setup( 13 | name="blanc", 14 | version=version, 15 | author="Primer AI", 16 | author_email="blanc@primer.ai", 17 | description="Human-free quality estimation of document summaries", 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | url="https://github.com/PrimerAI/blanc", 21 | packages=setuptools.find_packages(), 22 | install_requires=requirements, 23 | include_package_data=True, 24 | classifiers=[ 25 | "Programming Language :: Python :: 3", 26 | "License :: OSI Approved :: MIT License", 27 | "Operating System :: OS Independent", 28 | ], 29 | python_requires=">=3.6", 30 | entry_points={ 31 | "console_scripts": ["blanc=blanc.__main__:main"], 32 | }, 33 | ) 34 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | The datasets are as described and used in [Fill in the BLANC: Human-free quality estimation of document summaries](https://www.aclweb.org/anthology/2020.eval4nlp-1.2/) and in [Sensitivity of BLANC to human-scored qualities of text summaries](https://arxiv.org/abs/2010.06716). The human scores to summaries were assigned by 10 annotators from Odetta.ai, provided with detailed instructions and trained on trial tasks. Each dataset preserves Ids of the annotators with their individual scores. 3 | 4 | The datasets: 5 | 1. CNN_DailyMail_555: 555 text-summary pairs, with 100 texts with human summaries taken randomly from the CNN / Daily Mail dataset [Hermann et al., 2015](https://proceedings.neurips.cc/paper/2015/file/afdec7005cc9f14302cd0474fd0f3c96-Paper.pdf), and complemented with generated summaries. The single human score describes generic quality of the summary. 6 | 2. DailyNews_300: 300 text-summary pairs, created from 100 texts taken randomly from daily news of different sources. Three summaries for each text were generated by extractive, abstractive and semi-abstractive models. The single human score describes generic quality of the summary. 7 | 3. DailyNews_300_aspects: The same 300 text-summary pairs as above, but assigned 5 human quality scores, accordingly to how fluent, understandable, informative, compact and overall-good the summary is. 8 | 9 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Primer takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organization, [PrimerAI](https://github.com/PrimerAI). 6 | 7 | If you believe you have found a security vulnerability in any Primer-owned repository, please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Primer Security Team at: 14 | [security@primer.ai](mailto:security@primer.ai) 15 | 16 | 17 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 18 | 19 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 20 | * Full paths of source file(s) related to the manifestation of the issue 21 | * The location of the affected source code (tag/branch/commit or direct URL) 22 | * Any special configuration required to reproduce the issue 23 | * Step-by-step instructions to reproduce the issue 24 | * Proof-of-concept or exploit code (if possible) 25 | * Impact of the issue, including how an attacker might exploit the issue 26 | 27 | This information will help us triage your report more quickly. 28 | 29 | 30 | ## Preferred Languages 31 | 32 | We prefer all communications to be in English. 33 | 34 | ## Policy 35 | 36 | Primer follows the principle of [Coordinated Vulnerability Disclosure](https://resources.sei.cmu.edu/asset_files/SpecialReport/2017_003_001_503340.pdf). 37 | 38 | 39 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ master ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ master ] 20 | schedule: 21 | - cron: '31 17 * * 2' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] 37 | # Learn more: 38 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed 39 | 40 | steps: 41 | - name: Checkout repository 42 | uses: actions/checkout@v2 43 | 44 | # Initializes the CodeQL tools for scanning. 45 | - name: Initialize CodeQL 46 | uses: github/codeql-action/init@v2 47 | with: 48 | languages: ${{ matrix.language }} 49 | # If you wish to specify custom queries, you can do so here or in a config file. 50 | # By default, queries listed here will override any specified in a config file. 51 | # Prefix the list here with "+" to use these queries and those in the config file. 52 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 53 | 54 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 55 | # If this step fails, then you should remove it and run the build manually (see below) 56 | - name: Autobuild 57 | uses: github/codeql-action/autobuild@v2 58 | 59 | # ℹ️ Command-line programs to run using the OS shell. 60 | # 📚 https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 61 | 62 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 63 | # and modify them (or add more) to build your code if your project 64 | # uses a compiled language 65 | 66 | #- run: | 67 | # make bootstrap 68 | # make release 69 | 70 | - name: Perform CodeQL Analysis 71 | uses: github/codeql-action/analyze@v2 72 | -------------------------------------------------------------------------------- /shannon/summeval_score.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | import numpy as np 4 | from scipy.stats import pearsonr, spearmanr, kendalltau 5 | import argparse 6 | 7 | factors = ['coherence', 'consistency', 'fluency', 'relevance'] 8 | levels = ['expert'] 9 | human_cols = sum([[f'{level}_{factor}' for factor in factors] for level in levels], []) 10 | method_cols = ['shannon_score', 'info_diff', 'blanc_shannon'] 11 | models = ['small', 'medium', 'large', 'xl', 'gpt1', 'xlnet', 'transformerxl'] 12 | 13 | def load(input_file): 14 | with open(input_file) as reader: 15 | result_blobs = [json.loads(line) for line in reader] 16 | 17 | for blob in result_blobs: 18 | S = blob['S'] 19 | blob['blanc_shannon'] = (S[0][1] - S[1][0]) / sum(sum(S, [])) 20 | blob['s_impr'] = S[0][1] / (S[0][0] + S[1][1] + S[0][1]) 21 | del blob['S'] 22 | 23 | result_df = pd.DataFrame.from_records(result_blobs) 24 | 25 | with open('/nfs/data/summeval/docs-dynamicmix.json') as reader: 26 | input_blobs = json.load(reader) 27 | 28 | annotation_blobs = [] 29 | for blob in input_blobs: 30 | annotation_blob = {'id': blob['id']} 31 | for level in levels: 32 | annotations = blob.get(f'{level}_annotations') 33 | if annotations is None: 34 | continue 35 | for factor in factors: 36 | mean = sum([annotation[factor] for annotation in annotations]) / len(annotations) 37 | annotation_blob[f'{level}_{factor}'] = mean 38 | annotation_blob['lower'] = all([c.lower() == c for c in blob['summ']]) 39 | annotation_blobs.append(annotation_blob) 40 | 41 | annotation_df = pd.DataFrame.from_records(annotation_blobs) 42 | df = result_df.join(annotation_df) 43 | df['shannon_score'] = (df.ll_help - df.ll_base) / (df.ll_full - df.ll_base) 44 | df['info_decr'] = (df.ll_base - df.ll_help) / df.ll_base 45 | df['compression'] = df.num_summ_tokens / df.num_doc_tokens 46 | df['cond_lik'] = df.ll_help 47 | df['info_diff'] = df.ll_help - df.ll_base 48 | df['avg_cond_lik'] = df.ll_help / df.num_doc_tokens 49 | 50 | systems = df.groupby('system').mean() 51 | return df, systems 52 | 53 | def eval_many(models, filenames): 54 | shannon_scores, dfs = {}, [] 55 | for model, filename in zip(models, filenames): 56 | df, systems = load(filename) 57 | shannon_scores[model] = df.shannon_score 58 | print(f'{model} system kendall-b') 59 | print(systems.corr(method='kendall')[human_cols].loc[method_cols]) 60 | df['model_name'] = model 61 | dfs.append(df) 62 | shannon_scores = pd.DataFrame(shannon_scores) 63 | combined = pd.concat(dfs).reset_index() 64 | print(shannon_scores.corr()) 65 | 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--model', type=str, default=None) 68 | parser.add_argument('--upstream', action='store_true') 69 | args = parser.parse_args() 70 | 71 | if args.upstream: 72 | eval_many(range(5), [f'out/context-{context}-small.jsonl' for context in range(5)]) 73 | elif args.model is None: 74 | eval_many(models, [f'out/17-{model}.jsonl' for model in models]) 75 | else: 76 | df, systems = load(f'out/17-{args.model}.jsonl') 77 | print('Overall Pearson') 78 | print(df.corr(method='pearson')[human_cols].loc[method_cols]) 79 | print('Overall Spearman') 80 | print(df.corr(method='spearman')[human_cols].loc[method_cols]) 81 | print('Overall Kendall-b') 82 | print(df.corr(method='kendall')[human_cols].loc[method_cols]) 83 | 84 | print('System Pearson') 85 | print(systems.corr(method='pearson')[human_cols].loc[method_cols]) 86 | print('System Spearman') 87 | print(systems.corr(method='spearman')[human_cols].loc[method_cols]) 88 | print('System Kendall-b') 89 | print(systems.corr(method='kendall')[human_cols].loc[method_cols]) 90 | -------------------------------------------------------------------------------- /estime/README.md: -------------------------------------------------------------------------------- 1 | # ESTIME 2 | 3 | ESTIME as the 'number of alarms' was defined in [ESTIME: Estimation of Summary-to-Text Inconsistency by Mismatched Embeddings](https://aclanthology.org/2021.eval4nlp-1.10/). 4 | ESTIME-soft and ESTIME-coherence were defined in [Consistency and Coherence from Points of Contextual Similarity](https://arxiv.org/abs/2112.11638). Sourse: [estime](https://github.com/PrimerAI/blanc/blob/master/blanc/estime.py). 5 | 6 | ESTIME is a reference-free estimator of summary quality with emphasis on factual consistency. It can be used for filtering generated summaries, or for estimating improvement of a generation system. 7 | 8 | Usage is simple: create `Estime`, and use `evaluate_claims`. When creating Estime, specify the list of names of the measures to obtain for each claim. Basic usage: 9 | 10 | ```python 11 | >>> from blanc import Estime 12 | >>> estimator = Estime() 13 | >>> text = """In Kander’s telling, Mandel called him up out of the blue a decade or so ago to pitch a project. It made sense why. The two men had similar profiles: Jewish combat veterans in their early 30s. New statewide officeholders in the Midwest.""" 14 | >>> summary = """Kander and Mandel had similar profiles, and it makes sense.""" 15 | >>> estimator.evaluate_claims(text, [summary]) 16 | [[5]] 17 | ``` 18 | 19 | Default `device` in Estime() is `device`='cpu'. It can be set `device`='cuda'. 20 | 21 | In the example above only one summary is given to the text, and hence the list of results contains only one element [5] - the scores only for this summary. The scores list contains only single score =5, because by default the list of measures contains only one measure 'alarms'. More measures can be included: 'alarms', 'alarms_adjusted', 'alarms_alltokens', 'soft', 'coherence'. For example: 22 | 23 | ``` 24 | >>> estimator = Estime(output=['alarms', 'alarms_adjusted', 'soft', 'coherence']) 25 | >>> estimator.evaluate_claims(text, [summary]) 26 | [[5, 7.5, 0.502, -0.25]] 27 | ``` 28 | The results appear in the same order as the names given in `output`. The measures 'alarms' (the original ESTIME), 'soft' and 'coherence' are as defined in the papers. The only difference is that when there are no any tokens overlap between the claim and the text, the 'alarms' is set to the number of the tokens in the summary. Unlike 'soft', the original ESTIME does not make good estimation for the cases where the number of overlap tokens is much less than the total number of summary tokens. Starting from the version 0.3.3, the measure 'alarms_adjusted' can be added. It is defined as `alarms_adjusted = alarms * N / M`, where M is the number of overlap tokens, and N is the total number of summary tokens. Thus, it serves as an extrapolation of the 'alarms' to the total number of summary tokens. When M=0, the 'alarms_adjusted' is set to N. For curiocity (not recommended), the 'alarms_alltokens' also can be added, it is defined as `alarms_alltokens = alarms + N - M`, meaning that any non-overlapping token is counted as an alarm. 29 | 30 | For more options, see comments in the source [estime](https://github.com/PrimerAI/blanc/blob/master/blanc/estime.py), or see [estime](https://github.com/PrimerAI/primer-research/tree/main/estime). 31 | 32 | The table below is made in the same way as the Table 1 in [ESTIME](https://aclanthology.org/2021.eval4nlp-1.10/), except that the number of systems here is updated from 16 to 17, following the later version of [SummEval](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00373/100686/SummEval-Re-evaluating-Summarization-Evaluation). This means that the correlations are taken here between arrays of 1700-length (100 texts x 17 summary generation systems). 33 | 34 | |model|consistency
Spearman|consistency
Kendall|relevance
Spearman|relevance
Kendall|coherence
Spearman|coherence
Kendall|fluency
Spearman|fluency
Kendall| 35 | |:--|--:|--:|--:|--:|--:|--:|--:|--:| 36 | BLANC-AXXL|0.19|0.09|0.21|0.15|0.11|0.08|0.10|0.06| 37 | BLANC-BLU|0.20|0.10|0.18|0.13|0.10|0.07|0.11|0.06| 38 | BLANC|0.19|0.10|0.28|0.20|0.22|0.16|0.13|0.07| 39 | ESTIME-12|0.36|0.18|0.10|0.07|0.20|0.14|0.32|0.19| 40 | ESTIME-21|**0.39**|**0.19**|0.15|0.11|0.27|0.19|**0.38**|**0.22**| 41 | ESTIME-24|0.34|0.17|0.08|0.06|0.16|0.11|0.34|0.20| 42 | Jensen-Shannon|0.18|0.09|0.39|0.28|0.29|0.21|0.11|0.06| 43 | SummaQA-F|0.17|0.08|0.14|0.10|0.08|0.06|0.12|0.07| 44 | SummaQA-P|0.19|0.09|0.17|0.12|0.10|0.08|0.12|0.07| 45 | SUPERT|0.28|0.14|0.26|0.19|0.20|0.15|0.17|0.10| 46 | (r) BERTScore-F|0.10|0.05|0.38|0.28|**0.39**|**0.28**|0.13|0.07| 47 | (r) BERTScore-P|0.05|0.03|0.29|0.21|0.34|0.25|0.11|0.06| 48 | (r) BERTScore-R|0.15|0.08|**0.41**|**0.30**|0.34|0.249|0.11|0.06| 49 | (r) BLEU|0.09|0.04|0.23|0.17|0.19|0.14|0.12|0.07| 50 | (r) ROUGE-L|0.12|0.06|0.23|0.16|0.16|0.11|0.08|0.04| 51 | (r) ROUGE-1|0.13|0.07|0.28|0.20|0.17|0.12|0.07|0.04| 52 | (r) ROUGE-2|0.12|0.06|0.23|0.16|0.14|0.10|0.06|0.04| 53 | (r) ROUGE-3|0.15|0.07|0.23|0.17|0.15|0.11|0.06|0.04| 54 | 55 | (r): These measures need human-written reference summaries to evaluate a summary.
56 | The ESTIME and Jensen-Shannon scores are negated. 57 | The third row is the default version of BLANC. 58 | 59 | The numbers have slightly changed here compared to the 16-system data reported in [ESTIME](https://aclanthology.org/2021.eval4nlp-1.10/); the trends and the top correlations are the same.
60 | Notice that for consistency any reference-free measure outperforms all reference-needed measures.
61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /blanc/shannon.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import torch 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | from transformers import ( 7 | GPT2LMHeadModel, 8 | GPT2Tokenizer, 9 | OpenAIGPTLMHeadModel, 10 | OpenAIGPTTokenizer, 11 | XLNetLMHeadModel, 12 | XLNetTokenizer, 13 | TransfoXLLMHeadModel, 14 | TransfoXLTokenizer, 15 | ReformerModelWithLMHead, 16 | ReformerTokenizer, 17 | XLMWithLMHeadModel, 18 | XLMTokenizer, 19 | ) 20 | import numpy as np 21 | from nltk import sent_tokenize 22 | 23 | def get_model(name, size, device='cuda'): 24 | if device == 'cuda': 25 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | if name == 'gpt2': 27 | t = GPT2Tokenizer.from_pretrained('gpt2') 28 | if size == 'base': 29 | g = GPT2LMHeadModel.from_pretrained('gpt2') 30 | else: 31 | g = GPT2LMHeadModel.from_pretrained(f'gpt2-{size}') 32 | eos = g.config.eos_token_id # '<|endoftext|>' 33 | max_input = 1024 34 | elif name == 'gpt1': 35 | t = OpenAIGPTTokenizer.from_pretrained('openai-gpt') 36 | g = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt') 37 | eos = 1 # . 38 | max_input = 512 39 | elif name == 'xlnet': 40 | t = XLNetTokenizer.from_pretrained(f'xlnet-{size}-cased') 41 | g = XLNetLMHeadModel.from_pretrained(f'xlnet-{size}-cased') 42 | eos = g.config.eos_token_id # 43 | max_input = 1024 44 | elif name == 'transformerxl': 45 | t = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103') 46 | g = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103') 47 | eos = g.config.eos_token_id # 48 | max_input = 1024 49 | elif name == 'reformer': 50 | t = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment') 51 | g = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment') 52 | eos = g.config.eos_token_id # (empty string) 53 | max_input = 1024 54 | elif name == 'xlm': 55 | t = XLMTokenizer.from_pretrained('xlm-clm-ende-1024') 56 | g = XLMWithLMHeadModel.from_pretrained('xlm-clm-ende-1024') 57 | g.config.lang_id = g.config.lang2id['en'] 58 | eos = 4 # 59 | max_input = 1024 60 | 61 | g = g.to(device) 62 | g.config.return_dict = False 63 | g.eval() 64 | return g, t, eos, max_input 65 | 66 | def prepare_inputs_for_generation(input_ids, past=None, **kwargs): 67 | """Copied from gpt2 of huggingface transformers, but using it here separately, 68 | because in some versions this funciton worked differently, which caused errors.""" 69 | if past: # only last token for inputs_ids if past is defined in kwargs 70 | input_ids = input_ids[:, -1].unsqueeze(-1) 71 | 72 | attention_mask = kwargs.get("attention_mask", None) 73 | position_ids = kwargs.get("position_ids", None) 74 | 75 | if attention_mask is not None and position_ids is None: 76 | # create position_ids on the fly for batch generation 77 | position_ids = attention_mask.long().cumsum(-1) - 1 78 | position_ids.masked_fill_(attention_mask == 0, 1) 79 | if past: 80 | position_ids = position_ids[:, -1].unsqueeze(-1) 81 | else: 82 | position_ids = None 83 | return { 84 | "input_ids": input_ids, 85 | "past_key_values": past, 86 | "use_cache": kwargs.get("use_cache"), 87 | "position_ids": position_ids, 88 | "attention_mask": attention_mask, 89 | } 90 | 91 | class Shannon: 92 | def __init__( 93 | self, 94 | verbose=False, 95 | language_model='gpt2', 96 | model_size='base', 97 | num_upstream=0, 98 | return_token_lls=False, 99 | device='cuda' 100 | ): 101 | self.verbose = verbose 102 | self.language_model = language_model 103 | self.num_upstream = num_upstream 104 | self.return_token_lls = return_token_lls 105 | self.device = device 106 | if self.device == 'cuda': 107 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 108 | self.g, self.t, self.eos, self.max_input = get_model( 109 | language_model, model_size, self.device) 110 | 111 | def measure(self, doc_tokens, prompt): 112 | eos = torch.LongTensor([self.eos]).to(self.device) 113 | if prompt is None or (prompt.dim() == 1 and len(prompt) == 0): 114 | prompt = torch.LongTensor([]).to(self.device) 115 | 116 | token_lls = [] 117 | success = [] 118 | past = None 119 | for i, token in enumerate(doc_tokens): 120 | upstream = doc_tokens[:i] 121 | if len(upstream) + len(prompt) + 1 > self.max_input: 122 | upstream = upstream[-(self.max_input - 1 - len(prompt)):] 123 | if past is not None: 124 | past = [t[:, :, :, 1:, :] for t in past] 125 | 126 | prefix = torch.cat([eos, prompt, upstream]).unsqueeze(0) 127 | inputs = prepare_inputs_for_generation(prefix, past=past, use_cache=True, use_mems=True) 128 | with torch.no_grad(): 129 | out = self.g(**inputs) 130 | 131 | if self.language_model in 'gpt2': 132 | logits, past = out 133 | elif self.language_model in ['gpt1', 'reformer']: 134 | logits, = out 135 | logits = logits[0, -1, :] 136 | elif self.language_model == 'xlnet': 137 | logits, = out 138 | elif self.language_model == 'transformerxl': 139 | logits, past = out 140 | logits = logits[0, -1, :] 141 | elif self.language_model == 'xlm': 142 | logits, = out 143 | logits = logits[0, -1, :] 144 | probs = F.softmax(logits, dim=-1).view(-1) 145 | prob = probs[token].item() 146 | 147 | log_prob = np.log(prob) 148 | token_lls.append(log_prob) 149 | success.append(int(token == probs.argmax())) 150 | 151 | true_token = self.t.decode([token]) 152 | try: 153 | pred_token = self.t.decode([probs.argmax()]) 154 | except: 155 | pred_token = None 156 | info = -log_prob / np.log(2) 157 | self.log(f'{true_token},{info}') 158 | 159 | return token_lls, success 160 | 161 | def go(self, doc, summ, measure_t=False, measure_summ=False): 162 | sents = sent_tokenize(doc) 163 | encode_args = {'return_tensors': 'pt'} 164 | if self.language_model == 'transformerxl': 165 | encode_args['add_space_before_punct_symbol'] = True 166 | if self.language_model in ['xlnet', 'xlm']: 167 | encode_args['add_special_tokens'] = False 168 | 169 | sents_tokens = [self.t.encode(sent, **encode_args).to(self.device).view(-1) for sent in sents] 170 | summ_tokens = self.t.encode(summ, **encode_args).to(self.device).view(-1) 171 | sents_tokens = [sent_tokens[:self.max_input - 1 - len(summ_tokens)] for sent_tokens in sents_tokens] 172 | doc_tokens = torch.cat(sents_tokens, dim=-1) 173 | 174 | if measure_t: 175 | ll, tries, success = 0, 0, [] 176 | for sent_tokens in sents_tokens: 177 | sent_ll, sent_tries, sent_success = self.measure(sent_tokens, sent_tokens) 178 | ll += sent_ll 179 | tries += sent_tries 180 | success += sent_success 181 | return ll, tries, success 182 | 183 | elif measure_summ: 184 | summ_ll, summ_success = self.measure(summ_tokens, None) 185 | return summ_ll 186 | 187 | else: 188 | token_lls_base, token_lls_help, token_lls_full = [], [], [] 189 | S = [[0, 0], [0, 0]] 190 | for sent_idx in range(len(sents_tokens)): 191 | sent_tokens = sents_tokens[sent_idx] 192 | upstream_tensors = sents_tokens[sent_idx-self.num_upstream:sent_idx] 193 | if len(upstream_tensors) > 0: 194 | upstream_context = torch.cat(upstream_tensors) 195 | else: 196 | upstream_context = torch.LongTensor([]) 197 | if self.device == 'cuda': 198 | upstream_context = upstream_context.cuda() 199 | 200 | base_prompt = upstream_context 201 | help_prompt = torch.cat([summ_tokens, upstream_context]) 202 | full_prompt = torch.cat([upstream_context, sent_tokens, upstream_context]) 203 | 204 | base_sent_lls, base_sent_success = self.measure(sent_tokens, base_prompt) 205 | help_sent_lls, help_sent_success = self.measure(sent_tokens, help_prompt) 206 | full_sent_lls, full_sent_success = self.measure(sent_tokens, full_prompt) 207 | 208 | token_lls_base += base_sent_lls 209 | token_lls_help += help_sent_lls 210 | token_lls_full += full_sent_lls 211 | 212 | for b, h in zip(base_sent_success, help_sent_success): 213 | S[b][h] += 1 214 | 215 | if self.return_token_lls: 216 | doc_tokens = self.t.convert_ids_to_tokens(doc_tokens) 217 | summ_tokens = self.t.convert_ids_to_tokens(summ_tokens) 218 | return token_lls_base, token_lls_help, token_lls_full, doc_tokens, summ_tokens 219 | else: 220 | ll_base = sum(token_lls_base) 221 | ll_help = sum(token_lls_help) 222 | ll_full = sum(token_lls_full) 223 | 224 | self.log(f'Shannon Score: {(ll_help - ll_base) / (ll_full - ll_base)}') 225 | self.log(f'Info Diff: {ll_help - ll_base}') 226 | self.log(f'BLANC: {(S[0][1] - S[1][0]) / (S[0][0] + S[0][1] + S[1][0] + S[1][1])}') 227 | 228 | return ll_base, ll_help, ll_full, S, len(doc_tokens), len(summ_tokens) 229 | 230 | def log(self, s=None): 231 | if self.verbose: 232 | print(s) 233 | 234 | if __name__ == '__main__': 235 | parser = argparse.ArgumentParser() 236 | parser.add_argument('--simple', action='store_true') 237 | parser.add_argument('--verbose', action='store_true') 238 | parser.add_argument('--measure_t', action='store_true') 239 | parser.add_argument('--measure_summ', action='store_true') 240 | parser.add_argument('--input_file', type=str) 241 | parser.add_argument('--eval', type=str, choices=['6', '47', '23'], default=None) 242 | parser.add_argument('--system', type=str, default=None) 243 | parser.add_argument('--lm', type=str, default='gpt2') 244 | parser.add_argument('--model_size', type=str, default='base') 245 | parser.add_argument('--num_upstream', type=int, default=0) 246 | parser.add_argument('--start', type=int, default=0) 247 | args = parser.parse_args() 248 | 249 | s = Shannon(args.verbose, args.lm, args.model_size, args.num_upstream) 250 | 251 | if args.simple: 252 | doc = 'Jack drove his minivan to the bazaar to purchase milk and honey for his large family' 253 | summ = 'Jack bought milk and honey from the bazaar' 254 | results = s.go(doc, summ, measure_t=args.measure_t, measure_summ=args.measure_summ) 255 | 256 | print(results) 257 | else: 258 | with open(args.input_file) as reader: 259 | if args.input_file.endswith('.jsonl'): 260 | data = [json.loads(line) for line in reader] 261 | else: 262 | data = json.load(reader) 263 | 264 | selection = data[args.start:] 265 | if args.eval is not None: 266 | selection = [record for record in data if record['eval'] in args.eval] 267 | if args.system is not None: 268 | selection = [record for record in selection if record['model_id'] == args.system] 269 | 270 | for record in tqdm(selection): 271 | if args.measure_t or args.measure_summ: 272 | ll = s.go( 273 | record['doc'], record['summ'], 274 | measure_summ=args.measure_summ, measure_t=args.measure_t 275 | ) 276 | print(json.dumps({ 277 | 'doc_id': record['id'], 278 | 'system': record['model_id'], 279 | 'll_summ': ll, 280 | })) 281 | 282 | else: 283 | ll_base, ll_help, ll_full, S, num_doc_tokens, num_summ_tokens = s.go( 284 | record['doc'], record['summ'] 285 | ) 286 | print(json.dumps({ 287 | 'doc_id': record['id'], 288 | 'system': record['model_id'], 289 | 'll_base': ll_base, 290 | 'll_help': ll_help, 291 | 'll_full': ll_full, 292 | 'num_doc_tokens': num_doc_tokens, 293 | 'num_summ_tokens': num_summ_tokens, 294 | 'S': S, 295 | })) 296 | -------------------------------------------------------------------------------- /blanc/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from blanc import BlancHelp, BlancTune 9 | from blanc.utils import Defaults 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser( 14 | prog='blanc', formatter_class=argparse.ArgumentDefaultsHelpFormatter 15 | ) 16 | 17 | required_parser = parser.add_argument_group('required arguments') 18 | required_parser.add_argument( 19 | 'type', type=str, choices=['help', 'tune'], help='BLANC-help or BLANC-tune' 20 | ) 21 | 22 | input_parser = parser.add_argument_group('input arguments') 23 | input_parser.add_argument('--doc', type=str, help='single input document') 24 | input_parser.add_argument('--summary', type=str, help='single input summary') 25 | input_parser.add_argument( 26 | '--single_json', 27 | type=str, 28 | help='filename for single document summary pair', 29 | metavar='FILENAME', 30 | ) 31 | input_parser.add_argument( 32 | '--pairs_json', 33 | type=str, 34 | help='filename for list of document summary pairs', 35 | metavar='FILENAME', 36 | ) 37 | input_parser.add_argument( 38 | '--doc_summaries_json', 39 | type=str, 40 | help='filename for list of documents, each with a list of summaries', 41 | metavar='FILENAME', 42 | ) 43 | input_parser.add_argument( 44 | '--doc_key', 45 | type=str, 46 | help='json key for the input document', 47 | metavar='KEY', 48 | default=Defaults.doc_key, 49 | ) 50 | input_parser.add_argument( 51 | '--summary_key', 52 | type=str, 53 | help='json key for the input summary (single_json or pairs_json input)', 54 | metavar='KEY', 55 | default=Defaults.summary_key, 56 | ) 57 | input_parser.add_argument( 58 | '--summaries_key', 59 | type=str, 60 | help='json key for the input summaries (doc_summaries_json input)', 61 | metavar='KEY', 62 | default=Defaults.summaries_key, 63 | ) 64 | input_parser.add_argument( 65 | '--output_json', 66 | type=str, 67 | help='filename for output file, or None to print to STDOUT', 68 | metavar='FILENAME', 69 | ) 70 | 71 | blanc_parser = parser.add_argument_group('arguments for BLANC-help and BLANC-tune') 72 | blanc_parser.add_argument( 73 | '--model_name', 74 | type=str, 75 | choices=['bert-base-cased', 'bert-base-uncased', 'bert-large-cased', 'bert-large-uncased'], 76 | help='BERT model type', 77 | default=Defaults.model_name, 78 | metavar='NAME', 79 | ) 80 | blanc_parser.add_argument( 81 | '--measure', 82 | type=str, 83 | choices=['improve', 'relative'], 84 | help='measure improve or relative, as defined in the paper', 85 | default=Defaults.measure, 86 | ) 87 | blanc_parser.add_argument( 88 | '--gap', 89 | type=int, 90 | help='distance between words to mask during inference', 91 | default=Defaults.gap, 92 | ) 93 | blanc_parser.add_argument( 94 | '--gap_mask', 95 | type=int, 96 | help='number of tokens to mask at each designated position during inference', 97 | default=Defaults.gap_mask, 98 | ) 99 | blanc_parser.add_argument( 100 | '--gap_tune', 101 | type=int, 102 | help='distance between words to mask during finetuning', 103 | default=Defaults.gap, 104 | ) 105 | blanc_parser.add_argument( 106 | '--gap_mask_tune', 107 | type=int, 108 | help='number of tokens to mask at each designated position during finetuning', 109 | default=Defaults.gap_mask, 110 | ) 111 | blanc_parser.add_argument( 112 | '--min_token_length_normal', 113 | type=int, 114 | help=( 115 | 'minimum number of chars in normal tokens to mask, where a normal token is ' 116 | 'a whole word' 117 | ), 118 | default=Defaults.min_token_length_normal, 119 | metavar='LEN', 120 | ) 121 | blanc_parser.add_argument( 122 | '--min_token_length_lead', 123 | type=int, 124 | help='minimum number of chars in lead token to mask, where a lead token begins a word', 125 | default=Defaults.min_token_length_lead, 126 | metavar='LEN', 127 | ) 128 | blanc_parser.add_argument( 129 | '--min_token_length_followup', 130 | type=int, 131 | help=( 132 | 'minimum number of chars in followup token to mask, where a followup token ' 133 | 'continues a word' 134 | ), 135 | default=Defaults.min_token_length_followup, 136 | metavar='LEN', 137 | ) 138 | blanc_parser.add_argument( 139 | '--min_token_length_normal_tune', 140 | type=int, 141 | help=( 142 | 'minimum number of chars in normal tokens to mask at tuning, where a normal token is ' 143 | 'a whole word' 144 | ), 145 | default=Defaults.min_token_length_normal_tune, 146 | metavar='LEN', 147 | ) 148 | blanc_parser.add_argument( 149 | '--min_token_length_lead_tune', 150 | type=int, 151 | help='minimum number of chars in lead token to mask at tuning, where a lead token begins a word', 152 | default=Defaults.min_token_length_lead_tune, 153 | metavar='LEN', 154 | ) 155 | blanc_parser.add_argument( 156 | '--min_token_length_followup_tune', 157 | type=int, 158 | help=( 159 | 'minimum number of chars in followup token to mask at tuning, where a followup token ' 160 | 'continues a word' 161 | ), 162 | default=Defaults.min_token_length_followup_tune, 163 | metavar='LEN', 164 | ) 165 | blanc_parser.add_argument( 166 | '--device', type=str, help='cpu or cuda device', default=Defaults.device, 167 | ) 168 | blanc_parser.add_argument( 169 | '--random_seed', 170 | type=int, 171 | help='random seed for python and torch', 172 | default=Defaults.random_seed, 173 | metavar='SEED', 174 | ) 175 | blanc_parser.add_argument( 176 | '--inference_batch_size', 177 | type=int, 178 | help='batch size to use during inference', 179 | default=Defaults.inference_batch_size, 180 | metavar='SIZE', 181 | ) 182 | blanc_parser.add_argument( 183 | '--inference_mask_evenly', 184 | type=bool, 185 | help=( 186 | 'when True, mask every `gap` tokens (`gap_mask` tokens at once) that are longer than `min_token_length`' 187 | 'during finetuning, when False randomly mask tokens with probability 0.15' 188 | ), 189 | default=Defaults.inference_mask_evenly, 190 | metavar='MASK_EVENLY', 191 | ) 192 | 193 | help_parser = parser.add_argument_group('BLANC-help arguments') 194 | help_parser.add_argument( 195 | '--filler_token', 196 | type=str, 197 | help='token to use as filler in lieu of summary', 198 | default=Defaults.filler_token, 199 | metavar='TOKEN', 200 | ) 201 | help_parser.add_argument( 202 | '--help_sep', 203 | type=str, 204 | help=( 205 | "token to use to separate the summary or filler from the sentence, " 206 | "or '' for no separator" 207 | ), 208 | default=Defaults.help_sep, 209 | metavar='SEP', 210 | ) 211 | 212 | tune_parser = parser.add_argument_group('BLANC-tune arguments') 213 | tune_parser.add_argument( 214 | '--finetune_batch_size', 215 | type=int, 216 | help='batch size to use when finetuning on summary', 217 | default=Defaults.finetune_batch_size, 218 | metavar='SIZE', 219 | ) 220 | tune_parser.add_argument( 221 | '--finetune_epochs', 222 | type=int, 223 | help='number of epochs to train for when finetuning on summary', 224 | default=Defaults.finetune_epochs, 225 | metavar='EPOCHS', 226 | ) 227 | tune_parser.add_argument( 228 | '--finetune_mask_evenly', 229 | type=bool, 230 | help=( 231 | 'when True, mask every `gap` tokens (`gap_mask` tokens at once) that are longer than `min_token_length`' 232 | 'during finetuning, when False randomly mask tokens with probability 0.15' 233 | ), 234 | default=Defaults.finetune_mask_evenly, 235 | metavar='MASK_EVENLY', 236 | ) 237 | tune_parser.add_argument( 238 | '--finetune_chunk_size', 239 | type=int, 240 | help='number of summary tokens to use at a time when finetuning', 241 | default=Defaults.finetune_chunk_size, 242 | metavar='SIZE', 243 | ) 244 | tune_parser.add_argument( 245 | '--finetune_chunk_stride', 246 | type=int, 247 | help='number of tokens between summary chunks for finetuning', 248 | default=Defaults.finetune_chunk_stride, 249 | metavar='STRIDE', 250 | ) 251 | tune_parser.add_argument( 252 | '--learning_rate', 253 | type=float, 254 | help='learning rate when finetuning on summary', 255 | default=Defaults.learning_rate, 256 | metavar='LR', 257 | ) 258 | tune_parser.add_argument( 259 | '--warmup_steps', 260 | type=int, 261 | help='warmup steps when finetuning on summary', 262 | default=Defaults.warmup_steps, 263 | metavar='STEPS', 264 | ) 265 | 266 | args = parser.parse_args() 267 | 268 | random.seed(args.random_seed) 269 | np.random.seed(args.random_seed) 270 | torch.manual_seed(args.random_seed) 271 | 272 | if args.type == 'help': 273 | model = BlancHelp( 274 | model_name=args.model_name, 275 | measure=args.measure, 276 | gap=args.gap, 277 | gap_mask=args.gap_mask, 278 | gap_tune=args.gap_tune, 279 | gap_mask_tune=args.gap_mask_tune, 280 | min_token_length_normal=args.min_token_length_normal, 281 | min_token_length_lead=args.min_token_length_lead, 282 | min_token_length_followup=args.min_token_length_followup, 283 | min_token_length_normal_tune=args.min_token_length_normal_tune, 284 | min_token_length_lead_tune=args.min_token_length_lead_tune, 285 | min_token_length_followup_tune=args.min_token_length_followup_tune, 286 | device=args.device, 287 | inference_batch_size=args.inference_batch_size, 288 | inference_mask_evenly=args.inference_mask_evenly, 289 | filler_token=args.filler_token, 290 | help_sep=args.help_sep, 291 | ) 292 | elif args.type == 'tune': 293 | model = BlancTune( 294 | model_name=args.model_name, 295 | measure=args.measure, 296 | gap=args.gap, 297 | gap_mask=args.gap_mask, 298 | gap_tune=args.gap_tune, 299 | gap_mask_tune=args.gap_mask_tune, 300 | min_token_length_normal=args.min_token_length_normal, 301 | min_token_length_lead=args.min_token_length_lead, 302 | min_token_length_followup=args.min_token_length_followup, 303 | min_token_length_normal_tune=args.min_token_length_normal_tune, 304 | min_token_length_lead_tune=args.min_token_length_lead_tune, 305 | min_token_length_followup_tune=args.min_token_length_followup_tune, 306 | device=args.device, 307 | inference_batch_size=args.inference_batch_size, 308 | inference_mask_evenly=args.inference_mask_evenly, 309 | finetune_batch_size=args.finetune_batch_size, 310 | finetune_epochs=args.finetune_epochs, 311 | finetune_mask_evenly=args.finetune_mask_evenly, 312 | finetune_chunk_size=args.finetune_chunk_size, 313 | finetune_chunk_stride=args.finetune_chunk_stride, 314 | learning_rate=args.learning_rate, 315 | warmup_steps=args.warmup_steps, 316 | ) 317 | 318 | key = f"blanc-{args.type}-measure-{args.measure}" 319 | if args.doc is not None: 320 | result = model.eval_once(args.doc, args.summary) 321 | result_json = {key: result} 322 | elif args.single_json is not None: 323 | with open(args.single_json) as reader: 324 | data = json.load(reader) 325 | 326 | result = model.eval_once(data[args.doc_key], data[args.summary_key]) 327 | result_json = {key: result} 328 | elif args.pairs_json is not None: 329 | with open(args.pairs_json) as reader: 330 | data = json.load(reader) 331 | docs = [pair[args.doc_key] for pair in data] 332 | summaries = [pair[args.summary_key] for pair in data] 333 | 334 | result = model.eval_pairs(docs, summaries) 335 | result_json = [{key: score} for score in result] 336 | elif args.doc_summaries_json is not None: 337 | with open(args.doc_summaries_json) as reader: 338 | data = json.load(reader) 339 | docs = [doc_summary[args.doc_key] for doc_summary in data] 340 | doc_summaries = [doc_summary[args.summaries_key] for doc_summary in data] 341 | 342 | result = model.eval_summaries_for_docs(docs, doc_summaries) 343 | result_json = [{key: scores} for scores in result] 344 | else: 345 | raise ValueError('Please provide an input document and summary') 346 | 347 | if args.output_json is None: 348 | print(result) 349 | else: 350 | with open(args.output_json, 'w') as writer: 351 | json.dump(result_json, writer) 352 | 353 | 354 | if __name__ == '__main__': 355 | main() 356 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evaluation measures 2 | 3 | This repositary contains reference implementations and explanations to accompany [Primer.ai](https://primer.ai) research and publications related to evaluation measures, mostly for the purpose of summary evaluation. 4 | 5 | These evaluation measures include: 6 | 7 | * BLANC-help (or simply 'BLANC'), BLANC-tune 8 | * blanc.py 9 | * All the info is in this page 10 | * Shannon Score, Information Difference, BLANC-Shannon 11 | * shannon.py 12 | * Info: [Shannon Score and Information Difference](https://github.com/PrimerAI/blanc/tree/master/shannon) 13 | * ESTIME, ESTIME-soft, ESTIME-coherence 14 | * estime.py 15 | * Info: [ESTIME (hard, soft and coherence)](https://github.com/PrimerAI/blanc/tree/master/estime) 16 | 17 | Annotated summary quality datasets: [data](https://github.com/PrimerAI/blanc/tree/master/data) 18 | 19 | 20 | ## Setup 21 | 1. Install Python 3.6 or higher 22 | 2. Install with `pip install blanc` 23 | 24 | 25 | ## BLANC 26 | This is the reference implementation of BLANC-help and BLANC-tune as defined in [Fill in the BLANC: Human-free quality estimation of document summaries](https://www.aclweb.org/anthology/2020.eval4nlp-1.2/). 27 | 28 | BLANC is a reference-free approach to the automatic estimation of document summary quality. Our goal is to measure the functional performance of a summary with an objective, reproducible, and fully automated method. Our approach achieves this by measuring the performance boost gained by a pre-trained language model with access to a document summary while carrying out its language understanding task on the document's text. Unlike ROUGE, BLANC does not require human-written reference summaries, allowing for fully human-free summary quality estimation. 29 | 30 | Two types of BLANC scores were introduced in the paper and are available in this repo: BLANC-help and BLANC-tune. BLANC-help is faster to calculate (around 30% faster on CUDA with default settings), but BLANC-tune is more theoretically principled. They are around 90% correlated with each other, so either one can be used in most cases.
31 | BLANC-help with gap=2 on average correlates the best with human scores [Sensitivity of BLANC to human-scored qualities of text summaries](https://arxiv.org/abs/2010.06716), it is now set as default. The original paper used gap=6. Optimal parameters for BLANC-help and for BLANC-tune are found by using 'max-help' criterion, without relying on human summaries or human scores, in [Is Human Scoring the Best Criteria for Summary Evaluation?](https://aclanthology.org/2021.findings-acl.192) (the paper points to the possible bias of human experts). 32 | 33 | 34 | ## Python Usage 35 | Basic usage: 36 | ```python 37 | >>> from blanc import BlancHelp, BlancTune 38 | >>> document = "Jack drove his minivan to the bazaar to purchase milk and honey for his large family." 39 | >>> summary = "Jack bought milk and honey." 40 | >>> blanc_help = BlancHelp() 41 | >>> blanc_tune = BlancTune(finetune_mask_evenly=False, show_progress_bar=False) 42 | >>> blanc_help.eval_once(document, summary) 43 | 0.2222222222222222 44 | >>> blanc_tune.eval_once(document, summary) 45 | 0.3333333333333333 46 | ``` 47 | 48 | By default, BLANC is run on the CPU. Using CUDA with batching is much faster: 49 | ```python 50 | blanc_help = BlancHelp(device='cuda', inference_batch_size=128) 51 | blanc_tune = BlancTune(device='cuda', inference_batch_size=24, finetune_mask_evenly=False, finetune_batch_size=24) 52 | ``` 53 | With these batch sizes, BLANC-help takes around 1.4 sec per summary and BLANC-tune takes around 1.8 sec per summary on an NVIDIA V100. In addition to the parameters controlling device and batch sizes, BlancHelp and BlancTune take several other parameters controlling how the BLANC scores are calculated, and the default values for those parameters reproduce the results of the paper. BlancTune results may vary if random_seed is not set. 54 | 55 | If you want to compute the BLANC scores of many documents and summaries at once, you can use `eval_pairs()` or `eval_summaries_for_docs()`. `eval_pairs()` is useful when you have many documents, each with a single summary: 56 | ```python 57 | >>> documents = ["Jack drove his minivan to the bazaar to purchase milk and honey for his large family.", "As Jill started taking a walk in the park, she certainly noticed that the trees were extra green this year."] 58 | >>> summaries = ["Jack bought milk and honey.", "Jill saw green trees in the park."] 59 | >>> blanc_help.eval_pairs(documents, summaries) 60 | [0.2222222222222222, 0.0] 61 | ``` 62 | 63 | `eval_summaries_for_docs()` is useful when you have many documents, each with many summaries: 64 | ```python 65 | >>> doc_summaries = [["Jack bought milk and honey.", "Jack drove to the bazaar in a minivan"], ["Jill saw green trees in the park.", "The trees were green."]] 66 | >>> blanc_tune.eval_summaries_for_docs(documents, doc_summaries) 67 | [[0.2222222222222222, 0.2222222222222222], [-0.07142857142857142, -0.14285714285714285]] 68 | ``` 69 | 70 | ## CLI Usage 71 | A CLI for computing BLANC scores is provided for convenience. 72 | ``` 73 | $ blanc help --gap 6 --doc "Jack drove his minivan to the bazaar to purchase milk and honey for his large family." --summary "Jack bought milk and honey." 74 | 0.1111111111111111 75 | ``` 76 | 77 | Input data can also be provided in JSON format, with sample JSON input provided in `data/` 78 | ``` 79 | $ blanc help --single_json data/single.json --gap 6 80 | 0.1111111111111111 81 | $ blanc tune --pairs_json data/pairs.json --gap 6 --finetune_mask_evenly False 82 | [0.2222222222222222, 0.14285714285714285] 83 | $ blanc tune --doc_summaries_json data/doc-summaries.json --gap 6 --finetune_mask_evenly False 84 | [[0.2222222222222222, 0.2222222222222222], [0.14285714285714285, 0.07142857142857142]] 85 | ``` 86 | 87 | The `single_json` input format expects a single JSON blob with keys `document` and `summary`. The `pairs_json` input format expects a list of JSON blobs, each with a `document` and a `summary`. The `doc_summaries_json` input format expects a list of JSON blobs, each with keys `document` and `summaries`, where `summaries` is a list of strings. These keys are customizable with the `doc_key`, `summary_key`, and `summaries_key` arguments. By default, the output is printed to STDOUT, but it can be written to a JSON file provided with the `output_json` argument. 88 | 89 | Full documentation is available with `blanc --help`: 90 | ``` 91 | required arguments: 92 | {help,tune} BLANC-help or BLANC-tune 93 | 94 | input arguments: 95 | --doc DOC single input document (default: None) 96 | --summary SUMMARY single input summary (default: None) 97 | --single_json FILENAME 98 | filename for single document summary pair (default: 99 | None) 100 | --pairs_json FILENAME 101 | filename for list of document summary pairs (default: 102 | None) 103 | --doc_summaries_json FILENAME 104 | filename for list of documents, each with a list of 105 | summaries (default: None) 106 | --doc_key KEY json key for the input document (default: doc) 107 | --summary_key KEY json key for the input summary (single_json or 108 | pairs_json input) (default: summary) 109 | --summaries_key KEY json key for the input summaries (doc_summaries_json 110 | input) (default: summaries) 111 | 112 | arguments for BLANC-help and BLANC-tune: 113 | --model_name NAME BERT model type (default: bert-base-uncased) 114 | --measure {improve,relative} 115 | measure improve or relative, as defined in the paper 116 | (default: relative) 117 | --gap GAP distance between words to mask during inference 118 | (default: 2) 119 | --gap_mask NUM number of tokens to mask during inference at each 120 | gap-defined position 121 | (default: 1) 122 | --min_token_length_normal LEN 123 | minimum number of chars in normal tokens to mask, 124 | where a normal token is a whole word (default: 4) 125 | --min_token_length_lead LEN 126 | minimum number of chars in lead token to mask, where a 127 | lead token begins a word (default: 2) 128 | --min_token_length_followup LEN 129 | minimum number of chars in followup token to mask, 130 | where a followup token continues a word (default: 100) 131 | --device DEVICE cpu or cuda device (default: cpu) 132 | --random_seed SEED random seed for python and torch (default: 1) 133 | --inference_batch_size SIZE 134 | batch size to use during inference (default: 1) 135 | --inference_mask_evenly MASK_EVENLY 136 | when True, mask every `gap` tokens that are longer 137 | than `min_token_length` during finetuning, when False 138 | randomly mask tokens with probability 0.15 (default: 139 | True) 140 | 141 | BLANC-help arguments: 142 | --filler_token TOKEN token to use as filler in lieu of summary (default: .) 143 | --help_sep SEP token to use to separate the summary or filler from 144 | the sentence, or '' for no separator (default: ) 145 | 146 | BLANC-tune arguments: 147 | --finetune_batch_size SIZE 148 | batch size to use when finetuning on summary (default: 149 | 1) 150 | --finetune_epochs EPOCHS 151 | number of epochs to train for when finetuning on 152 | summary (default: 10) 153 | --finetune_mask_evenly MASK_EVENLY 154 | when True, mask every `gap` tokens that are longer 155 | than `min_token_length`during finetuning, when False 156 | randomly mask tokens with probability 0.15 (default: 157 | False) 158 | --finetune_chunk_size SIZE 159 | number of summary tokens to use at a time when 160 | finetuning (default: 64) 161 | --finetune_chunk_stride STRIDE 162 | number of tokens between summary chunks for finetuning 163 | (default: 32) 164 | --learning_rate LR learning rate when finetuning on summary (default: 165 | 5e-05) 166 | --warmup_steps STEPS warmup steps when finetuning on summary (default: 0) 167 | ``` 168 | 169 | ## BLANC on [SummEval](https://github.com/Yale-LILY/SummEval) dataset 170 | BLANC can run on top of any pretrained BERT or AlBERT model (more will be added). The table below lists correlations of BLANC with human scores on the human-annotated [SummEval](https://github.com/Yale-LILY/SummEval) dataset (described in [SummEval: Re-evaluating Summarization Evaluation](https://arxiv.org/abs/2007.12626v4)). The dataset contains 1600 text-summary pairs by 100 texts x 16 systems. We show correlation (Spearman and Kendall's Tau-c) between BLANC-help and experts-average scores for each quality of the summary (coherence, consistency, fluency, relevance): 171 | 172 | |quality|model|Spearman|Kendall| 173 | |:---------------|:-----------|-----:|-----:| 174 | |coherence|bbu|0.122|0.09| 175 | |coherence|bbc|0.197|0.142| 176 | |coherence|blu|0.116|0.085| 177 | |coherence|blc|0.226|0.165| 178 | |coherence|bluw|0.083|0.06| 179 | |coherence|blcw|0.196|0.142| 180 | |coherence|ab|0.168|0.125| 181 | |coherence|al|0.152|0.111| 182 | |coherence|axl|0.15|0.11| 183 | |coherence|axxl|0.127|0.093| 184 | |consistency|bbu|0.19|0.094| 185 | |consistency|bbc|0.19|0.094| 186 | |consistency|blu|0.207|0.102| 187 | |consistency|blc|0.204|0.1| 188 | |consistency|bluw|0.167|0.082| 189 | |consistency|blcw|0.18|0.089| 190 | |consistency|ab|0.192|0.095| 191 | |consistency|al|0.199|0.098| 192 | |consistency|axl|0.179|0.088| 193 | |consistency|axxl|0.2|0.098| 194 | |fluency|bbu|0.089|0.051| 195 | |fluency|bbc|0.108|0.062| 196 | |fluency|blu|0.112|0.065| 197 | |fluency|blc|0.113|0.064| 198 | |fluency|bluw|0.107|0.061| 199 | |fluency|blcw|0.121|0.069| 200 | |fluency|ab|0.124|0.072| 201 | |fluency|al|0.132|0.076| 202 | |fluency|axl|0.119|0.069| 203 | |fluency|axxl|0.115|0.066| 204 | |relevance|bbu|0.216|0.156| 205 | |relevance|bbc|0.278|0.201| 206 | |relevance|blu|0.217|0.156| 207 | |relevance|blc|0.306|0.223| 208 | |relevance|bluw|0.194|0.14| 209 | |relevance|blcw|0.258|0.188| 210 | |relevance|ab|0.27|0.193| 211 | |relevance|al|0.267|0.192| 212 | |relevance|axl|0.245|0.176| 213 | |relevance|axxl|0.246|0.179| 214 | 215 | The [transformers](https://huggingface.co/transformers/pretrained_models.html) models are: bert-base-uncased (bbu), bert-base-cased (bbc), bert-large-uncased (blu), bert-large-cased (blc), bert-large-uncased-whole-word-masking (bluw), bert-large-cased-whole-word-masking (blcw), albert-base-v2 (ab), albert-large-v2 (al), albert-xlarge-v2 (axl), albert-xxlarge-v2 (axxl). The BLANC-help was used with the current default settings (gap=2, min_token_length_normal=4, min_token_length_lead=2, min_token_length_followup=100). All the p-values above are of order 10^-5 or lower. 216 | 217 | The system-level correlations (correlations between 16-dimensional scores after averaging each system scores over 100 texts) have too high p-values. The table below shows only the correlations with p-values <0.05: 218 | 219 | |quality|model|Spearman|p|Kendall|p| 220 | |:---------------|:-----------|-----:|-----:|-----:|-----:| 221 | |consistency|bbu|0.738|0.001|0.567|0.002| 222 | |consistency|bbc|0.759|0.001|0.533|0.003| 223 | |consistency|blu|0.724|0.002|0.567|0.002| 224 | |consistency|blc|0.788|0.0|0.567|0.002| 225 | |consistency|bluw|0.771|0.0|0.617|0.001| 226 | |consistency|blcw|0.791|0.0|0.6|0.001| 227 | |consistency|ab|0.724|0.002|0.583|0.001| 228 | |consistency|al|0.774|0.0|0.6|0.001| 229 | |consistency|axl|0.706|0.002|0.517|0.005| 230 | |consistency|axxl|0.812|0.0|0.617|0.001| 231 | |fluency|bbc|0.558|0.025|0.444|0.017| 232 | |fluency|blc|0.549|0.028|0.444|0.017| 233 | |fluency|bluw|0.525|0.037|0.377|0.043| 234 | |fluency|blcw|0.595|0.015|0.477|0.01| 235 | |fluency|al|0.518|0.04|0.393|0.034| 236 | |fluency|axxl|0.534|0.033|0.41|0.027| 237 | |relevance|bbc| | |0.467|0.011| 238 | |relevance|blc| | |0.467|0.011| 239 | |relevance|blcw|0.515|0.041|0.467|0.011| 240 | -------------------------------------------------------------------------------- /blanc/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import random 3 | import unicodedata 4 | import copy 5 | 6 | import torch 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | # prefix used by the wordpiece tokenizer to indicate that the token continues the previous word 10 | WORDPIECE_PREFIX = '##' 11 | # a range of reasonable token ids to use for replacement during model training 12 | TOKEN_REPLACE_RANGE = (1000, 29999) 13 | # attention mask value that tells the model to not ignore the token 14 | NOT_MASKED = 1 15 | # token type for a single input sequence 16 | TOKEN_TYPE_A = 0 17 | # padding value used for attention_mask 18 | MASK_PAD = 0 19 | # padding value used for token_type_ids 20 | TOKEN_TYPE_PAD = 1 21 | # label to use for tokens to ignore when computing masked language modeling training loss 22 | LABEL_IGNORE = -100 23 | # maximum number of input tokens BERT supports 24 | BERT_MAX_TOKENS = 512 25 | 26 | # used to represent inputs to the BERT model 27 | BertInput = namedtuple( 28 | typename='BertInput', 29 | field_names=['input_ids', 'attention_mask', 'token_type_ids', 'labels', 'masked_idxs'], 30 | ) 31 | 32 | # all the configuration options 33 | Config = namedtuple( 34 | 'Config', 35 | [ 36 | 'doc_key', 37 | 'summary_key', 38 | 'summaries_key', 39 | 'model_name', 40 | 'measure', 41 | 'gap', 42 | 'gap_mask', 43 | 'gap_tune', 44 | 'gap_mask_tune', 45 | 'min_token_length_normal', 46 | 'min_token_length_lead', 47 | 'min_token_length_followup', 48 | 'min_token_length_normal_tune', 49 | 'min_token_length_lead_tune', 50 | 'min_token_length_followup_tune', 51 | 'device', 52 | 'random_seed', 53 | 'inference_batch_size', 54 | 'inference_mask_evenly', 55 | 'len_sent_allow_cut', 56 | 'filler_token', 57 | 'help_sep', 58 | 'finetune_batch_size', 59 | 'finetune_epochs', 60 | 'finetune_mask_evenly', 61 | 'finetune_chunk_size', 62 | 'finetune_chunk_stride', 63 | 'finetune_top_fully', 64 | 'id_layer_freeze_below', 65 | 'id_layer_freeze_above', 66 | 'show_progress_bar', 67 | 'p_mask', 68 | 'p_token_replace', 69 | 'p_token_original', 70 | 'learning_rate', 71 | 'warmup_steps', 72 | ], 73 | ) 74 | 75 | # the default configuration options that don't require a GPU 76 | # We found gap=2 to work the best. To reproduce the original paper results use gap=6 77 | Defaults = Config( 78 | doc_key='doc', 79 | summary_key='summary', 80 | summaries_key='summaries', 81 | model_name='bert-base-uncased', 82 | measure='relative', 83 | gap=2, 84 | gap_mask=1, 85 | gap_tune=-1, 86 | gap_mask_tune=-1, 87 | min_token_length_normal=4, 88 | min_token_length_lead=2, 89 | min_token_length_followup=100, 90 | min_token_length_normal_tune=-1, 91 | min_token_length_lead_tune=-1, 92 | min_token_length_followup_tune=-1, 93 | device='cpu', 94 | random_seed=0, 95 | inference_batch_size=1, 96 | inference_mask_evenly=True, 97 | len_sent_allow_cut=100, 98 | filler_token='.', 99 | help_sep='', 100 | finetune_batch_size=1, 101 | finetune_epochs=10, 102 | finetune_chunk_size=64, 103 | finetune_chunk_stride=32, 104 | finetune_top_fully=True, 105 | id_layer_freeze_below=-1, 106 | id_layer_freeze_above=-1, 107 | show_progress_bar=True, 108 | p_mask=0.15, 109 | p_token_replace=0.1, 110 | p_token_original=0.1, 111 | learning_rate=5e-5, 112 | finetune_mask_evenly=True, 113 | warmup_steps=0, 114 | ) 115 | 116 | 117 | def set_seed(seed_value): 118 | random.seed(seed_value) 119 | torch.manual_seed(seed_value) 120 | 121 | 122 | def batch_data(data, batch_size): 123 | """Given a list, batch that list into chunks of size batch_size 124 | 125 | Args: 126 | data (List): list to be batched 127 | batch_size (int): size of each batch 128 | 129 | Returns: 130 | batches (List[List]): a list of lists, each inner list of size batch_size except possibly 131 | the last one. 132 | """ 133 | batches = [data[i : i + batch_size] for i in range(0, len(data), batch_size)] 134 | return batches 135 | 136 | 137 | def is_token_large_enough(token, next_token, min_token_lengths): 138 | """Determine if a token is large enough according to min_token_lengths 139 | 140 | Args: 141 | token (str): a wordpiece token 142 | next_token (str): the next wordpiece token in the sequence 143 | min_token_lengths (Tuple[int, int, int]): minimum token lengths for normal tokens, lead 144 | tokens, and followup tokens 145 | 146 | Returns: 147 | large_enough (bool): whether or not the token is large enough 148 | """ 149 | min_normal, min_lead, min_followup = min_token_lengths 150 | token_size = len(token) 151 | 152 | if token.startswith(WORDPIECE_PREFIX): 153 | token_size -= len(WORDPIECE_PREFIX) 154 | return token_size >= min_followup 155 | elif next_token.startswith(WORDPIECE_PREFIX): 156 | return token_size >= min_lead 157 | else: 158 | return token_size >= min_normal 159 | 160 | 161 | def mask_tokens_evenly(tokens, gap, min_token_lengths, mask_token, gap_mask=1): 162 | """Produce several maskings for the given tokens where each masking is created by masking every 163 | "gap" tokens, as long as the token is large enough according to min_token_lengths. 164 | 165 | Args: 166 | tokens (List[str]): a sequence of wordpiece tokens 167 | gap (int): the spacing in-between masked tokens 168 | min_token_lengths (Tuple[int, int, int]): minimum token lengths for normal tokens, lead 169 | tokens, and followup tokens 170 | mask_token (str): wordpiece token to use for masking 171 | 172 | Returns: 173 | masked_inputs (List[List[str]]): a list of token sequences, where each token sequence 174 | contains masked tokens separated by "gap" tokens. 175 | all_answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps 176 | token indices corresponding to masked tokens back to their original token. 177 | """ 178 | gap = min(gap, len(tokens)) 179 | masked_inputs = [] 180 | all_answers = [] 181 | for modulus in range(gap): 182 | masked_input = [] 183 | answers = {} 184 | for idx, token in enumerate(tokens): 185 | next_token = '' if idx + 1 == len(tokens) else tokens[idx + 1] 186 | large_enough = is_token_large_enough(token, next_token, min_token_lengths) 187 | 188 | idx_off = idx % gap 189 | if gap == 1: 190 | can_mask = True 191 | elif modulus + gap_mask >= gap: 192 | can_mask = idx_off >= modulus or idx_off < (modulus + gap_mask)%gap 193 | else: 194 | can_mask = idx_off >= modulus and idx_off < modulus + gap_mask 195 | if can_mask and large_enough: 196 | masked_input.append(mask_token) 197 | answers[idx] = token 198 | else: 199 | masked_input.append(token) 200 | 201 | if len(answers) > 0: 202 | masked_inputs.append(masked_input) 203 | all_answers.append(answers) 204 | 205 | return masked_inputs, all_answers 206 | 207 | 208 | def mask_tokens_randomly(tokens, min_token_lengths, mask_token, p_mask): 209 | """Produce several maskings for the given tokens by randomly choosing tokens to mask 210 | 211 | Args: 212 | tokens (List[str]): a sequence of wordpiece tokens 213 | min_token_lengths (Tuple[int, int, int]): minimum token lengths for normal tokens, lead 214 | tokens, and followup tokens 215 | mask_token (str): wordpiece token to use for masking 216 | 217 | Returns: 218 | masked_inputs (List[List[str]]): a list of token sequences, where each token sequence 219 | contains masked tokens chosen randomly. 220 | all_answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps 221 | token indices corresponding to masked tokens back to their original token. 222 | """ 223 | n_mask = max(int(len(tokens) * p_mask), 1) 224 | 225 | token_positions = [] 226 | for idx, token in enumerate(tokens): 227 | next_token = '' if idx + 1 == len(tokens) else tokens[idx + 1] 228 | if is_token_large_enough(token, next_token, min_token_lengths): 229 | token_positions.append(idx) 230 | random.shuffle(token_positions) 231 | 232 | all_inputs, all_answers = [], [] 233 | while len(token_positions) > 0: 234 | positions_to_mask = token_positions[:n_mask] 235 | token_positions = token_positions[n_mask:] 236 | 237 | inputs, answers = [], {} 238 | for idx, token in enumerate(tokens): 239 | if idx in positions_to_mask: 240 | inputs.append(mask_token) 241 | answers[idx] = token 242 | else: 243 | inputs.append(token) 244 | 245 | all_inputs.append(inputs) 246 | all_answers.append(answers) 247 | 248 | return all_inputs, all_answers 249 | 250 | 251 | def stack_tensor(input_list, pad_value, device): 252 | """Given a batch of inputs, stack them into a single tensor on the given device, padding them 253 | at the back with pad_value to make sure they are all the same length. 254 | 255 | Args: 256 | input_list (List[List[int]]): a list of input sequences 257 | pad_value (int): the value to use for padding input sequences to make them the same length 258 | device (str): torch device (usually "cpu" or "cuda") 259 | 260 | Returns: 261 | stacked_tensor (torch.LongTensor): a tensor of dimensions (batch size) x (seq length) 262 | """ 263 | tensor_list = [torch.LongTensor(inputs) for inputs in input_list] 264 | stacked_tensor = pad_sequence( 265 | sequences=tensor_list, batch_first=True, padding_value=pad_value 266 | ).to(device) 267 | 268 | return stacked_tensor 269 | 270 | 271 | def get_input_tensors(input_batch, device, tokenizer): 272 | """Given a list of BertInputs, return the relevant tensors that are fed into BERT. 273 | 274 | Args: 275 | input_batch (List[BertInput]): a batch of model inputs 276 | device (str): torch device (usually "cpu" or "cuda") 277 | tokenizer (BertTokenizer): the wordpiece tokenizer used for BERT 278 | 279 | Returns: 280 | input_ids (torch.LongTensor): ids corresponding to input tokens 281 | attention_mask (torch.LongTensor): tells BERT about parts of the input to ignore 282 | token_type_ids (torch.LongTensor): used to differentiate input segments 283 | labels (torch.LongTensor): contains the original token ids for tokens that were masked 284 | """ 285 | input_ids_list = [inputs.input_ids for inputs in input_batch] 286 | attention_mask_list = [inputs.attention_mask for inputs in input_batch] 287 | token_type_ids_list = [inputs.token_type_ids for inputs in input_batch] 288 | labels_list = [inputs.labels for inputs in input_batch] 289 | 290 | (id_pad,) = tokenizer.convert_tokens_to_ids([tokenizer.pad_token]) 291 | input_ids = stack_tensor(input_ids_list, pad_value=id_pad, device=device) 292 | attention_mask = stack_tensor(attention_mask_list, pad_value=MASK_PAD, device=device) 293 | token_type_ids = stack_tensor(token_type_ids_list, pad_value=TOKEN_TYPE_PAD, device=device) 294 | 295 | if labels_list[0] is not None: 296 | labels = stack_tensor(labels_list, pad_value=LABEL_IGNORE, device=device) 297 | else: 298 | labels = None 299 | 300 | return input_ids, attention_mask, token_type_ids, labels 301 | 302 | 303 | def determine_correctness(outputs, answers): 304 | """Given dicts corresponding to predicted tokens and actual tokens at different indices, return 305 | a list of bools for whether or not those predictions were correct. 306 | 307 | Args: 308 | outputs (List[Dict[int, str]]): each list represents a different input masking, and each 309 | dict maps indices to model predictions 310 | answers (List[Dict[int, str]]): each list represents a different input masking, and each 311 | dict maps indices to original tokens 312 | 313 | Returns: 314 | correctness (List[bool]): a list of values that are True if the model made a correct 315 | prediction and False otherwise 316 | """ 317 | correctness = [] 318 | for output, answer in zip(outputs, answers): 319 | for idx, actual_token in answer.items(): 320 | predicted_token = output[idx] 321 | correctness.append(predicted_token == actual_token) 322 | 323 | return correctness 324 | 325 | 326 | def measure_relative(S): 327 | """Calculate the "measure-relative" score as defined in the paper 328 | 329 | Args: 330 | S (List[List[int]]): accuracy counts as defined in the paper 331 | 332 | Returns: 333 | score (float): measure-relative score 334 | """ 335 | denom = S[0][0] + S[1][1] + S[0][1] + S[1][0] 336 | if denom == 0: 337 | return 0 338 | return (S[0][1] - S[1][0]) / denom 339 | 340 | 341 | def measure_improve(S): 342 | """Calculate the "measure-improve" score as defined in the paper 343 | 344 | Args: 345 | S (List[List[int]]): accuracy counts as defined in the paper 346 | 347 | Returns: 348 | score (float): measure-improve score 349 | """ 350 | denom = S[0][0] + S[1][1] + S[0][1] 351 | if denom == 0: 352 | return 0 353 | return S[0][1] / denom 354 | 355 | 356 | def clean_text(text): 357 | """Return a cleaned version of the input text 358 | 359 | Args: 360 | text (str): dirty text 361 | 362 | Returns: 363 | text (str): cleaned text 364 | """ 365 | text = unicodedata.normalize('NFKD', text) 366 | return text 367 | 368 | 369 | def truncate_sentence_and_summary( 370 | sent, summary, len_sep=0, len_sent_allow_cut=0, truncate_bottom=True, 371 | ): 372 | """Cut summary+sentence to allowed input size. 2 more tokens: [CLS], [SEP] 373 | The summary must have at least one sublist (can be empty) 374 | The sentence is cut by tokens from the bottom. 375 | The summary is cut by sentences. Last sentence is cut by tokens. 376 | 377 | Args: 378 | sent (List[str]): Sentence as a list of tokens 379 | summary (List[List[str]]): Summary as list of sentences, each sentence is list of tokens 380 | len_sep (int): Number of tokens in a separator used between the summary and the sentence 381 | len_sent_allow_cut (int): Allowed size of truncated sentence before cutting summary 382 | truncate_bottom (bool): Indicator how to cut the summary 383 | 384 | Returns: 385 | sent (List[str]): Truncated (if necessary) sentence as a list of tokens 386 | summary_tokens (List[str]): Truncated (if necessary) summary as a list of tokens 387 | """ 388 | summary_tokens = [t for sublist in summary for t in sublist] 389 | len_input_estimate = 2 + len(summary_tokens) + len_sep + len(sent) 390 | len_excess = len_input_estimate - BERT_MAX_TOKENS 391 | if len_excess > 0: 392 | len_cut_sent = min(len_excess, len(sent) - len_sent_allow_cut) 393 | len_sent_new = len(sent) - len_cut_sent 394 | sent = sent[:len_sent_new] 395 | assert len_excess <= len_cut_sent or summary[0] 396 | if len_excess > len_cut_sent: 397 | len_summary_max = BERT_MAX_TOKENS - 2 - len_sep - len(sent) 398 | summary_truncated = truncate_list_of_lists( 399 | sents_tokenized=summary, num_max=len_summary_max, truncate_bottom=truncate_bottom, 400 | ) 401 | summary_tokens = [t for sublist in summary_truncated for t in sublist] 402 | assert len(sent) + len(summary_tokens) + len_sep + 2 <= BERT_MAX_TOKENS 403 | return sent, summary_tokens 404 | 405 | 406 | def truncate_list_of_lists(sents_tokenized, num_max, truncate_bottom=True): 407 | """Return a truncated list, with summ of tokens not exceeding maximum. 408 | Truncate by lists. If single left list is still too long, truncate it by tokens. 409 | In our context each element of sents_tokenized is a sentence represented as a list of tokens. 410 | 411 | Args: 412 | sents_tokenized (List[List[str]]): List, each element is a list. 413 | num_max (int): maximal allowed number of tokens. 414 | truncate_bottom (bool): truncate starting from bottom lists. 415 | 416 | Returns: 417 | sents_tokenized (List[str]): truncated list 418 | """ 419 | sents_truncated = [] 420 | if truncate_bottom: 421 | len_truncated = 0 422 | # Cut by sentences: 423 | for sent in sents_tokenized: 424 | len_truncated_maybe = len_truncated + len(sent) 425 | if len_truncated_maybe > num_max: 426 | break 427 | len_truncated = len_truncated_maybe 428 | sents_truncated.append(sent) 429 | if len_truncated == num_max: 430 | break 431 | else: 432 | sents_truncated = copy.deepcopy(sents_tokenized) 433 | len_truncated = sum([len(s) for s in sents_tokenized]) 434 | # Cut by sentences: 435 | for sent in sents_tokenized: 436 | if len_truncated <= num_max: 437 | break 438 | sents_truncated = sents_truncated[1:] 439 | len_truncated = len_truncated - len(sent) 440 | if not sents_truncated: 441 | sent_use = sents_tokenized[0] if truncate_bottom else sents_tokenized[-1] 442 | sents_truncated = [copy.deepcopy(sent_use)] 443 | len_truncated = len(sents_truncated[0]) 444 | # Cut by tokens - always from the top: 445 | if len_truncated > num_max: 446 | len_remove = len_truncated - num_max 447 | sents_truncated[0] = sents_truncated[0][len_remove:] 448 | assert sum([len(s) for s in sents_truncated]) <= num_max 449 | return sents_truncated 450 | -------------------------------------------------------------------------------- /blanc/estime.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | import copy 3 | import scipy 4 | from numpy import dot 5 | from numpy.linalg import norm 6 | 7 | import logging 8 | logging.getLogger('transformers').setLevel(level=logging.WARNING) 9 | 10 | import nltk 11 | from nltk.tokenize import word_tokenize 12 | 13 | import torch 14 | from transformers import BertForMaskedLM, BertTokenizer, BertModel 15 | 16 | 17 | class Estime: 18 | """Estimator of factual inconsistencies between summaries (or other textual claims) 19 | and text. Usage: create `Estime`, and use `evaluate_claims()`. 20 | In creating Estime, specify the names of the desired measures in the list 'output'. 21 | The function evaluate_claims() will return (for each claim) the list of results in 22 | the same order. The list 'output' can also include: 23 | 'alarms': the original ESTIME 24 | 'alarms_adjusted': the original ESTIME, extrapolated to non-overlapping tokens 25 | 'alarms_alltokens': ESTIME on all (not only overlapped) summary tokens 26 | 'soft': the soft ESTIME 27 | 'coherence': measure of summary coherence 28 | """ 29 | def __init__( 30 | self, 31 | path_mdl='bert-large-uncased-whole-word-masking', 32 | path_mdl_raw='bert-base-uncased', 33 | i_layer_context=21, 34 | device='cpu', 35 | output=['alarms'], 36 | tags_check=None, 37 | tags_exclude=None, 38 | input_size_max=450, 39 | margin=50, 40 | distance_word_min=8): 41 | """ 42 | Args: 43 | path_mdl (str): model for embeddings of masked tokens 44 | path_mdl_raw (str): model for raw embeddings 45 | i_layer_context (int): index of layer to take contextual embeddings from 46 | device (str): 'cpu' or 'cuda' 47 | tags_check (List[str]): list of parts of speech to use, each is a tag 48 | by NLTK notations. If None, then all words will be used, 49 | no matter which parts of speech they are. 50 | tags_exclude (List[str]): List or set of part-of-speach tags that should 51 | not be considered. Has priority over tags_check. 52 | input_size_max (int): length of input for the model as number of tokens 53 | margin (int): number of tokens on input edges not to take embeddings from 54 | distance_word_min (int): minimal distance between the masked tokens 55 | from which embeddings are to ne taken in a single input 56 | output (List[str]): list of names of measures to get in claim evaluations. 57 | The list must be nonempty and can include: 58 | 'alarms', 'alarms_adjusted', 'alarms_alltokens', 'soft', 'coherence'. 59 | The function evaluate_claims() will return (for each claim) the list of 60 | results in the same order. 61 | """ 62 | self.i_layer_context = i_layer_context 63 | self.device = device 64 | assert output 65 | self.output = output 66 | self.ESTIME_ALARMS = 'alarms' # original estime: by response of tokens to similar contexts 67 | self.ESTIME_ALLTOKENS = 'alarms_alltokens' # estime on all (not only overlapped) summary tokens 68 | self.ESTIME_ADJUSTED = 'alarms_adjusted' # original estime, extrapolated to non-overlapping tokens 69 | self.ESTIME_SOFT = 'soft' # soft estime by response of similarity between embeddings 70 | self.ESTIME_COHERENCE = 'coherence' # estimation of summary coherence 71 | self.get_estime_adjusted = self.ESTIME_ADJUSTED in self.output 72 | self.get_estime_soft = self.ESTIME_SOFT in self.output 73 | self.get_estime_coherence = self.ESTIME_COHERENCE in self.output 74 | self.tags_check = tags_check 75 | self.tags_exclude = tags_exclude 76 | self.input_size_max = input_size_max 77 | self.margin = margin 78 | self.distance_word_min = distance_word_min 79 | self.model_tokenizer = None 80 | self.model = None 81 | self.model_tokenizer = BertTokenizer.from_pretrained(path_mdl) 82 | self.model = BertForMaskedLM.from_pretrained(path_mdl, output_hidden_states = True).to(self.device) 83 | self.model.eval() 84 | if self.get_estime_soft: 85 | self.model_raw = BertModel.from_pretrained(path_mdl_raw) 86 | self.model_raw.eval() 87 | for param in self.model_raw.parameters(): 88 | param.requires_grad = False 89 | self.embeddings_raw = self.model_raw.get_input_embeddings() 90 | # Convenient data (including tokenized summary and text): 91 | self.text_map_wordscheck = None 92 | self.summ_toks = None 93 | self.text_toks = None 94 | self.embs_mask_text = None 95 | self.embs_raw_text = None 96 | 97 | 98 | def evaluate_claims(self, text, claims): 99 | """ 100 | Given a text, and a list of claims (e.g. summaries, or other texts), 101 | estimates how many likely factual inconsistencies each claim contains 102 | (the inconsistencies are with respect to the text). 103 | Returns for each claim whatever is specified in self.output 104 | Args: 105 | text (str): the main text 106 | claims (List[str]): texts ('claims') to be checked for consistency 107 | with the main text. Each claim is preferably a shorter text, 108 | like a claim/statement/summary. 109 | Returns: 110 | claims_info: list of the same length as claims; each element is a 111 | consistency info for the corresponding claim. The info is a list 112 | accordingly to the names in self.output. 113 | """ 114 | # Text: 115 | text = unicodedata.normalize('NFKD', text) 116 | text_words = word_tokenize(text) 117 | text_tagged = nltk.pos_tag(text_words) 118 | # Find all words of interest in the text, tokenize: 119 | self.text_map_wordscheck = self._get_map_words_intext(text_tagged) 120 | text_iwordscheck = sorted([item for sublist in self.text_map_wordscheck.values() for item in sublist]) 121 | self.text_toks, text_map_iword_itoks = self._translate_words_to_tokens(text_words) 122 | # All embeddings of interest in the text: 123 | self.embs_mask_text = self._get_embeddings( 124 | tokens=self.text_toks, 125 | ixs_words=text_iwordscheck, 126 | map_words_to_tokens=text_map_iword_itoks) 127 | self.embs_raw_text = self._get_embeddings_raw(tokens=self.text_toks) 128 | # Get the consistency info for each claim: 129 | claims_info = [] 130 | for claim in claims: 131 | claim = unicodedata.normalize('NFKD', claim) 132 | claim_info = self._evaluate_claim(claim) 133 | claims_info.append(claim_info) 134 | self.summ_toks = None 135 | self.text_toks = None 136 | return claims_info 137 | 138 | 139 | def _evaluate_claim(self, claim, words_check=None): 140 | """ 141 | Text is already processed, its embeddings can be used for the claim. 142 | Args: 143 | claim (str): claim, e.g. summary or short text - not the main text 144 | words_check (Set{str}): a set or map where the keys are the words 145 | of interest in the main text 146 | Returns: 147 | estime_info (List[float]): a list with results corresponding to the 148 | names of measures specified in self.output. 149 | """ 150 | summ_words = word_tokenize(claim) 151 | summ_tagged = nltk.pos_tag(summ_words) 152 | summ_iwordscheck, summ_iwords_overlap = [],[] # Find all words of interest in the summary 153 | for i, (w, t) in enumerate(summ_tagged): 154 | if not words_check or w in words_check: # if required, checking only what exists in the text 155 | summ_iwordscheck.append(i) 156 | if not self.text_map_wordscheck or w in self.text_map_wordscheck: 157 | summ_iwords_overlap.append(i) 158 | self.summ_toks, summ_map_iword_itoks = self._translate_words_to_tokens(summ_words) 159 | embs_mask_summ = self._get_embeddings( 160 | tokens=self.summ_toks, ixs_words=summ_iwordscheck, map_words_to_tokens=summ_map_iword_itoks) 161 | embs_raw_summ = self._get_embeddings_raw(tokens=self.summ_toks) 162 | summ_itoksoverlap = set() 163 | for iword in summ_iwords_overlap: 164 | itok = summ_map_iword_itoks[iword][0] # only first token of each word 165 | summ_itoksoverlap.add(itok) 166 | estime_info = self._evaluate( 167 | embs_mask_summ, self.embs_mask_text, 168 | embs_raw_summ, self.embs_raw_text, summ_itokscheck=summ_itoksoverlap) 169 | return estime_info 170 | 171 | 172 | def _get_embeddings_raw(self, tokens): 173 | """Simply gets raw embeddings. Needed only for estime-soft.""" 174 | if not self.get_estime_soft: 175 | return None 176 | toks_ids = self.model_tokenizer.convert_tokens_to_ids(tokens) 177 | input_tensor = torch.LongTensor([toks_ids]) 178 | word_embeddings = self.embeddings_raw(input_tensor)[0].numpy() 179 | return word_embeddings 180 | 181 | 182 | def _get_embeddings(self, tokens, ixs_words, map_words_to_tokens): 183 | """ 184 | Finds embeddings for all tokens of all words of interest. The embeddings 185 | are obtained one group of words at a time; each group contains well 186 | separated indexes, so that masked indexes do have enough context around. 187 | Args: 188 | tokens (List[str]): List of tokens, as strings. Represents the text. 189 | ixs_words (List[int]): List of indexes of words to check 190 | map_words_to_tokens Dict[int:(int,int)]: 191 | Maps each index of word in the text to its range of tokens, the 192 | range is: index of first token, index of one-past-last token. 193 | Returns: 194 | map_itok_embeddings (Dict{int: ndarray[float]}): 195 | Map of token index (in the text) to its obtained embedding 196 | """ 197 | # groups of well-separated words, represented by their indexes: 198 | groups = self._group_indexes_separated(ixs=ixs_words) 199 | map_itok_embeddings = {} 200 | for group in groups: 201 | map_itok_embeds = self._get_embeddings_of_sparse_words( 202 | tokens=tokens, 203 | ixs_words=group, 204 | map_words_to_tokens=map_words_to_tokens) 205 | map_itok_embeddings = {**map_itok_embeddings, **map_itok_embeds} 206 | return map_itok_embeddings 207 | 208 | 209 | def _get_embeddings_of_sparse_words(self, tokens, ixs_words, map_words_to_tokens): 210 | """Gets results for the get_embeddings function, which combines the results 211 | from tokens from groups of sparsely spread words. 212 | Here the result is obtained for one group of sparsely separated words. 213 | Args: 214 | tokens (List[str]): List of tokens, as strings. Represents the text. 215 | ixs_words (List[int]): List of indexes of words to check 216 | map_words_to_tokens Dict[int:(int,int)]: 217 | Maps each index of word in the text to its range of tokens, the 218 | range is: index of first token, index of one-past-last token. 219 | Returns: 220 | map_itok_embeds (Dict{int: ndarray[float]}): Map of token index (in text) 221 | to its obtained embedding 222 | """ 223 | map_itok_embeddings = {} 224 | toks_mask = [map_words_to_tokens[i] for i in ixs_words] # indexes (beg,end) of tokens to mask 225 | while toks_mask: 226 | i_tok_first = toks_mask[0][0] # first token allowed to mask 227 | ix_toks_input_beg = max(0, i_tok_first - self.margin) # input starts here 228 | ix_toks_input_end = min(len(tokens), ix_toks_input_beg + self.input_size_max) # input ends here 229 | i_tok_allowed_last = ix_toks_input_beg + self.input_size_max - self.margin # last token allowed to mask 230 | toks_mask_input = [] # tokens to be masked in the input 231 | for word_toks in toks_mask: 232 | if word_toks[0] >= i_tok_first and word_toks[1]-1 <= i_tok_allowed_last: 233 | toks_mask_input.append(word_toks) 234 | if word_toks[0] > i_tok_allowed_last: 235 | break 236 | # for preparing next input: 237 | n_words_used = len(toks_mask_input) 238 | toks_mask = toks_mask[n_words_used:] 239 | # get embeddings for the input: 240 | map_itok_embeds = self._get_embeddings_from_input( 241 | tokens, 242 | ix_toks_input_beg=ix_toks_input_beg, 243 | ix_toks_input_end=ix_toks_input_end, 244 | toks_mask_input=toks_mask_input) 245 | map_itok_embeddings = {**map_itok_embeddings, **map_itok_embeds} # from all inputs so far 246 | return map_itok_embeddings 247 | 248 | 249 | def _get_embeddings_from_input(self, tokens, ix_toks_input_beg, ix_toks_input_end, toks_mask_input): 250 | """ 251 | Gets embeddings for one specific input window. 252 | Returns embeddings for first tokens of each word of interest, while all 253 | tokens of the word are masked. 254 | Args: 255 | tokens (List[str]): Tokens of a summary or text 256 | ix_toks_input_beg (int): Index of first token of the input window 257 | ix_toks_input_end (int): Index of the end of the input window 258 | toks_mask_input (List[(int,int)]): Indexes of all tokens to mask 259 | in the input window, given as a list if duples, each duple 260 | is index of first and one-past last tokens to mask. 261 | Returns: 262 | map_itok_embeds (Dict{int: ndarray[float]}): Map of token indexes 263 | as in text tokens to their embeddings. Covers first tokens 264 | of all words of interest. 265 | """ 266 | # Ids of tokens in input for taking embedding 267 | input_toks = copy.deepcopy(tokens[ix_toks_input_beg:ix_toks_input_end]) 268 | map_itok_iembed = {} 269 | # Do masking, and also keep record of where to take embeddings from: 270 | for word_toks in toks_mask_input: 271 | i_tok_first = word_toks[0] # first token of the word 272 | map_itok_iembed[i_tok_first] = 1 + i_tok_first - ix_toks_input_beg # shift=1 by first [CLS] 273 | for i in range(word_toks[0], word_toks[1]): 274 | input_toks[i - ix_toks_input_beg] = '[MASK]' 275 | # Get embeddings of interest for this input: 276 | toks_ids = self.model_tokenizer.convert_tokens_to_ids(['[CLS]'] + input_toks + ['[SEP]']) 277 | input_tensor = torch.LongTensor([toks_ids]).to(self.device) 278 | outputs = self.model(input_tensor) 279 | # Contextual embedding: 280 | emb_all = outputs[1][self.i_layer_context][0] # all embeddings (for all tokens, at this layer) 281 | map_itok_embed = {} 282 | for itok, iembed in map_itok_iembed.items(): # itok is id of token exactly as in tokens 283 | map_itok_embed[itok] = emb_all[iembed].cpu().detach().numpy() 284 | return map_itok_embed 285 | 286 | 287 | def _evaluate(self, embs_summ, embs_text, embs_summ_raw, embs_text_raw, summ_itokscheck=None): 288 | """ 289 | Args: 290 | embs_summ (Dict{int: List[float]}): Map of token indexes 291 | (in summary) to the corresponding embeddings 292 | embs_text (Dict{int: List[float]}): Map of token indexes 293 | (in text) to the corresponding embeddings 294 | embs_summ_raw and embs_text_raw - the same as above, but with the raw embeddings 295 | summ_itokscheck (Set[int]): Set of indexes of tokens (in summary) 296 | that must be verified for alarms. 297 | This is needed for calculating the original and 'adjusted' ESTIME. 298 | Returns: 299 | List[float)] - List of results in order as specified in self.output 300 | """ 301 | # estime standard, adjusted, all-tokens and soft: 302 | n_alarms, n_alarms_alltoks, cos_raw_avg = 0, 0, 0 303 | itoks_similar = [] 304 | for itok_summ, emb_summ in embs_summ.items(): 305 | tok_summ = self.summ_toks[itok_summ] 306 | itok_text_best = -1 307 | sim_best = -1.0e30 308 | for itok_text, emb_text in embs_text.items(): 309 | sim = dot(emb_summ, emb_text) 310 | if sim > sim_best: 311 | sim_best = sim 312 | itok_text_best = itok_text 313 | tok_text_best = self.text_toks[itok_text_best] 314 | itoks_similar.append((itok_summ, itok_text_best, sim_best)) 315 | if tok_text_best != tok_summ: 316 | n_alarms_alltoks += 1 317 | if not summ_itokscheck or itok_summ in summ_itokscheck: 318 | n_alarms += 1 319 | # Soft estime: 320 | if self.get_estime_soft: 321 | emb_summ_nomask = embs_summ_raw[itok_summ] 322 | emb_text_nomask = embs_text_raw[itok_text_best] 323 | prod = dot(emb_summ_nomask, emb_text_nomask) 324 | norm_summ, norm_text = norm(emb_summ_nomask), norm(emb_text_nomask) 325 | cos_raw_avg += prod / (norm_summ * norm_text) 326 | if self.get_estime_soft: 327 | cos_raw_avg /= len(embs_summ) 328 | # estime-alarms-adjusted: 329 | if self.get_estime_adjusted: 330 | if not summ_itokscheck: 331 | n_alarms_adj = len(embs_summ) 332 | else: 333 | n_alarms_adj = n_alarms * len(embs_summ) / len(summ_itokscheck) 334 | # Coherence: 335 | if self.get_estime_coherence: 336 | itoks_summ = [a[0] for a in itoks_similar] 337 | itoks_text = [a[1] for a in itoks_similar] 338 | coherence = scipy.stats.kendalltau(itoks_summ, itoks_text, variant='c').correlation 339 | result = [] 340 | for out_name in self.output: 341 | if out_name == self.ESTIME_ALARMS: 342 | result.append(n_alarms) 343 | elif out_name == self.ESTIME_ADJUSTED: 344 | result.append(n_alarms_adj) 345 | elif out_name == self.ESTIME_ALLTOKENS: 346 | result.append(n_alarms_alltoks) 347 | elif out_name == self.ESTIME_SOFT: 348 | result.append(cos_raw_avg) 349 | elif out_name == self.ESTIME_COHERENCE: 350 | result.append(coherence) 351 | return result 352 | 353 | 354 | def _select_indexes_separated(self, ixs): 355 | """Given a list of sorted integers, starts with the first and selects next ones 356 | in such way that the difference between neighbors is not smaller than the given 357 | value. Meaning: the integers are the indexes of words in a text. 358 | Args: 359 | ixs (List[int]): list of indexes 360 | Returns: 361 | ixs_select (List[int]): list of well-separated selected indexes 362 | ixs_remain (List[int]): list of all the remaining indexes 363 | """ 364 | ixs_remain = [] 365 | ixs_select = [] 366 | ix_prev = -1000000 367 | for ix in ixs: 368 | if ix - ix_prev >= self.distance_word_min: 369 | ixs_select.append(ix) 370 | ix_prev = ix 371 | else: 372 | ixs_remain.append(ix) 373 | return ixs_select, ixs_remain 374 | 375 | 376 | def _group_indexes_separated(self, ixs): 377 | """Splits a sorted list of indexes (of words in a text) to groups (lists) 378 | of indexes, such that indexes in each groups are separated by the given 379 | minimal distance. 380 | Args: 381 | ixs (List[int]): list of indexes 382 | Returns: 383 | groups (List[List[int]]): list of lists of indexes. Each list of indexes 384 | contains well-separated indexes, satisfying the distance_word_min. 385 | """ 386 | groups = [] 387 | ixs_remain = copy.deepcopy(ixs) 388 | while ixs_remain: 389 | ixs_select, ixs_remain = self._select_indexes_separated(ixs_remain) 390 | groups.append(ixs_select) 391 | return groups 392 | 393 | 394 | def _get_map_words_intext(self, text_tagged): 395 | """ 396 | Creates dictionary of words in the text, with all occurrences for each word 397 | Args: 398 | text_tagged (List[(str,str)]): List of duples, each is word and its 399 | part-of-speach tag. The list is result of nltk.pos_tag function. 400 | Returns: 401 | map_words_text Dict{str:List[int]}: Dictionary, key is word from the text, 402 | value is List[int] - list of all word occurrence indexes in the text 403 | """ 404 | map_words_text = {} 405 | for i, (w, t) in enumerate(text_tagged): 406 | if self.tags_check and t not in self.tags_check: 407 | continue 408 | if self.tags_exclude and t in self.tags_exclude: 409 | continue 410 | if w not in map_words_text: 411 | map_words_text[w] = [i] 412 | else: 413 | map_words_text[w].append(i) 414 | return map_words_text 415 | 416 | 417 | def _translate_words_to_tokens(self, text_words): 418 | """Tokenizes text by model tokenizer. 419 | Keeps map of indexes of words into indexes of tokens. 420 | Args: 421 | text_words (List[str]): Text given as a list of words 422 | Returns: 423 | text_tokens (List[str]): Text given as list of tokens 424 | map_iword_itoks (Dict[int:(int,int)]): Dictionary of the same length 425 | as text_words. Word index points to duple of token indexes: 426 | index of the first token, and index of the end (one-past-last) token. 427 | """ 428 | text_tokens = [] 429 | map_iword_itoks = {} 430 | i_tok = 0 431 | for ix_word, word in enumerate(text_words): 432 | toks = self.model_tokenizer.tokenize(word) 433 | text_tokens.extend(toks) 434 | i_tok_end = i_tok + len(toks) 435 | map_iword_itoks[ix_word] = (i_tok, i_tok_end) 436 | i_tok = i_tok_end 437 | return text_tokens, map_iword_itoks 438 | -------------------------------------------------------------------------------- /blanc/blanc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import random 4 | 5 | logging.getLogger('transformers').setLevel(level=logging.WARNING) 6 | 7 | from nltk.tokenize import sent_tokenize 8 | import torch 9 | from torch.nn.utils.rnn import pad_sequence 10 | import tqdm 11 | from transformers import BertForMaskedLM, BertTokenizer, AdamW, get_linear_schedule_with_warmup 12 | from transformers import AlbertForMaskedLM, AlbertTokenizer 13 | 14 | from blanc.utils import ( 15 | BertInput, 16 | Defaults, 17 | batch_data, 18 | mask_tokens_evenly, 19 | mask_tokens_randomly, 20 | get_input_tensors, 21 | determine_correctness, 22 | measure_relative, 23 | measure_improve, 24 | clean_text, 25 | truncate_list_of_lists, 26 | truncate_sentence_and_summary, 27 | set_seed, 28 | NOT_MASKED, 29 | TOKEN_TYPE_A, 30 | LABEL_IGNORE, 31 | BERT_MAX_TOKENS, 32 | TOKEN_REPLACE_RANGE, 33 | ) 34 | 35 | 36 | class Blanc: 37 | """An abstract superclass containing shared functionality between BlancHelp and BlancTune. 38 | measure ('relative' or 'improve') is a choice of how the success of inference is measured. 39 | Add '-counts' to return also counts: 'relative-counts' or 'improve-counts'. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | model_name=Defaults.model_name, 45 | measure=Defaults.measure, 46 | gap=Defaults.gap, 47 | gap_mask=Defaults.gap_mask, 48 | gap_tune=Defaults.gap_tune, 49 | gap_mask_tune=Defaults.gap_mask_tune, 50 | min_token_length_normal=Defaults.min_token_length_normal, 51 | min_token_length_lead=Defaults.min_token_length_lead, 52 | min_token_length_followup=Defaults.min_token_length_followup, 53 | min_token_length_normal_tune=Defaults.min_token_length_normal_tune, 54 | min_token_length_lead_tune=Defaults.min_token_length_lead_tune, 55 | min_token_length_followup_tune=Defaults.min_token_length_followup_tune, 56 | device=Defaults.device, 57 | inference_batch_size=Defaults.inference_batch_size, 58 | inference_mask_evenly=Defaults.inference_mask_evenly, 59 | len_sent_allow_cut=Defaults.len_sent_allow_cut, 60 | p_mask=Defaults.p_mask, 61 | show_progress_bar=Defaults.show_progress_bar, 62 | ): 63 | """This class should not be instantiated directly: instead use BlancHelp or BlancTune""" 64 | self.model_name = model_name 65 | self.measure = measure 66 | self.gap = gap 67 | self.gap_mask = gap_mask 68 | self.gap_tune = gap_tune 69 | self.gap_mask_tune = gap_mask_tune 70 | self.min_token_length_normal = min_token_length_normal 71 | self.min_token_length_lead = min_token_length_lead 72 | self.min_token_length_followup = min_token_length_followup 73 | self.min_token_length_normal_tune = min_token_length_normal_tune 74 | self.min_token_length_lead_tune = min_token_length_lead_tune 75 | self.min_token_length_followup_tune = min_token_length_followup_tune 76 | self.device = device 77 | self.inference_batch_size = inference_batch_size 78 | self.inference_mask_evenly = inference_mask_evenly 79 | self.len_sent_allow_cut = len_sent_allow_cut 80 | self.p_mask = p_mask 81 | self.show_progress_bar = show_progress_bar 82 | 83 | # The same is intentionally not given: 84 | self.gap_tune = self.gap if self.gap_tune < 0 else self.gap_tune 85 | self.gap_mask_tune = self.gap_mask if self.gap_mask_tune < 0 else self.gap_mask_tune 86 | 87 | if self.model_name.lower().find('albert') >= 0: 88 | self.model_tokenizer = AlbertTokenizer.from_pretrained(model_name) 89 | else: 90 | self.model_tokenizer = BertTokenizer.from_pretrained(model_name) 91 | 92 | def eval_once(self, doc, summary): 93 | """Calculate the BLANC score for a single doc with a single summary. 94 | 95 | Args: 96 | doc (str): The input document 97 | summary (str): The input summary for the input document 98 | 99 | Returns: 100 | score (float): The BLANC score for the input 101 | """ 102 | (doc_score,) = self.eval_summaries_for_docs([doc], [[summary]]) 103 | (score,) = doc_score 104 | return score 105 | 106 | def eval_pairs(self, docs, summaries): 107 | """Calculate the BLANC score for multiple docs, each with a single summary 108 | 109 | Args: 110 | docs (List[str]): A list of input documents 111 | summaries (List[str]): The input summary for each input document 112 | 113 | Returns: 114 | scores (List[float]): The BLANC scores for the inputs 115 | """ 116 | doc_summaries = [[summary] for summary in summaries] 117 | full_scores = self.eval_summaries_for_docs(docs, doc_summaries) 118 | scores = [score for score, in full_scores] 119 | return scores 120 | 121 | def eval_summaries_for_docs(self, docs, doc_summaries): 122 | """Calculate the BLANC score for multiple docs, each with multiple summaries 123 | 124 | Args: 125 | docs (List[str]): A list of input documents 126 | doc_summaries (List[List[str]]): A list of summaries for every input document 127 | 128 | Returns: 129 | scores (List[List[float]]): A list of blanc scores corresponding to each summary for 130 | each document 131 | """ 132 | raise NotImplementedError() 133 | 134 | def get_inputs_for_sentence(self, sent_tokens, summary_tokens): 135 | """Used by subclasses to specify inference inputs corresponding to a sentence 136 | 137 | Args: 138 | sent_tokens (List[str]): list of tokens corresponding to sentence 139 | summary_tokens (List[str]): list of tokens corresponding to a summary 140 | sep (List[str]): List of tokens corresponding to a separator between summary and sentence 141 | 142 | Returns: 143 | inputs (List[BertInput]): a list of masked token inputs to BERT 144 | answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps 145 | token indices corresponding to masked tokens back to their original token. 146 | """ 147 | raise NotImplementedError() 148 | 149 | def mask_and_infer(self, model, docs, doc_summaries, sep=None): 150 | """Run the given model on masked versions of the provided doc_summaries and collect model 151 | output 152 | 153 | Args: 154 | model (BertForMaskedLM): a BERT for masked language modeling torch model 155 | docs (List[str]): A list of input documents 156 | doc_summaries (List[List[str]]): A list of summaries for every input document 157 | sep (str): Separator between the inference help (summary) and a sentence from the doc 158 | 159 | Returns: 160 | all_outputs (List[List[List[Dict[int, str]]]]): for each doc, for each summary for the 161 | doc, for each input sequence for the summary, we have a dict mapping indices to 162 | model predictions 163 | all_answers (List[List[List[Dict[int, str]]]]): for each doc, for each summary for the 164 | doc, for each input sequence for the summary, we have a dict mapping indices to 165 | original tokens 166 | """ 167 | 168 | # Prepare inputs 169 | all_inputs, all_answers = [], [] 170 | for doc, summaries in zip(docs, doc_summaries): 171 | doc_inputs, doc_answers = [], [] 172 | for summary in summaries: 173 | summary_inputs, summary_answers = self.get_inference_inputs(doc, summary, sep) 174 | doc_inputs.append(summary_inputs) 175 | doc_answers.append(summary_answers) 176 | all_inputs.append(doc_inputs) 177 | all_answers.append(doc_answers) 178 | 179 | # Run inference in batches 180 | inputs_per_summary_per_doc = [ 181 | [len(inputs) for inputs in summary_input] for summary_input in all_inputs 182 | ] 183 | collapsed_inputs = sum(sum(all_inputs, []), []) 184 | batched_inputs = batch_data(collapsed_inputs, self.inference_batch_size) 185 | 186 | iterator = tqdm.tqdm(batched_inputs, disable=not self.show_progress_bar) 187 | batched_outputs = [self.run_inference_batch(model, batch) for batch in iterator] 188 | collapsed_outputs = sum(batched_outputs, []) 189 | 190 | # Regroup outputs 191 | i = 0 192 | all_outputs = [] 193 | for inputs_per_summary in inputs_per_summary_per_doc: 194 | doc_outputs = [] 195 | for num_inputs in inputs_per_summary: 196 | doc_outputs.append(collapsed_outputs[i : i + num_inputs]) 197 | i += num_inputs 198 | all_outputs.append(doc_outputs) 199 | 200 | return all_outputs, all_answers 201 | 202 | def get_inference_inputs(self, doc, summary=None, sep=None): 203 | """Get the inference inputs for a document, which possibly includes a summary 204 | 205 | Args: 206 | doc (str): an input document 207 | summary (str): an optional input summary 208 | sep (str): Separator between the inference help (summary) and a sentence from the doc 209 | 210 | Returns: 211 | summary_inputs (List[BertInput]): a list of BertInputs for inference 212 | summary_answers (List[Dict[int, str]]): each dict maps token indices back to their 213 | original token 214 | """ 215 | doc = clean_text(doc) 216 | doc_sents = sent_tokenize(doc) 217 | doc_sent_tokens = [self.model_tokenizer.tokenize(sent) for sent in doc_sents] 218 | 219 | summary_sent_tokens = None 220 | if summary: 221 | summary = clean_text(summary) 222 | summary_sents = sent_tokenize(summary) 223 | summary_sent_tokens = [self.model_tokenizer.tokenize(sent) for sent in summary_sents] 224 | if not summary_sent_tokens: 225 | summary_sent_tokens = [[]] 226 | 227 | len_sep = 0 228 | if sep: 229 | len_sep = len(sep) 230 | 231 | summary_inputs, summary_answers = [], [] 232 | half_num_sents = len(doc_sent_tokens) 233 | truncate_bottom = True 234 | for i_sent, sent_tokens in enumerate(doc_sent_tokens): 235 | if i_sent > half_num_sents: 236 | truncate_bottom = False 237 | sent_tokens, summary_tokens = truncate_sentence_and_summary( 238 | sent=sent_tokens, 239 | summary=summary_sent_tokens, 240 | len_sep=len_sep, 241 | len_sent_allow_cut=self.len_sent_allow_cut, 242 | truncate_bottom=truncate_bottom, 243 | ) 244 | # now it is assured that everything fits into the allowed input size: 245 | assert len(sent_tokens) + len(summary_tokens) + len_sep + 2 <= BERT_MAX_TOKENS 246 | inputs, answers = self.get_inputs_for_sentence(sent_tokens, summary_tokens) 247 | summary_inputs += inputs 248 | summary_answers += answers 249 | return summary_inputs, summary_answers 250 | 251 | def assemble_inference_input(self, answers, sent_tokens, help_tokens=None, help_sep=None): 252 | """Given input tokens, assemble them into the tensors used by the model for inference 253 | 254 | Args: 255 | answers (Dict[int, str]): a mapping of input token indices to their original value 256 | sent_tokens (List[str]): tokens corresponding to an input sentence 257 | help_tokens (List[str]): tokens corresponding to an input summary or filler 258 | help_sep (List[str]): tokens to put between the summary/filler and the sentence 259 | 260 | Returns: 261 | model_input (BertInput): an input to the BERT model 262 | shifted_answers (Dict[int, str]): the input answers but with shifted indices that take 263 | into account the summary/filler and starting CLS token 264 | 265 | Raises: 266 | ValueError: if the sentence itself is longer than the BERT_MAX_TOKENS limit, we raise 267 | this error as opposed to truncating the sentence 268 | """ 269 | if not help_tokens: 270 | help_tokens = [] 271 | if not help_sep: 272 | help_sep = [] 273 | 274 | all_tokens = ( 275 | [self.model_tokenizer.cls_token] 276 | + help_tokens 277 | + help_sep 278 | + sent_tokens 279 | + [self.model_tokenizer.sep_token] 280 | ) 281 | 282 | input_ids = self.model_tokenizer.convert_tokens_to_ids(all_tokens) 283 | token_type_ids = [TOKEN_TYPE_A] * len(all_tokens) 284 | attention_mask = [NOT_MASKED] * len(all_tokens) 285 | 286 | offset = 1 + len(help_tokens) + len(help_sep) 287 | shifted_answers = {} 288 | for idx, token in answers.items(): 289 | shifted_answers[idx + offset] = token 290 | 291 | model_input = BertInput( 292 | input_ids=input_ids, 293 | token_type_ids=token_type_ids, 294 | attention_mask=attention_mask, 295 | labels=None, 296 | masked_idxs=list(shifted_answers.keys()), 297 | ) 298 | 299 | return model_input, shifted_answers 300 | 301 | def run_inference_batch(self, model, batch): 302 | """Run an inference batch through the provided model 303 | 304 | Args: 305 | model (BertForMaskedLM): a BERT for masked language modeling torch model 306 | batch (List[BertInput]): the input batch to run through the model 307 | 308 | Returns: 309 | all_predictions (List[Dict[int, str]]): predicted tokens for every masked token in 310 | the inputs 311 | """ 312 | input_ids, attention_mask, token_type_ids, _ = get_input_tensors( 313 | batch, device=self.device, tokenizer=self.model_tokenizer, 314 | ) 315 | 316 | with torch.no_grad(): 317 | (model_output_batch,) = model( 318 | input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, 319 | ) 320 | 321 | all_predictions = [] 322 | for model_input, model_output in zip(batch, model_output_batch): 323 | predictions = {} 324 | for idx in model_input.masked_idxs: 325 | predicted_id = model_output[idx].argmax() 326 | (predicted_token,) = self.model_tokenizer.convert_ids_to_tokens([predicted_id]) 327 | predictions[idx] = predicted_token 328 | all_predictions.append(predictions) 329 | 330 | return all_predictions 331 | 332 | def mask_input_tokens(self, tokens, is_finetune): 333 | """Given a list of tokens, produce maskings for them 334 | 335 | Args: 336 | tokens (List[str]): a sequence of wordpiece tokens 337 | is_finetune (bool): whether or not these tokens are going to be used for finetuning 338 | 339 | Returns: 340 | masked_inputs (List[List[str]]): a list of token sequences, where each token sequence 341 | contains some masked tokens. 342 | all_answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps 343 | token indices corresponding to masked tokens back to their original token. 344 | """ 345 | if is_finetune: 346 | even_masking = self.finetune_mask_evenly 347 | else: 348 | even_masking = self.inference_mask_evenly 349 | 350 | min_token_lengths = ( 351 | self.min_token_length_normal, 352 | self.min_token_length_lead, 353 | self.min_token_length_followup, 354 | ) 355 | if is_finetune: 356 | min_token_lengths = ( 357 | self.min_token_length_normal_tune, 358 | self.min_token_length_lead_tune, 359 | self.min_token_length_followup_tune, 360 | ) 361 | 362 | if even_masking: 363 | gap_use = self.gap 364 | gap_mask_use = self.gap_mask 365 | if is_finetune: 366 | gap_use = self.gap_tune 367 | gap_mask_use = self.gap_mask_tune 368 | return mask_tokens_evenly( 369 | tokens=tokens, 370 | gap=gap_use, 371 | min_token_lengths=min_token_lengths, 372 | mask_token=self.model_tokenizer.mask_token, 373 | gap_mask=gap_mask_use, 374 | ) 375 | else: 376 | return mask_tokens_randomly( 377 | tokens=tokens, 378 | min_token_lengths=min_token_lengths, 379 | mask_token=self.model_tokenizer.mask_token, 380 | p_mask=self.p_mask, 381 | ) 382 | 383 | def judge_output(self, base_output, assisted_output, base_answers, assisted_answers): 384 | """Given a model's predicted tokens with and without assistance, as well as the correct 385 | token predictions, produce the BLANC score 386 | 387 | Args: 388 | base_outputs (List[Dict[int, str]]): outputs without using "help" or "tune." Each list 389 | represents a different input masking, and each dict maps indices to model 390 | predictions. 391 | assisted_outputs (List[Dict[int, str]]): outputs using "help" or "tune." Each list 392 | represents a different input masking, and each dict maps indices to model 393 | predictions. 394 | base_answers (List[Dict[int, str]]): answers without using "help" or "tune." Each list 395 | represents a different input masking, and each dict maps indices to original 396 | tokens. 397 | assisted_answers (List[Dict[int, str]]): answers using "help" or "tune." Each 398 | list represents a different input masking, and each dict maps indices to original 399 | tokens. 400 | 401 | Returns: 402 | score (float): the BLANC score, if the measure is 'relative' or 'improve'. 403 | score, S (tuple of float and list): the BLANC score and counts, 404 | if the measure is 'relative-counts' or 'improve-counts'. 405 | """ 406 | base_correctness = determine_correctness(base_output, base_answers) 407 | assisted_correctness = determine_correctness(assisted_output, assisted_answers) 408 | 409 | S = [[0, 0], [0, 0]] 410 | for base_correct, assisted_correct in zip(base_correctness, assisted_correctness): 411 | S[int(base_correct)][int(assisted_correct)] += 1 412 | 413 | measure_split = self.measure.split("-") 414 | if measure_split[0] == 'relative': 415 | result = measure_relative(S) 416 | if self.measure == 'relative-counts': 417 | result = result, S 418 | elif measure_split[0] == 'improve': 419 | result = measure_improve(S) 420 | if self.measure == 'improve-counts': 421 | result = result, S 422 | else: 423 | raise NotImplementedError() 424 | 425 | return result 426 | 427 | def init_model(self, device): 428 | """Initialize the language model and send it to the given device 429 | Note: Transformers v.4 and higher made default return_dict=True. 430 | Args: 431 | device (str): torch device (usually "cpu" or "cuda") 432 | 433 | Returns: 434 | model: a model for masked language modeling torch model 435 | """ 436 | model = None 437 | if self.model_name.lower().find('albert') >= 0: 438 | try: 439 | model = AlbertForMaskedLM.from_pretrained(self.model_name, return_dict=False).to(device) 440 | except: 441 | model = AlbertForMaskedLM.from_pretrained(self.model_name).to(device) 442 | else: 443 | try: 444 | model = BertForMaskedLM.from_pretrained(self.model_name, return_dict=False).to(device) 445 | except: 446 | model = BertForMaskedLM.from_pretrained(self.model_name).to(device) 447 | model.eval() 448 | return model 449 | 450 | 451 | class BlancHelp(Blanc): 452 | """BLANC-help, as defined in the BLANC paper.""" 453 | 454 | def __init__( 455 | self, 456 | model_name=Defaults.model_name, 457 | measure=Defaults.measure, 458 | gap=Defaults.gap, 459 | gap_mask=Defaults.gap_mask, 460 | gap_tune=Defaults.gap_tune, 461 | gap_mask_tune=Defaults.gap_mask_tune, 462 | min_token_length_normal=Defaults.min_token_length_normal, 463 | min_token_length_lead=Defaults.min_token_length_lead, 464 | min_token_length_followup=Defaults.min_token_length_followup, 465 | min_token_length_normal_tune=Defaults.min_token_length_normal_tune, 466 | min_token_length_lead_tune=Defaults.min_token_length_lead_tune, 467 | min_token_length_followup_tune=Defaults.min_token_length_followup_tune, 468 | device=Defaults.device, 469 | inference_batch_size=Defaults.inference_batch_size, 470 | inference_mask_evenly=Defaults.inference_mask_evenly, 471 | len_sent_allow_cut=Defaults.len_sent_allow_cut, 472 | filler_token=Defaults.filler_token, 473 | help_sep=Defaults.help_sep, 474 | p_mask=Defaults.p_mask, 475 | show_progress_bar=Defaults.show_progress_bar, 476 | ): 477 | """See CLI documentation (blanc --help) for information about each arg""" 478 | super().__init__( 479 | model_name=model_name, 480 | measure=measure, 481 | gap=gap, 482 | gap_mask=gap_mask, 483 | gap_tune=gap_tune, 484 | gap_mask_tune=gap_mask_tune, 485 | min_token_length_normal=min_token_length_normal, 486 | min_token_length_lead=min_token_length_lead, 487 | min_token_length_followup=min_token_length_followup, 488 | min_token_length_normal_tune=min_token_length_normal_tune, 489 | min_token_length_lead_tune=min_token_length_lead_tune, 490 | min_token_length_followup_tune=min_token_length_followup_tune, 491 | device=device, 492 | inference_batch_size=inference_batch_size, 493 | inference_mask_evenly=inference_mask_evenly, 494 | len_sent_allow_cut=len_sent_allow_cut, 495 | p_mask=p_mask, 496 | show_progress_bar=show_progress_bar, 497 | ) 498 | 499 | self.filler_token = filler_token 500 | self.help_sep = self.model_tokenizer.tokenize(help_sep) 501 | self.model = self.init_model(self.device) 502 | 503 | def eval_summaries_for_docs(self, docs, doc_summaries): 504 | """Calculate the BLANC score for multiple docs, each with multiple summaries. 505 | See documentation in superclass. 506 | """ 507 | all_outputs, all_answers = self.mask_and_infer( 508 | self.model, docs, doc_summaries, sep=self.help_sep 509 | ) 510 | 511 | all_scores = [] 512 | for doc_outputs, doc_answers in zip(all_outputs, all_answers): 513 | doc_scores = [] 514 | for summary_output, summary_answers in zip(doc_outputs, doc_answers): 515 | help_output = [out for i, out in enumerate(summary_output) if i % 2 == 0] 516 | filler_output = [out for i, out in enumerate(summary_output) if i % 2 == 1] 517 | help_answers = [answer for i, answer in enumerate(summary_answers) if i % 2 == 0] 518 | filler_answers = [answer for i, answer in enumerate(summary_answers) if i % 2 == 1] 519 | 520 | score = self.judge_output(filler_output, help_output, filler_answers, help_answers) 521 | doc_scores.append(score) 522 | all_scores.append(doc_scores) 523 | 524 | return all_scores 525 | 526 | def get_inputs_for_sentence(self, sent_tokens, summary_tokens): 527 | """Get inference inputs corresponding to a given sentence. For BLANC-help, we get several 528 | maskings for each sentence, and for each of these maskings we have an input with the 529 | summary prepended, and an input with a filler prepended. See documentation in superclass. 530 | """ 531 | sent_maskings, init_answers = self.mask_input_tokens(sent_tokens, is_finetune=False) 532 | 533 | filler_tokens = [self.filler_token] * len(summary_tokens) 534 | inputs, final_answers = [], [] 535 | for sent_masking, init_answer in zip(sent_maskings, init_answers): 536 | help_input, help_answers = self.assemble_inference_input( 537 | answers=init_answer, 538 | sent_tokens=sent_masking, 539 | help_tokens=summary_tokens, 540 | help_sep=self.help_sep, 541 | ) 542 | 543 | filler_input, filler_answers = self.assemble_inference_input( 544 | answers=init_answer, 545 | sent_tokens=sent_masking, 546 | help_tokens=filler_tokens, 547 | help_sep=self.help_sep, 548 | ) 549 | 550 | inputs += [help_input, filler_input] 551 | final_answers += [help_answers, filler_answers] 552 | 553 | return inputs, final_answers 554 | 555 | 556 | class BlancTune(Blanc): 557 | """BLANC-tune, as defined in the BLANC paper.""" 558 | 559 | def __init__( 560 | self, 561 | model_name=Defaults.model_name, 562 | measure=Defaults.measure, 563 | gap=Defaults.gap, 564 | gap_mask=Defaults.gap_mask, 565 | gap_tune=Defaults.gap_tune, 566 | gap_mask_tune=Defaults.gap_mask_tune, 567 | min_token_length_normal=Defaults.min_token_length_normal, 568 | min_token_length_lead=Defaults.min_token_length_lead, 569 | min_token_length_followup=Defaults.min_token_length_followup, 570 | device=Defaults.device, 571 | min_token_length_normal_tune=Defaults.min_token_length_normal_tune, 572 | min_token_length_lead_tune=Defaults.min_token_length_lead_tune, 573 | min_token_length_followup_tune=Defaults.min_token_length_followup_tune, 574 | inference_batch_size=Defaults.inference_batch_size, 575 | inference_mask_evenly=Defaults.inference_mask_evenly, 576 | finetune_batch_size=Defaults.finetune_batch_size, 577 | finetune_epochs=Defaults.finetune_epochs, 578 | finetune_mask_evenly=Defaults.finetune_mask_evenly, 579 | len_sent_allow_cut=Defaults.len_sent_allow_cut, 580 | finetune_chunk_size=Defaults.finetune_chunk_size, 581 | finetune_chunk_stride=Defaults.finetune_chunk_stride, 582 | finetune_top_fully=Defaults.finetune_top_fully, 583 | id_layer_freeze_below=Defaults.id_layer_freeze_below, 584 | id_layer_freeze_above=Defaults.id_layer_freeze_above, 585 | show_progress_bar=Defaults.show_progress_bar, 586 | p_mask=Defaults.p_mask, 587 | p_token_replace=Defaults.p_token_replace, 588 | p_token_original=Defaults.p_token_original, 589 | learning_rate=Defaults.learning_rate, 590 | warmup_steps=Defaults.warmup_steps, 591 | random_seed=Defaults.random_seed, 592 | ): 593 | """See CLI documentation (blanc --help) for information about each arg""" 594 | super().__init__( 595 | model_name=model_name, 596 | measure=measure, 597 | gap=gap, 598 | gap_mask=gap_mask, 599 | gap_tune=gap_tune, 600 | gap_mask_tune=gap_mask_tune, 601 | min_token_length_normal=min_token_length_normal, 602 | min_token_length_lead=min_token_length_lead, 603 | min_token_length_followup=min_token_length_followup, 604 | min_token_length_normal_tune=min_token_length_normal_tune, 605 | min_token_length_lead_tune=min_token_length_lead_tune, 606 | min_token_length_followup_tune=min_token_length_followup_tune, 607 | device=device, 608 | inference_batch_size=inference_batch_size, 609 | inference_mask_evenly=inference_mask_evenly, 610 | len_sent_allow_cut=len_sent_allow_cut, 611 | ) 612 | 613 | self.finetune_batch_size = finetune_batch_size 614 | self.finetune_epochs = finetune_epochs 615 | self.finetune_mask_evenly = finetune_mask_evenly 616 | self.finetune_chunk_size = finetune_chunk_size 617 | self.finetune_chunk_stride = finetune_chunk_stride 618 | self.finetune_top_fully = finetune_top_fully 619 | self.id_layer_freeze_below = id_layer_freeze_below 620 | self.id_layer_freeze_above = id_layer_freeze_above 621 | self.show_progress_bar = show_progress_bar 622 | self.p_mask = p_mask 623 | self.p_token_replace = p_token_replace 624 | self.p_token_original = p_token_original 625 | self.learning_rate = learning_rate 626 | self.warmup_steps = warmup_steps 627 | self.random_seed = random_seed 628 | 629 | # The same is intentionally not given: 630 | self.gap_tune = self.gap if self.gap_tune < 0 else self.gap_tune 631 | self.gap_mask_tune = self.gap_mask if self.gap_mask_tune < 0 else self.gap_mask_tune 632 | self.min_token_length_normal_tune = self.min_token_length_normal if self.min_token_length_normal_tune < 0 else self.min_token_length_normal_tune 633 | self.min_token_length_lead_tune = self.min_token_length_lead if self.min_token_length_lead_tune < 0 else self.min_token_length_lead_tune 634 | self.min_token_length_followup_tune = self.min_token_length_followup if self.min_token_length_followup_tune < 0 else self.min_token_length_followup_tune 635 | 636 | self.base_model = self.init_model(self.device) 637 | 638 | def eval_summaries_for_docs(self, docs, doc_summaries): 639 | """Calculate the BLANC score for multiple docs, each with multiple summaries. 640 | See documentation in superclass. 641 | Note that a summary should not be used in any ways for base outputs and base answers. 642 | When a finetuned model is used, the summary can used, meaning that 'help' and 'tune' versions of BLANC are put together. 643 | """ 644 | 645 | doc_summaries_use = [[None for s in summs] for summs in doc_summaries] 646 | base_outputs, base_answers = self.mask_and_infer(self.base_model, docs, doc_summaries_use) 647 | 648 | finetuned_outputs, finetuned_answers = [], [] 649 | model_cpu = self.init_model(device='cpu') 650 | for doc, summaries in tqdm.tqdm(zip(docs, doc_summaries), total=len(docs), disable=not self.show_progress_bar): 651 | finetuned_doc_outputs, finetuned_doc_answers = [], [] 652 | for summary in summaries: 653 | model_copy = copy.deepcopy(model_cpu) 654 | finetuned_model = model_copy.to(self.device) 655 | self.finetune(finetuned_model, summary) 656 | 657 | (finetuned_summary_output,), (finetuned_summary_answer,) = self.mask_and_infer( 658 | finetuned_model, [doc], [[None]] 659 | ) 660 | finetuned_doc_outputs += finetuned_summary_output 661 | finetuned_doc_answers += finetuned_summary_answer 662 | 663 | del finetuned_model 664 | torch.cuda.empty_cache() 665 | 666 | finetuned_outputs.append(finetuned_doc_outputs) 667 | finetuned_answers.append(finetuned_doc_answers) 668 | 669 | all_scores = [ 670 | [ 671 | self.judge_output( 672 | base_summary_output, 673 | finetuned_summary_output, 674 | base_summary_answers, 675 | finetuned_summary_answers, 676 | ) 677 | for ( 678 | base_summary_output, 679 | base_summary_answers, 680 | finetuned_summary_output, 681 | finetuned_summary_answers, 682 | ) in zip( 683 | base_doc_output, base_doc_answers, finetuned_doc_output, finetuned_doc_answers, 684 | ) 685 | ] 686 | for ( 687 | base_doc_output, 688 | base_doc_answers, 689 | finetuned_doc_output, 690 | finetuned_doc_answers, 691 | ) in zip( 692 | base_outputs, base_answers, finetuned_outputs, finetuned_answers, 693 | ) 694 | ] 695 | 696 | return all_scores 697 | 698 | def get_inputs_for_sentence(self, sent_tokens, summary_tokens): 699 | """Get inference inputs corresponding to a given sentence. For BLANC-tune, we get several 700 | maskings for each sentence, and each masking is a single input. See documentation in 701 | superclass. 702 | """ 703 | sent_maskings, init_answers = self.mask_input_tokens(sent_tokens, is_finetune=False) 704 | inputs, final_answers = [], [] 705 | for sent_idx, (sent_masking, init_answer) in enumerate(zip(sent_maskings, init_answers)): 706 | input_, answers = self.assemble_inference_input( 707 | answers=init_answer, sent_tokens=sent_masking, 708 | ) 709 | 710 | inputs.append(input_) 711 | final_answers.append(answers) 712 | 713 | return inputs, final_answers 714 | 715 | def finetune(self, model, summary): 716 | """Finetune the given model on a "dataset" produced from chunks of the given summary. 717 | 718 | Args: 719 | model (BertForMaskedLM): a BERT for masked language modeling torch model 720 | summary (str): the summary to finetune on 721 | """ 722 | if self.random_seed > 0: 723 | set_seed(self.random_seed) 724 | model.train() 725 | n_params = len(list(model.parameters())) 726 | # Freeze a few lowest or a few highest layers: 727 | if self.id_layer_freeze_below > 0 or self.id_layer_freeze_above > 0: 728 | for i, param in enumerate(model.parameters()): 729 | if i < self.id_layer_freeze_below: 730 | param.requires_grad = False 731 | elif self.id_layer_freeze_above < 0: 732 | break 733 | elif n_params - i < self.id_layer_freeze_above: 734 | param.requires_grad = False 735 | all_inputs = self.prepare_finetuning_data(summary) 736 | input_batches = batch_data(all_inputs, self.finetune_batch_size) 737 | 738 | no_decay = ["bias", "LayerNorm.weight"] 739 | optimizer_grouped_parameters = [ 740 | { 741 | "params": [ 742 | p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) 743 | ], 744 | "weight_decay": 1e-2, 745 | }, 746 | { 747 | "params": [ 748 | p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) 749 | ], 750 | "weight_decay": 0.0, 751 | }, 752 | ] 753 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=1e-8) 754 | scheduler = get_linear_schedule_with_warmup( 755 | optimizer, 756 | num_warmup_steps=self.warmup_steps, 757 | num_training_steps=len(input_batches) * self.finetune_epochs, 758 | ) 759 | for epoch in range(self.finetune_epochs): 760 | for input_batch in input_batches: 761 | input_ids, attention_mask, token_type_ids, labels = get_input_tensors( 762 | input_batch, device=self.device, tokenizer=self.model_tokenizer, 763 | ) 764 | model.zero_grad() 765 | optimizer.zero_grad() 766 | try: # masked_lm_labels were deprecated, replace by labels in transformers v4.x 767 | loss, _ = model( 768 | input_ids=input_ids, 769 | attention_mask=attention_mask, 770 | token_type_ids=token_type_ids, 771 | labels=labels, 772 | ) 773 | except: 774 | loss, _ = model( 775 | input_ids=input_ids, 776 | attention_mask=attention_mask, 777 | token_type_ids=token_type_ids, 778 | masked_lm_labels=labels, 779 | ) 780 | loss.backward() 781 | optimizer.step() 782 | scheduler.step() 783 | model.eval() 784 | 785 | def prepare_finetuning_data(self, summary): 786 | """Create a finetuning dataset using chunks of the given summary 787 | The finetune_top_fully=True compensate finetuning of top tokens, which 788 | otherwise get less tuning than tokens at further strides. 789 | 790 | Args: 791 | summary (str): the input summary to finetune on 792 | 793 | Returns: 794 | model_inputs (List[BertInput]): a list of inputs to use as the finetuning dataset 795 | """ 796 | summary_tokens = self.model_tokenizer.tokenize(summary) 797 | model_inputs = [] 798 | for start_token in range(0, len(summary_tokens), self.finetune_chunk_stride): 799 | end_token = start_token + self.finetune_chunk_size 800 | chunk_tokens = summary_tokens[start_token:end_token] 801 | model_inputs += self.assemble_finetuning_input(chunk_tokens) 802 | if self.finetune_top_fully and start_token > 0 and start_token < self.finetune_chunk_size: 803 | chunk_tokens = summary_tokens[:start_token] 804 | model_inputs += self.assemble_finetuning_input(chunk_tokens) 805 | return model_inputs 806 | 807 | def assemble_finetuning_input(self, chunk_tokens): 808 | """Given input tokens, assemble them into the tensors used by the model for finetuning 809 | 810 | Args: 811 | chunk_tokens (List[str]): a token sequence 812 | 813 | Returns: 814 | model_inputs (List[BertInput]): BertInputs corresponding to different maskings of 815 | chunk_tokens 816 | """ 817 | all_input_tokens, all_answers = self.mask_input_tokens(chunk_tokens, is_finetune=True) 818 | 819 | all_input_tokens = [ 820 | [self.model_tokenizer.cls_token] + tokens + [self.model_tokenizer.sep_token] 821 | for tokens in all_input_tokens 822 | ] 823 | all_input_ids = [ 824 | self.model_tokenizer.convert_tokens_to_ids(tokens) for tokens in all_input_tokens 825 | ] 826 | all_labels = [[LABEL_IGNORE] * len(tokens) for tokens in all_input_tokens] 827 | 828 | model_inputs = [] 829 | for input_ids, answers, labels in zip(all_input_ids, all_answers, all_labels): 830 | for original_idx, token in answers.items(): 831 | idx = original_idx + 1 # accounting for starting CLS token 832 | (original_token_id,) = self.model_tokenizer.convert_tokens_to_ids([token]) 833 | labels[idx] = original_token_id 834 | 835 | random_number = random.random() 836 | if random_number < self.p_token_replace: 837 | # replace with a random token 838 | input_ids[idx] = random.randint(*TOKEN_REPLACE_RANGE) 839 | elif random_number < self.p_token_original + self.p_token_replace: 840 | # use original token 841 | input_ids[idx] = original_token_id 842 | 843 | attention_mask = [NOT_MASKED] * len(input_ids) 844 | token_type_ids = [TOKEN_TYPE_A] * len(input_ids) 845 | model_input = BertInput( 846 | input_ids=input_ids, 847 | attention_mask=attention_mask, 848 | token_type_ids=token_type_ids, 849 | labels=labels, 850 | masked_idxs=None, 851 | ) 852 | model_inputs.append(model_input) 853 | 854 | return model_inputs 855 | --------------------------------------------------------------------------------