├── blanc
├── __version__.py
├── __init__.py
├── shannon.py
├── __main__.py
├── utils.py
├── estime.py
└── blanc.py
├── .gitignore
├── requirements.txt
├── data
├── single.json
├── pairs.json
├── doc-summaries.json
└── README.md
├── shannon
├── README.md
└── summeval_score.py
├── LICENSE
├── setup.py
├── SECURITY.md
├── .github
└── workflows
│ └── codeql-analysis.yml
├── estime
└── README.md
└── README.md
/blanc/__version__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.3.4"
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *.sw*
3 | ._*
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | nltk>=3.1,<4.0
2 | numpy>=1.0,<2.0
3 | scipy>=1.6.0
4 | torch>=1.0,<2.0
5 | tqdm>=4.27.0,<5.0
6 | transformers>=2.4.0
7 |
--------------------------------------------------------------------------------
/blanc/__init__.py:
--------------------------------------------------------------------------------
1 | from .__version__ import __version__
2 | from .blanc import BlancHelp, BlancTune
3 | from .estime import Estime
4 | from .shannon import Shannon
5 |
--------------------------------------------------------------------------------
/data/single.json:
--------------------------------------------------------------------------------
1 | {
2 | "doc": "Jack drove his minivan to the bazaar to purchase milk and honey for his large family.",
3 | "summary": "Jack bought milk and honey."
4 | }
5 |
--------------------------------------------------------------------------------
/data/pairs.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "doc": "Jack drove his minivan to the bazaar to purchase milk and honey for his large family.",
4 | "summary": "Jack bought milk and honey."
5 | },
6 | {
7 | "doc": "As Jill started taking a walk in the park, she certainly noticed that the trees were extra green this year.",
8 | "summary": "Jill saw green trees in the park."
9 | }
10 | ]
11 |
--------------------------------------------------------------------------------
/shannon/README.md:
--------------------------------------------------------------------------------
1 | # Shannon Game
2 |
3 | Shannon Score and Information Difference metrics of summary quality are defined in [Play the Shannon Game With Language Models: A Human-Free Approach to Summary Evaluation](https://arxiv.org/abs/2103.10918), in [Proceedings AAAI 2022](https://ojs.aaai.org/index.php/AAAI/article/view/21304). Our implementation of these metrics is [here](https://github.com/PrimerAI/blanc/blob/master/blanc/shannon.py).
4 |
--------------------------------------------------------------------------------
/data/doc-summaries.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "doc": "Jack drove his minivan to the bazaar to purchase milk and honey for his large family.",
4 | "summaries": [
5 | "Jack bought milk and honey.",
6 | "Jack drove to the bazaar in a minivan."
7 | ]
8 | },
9 | {
10 | "doc": "As Jill started taking a walk in the park, she certainly noticed that the trees were extra green this year.",
11 | "summaries": [
12 | "Jill saw green trees in the park.",
13 | "The trees were green."
14 | ]
15 | }
16 | ]
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2020 Primer AI
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import json
2 | import setuptools
3 |
4 | version = "0.3.4"
5 |
6 | with open("README.md", encoding="utf-8") as reader:
7 | long_description = reader.read()
8 |
9 | with open("requirements.txt") as reader:
10 | requirements = [line.strip() for line in reader]
11 |
12 | setuptools.setup(
13 | name="blanc",
14 | version=version,
15 | author="Primer AI",
16 | author_email="blanc@primer.ai",
17 | description="Human-free quality estimation of document summaries",
18 | long_description=long_description,
19 | long_description_content_type="text/markdown",
20 | url="https://github.com/PrimerAI/blanc",
21 | packages=setuptools.find_packages(),
22 | install_requires=requirements,
23 | include_package_data=True,
24 | classifiers=[
25 | "Programming Language :: Python :: 3",
26 | "License :: OSI Approved :: MIT License",
27 | "Operating System :: OS Independent",
28 | ],
29 | python_requires=">=3.6",
30 | entry_points={
31 | "console_scripts": ["blanc=blanc.__main__:main"],
32 | },
33 | )
34 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 | The datasets are as described and used in [Fill in the BLANC: Human-free quality estimation of document summaries](https://www.aclweb.org/anthology/2020.eval4nlp-1.2/) and in [Sensitivity of BLANC to human-scored qualities of text summaries](https://arxiv.org/abs/2010.06716). The human scores to summaries were assigned by 10 annotators from Odetta.ai, provided with detailed instructions and trained on trial tasks. Each dataset preserves Ids of the annotators with their individual scores.
3 |
4 | The datasets:
5 | 1. CNN_DailyMail_555: 555 text-summary pairs, with 100 texts with human summaries taken randomly from the CNN / Daily Mail dataset [Hermann et al., 2015](https://proceedings.neurips.cc/paper/2015/file/afdec7005cc9f14302cd0474fd0f3c96-Paper.pdf), and complemented with generated summaries. The single human score describes generic quality of the summary.
6 | 2. DailyNews_300: 300 text-summary pairs, created from 100 texts taken randomly from daily news of different sources. Three summaries for each text were generated by extractive, abstractive and semi-abstractive models. The single human score describes generic quality of the summary.
7 | 3. DailyNews_300_aspects: The same 300 text-summary pairs as above, but assigned 5 human quality scores, accordingly to how fluent, understandable, informative, compact and overall-good the summary is.
8 |
9 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Primer takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organization, [PrimerAI](https://github.com/PrimerAI).
6 |
7 | If you believe you have found a security vulnerability in any Primer-owned repository, please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Primer Security Team at:
14 | [security@primer.ai](mailto:security@primer.ai)
15 |
16 |
17 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
18 |
19 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
20 | * Full paths of source file(s) related to the manifestation of the issue
21 | * The location of the affected source code (tag/branch/commit or direct URL)
22 | * Any special configuration required to reproduce the issue
23 | * Step-by-step instructions to reproduce the issue
24 | * Proof-of-concept or exploit code (if possible)
25 | * Impact of the issue, including how an attacker might exploit the issue
26 |
27 | This information will help us triage your report more quickly.
28 |
29 |
30 | ## Preferred Languages
31 |
32 | We prefer all communications to be in English.
33 |
34 | ## Policy
35 |
36 | Primer follows the principle of [Coordinated Vulnerability Disclosure](https://resources.sei.cmu.edu/asset_files/SpecialReport/2017_003_001_503340.pdf).
37 |
38 |
39 |
--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ master ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ master ]
20 | schedule:
21 | - cron: '31 17 * * 2'
22 |
23 | jobs:
24 | analyze:
25 | name: Analyze
26 | runs-on: ubuntu-latest
27 | permissions:
28 | actions: read
29 | contents: read
30 | security-events: write
31 |
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | language: [ 'python' ]
36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
37 | # Learn more:
38 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
39 |
40 | steps:
41 | - name: Checkout repository
42 | uses: actions/checkout@v2
43 |
44 | # Initializes the CodeQL tools for scanning.
45 | - name: Initialize CodeQL
46 | uses: github/codeql-action/init@v2
47 | with:
48 | languages: ${{ matrix.language }}
49 | # If you wish to specify custom queries, you can do so here or in a config file.
50 | # By default, queries listed here will override any specified in a config file.
51 | # Prefix the list here with "+" to use these queries and those in the config file.
52 | # queries: ./path/to/local/query, your-org/your-repo/queries@main
53 |
54 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
55 | # If this step fails, then you should remove it and run the build manually (see below)
56 | - name: Autobuild
57 | uses: github/codeql-action/autobuild@v2
58 |
59 | # ℹ️ Command-line programs to run using the OS shell.
60 | # 📚 https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
61 |
62 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
63 | # and modify them (or add more) to build your code if your project
64 | # uses a compiled language
65 |
66 | #- run: |
67 | # make bootstrap
68 | # make release
69 |
70 | - name: Perform CodeQL Analysis
71 | uses: github/codeql-action/analyze@v2
72 |
--------------------------------------------------------------------------------
/shannon/summeval_score.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import json
3 | import numpy as np
4 | from scipy.stats import pearsonr, spearmanr, kendalltau
5 | import argparse
6 |
7 | factors = ['coherence', 'consistency', 'fluency', 'relevance']
8 | levels = ['expert']
9 | human_cols = sum([[f'{level}_{factor}' for factor in factors] for level in levels], [])
10 | method_cols = ['shannon_score', 'info_diff', 'blanc_shannon']
11 | models = ['small', 'medium', 'large', 'xl', 'gpt1', 'xlnet', 'transformerxl']
12 |
13 | def load(input_file):
14 | with open(input_file) as reader:
15 | result_blobs = [json.loads(line) for line in reader]
16 |
17 | for blob in result_blobs:
18 | S = blob['S']
19 | blob['blanc_shannon'] = (S[0][1] - S[1][0]) / sum(sum(S, []))
20 | blob['s_impr'] = S[0][1] / (S[0][0] + S[1][1] + S[0][1])
21 | del blob['S']
22 |
23 | result_df = pd.DataFrame.from_records(result_blobs)
24 |
25 | with open('/nfs/data/summeval/docs-dynamicmix.json') as reader:
26 | input_blobs = json.load(reader)
27 |
28 | annotation_blobs = []
29 | for blob in input_blobs:
30 | annotation_blob = {'id': blob['id']}
31 | for level in levels:
32 | annotations = blob.get(f'{level}_annotations')
33 | if annotations is None:
34 | continue
35 | for factor in factors:
36 | mean = sum([annotation[factor] for annotation in annotations]) / len(annotations)
37 | annotation_blob[f'{level}_{factor}'] = mean
38 | annotation_blob['lower'] = all([c.lower() == c for c in blob['summ']])
39 | annotation_blobs.append(annotation_blob)
40 |
41 | annotation_df = pd.DataFrame.from_records(annotation_blobs)
42 | df = result_df.join(annotation_df)
43 | df['shannon_score'] = (df.ll_help - df.ll_base) / (df.ll_full - df.ll_base)
44 | df['info_decr'] = (df.ll_base - df.ll_help) / df.ll_base
45 | df['compression'] = df.num_summ_tokens / df.num_doc_tokens
46 | df['cond_lik'] = df.ll_help
47 | df['info_diff'] = df.ll_help - df.ll_base
48 | df['avg_cond_lik'] = df.ll_help / df.num_doc_tokens
49 |
50 | systems = df.groupby('system').mean()
51 | return df, systems
52 |
53 | def eval_many(models, filenames):
54 | shannon_scores, dfs = {}, []
55 | for model, filename in zip(models, filenames):
56 | df, systems = load(filename)
57 | shannon_scores[model] = df.shannon_score
58 | print(f'{model} system kendall-b')
59 | print(systems.corr(method='kendall')[human_cols].loc[method_cols])
60 | df['model_name'] = model
61 | dfs.append(df)
62 | shannon_scores = pd.DataFrame(shannon_scores)
63 | combined = pd.concat(dfs).reset_index()
64 | print(shannon_scores.corr())
65 |
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument('--model', type=str, default=None)
68 | parser.add_argument('--upstream', action='store_true')
69 | args = parser.parse_args()
70 |
71 | if args.upstream:
72 | eval_many(range(5), [f'out/context-{context}-small.jsonl' for context in range(5)])
73 | elif args.model is None:
74 | eval_many(models, [f'out/17-{model}.jsonl' for model in models])
75 | else:
76 | df, systems = load(f'out/17-{args.model}.jsonl')
77 | print('Overall Pearson')
78 | print(df.corr(method='pearson')[human_cols].loc[method_cols])
79 | print('Overall Spearman')
80 | print(df.corr(method='spearman')[human_cols].loc[method_cols])
81 | print('Overall Kendall-b')
82 | print(df.corr(method='kendall')[human_cols].loc[method_cols])
83 |
84 | print('System Pearson')
85 | print(systems.corr(method='pearson')[human_cols].loc[method_cols])
86 | print('System Spearman')
87 | print(systems.corr(method='spearman')[human_cols].loc[method_cols])
88 | print('System Kendall-b')
89 | print(systems.corr(method='kendall')[human_cols].loc[method_cols])
90 |
--------------------------------------------------------------------------------
/estime/README.md:
--------------------------------------------------------------------------------
1 | # ESTIME
2 |
3 | ESTIME as the 'number of alarms' was defined in [ESTIME: Estimation of Summary-to-Text Inconsistency by Mismatched Embeddings](https://aclanthology.org/2021.eval4nlp-1.10/).
4 | ESTIME-soft and ESTIME-coherence were defined in [Consistency and Coherence from Points of Contextual Similarity](https://arxiv.org/abs/2112.11638). Sourse: [estime](https://github.com/PrimerAI/blanc/blob/master/blanc/estime.py).
5 |
6 | ESTIME is a reference-free estimator of summary quality with emphasis on factual consistency. It can be used for filtering generated summaries, or for estimating improvement of a generation system.
7 |
8 | Usage is simple: create `Estime`, and use `evaluate_claims`. When creating Estime, specify the list of names of the measures to obtain for each claim. Basic usage:
9 |
10 | ```python
11 | >>> from blanc import Estime
12 | >>> estimator = Estime()
13 | >>> text = """In Kander’s telling, Mandel called him up out of the blue a decade or so ago to pitch a project. It made sense why. The two men had similar profiles: Jewish combat veterans in their early 30s. New statewide officeholders in the Midwest."""
14 | >>> summary = """Kander and Mandel had similar profiles, and it makes sense."""
15 | >>> estimator.evaluate_claims(text, [summary])
16 | [[5]]
17 | ```
18 |
19 | Default `device` in Estime() is `device`='cpu'. It can be set `device`='cuda'.
20 |
21 | In the example above only one summary is given to the text, and hence the list of results contains only one element [5] - the scores only for this summary. The scores list contains only single score =5, because by default the list of measures contains only one measure 'alarms'. More measures can be included: 'alarms', 'alarms_adjusted', 'alarms_alltokens', 'soft', 'coherence'. For example:
22 |
23 | ```
24 | >>> estimator = Estime(output=['alarms', 'alarms_adjusted', 'soft', 'coherence'])
25 | >>> estimator.evaluate_claims(text, [summary])
26 | [[5, 7.5, 0.502, -0.25]]
27 | ```
28 | The results appear in the same order as the names given in `output`. The measures 'alarms' (the original ESTIME), 'soft' and 'coherence' are as defined in the papers. The only difference is that when there are no any tokens overlap between the claim and the text, the 'alarms' is set to the number of the tokens in the summary. Unlike 'soft', the original ESTIME does not make good estimation for the cases where the number of overlap tokens is much less than the total number of summary tokens. Starting from the version 0.3.3, the measure 'alarms_adjusted' can be added. It is defined as `alarms_adjusted = alarms * N / M`, where M is the number of overlap tokens, and N is the total number of summary tokens. Thus, it serves as an extrapolation of the 'alarms' to the total number of summary tokens. When M=0, the 'alarms_adjusted' is set to N. For curiocity (not recommended), the 'alarms_alltokens' also can be added, it is defined as `alarms_alltokens = alarms + N - M`, meaning that any non-overlapping token is counted as an alarm.
29 |
30 | For more options, see comments in the source [estime](https://github.com/PrimerAI/blanc/blob/master/blanc/estime.py), or see [estime](https://github.com/PrimerAI/primer-research/tree/main/estime).
31 |
32 | The table below is made in the same way as the Table 1 in [ESTIME](https://aclanthology.org/2021.eval4nlp-1.10/), except that the number of systems here is updated from 16 to 17, following the later version of [SummEval](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00373/100686/SummEval-Re-evaluating-Summarization-Evaluation). This means that the correlations are taken here between arrays of 1700-length (100 texts x 17 summary generation systems).
33 |
34 | |model|consistency
Spearman|consistency
Kendall|relevance
Spearman|relevance
Kendall|coherence
Spearman|coherence
Kendall|fluency
Spearman|fluency
Kendall|
35 | |:--|--:|--:|--:|--:|--:|--:|--:|--:|
36 | BLANC-AXXL|0.19|0.09|0.21|0.15|0.11|0.08|0.10|0.06|
37 | BLANC-BLU|0.20|0.10|0.18|0.13|0.10|0.07|0.11|0.06|
38 | BLANC|0.19|0.10|0.28|0.20|0.22|0.16|0.13|0.07|
39 | ESTIME-12|0.36|0.18|0.10|0.07|0.20|0.14|0.32|0.19|
40 | ESTIME-21|**0.39**|**0.19**|0.15|0.11|0.27|0.19|**0.38**|**0.22**|
41 | ESTIME-24|0.34|0.17|0.08|0.06|0.16|0.11|0.34|0.20|
42 | Jensen-Shannon|0.18|0.09|0.39|0.28|0.29|0.21|0.11|0.06|
43 | SummaQA-F|0.17|0.08|0.14|0.10|0.08|0.06|0.12|0.07|
44 | SummaQA-P|0.19|0.09|0.17|0.12|0.10|0.08|0.12|0.07|
45 | SUPERT|0.28|0.14|0.26|0.19|0.20|0.15|0.17|0.10|
46 | (r) BERTScore-F|0.10|0.05|0.38|0.28|**0.39**|**0.28**|0.13|0.07|
47 | (r) BERTScore-P|0.05|0.03|0.29|0.21|0.34|0.25|0.11|0.06|
48 | (r) BERTScore-R|0.15|0.08|**0.41**|**0.30**|0.34|0.249|0.11|0.06|
49 | (r) BLEU|0.09|0.04|0.23|0.17|0.19|0.14|0.12|0.07|
50 | (r) ROUGE-L|0.12|0.06|0.23|0.16|0.16|0.11|0.08|0.04|
51 | (r) ROUGE-1|0.13|0.07|0.28|0.20|0.17|0.12|0.07|0.04|
52 | (r) ROUGE-2|0.12|0.06|0.23|0.16|0.14|0.10|0.06|0.04|
53 | (r) ROUGE-3|0.15|0.07|0.23|0.17|0.15|0.11|0.06|0.04|
54 |
55 | (r): These measures need human-written reference summaries to evaluate a summary.
56 | The ESTIME and Jensen-Shannon scores are negated.
57 | The third row is the default version of BLANC.
58 |
59 | The numbers have slightly changed here compared to the 16-system data reported in [ESTIME](https://aclanthology.org/2021.eval4nlp-1.10/); the trends and the top correlations are the same.
60 | Notice that for consistency any reference-free measure outperforms all reference-needed measures.
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/blanc/shannon.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import torch
4 | import torch.nn.functional as F
5 | from tqdm import tqdm
6 | from transformers import (
7 | GPT2LMHeadModel,
8 | GPT2Tokenizer,
9 | OpenAIGPTLMHeadModel,
10 | OpenAIGPTTokenizer,
11 | XLNetLMHeadModel,
12 | XLNetTokenizer,
13 | TransfoXLLMHeadModel,
14 | TransfoXLTokenizer,
15 | ReformerModelWithLMHead,
16 | ReformerTokenizer,
17 | XLMWithLMHeadModel,
18 | XLMTokenizer,
19 | )
20 | import numpy as np
21 | from nltk import sent_tokenize
22 |
23 | def get_model(name, size, device='cuda'):
24 | if device == 'cuda':
25 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
26 | if name == 'gpt2':
27 | t = GPT2Tokenizer.from_pretrained('gpt2')
28 | if size == 'base':
29 | g = GPT2LMHeadModel.from_pretrained('gpt2')
30 | else:
31 | g = GPT2LMHeadModel.from_pretrained(f'gpt2-{size}')
32 | eos = g.config.eos_token_id # '<|endoftext|>'
33 | max_input = 1024
34 | elif name == 'gpt1':
35 | t = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
36 | g = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
37 | eos = 1 # .
38 | max_input = 512
39 | elif name == 'xlnet':
40 | t = XLNetTokenizer.from_pretrained(f'xlnet-{size}-cased')
41 | g = XLNetLMHeadModel.from_pretrained(f'xlnet-{size}-cased')
42 | eos = g.config.eos_token_id #
43 | max_input = 1024
44 | elif name == 'transformerxl':
45 | t = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
46 | g = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
47 | eos = g.config.eos_token_id #
48 | max_input = 1024
49 | elif name == 'reformer':
50 | t = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
51 | g = ReformerModelWithLMHead.from_pretrained('google/reformer-crime-and-punishment')
52 | eos = g.config.eos_token_id # (empty string)
53 | max_input = 1024
54 | elif name == 'xlm':
55 | t = XLMTokenizer.from_pretrained('xlm-clm-ende-1024')
56 | g = XLMWithLMHeadModel.from_pretrained('xlm-clm-ende-1024')
57 | g.config.lang_id = g.config.lang2id['en']
58 | eos = 4 #
59 | max_input = 1024
60 |
61 | g = g.to(device)
62 | g.config.return_dict = False
63 | g.eval()
64 | return g, t, eos, max_input
65 |
66 | def prepare_inputs_for_generation(input_ids, past=None, **kwargs):
67 | """Copied from gpt2 of huggingface transformers, but using it here separately,
68 | because in some versions this funciton worked differently, which caused errors."""
69 | if past: # only last token for inputs_ids if past is defined in kwargs
70 | input_ids = input_ids[:, -1].unsqueeze(-1)
71 |
72 | attention_mask = kwargs.get("attention_mask", None)
73 | position_ids = kwargs.get("position_ids", None)
74 |
75 | if attention_mask is not None and position_ids is None:
76 | # create position_ids on the fly for batch generation
77 | position_ids = attention_mask.long().cumsum(-1) - 1
78 | position_ids.masked_fill_(attention_mask == 0, 1)
79 | if past:
80 | position_ids = position_ids[:, -1].unsqueeze(-1)
81 | else:
82 | position_ids = None
83 | return {
84 | "input_ids": input_ids,
85 | "past_key_values": past,
86 | "use_cache": kwargs.get("use_cache"),
87 | "position_ids": position_ids,
88 | "attention_mask": attention_mask,
89 | }
90 |
91 | class Shannon:
92 | def __init__(
93 | self,
94 | verbose=False,
95 | language_model='gpt2',
96 | model_size='base',
97 | num_upstream=0,
98 | return_token_lls=False,
99 | device='cuda'
100 | ):
101 | self.verbose = verbose
102 | self.language_model = language_model
103 | self.num_upstream = num_upstream
104 | self.return_token_lls = return_token_lls
105 | self.device = device
106 | if self.device == 'cuda':
107 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
108 | self.g, self.t, self.eos, self.max_input = get_model(
109 | language_model, model_size, self.device)
110 |
111 | def measure(self, doc_tokens, prompt):
112 | eos = torch.LongTensor([self.eos]).to(self.device)
113 | if prompt is None or (prompt.dim() == 1 and len(prompt) == 0):
114 | prompt = torch.LongTensor([]).to(self.device)
115 |
116 | token_lls = []
117 | success = []
118 | past = None
119 | for i, token in enumerate(doc_tokens):
120 | upstream = doc_tokens[:i]
121 | if len(upstream) + len(prompt) + 1 > self.max_input:
122 | upstream = upstream[-(self.max_input - 1 - len(prompt)):]
123 | if past is not None:
124 | past = [t[:, :, :, 1:, :] for t in past]
125 |
126 | prefix = torch.cat([eos, prompt, upstream]).unsqueeze(0)
127 | inputs = prepare_inputs_for_generation(prefix, past=past, use_cache=True, use_mems=True)
128 | with torch.no_grad():
129 | out = self.g(**inputs)
130 |
131 | if self.language_model in 'gpt2':
132 | logits, past = out
133 | elif self.language_model in ['gpt1', 'reformer']:
134 | logits, = out
135 | logits = logits[0, -1, :]
136 | elif self.language_model == 'xlnet':
137 | logits, = out
138 | elif self.language_model == 'transformerxl':
139 | logits, past = out
140 | logits = logits[0, -1, :]
141 | elif self.language_model == 'xlm':
142 | logits, = out
143 | logits = logits[0, -1, :]
144 | probs = F.softmax(logits, dim=-1).view(-1)
145 | prob = probs[token].item()
146 |
147 | log_prob = np.log(prob)
148 | token_lls.append(log_prob)
149 | success.append(int(token == probs.argmax()))
150 |
151 | true_token = self.t.decode([token])
152 | try:
153 | pred_token = self.t.decode([probs.argmax()])
154 | except:
155 | pred_token = None
156 | info = -log_prob / np.log(2)
157 | self.log(f'{true_token},{info}')
158 |
159 | return token_lls, success
160 |
161 | def go(self, doc, summ, measure_t=False, measure_summ=False):
162 | sents = sent_tokenize(doc)
163 | encode_args = {'return_tensors': 'pt'}
164 | if self.language_model == 'transformerxl':
165 | encode_args['add_space_before_punct_symbol'] = True
166 | if self.language_model in ['xlnet', 'xlm']:
167 | encode_args['add_special_tokens'] = False
168 |
169 | sents_tokens = [self.t.encode(sent, **encode_args).to(self.device).view(-1) for sent in sents]
170 | summ_tokens = self.t.encode(summ, **encode_args).to(self.device).view(-1)
171 | sents_tokens = [sent_tokens[:self.max_input - 1 - len(summ_tokens)] for sent_tokens in sents_tokens]
172 | doc_tokens = torch.cat(sents_tokens, dim=-1)
173 |
174 | if measure_t:
175 | ll, tries, success = 0, 0, []
176 | for sent_tokens in sents_tokens:
177 | sent_ll, sent_tries, sent_success = self.measure(sent_tokens, sent_tokens)
178 | ll += sent_ll
179 | tries += sent_tries
180 | success += sent_success
181 | return ll, tries, success
182 |
183 | elif measure_summ:
184 | summ_ll, summ_success = self.measure(summ_tokens, None)
185 | return summ_ll
186 |
187 | else:
188 | token_lls_base, token_lls_help, token_lls_full = [], [], []
189 | S = [[0, 0], [0, 0]]
190 | for sent_idx in range(len(sents_tokens)):
191 | sent_tokens = sents_tokens[sent_idx]
192 | upstream_tensors = sents_tokens[sent_idx-self.num_upstream:sent_idx]
193 | if len(upstream_tensors) > 0:
194 | upstream_context = torch.cat(upstream_tensors)
195 | else:
196 | upstream_context = torch.LongTensor([])
197 | if self.device == 'cuda':
198 | upstream_context = upstream_context.cuda()
199 |
200 | base_prompt = upstream_context
201 | help_prompt = torch.cat([summ_tokens, upstream_context])
202 | full_prompt = torch.cat([upstream_context, sent_tokens, upstream_context])
203 |
204 | base_sent_lls, base_sent_success = self.measure(sent_tokens, base_prompt)
205 | help_sent_lls, help_sent_success = self.measure(sent_tokens, help_prompt)
206 | full_sent_lls, full_sent_success = self.measure(sent_tokens, full_prompt)
207 |
208 | token_lls_base += base_sent_lls
209 | token_lls_help += help_sent_lls
210 | token_lls_full += full_sent_lls
211 |
212 | for b, h in zip(base_sent_success, help_sent_success):
213 | S[b][h] += 1
214 |
215 | if self.return_token_lls:
216 | doc_tokens = self.t.convert_ids_to_tokens(doc_tokens)
217 | summ_tokens = self.t.convert_ids_to_tokens(summ_tokens)
218 | return token_lls_base, token_lls_help, token_lls_full, doc_tokens, summ_tokens
219 | else:
220 | ll_base = sum(token_lls_base)
221 | ll_help = sum(token_lls_help)
222 | ll_full = sum(token_lls_full)
223 |
224 | self.log(f'Shannon Score: {(ll_help - ll_base) / (ll_full - ll_base)}')
225 | self.log(f'Info Diff: {ll_help - ll_base}')
226 | self.log(f'BLANC: {(S[0][1] - S[1][0]) / (S[0][0] + S[0][1] + S[1][0] + S[1][1])}')
227 |
228 | return ll_base, ll_help, ll_full, S, len(doc_tokens), len(summ_tokens)
229 |
230 | def log(self, s=None):
231 | if self.verbose:
232 | print(s)
233 |
234 | if __name__ == '__main__':
235 | parser = argparse.ArgumentParser()
236 | parser.add_argument('--simple', action='store_true')
237 | parser.add_argument('--verbose', action='store_true')
238 | parser.add_argument('--measure_t', action='store_true')
239 | parser.add_argument('--measure_summ', action='store_true')
240 | parser.add_argument('--input_file', type=str)
241 | parser.add_argument('--eval', type=str, choices=['6', '47', '23'], default=None)
242 | parser.add_argument('--system', type=str, default=None)
243 | parser.add_argument('--lm', type=str, default='gpt2')
244 | parser.add_argument('--model_size', type=str, default='base')
245 | parser.add_argument('--num_upstream', type=int, default=0)
246 | parser.add_argument('--start', type=int, default=0)
247 | args = parser.parse_args()
248 |
249 | s = Shannon(args.verbose, args.lm, args.model_size, args.num_upstream)
250 |
251 | if args.simple:
252 | doc = 'Jack drove his minivan to the bazaar to purchase milk and honey for his large family'
253 | summ = 'Jack bought milk and honey from the bazaar'
254 | results = s.go(doc, summ, measure_t=args.measure_t, measure_summ=args.measure_summ)
255 |
256 | print(results)
257 | else:
258 | with open(args.input_file) as reader:
259 | if args.input_file.endswith('.jsonl'):
260 | data = [json.loads(line) for line in reader]
261 | else:
262 | data = json.load(reader)
263 |
264 | selection = data[args.start:]
265 | if args.eval is not None:
266 | selection = [record for record in data if record['eval'] in args.eval]
267 | if args.system is not None:
268 | selection = [record for record in selection if record['model_id'] == args.system]
269 |
270 | for record in tqdm(selection):
271 | if args.measure_t or args.measure_summ:
272 | ll = s.go(
273 | record['doc'], record['summ'],
274 | measure_summ=args.measure_summ, measure_t=args.measure_t
275 | )
276 | print(json.dumps({
277 | 'doc_id': record['id'],
278 | 'system': record['model_id'],
279 | 'll_summ': ll,
280 | }))
281 |
282 | else:
283 | ll_base, ll_help, ll_full, S, num_doc_tokens, num_summ_tokens = s.go(
284 | record['doc'], record['summ']
285 | )
286 | print(json.dumps({
287 | 'doc_id': record['id'],
288 | 'system': record['model_id'],
289 | 'll_base': ll_base,
290 | 'll_help': ll_help,
291 | 'll_full': ll_full,
292 | 'num_doc_tokens': num_doc_tokens,
293 | 'num_summ_tokens': num_summ_tokens,
294 | 'S': S,
295 | }))
296 |
--------------------------------------------------------------------------------
/blanc/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from blanc import BlancHelp, BlancTune
9 | from blanc.utils import Defaults
10 |
11 |
12 | def main():
13 | parser = argparse.ArgumentParser(
14 | prog='blanc', formatter_class=argparse.ArgumentDefaultsHelpFormatter
15 | )
16 |
17 | required_parser = parser.add_argument_group('required arguments')
18 | required_parser.add_argument(
19 | 'type', type=str, choices=['help', 'tune'], help='BLANC-help or BLANC-tune'
20 | )
21 |
22 | input_parser = parser.add_argument_group('input arguments')
23 | input_parser.add_argument('--doc', type=str, help='single input document')
24 | input_parser.add_argument('--summary', type=str, help='single input summary')
25 | input_parser.add_argument(
26 | '--single_json',
27 | type=str,
28 | help='filename for single document summary pair',
29 | metavar='FILENAME',
30 | )
31 | input_parser.add_argument(
32 | '--pairs_json',
33 | type=str,
34 | help='filename for list of document summary pairs',
35 | metavar='FILENAME',
36 | )
37 | input_parser.add_argument(
38 | '--doc_summaries_json',
39 | type=str,
40 | help='filename for list of documents, each with a list of summaries',
41 | metavar='FILENAME',
42 | )
43 | input_parser.add_argument(
44 | '--doc_key',
45 | type=str,
46 | help='json key for the input document',
47 | metavar='KEY',
48 | default=Defaults.doc_key,
49 | )
50 | input_parser.add_argument(
51 | '--summary_key',
52 | type=str,
53 | help='json key for the input summary (single_json or pairs_json input)',
54 | metavar='KEY',
55 | default=Defaults.summary_key,
56 | )
57 | input_parser.add_argument(
58 | '--summaries_key',
59 | type=str,
60 | help='json key for the input summaries (doc_summaries_json input)',
61 | metavar='KEY',
62 | default=Defaults.summaries_key,
63 | )
64 | input_parser.add_argument(
65 | '--output_json',
66 | type=str,
67 | help='filename for output file, or None to print to STDOUT',
68 | metavar='FILENAME',
69 | )
70 |
71 | blanc_parser = parser.add_argument_group('arguments for BLANC-help and BLANC-tune')
72 | blanc_parser.add_argument(
73 | '--model_name',
74 | type=str,
75 | choices=['bert-base-cased', 'bert-base-uncased', 'bert-large-cased', 'bert-large-uncased'],
76 | help='BERT model type',
77 | default=Defaults.model_name,
78 | metavar='NAME',
79 | )
80 | blanc_parser.add_argument(
81 | '--measure',
82 | type=str,
83 | choices=['improve', 'relative'],
84 | help='measure improve or relative, as defined in the paper',
85 | default=Defaults.measure,
86 | )
87 | blanc_parser.add_argument(
88 | '--gap',
89 | type=int,
90 | help='distance between words to mask during inference',
91 | default=Defaults.gap,
92 | )
93 | blanc_parser.add_argument(
94 | '--gap_mask',
95 | type=int,
96 | help='number of tokens to mask at each designated position during inference',
97 | default=Defaults.gap_mask,
98 | )
99 | blanc_parser.add_argument(
100 | '--gap_tune',
101 | type=int,
102 | help='distance between words to mask during finetuning',
103 | default=Defaults.gap,
104 | )
105 | blanc_parser.add_argument(
106 | '--gap_mask_tune',
107 | type=int,
108 | help='number of tokens to mask at each designated position during finetuning',
109 | default=Defaults.gap_mask,
110 | )
111 | blanc_parser.add_argument(
112 | '--min_token_length_normal',
113 | type=int,
114 | help=(
115 | 'minimum number of chars in normal tokens to mask, where a normal token is '
116 | 'a whole word'
117 | ),
118 | default=Defaults.min_token_length_normal,
119 | metavar='LEN',
120 | )
121 | blanc_parser.add_argument(
122 | '--min_token_length_lead',
123 | type=int,
124 | help='minimum number of chars in lead token to mask, where a lead token begins a word',
125 | default=Defaults.min_token_length_lead,
126 | metavar='LEN',
127 | )
128 | blanc_parser.add_argument(
129 | '--min_token_length_followup',
130 | type=int,
131 | help=(
132 | 'minimum number of chars in followup token to mask, where a followup token '
133 | 'continues a word'
134 | ),
135 | default=Defaults.min_token_length_followup,
136 | metavar='LEN',
137 | )
138 | blanc_parser.add_argument(
139 | '--min_token_length_normal_tune',
140 | type=int,
141 | help=(
142 | 'minimum number of chars in normal tokens to mask at tuning, where a normal token is '
143 | 'a whole word'
144 | ),
145 | default=Defaults.min_token_length_normal_tune,
146 | metavar='LEN',
147 | )
148 | blanc_parser.add_argument(
149 | '--min_token_length_lead_tune',
150 | type=int,
151 | help='minimum number of chars in lead token to mask at tuning, where a lead token begins a word',
152 | default=Defaults.min_token_length_lead_tune,
153 | metavar='LEN',
154 | )
155 | blanc_parser.add_argument(
156 | '--min_token_length_followup_tune',
157 | type=int,
158 | help=(
159 | 'minimum number of chars in followup token to mask at tuning, where a followup token '
160 | 'continues a word'
161 | ),
162 | default=Defaults.min_token_length_followup_tune,
163 | metavar='LEN',
164 | )
165 | blanc_parser.add_argument(
166 | '--device', type=str, help='cpu or cuda device', default=Defaults.device,
167 | )
168 | blanc_parser.add_argument(
169 | '--random_seed',
170 | type=int,
171 | help='random seed for python and torch',
172 | default=Defaults.random_seed,
173 | metavar='SEED',
174 | )
175 | blanc_parser.add_argument(
176 | '--inference_batch_size',
177 | type=int,
178 | help='batch size to use during inference',
179 | default=Defaults.inference_batch_size,
180 | metavar='SIZE',
181 | )
182 | blanc_parser.add_argument(
183 | '--inference_mask_evenly',
184 | type=bool,
185 | help=(
186 | 'when True, mask every `gap` tokens (`gap_mask` tokens at once) that are longer than `min_token_length`'
187 | 'during finetuning, when False randomly mask tokens with probability 0.15'
188 | ),
189 | default=Defaults.inference_mask_evenly,
190 | metavar='MASK_EVENLY',
191 | )
192 |
193 | help_parser = parser.add_argument_group('BLANC-help arguments')
194 | help_parser.add_argument(
195 | '--filler_token',
196 | type=str,
197 | help='token to use as filler in lieu of summary',
198 | default=Defaults.filler_token,
199 | metavar='TOKEN',
200 | )
201 | help_parser.add_argument(
202 | '--help_sep',
203 | type=str,
204 | help=(
205 | "token to use to separate the summary or filler from the sentence, "
206 | "or '' for no separator"
207 | ),
208 | default=Defaults.help_sep,
209 | metavar='SEP',
210 | )
211 |
212 | tune_parser = parser.add_argument_group('BLANC-tune arguments')
213 | tune_parser.add_argument(
214 | '--finetune_batch_size',
215 | type=int,
216 | help='batch size to use when finetuning on summary',
217 | default=Defaults.finetune_batch_size,
218 | metavar='SIZE',
219 | )
220 | tune_parser.add_argument(
221 | '--finetune_epochs',
222 | type=int,
223 | help='number of epochs to train for when finetuning on summary',
224 | default=Defaults.finetune_epochs,
225 | metavar='EPOCHS',
226 | )
227 | tune_parser.add_argument(
228 | '--finetune_mask_evenly',
229 | type=bool,
230 | help=(
231 | 'when True, mask every `gap` tokens (`gap_mask` tokens at once) that are longer than `min_token_length`'
232 | 'during finetuning, when False randomly mask tokens with probability 0.15'
233 | ),
234 | default=Defaults.finetune_mask_evenly,
235 | metavar='MASK_EVENLY',
236 | )
237 | tune_parser.add_argument(
238 | '--finetune_chunk_size',
239 | type=int,
240 | help='number of summary tokens to use at a time when finetuning',
241 | default=Defaults.finetune_chunk_size,
242 | metavar='SIZE',
243 | )
244 | tune_parser.add_argument(
245 | '--finetune_chunk_stride',
246 | type=int,
247 | help='number of tokens between summary chunks for finetuning',
248 | default=Defaults.finetune_chunk_stride,
249 | metavar='STRIDE',
250 | )
251 | tune_parser.add_argument(
252 | '--learning_rate',
253 | type=float,
254 | help='learning rate when finetuning on summary',
255 | default=Defaults.learning_rate,
256 | metavar='LR',
257 | )
258 | tune_parser.add_argument(
259 | '--warmup_steps',
260 | type=int,
261 | help='warmup steps when finetuning on summary',
262 | default=Defaults.warmup_steps,
263 | metavar='STEPS',
264 | )
265 |
266 | args = parser.parse_args()
267 |
268 | random.seed(args.random_seed)
269 | np.random.seed(args.random_seed)
270 | torch.manual_seed(args.random_seed)
271 |
272 | if args.type == 'help':
273 | model = BlancHelp(
274 | model_name=args.model_name,
275 | measure=args.measure,
276 | gap=args.gap,
277 | gap_mask=args.gap_mask,
278 | gap_tune=args.gap_tune,
279 | gap_mask_tune=args.gap_mask_tune,
280 | min_token_length_normal=args.min_token_length_normal,
281 | min_token_length_lead=args.min_token_length_lead,
282 | min_token_length_followup=args.min_token_length_followup,
283 | min_token_length_normal_tune=args.min_token_length_normal_tune,
284 | min_token_length_lead_tune=args.min_token_length_lead_tune,
285 | min_token_length_followup_tune=args.min_token_length_followup_tune,
286 | device=args.device,
287 | inference_batch_size=args.inference_batch_size,
288 | inference_mask_evenly=args.inference_mask_evenly,
289 | filler_token=args.filler_token,
290 | help_sep=args.help_sep,
291 | )
292 | elif args.type == 'tune':
293 | model = BlancTune(
294 | model_name=args.model_name,
295 | measure=args.measure,
296 | gap=args.gap,
297 | gap_mask=args.gap_mask,
298 | gap_tune=args.gap_tune,
299 | gap_mask_tune=args.gap_mask_tune,
300 | min_token_length_normal=args.min_token_length_normal,
301 | min_token_length_lead=args.min_token_length_lead,
302 | min_token_length_followup=args.min_token_length_followup,
303 | min_token_length_normal_tune=args.min_token_length_normal_tune,
304 | min_token_length_lead_tune=args.min_token_length_lead_tune,
305 | min_token_length_followup_tune=args.min_token_length_followup_tune,
306 | device=args.device,
307 | inference_batch_size=args.inference_batch_size,
308 | inference_mask_evenly=args.inference_mask_evenly,
309 | finetune_batch_size=args.finetune_batch_size,
310 | finetune_epochs=args.finetune_epochs,
311 | finetune_mask_evenly=args.finetune_mask_evenly,
312 | finetune_chunk_size=args.finetune_chunk_size,
313 | finetune_chunk_stride=args.finetune_chunk_stride,
314 | learning_rate=args.learning_rate,
315 | warmup_steps=args.warmup_steps,
316 | )
317 |
318 | key = f"blanc-{args.type}-measure-{args.measure}"
319 | if args.doc is not None:
320 | result = model.eval_once(args.doc, args.summary)
321 | result_json = {key: result}
322 | elif args.single_json is not None:
323 | with open(args.single_json) as reader:
324 | data = json.load(reader)
325 |
326 | result = model.eval_once(data[args.doc_key], data[args.summary_key])
327 | result_json = {key: result}
328 | elif args.pairs_json is not None:
329 | with open(args.pairs_json) as reader:
330 | data = json.load(reader)
331 | docs = [pair[args.doc_key] for pair in data]
332 | summaries = [pair[args.summary_key] for pair in data]
333 |
334 | result = model.eval_pairs(docs, summaries)
335 | result_json = [{key: score} for score in result]
336 | elif args.doc_summaries_json is not None:
337 | with open(args.doc_summaries_json) as reader:
338 | data = json.load(reader)
339 | docs = [doc_summary[args.doc_key] for doc_summary in data]
340 | doc_summaries = [doc_summary[args.summaries_key] for doc_summary in data]
341 |
342 | result = model.eval_summaries_for_docs(docs, doc_summaries)
343 | result_json = [{key: scores} for scores in result]
344 | else:
345 | raise ValueError('Please provide an input document and summary')
346 |
347 | if args.output_json is None:
348 | print(result)
349 | else:
350 | with open(args.output_json, 'w') as writer:
351 | json.dump(result_json, writer)
352 |
353 |
354 | if __name__ == '__main__':
355 | main()
356 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Evaluation measures
2 |
3 | This repositary contains reference implementations and explanations to accompany [Primer.ai](https://primer.ai) research and publications related to evaluation measures, mostly for the purpose of summary evaluation.
4 |
5 | These evaluation measures include:
6 |
7 | * BLANC-help (or simply 'BLANC'), BLANC-tune
8 | * blanc.py
9 | * All the info is in this page
10 | * Shannon Score, Information Difference, BLANC-Shannon
11 | * shannon.py
12 | * Info: [Shannon Score and Information Difference](https://github.com/PrimerAI/blanc/tree/master/shannon)
13 | * ESTIME, ESTIME-soft, ESTIME-coherence
14 | * estime.py
15 | * Info: [ESTIME (hard, soft and coherence)](https://github.com/PrimerAI/blanc/tree/master/estime)
16 |
17 | Annotated summary quality datasets: [data](https://github.com/PrimerAI/blanc/tree/master/data)
18 |
19 |
20 | ## Setup
21 | 1. Install Python 3.6 or higher
22 | 2. Install with `pip install blanc`
23 |
24 |
25 | ## BLANC
26 | This is the reference implementation of BLANC-help and BLANC-tune as defined in [Fill in the BLANC: Human-free quality estimation of document summaries](https://www.aclweb.org/anthology/2020.eval4nlp-1.2/).
27 |
28 | BLANC is a reference-free approach to the automatic estimation of document summary quality. Our goal is to measure the functional performance of a summary with an objective, reproducible, and fully automated method. Our approach achieves this by measuring the performance boost gained by a pre-trained language model with access to a document summary while carrying out its language understanding task on the document's text. Unlike ROUGE, BLANC does not require human-written reference summaries, allowing for fully human-free summary quality estimation.
29 |
30 | Two types of BLANC scores were introduced in the paper and are available in this repo: BLANC-help and BLANC-tune. BLANC-help is faster to calculate (around 30% faster on CUDA with default settings), but BLANC-tune is more theoretically principled. They are around 90% correlated with each other, so either one can be used in most cases.
31 | BLANC-help with gap=2 on average correlates the best with human scores [Sensitivity of BLANC to human-scored qualities of text summaries](https://arxiv.org/abs/2010.06716), it is now set as default. The original paper used gap=6. Optimal parameters for BLANC-help and for BLANC-tune are found by using 'max-help' criterion, without relying on human summaries or human scores, in [Is Human Scoring the Best Criteria for Summary Evaluation?](https://aclanthology.org/2021.findings-acl.192) (the paper points to the possible bias of human experts).
32 |
33 |
34 | ## Python Usage
35 | Basic usage:
36 | ```python
37 | >>> from blanc import BlancHelp, BlancTune
38 | >>> document = "Jack drove his minivan to the bazaar to purchase milk and honey for his large family."
39 | >>> summary = "Jack bought milk and honey."
40 | >>> blanc_help = BlancHelp()
41 | >>> blanc_tune = BlancTune(finetune_mask_evenly=False, show_progress_bar=False)
42 | >>> blanc_help.eval_once(document, summary)
43 | 0.2222222222222222
44 | >>> blanc_tune.eval_once(document, summary)
45 | 0.3333333333333333
46 | ```
47 |
48 | By default, BLANC is run on the CPU. Using CUDA with batching is much faster:
49 | ```python
50 | blanc_help = BlancHelp(device='cuda', inference_batch_size=128)
51 | blanc_tune = BlancTune(device='cuda', inference_batch_size=24, finetune_mask_evenly=False, finetune_batch_size=24)
52 | ```
53 | With these batch sizes, BLANC-help takes around 1.4 sec per summary and BLANC-tune takes around 1.8 sec per summary on an NVIDIA V100. In addition to the parameters controlling device and batch sizes, BlancHelp and BlancTune take several other parameters controlling how the BLANC scores are calculated, and the default values for those parameters reproduce the results of the paper. BlancTune results may vary if random_seed is not set.
54 |
55 | If you want to compute the BLANC scores of many documents and summaries at once, you can use `eval_pairs()` or `eval_summaries_for_docs()`. `eval_pairs()` is useful when you have many documents, each with a single summary:
56 | ```python
57 | >>> documents = ["Jack drove his minivan to the bazaar to purchase milk and honey for his large family.", "As Jill started taking a walk in the park, she certainly noticed that the trees were extra green this year."]
58 | >>> summaries = ["Jack bought milk and honey.", "Jill saw green trees in the park."]
59 | >>> blanc_help.eval_pairs(documents, summaries)
60 | [0.2222222222222222, 0.0]
61 | ```
62 |
63 | `eval_summaries_for_docs()` is useful when you have many documents, each with many summaries:
64 | ```python
65 | >>> doc_summaries = [["Jack bought milk and honey.", "Jack drove to the bazaar in a minivan"], ["Jill saw green trees in the park.", "The trees were green."]]
66 | >>> blanc_tune.eval_summaries_for_docs(documents, doc_summaries)
67 | [[0.2222222222222222, 0.2222222222222222], [-0.07142857142857142, -0.14285714285714285]]
68 | ```
69 |
70 | ## CLI Usage
71 | A CLI for computing BLANC scores is provided for convenience.
72 | ```
73 | $ blanc help --gap 6 --doc "Jack drove his minivan to the bazaar to purchase milk and honey for his large family." --summary "Jack bought milk and honey."
74 | 0.1111111111111111
75 | ```
76 |
77 | Input data can also be provided in JSON format, with sample JSON input provided in `data/`
78 | ```
79 | $ blanc help --single_json data/single.json --gap 6
80 | 0.1111111111111111
81 | $ blanc tune --pairs_json data/pairs.json --gap 6 --finetune_mask_evenly False
82 | [0.2222222222222222, 0.14285714285714285]
83 | $ blanc tune --doc_summaries_json data/doc-summaries.json --gap 6 --finetune_mask_evenly False
84 | [[0.2222222222222222, 0.2222222222222222], [0.14285714285714285, 0.07142857142857142]]
85 | ```
86 |
87 | The `single_json` input format expects a single JSON blob with keys `document` and `summary`. The `pairs_json` input format expects a list of JSON blobs, each with a `document` and a `summary`. The `doc_summaries_json` input format expects a list of JSON blobs, each with keys `document` and `summaries`, where `summaries` is a list of strings. These keys are customizable with the `doc_key`, `summary_key`, and `summaries_key` arguments. By default, the output is printed to STDOUT, but it can be written to a JSON file provided with the `output_json` argument.
88 |
89 | Full documentation is available with `blanc --help`:
90 | ```
91 | required arguments:
92 | {help,tune} BLANC-help or BLANC-tune
93 |
94 | input arguments:
95 | --doc DOC single input document (default: None)
96 | --summary SUMMARY single input summary (default: None)
97 | --single_json FILENAME
98 | filename for single document summary pair (default:
99 | None)
100 | --pairs_json FILENAME
101 | filename for list of document summary pairs (default:
102 | None)
103 | --doc_summaries_json FILENAME
104 | filename for list of documents, each with a list of
105 | summaries (default: None)
106 | --doc_key KEY json key for the input document (default: doc)
107 | --summary_key KEY json key for the input summary (single_json or
108 | pairs_json input) (default: summary)
109 | --summaries_key KEY json key for the input summaries (doc_summaries_json
110 | input) (default: summaries)
111 |
112 | arguments for BLANC-help and BLANC-tune:
113 | --model_name NAME BERT model type (default: bert-base-uncased)
114 | --measure {improve,relative}
115 | measure improve or relative, as defined in the paper
116 | (default: relative)
117 | --gap GAP distance between words to mask during inference
118 | (default: 2)
119 | --gap_mask NUM number of tokens to mask during inference at each
120 | gap-defined position
121 | (default: 1)
122 | --min_token_length_normal LEN
123 | minimum number of chars in normal tokens to mask,
124 | where a normal token is a whole word (default: 4)
125 | --min_token_length_lead LEN
126 | minimum number of chars in lead token to mask, where a
127 | lead token begins a word (default: 2)
128 | --min_token_length_followup LEN
129 | minimum number of chars in followup token to mask,
130 | where a followup token continues a word (default: 100)
131 | --device DEVICE cpu or cuda device (default: cpu)
132 | --random_seed SEED random seed for python and torch (default: 1)
133 | --inference_batch_size SIZE
134 | batch size to use during inference (default: 1)
135 | --inference_mask_evenly MASK_EVENLY
136 | when True, mask every `gap` tokens that are longer
137 | than `min_token_length` during finetuning, when False
138 | randomly mask tokens with probability 0.15 (default:
139 | True)
140 |
141 | BLANC-help arguments:
142 | --filler_token TOKEN token to use as filler in lieu of summary (default: .)
143 | --help_sep SEP token to use to separate the summary or filler from
144 | the sentence, or '' for no separator (default: )
145 |
146 | BLANC-tune arguments:
147 | --finetune_batch_size SIZE
148 | batch size to use when finetuning on summary (default:
149 | 1)
150 | --finetune_epochs EPOCHS
151 | number of epochs to train for when finetuning on
152 | summary (default: 10)
153 | --finetune_mask_evenly MASK_EVENLY
154 | when True, mask every `gap` tokens that are longer
155 | than `min_token_length`during finetuning, when False
156 | randomly mask tokens with probability 0.15 (default:
157 | False)
158 | --finetune_chunk_size SIZE
159 | number of summary tokens to use at a time when
160 | finetuning (default: 64)
161 | --finetune_chunk_stride STRIDE
162 | number of tokens between summary chunks for finetuning
163 | (default: 32)
164 | --learning_rate LR learning rate when finetuning on summary (default:
165 | 5e-05)
166 | --warmup_steps STEPS warmup steps when finetuning on summary (default: 0)
167 | ```
168 |
169 | ## BLANC on [SummEval](https://github.com/Yale-LILY/SummEval) dataset
170 | BLANC can run on top of any pretrained BERT or AlBERT model (more will be added). The table below lists correlations of BLANC with human scores on the human-annotated [SummEval](https://github.com/Yale-LILY/SummEval) dataset (described in [SummEval: Re-evaluating Summarization Evaluation](https://arxiv.org/abs/2007.12626v4)). The dataset contains 1600 text-summary pairs by 100 texts x 16 systems. We show correlation (Spearman and Kendall's Tau-c) between BLANC-help and experts-average scores for each quality of the summary (coherence, consistency, fluency, relevance):
171 |
172 | |quality|model|Spearman|Kendall|
173 | |:---------------|:-----------|-----:|-----:|
174 | |coherence|bbu|0.122|0.09|
175 | |coherence|bbc|0.197|0.142|
176 | |coherence|blu|0.116|0.085|
177 | |coherence|blc|0.226|0.165|
178 | |coherence|bluw|0.083|0.06|
179 | |coherence|blcw|0.196|0.142|
180 | |coherence|ab|0.168|0.125|
181 | |coherence|al|0.152|0.111|
182 | |coherence|axl|0.15|0.11|
183 | |coherence|axxl|0.127|0.093|
184 | |consistency|bbu|0.19|0.094|
185 | |consistency|bbc|0.19|0.094|
186 | |consistency|blu|0.207|0.102|
187 | |consistency|blc|0.204|0.1|
188 | |consistency|bluw|0.167|0.082|
189 | |consistency|blcw|0.18|0.089|
190 | |consistency|ab|0.192|0.095|
191 | |consistency|al|0.199|0.098|
192 | |consistency|axl|0.179|0.088|
193 | |consistency|axxl|0.2|0.098|
194 | |fluency|bbu|0.089|0.051|
195 | |fluency|bbc|0.108|0.062|
196 | |fluency|blu|0.112|0.065|
197 | |fluency|blc|0.113|0.064|
198 | |fluency|bluw|0.107|0.061|
199 | |fluency|blcw|0.121|0.069|
200 | |fluency|ab|0.124|0.072|
201 | |fluency|al|0.132|0.076|
202 | |fluency|axl|0.119|0.069|
203 | |fluency|axxl|0.115|0.066|
204 | |relevance|bbu|0.216|0.156|
205 | |relevance|bbc|0.278|0.201|
206 | |relevance|blu|0.217|0.156|
207 | |relevance|blc|0.306|0.223|
208 | |relevance|bluw|0.194|0.14|
209 | |relevance|blcw|0.258|0.188|
210 | |relevance|ab|0.27|0.193|
211 | |relevance|al|0.267|0.192|
212 | |relevance|axl|0.245|0.176|
213 | |relevance|axxl|0.246|0.179|
214 |
215 | The [transformers](https://huggingface.co/transformers/pretrained_models.html) models are: bert-base-uncased (bbu), bert-base-cased (bbc), bert-large-uncased (blu), bert-large-cased (blc), bert-large-uncased-whole-word-masking (bluw), bert-large-cased-whole-word-masking (blcw), albert-base-v2 (ab), albert-large-v2 (al), albert-xlarge-v2 (axl), albert-xxlarge-v2 (axxl). The BLANC-help was used with the current default settings (gap=2, min_token_length_normal=4, min_token_length_lead=2, min_token_length_followup=100). All the p-values above are of order 10^-5 or lower.
216 |
217 | The system-level correlations (correlations between 16-dimensional scores after averaging each system scores over 100 texts) have too high p-values. The table below shows only the correlations with p-values <0.05:
218 |
219 | |quality|model|Spearman|p|Kendall|p|
220 | |:---------------|:-----------|-----:|-----:|-----:|-----:|
221 | |consistency|bbu|0.738|0.001|0.567|0.002|
222 | |consistency|bbc|0.759|0.001|0.533|0.003|
223 | |consistency|blu|0.724|0.002|0.567|0.002|
224 | |consistency|blc|0.788|0.0|0.567|0.002|
225 | |consistency|bluw|0.771|0.0|0.617|0.001|
226 | |consistency|blcw|0.791|0.0|0.6|0.001|
227 | |consistency|ab|0.724|0.002|0.583|0.001|
228 | |consistency|al|0.774|0.0|0.6|0.001|
229 | |consistency|axl|0.706|0.002|0.517|0.005|
230 | |consistency|axxl|0.812|0.0|0.617|0.001|
231 | |fluency|bbc|0.558|0.025|0.444|0.017|
232 | |fluency|blc|0.549|0.028|0.444|0.017|
233 | |fluency|bluw|0.525|0.037|0.377|0.043|
234 | |fluency|blcw|0.595|0.015|0.477|0.01|
235 | |fluency|al|0.518|0.04|0.393|0.034|
236 | |fluency|axxl|0.534|0.033|0.41|0.027|
237 | |relevance|bbc| | |0.467|0.011|
238 | |relevance|blc| | |0.467|0.011|
239 | |relevance|blcw|0.515|0.041|0.467|0.011|
240 |
--------------------------------------------------------------------------------
/blanc/utils.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import random
3 | import unicodedata
4 | import copy
5 |
6 | import torch
7 | from torch.nn.utils.rnn import pad_sequence
8 |
9 | # prefix used by the wordpiece tokenizer to indicate that the token continues the previous word
10 | WORDPIECE_PREFIX = '##'
11 | # a range of reasonable token ids to use for replacement during model training
12 | TOKEN_REPLACE_RANGE = (1000, 29999)
13 | # attention mask value that tells the model to not ignore the token
14 | NOT_MASKED = 1
15 | # token type for a single input sequence
16 | TOKEN_TYPE_A = 0
17 | # padding value used for attention_mask
18 | MASK_PAD = 0
19 | # padding value used for token_type_ids
20 | TOKEN_TYPE_PAD = 1
21 | # label to use for tokens to ignore when computing masked language modeling training loss
22 | LABEL_IGNORE = -100
23 | # maximum number of input tokens BERT supports
24 | BERT_MAX_TOKENS = 512
25 |
26 | # used to represent inputs to the BERT model
27 | BertInput = namedtuple(
28 | typename='BertInput',
29 | field_names=['input_ids', 'attention_mask', 'token_type_ids', 'labels', 'masked_idxs'],
30 | )
31 |
32 | # all the configuration options
33 | Config = namedtuple(
34 | 'Config',
35 | [
36 | 'doc_key',
37 | 'summary_key',
38 | 'summaries_key',
39 | 'model_name',
40 | 'measure',
41 | 'gap',
42 | 'gap_mask',
43 | 'gap_tune',
44 | 'gap_mask_tune',
45 | 'min_token_length_normal',
46 | 'min_token_length_lead',
47 | 'min_token_length_followup',
48 | 'min_token_length_normal_tune',
49 | 'min_token_length_lead_tune',
50 | 'min_token_length_followup_tune',
51 | 'device',
52 | 'random_seed',
53 | 'inference_batch_size',
54 | 'inference_mask_evenly',
55 | 'len_sent_allow_cut',
56 | 'filler_token',
57 | 'help_sep',
58 | 'finetune_batch_size',
59 | 'finetune_epochs',
60 | 'finetune_mask_evenly',
61 | 'finetune_chunk_size',
62 | 'finetune_chunk_stride',
63 | 'finetune_top_fully',
64 | 'id_layer_freeze_below',
65 | 'id_layer_freeze_above',
66 | 'show_progress_bar',
67 | 'p_mask',
68 | 'p_token_replace',
69 | 'p_token_original',
70 | 'learning_rate',
71 | 'warmup_steps',
72 | ],
73 | )
74 |
75 | # the default configuration options that don't require a GPU
76 | # We found gap=2 to work the best. To reproduce the original paper results use gap=6
77 | Defaults = Config(
78 | doc_key='doc',
79 | summary_key='summary',
80 | summaries_key='summaries',
81 | model_name='bert-base-uncased',
82 | measure='relative',
83 | gap=2,
84 | gap_mask=1,
85 | gap_tune=-1,
86 | gap_mask_tune=-1,
87 | min_token_length_normal=4,
88 | min_token_length_lead=2,
89 | min_token_length_followup=100,
90 | min_token_length_normal_tune=-1,
91 | min_token_length_lead_tune=-1,
92 | min_token_length_followup_tune=-1,
93 | device='cpu',
94 | random_seed=0,
95 | inference_batch_size=1,
96 | inference_mask_evenly=True,
97 | len_sent_allow_cut=100,
98 | filler_token='.',
99 | help_sep='',
100 | finetune_batch_size=1,
101 | finetune_epochs=10,
102 | finetune_chunk_size=64,
103 | finetune_chunk_stride=32,
104 | finetune_top_fully=True,
105 | id_layer_freeze_below=-1,
106 | id_layer_freeze_above=-1,
107 | show_progress_bar=True,
108 | p_mask=0.15,
109 | p_token_replace=0.1,
110 | p_token_original=0.1,
111 | learning_rate=5e-5,
112 | finetune_mask_evenly=True,
113 | warmup_steps=0,
114 | )
115 |
116 |
117 | def set_seed(seed_value):
118 | random.seed(seed_value)
119 | torch.manual_seed(seed_value)
120 |
121 |
122 | def batch_data(data, batch_size):
123 | """Given a list, batch that list into chunks of size batch_size
124 |
125 | Args:
126 | data (List): list to be batched
127 | batch_size (int): size of each batch
128 |
129 | Returns:
130 | batches (List[List]): a list of lists, each inner list of size batch_size except possibly
131 | the last one.
132 | """
133 | batches = [data[i : i + batch_size] for i in range(0, len(data), batch_size)]
134 | return batches
135 |
136 |
137 | def is_token_large_enough(token, next_token, min_token_lengths):
138 | """Determine if a token is large enough according to min_token_lengths
139 |
140 | Args:
141 | token (str): a wordpiece token
142 | next_token (str): the next wordpiece token in the sequence
143 | min_token_lengths (Tuple[int, int, int]): minimum token lengths for normal tokens, lead
144 | tokens, and followup tokens
145 |
146 | Returns:
147 | large_enough (bool): whether or not the token is large enough
148 | """
149 | min_normal, min_lead, min_followup = min_token_lengths
150 | token_size = len(token)
151 |
152 | if token.startswith(WORDPIECE_PREFIX):
153 | token_size -= len(WORDPIECE_PREFIX)
154 | return token_size >= min_followup
155 | elif next_token.startswith(WORDPIECE_PREFIX):
156 | return token_size >= min_lead
157 | else:
158 | return token_size >= min_normal
159 |
160 |
161 | def mask_tokens_evenly(tokens, gap, min_token_lengths, mask_token, gap_mask=1):
162 | """Produce several maskings for the given tokens where each masking is created by masking every
163 | "gap" tokens, as long as the token is large enough according to min_token_lengths.
164 |
165 | Args:
166 | tokens (List[str]): a sequence of wordpiece tokens
167 | gap (int): the spacing in-between masked tokens
168 | min_token_lengths (Tuple[int, int, int]): minimum token lengths for normal tokens, lead
169 | tokens, and followup tokens
170 | mask_token (str): wordpiece token to use for masking
171 |
172 | Returns:
173 | masked_inputs (List[List[str]]): a list of token sequences, where each token sequence
174 | contains masked tokens separated by "gap" tokens.
175 | all_answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps
176 | token indices corresponding to masked tokens back to their original token.
177 | """
178 | gap = min(gap, len(tokens))
179 | masked_inputs = []
180 | all_answers = []
181 | for modulus in range(gap):
182 | masked_input = []
183 | answers = {}
184 | for idx, token in enumerate(tokens):
185 | next_token = '' if idx + 1 == len(tokens) else tokens[idx + 1]
186 | large_enough = is_token_large_enough(token, next_token, min_token_lengths)
187 |
188 | idx_off = idx % gap
189 | if gap == 1:
190 | can_mask = True
191 | elif modulus + gap_mask >= gap:
192 | can_mask = idx_off >= modulus or idx_off < (modulus + gap_mask)%gap
193 | else:
194 | can_mask = idx_off >= modulus and idx_off < modulus + gap_mask
195 | if can_mask and large_enough:
196 | masked_input.append(mask_token)
197 | answers[idx] = token
198 | else:
199 | masked_input.append(token)
200 |
201 | if len(answers) > 0:
202 | masked_inputs.append(masked_input)
203 | all_answers.append(answers)
204 |
205 | return masked_inputs, all_answers
206 |
207 |
208 | def mask_tokens_randomly(tokens, min_token_lengths, mask_token, p_mask):
209 | """Produce several maskings for the given tokens by randomly choosing tokens to mask
210 |
211 | Args:
212 | tokens (List[str]): a sequence of wordpiece tokens
213 | min_token_lengths (Tuple[int, int, int]): minimum token lengths for normal tokens, lead
214 | tokens, and followup tokens
215 | mask_token (str): wordpiece token to use for masking
216 |
217 | Returns:
218 | masked_inputs (List[List[str]]): a list of token sequences, where each token sequence
219 | contains masked tokens chosen randomly.
220 | all_answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps
221 | token indices corresponding to masked tokens back to their original token.
222 | """
223 | n_mask = max(int(len(tokens) * p_mask), 1)
224 |
225 | token_positions = []
226 | for idx, token in enumerate(tokens):
227 | next_token = '' if idx + 1 == len(tokens) else tokens[idx + 1]
228 | if is_token_large_enough(token, next_token, min_token_lengths):
229 | token_positions.append(idx)
230 | random.shuffle(token_positions)
231 |
232 | all_inputs, all_answers = [], []
233 | while len(token_positions) > 0:
234 | positions_to_mask = token_positions[:n_mask]
235 | token_positions = token_positions[n_mask:]
236 |
237 | inputs, answers = [], {}
238 | for idx, token in enumerate(tokens):
239 | if idx in positions_to_mask:
240 | inputs.append(mask_token)
241 | answers[idx] = token
242 | else:
243 | inputs.append(token)
244 |
245 | all_inputs.append(inputs)
246 | all_answers.append(answers)
247 |
248 | return all_inputs, all_answers
249 |
250 |
251 | def stack_tensor(input_list, pad_value, device):
252 | """Given a batch of inputs, stack them into a single tensor on the given device, padding them
253 | at the back with pad_value to make sure they are all the same length.
254 |
255 | Args:
256 | input_list (List[List[int]]): a list of input sequences
257 | pad_value (int): the value to use for padding input sequences to make them the same length
258 | device (str): torch device (usually "cpu" or "cuda")
259 |
260 | Returns:
261 | stacked_tensor (torch.LongTensor): a tensor of dimensions (batch size) x (seq length)
262 | """
263 | tensor_list = [torch.LongTensor(inputs) for inputs in input_list]
264 | stacked_tensor = pad_sequence(
265 | sequences=tensor_list, batch_first=True, padding_value=pad_value
266 | ).to(device)
267 |
268 | return stacked_tensor
269 |
270 |
271 | def get_input_tensors(input_batch, device, tokenizer):
272 | """Given a list of BertInputs, return the relevant tensors that are fed into BERT.
273 |
274 | Args:
275 | input_batch (List[BertInput]): a batch of model inputs
276 | device (str): torch device (usually "cpu" or "cuda")
277 | tokenizer (BertTokenizer): the wordpiece tokenizer used for BERT
278 |
279 | Returns:
280 | input_ids (torch.LongTensor): ids corresponding to input tokens
281 | attention_mask (torch.LongTensor): tells BERT about parts of the input to ignore
282 | token_type_ids (torch.LongTensor): used to differentiate input segments
283 | labels (torch.LongTensor): contains the original token ids for tokens that were masked
284 | """
285 | input_ids_list = [inputs.input_ids for inputs in input_batch]
286 | attention_mask_list = [inputs.attention_mask for inputs in input_batch]
287 | token_type_ids_list = [inputs.token_type_ids for inputs in input_batch]
288 | labels_list = [inputs.labels for inputs in input_batch]
289 |
290 | (id_pad,) = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])
291 | input_ids = stack_tensor(input_ids_list, pad_value=id_pad, device=device)
292 | attention_mask = stack_tensor(attention_mask_list, pad_value=MASK_PAD, device=device)
293 | token_type_ids = stack_tensor(token_type_ids_list, pad_value=TOKEN_TYPE_PAD, device=device)
294 |
295 | if labels_list[0] is not None:
296 | labels = stack_tensor(labels_list, pad_value=LABEL_IGNORE, device=device)
297 | else:
298 | labels = None
299 |
300 | return input_ids, attention_mask, token_type_ids, labels
301 |
302 |
303 | def determine_correctness(outputs, answers):
304 | """Given dicts corresponding to predicted tokens and actual tokens at different indices, return
305 | a list of bools for whether or not those predictions were correct.
306 |
307 | Args:
308 | outputs (List[Dict[int, str]]): each list represents a different input masking, and each
309 | dict maps indices to model predictions
310 | answers (List[Dict[int, str]]): each list represents a different input masking, and each
311 | dict maps indices to original tokens
312 |
313 | Returns:
314 | correctness (List[bool]): a list of values that are True if the model made a correct
315 | prediction and False otherwise
316 | """
317 | correctness = []
318 | for output, answer in zip(outputs, answers):
319 | for idx, actual_token in answer.items():
320 | predicted_token = output[idx]
321 | correctness.append(predicted_token == actual_token)
322 |
323 | return correctness
324 |
325 |
326 | def measure_relative(S):
327 | """Calculate the "measure-relative" score as defined in the paper
328 |
329 | Args:
330 | S (List[List[int]]): accuracy counts as defined in the paper
331 |
332 | Returns:
333 | score (float): measure-relative score
334 | """
335 | denom = S[0][0] + S[1][1] + S[0][1] + S[1][0]
336 | if denom == 0:
337 | return 0
338 | return (S[0][1] - S[1][0]) / denom
339 |
340 |
341 | def measure_improve(S):
342 | """Calculate the "measure-improve" score as defined in the paper
343 |
344 | Args:
345 | S (List[List[int]]): accuracy counts as defined in the paper
346 |
347 | Returns:
348 | score (float): measure-improve score
349 | """
350 | denom = S[0][0] + S[1][1] + S[0][1]
351 | if denom == 0:
352 | return 0
353 | return S[0][1] / denom
354 |
355 |
356 | def clean_text(text):
357 | """Return a cleaned version of the input text
358 |
359 | Args:
360 | text (str): dirty text
361 |
362 | Returns:
363 | text (str): cleaned text
364 | """
365 | text = unicodedata.normalize('NFKD', text)
366 | return text
367 |
368 |
369 | def truncate_sentence_and_summary(
370 | sent, summary, len_sep=0, len_sent_allow_cut=0, truncate_bottom=True,
371 | ):
372 | """Cut summary+sentence to allowed input size. 2 more tokens: [CLS], [SEP]
373 | The summary must have at least one sublist (can be empty)
374 | The sentence is cut by tokens from the bottom.
375 | The summary is cut by sentences. Last sentence is cut by tokens.
376 |
377 | Args:
378 | sent (List[str]): Sentence as a list of tokens
379 | summary (List[List[str]]): Summary as list of sentences, each sentence is list of tokens
380 | len_sep (int): Number of tokens in a separator used between the summary and the sentence
381 | len_sent_allow_cut (int): Allowed size of truncated sentence before cutting summary
382 | truncate_bottom (bool): Indicator how to cut the summary
383 |
384 | Returns:
385 | sent (List[str]): Truncated (if necessary) sentence as a list of tokens
386 | summary_tokens (List[str]): Truncated (if necessary) summary as a list of tokens
387 | """
388 | summary_tokens = [t for sublist in summary for t in sublist]
389 | len_input_estimate = 2 + len(summary_tokens) + len_sep + len(sent)
390 | len_excess = len_input_estimate - BERT_MAX_TOKENS
391 | if len_excess > 0:
392 | len_cut_sent = min(len_excess, len(sent) - len_sent_allow_cut)
393 | len_sent_new = len(sent) - len_cut_sent
394 | sent = sent[:len_sent_new]
395 | assert len_excess <= len_cut_sent or summary[0]
396 | if len_excess > len_cut_sent:
397 | len_summary_max = BERT_MAX_TOKENS - 2 - len_sep - len(sent)
398 | summary_truncated = truncate_list_of_lists(
399 | sents_tokenized=summary, num_max=len_summary_max, truncate_bottom=truncate_bottom,
400 | )
401 | summary_tokens = [t for sublist in summary_truncated for t in sublist]
402 | assert len(sent) + len(summary_tokens) + len_sep + 2 <= BERT_MAX_TOKENS
403 | return sent, summary_tokens
404 |
405 |
406 | def truncate_list_of_lists(sents_tokenized, num_max, truncate_bottom=True):
407 | """Return a truncated list, with summ of tokens not exceeding maximum.
408 | Truncate by lists. If single left list is still too long, truncate it by tokens.
409 | In our context each element of sents_tokenized is a sentence represented as a list of tokens.
410 |
411 | Args:
412 | sents_tokenized (List[List[str]]): List, each element is a list.
413 | num_max (int): maximal allowed number of tokens.
414 | truncate_bottom (bool): truncate starting from bottom lists.
415 |
416 | Returns:
417 | sents_tokenized (List[str]): truncated list
418 | """
419 | sents_truncated = []
420 | if truncate_bottom:
421 | len_truncated = 0
422 | # Cut by sentences:
423 | for sent in sents_tokenized:
424 | len_truncated_maybe = len_truncated + len(sent)
425 | if len_truncated_maybe > num_max:
426 | break
427 | len_truncated = len_truncated_maybe
428 | sents_truncated.append(sent)
429 | if len_truncated == num_max:
430 | break
431 | else:
432 | sents_truncated = copy.deepcopy(sents_tokenized)
433 | len_truncated = sum([len(s) for s in sents_tokenized])
434 | # Cut by sentences:
435 | for sent in sents_tokenized:
436 | if len_truncated <= num_max:
437 | break
438 | sents_truncated = sents_truncated[1:]
439 | len_truncated = len_truncated - len(sent)
440 | if not sents_truncated:
441 | sent_use = sents_tokenized[0] if truncate_bottom else sents_tokenized[-1]
442 | sents_truncated = [copy.deepcopy(sent_use)]
443 | len_truncated = len(sents_truncated[0])
444 | # Cut by tokens - always from the top:
445 | if len_truncated > num_max:
446 | len_remove = len_truncated - num_max
447 | sents_truncated[0] = sents_truncated[0][len_remove:]
448 | assert sum([len(s) for s in sents_truncated]) <= num_max
449 | return sents_truncated
450 |
--------------------------------------------------------------------------------
/blanc/estime.py:
--------------------------------------------------------------------------------
1 | import unicodedata
2 | import copy
3 | import scipy
4 | from numpy import dot
5 | from numpy.linalg import norm
6 |
7 | import logging
8 | logging.getLogger('transformers').setLevel(level=logging.WARNING)
9 |
10 | import nltk
11 | from nltk.tokenize import word_tokenize
12 |
13 | import torch
14 | from transformers import BertForMaskedLM, BertTokenizer, BertModel
15 |
16 |
17 | class Estime:
18 | """Estimator of factual inconsistencies between summaries (or other textual claims)
19 | and text. Usage: create `Estime`, and use `evaluate_claims()`.
20 | In creating Estime, specify the names of the desired measures in the list 'output'.
21 | The function evaluate_claims() will return (for each claim) the list of results in
22 | the same order. The list 'output' can also include:
23 | 'alarms': the original ESTIME
24 | 'alarms_adjusted': the original ESTIME, extrapolated to non-overlapping tokens
25 | 'alarms_alltokens': ESTIME on all (not only overlapped) summary tokens
26 | 'soft': the soft ESTIME
27 | 'coherence': measure of summary coherence
28 | """
29 | def __init__(
30 | self,
31 | path_mdl='bert-large-uncased-whole-word-masking',
32 | path_mdl_raw='bert-base-uncased',
33 | i_layer_context=21,
34 | device='cpu',
35 | output=['alarms'],
36 | tags_check=None,
37 | tags_exclude=None,
38 | input_size_max=450,
39 | margin=50,
40 | distance_word_min=8):
41 | """
42 | Args:
43 | path_mdl (str): model for embeddings of masked tokens
44 | path_mdl_raw (str): model for raw embeddings
45 | i_layer_context (int): index of layer to take contextual embeddings from
46 | device (str): 'cpu' or 'cuda'
47 | tags_check (List[str]): list of parts of speech to use, each is a tag
48 | by NLTK notations. If None, then all words will be used,
49 | no matter which parts of speech they are.
50 | tags_exclude (List[str]): List or set of part-of-speach tags that should
51 | not be considered. Has priority over tags_check.
52 | input_size_max (int): length of input for the model as number of tokens
53 | margin (int): number of tokens on input edges not to take embeddings from
54 | distance_word_min (int): minimal distance between the masked tokens
55 | from which embeddings are to ne taken in a single input
56 | output (List[str]): list of names of measures to get in claim evaluations.
57 | The list must be nonempty and can include:
58 | 'alarms', 'alarms_adjusted', 'alarms_alltokens', 'soft', 'coherence'.
59 | The function evaluate_claims() will return (for each claim) the list of
60 | results in the same order.
61 | """
62 | self.i_layer_context = i_layer_context
63 | self.device = device
64 | assert output
65 | self.output = output
66 | self.ESTIME_ALARMS = 'alarms' # original estime: by response of tokens to similar contexts
67 | self.ESTIME_ALLTOKENS = 'alarms_alltokens' # estime on all (not only overlapped) summary tokens
68 | self.ESTIME_ADJUSTED = 'alarms_adjusted' # original estime, extrapolated to non-overlapping tokens
69 | self.ESTIME_SOFT = 'soft' # soft estime by response of similarity between embeddings
70 | self.ESTIME_COHERENCE = 'coherence' # estimation of summary coherence
71 | self.get_estime_adjusted = self.ESTIME_ADJUSTED in self.output
72 | self.get_estime_soft = self.ESTIME_SOFT in self.output
73 | self.get_estime_coherence = self.ESTIME_COHERENCE in self.output
74 | self.tags_check = tags_check
75 | self.tags_exclude = tags_exclude
76 | self.input_size_max = input_size_max
77 | self.margin = margin
78 | self.distance_word_min = distance_word_min
79 | self.model_tokenizer = None
80 | self.model = None
81 | self.model_tokenizer = BertTokenizer.from_pretrained(path_mdl)
82 | self.model = BertForMaskedLM.from_pretrained(path_mdl, output_hidden_states = True).to(self.device)
83 | self.model.eval()
84 | if self.get_estime_soft:
85 | self.model_raw = BertModel.from_pretrained(path_mdl_raw)
86 | self.model_raw.eval()
87 | for param in self.model_raw.parameters():
88 | param.requires_grad = False
89 | self.embeddings_raw = self.model_raw.get_input_embeddings()
90 | # Convenient data (including tokenized summary and text):
91 | self.text_map_wordscheck = None
92 | self.summ_toks = None
93 | self.text_toks = None
94 | self.embs_mask_text = None
95 | self.embs_raw_text = None
96 |
97 |
98 | def evaluate_claims(self, text, claims):
99 | """
100 | Given a text, and a list of claims (e.g. summaries, or other texts),
101 | estimates how many likely factual inconsistencies each claim contains
102 | (the inconsistencies are with respect to the text).
103 | Returns for each claim whatever is specified in self.output
104 | Args:
105 | text (str): the main text
106 | claims (List[str]): texts ('claims') to be checked for consistency
107 | with the main text. Each claim is preferably a shorter text,
108 | like a claim/statement/summary.
109 | Returns:
110 | claims_info: list of the same length as claims; each element is a
111 | consistency info for the corresponding claim. The info is a list
112 | accordingly to the names in self.output.
113 | """
114 | # Text:
115 | text = unicodedata.normalize('NFKD', text)
116 | text_words = word_tokenize(text)
117 | text_tagged = nltk.pos_tag(text_words)
118 | # Find all words of interest in the text, tokenize:
119 | self.text_map_wordscheck = self._get_map_words_intext(text_tagged)
120 | text_iwordscheck = sorted([item for sublist in self.text_map_wordscheck.values() for item in sublist])
121 | self.text_toks, text_map_iword_itoks = self._translate_words_to_tokens(text_words)
122 | # All embeddings of interest in the text:
123 | self.embs_mask_text = self._get_embeddings(
124 | tokens=self.text_toks,
125 | ixs_words=text_iwordscheck,
126 | map_words_to_tokens=text_map_iword_itoks)
127 | self.embs_raw_text = self._get_embeddings_raw(tokens=self.text_toks)
128 | # Get the consistency info for each claim:
129 | claims_info = []
130 | for claim in claims:
131 | claim = unicodedata.normalize('NFKD', claim)
132 | claim_info = self._evaluate_claim(claim)
133 | claims_info.append(claim_info)
134 | self.summ_toks = None
135 | self.text_toks = None
136 | return claims_info
137 |
138 |
139 | def _evaluate_claim(self, claim, words_check=None):
140 | """
141 | Text is already processed, its embeddings can be used for the claim.
142 | Args:
143 | claim (str): claim, e.g. summary or short text - not the main text
144 | words_check (Set{str}): a set or map where the keys are the words
145 | of interest in the main text
146 | Returns:
147 | estime_info (List[float]): a list with results corresponding to the
148 | names of measures specified in self.output.
149 | """
150 | summ_words = word_tokenize(claim)
151 | summ_tagged = nltk.pos_tag(summ_words)
152 | summ_iwordscheck, summ_iwords_overlap = [],[] # Find all words of interest in the summary
153 | for i, (w, t) in enumerate(summ_tagged):
154 | if not words_check or w in words_check: # if required, checking only what exists in the text
155 | summ_iwordscheck.append(i)
156 | if not self.text_map_wordscheck or w in self.text_map_wordscheck:
157 | summ_iwords_overlap.append(i)
158 | self.summ_toks, summ_map_iword_itoks = self._translate_words_to_tokens(summ_words)
159 | embs_mask_summ = self._get_embeddings(
160 | tokens=self.summ_toks, ixs_words=summ_iwordscheck, map_words_to_tokens=summ_map_iword_itoks)
161 | embs_raw_summ = self._get_embeddings_raw(tokens=self.summ_toks)
162 | summ_itoksoverlap = set()
163 | for iword in summ_iwords_overlap:
164 | itok = summ_map_iword_itoks[iword][0] # only first token of each word
165 | summ_itoksoverlap.add(itok)
166 | estime_info = self._evaluate(
167 | embs_mask_summ, self.embs_mask_text,
168 | embs_raw_summ, self.embs_raw_text, summ_itokscheck=summ_itoksoverlap)
169 | return estime_info
170 |
171 |
172 | def _get_embeddings_raw(self, tokens):
173 | """Simply gets raw embeddings. Needed only for estime-soft."""
174 | if not self.get_estime_soft:
175 | return None
176 | toks_ids = self.model_tokenizer.convert_tokens_to_ids(tokens)
177 | input_tensor = torch.LongTensor([toks_ids])
178 | word_embeddings = self.embeddings_raw(input_tensor)[0].numpy()
179 | return word_embeddings
180 |
181 |
182 | def _get_embeddings(self, tokens, ixs_words, map_words_to_tokens):
183 | """
184 | Finds embeddings for all tokens of all words of interest. The embeddings
185 | are obtained one group of words at a time; each group contains well
186 | separated indexes, so that masked indexes do have enough context around.
187 | Args:
188 | tokens (List[str]): List of tokens, as strings. Represents the text.
189 | ixs_words (List[int]): List of indexes of words to check
190 | map_words_to_tokens Dict[int:(int,int)]:
191 | Maps each index of word in the text to its range of tokens, the
192 | range is: index of first token, index of one-past-last token.
193 | Returns:
194 | map_itok_embeddings (Dict{int: ndarray[float]}):
195 | Map of token index (in the text) to its obtained embedding
196 | """
197 | # groups of well-separated words, represented by their indexes:
198 | groups = self._group_indexes_separated(ixs=ixs_words)
199 | map_itok_embeddings = {}
200 | for group in groups:
201 | map_itok_embeds = self._get_embeddings_of_sparse_words(
202 | tokens=tokens,
203 | ixs_words=group,
204 | map_words_to_tokens=map_words_to_tokens)
205 | map_itok_embeddings = {**map_itok_embeddings, **map_itok_embeds}
206 | return map_itok_embeddings
207 |
208 |
209 | def _get_embeddings_of_sparse_words(self, tokens, ixs_words, map_words_to_tokens):
210 | """Gets results for the get_embeddings function, which combines the results
211 | from tokens from groups of sparsely spread words.
212 | Here the result is obtained for one group of sparsely separated words.
213 | Args:
214 | tokens (List[str]): List of tokens, as strings. Represents the text.
215 | ixs_words (List[int]): List of indexes of words to check
216 | map_words_to_tokens Dict[int:(int,int)]:
217 | Maps each index of word in the text to its range of tokens, the
218 | range is: index of first token, index of one-past-last token.
219 | Returns:
220 | map_itok_embeds (Dict{int: ndarray[float]}): Map of token index (in text)
221 | to its obtained embedding
222 | """
223 | map_itok_embeddings = {}
224 | toks_mask = [map_words_to_tokens[i] for i in ixs_words] # indexes (beg,end) of tokens to mask
225 | while toks_mask:
226 | i_tok_first = toks_mask[0][0] # first token allowed to mask
227 | ix_toks_input_beg = max(0, i_tok_first - self.margin) # input starts here
228 | ix_toks_input_end = min(len(tokens), ix_toks_input_beg + self.input_size_max) # input ends here
229 | i_tok_allowed_last = ix_toks_input_beg + self.input_size_max - self.margin # last token allowed to mask
230 | toks_mask_input = [] # tokens to be masked in the input
231 | for word_toks in toks_mask:
232 | if word_toks[0] >= i_tok_first and word_toks[1]-1 <= i_tok_allowed_last:
233 | toks_mask_input.append(word_toks)
234 | if word_toks[0] > i_tok_allowed_last:
235 | break
236 | # for preparing next input:
237 | n_words_used = len(toks_mask_input)
238 | toks_mask = toks_mask[n_words_used:]
239 | # get embeddings for the input:
240 | map_itok_embeds = self._get_embeddings_from_input(
241 | tokens,
242 | ix_toks_input_beg=ix_toks_input_beg,
243 | ix_toks_input_end=ix_toks_input_end,
244 | toks_mask_input=toks_mask_input)
245 | map_itok_embeddings = {**map_itok_embeddings, **map_itok_embeds} # from all inputs so far
246 | return map_itok_embeddings
247 |
248 |
249 | def _get_embeddings_from_input(self, tokens, ix_toks_input_beg, ix_toks_input_end, toks_mask_input):
250 | """
251 | Gets embeddings for one specific input window.
252 | Returns embeddings for first tokens of each word of interest, while all
253 | tokens of the word are masked.
254 | Args:
255 | tokens (List[str]): Tokens of a summary or text
256 | ix_toks_input_beg (int): Index of first token of the input window
257 | ix_toks_input_end (int): Index of the end of the input window
258 | toks_mask_input (List[(int,int)]): Indexes of all tokens to mask
259 | in the input window, given as a list if duples, each duple
260 | is index of first and one-past last tokens to mask.
261 | Returns:
262 | map_itok_embeds (Dict{int: ndarray[float]}): Map of token indexes
263 | as in text tokens to their embeddings. Covers first tokens
264 | of all words of interest.
265 | """
266 | # Ids of tokens in input for taking embedding
267 | input_toks = copy.deepcopy(tokens[ix_toks_input_beg:ix_toks_input_end])
268 | map_itok_iembed = {}
269 | # Do masking, and also keep record of where to take embeddings from:
270 | for word_toks in toks_mask_input:
271 | i_tok_first = word_toks[0] # first token of the word
272 | map_itok_iembed[i_tok_first] = 1 + i_tok_first - ix_toks_input_beg # shift=1 by first [CLS]
273 | for i in range(word_toks[0], word_toks[1]):
274 | input_toks[i - ix_toks_input_beg] = '[MASK]'
275 | # Get embeddings of interest for this input:
276 | toks_ids = self.model_tokenizer.convert_tokens_to_ids(['[CLS]'] + input_toks + ['[SEP]'])
277 | input_tensor = torch.LongTensor([toks_ids]).to(self.device)
278 | outputs = self.model(input_tensor)
279 | # Contextual embedding:
280 | emb_all = outputs[1][self.i_layer_context][0] # all embeddings (for all tokens, at this layer)
281 | map_itok_embed = {}
282 | for itok, iembed in map_itok_iembed.items(): # itok is id of token exactly as in tokens
283 | map_itok_embed[itok] = emb_all[iembed].cpu().detach().numpy()
284 | return map_itok_embed
285 |
286 |
287 | def _evaluate(self, embs_summ, embs_text, embs_summ_raw, embs_text_raw, summ_itokscheck=None):
288 | """
289 | Args:
290 | embs_summ (Dict{int: List[float]}): Map of token indexes
291 | (in summary) to the corresponding embeddings
292 | embs_text (Dict{int: List[float]}): Map of token indexes
293 | (in text) to the corresponding embeddings
294 | embs_summ_raw and embs_text_raw - the same as above, but with the raw embeddings
295 | summ_itokscheck (Set[int]): Set of indexes of tokens (in summary)
296 | that must be verified for alarms.
297 | This is needed for calculating the original and 'adjusted' ESTIME.
298 | Returns:
299 | List[float)] - List of results in order as specified in self.output
300 | """
301 | # estime standard, adjusted, all-tokens and soft:
302 | n_alarms, n_alarms_alltoks, cos_raw_avg = 0, 0, 0
303 | itoks_similar = []
304 | for itok_summ, emb_summ in embs_summ.items():
305 | tok_summ = self.summ_toks[itok_summ]
306 | itok_text_best = -1
307 | sim_best = -1.0e30
308 | for itok_text, emb_text in embs_text.items():
309 | sim = dot(emb_summ, emb_text)
310 | if sim > sim_best:
311 | sim_best = sim
312 | itok_text_best = itok_text
313 | tok_text_best = self.text_toks[itok_text_best]
314 | itoks_similar.append((itok_summ, itok_text_best, sim_best))
315 | if tok_text_best != tok_summ:
316 | n_alarms_alltoks += 1
317 | if not summ_itokscheck or itok_summ in summ_itokscheck:
318 | n_alarms += 1
319 | # Soft estime:
320 | if self.get_estime_soft:
321 | emb_summ_nomask = embs_summ_raw[itok_summ]
322 | emb_text_nomask = embs_text_raw[itok_text_best]
323 | prod = dot(emb_summ_nomask, emb_text_nomask)
324 | norm_summ, norm_text = norm(emb_summ_nomask), norm(emb_text_nomask)
325 | cos_raw_avg += prod / (norm_summ * norm_text)
326 | if self.get_estime_soft:
327 | cos_raw_avg /= len(embs_summ)
328 | # estime-alarms-adjusted:
329 | if self.get_estime_adjusted:
330 | if not summ_itokscheck:
331 | n_alarms_adj = len(embs_summ)
332 | else:
333 | n_alarms_adj = n_alarms * len(embs_summ) / len(summ_itokscheck)
334 | # Coherence:
335 | if self.get_estime_coherence:
336 | itoks_summ = [a[0] for a in itoks_similar]
337 | itoks_text = [a[1] for a in itoks_similar]
338 | coherence = scipy.stats.kendalltau(itoks_summ, itoks_text, variant='c').correlation
339 | result = []
340 | for out_name in self.output:
341 | if out_name == self.ESTIME_ALARMS:
342 | result.append(n_alarms)
343 | elif out_name == self.ESTIME_ADJUSTED:
344 | result.append(n_alarms_adj)
345 | elif out_name == self.ESTIME_ALLTOKENS:
346 | result.append(n_alarms_alltoks)
347 | elif out_name == self.ESTIME_SOFT:
348 | result.append(cos_raw_avg)
349 | elif out_name == self.ESTIME_COHERENCE:
350 | result.append(coherence)
351 | return result
352 |
353 |
354 | def _select_indexes_separated(self, ixs):
355 | """Given a list of sorted integers, starts with the first and selects next ones
356 | in such way that the difference between neighbors is not smaller than the given
357 | value. Meaning: the integers are the indexes of words in a text.
358 | Args:
359 | ixs (List[int]): list of indexes
360 | Returns:
361 | ixs_select (List[int]): list of well-separated selected indexes
362 | ixs_remain (List[int]): list of all the remaining indexes
363 | """
364 | ixs_remain = []
365 | ixs_select = []
366 | ix_prev = -1000000
367 | for ix in ixs:
368 | if ix - ix_prev >= self.distance_word_min:
369 | ixs_select.append(ix)
370 | ix_prev = ix
371 | else:
372 | ixs_remain.append(ix)
373 | return ixs_select, ixs_remain
374 |
375 |
376 | def _group_indexes_separated(self, ixs):
377 | """Splits a sorted list of indexes (of words in a text) to groups (lists)
378 | of indexes, such that indexes in each groups are separated by the given
379 | minimal distance.
380 | Args:
381 | ixs (List[int]): list of indexes
382 | Returns:
383 | groups (List[List[int]]): list of lists of indexes. Each list of indexes
384 | contains well-separated indexes, satisfying the distance_word_min.
385 | """
386 | groups = []
387 | ixs_remain = copy.deepcopy(ixs)
388 | while ixs_remain:
389 | ixs_select, ixs_remain = self._select_indexes_separated(ixs_remain)
390 | groups.append(ixs_select)
391 | return groups
392 |
393 |
394 | def _get_map_words_intext(self, text_tagged):
395 | """
396 | Creates dictionary of words in the text, with all occurrences for each word
397 | Args:
398 | text_tagged (List[(str,str)]): List of duples, each is word and its
399 | part-of-speach tag. The list is result of nltk.pos_tag function.
400 | Returns:
401 | map_words_text Dict{str:List[int]}: Dictionary, key is word from the text,
402 | value is List[int] - list of all word occurrence indexes in the text
403 | """
404 | map_words_text = {}
405 | for i, (w, t) in enumerate(text_tagged):
406 | if self.tags_check and t not in self.tags_check:
407 | continue
408 | if self.tags_exclude and t in self.tags_exclude:
409 | continue
410 | if w not in map_words_text:
411 | map_words_text[w] = [i]
412 | else:
413 | map_words_text[w].append(i)
414 | return map_words_text
415 |
416 |
417 | def _translate_words_to_tokens(self, text_words):
418 | """Tokenizes text by model tokenizer.
419 | Keeps map of indexes of words into indexes of tokens.
420 | Args:
421 | text_words (List[str]): Text given as a list of words
422 | Returns:
423 | text_tokens (List[str]): Text given as list of tokens
424 | map_iword_itoks (Dict[int:(int,int)]): Dictionary of the same length
425 | as text_words. Word index points to duple of token indexes:
426 | index of the first token, and index of the end (one-past-last) token.
427 | """
428 | text_tokens = []
429 | map_iword_itoks = {}
430 | i_tok = 0
431 | for ix_word, word in enumerate(text_words):
432 | toks = self.model_tokenizer.tokenize(word)
433 | text_tokens.extend(toks)
434 | i_tok_end = i_tok + len(toks)
435 | map_iword_itoks[ix_word] = (i_tok, i_tok_end)
436 | i_tok = i_tok_end
437 | return text_tokens, map_iword_itoks
438 |
--------------------------------------------------------------------------------
/blanc/blanc.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import random
4 |
5 | logging.getLogger('transformers').setLevel(level=logging.WARNING)
6 |
7 | from nltk.tokenize import sent_tokenize
8 | import torch
9 | from torch.nn.utils.rnn import pad_sequence
10 | import tqdm
11 | from transformers import BertForMaskedLM, BertTokenizer, AdamW, get_linear_schedule_with_warmup
12 | from transformers import AlbertForMaskedLM, AlbertTokenizer
13 |
14 | from blanc.utils import (
15 | BertInput,
16 | Defaults,
17 | batch_data,
18 | mask_tokens_evenly,
19 | mask_tokens_randomly,
20 | get_input_tensors,
21 | determine_correctness,
22 | measure_relative,
23 | measure_improve,
24 | clean_text,
25 | truncate_list_of_lists,
26 | truncate_sentence_and_summary,
27 | set_seed,
28 | NOT_MASKED,
29 | TOKEN_TYPE_A,
30 | LABEL_IGNORE,
31 | BERT_MAX_TOKENS,
32 | TOKEN_REPLACE_RANGE,
33 | )
34 |
35 |
36 | class Blanc:
37 | """An abstract superclass containing shared functionality between BlancHelp and BlancTune.
38 | measure ('relative' or 'improve') is a choice of how the success of inference is measured.
39 | Add '-counts' to return also counts: 'relative-counts' or 'improve-counts'.
40 | """
41 |
42 | def __init__(
43 | self,
44 | model_name=Defaults.model_name,
45 | measure=Defaults.measure,
46 | gap=Defaults.gap,
47 | gap_mask=Defaults.gap_mask,
48 | gap_tune=Defaults.gap_tune,
49 | gap_mask_tune=Defaults.gap_mask_tune,
50 | min_token_length_normal=Defaults.min_token_length_normal,
51 | min_token_length_lead=Defaults.min_token_length_lead,
52 | min_token_length_followup=Defaults.min_token_length_followup,
53 | min_token_length_normal_tune=Defaults.min_token_length_normal_tune,
54 | min_token_length_lead_tune=Defaults.min_token_length_lead_tune,
55 | min_token_length_followup_tune=Defaults.min_token_length_followup_tune,
56 | device=Defaults.device,
57 | inference_batch_size=Defaults.inference_batch_size,
58 | inference_mask_evenly=Defaults.inference_mask_evenly,
59 | len_sent_allow_cut=Defaults.len_sent_allow_cut,
60 | p_mask=Defaults.p_mask,
61 | show_progress_bar=Defaults.show_progress_bar,
62 | ):
63 | """This class should not be instantiated directly: instead use BlancHelp or BlancTune"""
64 | self.model_name = model_name
65 | self.measure = measure
66 | self.gap = gap
67 | self.gap_mask = gap_mask
68 | self.gap_tune = gap_tune
69 | self.gap_mask_tune = gap_mask_tune
70 | self.min_token_length_normal = min_token_length_normal
71 | self.min_token_length_lead = min_token_length_lead
72 | self.min_token_length_followup = min_token_length_followup
73 | self.min_token_length_normal_tune = min_token_length_normal_tune
74 | self.min_token_length_lead_tune = min_token_length_lead_tune
75 | self.min_token_length_followup_tune = min_token_length_followup_tune
76 | self.device = device
77 | self.inference_batch_size = inference_batch_size
78 | self.inference_mask_evenly = inference_mask_evenly
79 | self.len_sent_allow_cut = len_sent_allow_cut
80 | self.p_mask = p_mask
81 | self.show_progress_bar = show_progress_bar
82 |
83 | # The same is intentionally not given:
84 | self.gap_tune = self.gap if self.gap_tune < 0 else self.gap_tune
85 | self.gap_mask_tune = self.gap_mask if self.gap_mask_tune < 0 else self.gap_mask_tune
86 |
87 | if self.model_name.lower().find('albert') >= 0:
88 | self.model_tokenizer = AlbertTokenizer.from_pretrained(model_name)
89 | else:
90 | self.model_tokenizer = BertTokenizer.from_pretrained(model_name)
91 |
92 | def eval_once(self, doc, summary):
93 | """Calculate the BLANC score for a single doc with a single summary.
94 |
95 | Args:
96 | doc (str): The input document
97 | summary (str): The input summary for the input document
98 |
99 | Returns:
100 | score (float): The BLANC score for the input
101 | """
102 | (doc_score,) = self.eval_summaries_for_docs([doc], [[summary]])
103 | (score,) = doc_score
104 | return score
105 |
106 | def eval_pairs(self, docs, summaries):
107 | """Calculate the BLANC score for multiple docs, each with a single summary
108 |
109 | Args:
110 | docs (List[str]): A list of input documents
111 | summaries (List[str]): The input summary for each input document
112 |
113 | Returns:
114 | scores (List[float]): The BLANC scores for the inputs
115 | """
116 | doc_summaries = [[summary] for summary in summaries]
117 | full_scores = self.eval_summaries_for_docs(docs, doc_summaries)
118 | scores = [score for score, in full_scores]
119 | return scores
120 |
121 | def eval_summaries_for_docs(self, docs, doc_summaries):
122 | """Calculate the BLANC score for multiple docs, each with multiple summaries
123 |
124 | Args:
125 | docs (List[str]): A list of input documents
126 | doc_summaries (List[List[str]]): A list of summaries for every input document
127 |
128 | Returns:
129 | scores (List[List[float]]): A list of blanc scores corresponding to each summary for
130 | each document
131 | """
132 | raise NotImplementedError()
133 |
134 | def get_inputs_for_sentence(self, sent_tokens, summary_tokens):
135 | """Used by subclasses to specify inference inputs corresponding to a sentence
136 |
137 | Args:
138 | sent_tokens (List[str]): list of tokens corresponding to sentence
139 | summary_tokens (List[str]): list of tokens corresponding to a summary
140 | sep (List[str]): List of tokens corresponding to a separator between summary and sentence
141 |
142 | Returns:
143 | inputs (List[BertInput]): a list of masked token inputs to BERT
144 | answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps
145 | token indices corresponding to masked tokens back to their original token.
146 | """
147 | raise NotImplementedError()
148 |
149 | def mask_and_infer(self, model, docs, doc_summaries, sep=None):
150 | """Run the given model on masked versions of the provided doc_summaries and collect model
151 | output
152 |
153 | Args:
154 | model (BertForMaskedLM): a BERT for masked language modeling torch model
155 | docs (List[str]): A list of input documents
156 | doc_summaries (List[List[str]]): A list of summaries for every input document
157 | sep (str): Separator between the inference help (summary) and a sentence from the doc
158 |
159 | Returns:
160 | all_outputs (List[List[List[Dict[int, str]]]]): for each doc, for each summary for the
161 | doc, for each input sequence for the summary, we have a dict mapping indices to
162 | model predictions
163 | all_answers (List[List[List[Dict[int, str]]]]): for each doc, for each summary for the
164 | doc, for each input sequence for the summary, we have a dict mapping indices to
165 | original tokens
166 | """
167 |
168 | # Prepare inputs
169 | all_inputs, all_answers = [], []
170 | for doc, summaries in zip(docs, doc_summaries):
171 | doc_inputs, doc_answers = [], []
172 | for summary in summaries:
173 | summary_inputs, summary_answers = self.get_inference_inputs(doc, summary, sep)
174 | doc_inputs.append(summary_inputs)
175 | doc_answers.append(summary_answers)
176 | all_inputs.append(doc_inputs)
177 | all_answers.append(doc_answers)
178 |
179 | # Run inference in batches
180 | inputs_per_summary_per_doc = [
181 | [len(inputs) for inputs in summary_input] for summary_input in all_inputs
182 | ]
183 | collapsed_inputs = sum(sum(all_inputs, []), [])
184 | batched_inputs = batch_data(collapsed_inputs, self.inference_batch_size)
185 |
186 | iterator = tqdm.tqdm(batched_inputs, disable=not self.show_progress_bar)
187 | batched_outputs = [self.run_inference_batch(model, batch) for batch in iterator]
188 | collapsed_outputs = sum(batched_outputs, [])
189 |
190 | # Regroup outputs
191 | i = 0
192 | all_outputs = []
193 | for inputs_per_summary in inputs_per_summary_per_doc:
194 | doc_outputs = []
195 | for num_inputs in inputs_per_summary:
196 | doc_outputs.append(collapsed_outputs[i : i + num_inputs])
197 | i += num_inputs
198 | all_outputs.append(doc_outputs)
199 |
200 | return all_outputs, all_answers
201 |
202 | def get_inference_inputs(self, doc, summary=None, sep=None):
203 | """Get the inference inputs for a document, which possibly includes a summary
204 |
205 | Args:
206 | doc (str): an input document
207 | summary (str): an optional input summary
208 | sep (str): Separator between the inference help (summary) and a sentence from the doc
209 |
210 | Returns:
211 | summary_inputs (List[BertInput]): a list of BertInputs for inference
212 | summary_answers (List[Dict[int, str]]): each dict maps token indices back to their
213 | original token
214 | """
215 | doc = clean_text(doc)
216 | doc_sents = sent_tokenize(doc)
217 | doc_sent_tokens = [self.model_tokenizer.tokenize(sent) for sent in doc_sents]
218 |
219 | summary_sent_tokens = None
220 | if summary:
221 | summary = clean_text(summary)
222 | summary_sents = sent_tokenize(summary)
223 | summary_sent_tokens = [self.model_tokenizer.tokenize(sent) for sent in summary_sents]
224 | if not summary_sent_tokens:
225 | summary_sent_tokens = [[]]
226 |
227 | len_sep = 0
228 | if sep:
229 | len_sep = len(sep)
230 |
231 | summary_inputs, summary_answers = [], []
232 | half_num_sents = len(doc_sent_tokens)
233 | truncate_bottom = True
234 | for i_sent, sent_tokens in enumerate(doc_sent_tokens):
235 | if i_sent > half_num_sents:
236 | truncate_bottom = False
237 | sent_tokens, summary_tokens = truncate_sentence_and_summary(
238 | sent=sent_tokens,
239 | summary=summary_sent_tokens,
240 | len_sep=len_sep,
241 | len_sent_allow_cut=self.len_sent_allow_cut,
242 | truncate_bottom=truncate_bottom,
243 | )
244 | # now it is assured that everything fits into the allowed input size:
245 | assert len(sent_tokens) + len(summary_tokens) + len_sep + 2 <= BERT_MAX_TOKENS
246 | inputs, answers = self.get_inputs_for_sentence(sent_tokens, summary_tokens)
247 | summary_inputs += inputs
248 | summary_answers += answers
249 | return summary_inputs, summary_answers
250 |
251 | def assemble_inference_input(self, answers, sent_tokens, help_tokens=None, help_sep=None):
252 | """Given input tokens, assemble them into the tensors used by the model for inference
253 |
254 | Args:
255 | answers (Dict[int, str]): a mapping of input token indices to their original value
256 | sent_tokens (List[str]): tokens corresponding to an input sentence
257 | help_tokens (List[str]): tokens corresponding to an input summary or filler
258 | help_sep (List[str]): tokens to put between the summary/filler and the sentence
259 |
260 | Returns:
261 | model_input (BertInput): an input to the BERT model
262 | shifted_answers (Dict[int, str]): the input answers but with shifted indices that take
263 | into account the summary/filler and starting CLS token
264 |
265 | Raises:
266 | ValueError: if the sentence itself is longer than the BERT_MAX_TOKENS limit, we raise
267 | this error as opposed to truncating the sentence
268 | """
269 | if not help_tokens:
270 | help_tokens = []
271 | if not help_sep:
272 | help_sep = []
273 |
274 | all_tokens = (
275 | [self.model_tokenizer.cls_token]
276 | + help_tokens
277 | + help_sep
278 | + sent_tokens
279 | + [self.model_tokenizer.sep_token]
280 | )
281 |
282 | input_ids = self.model_tokenizer.convert_tokens_to_ids(all_tokens)
283 | token_type_ids = [TOKEN_TYPE_A] * len(all_tokens)
284 | attention_mask = [NOT_MASKED] * len(all_tokens)
285 |
286 | offset = 1 + len(help_tokens) + len(help_sep)
287 | shifted_answers = {}
288 | for idx, token in answers.items():
289 | shifted_answers[idx + offset] = token
290 |
291 | model_input = BertInput(
292 | input_ids=input_ids,
293 | token_type_ids=token_type_ids,
294 | attention_mask=attention_mask,
295 | labels=None,
296 | masked_idxs=list(shifted_answers.keys()),
297 | )
298 |
299 | return model_input, shifted_answers
300 |
301 | def run_inference_batch(self, model, batch):
302 | """Run an inference batch through the provided model
303 |
304 | Args:
305 | model (BertForMaskedLM): a BERT for masked language modeling torch model
306 | batch (List[BertInput]): the input batch to run through the model
307 |
308 | Returns:
309 | all_predictions (List[Dict[int, str]]): predicted tokens for every masked token in
310 | the inputs
311 | """
312 | input_ids, attention_mask, token_type_ids, _ = get_input_tensors(
313 | batch, device=self.device, tokenizer=self.model_tokenizer,
314 | )
315 |
316 | with torch.no_grad():
317 | (model_output_batch,) = model(
318 | input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
319 | )
320 |
321 | all_predictions = []
322 | for model_input, model_output in zip(batch, model_output_batch):
323 | predictions = {}
324 | for idx in model_input.masked_idxs:
325 | predicted_id = model_output[idx].argmax()
326 | (predicted_token,) = self.model_tokenizer.convert_ids_to_tokens([predicted_id])
327 | predictions[idx] = predicted_token
328 | all_predictions.append(predictions)
329 |
330 | return all_predictions
331 |
332 | def mask_input_tokens(self, tokens, is_finetune):
333 | """Given a list of tokens, produce maskings for them
334 |
335 | Args:
336 | tokens (List[str]): a sequence of wordpiece tokens
337 | is_finetune (bool): whether or not these tokens are going to be used for finetuning
338 |
339 | Returns:
340 | masked_inputs (List[List[str]]): a list of token sequences, where each token sequence
341 | contains some masked tokens.
342 | all_answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps
343 | token indices corresponding to masked tokens back to their original token.
344 | """
345 | if is_finetune:
346 | even_masking = self.finetune_mask_evenly
347 | else:
348 | even_masking = self.inference_mask_evenly
349 |
350 | min_token_lengths = (
351 | self.min_token_length_normal,
352 | self.min_token_length_lead,
353 | self.min_token_length_followup,
354 | )
355 | if is_finetune:
356 | min_token_lengths = (
357 | self.min_token_length_normal_tune,
358 | self.min_token_length_lead_tune,
359 | self.min_token_length_followup_tune,
360 | )
361 |
362 | if even_masking:
363 | gap_use = self.gap
364 | gap_mask_use = self.gap_mask
365 | if is_finetune:
366 | gap_use = self.gap_tune
367 | gap_mask_use = self.gap_mask_tune
368 | return mask_tokens_evenly(
369 | tokens=tokens,
370 | gap=gap_use,
371 | min_token_lengths=min_token_lengths,
372 | mask_token=self.model_tokenizer.mask_token,
373 | gap_mask=gap_mask_use,
374 | )
375 | else:
376 | return mask_tokens_randomly(
377 | tokens=tokens,
378 | min_token_lengths=min_token_lengths,
379 | mask_token=self.model_tokenizer.mask_token,
380 | p_mask=self.p_mask,
381 | )
382 |
383 | def judge_output(self, base_output, assisted_output, base_answers, assisted_answers):
384 | """Given a model's predicted tokens with and without assistance, as well as the correct
385 | token predictions, produce the BLANC score
386 |
387 | Args:
388 | base_outputs (List[Dict[int, str]]): outputs without using "help" or "tune." Each list
389 | represents a different input masking, and each dict maps indices to model
390 | predictions.
391 | assisted_outputs (List[Dict[int, str]]): outputs using "help" or "tune." Each list
392 | represents a different input masking, and each dict maps indices to model
393 | predictions.
394 | base_answers (List[Dict[int, str]]): answers without using "help" or "tune." Each list
395 | represents a different input masking, and each dict maps indices to original
396 | tokens.
397 | assisted_answers (List[Dict[int, str]]): answers using "help" or "tune." Each
398 | list represents a different input masking, and each dict maps indices to original
399 | tokens.
400 |
401 | Returns:
402 | score (float): the BLANC score, if the measure is 'relative' or 'improve'.
403 | score, S (tuple of float and list): the BLANC score and counts,
404 | if the measure is 'relative-counts' or 'improve-counts'.
405 | """
406 | base_correctness = determine_correctness(base_output, base_answers)
407 | assisted_correctness = determine_correctness(assisted_output, assisted_answers)
408 |
409 | S = [[0, 0], [0, 0]]
410 | for base_correct, assisted_correct in zip(base_correctness, assisted_correctness):
411 | S[int(base_correct)][int(assisted_correct)] += 1
412 |
413 | measure_split = self.measure.split("-")
414 | if measure_split[0] == 'relative':
415 | result = measure_relative(S)
416 | if self.measure == 'relative-counts':
417 | result = result, S
418 | elif measure_split[0] == 'improve':
419 | result = measure_improve(S)
420 | if self.measure == 'improve-counts':
421 | result = result, S
422 | else:
423 | raise NotImplementedError()
424 |
425 | return result
426 |
427 | def init_model(self, device):
428 | """Initialize the language model and send it to the given device
429 | Note: Transformers v.4 and higher made default return_dict=True.
430 | Args:
431 | device (str): torch device (usually "cpu" or "cuda")
432 |
433 | Returns:
434 | model: a model for masked language modeling torch model
435 | """
436 | model = None
437 | if self.model_name.lower().find('albert') >= 0:
438 | try:
439 | model = AlbertForMaskedLM.from_pretrained(self.model_name, return_dict=False).to(device)
440 | except:
441 | model = AlbertForMaskedLM.from_pretrained(self.model_name).to(device)
442 | else:
443 | try:
444 | model = BertForMaskedLM.from_pretrained(self.model_name, return_dict=False).to(device)
445 | except:
446 | model = BertForMaskedLM.from_pretrained(self.model_name).to(device)
447 | model.eval()
448 | return model
449 |
450 |
451 | class BlancHelp(Blanc):
452 | """BLANC-help, as defined in the BLANC paper."""
453 |
454 | def __init__(
455 | self,
456 | model_name=Defaults.model_name,
457 | measure=Defaults.measure,
458 | gap=Defaults.gap,
459 | gap_mask=Defaults.gap_mask,
460 | gap_tune=Defaults.gap_tune,
461 | gap_mask_tune=Defaults.gap_mask_tune,
462 | min_token_length_normal=Defaults.min_token_length_normal,
463 | min_token_length_lead=Defaults.min_token_length_lead,
464 | min_token_length_followup=Defaults.min_token_length_followup,
465 | min_token_length_normal_tune=Defaults.min_token_length_normal_tune,
466 | min_token_length_lead_tune=Defaults.min_token_length_lead_tune,
467 | min_token_length_followup_tune=Defaults.min_token_length_followup_tune,
468 | device=Defaults.device,
469 | inference_batch_size=Defaults.inference_batch_size,
470 | inference_mask_evenly=Defaults.inference_mask_evenly,
471 | len_sent_allow_cut=Defaults.len_sent_allow_cut,
472 | filler_token=Defaults.filler_token,
473 | help_sep=Defaults.help_sep,
474 | p_mask=Defaults.p_mask,
475 | show_progress_bar=Defaults.show_progress_bar,
476 | ):
477 | """See CLI documentation (blanc --help) for information about each arg"""
478 | super().__init__(
479 | model_name=model_name,
480 | measure=measure,
481 | gap=gap,
482 | gap_mask=gap_mask,
483 | gap_tune=gap_tune,
484 | gap_mask_tune=gap_mask_tune,
485 | min_token_length_normal=min_token_length_normal,
486 | min_token_length_lead=min_token_length_lead,
487 | min_token_length_followup=min_token_length_followup,
488 | min_token_length_normal_tune=min_token_length_normal_tune,
489 | min_token_length_lead_tune=min_token_length_lead_tune,
490 | min_token_length_followup_tune=min_token_length_followup_tune,
491 | device=device,
492 | inference_batch_size=inference_batch_size,
493 | inference_mask_evenly=inference_mask_evenly,
494 | len_sent_allow_cut=len_sent_allow_cut,
495 | p_mask=p_mask,
496 | show_progress_bar=show_progress_bar,
497 | )
498 |
499 | self.filler_token = filler_token
500 | self.help_sep = self.model_tokenizer.tokenize(help_sep)
501 | self.model = self.init_model(self.device)
502 |
503 | def eval_summaries_for_docs(self, docs, doc_summaries):
504 | """Calculate the BLANC score for multiple docs, each with multiple summaries.
505 | See documentation in superclass.
506 | """
507 | all_outputs, all_answers = self.mask_and_infer(
508 | self.model, docs, doc_summaries, sep=self.help_sep
509 | )
510 |
511 | all_scores = []
512 | for doc_outputs, doc_answers in zip(all_outputs, all_answers):
513 | doc_scores = []
514 | for summary_output, summary_answers in zip(doc_outputs, doc_answers):
515 | help_output = [out for i, out in enumerate(summary_output) if i % 2 == 0]
516 | filler_output = [out for i, out in enumerate(summary_output) if i % 2 == 1]
517 | help_answers = [answer for i, answer in enumerate(summary_answers) if i % 2 == 0]
518 | filler_answers = [answer for i, answer in enumerate(summary_answers) if i % 2 == 1]
519 |
520 | score = self.judge_output(filler_output, help_output, filler_answers, help_answers)
521 | doc_scores.append(score)
522 | all_scores.append(doc_scores)
523 |
524 | return all_scores
525 |
526 | def get_inputs_for_sentence(self, sent_tokens, summary_tokens):
527 | """Get inference inputs corresponding to a given sentence. For BLANC-help, we get several
528 | maskings for each sentence, and for each of these maskings we have an input with the
529 | summary prepended, and an input with a filler prepended. See documentation in superclass.
530 | """
531 | sent_maskings, init_answers = self.mask_input_tokens(sent_tokens, is_finetune=False)
532 |
533 | filler_tokens = [self.filler_token] * len(summary_tokens)
534 | inputs, final_answers = [], []
535 | for sent_masking, init_answer in zip(sent_maskings, init_answers):
536 | help_input, help_answers = self.assemble_inference_input(
537 | answers=init_answer,
538 | sent_tokens=sent_masking,
539 | help_tokens=summary_tokens,
540 | help_sep=self.help_sep,
541 | )
542 |
543 | filler_input, filler_answers = self.assemble_inference_input(
544 | answers=init_answer,
545 | sent_tokens=sent_masking,
546 | help_tokens=filler_tokens,
547 | help_sep=self.help_sep,
548 | )
549 |
550 | inputs += [help_input, filler_input]
551 | final_answers += [help_answers, filler_answers]
552 |
553 | return inputs, final_answers
554 |
555 |
556 | class BlancTune(Blanc):
557 | """BLANC-tune, as defined in the BLANC paper."""
558 |
559 | def __init__(
560 | self,
561 | model_name=Defaults.model_name,
562 | measure=Defaults.measure,
563 | gap=Defaults.gap,
564 | gap_mask=Defaults.gap_mask,
565 | gap_tune=Defaults.gap_tune,
566 | gap_mask_tune=Defaults.gap_mask_tune,
567 | min_token_length_normal=Defaults.min_token_length_normal,
568 | min_token_length_lead=Defaults.min_token_length_lead,
569 | min_token_length_followup=Defaults.min_token_length_followup,
570 | device=Defaults.device,
571 | min_token_length_normal_tune=Defaults.min_token_length_normal_tune,
572 | min_token_length_lead_tune=Defaults.min_token_length_lead_tune,
573 | min_token_length_followup_tune=Defaults.min_token_length_followup_tune,
574 | inference_batch_size=Defaults.inference_batch_size,
575 | inference_mask_evenly=Defaults.inference_mask_evenly,
576 | finetune_batch_size=Defaults.finetune_batch_size,
577 | finetune_epochs=Defaults.finetune_epochs,
578 | finetune_mask_evenly=Defaults.finetune_mask_evenly,
579 | len_sent_allow_cut=Defaults.len_sent_allow_cut,
580 | finetune_chunk_size=Defaults.finetune_chunk_size,
581 | finetune_chunk_stride=Defaults.finetune_chunk_stride,
582 | finetune_top_fully=Defaults.finetune_top_fully,
583 | id_layer_freeze_below=Defaults.id_layer_freeze_below,
584 | id_layer_freeze_above=Defaults.id_layer_freeze_above,
585 | show_progress_bar=Defaults.show_progress_bar,
586 | p_mask=Defaults.p_mask,
587 | p_token_replace=Defaults.p_token_replace,
588 | p_token_original=Defaults.p_token_original,
589 | learning_rate=Defaults.learning_rate,
590 | warmup_steps=Defaults.warmup_steps,
591 | random_seed=Defaults.random_seed,
592 | ):
593 | """See CLI documentation (blanc --help) for information about each arg"""
594 | super().__init__(
595 | model_name=model_name,
596 | measure=measure,
597 | gap=gap,
598 | gap_mask=gap_mask,
599 | gap_tune=gap_tune,
600 | gap_mask_tune=gap_mask_tune,
601 | min_token_length_normal=min_token_length_normal,
602 | min_token_length_lead=min_token_length_lead,
603 | min_token_length_followup=min_token_length_followup,
604 | min_token_length_normal_tune=min_token_length_normal_tune,
605 | min_token_length_lead_tune=min_token_length_lead_tune,
606 | min_token_length_followup_tune=min_token_length_followup_tune,
607 | device=device,
608 | inference_batch_size=inference_batch_size,
609 | inference_mask_evenly=inference_mask_evenly,
610 | len_sent_allow_cut=len_sent_allow_cut,
611 | )
612 |
613 | self.finetune_batch_size = finetune_batch_size
614 | self.finetune_epochs = finetune_epochs
615 | self.finetune_mask_evenly = finetune_mask_evenly
616 | self.finetune_chunk_size = finetune_chunk_size
617 | self.finetune_chunk_stride = finetune_chunk_stride
618 | self.finetune_top_fully = finetune_top_fully
619 | self.id_layer_freeze_below = id_layer_freeze_below
620 | self.id_layer_freeze_above = id_layer_freeze_above
621 | self.show_progress_bar = show_progress_bar
622 | self.p_mask = p_mask
623 | self.p_token_replace = p_token_replace
624 | self.p_token_original = p_token_original
625 | self.learning_rate = learning_rate
626 | self.warmup_steps = warmup_steps
627 | self.random_seed = random_seed
628 |
629 | # The same is intentionally not given:
630 | self.gap_tune = self.gap if self.gap_tune < 0 else self.gap_tune
631 | self.gap_mask_tune = self.gap_mask if self.gap_mask_tune < 0 else self.gap_mask_tune
632 | self.min_token_length_normal_tune = self.min_token_length_normal if self.min_token_length_normal_tune < 0 else self.min_token_length_normal_tune
633 | self.min_token_length_lead_tune = self.min_token_length_lead if self.min_token_length_lead_tune < 0 else self.min_token_length_lead_tune
634 | self.min_token_length_followup_tune = self.min_token_length_followup if self.min_token_length_followup_tune < 0 else self.min_token_length_followup_tune
635 |
636 | self.base_model = self.init_model(self.device)
637 |
638 | def eval_summaries_for_docs(self, docs, doc_summaries):
639 | """Calculate the BLANC score for multiple docs, each with multiple summaries.
640 | See documentation in superclass.
641 | Note that a summary should not be used in any ways for base outputs and base answers.
642 | When a finetuned model is used, the summary can used, meaning that 'help' and 'tune' versions of BLANC are put together.
643 | """
644 |
645 | doc_summaries_use = [[None for s in summs] for summs in doc_summaries]
646 | base_outputs, base_answers = self.mask_and_infer(self.base_model, docs, doc_summaries_use)
647 |
648 | finetuned_outputs, finetuned_answers = [], []
649 | model_cpu = self.init_model(device='cpu')
650 | for doc, summaries in tqdm.tqdm(zip(docs, doc_summaries), total=len(docs), disable=not self.show_progress_bar):
651 | finetuned_doc_outputs, finetuned_doc_answers = [], []
652 | for summary in summaries:
653 | model_copy = copy.deepcopy(model_cpu)
654 | finetuned_model = model_copy.to(self.device)
655 | self.finetune(finetuned_model, summary)
656 |
657 | (finetuned_summary_output,), (finetuned_summary_answer,) = self.mask_and_infer(
658 | finetuned_model, [doc], [[None]]
659 | )
660 | finetuned_doc_outputs += finetuned_summary_output
661 | finetuned_doc_answers += finetuned_summary_answer
662 |
663 | del finetuned_model
664 | torch.cuda.empty_cache()
665 |
666 | finetuned_outputs.append(finetuned_doc_outputs)
667 | finetuned_answers.append(finetuned_doc_answers)
668 |
669 | all_scores = [
670 | [
671 | self.judge_output(
672 | base_summary_output,
673 | finetuned_summary_output,
674 | base_summary_answers,
675 | finetuned_summary_answers,
676 | )
677 | for (
678 | base_summary_output,
679 | base_summary_answers,
680 | finetuned_summary_output,
681 | finetuned_summary_answers,
682 | ) in zip(
683 | base_doc_output, base_doc_answers, finetuned_doc_output, finetuned_doc_answers,
684 | )
685 | ]
686 | for (
687 | base_doc_output,
688 | base_doc_answers,
689 | finetuned_doc_output,
690 | finetuned_doc_answers,
691 | ) in zip(
692 | base_outputs, base_answers, finetuned_outputs, finetuned_answers,
693 | )
694 | ]
695 |
696 | return all_scores
697 |
698 | def get_inputs_for_sentence(self, sent_tokens, summary_tokens):
699 | """Get inference inputs corresponding to a given sentence. For BLANC-tune, we get several
700 | maskings for each sentence, and each masking is a single input. See documentation in
701 | superclass.
702 | """
703 | sent_maskings, init_answers = self.mask_input_tokens(sent_tokens, is_finetune=False)
704 | inputs, final_answers = [], []
705 | for sent_idx, (sent_masking, init_answer) in enumerate(zip(sent_maskings, init_answers)):
706 | input_, answers = self.assemble_inference_input(
707 | answers=init_answer, sent_tokens=sent_masking,
708 | )
709 |
710 | inputs.append(input_)
711 | final_answers.append(answers)
712 |
713 | return inputs, final_answers
714 |
715 | def finetune(self, model, summary):
716 | """Finetune the given model on a "dataset" produced from chunks of the given summary.
717 |
718 | Args:
719 | model (BertForMaskedLM): a BERT for masked language modeling torch model
720 | summary (str): the summary to finetune on
721 | """
722 | if self.random_seed > 0:
723 | set_seed(self.random_seed)
724 | model.train()
725 | n_params = len(list(model.parameters()))
726 | # Freeze a few lowest or a few highest layers:
727 | if self.id_layer_freeze_below > 0 or self.id_layer_freeze_above > 0:
728 | for i, param in enumerate(model.parameters()):
729 | if i < self.id_layer_freeze_below:
730 | param.requires_grad = False
731 | elif self.id_layer_freeze_above < 0:
732 | break
733 | elif n_params - i < self.id_layer_freeze_above:
734 | param.requires_grad = False
735 | all_inputs = self.prepare_finetuning_data(summary)
736 | input_batches = batch_data(all_inputs, self.finetune_batch_size)
737 |
738 | no_decay = ["bias", "LayerNorm.weight"]
739 | optimizer_grouped_parameters = [
740 | {
741 | "params": [
742 | p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)
743 | ],
744 | "weight_decay": 1e-2,
745 | },
746 | {
747 | "params": [
748 | p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)
749 | ],
750 | "weight_decay": 0.0,
751 | },
752 | ]
753 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=1e-8)
754 | scheduler = get_linear_schedule_with_warmup(
755 | optimizer,
756 | num_warmup_steps=self.warmup_steps,
757 | num_training_steps=len(input_batches) * self.finetune_epochs,
758 | )
759 | for epoch in range(self.finetune_epochs):
760 | for input_batch in input_batches:
761 | input_ids, attention_mask, token_type_ids, labels = get_input_tensors(
762 | input_batch, device=self.device, tokenizer=self.model_tokenizer,
763 | )
764 | model.zero_grad()
765 | optimizer.zero_grad()
766 | try: # masked_lm_labels were deprecated, replace by labels in transformers v4.x
767 | loss, _ = model(
768 | input_ids=input_ids,
769 | attention_mask=attention_mask,
770 | token_type_ids=token_type_ids,
771 | labels=labels,
772 | )
773 | except:
774 | loss, _ = model(
775 | input_ids=input_ids,
776 | attention_mask=attention_mask,
777 | token_type_ids=token_type_ids,
778 | masked_lm_labels=labels,
779 | )
780 | loss.backward()
781 | optimizer.step()
782 | scheduler.step()
783 | model.eval()
784 |
785 | def prepare_finetuning_data(self, summary):
786 | """Create a finetuning dataset using chunks of the given summary
787 | The finetune_top_fully=True compensate finetuning of top tokens, which
788 | otherwise get less tuning than tokens at further strides.
789 |
790 | Args:
791 | summary (str): the input summary to finetune on
792 |
793 | Returns:
794 | model_inputs (List[BertInput]): a list of inputs to use as the finetuning dataset
795 | """
796 | summary_tokens = self.model_tokenizer.tokenize(summary)
797 | model_inputs = []
798 | for start_token in range(0, len(summary_tokens), self.finetune_chunk_stride):
799 | end_token = start_token + self.finetune_chunk_size
800 | chunk_tokens = summary_tokens[start_token:end_token]
801 | model_inputs += self.assemble_finetuning_input(chunk_tokens)
802 | if self.finetune_top_fully and start_token > 0 and start_token < self.finetune_chunk_size:
803 | chunk_tokens = summary_tokens[:start_token]
804 | model_inputs += self.assemble_finetuning_input(chunk_tokens)
805 | return model_inputs
806 |
807 | def assemble_finetuning_input(self, chunk_tokens):
808 | """Given input tokens, assemble them into the tensors used by the model for finetuning
809 |
810 | Args:
811 | chunk_tokens (List[str]): a token sequence
812 |
813 | Returns:
814 | model_inputs (List[BertInput]): BertInputs corresponding to different maskings of
815 | chunk_tokens
816 | """
817 | all_input_tokens, all_answers = self.mask_input_tokens(chunk_tokens, is_finetune=True)
818 |
819 | all_input_tokens = [
820 | [self.model_tokenizer.cls_token] + tokens + [self.model_tokenizer.sep_token]
821 | for tokens in all_input_tokens
822 | ]
823 | all_input_ids = [
824 | self.model_tokenizer.convert_tokens_to_ids(tokens) for tokens in all_input_tokens
825 | ]
826 | all_labels = [[LABEL_IGNORE] * len(tokens) for tokens in all_input_tokens]
827 |
828 | model_inputs = []
829 | for input_ids, answers, labels in zip(all_input_ids, all_answers, all_labels):
830 | for original_idx, token in answers.items():
831 | idx = original_idx + 1 # accounting for starting CLS token
832 | (original_token_id,) = self.model_tokenizer.convert_tokens_to_ids([token])
833 | labels[idx] = original_token_id
834 |
835 | random_number = random.random()
836 | if random_number < self.p_token_replace:
837 | # replace with a random token
838 | input_ids[idx] = random.randint(*TOKEN_REPLACE_RANGE)
839 | elif random_number < self.p_token_original + self.p_token_replace:
840 | # use original token
841 | input_ids[idx] = original_token_id
842 |
843 | attention_mask = [NOT_MASKED] * len(input_ids)
844 | token_type_ids = [TOKEN_TYPE_A] * len(input_ids)
845 | model_input = BertInput(
846 | input_ids=input_ids,
847 | attention_mask=attention_mask,
848 | token_type_ids=token_type_ids,
849 | labels=labels,
850 | masked_idxs=None,
851 | )
852 | model_inputs.append(model_input)
853 |
854 | return model_inputs
855 |
--------------------------------------------------------------------------------