├── .github
└── workflows
│ ├── ci.yml
│ └── publish-to-pypi.yml
├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── compare_mt
├── __init__.py
├── align_utils.py
├── arg_utils.py
├── bucketers.py
├── cache_utils.py
├── compare_ll_main.py
├── compare_mt_main.py
├── corpus_utils.py
├── formatting.py
├── ngram_utils.py
├── print_utils.py
├── reporters.py
├── rouge
│ ├── README.md
│ ├── __init__.py
│ ├── io.py
│ ├── requirements.txt
│ ├── rouge.py
│ ├── rouge_scorer.py
│ ├── run.sh
│ ├── scoring.py
│ └── tokenize.py
├── scorers.py
├── sign_utils.py
├── stat_utils.py
└── version_info.py
├── example
├── ll_test.sys1.likelihood
├── ll_test.sys2.likelihood
├── ll_test.tag
├── ll_test.txt
├── multited.ref.jpn
├── multited.ref.jpn.tag
├── multited.sys1.jpn
├── multited.sys1.jpn.tag
├── multited.sys2.jpn
├── multited.sys2.jpn.tag
├── sum.ref.eng
├── sum.sys1.eng
├── sum.sys2.eng
├── ted.orig.slk
├── ted.ref.align
├── ted.ref.detok.eng
├── ted.ref.eng
├── ted.ref.eng.rptag
├── ted.ref.eng.tag
├── ted.sys1.align
├── ted.sys1.detok.eng
├── ted.sys1.eng
├── ted.sys1.eng.rptag
├── ted.sys1.eng.senttag
├── ted.sys1.eng.tag
├── ted.sys2.align
├── ted.sys2.detok.eng
├── ted.sys2.eng
├── ted.sys2.eng.rptag
├── ted.sys2.eng.senttag
├── ted.sys2.eng.tag
├── ted.train.counts
└── ted.train.eng
├── pytest.ini
├── requirements.txt
├── scripts
├── count.py
├── interleave.py
├── postag.py
└── relativepositiontag.py
├── setup.py
└── tests
├── __init__.py
├── test_cache.py
└── test_scorers.py
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 | on: [push]
3 |
4 | jobs:
5 | build:
6 | runs-on: ubuntu-latest
7 | steps:
8 | - uses: actions/checkout@v2
9 | - name: Install Python 3
10 | uses: actions/setup-python@v1
11 | with:
12 | python-version: 3.9
13 | - name: Install dependencies
14 | run: |
15 | python -m pip install --upgrade pip
16 | pip install .
17 | - name: Run tests with unittest
18 | run: python -m unittest
19 |
--------------------------------------------------------------------------------
/.github/workflows/publish-to-pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python 🐍 distributions 📦 to PyPI and TestPyPI
2 |
3 | on: push
4 |
5 | jobs:
6 | build-n-publish:
7 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI
8 | runs-on: ubuntu-18.04
9 | steps:
10 | - uses: actions/checkout@master
11 | - name: Set up Python 3.9
12 | uses: actions/setup-python@v1
13 | with:
14 | python-version: 3.9
15 | - name: Install pypa/build
16 | run: >-
17 | python -m
18 | pip install
19 | build
20 | --user
21 | - name: Build a binary wheel and a source tarball
22 | run: >-
23 | python -m
24 | build
25 | --sdist
26 | --wheel
27 | --outdir dist/
28 | .
29 | - name: Publish distribution 📦 to Test PyPI
30 | uses: pypa/gh-action-pypi-publish@master
31 | with:
32 | skip_existing: true
33 | password: ${{ secrets.TEST_PYPI_API_KEY }}
34 | repository_url: https://test.pypi.org/legacy/
35 | - name: Publish distribution 📦 to PyPI
36 | if: startsWith(github.ref, 'refs/tags')
37 | uses: pypa/gh-action-pypi-publish@master
38 | with:
39 | password: ${{ secrets.PYPI_API_KEY }}
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # PyCharm
2 | .idea
3 | __pycache__
4 | # vim
5 | *.swp
6 | # VS code
7 | .vscode/
8 | # Mac
9 | .DS_Store
10 | # setup.py build artifacts
11 | *.egg-info
12 | dist/
13 | build/
14 | # Virtualenv for developing
15 | env/
16 | # Outputs
17 | output/
18 | outputs/
19 | .pytest_cache
20 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - '3.6'
4 | - 3.7-dev
5 | install:
6 | - pip install -r requirements.txt
7 | - pip install -U setuptools
8 | - python setup.py install
9 | script:
10 | - pytest
11 | - compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng --decimals 2 --output_directory output
12 | - compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng --compare_scores score_type=bleu,bootstrap=10,prob_thresh=0.05 --output_directory output
13 | - compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng --compare_word_accuracies bucket_type=freq,freq_corpus_file=example/ted.train.eng,bucket_cutoffs=1:2:3:5:10 bucket_type=freq,freq_count_file=example/ted.train.counts,bucket_cutoffs=1:2:3:5:10 --output_directory output
14 | - compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng --compare_word_accuracies bucket_type=case bucket_type=label,ref_labels=example/ted.ref.eng.tag,out_labels="example/ted.sys1.eng.tag;example/ted.sys2.eng.tag",label_set=CC+DT+IN+JJ+NN+NNP+NNS+PRP+RB+TO+VB+VBP+VBZ bucket_type=numlabel,ref_labels=example/ted.ref.eng.rptag,out_labels="example/ted.sys1.eng.rptag;example/ted.sys2.eng.rptag" --compare_ngrams compare_type=match,ref_labels=example/ted.ref.eng.tag,out_labels="example/ted.sys1.eng.tag;example/ted.sys2.eng.tag" --output_directory output
15 | - compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng --src_file example/ted.orig.slk --compare_src_word_accuracies ref_align_file=example/ted.ref.align --output_directory output
16 | - compare-ll --ref example/ll_test.txt --ll-files example/ll_test.sys1.likelihood example/ll_test.sys2.likelihood --compare-word-likelihoods bucket_type=freq,freq_corpus_file=example/ll_test.txt --decimals 2
17 | - compare-ll --ref example/ll_test.txt --ll-files example/ll_test.sys1.likelihood example/ll_test.sys2.likelihood --compare-word-likelihoods bucket_type=label,label_corpus=example/ll_test.tag,label_set=CC+DT+IN+JJ+NN+NNP+NNS+PRP+RB+TO+VB+VBP+VBZ
18 | - compare-mt example/sum.ref.eng example/sum.sys1.eng example/sum.sys2.eng --compare_scores 'score_type=rouge1' 'score_type=rouge2' 'score_type=rougeL' --output_directory output
19 | - python compare_mt/compare_mt_main.py example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng --output_directory output
20 | deploy:
21 | provider: pypi
22 | user: pmichel31415
23 | skip_existing: true
24 | password:
25 | secure: fGKIZDGfu5L2WGiGlIidPI5uBi2P2TIytEIDerK8sJWKdIM6CSLnzVVXHst5VIujIhF2/TP7YMniLvMEflW5HY7Bu5fb2dBMQnyQJiE8SE9ih/Oq35W3fHJCEiAYnWo3CKLYlwUyJC9VZn8w0JrU2MBWfLCIli3Fuh9sbRyVNvjRq4kc2IGIjcxwQvM0Hml9G/89UwWYKUbxi53tFfUr5qu9WyuPdy/i2bcHaYMB6FgXbTn47MmOgVDvLjLjePpMsF+fNQDkkN035ngPRLDfHfBM74ag2ycVUhjT8nsMOfKGMpmbk/CeyKOYT9TW6Fp/MALQ5nJ9qF4q49mOpz7lh0JfogTCxweU76cpPsi9j99BvYULTYy1SnjOP9ZqglobosWq2fUtw8Pf6KE57Y0ultfh+CAgXWhX7rBFGj9PrYW6+P8Y2p5+MQuXRZp+6TOXgpELh0SiXUAFQA5B77Kw8+tPw5DJL1b5oGXBTp94sttHxNXeV9bm9AwKB18rUcKKA0AHFP5FgjvdtfZKnjydSg/hFn82UA/0g0ubcSuqdoSRgk49NT4RasODiqnfqXseJ/q1vWm5eiW60QzXuHZrK6EN8vzKxFH7DYjAZTOQsAdoCgAQvSXABOKum/Pm3HWU+BfD0xZH9cJEn9YvSKD5qMNikmMK1LR2cgRHmbhjUXQ=
26 | on:
27 | branch: master
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2019 Graham Neubig
2 |
3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4 |
5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6 |
7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 |
9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10 |
11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # compare-mt
2 | by [NeuLab](http://www.cs.cmu.edu/~neulab/) @ [CMU LTI](https://lti.cs.cmu.edu), and other contributors
3 |
4 | [](.github/workflows/ci.yml)
5 |
6 | `compare-mt` (for "compare my text") is a program to compare the output of multiple systems for language generation,
7 | including machine translation, summarization, dialog response generation, etc.
8 | To use it you need to have, in text format, a "correct" reference, and the output of two different systems.
9 | Based on this, `compare-mt` will run a number of analyses that attempt to pick out salient differences between
10 | the systems, which will make it easier for you to figure out what things one system is doing better than another.
11 |
12 | ## Basic Usage
13 |
14 | First, you need to install the package:
15 |
16 | ```bash
17 | # Requirements
18 | pip install -r requirements.txt
19 | # Install the package
20 | python setup.py install
21 | ```
22 |
23 | Then, as an example, you can run this over two included system outputs.
24 |
25 | ```bash
26 | compare-mt --output_directory output/ example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng
27 | ```
28 |
29 | This will output some statistics to the command line, and also write a formatted HTML report to `output/`.
30 | Here, system 1 and system 2 are the baseline phrase-based and neural Slovak-English systems from our
31 | [EMNLP 2018 paper](http://aclweb.org/anthology/D18-1103). This will print out a number of statistics including:
32 |
33 | * **Aggregate Scores:** A report on overall BLEU scores and length ratios
34 | * **Word Accuracy Analysis:** A report on the F-measure of words by frequency bucket
35 | * **Sentence Bucket Analysis:** Bucket sentences by various statistics (e.g. sentence BLEU, length difference with the
36 | reference, overall length), and calculate statistics by bucket (e.g. number of sentences, BLEU score per bucket)
37 | * **N-gram Difference Analysis:** Calculate which n-grams one system is consistently translating better
38 | * **Sentence Examples:** Find sentences where one system is doing better than the other according to sentence BLEU
39 |
40 | You can see an example of running this analysis (as well as the more advanced analysis below) either through a
41 | [generated HTML report here](http://phontron.com/compare-mt/output/), or in the following narrated video:
42 |
43 | [](https://www.youtube.com/watch?v=K-MNPOGKnDQ)
44 |
45 | To summarize the results that immediately stick out from the basic analysis:
46 |
47 | * From the *aggregate scores* we can see that the BLEU of neural MT is higher, but its sentences are slightly shorter.
48 | * From the *word accuracy analysis* we can see that phrase-based MT is better at low-frequency words.
49 | * From the *sentence bucket analysis* we can see that neural seems to be better at translating shorter sentences.
50 | * From the *n-gram difference analysis* we can see that there are a few words that neural MT is not good at
51 | but phrase based MT gets right (e.g. "phantom"), while there are a few long phrases that neural MT does better with
52 | (e.g. "going to show you").
53 |
54 | If you run on your own data, you might be able to find more interesting things about your own systems. Try comparing
55 | your modified system with your baseline and seeing what you find!
56 |
57 | ## Other Options
58 |
59 | There are many options that can be used to do different types of analysis.
60 | If you want to find all the different types of analysis supported, the most comprehensive way to do so is by
61 | taking a look at `compare-mt`, which is documented relatively well and should give examples.
62 | We do highlight a few particularly useful and common types of analysis below:
63 |
64 | ### Significance Tests
65 |
66 | The script allows you to perform statistical significance tests for scores based on [bootstrap resampling](https://aclanthology.org/W04-3250.pdf). You can set
67 | the number of samples manually. Here is an example using the example data:
68 |
69 |
70 | ```bash
71 | compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng --compare_scores score_type=bleu,bootstrap=1000,prob_thresh=0.05
72 | ```
73 |
74 | One important thing to note is that bootrap resampling as implemented in compare-mt only tests for variance due to data sampling, approximately answering the question ``if I ran the same system on a different, similarly sampled dataset, would I be likely to get the same result?''.
75 | It does not say anything about whether a system will perform better on another dataset in a different domain, and it [does not control for training-time factors](https://aclanthology.org/P11-2031/) such as selection of the random seed, so it cannot say if another training run of the same model would yield the same result.
76 |
77 | ### Using Training Set Frequency
78 |
79 | One useful piece of analysis is the "word accuracy by frequency" analysis. By default this frequency is the frequency
80 | in the *test set*, but arguably it is more informative to know accuracy by frequency in the *training set* as this
81 | demonstrates the models' robustness to words they haven't seen much, or at all, in the training data. To change the
82 | corpus used to calculate word frequency and use the training set (or some other set), you can set the `freq_corpus_file`
83 | option to the appropriate corpus.
84 |
85 |
86 | ```bash
87 | compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng
88 | --compare_word_accuracies bucket_type=freq,freq_corpus_file=example/ted.train.eng
89 | ```
90 |
91 | In addition, because training sets may be very big, you can also calculate the counts on the file beforehand,
92 |
93 | ```bash
94 | python scripts/count.py < example/ted.train.eng > example/ted.train.counts
95 | ```
96 |
97 | and then use these counts directly to improve efficiency.
98 |
99 | ```bash
100 | compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng
101 | --compare_word_accuracies bucket_type=freq,freq_count_file=example/ted.train.counts
102 | ```
103 |
104 |
105 | ### Incorporating Word/Sentence Labels
106 |
107 | If you're interested in performing aggregate analysis over labels for each word/sentence instead of the words/sentences themselves, it
108 | is possible to do so. As an example, we've included POS tags for each of the example outputs. You can use these in
109 | aggregate analysis, or n-gram-based analysis. The following gives an example:
110 |
111 |
112 | ```bash
113 | compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng
114 | --compare_word_accuracies bucket_type=label,ref_labels=example/ted.ref.eng.tag,out_labels="example/ted.sys1.eng.tag;example/ted.sys2.eng.tag",label_set=CC+DT+IN+JJ+NN+NNP+NNS+PRP+RB+TO+VB+VBP+VBZ
115 | --compare_ngrams compare_type=match,ref_labels=example/ted.ref.eng.tag,out_labels="example/ted.sys1.eng.tag;example/ted.sys2.eng.tag"
116 | ```
117 |
118 | This will calculate word accuracies and n-gram matches by POS bucket, and allows you to see things like the fact
119 | that the phrase-based MT system is better at translating content words such as nouns and verbs, while neural MT
120 | is doing better at translating function words.
121 |
122 | We also give an example to perform aggregate analysis when multiple labels per word/sentence, where each group of labels is a string separated by '+'s, are allowed:
123 |
124 | ```bash
125 | compare-mt example/multited.ref.jpn example/multited.sys1.jpn example/multited.sys2.jpn
126 | --compare_word_accuracies bucket_type=multilabel,ref_labels=example/multited.ref.jpn.tag,out_labels="example/multited.sys1.jpn.tag;example/multited.sys2.jpn.tag",label_set=lexical+formality+pronouns+ellipsis
127 | ```
128 |
129 | It also is possible to create labels that represent numberical values. For example, `scripts/relativepositiontag.py` calculates the relative position of words in the sentence, where 0 is the first word in the sentence, 0.5 is the word in the middle, and 1.0 is the word in the end. These numerical values can then be bucketed. Here is an example:
130 |
131 | ```bash
132 | compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng
133 | --compare_word_accuracies bucket_type=numlabel,ref_labels=example/ted.ref.eng.rptag,out_labels="example/ted.sys1.eng.rptag;example/ted.sys2.eng.rptag"
134 | ```
135 |
136 | From this particular analysis we can discover that NMT does worse than PBMT at the end of the sentence, and of course other varieties of numerical labels could be used to measure different properties of words.
137 |
138 | You can also perform analysis over labels for sentences. Here is an example:
139 |
140 | ```bash
141 | compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng
142 | --compare_sentence_buckets 'bucket_type=label,out_labels=example/ted.sys1.eng.senttag;example/ted.sys2.eng.senttag,label_set=0+10+20+30+40+50+60+70+80+90+100,statistic_type=score,score_measure=bleu'
143 | ```
144 |
145 |
146 | ### Analyzing Source Words
147 |
148 | If you have a source corpus that is aligned to the target, you can also analyze accuracies according to features of the
149 | source language words, which would allow you to examine whether, for example, infrequent words on the source side are
150 | hard to output properly. Here is an example using the example data:
151 |
152 | ```bash
153 | compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng --src_file example/ted.orig.slk --compare_src_word_accuracies ref_align_file=example/ted.ref.align
154 | ```
155 |
156 | ### Analyzing Word Likelihoods
157 |
158 | If you wish to analyze the word log likelihoods by two systems on the target corpus, you can use the following
159 |
160 | ```bash
161 | compare-ll --ref example/ll_test.txt --ll-files example/ll_test.sys1.likelihood example/ll_test.sys2.likelihood --compare-word-likelihoods bucket_type=freq,freq_corpus_file=example/ll_test.txt
162 | ```
163 |
164 | You can analyze the word log likelihoods over labels for each word instead of the words themselves:
165 |
166 | ```bash
167 | compare-ll --ref example/ll_test.txt --ll-files example/ll_test.sys1.likelihood example/ll_test.sys2.likelihood --compare-word-likelihoods bucket_type=label,label_corpus=example/ll_test.tag,label_set=CC+DT+IN+JJ+NN+NNP+NNS+PRP+RB+TO+VB+VBP+VBZ
168 | ```
169 |
170 | NOTE: You can also use the above to also analyze the word likelihoods produced by two language models.
171 |
172 | ### Analyzing Other Language Generation Systems
173 |
174 | You can also analyze other language generation systems using the script. Here is an example of comparing two text summarization systems.
175 |
176 | ```bash
177 | compare-mt example/sum.ref.eng example/sum.sys1.eng example/sum.sys2.eng --compare_scores 'score_type=rouge1' 'score_type=rouge2' 'score_type=rougeL'
178 | ```
179 |
180 | ### Evaluating on COMET
181 |
182 | It is possible to use the [COMET](https://unbabel.github.io/COMET/html/index.html) as a metric.
183 | To do so, you need to install it first by running
184 |
185 | ```bash
186 | pip install unbabel-comet
187 | ```
188 |
189 | To then run, pass the source and select the appropriate score type. Here is an example.
190 | ```bash
191 | compare-mt example/ted.ref.eng example/ted.sys1.eng example/ted.sys2.eng --src_file example/ted.orig.slk \
192 | --compare_scores score_type=comet \
193 | --compare_sentence_buckets bucket_type=score,score_measure=sentcomet
194 | ```
195 |
196 | Note that COMET runs on top of XLM-R, so it's highly recommended you use a GPU with it.
197 |
198 | ## Citation/References
199 |
200 | If you use compare-mt, we'd appreciate if you cite the [paper](http://arxiv.org/abs/1903.07926) about it!
201 |
202 | @article{DBLP:journals/corr/abs-1903-07926,
203 | author = {Graham Neubig and Zi{-}Yi Dou and Junjie Hu and Paul Michel and Danish Pruthi and Xinyi Wang and John Wieting},
204 | title = {compare-mt: {A} Tool for Holistic Comparison of Language Generation Systems},
205 | journal = {CoRR},
206 | volume = {abs/1903.07926},
207 | year = {2019},
208 | url = {http://arxiv.org/abs/1903.07926},
209 | }
210 |
211 | There is an extensive literature review included in the paper above, but some key papers that it borrows ideas from are below:
212 |
213 | * **Automatic Error Analysis:**
214 | Popovic and Ney "[Towards Automatic Error Analysis of Machine Translation Output](https://www.mitpressjournals.org/doi/pdf/10.1162/COLI_a_00072)" Computational Linguistics 2011.
215 | * **POS-based Analysis:**
216 | Chiang et al. "[The Hiero Machine Translation System](http://aclweb.org/anthology/H05-1098)" EMNLP 2005.
217 | * **n-gram Difference Analysis**
218 | Akabe et al. "[Discriminative Language Models as a Tool for Machine Translation Error Analysis](http://www.phontron.com/paper/akabe14coling.pdf)" COLING 2014.
219 |
220 | There is also other good software for automatic comparison or error analysis of MT systems:
221 |
222 | * **[MT-ComparEval](https://github.com/choko/MT-ComparEval):** Very nice for visualization of individual examples, but
223 | not as focused on aggregate analysis as `compare-mt`. Also has more software dependencies and requires using a web
224 | browser, while `compare-mt` can be used as a command-line tool.
225 |
--------------------------------------------------------------------------------
/compare_mt/__init__.py:
--------------------------------------------------------------------------------
1 | import compare_mt.ngram_utils
2 | import compare_mt.stat_utils
3 | import compare_mt.corpus_utils
4 | import compare_mt.sign_utils
5 | import compare_mt.scorers
6 | import compare_mt.bucketers
7 | import compare_mt.reporters
8 | import compare_mt.arg_utils
9 | import compare_mt.print_utils
10 | import compare_mt.version_info
11 |
12 | __version__ = compare_mt.version_info.__version__
13 |
--------------------------------------------------------------------------------
/compare_mt/align_utils.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from compare_mt import corpus_utils
3 |
4 | def _count_ngram(sent, order):
5 | gram_pos = dict()
6 | for i in range(order):
7 | gram_pos[i+1] = defaultdict(lambda: [])
8 | for i, word in enumerate(sent):
9 | for j in range(min(i+1, order)):
10 | gram_pos[j+1][word].append(i-j)
11 | word = sent[i-j-1] + ' ' + word
12 | return gram_pos
13 |
14 | def ngram_context_align(ref, out, order=-1, case_insensitive=False):
15 | """
16 | Calculate the word alignment between a reference sentence and an output sentence.
17 | Proposed in the following paper:
18 |
19 | Automatic Evaluation of Translation Quality for Distant Language Pairs
20 | Hideki Isozaki, Tsutomu Hirao, Kevin Duh, Katsuhito Sudoh, Hajime Tsukada
21 | http://www.anthology.aclweb.org/D/D10/D10-1092.pdf
22 |
23 | Args:
24 | ref: A reference sentence
25 | out: An output sentence
26 | order: The highest order of grams we want to consider (-1=inf)
27 | case_insensitive: A boolean specifying whether to turn on the case insensitive option
28 |
29 | Returns:
30 | The word alignment, represented as a list of integers.
31 | """
32 |
33 | if case_insensitive:
34 | ref = corpus_utils.lower(ref)
35 | out = corpus_utils.lower(out)
36 |
37 | order = len(ref) if order == -1 else order
38 |
39 | ref_gram_pos = _count_ngram(ref, order)
40 | out_gram_pos = _count_ngram(out, order)
41 |
42 | worder = []
43 | for i, word in enumerate(out):
44 | if len(ref_gram_pos[1][word]) == 0:
45 | continue
46 | if len(ref_gram_pos[1][word]) == len(out_gram_pos[1][word]) == 1:
47 | worder.append(ref_gram_pos[1][word][0])
48 | else:
49 | word_forward = word
50 | word_backward = word
51 | for j in range(1, order):
52 | if i - j >= 0:
53 | word_backward = out[i-j] + ' ' + word_backward
54 | if len(ref_gram_pos[j+1][word_backward]) == len(out_gram_pos[j+1][word_backward]) == 1:
55 | worder.append(ref_gram_pos[j+1][word_backward][0]+j)
56 | break
57 |
58 | if i + j < len(out):
59 | word_forward = word_forward + ' ' + out[i+j]
60 | if len(ref_gram_pos[j+1][word_forward]) == len(out_gram_pos[j+1][word_forward]) == 1:
61 | worder.append(ref_gram_pos[j+1][word_forward][0])
62 | break
63 |
64 | return worder
65 |
--------------------------------------------------------------------------------
/compare_mt/arg_utils.py:
--------------------------------------------------------------------------------
1 | def parse_profile(profile):
2 | kargs = {}
3 | try:
4 | for kv in profile.split(','):
5 | k, v = kv.split('=')
6 | kargs[k] = v
7 | except ValueError:
8 | # more informative error message
9 | raise ValueError(
10 | f"Failed to parse profile: {profile}. The expected format is:"
11 | " \"key1=value1,key2=value2,[...]\""
12 | )
13 | return kargs
14 |
15 | def parse_compare_directions(compare_directions):
16 | direcs = []
17 | try:
18 | for direc in compare_directions.split(';'):
19 | left, right = direc.split('-')
20 | left, right = int(left), int(right)
21 | direcs.append((left, right))
22 | except ValueError:
23 | # more informative error message
24 | raise ValueError(
25 | f"Failed to parse directions: {compare_directions}."
26 | " The expected format is: \"left1-right1;left2-right2;[...]\""
27 | )
28 | return direcs
29 |
30 | def parse_files(filenames):
31 | files = []
32 | for f in filenames.split(';'):
33 | files.append(f)
34 | return files
35 |
36 | def parse_intfloat(s):
37 | try:
38 | return int(s)
39 | except ValueError:
40 | return float(s)
--------------------------------------------------------------------------------
/compare_mt/bucketers.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import itertools
3 | import numpy as np
4 | from collections import defaultdict
5 |
6 | from compare_mt import corpus_utils
7 | from compare_mt import scorers
8 | from compare_mt import arg_utils
9 |
10 | class Bucketer:
11 |
12 | def set_bucket_cutoffs(self, bucket_cutoffs, num_type='int'):
13 | self.bucket_cutoffs = bucket_cutoffs
14 | self.bucket_strs = []
15 | for i, x in enumerate(bucket_cutoffs):
16 | if i == 0:
17 | self.bucket_strs.append(f'<{x}')
18 | elif num_type == 'int' and x-1 == bucket_cutoffs[i-1]:
19 | self.bucket_strs.append(f'{x-1}')
20 | else:
21 | self.bucket_strs.append(f'[{bucket_cutoffs[i-1]},{x})')
22 | self.bucket_strs.append(f'>={x}')
23 |
24 | def cutoff_into_bucket(self, value):
25 | for i, v in enumerate(self.bucket_cutoffs):
26 | if value < v:
27 | return i
28 | return len(self.bucket_cutoffs)
29 |
30 | class WordBucketer(Bucketer):
31 |
32 | def calc_bucket(self, val, label=None):
33 | """
34 | Calculate the bucket for a particular word
35 |
36 | Args:
37 | val: The word to calculate the bucket for
38 | label: If there's a label on the target word, add it
39 |
40 | Returns:
41 | An integer ID of the bucket
42 | """
43 | raise NotImplementedError('calc_bucket must be implemented in subclasses of WordBucketer')
44 |
45 | def _calc_trg_matches(self, ref_sent, out_sents):
46 | ref_pos = defaultdict(lambda: [])
47 | out_matches = [[-1 for _ in s] for s in out_sents]
48 | ref_matches = [[-1 for _ in ref_sent] for _ in out_sents]
49 | for ri, ref_word in enumerate(ref_sent):
50 | ref_pos[ref_word].append(ri)
51 | for oai, out_sent in enumerate(out_sents):
52 | out_word_cnts = {}
53 | for oi, out_word in enumerate(out_sent):
54 | ref_poss = ref_pos.get(out_word, None)
55 | if ref_poss:
56 | out_word_cnt = out_word_cnts.get(out_word, 0)
57 | if out_word_cnt < len(ref_poss):
58 | out_matches[oai][oi] = ref_poss[out_word_cnt]
59 | ref_matches[oai][ref_poss[out_word_cnt]] = oi
60 | out_word_cnts[out_word] = out_word_cnt + 1
61 | return out_matches, ref_matches
62 |
63 | def _calc_trg_buckets_and_matches(self, ref_sent, ref_label, out_sents, out_labels):
64 | # Initial setup for special cases
65 | if self.case_insensitive:
66 | ref_sent = [corpus_utils.lower(w) for w in ref_sent]
67 | out_sents = [[corpus_utils.lower(w) for w in out_sent] for out_sent in out_sents]
68 | if not ref_label:
69 | ref_label = []
70 | out_labels = [[] for _ in out_sents]
71 | # Get matches
72 | out_matches, _ = self._calc_trg_matches(ref_sent, out_sents)
73 | # Process the reference, getting the bucket
74 | ref_buckets = [self.calc_bucket(w, label=l) for (w,l) in itertools.zip_longest(ref_sent, ref_label)]
75 | # Process each of the outputs, finding matches
76 | out_buckets = [[] for _ in out_sents]
77 | for oai, (out_sent, out_label, match, out_buck) in \
78 | enumerate(itertools.zip_longest(out_sents, out_labels, out_matches, out_buckets)):
79 | for oi, (w, l, m) in enumerate(itertools.zip_longest(out_sent, out_label, match)):
80 | out_buck.append(self.calc_bucket(w, label=l) if m < 0 else ref_buckets[m])
81 | # Calculate totals for each sentence
82 | num_buckets = len(self.bucket_strs)
83 | num_outs = len(out_sents)
84 | my_ref_total = np.zeros(num_buckets ,dtype=int)
85 | my_out_totals = np.zeros( (num_outs, num_buckets) ,dtype=int)
86 | my_out_matches = np.zeros( (num_outs, num_buckets) ,dtype=int)
87 | for b in ref_buckets:
88 | if isinstance(b, list):
89 | for bi in b:
90 | my_ref_total[bi] += 1
91 | else:
92 | my_ref_total[b] += 1
93 | for oi, (obs, ms) in enumerate(zip(out_buckets, out_matches)):
94 | for b, m in zip(obs, ms):
95 | if isinstance(b, list):
96 | for bi in b:
97 | my_out_totals[oi,bi] += 1
98 | if m >= 0:
99 | my_out_matches[oi,bi] += 1
100 | else:
101 | my_out_totals[oi,b] += 1
102 | if m >= 0:
103 | my_out_matches[oi,b] += 1
104 | return my_ref_total, my_out_totals, my_out_matches, ref_buckets, out_buckets, out_matches
105 |
106 | def _calc_src_buckets_and_matches(self, src_sent, src_label, ref_sent, ref_aligns, out_sents):
107 | # Initial setup for special cases
108 | if self.case_insensitive:
109 | src_sent = [corpus_utils.lower(w) for w in src_sent]
110 | ref_sent = [corpus_utils.lower(w) for w in ref_sent]
111 | out_sents = [[corpus_utils.lower(w) for w in out_sent] for out_sent in out_sents]
112 | if not src_label:
113 | src_label = []
114 | # Get matches
115 | _, ref_matches = self._calc_trg_matches(ref_sent, out_sents)
116 | # Process the source, getting the bucket
117 | src_buckets = [self.calc_bucket(w, label=l) for (w,l) in itertools.zip_longest(src_sent, src_label)]
118 | # For each source word, find the reference words that need to be correct
119 | src_aligns = [[] for _ in src_sent]
120 | for src, trg in ref_aligns:
121 | src_aligns[src].append(trg)
122 | # Calculate totals for each sentence
123 | num_buckets = len(self.bucket_strs)
124 | num_outs = len(out_sents)
125 | my_ref_total = np.zeros(num_buckets ,dtype=int)
126 | my_out_matches = np.zeros( (num_outs, num_buckets) ,dtype=int)
127 | for src_bucket in src_buckets:
128 | my_ref_total[src_bucket] += 1
129 | my_out_totals = np.broadcast_to(np.reshape(my_ref_total, (1, num_buckets)), (num_outs, num_buckets))
130 | for oai, (out_sent, ref_match) in enumerate(zip(out_sents, ref_matches)):
131 | for src_bucket, src_align in zip(src_buckets, src_aligns):
132 | if len(src_align) != 0:
133 | if all([ref_match[x] >= 0 for x in src_align]):
134 | my_out_matches[oai,src_bucket] += 1
135 | return my_ref_total, my_out_totals, my_out_matches, src_buckets, src_aligns, ref_matches
136 |
137 | def calc_statistics(self, ref, outs,
138 | src=None,
139 | ref_labels=None, out_labels=None,
140 | ref_aligns=None, src_labels=None):
141 | """
142 | Calculate match statistics, bucketed by the type of word we have, and IDs of example sentences to show.
143 | This must be used with a subclass that has self.bucket_strs defined, and self.calc_bucket(word) implemented.
144 |
145 | Args:
146 | ref: The reference corpus
147 | outs: A list of output corpora
148 | src: Source sentences.
149 | If src is set, it will use ref_aligns, out_aligns, and src_labels.
150 | Otherwise, it will use ref_labels and out_labels.
151 | ref_labels: Labels of the reference corpus (optional)
152 | out_labels: Labels of the output corpora (should be specified iff ref_labels is)
153 |
154 | Returns:
155 | statistics: containing a list of equal length to out, containing for each system
156 | both_tot: the frequency of a particular bucket appearing in both output and reference
157 | ref_tot: the frequency of a particular bucket appearing in just reference
158 | out_tot: the frequency of a particular bucket appearing in just output
159 | rec: recall of the bucket
160 | prec: precision of the bucket
161 | fmeas: f1-measure of the bucket
162 | my_ref_total_list: containing a list of statistics of the reference
163 | my_out_matches_list: containing a list of statistics of the outputs
164 | """
165 | if not hasattr(self, 'case_insensitive'):
166 | self.case_insensitive = False
167 |
168 | # Dimensions
169 | num_buckets = len(self.bucket_strs)
170 | num_outs = len(outs)
171 |
172 | # Initialize the sufficient statistics for prec/rec/fmeas
173 | ref_total = np.zeros(num_buckets, dtype=int)
174 | out_totals = np.zeros( (num_outs, num_buckets) ,dtype=int)
175 | out_matches = np.zeros( ( num_outs, num_buckets) ,dtype=int)
176 |
177 | my_ref_total_list = []
178 | my_out_totals_list = []
179 | my_out_matches_list = []
180 |
181 | # Step through the sentences
182 | for rsi, (ref_sent, ref_label) in enumerate(itertools.zip_longest(ref, ref_labels if ref_labels else [])):
183 | if src:
184 | my_ref_total, my_out_totals, my_out_matches, _, _, _ = \
185 | self._calc_src_buckets_and_matches(src[rsi],
186 | src_labels[rsi] if src_labels else None,
187 | ref_sent,
188 | ref_aligns[rsi],
189 | [x[rsi] for x in outs])
190 | else:
191 | my_ref_total, my_out_totals, my_out_matches, _, _, _ = \
192 | self._calc_trg_buckets_and_matches(ref_sent,
193 | ref_label,
194 | [x[rsi] for x in outs],
195 | [x[rsi] for x in out_labels] if out_labels else None)
196 | ref_total += my_ref_total
197 | out_totals += my_out_totals
198 | out_matches += my_out_matches
199 |
200 | my_ref_total_list.append(my_ref_total)
201 | my_out_totals_list.append(my_out_totals)
202 | my_out_matches_list.append(my_out_matches)
203 |
204 | # Calculate statistics
205 | statistics = [[] for _ in range(num_outs)]
206 | for oi, ostatistics in enumerate(statistics):
207 | for bi in range(num_buckets):
208 | mcnt, ocnt, rcnt = out_matches[oi,bi], out_totals[oi,bi], ref_total[bi]
209 | if mcnt == 0:
210 | rec, prec, fmeas = 0.0, 0.0, 0.0
211 | else:
212 | rec = mcnt / float(rcnt)
213 | prec = mcnt / float(ocnt)
214 | fmeas = 2 * prec * rec / (prec + rec)
215 | ostatistics.append( (mcnt, rcnt, ocnt, rec, prec, fmeas) )
216 |
217 | return statistics, my_ref_total_list, my_out_totals_list, my_out_matches_list
218 |
219 | def calc_bucket_details(self, my_ref_total_list, my_out_totals_list, my_out_matches_list, num_samples=1000, sample_ratio=0.5):
220 |
221 | ref_total = np.array(my_ref_total_list).sum(0)
222 |
223 | num_outs, num_buckets = my_out_totals_list[0].shape
224 | n = len(my_ref_total_list)
225 | ids = list(range(n))
226 | sample_size = int(np.ceil(n*sample_ratio))
227 | rt_arr = np.array(my_ref_total_list)
228 | ot_arr = np.array(my_out_totals_list)
229 | om_arr = np.array(my_out_matches_list)
230 | statistics = [[ [] for __ in range(num_buckets) ] for _ in range(num_outs)]
231 | for _ in range(num_samples):
232 | reduced_ids = np.random.choice(ids, size=sample_size, replace=True)
233 | reduced_ref_total, reduced_out_totals, reduced_out_matches= rt_arr[reduced_ids].sum(0), ot_arr[reduced_ids].sum(0), om_arr[reduced_ids].sum(0)
234 | # Calculate accuracy on the reduced sample and save stats
235 | for oi in range(num_outs):
236 | for bi in range(num_buckets):
237 | mcnt, ocnt, rcnt = reduced_out_matches[oi,bi], reduced_out_totals[oi,bi], reduced_ref_total[bi]
238 | if mcnt == 0:
239 | rec, prec, fmeas = 0.0, 0.0, 0.0
240 | else:
241 | rec = mcnt / float(rcnt)
242 | prec = mcnt / float(ocnt)
243 | fmeas = 2 * prec * rec / (prec + rec)
244 | statistics[oi][bi].append( (mcnt, rcnt, ocnt, rec, prec, fmeas) )
245 |
246 | intervals = [[] for _ in range(num_outs)]
247 | for oi in range(num_outs):
248 | for bi in range(num_buckets):
249 | if len(statistics[oi][bi]) > 0:
250 | _, _, _, recs, precs, fmeas = zip(*statistics[oi][bi])
251 | else:
252 | recs, precs, fmeas = [0.0], [0.0], [0.0]
253 | # The first three elements (intervals of mcnt, ocnt and rcnt) are None
254 | bounds = [None, None, None]
255 | for x in [recs, precs, fmeas]:
256 | x = list(x)
257 | x.sort()
258 | lower_bound = x[int(num_samples * 0.025)]
259 | upper_bound = x[int(num_samples * 0.975)]
260 | bounds.append( (lower_bound, upper_bound) )
261 | intervals[oi].append(bounds)
262 |
263 | return ref_total, intervals
264 |
265 | def calc_examples(self, num_sents, num_outs,
266 | statistics,
267 | my_ref_total_list, my_out_matches_list,
268 | num_examples=5):
269 | """
270 | Calculate examples based the computed statistics.
271 |
272 | Args:
273 | num_sents: number of sentences
274 | num_outs: number of outputs
275 | statistics: containing a list of equal length to out, containing for each system
276 | both_tot: the frequency of a particular bucket appearing in both output and reference
277 | ref_tot: the frequency of a particular bucket appearing in just reference
278 | out_tot: the frequency of a particular bucket appearing in just output
279 | rec: recall of the bucket
280 | prec: precision of the bucket
281 | fmeas: f1-measure of the bucket
282 | my_ref_total_list: containing a list of statistics of the reference
283 | my_out_matches_list: containing a list of statistics of the outputs
284 | num_examples: number of examples to print
285 |
286 | Returns:
287 | example: containing a list of examples to print
288 | """
289 | num_buckets = len(self.bucket_strs)
290 | num_examp_feats = 3
291 | example_scores = np.zeros( (num_sents, num_examp_feats, num_buckets) )
292 |
293 | # Step through the sentences
294 | for rsi, (my_ref_total, my_out_matches) in enumerate(zip(my_ref_total_list, my_out_matches_list)):
295 |
296 | # Scoring of examples across different dimensions:
297 | # 0: overall variance of matches
298 | example_scores[rsi,0] = (my_out_matches / (my_ref_total+1e-10).reshape( (1, num_buckets) )).std(axis=0)
299 | # 1: overall percentage of matches
300 | example_scores[rsi,1] = my_out_matches.sum(axis=0) / (my_ref_total*num_outs+1e-10)
301 | # 2: overall percentage of misses
302 | example_scores[rsi,2] = (my_ref_total*num_outs-my_out_matches.sum(axis=0)) / (my_ref_total*num_outs+1e-10)
303 |
304 | # Calculate statistics
305 | # Find top-5 examples of each class
306 | examples = [[('Examples where some systems were good, some were bad', []),
307 | ('Examples where all systems were good', []),
308 | ('Examples where all systems were bad', [])] for _ in range(num_buckets)]
309 | # NOTE: This could be made faster with argpartition, but the complexity is probably not worth it
310 | topn = np.argsort(-example_scores, axis=0)
311 | for bi, bexamples in enumerate(examples):
312 | for fi, (_, fexamples) in enumerate(bexamples):
313 | for si in topn[:num_examples,fi,bi]:
314 | if example_scores[si,fi,bi] > 0:
315 | fexamples.append(si)
316 |
317 | return examples
318 |
319 | def calc_source_bucketed_matches(self, src, ref, out, ref_aligns, out_aligns, src_labels=None):
320 | """
321 | Calculate the number of matches, bucketed by the type of word we have
322 | This must be used with a subclass that has self.bucket_strs defined, and self.calc_bucket(word) implemented.
323 |
324 | Args:
325 | src: The source corpus
326 | ref: The reference corpus
327 | out: The output corpus
328 | ref_aligns: Alignments of the reference corpus
329 | out_aligns: Alignments of the output corpus
330 | src_labels: Labels of the source corpus (optional)
331 |
332 | Returns:
333 | A tuple containing:
334 | both_tot: the frequency of a particular bucket appearing in both output and reference
335 | ref_tot: the frequency of a particular bucket appearing in just reference
336 | out_tot: the frequency of a particular bucket appearing in just output
337 | rec: recall of the bucket
338 | prec: precision of the bucket
339 | fmeas: f1-measure of the bucket
340 | """
341 | if not hasattr(self, 'case_insensitive'):
342 | self.case_insensitive = False
343 |
344 | src_labels = src_labels if src_labels else []
345 | matches = [[0, 0, 0] for x in self.bucket_strs]
346 | for src_sent, ref_sent, out_sent, ref_align, out_align, src_lab in itertools.zip_longest(src, ref, out, ref_aligns, out_aligns, src_labels):
347 | ref_cnt = defaultdict(lambda: 0)
348 | for i, word in enumerate(ref_sent):
349 | if self.case_insensitive:
350 | word = corpus_utils.lower(word)
351 | ref_cnt[word] += 1
352 | for i, (src_index, trg_index) in enumerate(out_align):
353 | src_word = src_sent[src_index]
354 | word = out_sent[trg_index]
355 | if self.case_insensitive:
356 | word = corpus_utils.lower(word)
357 | bucket = self.calc_bucket(src_word,
358 | label=src_lab[src_index] if src_lab else None)
359 | if ref_cnt[word] > 0:
360 | ref_cnt[word] -= 1
361 | matches[bucket][0] += 1
362 | matches[bucket][2] += 1
363 | for i, (src_index, trg_index) in enumerate(ref_align):
364 | src_word = src_sent[src_index]
365 | bucket = self.calc_bucket(src_word,
366 | label=src_lab[src_index] if src_lab else None)
367 | matches[bucket][1] += 1
368 |
369 | for both_tot, ref_tot, out_tot in matches:
370 | if both_tot == 0:
371 | rec, prec, fmeas = 0.0, 0.0, 0.0
372 | else:
373 | rec = both_tot / float(ref_tot)
374 | prec = both_tot / float(out_tot)
375 | fmeas = 2 * prec * rec / (prec + rec)
376 | yield both_tot, ref_tot, out_tot, rec, prec, fmeas
377 |
378 | def calc_bucketed_likelihoods(self, corpus, likelihoods):
379 | """
380 | Calculate the average of log likelihoods, bucketed by the type of word/label we have
381 | This must be used with a subclass that has self.bucket_strs defined, and self.calc_bucket(word) implemented.
382 |
383 | Args:
384 | corpus: The text/label corpus over which we compute the likelihoods
385 | likelihoods: The log-likelihoods corresponding to each word/label in the corpus
386 |
387 | Returns:
388 | the average log-likelihood bucketed by the type of word/label we have
389 | """
390 | if not hasattr(self, 'case_insensitive'):
391 | self.case_insensitive = False
392 |
393 | if type(corpus) == str:
394 | corpus = corpus_utils.load_tokens(corpus)
395 | bucketed_likelihoods = [[0.0, 0] for _ in self.bucket_strs]
396 | if len(corpus) != len(likelihoods):
397 | raise ValueError("Corpus and likelihoods should have the same size.")
398 | for sent, list_of_likelihoods in zip(corpus, likelihoods):
399 | if len(sent) != len(list_of_likelihoods):
400 | raise ValueError("Each sentence of the corpus should have likelihood value for each word")
401 |
402 | for word, ll in zip(sent, list_of_likelihoods):
403 | if self.case_insensitive:
404 | word = corpus_utils.lower(word)
405 | bucket = self.calc_bucket(word, label=word)
406 | bucketed_likelihoods[bucket][0] += ll
407 | bucketed_likelihoods[bucket][1] += 1
408 |
409 | for ll, count in bucketed_likelihoods:
410 | if count != 0:
411 | yield ll/float(count)
412 | else:
413 | yield "NA" # not applicable
414 |
415 |
416 | class FreqWordBucketer(WordBucketer):
417 |
418 | def __init__(self,
419 | freq_counts=None, freq_count_file=None, freq_corpus_file=None, freq_data=None,
420 | bucket_cutoffs=None,
421 | case_insensitive=False):
422 | """
423 | A bucketer that buckets words by their frequency.
424 |
425 | Args:
426 | freq_counts: A dictionary containing word/count data.
427 | freq_count_file: A file containing counts for each word in tab-separated word, count format.
428 | Ignored if freq_counts exists.
429 | freq_corpus_file: A file with a corpus used for collecting counts. Ignored if freq_count_file exists.
430 | freq_data: A tokenized corpus from which counts can be calculated. Ignored if freq_corpus_file exists.
431 | bucket_cutoffs: Cutoffs for each bucket.
432 | The first bucket will be range(0,bucket_cutoffs[0]).
433 | Middle buckets will be range(bucket_cutoffs[i],bucket_cutoffs[i-1].
434 | Final bucket will be everything greater than bucket_cutoffs[-1].
435 | case_insensitive: A boolean specifying whether to turn on the case insensitive option.
436 | """
437 | self.case_insensitive = case_insensitive
438 | if not freq_counts:
439 | freq_counts = defaultdict(lambda: 0)
440 | if freq_count_file != None:
441 | print(f'Reading frequency from "{freq_count_file}"')
442 | with open(freq_count_file, "r") as f:
443 | for line in f:
444 | cols = line.strip().split('\t')
445 | if len(cols) != 2:
446 | print(f'Bad line in counts file {freq_count_file}, ignoring:\n{line}')
447 | else:
448 | word, freq = cols
449 | if self.case_insensitive:
450 | word = corpus_utils.lower(word)
451 | freq_counts[word] = int(freq)
452 | elif freq_corpus_file:
453 | print(f'Reading frequency from "{freq_corpus_file}"')
454 | for words in corpus_utils.iterate_tokens(freq_corpus_file):
455 | for word in words:
456 | if self.case_insensitive:
457 | word = corpus_utils.lower(word)
458 | freq_counts[word] += 1
459 | elif freq_data:
460 | print('Reading frequency from the reference')
461 | for words in freq_data:
462 | for word in words:
463 | if self.case_insensitive:
464 | word = corpus_utils.lower(word)
465 | freq_counts[word] += 1
466 | else:
467 | raise ValueError('Must have at least one source of frequency counts for FreqWordBucketer')
468 | self.freq_counts = freq_counts
469 |
470 | if bucket_cutoffs is None:
471 | bucket_cutoffs = [1, 2, 3, 4, 5, 10, 100, 1000]
472 | self.set_bucket_cutoffs(bucket_cutoffs)
473 |
474 | def calc_bucket(self, word, label=None):
475 | if self.case_insensitive:
476 | word = corpus_utils.lower(word)
477 | return self.cutoff_into_bucket(self.freq_counts.get(word, 0))
478 |
479 | def name(self):
480 | return "frequency"
481 |
482 | def idstr(self):
483 | return "freq"
484 |
485 | class CaseWordBucketer(WordBucketer):
486 |
487 | def __init__(self):
488 | """
489 | A bucketer that buckets words by whether they're all all lower-case (lower), all upper-case (upper),
490 | title case (title), or other.
491 | """
492 | self.bucket_strs = ['lower', 'upper', 'title', 'other']
493 |
494 | def calc_bucket(self, word, label=None):
495 | if word.islower():
496 | return 0
497 | elif word.isupper():
498 | return 1
499 | elif word.istitle():
500 | return 2
501 | else:
502 | return 3
503 |
504 | def name(self):
505 | return "case"
506 |
507 | def idstr(self):
508 | return "case"
509 |
510 | class LabelWordBucketer(WordBucketer):
511 |
512 | def __init__(self,
513 | label_set=None):
514 | """
515 | A bucketer that buckets words by their labels.
516 |
517 | Args:
518 | label_set: The set of labels to use as buckets. This can be a list, or a string separated by '+'s.
519 | """
520 | if type(label_set) == str:
521 | label_set = label_set.split('+')
522 | self.bucket_strs = label_set + ['other']
523 | label_set_len = len(label_set)
524 | self.bucket_map = defaultdict(lambda: label_set_len)
525 | for i, l in enumerate(label_set):
526 | self.bucket_map[l] = i
527 |
528 | def calc_bucket(self, word, label=None):
529 | if not label:
530 | raise ValueError('When calculating buckets by label, label must be non-zero')
531 | return self.bucket_map[label]
532 |
533 | def name(self):
534 | return "labels"
535 |
536 | def idstr(self):
537 | return "labels"
538 |
539 | class MultiLabelWordBucketer(WordBucketer):
540 |
541 | def __init__(self,
542 | label_set=None):
543 | """
544 | A bucketer that buckets words by one or multiple labels.
545 |
546 | Args:
547 | label_set: The set of labels to use as buckets. This can be a list, or a string separated by '+'s.
548 | """
549 | if type(label_set) == str:
550 | label_set = label_set.split('+')
551 | self.bucket_strs = label_set + ['other']
552 | label_set_len = len(label_set)
553 | self.bucket_map = defaultdict(lambda: label_set_len)
554 | for i, l in enumerate(label_set):
555 | self.bucket_map[l] = i
556 |
557 | def calc_bucket(self, word, label=None):
558 | if not label:
559 | raise ValueError('When calculating buckets by label, label must be non-zero')
560 | label = label.split('+')
561 | return [self.bucket_map[l] for l in label]
562 |
563 | def name(self):
564 | return "multilabels"
565 |
566 | def idstr(self):
567 | return "multilabels"
568 |
569 | class NumericalLabelWordBucketer(WordBucketer):
570 |
571 | def __init__(self,
572 | bucket_cutoffs=None):
573 | """
574 | A bucketer that buckets words by labels that are numerical values.
575 |
576 | Args:
577 | bucket_cutoffs: Cutoffs for each bucket.
578 | The first bucket will be range(0,bucket_cutoffs[0]).
579 | Middle buckets will be range(bucket_cutoffs[i],bucket_cutoffs[i-1].
580 | Final bucket will be everything greater than bucket_cutoffs[-1].
581 | """
582 | if bucket_cutoffs is None:
583 | bucket_cutoffs = [0.25, 0.5, 0.75]
584 | self.set_bucket_cutoffs(bucket_cutoffs)
585 |
586 | def calc_bucket(self, word, label=None):
587 | if label:
588 | return self.cutoff_into_bucket(float(label))
589 | else:
590 | raise ValueError('When calculating buckets by label must be non-zero')
591 |
592 | def name(self):
593 | return "numerical labels"
594 |
595 | def idstr(self):
596 | return "numlabels"
597 |
598 | class SentenceBucketer(Bucketer):
599 |
600 | def calc_bucket(self, val, ref=None, src=None, out_label=None, ref_label=None):
601 | """
602 | Calculate the bucket for a particular sentence
603 |
604 | Args:
605 | val: The sentence to calculate the bucket for
606 | ref: The reference sentence, if it exists
607 | src: The source sentence, if it exists
608 | ref_labels: The label of the reference sentence, if it exists
609 | out_labels: The label of the output sentence, if it exists
610 |
611 | Returns:
612 | An integer ID of the bucket
613 | """
614 | raise NotImplementedError('calc_bucket must be implemented in subclasses of SentenceBucketer')
615 |
616 | def create_bucketed_corpus(self, out, ref=None, src=None, ref_labels=None, out_labels=None):
617 | bucketed_corpus = [([],[] if ref else None, []) for _ in self.bucket_strs]
618 | if ref is None:
619 | ref = out
620 |
621 | if ref_labels is None:
622 | ref_labels = out_labels
623 |
624 | src = [None for _ in out] if src is None else src
625 |
626 | for i, (out_words, ref_words, src_words) in enumerate(zip(out, ref, src)):
627 | bucket = self.calc_bucket(out_words, ref_words, src_words, label=(ref_labels[i][0] if ref_labels else None))
628 |
629 | bucketed_corpus[bucket][0].append(out_words)
630 | bucketed_corpus[bucket][1].append(ref_words)
631 | bucketed_corpus[bucket][2].append(src_words)
632 |
633 | return bucketed_corpus
634 |
635 |
636 | class ScoreSentenceBucketer(SentenceBucketer):
637 | """
638 | Bucket sentences by some score (e.g. BLEU)
639 | """
640 |
641 | def __init__(self, score_type, bucket_cutoffs=None, case_insensitive=False):
642 | self.score_type = score_type
643 | self.scorer = scorers.create_scorer_from_profile(score_type)
644 | if bucket_cutoffs is None:
645 | bucket_cutoffs = [x * self.scorer.scale / 10.0 for x in range(1,10)]
646 | self.set_bucket_cutoffs(bucket_cutoffs, num_type='float')
647 | self.case_insensitive = case_insensitive
648 |
649 | def calc_bucket(self, val, ref=None, src=None, label=None):
650 | if self.case_insensitive:
651 | return self.cutoff_into_bucket(self.scorer.score_sentence(corpus_utils.lower(ref), corpus_utils.lower(val))[0])
652 | else:
653 | return self.cutoff_into_bucket(self.scorer.score_sentence(ref, val, src)[0])
654 |
655 | def name(self):
656 | return self.scorer.name()
657 |
658 | def idstr(self):
659 | return self.scorer.idstr()
660 |
661 | class LengthSentenceBucketer(SentenceBucketer):
662 | """
663 | Bucket sentences by length
664 | """
665 |
666 | def __init__(self, bucket_cutoffs=None):
667 | if bucket_cutoffs is None:
668 | bucket_cutoffs = [10, 20, 30, 40, 50, 60]
669 | self.set_bucket_cutoffs(bucket_cutoffs, num_type='int')
670 |
671 | def calc_bucket(self, val, ref=None, src=None, label=None):
672 | return self.cutoff_into_bucket(len(ref))
673 |
674 | def name(self):
675 | return "length"
676 |
677 | def idstr(self):
678 | return "length"
679 |
680 | class LengthDiffSentenceBucketer(SentenceBucketer):
681 | """
682 | Bucket sentences by length
683 | """
684 |
685 | def __init__(self, bucket_cutoffs=None):
686 | if bucket_cutoffs is None:
687 | bucket_cutoffs = [-20, -10, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 11, 21]
688 | self.set_bucket_cutoffs(bucket_cutoffs, num_type='int')
689 |
690 | def calc_bucket(self, val, ref=None, src=None, label=None):
691 | return self.cutoff_into_bucket(len(val) - len(ref))
692 |
693 | def name(self):
694 | return "len(output)-len(reference)"
695 |
696 | def idstr(self):
697 | return "lengthdiff"
698 |
699 | class LabelSentenceBucketer(SentenceBucketer):
700 |
701 | def __init__(self, label_set=None):
702 | """
703 | A bucketer that buckets sentences by their labels.
704 |
705 | Args:
706 | label_set: The set of labels to use as buckets. This can be a list, or a string separated by '+'s.
707 | """
708 | if type(label_set) == str:
709 | label_set = label_set.split('+')
710 | self.bucket_strs = label_set + ['other']
711 | label_set_len = len(label_set)
712 | self.bucket_map = defaultdict(lambda: label_set_len)
713 | for i, l in enumerate(label_set):
714 | self.bucket_map[l] = i
715 |
716 | def calc_bucket(self, val, ref=None, src=None, label=None):
717 | return self.bucket_map[label]
718 |
719 | def name(self):
720 | return "labels"
721 |
722 | def idstr(self):
723 | return "labels"
724 |
725 | class MultiLabelSentenceBucketer(SentenceBucketer):
726 |
727 | def __init__(self, label_set=None):
728 | """
729 | A bucketer that buckets sentences by their labels.
730 |
731 | Args:
732 | label_set: The set of labels to use as buckets. This can be a list, or a string separated by '+'s.
733 | """
734 | if type(label_set) == str:
735 | label_set = label_set.split('+')
736 | self.bucket_strs = label_set + ['other']
737 | label_set_len = len(label_set)
738 | self.bucket_map = defaultdict(lambda: label_set_len)
739 | for i, l in enumerate(label_set):
740 | self.bucket_map[l] = i
741 |
742 | def calc_bucket(self, val, ref=None, src=None, label=None):
743 | label = label.split('+')
744 | return [self.bucket_map[l] for l in label]
745 |
746 | def name(self):
747 | return "multilabels"
748 |
749 | def idstr(self):
750 | return "multilabels"
751 |
752 | class NumericalLabelSentenceBucketer(SentenceBucketer):
753 |
754 | def __init__(self, bucket_cutoffs=None):
755 | """
756 | A bucketer that buckets sentences by labels that are numerical values.
757 |
758 | Args:
759 | bucket_cutoffs: Cutoffs for each bucket.
760 | The first bucket will be range(0,bucket_cutoffs[0]).
761 | Middle buckets will be range(bucket_cutoffs[i],bucket_cutoffs[i-1].
762 | Final bucket will be everything greater than bucket_cutoffs[-1].
763 | """
764 | if bucket_cutoffs is None:
765 | bucket_cutoffs = [0.25, 0.5, 0.75]
766 | self.set_bucket_cutoffs(bucket_cutoffs)
767 |
768 | def calc_bucket(self, val, ref=None, src=None, label=None):
769 | return self.cutoff_into_bucket(float(label))
770 |
771 | def name(self):
772 | return "numerical labels"
773 |
774 | def idstr(self):
775 | return "numlabels"
776 |
777 | def create_word_bucketer_from_profile(bucket_type,
778 | freq_counts=None, freq_count_file=None, freq_corpus_file=None, freq_data=None,
779 | label_set=None,
780 | bucket_cutoffs=None,
781 | case_insensitive=False):
782 | if type(bucket_cutoffs) == str:
783 | bucket_cutoffs = [arg_utils.parse_intfloat(x) for x in bucket_cutoffs.split(':')]
784 | if bucket_type == 'freq':
785 | return FreqWordBucketer(
786 | freq_counts=freq_counts,
787 | freq_count_file=freq_count_file,
788 | freq_corpus_file=freq_corpus_file,
789 | freq_data=freq_data,
790 | bucket_cutoffs=bucket_cutoffs,
791 | case_insensitive=case_insensitive)
792 | if bucket_type == 'case':
793 | return CaseWordBucketer()
794 | elif bucket_type == 'label':
795 | return LabelWordBucketer(
796 | label_set=label_set)
797 | elif bucket_type == 'multilabel':
798 | return MultiLabelWordBucketer(
799 | label_set=label_set)
800 | elif bucket_type == 'numlabel':
801 | return NumericalLabelWordBucketer(
802 | bucket_cutoffs=bucket_cutoffs)
803 | else:
804 | raise ValueError(f'Illegal bucket type {bucket_type}')
805 |
806 | def create_sentence_bucketer_from_profile(bucket_type,
807 | score_type=None,
808 | bucket_cutoffs=None,
809 | label_set=None,
810 | case_insensitive=False):
811 | if type(bucket_cutoffs) == str:
812 | bucket_cutoffs = [arg_utils.parse_intfloat(x) for x in bucket_cutoffs.split(':')]
813 | if bucket_type == 'score':
814 | return ScoreSentenceBucketer(score_type, bucket_cutoffs=bucket_cutoffs, case_insensitive=case_insensitive)
815 | elif bucket_type == 'length':
816 | return LengthSentenceBucketer(bucket_cutoffs=bucket_cutoffs)
817 | elif bucket_type == 'lengthdiff':
818 | return LengthDiffSentenceBucketer(bucket_cutoffs=bucket_cutoffs)
819 | elif bucket_type == 'label':
820 | return LabelSentenceBucketer(label_set=label_set)
821 | elif bucket_type == 'multilabel':
822 | return MultiLabelSentenceBucketer(
823 | label_set=label_set)
824 | elif bucket_type == 'numlabel':
825 | return NumericalLabelSentenceBucketer(bucket_cutoffs=bucket_cutoffs)
826 | else:
827 | raise NotImplementedError(f'Illegal bucket type {bucket_type}')
828 |
--------------------------------------------------------------------------------
/compare_mt/cache_utils.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 | from nltk.stem.porter import PorterStemmer
3 |
4 | def extract_cache_dicts(cache_dicts, key_list, num_out):
5 | if cache_dicts is not None:
6 | if len(cache_dicts) != num_out:
7 | raise ValueError(f'Length of cache_dicts should be equal to the number of output files!')
8 | if len(key_list) == 1:
9 | return [c[key_list[0]] for c in cache_dicts]
10 | return zip(*[[c[k] for k in key_list] for c in cache_dicts])
11 |
12 | return [None]*len(key_list)
13 |
14 | def return_cache_dict(key_list, value_list):
15 | for v in value_list:
16 | if len(v) != 1:
17 | raise ValueError(f'Only support caching for one system at a time!')
18 | cache_dict = {k:v[0] for (k, v) in zip(key_list, value_list)}
19 | return cache_dict
20 |
21 | class CachedPorterStemmer(PorterStemmer):
22 | """A wrapper class for PorterStemmer that uses LRU cache to reduce latency"""
23 | def __init__(self, mode=PorterStemmer.NLTK_EXTENSIONS):
24 | super().__init__(mode)
25 |
26 | @lru_cache(maxsize=50000)
27 | def stem(self, word, to_lowercase=True):
28 | return super().stem(word, to_lowercase)
--------------------------------------------------------------------------------
/compare_mt/compare_ll_main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | # In-package imports
4 | from compare_mt import corpus_utils
5 | from compare_mt import bucketers
6 | from compare_mt import arg_utils
7 | from compare_mt import print_utils
8 | from compare_mt import formatting
9 |
10 | def print_word_likelihood_report(ref, lls, bucket_type='freq', bucket_cutoffs=None,
11 | freq_count_file=None, freq_corpus_file=None,
12 | label_corpus=None, label_set=None,
13 | case_insensitive=False):
14 | """
15 | Print a report comparing the word log likelihood.
16 |
17 | Args:
18 | ref: the ref of words over which the likelihoods are computed
19 | lls: likelihoods corresponding to each word in ref from the systems
20 | bucket_type: A string specifying the way to bucket words together to calculate average likelihood
21 | bucket_cutoffs: The boundaries between buckets, specified as a colon-separated string.
22 | freq_corpus_file: When using "freq" as a bucketer, which corpus to use to calculate frequency.
23 | freq_count_file: An alternative to freq_corpus that uses a count file in "word\tfreq" format.
24 | label_corpus: When using "label" as bucket type, the corpus containing the labels
25 | corresponding to each word in the corpus
26 | label_set: the permissible set of labels when using "label" as a bucket type
27 | case_insensitive: A boolean specifying whether to turn on the case insensitive option
28 | """
29 | case_insensitive = True if case_insensitive == 'True' else False
30 |
31 | bucketer = bucketers.create_word_bucketer_from_profile(bucket_type=bucket_type,
32 | bucket_cutoffs=bucket_cutoffs,
33 | freq_count_file=freq_count_file,
34 | freq_corpus_file=freq_corpus_file,
35 | label_set=label_set,
36 | case_insensitive=case_insensitive)
37 |
38 | if type(label_corpus) == str:
39 | label_corpus = corpus_utils.load_tokens(label_corpus)
40 |
41 | if label_corpus is not None:
42 | ref = label_corpus
43 |
44 | lls_out = [[l for l in bucketer.calc_bucketed_likelihoods(ref, ll)] for ll in lls]
45 |
46 | print(f'--- average word log likelihood by {bucketer.name()} bucket')
47 | for i, bucket_str in enumerate(bucketer.bucket_strs):
48 | print (bucket_str + "\t", end='')
49 | for ll_out in lls_out:
50 | print(f"{formatting.fmt(ll_out[i])}\t", end="")
51 | print()
52 |
53 | def main():
54 | parser = argparse.ArgumentParser(
55 | description='Program to compare MT results',
56 | )
57 | parser.add_argument('--ref-file', type=str, dest='ref_file',
58 | help='A path to a reference file over which the likelihoods are being computed/compared')
59 | parser.add_argument('--ll-files', type=str, nargs='+', dest='ll_files',
60 | help='A path to file containing log likelihoods for ref-file generated by systems')
61 | parser.add_argument('--compare-word-likelihoods', type=str, dest='compare_word_likelihoods', nargs='*',
62 | default=['bucket_type=freq'],
63 | help="""
64 | Compare word log likelihoods by buckets. Can specify arguments in 'arg1=val1,arg2=val2,...' format.
65 | See documentation for 'print_word_likelihood_report' to see which arguments are available.
66 | """)
67 | parser.add_argument('--decimals', type=int, default=4,
68 | help="Number of decimals to print for floating point numbers")
69 |
70 | args = parser.parse_args()
71 |
72 | # Set formatting
73 |
74 | # Set formatting
75 | formatting.fmt.set_decimals(args.decimals)
76 |
77 | ref = corpus_utils.load_tokens(args.ref_file)
78 | lls = [corpus_utils.load_nums(x) for x in args.ll_files]
79 |
80 | # Word likelihood analysis
81 | if args.compare_word_likelihoods:
82 | print_utils.print_header('Word Likelihood Analysis')
83 | for profile in args.compare_word_likelihoods:
84 | kargs = arg_utils.parse_profile(profile)
85 | print_word_likelihood_report(ref, lls, **kargs)
86 | print()
87 |
88 |
89 | if __name__ == '__main__':
90 | main()
91 |
--------------------------------------------------------------------------------
/compare_mt/corpus_utils.py:
--------------------------------------------------------------------------------
1 | def iterate_tokens(filename):
2 | with open(filename, "r", encoding="utf-8") as f:
3 | for line in f:
4 | yield line.strip().split(' ')
5 |
6 | def load_tokens(filename):
7 | return list(iterate_tokens(filename))
8 |
9 | def iterate_nums(filename):
10 | with open(filename, "r", encoding="utf-8") as f:
11 | for line in f:
12 | yield [float(i) for i in line.strip().split(' ')]
13 |
14 | def load_nums(filename):
15 | return list(iterate_nums(filename))
16 |
17 | def iterate_alignments(filename):
18 | with open(filename, "r", encoding="utf-8") as f:
19 | for line in f:
20 | try:
21 | yield [(int(src),int(trg)) for (src,trg) in [x.split('-') for x in line.strip().split(' ')]]
22 | except:
23 | raise ValueError(f'Poorly formed alignment line in {filename}:\n{line}')
24 |
25 | def load_alignments(filename):
26 | return list(iterate_alignments(filename))
27 |
28 | def lower(inp):
29 | return inp.lower() if type(inp) == str else [lower(x) for x in inp]
30 |
31 | def list2str(l):
32 | string = ''
33 | for i, s in enumerate(l):
34 | string = string + ' ' + str(s) if i != 0 else string + str(s)
35 | return string
36 |
37 | def write_tokens(filename, ls):
38 | with open(filename, 'w') as f:
39 | for i, l in enumerate(ls):
40 | string = list2str(l)
41 | string = '\n' + string if i != 0 else string
42 | f.write(string)
43 | return string
44 |
--------------------------------------------------------------------------------
/compare_mt/formatting.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | class Formatter(object):
4 |
5 | latex_substitutions = {
6 | re.compile("\["): "{[}",
7 | re.compile("\]"): "{]}",
8 | re.compile("<"): r"\\textless",
9 | re.compile(">"): r"\\textgreater"
10 | }
11 |
12 | def __init__(self, decimals=4):
13 | self.set_decimals(decimals)
14 |
15 | def set_decimals(self, decimals):
16 | self.decimals = decimals
17 |
18 | def escape_latex(self, x):
19 | """Adds escape sequences wherever needed to make the output
20 | LateX compatible"""
21 | for pat, replace_with in self.latex_substitutions.items():
22 | x = pat.sub(replace_with, x)
23 | return x
24 |
25 | def __call__(self, x, latex=True):
26 | """Convert object to string with controlled decimals"""
27 | if isinstance(x, str):
28 | return self.escape_latex(x) if latex else x
29 | elif isinstance(x, int):
30 | return f"{x:d}"
31 | elif isinstance(x, float):
32 | return f"{x:.{self.decimals}f}"
33 | else:
34 | str(x)
35 |
36 | fmt = Formatter(decimals=4)
37 |
--------------------------------------------------------------------------------
/compare_mt/ngram_utils.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import itertools
3 |
4 | def sent_ngrams_list(words, n):
5 | """
6 | Create a list with all the n-grams in a sentence
7 |
8 | Arguments:
9 | words: A list of strings representing a sentence
10 | n: The ngram length to consider
11 |
12 | Returns:
13 | A list of n-grams in the sentence
14 | """
15 | word_ngram = []
16 | for i in range(len(words) - n + 1):
17 | ngram = tuple(words[i:i + n])
18 | word_ngram.append(ngram)
19 | return word_ngram
20 |
21 | def iterate_sent_ngrams(words, labels=None, min_length=1, max_length=4):
22 | """
23 | Create a list with all the n-grams in a sentence
24 |
25 | Arguments:
26 | words: A list of strings representing a sentence
27 | labels: A list of labels on each word in the sentence, optional (will use `words` if not specified)
28 | min_length: The minimum ngram length to consider
29 | max_length: The maximum ngram length to consider
30 |
31 | Returns:
32 | An iterator over n-grams in the sentence with both words and labels
33 | """
34 | if labels is not None and len(labels) != len(words):
35 | raise ValueError(f'length of labels and sentence must be the same but got'
36 | f' {len(words)} != {len(labels)} at\n{words}\n{labels}')
37 | for n in range(min_length-1, max_length):
38 | for i in range(len(words) - n):
39 | word_ngram = tuple(words[i:i + n + 1])
40 | label_ngram = tuple(labels[i:i + n + 1]) if (labels is not None) else word_ngram
41 | yield word_ngram, label_ngram
42 |
43 | def compare_ngrams(ref, out, ref_labels=None, out_labels=None, min_length=1, max_length=4):
44 | """
45 | Compare n-grams appearing in the reference sentences and output
46 |
47 | Args:
48 | ref: A list of reference sentences
49 | out: A list of output sentences
50 | ref_labels: Alternative labels for reference words (e.g. POS tags) to use when aggregating counts
51 | out_labels: Alternative labels for output words (e.g. POS tags) to use when aggregating counts
52 | min_length: The minimum length of n-grams to consider
53 | max_length: The maximum length of n-grams to consider
54 |
55 | Returns:
56 | A tuple of dictionaries including
57 | total: the total number of n-grams in the output
58 | match: the total number of matched n-grams appearing in both output and reference
59 | over: the total number of over-generated n-grams appearing in output but not reference
60 | under: the total number of under-generated n-grams appearing in output but not reference
61 | """
62 | if (ref_labels is None) != (out_labels is None):
63 | raise ValueError('ref_labels or out_labels must both be either None or not None')
64 | total, match, over, under = [defaultdict(lambda: 0) for _ in range(4)]
65 | if ref_labels is None: ref_labels = []
66 | if out_labels is None: out_labels = []
67 | for ref_sent, out_sent, ref_lab, out_lab in itertools.zip_longest(ref, out, ref_labels, out_labels):
68 | # Find the number of reference n-grams (on a word level)
69 | ref_ngrams = list(iterate_sent_ngrams(ref_sent, labels=ref_lab, min_length=min_length, max_length=max_length))
70 | ref_word_counts = defaultdict(lambda: 0)
71 | for ref_w, ref_l in ref_ngrams:
72 | ref_word_counts[ref_w] += 1
73 | # Step through the output ngrams and find matched and overproduced ones
74 | for out_w, out_l in iterate_sent_ngrams(out_sent, labels=out_lab, min_length=min_length, max_length=max_length):
75 | total[out_l] += 1
76 | if ref_word_counts[out_w] > 0:
77 | match[out_l] += 1
78 | ref_word_counts[out_w] -= 1
79 | else:
80 | over[out_l] += 1
81 | # Remaining ones are underproduced
82 | # (do reverse order just to make ordering consistent for over and under, shouldn't matter much)
83 | for ref_w, ref_l in reversed(ref_ngrams):
84 | if ref_word_counts[ref_w] > 0:
85 | under[ref_l] += 1
86 | ref_word_counts[ref_w] -= 1
87 | return total, match, over, under
88 |
--------------------------------------------------------------------------------
/compare_mt/print_utils.py:
--------------------------------------------------------------------------------
1 | def print_header(header):
2 | print(f'********************** {header} ************************')
--------------------------------------------------------------------------------
/compare_mt/reporters.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('agg')
3 | from matplotlib import pyplot as plt
4 | from cycler import cycler
5 | plt.rcParams['font.family'] = 'sans-serif'
6 | plt.rcParams['axes.prop_cycle'] = cycler(color=["#7293CB", "#E1974C", "#84BA5B", "#D35E60", "#808585", "#9067A7", "#AB6857", "#CCC210"])
7 | import numpy as np
8 | import os
9 | import itertools
10 | from compare_mt.formatting import fmt
11 |
12 | from functools import partial
13 | from http.server import SimpleHTTPRequestHandler, HTTPServer
14 | import socket
15 | from pathlib import Path
16 | import logging as log
17 |
18 | log.basicConfig(level=log.INFO)
19 |
20 | # Global variables used by all reporters. These are set by compare_mt_main.py
21 | sys_names = None
22 | fig_size = None
23 |
24 | # The CSS style file to use
25 | css_style = """
26 | html {
27 | font-family: sans-serif;
28 | }
29 |
30 | table, th, td {
31 | border: 1px solid black;
32 | }
33 |
34 | th, td {
35 | padding: 2px;
36 | }
37 |
38 | tr:hover {background-color: #f5f5f5;}
39 |
40 | tr:nth-child(even) {background-color: #f2f2f2;}
41 |
42 | th {
43 | background-color: #396AB1;
44 | color: white;
45 | }
46 |
47 | em {
48 | font-weight: bold;
49 | }
50 |
51 | caption {
52 | font-size: 14pt;
53 | font-weight: bold;
54 | }
55 |
56 | table {
57 | border-collapse: collapse;
58 | }
59 | """
60 |
61 | # The Javascript header to use
62 | javascript_style = """
63 | function showhide(elem) {
64 | var x = document.getElementById(elem);
65 | if (x.style.display === "none") {
66 | x.style.display = "block";
67 | } else {
68 | x.style.display = "none";
69 | }
70 | }
71 | """
72 |
73 | fig_counter, tab_counter = 0, 0
74 | def next_fig_id():
75 | global fig_counter
76 | fig_counter += 1
77 | return f'{fig_counter:03d}'
78 | def next_tab_id():
79 | global tab_counter
80 | tab_counter += 1
81 | return f'{tab_counter:03d}'
82 |
83 | def make_bar_chart(datas,
84 | output_directory, output_fig_file, output_fig_format='png',
85 | errs=None, title=None, xlabel=None, xticklabels=None, ylabel=None):
86 | fig, ax = plt.subplots(figsize=fig_size)
87 | ind = np.arange(len(datas[0]))
88 | width = 0.7/len(datas)
89 | bars = []
90 | for i, data in enumerate(datas):
91 | err = errs[i] if errs != None else None
92 | bars.append(ax.bar(ind+i*width, data, width, bottom=0, yerr=err))
93 | # Set axis/title labels
94 | if title is not None:
95 | ax.set_title(title)
96 | if xlabel is not None:
97 | ax.set_xlabel(xlabel)
98 | if ylabel is not None:
99 | ax.set_ylabel(ylabel)
100 | if xticklabels is not None:
101 | ax.set_xticks(ind + width / 2)
102 | ax.set_xticklabels(xticklabels)
103 | plt.xticks(rotation=70)
104 | else:
105 | ax.xaxis.set_visible(False)
106 |
107 | ax.legend(bars, sys_names)
108 | ax.autoscale_view()
109 |
110 | if not os.path.exists(output_directory):
111 | os.makedirs(output_directory)
112 | out_file = os.path.join(output_directory, f'{output_fig_file}.{output_fig_format}')
113 | plt.savefig(out_file, format=output_fig_format, bbox_inches='tight')
114 |
115 | def html_img_reference(fig_file, title):
116 | latex_code_pieces = [r"\begin{figure}[h]",
117 | r" \centering",
118 | r" \includegraphics{" + fig_file + ".pdf}",
119 | r" \caption{" + title + "}",
120 | r" \label{fig:" + fig_file + "}",
121 | r"\end{figure}"]
122 | latex_code = "\n".join(latex_code_pieces)
123 | return (f' ' +
124 | f' ' +
125 | f'
{latex_code}
')
126 |
127 | class Report:
128 | # def __init__(self, iterable=(), **kwargs):
129 | # # Initialize a report by a dictionary which contains all the statistics
130 | # self.__dict__.update(iterable, **kwargs)
131 |
132 | def print(self):
133 | raise NotImplementedError('print must be implemented in subclasses of Report')
134 |
135 | def plot(self, output_directory, output_fig_file, output_fig_type):
136 | raise NotImplementedError('plot must be implemented in subclasses of Report')
137 |
138 | def print_header(self, header):
139 | print(f'********************** {header} ************************')
140 |
141 | def print_tabbed_table(self, tab):
142 | for x in tab:
143 | print('\t'.join([fmt(y, latex=False) if y else '' for y in x]))
144 | print()
145 |
146 | def generate_report(self, output_fig_file=None, output_fig_format=None, output_directory=None):
147 | self.print()
148 |
149 | class ScoreReport(Report):
150 | def __init__(self, scorer, scores, strs,
151 | wins=None, sys_stats=None, prob_thresh=0.05,
152 | title=None):
153 | self.scorer = scorer
154 | self.scores = scores
155 | self.strs = [f'{fmt(x)} ({y})' if y else fmt(x) for (x,y) in zip(scores,strs)]
156 | self.wins = wins
157 | self.sys_stats = sys_stats
158 | self.output_fig_file = f'{next_fig_id()}-score-{scorer.idstr()}'
159 | self.prob_thresh = prob_thresh
160 | self.title = scorer.name() if not title else title
161 |
162 | def winstr_pval(self, my_wins):
163 | if 1-my_wins[0] < self.prob_thresh:
164 | winstr = 's1>s2'
165 | elif 1-my_wins[1] < self.prob_thresh:
166 | winstr = 's2>s1'
167 | else:
168 | winstr = '-'
169 | pval = 1-(my_wins[0] if my_wins[0] > my_wins[1] else my_wins[1])
170 | return winstr, pval
171 |
172 | def scores_to_tables(self):
173 | if self.wins is None:
174 | # Single table with just scores
175 | return [[""]+sys_names, [self.scorer.name()]+self.strs], None
176 | elif len(self.scores) == 1:
177 | # Single table with scores for one system
178 | return [
179 | [""]+sys_names,
180 | [self.scorer.name()]+self.strs,
181 | [""]+[f'[{fmt(x["lower_bound"])},{fmt(x["upper_bound"])}]' for x in self.sys_stats]
182 | ], None
183 | elif len(self.scores) == 2:
184 | # Single table with scores and wins for two systems
185 | winstr, pval = self.winstr_pval(self.wins[0][1])
186 | return [
187 | [""]+sys_names+["Win?"],
188 | [self.scorer.name()]+self.strs+[winstr],
189 | [""]+[f'[{fmt(x["lower_bound"])},{fmt(x["upper_bound"])}]' for x in self.sys_stats]+[f'p={fmt(pval)}']
190 | ], None
191 | else:
192 | # Table with scores, and separate one with wins for multiple systems
193 | wptable = [['v s1 / s2 ->'] + [sys_names[i] for i in range(1,len(self.scores))]]
194 | for i in range(0, len(self.scores)-1):
195 | wptable.append([sys_names[i]] + [""] * (len(self.scores)-1))
196 | for (left,right), my_wins in self.wins:
197 | winstr, pval = self.winstr_pval(my_wins)
198 | wptable[left+1][right] = f'{winstr} (p={fmt(pval)})'
199 | return [[""]+sys_names, [self.scorer.name()]+self.strs], wptable
200 |
201 | def print(self):
202 | aggregate_table, win_table = self.scores_to_tables()
203 | self.print_header('Aggregate Scores')
204 | print(f'{self.title}:')
205 | self.print_tabbed_table(aggregate_table)
206 | if win_table:
207 | self.print_tabbed_table(win_table)
208 |
209 | def plot(self, output_directory, output_fig_file, output_fig_format='pdf'):
210 | sys = [[score] for score in self.scores]
211 | if self.wins:
212 | sys_errs = [np.array([ [score-stat['lower_bound']], [stat['upper_bound']-score] ]) for (score,stat) in zip(self.scores, self.sys_stats)]
213 | else:
214 | sys_errs = None
215 | xticklabels = None
216 |
217 | make_bar_chart(sys,
218 | output_directory, output_fig_file,
219 | output_fig_format=output_fig_format,
220 | errs=sys_errs, ylabel=self.scorer.name(),
221 | xticklabels=xticklabels)
222 |
223 | def html_content(self, output_directory):
224 | aggregate_table, win_table = self.scores_to_tables()
225 | html = html_table(aggregate_table, title=self.title)
226 | if win_table:
227 | html += html_table(win_table, title=f'{self.scorer.name()} Wins')
228 | for ext in ('png', 'pdf'):
229 | self.plot(output_directory, self.output_fig_file, ext)
230 | html += html_img_reference(self.output_fig_file, 'Score Comparison')
231 | return html
232 |
233 | class WordReport(Report):
234 | def __init__(self, bucketer, statistics,
235 | acc_type, header,
236 | examples=None,
237 | bucket_cnts=None,
238 | bucket_intervals=None,
239 | src_sents=None,
240 | ref_sents=None, ref_labels=None,
241 | out_sents=None, out_labels=None,
242 | src_labels=None, ref_aligns=None,
243 | title=None):
244 | self.bucketer = bucketer
245 | self.statistics = [[s for s in stat] for stat in statistics]
246 | self.examples = examples
247 | self.bucket_cnts = bucket_cnts
248 | self.bucket_intervals = bucket_intervals
249 | self.src_sents = src_sents
250 | self.ref_sents = ref_sents
251 | self.ref_labels = ref_labels
252 | self.out_sents = out_sents
253 | self.out_labels = out_labels
254 | self.src_labels = src_labels
255 | self.ref_aligns = ref_aligns
256 | self.acc_type = acc_type
257 | self.header = header
258 | self.acc_type_map = {'prec': 3, 'rec': 4, 'fmeas': 5}
259 | self.output_fig_file = f'{next_fig_id()}-wordacc-{bucketer.name()}'
260 | self.title = title if title else f'word {acc_type} by {bucketer.name()} bucket'
261 |
262 | def print(self):
263 | acc_type_map = self.acc_type_map
264 | bucketer, statistics, acc_type, header = self.bucketer, self.statistics, self.acc_type, self.header
265 | self.print_header(header)
266 | acc_types = acc_type.split('+')
267 | for at in acc_types:
268 | if at not in acc_type_map:
269 | raise ValueError(f'Unknown accuracy type {at}')
270 | aid = acc_type_map[at]
271 | print(f'--- {self.title}')
272 | # first line
273 | print(f'{bucketer.name()}', end='')
274 | if self.bucket_cnts is not None:
275 | print(f'\t# words', end='')
276 | for sn in sys_names:
277 | print(f'\t{sn}', end='')
278 | print()
279 | # stats
280 | for i, bucket_str in enumerate(bucketer.bucket_strs):
281 | print(f'{bucket_str}', end='')
282 | if self.bucket_cnts is not None:
283 | print(f'\t{self.bucket_cnts[i]}', end='')
284 | for j, match in enumerate(statistics):
285 | print(f'\t{fmt(match[i][aid])}', end='')
286 | if self.bucket_intervals is not None:
287 | low, up = self.bucket_intervals[j][i][aid]
288 | print(f' [{fmt(low)}, {fmt(up)}]', end='')
289 | print()
290 | print()
291 |
292 | def plot(self, output_directory, output_fig_file, output_fig_format='pdf'):
293 | acc_types = self.acc_type.split('+')
294 | for at in acc_types:
295 | if at not in self.acc_type_map:
296 | raise ValueError(f'Unknown accuracy type {at}')
297 | aid = self.acc_type_map[at]
298 | sys = [[m[aid] for m in match] for match in self.statistics]
299 | xticklabels = [s for s in self.bucketer.bucket_strs]
300 |
301 | if self.bucket_intervals:
302 | errs = []
303 | for i, match in enumerate(sys):
304 | lows, ups = [], []
305 | for j, score in enumerate(match):
306 | low, up = self.bucket_intervals[i][j][aid]
307 | lows.append(score-low)
308 | ups.append(up-score)
309 | errs.append(np.array([lows, ups]) )
310 | else:
311 | errs = None
312 |
313 | make_bar_chart(sys,
314 | output_directory, output_fig_file,
315 | output_fig_format=output_fig_format,
316 | errs=errs,
317 | xlabel=self.bucketer.name(), ylabel=at,
318 | xticklabels=xticklabels)
319 |
320 | def highlight_words(self, sent, hls=None):
321 | if not hls:
322 | return ' '.join(sent)
323 | return ' '.join([f'{w}' if hl else w for (w,hl) in zip(sent, hls)])
324 |
325 | def write_examples(self, title, output_directory):
326 | # Create separate examples HTML file
327 | html = ''
328 | for bi, bucket_examples in enumerate(self.examples):
329 | html += f''
330 | html += tag_str('h3', f'Examples for Bucket {self.bucketer.bucket_strs[bi]}')
331 | for tag, examp_ids in bucket_examples:
332 | # Skip ones with no examples
333 | if len(examp_ids) == 0:
334 | continue
335 | html += tag_str('h4', tag)
336 | for eid in examp_ids:
337 | table = [['', 'Output']]
338 | # Find buckets for the examples if it's on the source side (will have alignments in this case)
339 | if self.ref_aligns:
340 | _, _, _, src_buckets, ref_aligns, ref_matches = \
341 | self.bucketer._calc_src_buckets_and_matches(self.src_sents[eid],
342 | self.src_labels[eid] if self.src_labels else None,
343 | self.ref_sents[eid],
344 | self.ref_aligns[eid],
345 | [x[eid] for x in self.out_sents])
346 | src_hls = [x == bi for x in src_buckets]
347 | table.append(['Src', self.highlight_words(self.src_sents[eid], src_hls)])
348 | ref_hls = [False for _ in self.ref_sents[eid]]
349 | out_hls = [[False for _ in x[eid]] for x in self.out_sents]
350 | for sid, tid in self.ref_aligns[eid]:
351 | if src_hls[sid]:
352 | ref_hls[tid] = True
353 | for rm, ohls in zip(ref_matches, out_hls):
354 | if rm[tid] >= 0:
355 | ohls[rm[tid]] = True
356 | # Find buckets for the examples if it's on the target side
357 | else:
358 | _, _, _, ref_buckets, out_buckets, out_matches = \
359 | self.bucketer._calc_trg_buckets_and_matches(self.ref_sents[eid],
360 | self.ref_labels[eid] if self.ref_labels else None,
361 | [x[eid] for x in self.out_sents],
362 | [x[eid] for x in self.out_labels] if self.out_labels else None)
363 | ref_hls = [x == bi for x in ref_buckets]
364 | out_hls = [[(b == bi and m >= 0) for (b,m) in zip(ob, om)] for (ob, om) in zip(out_buckets, out_matches)]
365 | table.append(['Ref', self.highlight_words(self.ref_sents[eid], ref_hls)])
366 | for sn, oss, ohl in itertools.zip_longest(sys_names, self.out_sents, out_hls):
367 | table.append([sn, self.highlight_words(oss[eid], ohl)])
368 | html += html_table(table, None)
369 | with open(f'{output_directory}/{self.output_fig_file}.html', 'w') as example_stream:
370 | example_stream.write(styled_html_message(title, html))
371 |
372 | def html_content(self, output_directory):
373 | acc_type_map = self.acc_type_map
374 | bucketer, matches, acc_type, header = self.bucketer, self.statistics, self.acc_type, self.header
375 | acc_types = acc_type.split('+')
376 |
377 | title = f'Word {acc_type} by {bucketer.name()} bucket' if not self.title else self.title
378 |
379 | if self.examples:
380 | self.write_examples(title, output_directory)
381 |
382 | # Create main HTML content
383 | html = ''
384 | for at in acc_types:
385 | if at not in acc_type_map:
386 | raise ValueError(f'Unknown accuracy type {at}')
387 | aid = acc_type_map[at]
388 | line = [bucketer.name()]
389 | if self.bucket_cnts is not None:
390 | line.append('# words')
391 | line += sys_names
392 | table = [line]
393 | if self.examples:
394 | table[0].append('Examples')
395 | for i, bs in enumerate(bucketer.bucket_strs):
396 | line = [bs]
397 | if self.bucket_cnts is not None:
398 | line.append(f'{self.bucket_cnts[i]}')
399 | for j, match in enumerate(matches):
400 | line.append(f'{fmt(match[i][aid])}')
401 | if self.bucket_intervals is not None:
402 | low, up = self.bucket_intervals[j][i][aid]
403 | line[-1] += f' [{fmt(low)}, {fmt(up)}]'
404 | if self.examples:
405 | line.append(f'Examples')
406 | table += [line]
407 | html += html_table(table, title, latex_ignore_cols={3})
408 | img_name = f'{self.output_fig_file}-{at}'
409 | for ext in ('png', 'pdf'):
410 | self.plot(output_directory, img_name, ext)
411 | html += html_img_reference(img_name, self.header)
412 | return html
413 |
414 | class NgramReport(Report):
415 | def __init__(self, scorelist, report_length, min_ngram_length, max_ngram_length,
416 | matches, compare_type, alpha, compare_directions=[(0, 1)], label_files=None, title=None):
417 | self.scorelist = scorelist
418 | self.report_length = report_length
419 | self.min_ngram_length = min_ngram_length
420 | self.max_ngram_length = max_ngram_length
421 | self.matches = matches
422 | self.compare_type = compare_type
423 | self.label_files = label_files
424 | self.alpha = alpha
425 | self.compare_directions = compare_directions
426 | self.title = title
427 |
428 | def print(self):
429 | report_length = self.report_length
430 | self.print_header('N-gram Difference Analysis')
431 | if self.title:
432 | print(f'--- {self.title}')
433 | else:
434 | print(f'--- min_ngram_length={self.min_ngram_length}, max_ngram_length={self.max_ngram_length}')
435 | print(f' report_length={report_length}, alpha={self.alpha}, compare_type={self.compare_type}')
436 |
437 | if self.label_files is not None:
438 | print(self.label_files)
439 |
440 | for i, (left, right) in enumerate(self.compare_directions):
441 | print(f'--- {report_length} n-grams where {sys_names[left]}>{sys_names[right]} in {self.compare_type}')
442 | for k, v in self.scorelist[i][:report_length]:
443 | print(f"{' '.join(k)}\t{fmt(v)} (sys{left+1}={self.matches[left][k]}, sys{right+1}={self.matches[right][k]})")
444 | print()
445 | print(f'--- {report_length} n-grams where {sys_names[right]}>{sys_names[left]} in {self.compare_type}')
446 | for k, v in reversed(self.scorelist[i][-report_length:]):
447 | print(f"{' '.join(k)}\t{fmt(v)} (sys{left+1}={self.matches[left][k]}, sys{right+1}={self.matches[right][k]})")
448 | print()
449 |
450 | def plot(self, output_directory, output_fig_file, output_fig_format='pdf'):
451 | raise NotImplementedError('Plotting is not implemented for n-gram reports')
452 |
453 | def html_content(self, output_directory=None):
454 | report_length = self.report_length
455 | if self.title:
456 | html = tag_str('p', self.title)
457 | else:
458 | html = tag_str('p', f'min_ngram_length={self.min_ngram_length}, max_ngram_length={self.max_ngram_length}')
459 | html += tag_str('p', f'report_length={report_length}, alpha={self.alpha}, compare_type={self.compare_type}')
460 | if self.label_files is not None:
461 | html += tag_str('p', self.label_files)
462 |
463 | for i, (left, right) in enumerate(self.compare_directions):
464 | title = f'{report_length} n-grams where {sys_names[left]}>{sys_names[right]} in {self.compare_type}'
465 | table = [['n-gram', self.compare_type, f'{sys_names[left]}', f'{sys_names[right]}']]
466 | table.extend([[' '.join(k), fmt(v), self.matches[left][k], self.matches[right][k]] for k, v in self.scorelist[i][:report_length]])
467 | html += html_table(table, title)
468 |
469 | title = f'{report_length} n-grams where {sys_names[right]}>{sys_names[left]} in {self.compare_type}'
470 | table = [['n-gram', self.compare_type, f'{sys_names[left]}', f'{sys_names[right]}']]
471 | table.extend([[' '.join(k), fmt(v), self.matches[left][k], self.matches[right][k]] for k, v in reversed(self.scorelist[i][-report_length:])])
472 | html += html_table(table, title)
473 | return html
474 |
475 | class SentenceReport(Report):
476 |
477 | def __init__(self, bucketer=None, sys_stats=None, statistic_type=None, scorer=None, bucket_cnts=None, bucket_intervals=None, title=None):
478 | self.bucketer = bucketer
479 | self.sys_stats = [[s for s in stat] for stat in sys_stats]
480 | self.statistic_type = statistic_type
481 | self.scorer = scorer
482 | self.bucket_cnts = bucket_cnts
483 | self.bucket_intervals = bucket_intervals
484 | self.yname = scorer.name() if statistic_type == 'score' else statistic_type
485 | self.yidstr = scorer.idstr() if statistic_type == 'score' else statistic_type
486 | self.output_fig_file = f'{next_fig_id()}-sent-{bucketer.idstr()}-{self.yidstr}'
487 | if title:
488 | self.title = title
489 | elif scorer:
490 | self.title = f'bucket type: {bucketer.name()}, statistic type: {scorer.name()}'
491 | else:
492 | self.title = f'bucket type: {bucketer.name()}, statistic type: {statistic_type}'
493 |
494 | def print(self):
495 | self.print_header('Sentence Bucket Analysis')
496 | print(f'--- {self.title}')
497 | # first line
498 | print(f'{self.bucketer.idstr()}', end='')
499 | if self.bucket_cnts is not None:
500 | print(f'\t# sents', end='')
501 | for sn in sys_names:
502 | print(f'\t{sn}', end='')
503 | print()
504 | for i, bs in enumerate(self.bucketer.bucket_strs):
505 | print(f'{bs}', end='')
506 | if self.bucket_cnts is not None:
507 | print(f'\t{self.bucket_cnts[i]}', end='')
508 | for j, stat in enumerate(self.sys_stats):
509 | print(f'\t{fmt(stat[i])}', end='')
510 | if self.bucket_intervals is not None:
511 | interval = self.bucket_intervals[j][i]
512 | low, up = interval['lower_bound'], interval['upper_bound']
513 | print(f' [{fmt(low)}, {fmt(up)}]', end='')
514 | print()
515 | print()
516 |
517 | def plot(self, output_directory='outputs', output_fig_file='word-acc', output_fig_format='pdf'):
518 | sys = self.sys_stats
519 | xticklabels = [s for s in self.bucketer.bucket_strs]
520 |
521 | if self.bucket_intervals:
522 | errs = []
523 | for i, stat in enumerate(sys):
524 | lows, ups = [], []
525 | for j, score in enumerate(stat):
526 | interval = self.bucket_intervals[i][j]
527 | low, up = interval['lower_bound'], interval['upper_bound']
528 | lows.append(score-low)
529 | ups.append(up-score)
530 | errs.append(np.array([lows, ups]) )
531 | else:
532 | errs = None
533 |
534 | make_bar_chart(sys,
535 | output_directory, output_fig_file,
536 | output_fig_format=output_fig_format,
537 | errs=errs,
538 | xlabel=self.bucketer.name(), ylabel=self.yname,
539 | xticklabels=xticklabels)
540 |
541 | def html_content(self, output_directory=None):
542 | line = [self.bucketer.idstr()]
543 | if self.bucket_cnts is not None:
544 | line.append('# sents')
545 | line += sys_names
546 | table = [ line ]
547 | for i, bs in enumerate(self.bucketer.bucket_strs):
548 | line = [bs]
549 | if self.bucket_cnts is not None:
550 | line.append(f'\t{self.bucket_cnts[i]}')
551 | for j, stat in enumerate(self.sys_stats):
552 | line.append(fmt(stat[i]))
553 | if self.bucket_intervals is not None:
554 | interval = self.bucket_intervals[j][i]
555 | low, up = interval['lower_bound'], interval['upper_bound']
556 | line[-1] += f' [{fmt(low)}, {fmt(up)}]'
557 | table.extend([line])
558 | html = html_table(table, self.title)
559 | for ext in ('png', 'pdf'):
560 | self.plot(output_directory, self.output_fig_file, ext)
561 | html += html_img_reference(self.output_fig_file, 'Sentence Bucket Analysis')
562 | return html
563 |
564 | class SentenceExampleReport(Report):
565 |
566 | def __init__(self, report_length=None, scorediff_lists=None, scorer=None, ref=None, outs=None, src=None, compare_directions=[(0, 1)], title=None):
567 | self.report_length = report_length
568 | self.scorediff_lists = scorediff_lists
569 | self.scorer = scorer
570 | self.ref = ref
571 | self.outs = outs
572 | self.src = src
573 | self.compare_directions = compare_directions
574 | self.title = title
575 |
576 | def print(self):
577 | self.print_header('Sentence Examples Analysis')
578 | report_length = self.report_length
579 | for cnt, (left, right) in enumerate(self.compare_directions):
580 | ref, out1, out2 = self.ref, self.outs[left], self.outs[right]
581 | sleft, sright = sys_names[left], sys_names[right]
582 | print(f'--- {report_length} sentences where {sleft}>{sright} at {self.scorer.name()}')
583 | for bdiff, s1, s2, str1, str2, i in self.scorediff_lists[cnt][:report_length]:
584 | print(f"{sleft}-{sright}={fmt(-bdiff)}, {sleft}={fmt(s1)}, {sright}={fmt(s2)}")
585 | if self.src and self.src[i]:
586 | print(f"Src: {' '.join(self.src[i])}")
587 | print (
588 | f"Ref: {' '.join(ref[i])}\n"
589 | f"{sleft}: {' '.join(out1[i])}\n"
590 | f"{sright}: {' '.join(out2[i])}\n"
591 | )
592 |
593 | print(f'--- {report_length} sentences where {sright}>{sleft} at {self.scorer.name()}')
594 | for bdiff, s1, s2, str1, str2, i in self.scorediff_lists[cnt][-report_length:]:
595 | print(f"{sleft}-{sright}={fmt(-bdiff)}, {sleft}={fmt(s1)}, {sright}={fmt(s2)}")
596 | if self.src and self.src[i]:
597 | print(f"Src: {' '.join(self.src[i])}")
598 | print (
599 | f"Ref: {' '.join(ref[i])}\n"
600 | f"{sleft}: {' '.join(out1[i])}\n"
601 | f"{sright}: {' '.join(out2[i])}\n"
602 | )
603 |
604 | def plot(self, output_directory, output_fig_file, output_fig_format='pdf'):
605 | pass
606 |
607 | def html_content(self, output_directory=None):
608 | report_length = self.report_length
609 | for cnt, (left, right) in enumerate(self.compare_directions):
610 | sleft, sright = sys_names[left], sys_names[right]
611 | ref, out1, out2 = self.ref, self.outs[left], self.outs[right]
612 | html = tag_str('h4', f'{report_length} sentences where {sleft}>{sright} at {self.scorer.name()}')
613 | for bdiff, s1, s2, str1, str2, i in self.scorediff_lists[cnt][:report_length]:
614 | table = [['', 'Output', f'{self.scorer.idstr()}']]
615 | if self.src and self.src[i]:
616 | table.append(['Src', ' '.join(self.src[i]), ''])
617 | table += [
618 | ['Ref', ' '.join(ref[i]), ''],
619 | [f'{sleft}', ' '.join(out1[i]), fmt(s1)],
620 | [f'{sright}', ' '.join(out2[i]), fmt(s2)]
621 | ]
622 |
623 | html += html_table(table, None)
624 |
625 | html += tag_str('h4', f'{report_length} sentences where {sright}>{sleft} at {self.scorer.name()}')
626 | for bdiff, s1, s2, str1, str2, i in self.scorediff_lists[cnt][-report_length:]:
627 | table = [['', 'Output', f'{self.scorer.idstr()}']]
628 | if self.src and self.src[i]:
629 | table.append(['Src', ' '.join(self.src[i]), ''])
630 | table += [
631 | ['Ref', ' '.join(ref[i]), ''],
632 | [f'{sleft}', ' '.join(out1[i]), fmt(s1)],
633 | [f'{sright}', ' '.join(out2[i]), fmt(s2)]
634 | ]
635 |
636 | html += html_table(table, None)
637 |
638 | return html
639 |
640 |
641 | def tag_str(tag, str, new_line=''):
642 | return f'<{tag}>{new_line} {str} {new_line}{tag}>'
643 |
644 | def html_table(table, title=None, bold_rows=1, bold_cols=1, latex_ignore_cols={}):
645 | html = '
\n'
646 | if title is not None:
647 | html += tag_str('caption', title)
648 | for i, row in enumerate(table):
649 | tag_type = 'th' if (i < bold_rows) else 'td'
650 | table_row = '\n '.join(tag_str('th' if j < bold_cols else tag_type, rdata) for (j, rdata) in enumerate(row))
651 | html += tag_str('tr', table_row)
652 | html += '\n
\n '
653 |
654 | tab_id = next_tab_id()
655 | latex_code = "\\begin{table}[t]\n \\centering\n"
656 | cs = ['c'] * len(table[0])
657 | if bold_cols != 0:
658 | cs[bold_cols-1] = 'c||'
659 | latex_code += " \\begin{tabular}{"+''.join(cs)+"}\n"
660 | for i, row in enumerate(table):
661 | latex_code += ' & '.join([fmt(x) for c_i, x in enumerate(row) if c_i not in latex_ignore_cols]) + (' \\\\\n' if i != bold_rows-1 else ' \\\\ \\hline \\hline\n')
662 | latex_code += " \\end{tabular}\n \\caption{Caption}\n \\label{tab:table"+tab_id+"}\n\\end{table}"
663 |
664 | html += (f' ' +
665 | f'