├── .github └── dependabot.yml ├── .gitignore ├── LICENSE ├── README.md ├── example-full-named-entity-evaluation.ipynb ├── ner_evaluation ├── .coverage ├── __init__.py ├── ner_eval.py └── tests │ ├── test_evaluator.py │ └── test_ner_evaluation.py ├── requirements.txt └── setup.cfg /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To start with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .coverage 2 | .mypy_cache/ 3 | ner_evaluation/tests/__pycache__/ 4 | ner_evaluation/__pycache__/ 5 | .ipynb_checkpoints/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 David S. Batista 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Named Entity Evaluation as in SemEval 2013 task 9.1 2 | 3 | My own implementation, with lots of input from [Matt Upson](https://github.com/ivyleavedtoadflax), of the Named-Entity Recognition evaluation metrics as defined by the SemEval 2013 - 9.1 task. 4 | 5 | This evaluation metrics go belong a simple token/tag based schema, and consider diferent scenarios based on wether all the tokens that belong to a named entity were classified or not, and also wether the correct entity type was assigned. 6 | 7 | You can find a more detailed explanation in the following blog post: 8 | 9 | * http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/ 10 | 11 | 12 | ## Notes: 13 | 14 | In scenarios IV and VI the entity type of the `true` and `pred` does not match, in both cases we only scored against the true entity, not the predicted one. You can argue that the predicted entity could also be scored as spurious, but according to the definition of `spurious`: 15 | 16 | * Spurius (SPU) : system produces a response which doesn’t exist in the golden annotation; 17 | 18 | In this case it exists an annotation, but only with a different entity type, so we assume it's only incorrect 19 | 20 | 21 | ## Example: 22 | 23 | You can see a working example on the following notebook: 24 | 25 | - [example-full-named-entity-evaluation.ipynb](example-full-named-entity-evaluation.ipynb) 26 | 27 | Note that in order to run that example you need to have installed: 28 | 29 | - sklearn 30 | - nltk 31 | - sklearn_crfsuite 32 | 33 | For testing you will need: 34 | 35 | - pytest 36 | - coverage 37 | 38 | These dependencies can be installed by running `pip3 install -r requirements.txt` 39 | 40 | ## Code tests and tests coverage: 41 | 42 | To run tests: 43 | 44 | `coverage run --rcfile=setup.cfg -m pytest` 45 | 46 | To produce a coverage report: 47 | 48 | `coverage report` 49 | -------------------------------------------------------------------------------- /example-full-named-entity-evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import nltk\n", 10 | "import sklearn_crfsuite\n", 11 | "\n", 12 | "from copy import deepcopy\n", 13 | "from collections import defaultdict\n", 14 | "\n", 15 | "from sklearn_crfsuite.metrics import flat_classification_report\n", 16 | "\n", 17 | "from ner_evaluation.ner_eval import collect_named_entities\n", 18 | "from ner_evaluation.ner_eval import compute_metrics\n", 19 | "from ner_evaluation.ner_eval import compute_precision_recall_wrapper" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "## Train a CRF on the CoNLL 2002 NER Spanish data" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "nltk.corpus.conll2002.fileids()\n", 36 | "train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))\n", 37 | "test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def word2features(sent, i):\n", 47 | " word = sent[i][0]\n", 48 | " postag = sent[i][1]\n", 49 | "\n", 50 | " features = {\n", 51 | " 'bias': 1.0,\n", 52 | " 'word.lower()': word.lower(),\n", 53 | " 'word[-3:]': word[-3:],\n", 54 | " 'word[-2:]': word[-2:],\n", 55 | " 'word.isupper()': word.isupper(),\n", 56 | " 'word.istitle()': word.istitle(),\n", 57 | " 'word.isdigit()': word.isdigit(),\n", 58 | " 'postag': postag,\n", 59 | " 'postag[:2]': postag[:2],\n", 60 | " }\n", 61 | " if i > 0:\n", 62 | " word1 = sent[i-1][0]\n", 63 | " postag1 = sent[i-1][1]\n", 64 | " features.update({\n", 65 | " '-1:word.lower()': word1.lower(),\n", 66 | " '-1:word.istitle()': word1.istitle(),\n", 67 | " '-1:word.isupper()': word1.isupper(),\n", 68 | " '-1:postag': postag1,\n", 69 | " '-1:postag[:2]': postag1[:2],\n", 70 | " })\n", 71 | " else:\n", 72 | " features['BOS'] = True\n", 73 | "\n", 74 | " if i < len(sent)-1:\n", 75 | " word1 = sent[i+1][0]\n", 76 | " postag1 = sent[i+1][1]\n", 77 | " features.update({\n", 78 | " '+1:word.lower()': word1.lower(),\n", 79 | " '+1:word.istitle()': word1.istitle(),\n", 80 | " '+1:word.isupper()': word1.isupper(),\n", 81 | " '+1:postag': postag1,\n", 82 | " '+1:postag[:2]': postag1[:2],\n", 83 | " })\n", 84 | " else:\n", 85 | " features['EOS'] = True\n", 86 | "\n", 87 | " return features\n", 88 | "\n", 89 | "\n", 90 | "def sent2features(sent):\n", 91 | " return [word2features(sent, i) for i in range(len(sent))]\n", 92 | "\n", 93 | "\n", 94 | "def sent2labels(sent):\n", 95 | " return [label for token, postag, label in sent]\n", 96 | "\n", 97 | "\n", 98 | "def sent2tokens(sent):\n", 99 | " return [token for token, postag, label in sent]" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## Feature Extraction" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 4, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "CPU times: user 761 ms, sys: 48.4 ms, total: 809 ms\n", 119 | "Wall time: 809 ms\n" 120 | ] 121 | } 122 | ], 123 | "source": [ 124 | "%%time\n", 125 | "X_train = [sent2features(s) for s in train_sents]\n", 126 | "y_train = [sent2labels(s) for s in train_sents]\n", 127 | "\n", 128 | "X_test = [sent2features(s) for s in test_sents]\n", 129 | "y_test = [sent2labels(s) for s in test_sents]" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "## Training" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 5, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "CPU times: user 28.7 s, sys: 36.4 ms, total: 28.7 s\n", 149 | "Wall time: 28.7 s\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "%%time\n", 155 | "crf = sklearn_crfsuite.CRF(\n", 156 | " algorithm='lbfgs',\n", 157 | " c1=0.1,\n", 158 | " c2=0.1,\n", 159 | " max_iterations=100,\n", 160 | " all_possible_transitions=True\n", 161 | ")\n", 162 | "crf.fit(X_train, y_train)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "## Performance per label type per token" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 25, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "name": "stdout", 179 | "output_type": "stream", 180 | "text": [ 181 | " precision recall f1-score support\n", 182 | "\n", 183 | " B-LOC 0.810 0.784 0.797 1084\n", 184 | " I-LOC 0.690 0.637 0.662 325\n", 185 | " B-MISC 0.731 0.569 0.640 339\n", 186 | " I-MISC 0.699 0.589 0.639 557\n", 187 | " B-ORG 0.807 0.832 0.820 1400\n", 188 | " I-ORG 0.852 0.786 0.818 1104\n", 189 | " B-PER 0.850 0.884 0.867 735\n", 190 | " I-PER 0.893 0.943 0.917 634\n", 191 | "\n", 192 | " micro avg 0.813 0.787 0.799 6178\n", 193 | " macro avg 0.791 0.753 0.770 6178\n", 194 | "weighted avg 0.809 0.787 0.796 6178\n", 195 | "\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "y_pred = crf.predict(X_test)\n", 201 | "labels = list(crf.classes_)\n", 202 | "labels.remove('O') # remove 'O' label from evaluation\n", 203 | "sorted_labels = sorted(labels,key=lambda name: (name[1:], name[0])) # group B and I results\n", 204 | "print(flat_classification_report(y_test, y_pred, labels=sorted_labels, digits=3))" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "## Performance over full named-entity" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 26, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "test_sents_labels = []\n", 221 | "for sentence in test_sents:\n", 222 | " sentence = [token[2] for token in sentence]\n", 223 | " test_sents_labels.append(sentence)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 27, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "index = 2\n", 233 | "true = collect_named_entities(test_sents_labels[index])\n", 234 | "pred = collect_named_entities(y_pred[index])" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 28, 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "[Entity(e_type='MISC', start_offset=12, end_offset=12),\n", 246 | " Entity(e_type='LOC', start_offset=15, end_offset=15),\n", 247 | " Entity(e_type='PER', start_offset=37, end_offset=39),\n", 248 | " Entity(e_type='ORG', start_offset=45, end_offset=46)]" 249 | ] 250 | }, 251 | "execution_count": 28, 252 | "metadata": {}, 253 | "output_type": "execute_result" 254 | } 255 | ], 256 | "source": [ 257 | "true" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 29, 263 | "metadata": {}, 264 | "outputs": [ 265 | { 266 | "data": { 267 | "text/plain": [ 268 | "[Entity(e_type='MISC', start_offset=12, end_offset=12),\n", 269 | " Entity(e_type='LOC', start_offset=15, end_offset=15),\n", 270 | " Entity(e_type='PER', start_offset=37, end_offset=39),\n", 271 | " Entity(e_type='LOC', start_offset=45, end_offset=46)]" 272 | ] 273 | }, 274 | "execution_count": 29, 275 | "metadata": {}, 276 | "output_type": "execute_result" 277 | } 278 | ], 279 | "source": [ 280 | "pred" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 30, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "index = 2\n", 290 | "true = collect_named_entities(test_sents_labels[index])\n", 291 | "pred = collect_named_entities(y_pred[index])" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 31, 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "data": { 301 | "text/plain": [ 302 | "[Entity(e_type='MISC', start_offset=12, end_offset=12),\n", 303 | " Entity(e_type='LOC', start_offset=15, end_offset=15),\n", 304 | " Entity(e_type='PER', start_offset=37, end_offset=39),\n", 305 | " Entity(e_type='ORG', start_offset=45, end_offset=46)]" 306 | ] 307 | }, 308 | "execution_count": 31, 309 | "metadata": {}, 310 | "output_type": "execute_result" 311 | } 312 | ], 313 | "source": [ 314 | "true" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 32, 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "data": { 324 | "text/plain": [ 325 | "[Entity(e_type='MISC', start_offset=12, end_offset=12),\n", 326 | " Entity(e_type='LOC', start_offset=15, end_offset=15),\n", 327 | " Entity(e_type='PER', start_offset=37, end_offset=39),\n", 328 | " Entity(e_type='LOC', start_offset=45, end_offset=46)]" 329 | ] 330 | }, 331 | "execution_count": 32, 332 | "metadata": {}, 333 | "output_type": "execute_result" 334 | } 335 | ], 336 | "source": [ 337 | "pred" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 33, 343 | "metadata": {}, 344 | "outputs": [ 345 | { 346 | "data": { 347 | "text/plain": [ 348 | "({'strict': {'correct': 3,\n", 349 | " 'incorrect': 1,\n", 350 | " 'partial': 0,\n", 351 | " 'missed': 0,\n", 352 | " 'spurious': 0,\n", 353 | " 'precision': 0,\n", 354 | " 'recall': 0,\n", 355 | " 'actual': 4,\n", 356 | " 'possible': 4},\n", 357 | " 'ent_type': {'correct': 3,\n", 358 | " 'incorrect': 1,\n", 359 | " 'partial': 0,\n", 360 | " 'missed': 0,\n", 361 | " 'spurious': 0,\n", 362 | " 'precision': 0,\n", 363 | " 'recall': 0,\n", 364 | " 'actual': 4,\n", 365 | " 'possible': 4},\n", 366 | " 'partial': {'correct': 4,\n", 367 | " 'incorrect': 0,\n", 368 | " 'partial': 0,\n", 369 | " 'missed': 0,\n", 370 | " 'spurious': 0,\n", 371 | " 'precision': 0,\n", 372 | " 'recall': 0,\n", 373 | " 'actual': 4,\n", 374 | " 'possible': 4},\n", 375 | " 'exact': {'correct': 4,\n", 376 | " 'incorrect': 0,\n", 377 | " 'partial': 0,\n", 378 | " 'missed': 0,\n", 379 | " 'spurious': 0,\n", 380 | " 'precision': 0,\n", 381 | " 'recall': 0,\n", 382 | " 'actual': 4,\n", 383 | " 'possible': 4}},\n", 384 | " {'LOC': {'strict': {'correct': 1,\n", 385 | " 'incorrect': 0,\n", 386 | " 'partial': 0,\n", 387 | " 'missed': 0,\n", 388 | " 'spurious': 0,\n", 389 | " 'precision': 0,\n", 390 | " 'recall': 0,\n", 391 | " 'actual': 1,\n", 392 | " 'possible': 1},\n", 393 | " 'ent_type': {'correct': 1,\n", 394 | " 'incorrect': 0,\n", 395 | " 'partial': 0,\n", 396 | " 'missed': 0,\n", 397 | " 'spurious': 0,\n", 398 | " 'precision': 0,\n", 399 | " 'recall': 0,\n", 400 | " 'actual': 1,\n", 401 | " 'possible': 1},\n", 402 | " 'partial': {'correct': 1,\n", 403 | " 'incorrect': 0,\n", 404 | " 'partial': 0,\n", 405 | " 'missed': 0,\n", 406 | " 'spurious': 0,\n", 407 | " 'precision': 0,\n", 408 | " 'recall': 0,\n", 409 | " 'actual': 1,\n", 410 | " 'possible': 1},\n", 411 | " 'exact': {'correct': 1,\n", 412 | " 'incorrect': 0,\n", 413 | " 'partial': 0,\n", 414 | " 'missed': 0,\n", 415 | " 'spurious': 0,\n", 416 | " 'precision': 0,\n", 417 | " 'recall': 0,\n", 418 | " 'actual': 1,\n", 419 | " 'possible': 1}},\n", 420 | " 'MISC': {'strict': {'correct': 1,\n", 421 | " 'incorrect': 0,\n", 422 | " 'partial': 0,\n", 423 | " 'missed': 0,\n", 424 | " 'spurious': 0,\n", 425 | " 'precision': 0,\n", 426 | " 'recall': 0,\n", 427 | " 'actual': 1,\n", 428 | " 'possible': 1},\n", 429 | " 'ent_type': {'correct': 1,\n", 430 | " 'incorrect': 0,\n", 431 | " 'partial': 0,\n", 432 | " 'missed': 0,\n", 433 | " 'spurious': 0,\n", 434 | " 'precision': 0,\n", 435 | " 'recall': 0,\n", 436 | " 'actual': 1,\n", 437 | " 'possible': 1},\n", 438 | " 'partial': {'correct': 1,\n", 439 | " 'incorrect': 0,\n", 440 | " 'partial': 0,\n", 441 | " 'missed': 0,\n", 442 | " 'spurious': 0,\n", 443 | " 'precision': 0,\n", 444 | " 'recall': 0,\n", 445 | " 'actual': 1,\n", 446 | " 'possible': 1},\n", 447 | " 'exact': {'correct': 1,\n", 448 | " 'incorrect': 0,\n", 449 | " 'partial': 0,\n", 450 | " 'missed': 0,\n", 451 | " 'spurious': 0,\n", 452 | " 'precision': 0,\n", 453 | " 'recall': 0,\n", 454 | " 'actual': 1,\n", 455 | " 'possible': 1}},\n", 456 | " 'PER': {'strict': {'correct': 1,\n", 457 | " 'incorrect': 0,\n", 458 | " 'partial': 0,\n", 459 | " 'missed': 0,\n", 460 | " 'spurious': 0,\n", 461 | " 'precision': 0,\n", 462 | " 'recall': 0,\n", 463 | " 'actual': 1,\n", 464 | " 'possible': 1},\n", 465 | " 'ent_type': {'correct': 1,\n", 466 | " 'incorrect': 0,\n", 467 | " 'partial': 0,\n", 468 | " 'missed': 0,\n", 469 | " 'spurious': 0,\n", 470 | " 'precision': 0,\n", 471 | " 'recall': 0,\n", 472 | " 'actual': 1,\n", 473 | " 'possible': 1},\n", 474 | " 'partial': {'correct': 1,\n", 475 | " 'incorrect': 0,\n", 476 | " 'partial': 0,\n", 477 | " 'missed': 0,\n", 478 | " 'spurious': 0,\n", 479 | " 'precision': 0,\n", 480 | " 'recall': 0,\n", 481 | " 'actual': 1,\n", 482 | " 'possible': 1},\n", 483 | " 'exact': {'correct': 1,\n", 484 | " 'incorrect': 0,\n", 485 | " 'partial': 0,\n", 486 | " 'missed': 0,\n", 487 | " 'spurious': 0,\n", 488 | " 'precision': 0,\n", 489 | " 'recall': 0,\n", 490 | " 'actual': 1,\n", 491 | " 'possible': 1}},\n", 492 | " 'ORG': {'strict': {'correct': 0,\n", 493 | " 'incorrect': 1,\n", 494 | " 'partial': 0,\n", 495 | " 'missed': 0,\n", 496 | " 'spurious': 0,\n", 497 | " 'precision': 0,\n", 498 | " 'recall': 0,\n", 499 | " 'actual': 1,\n", 500 | " 'possible': 1},\n", 501 | " 'ent_type': {'correct': 0,\n", 502 | " 'incorrect': 1,\n", 503 | " 'partial': 0,\n", 504 | " 'missed': 0,\n", 505 | " 'spurious': 0,\n", 506 | " 'precision': 0,\n", 507 | " 'recall': 0,\n", 508 | " 'actual': 1,\n", 509 | " 'possible': 1},\n", 510 | " 'partial': {'correct': 1,\n", 511 | " 'incorrect': 0,\n", 512 | " 'partial': 0,\n", 513 | " 'missed': 0,\n", 514 | " 'spurious': 0,\n", 515 | " 'precision': 0,\n", 516 | " 'recall': 0,\n", 517 | " 'actual': 1,\n", 518 | " 'possible': 1},\n", 519 | " 'exact': {'correct': 1,\n", 520 | " 'incorrect': 0,\n", 521 | " 'partial': 0,\n", 522 | " 'missed': 0,\n", 523 | " 'spurious': 0,\n", 524 | " 'precision': 0,\n", 525 | " 'recall': 0,\n", 526 | " 'actual': 1,\n", 527 | " 'possible': 1}}})" 528 | ] 529 | }, 530 | "execution_count": 33, 531 | "metadata": {}, 532 | "output_type": "execute_result" 533 | } 534 | ], 535 | "source": [ 536 | "compute_metrics(true, pred, ['LOC', 'MISC', 'PER', 'ORG'])" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 34, 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [ 545 | "to_test = [2,4,12,14]" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 35, 551 | "metadata": {}, 552 | "outputs": [], 553 | "source": [ 554 | "index = 2\n", 555 | "true_named_entities_type = defaultdict(list)\n", 556 | "pred_named_entities_type = defaultdict(list)\n", 557 | "\n", 558 | "for true in collect_named_entities(test_sents_labels[index]):\n", 559 | " true_named_entities_type[true.e_type].append(true)\n", 560 | "\n", 561 | "for pred in collect_named_entities(y_pred[index]):\n", 562 | " pred_named_entities_type[pred.e_type].append(pred)" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": 36, 568 | "metadata": {}, 569 | "outputs": [ 570 | { 571 | "data": { 572 | "text/plain": [ 573 | "defaultdict(list,\n", 574 | " {'MISC': [Entity(e_type='MISC', start_offset=12, end_offset=12)],\n", 575 | " 'LOC': [Entity(e_type='LOC', start_offset=15, end_offset=15)],\n", 576 | " 'PER': [Entity(e_type='PER', start_offset=37, end_offset=39)],\n", 577 | " 'ORG': [Entity(e_type='ORG', start_offset=45, end_offset=46)]})" 578 | ] 579 | }, 580 | "execution_count": 36, 581 | "metadata": {}, 582 | "output_type": "execute_result" 583 | } 584 | ], 585 | "source": [ 586 | "true_named_entities_type" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": 37, 592 | "metadata": {}, 593 | "outputs": [ 594 | { 595 | "data": { 596 | "text/plain": [ 597 | "defaultdict(list,\n", 598 | " {'MISC': [Entity(e_type='MISC', start_offset=12, end_offset=12)],\n", 599 | " 'LOC': [Entity(e_type='LOC', start_offset=15, end_offset=15),\n", 600 | " Entity(e_type='LOC', start_offset=45, end_offset=46)],\n", 601 | " 'PER': [Entity(e_type='PER', start_offset=37, end_offset=39)]})" 602 | ] 603 | }, 604 | "execution_count": 37, 605 | "metadata": {}, 606 | "output_type": "execute_result" 607 | } 608 | ], 609 | "source": [ 610 | "pred_named_entities_type" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": 38, 616 | "metadata": {}, 617 | "outputs": [ 618 | { 619 | "data": { 620 | "text/plain": [ 621 | "[Entity(e_type='LOC', start_offset=15, end_offset=15)]" 622 | ] 623 | }, 624 | "execution_count": 38, 625 | "metadata": {}, 626 | "output_type": "execute_result" 627 | } 628 | ], 629 | "source": [ 630 | "true_named_entities_type['LOC']" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 39, 636 | "metadata": {}, 637 | "outputs": [ 638 | { 639 | "data": { 640 | "text/plain": [ 641 | "[Entity(e_type='LOC', start_offset=15, end_offset=15),\n", 642 | " Entity(e_type='LOC', start_offset=45, end_offset=46)]" 643 | ] 644 | }, 645 | "execution_count": 39, 646 | "metadata": {}, 647 | "output_type": "execute_result" 648 | } 649 | ], 650 | "source": [ 651 | "pred_named_entities_type['LOC']" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": 40, 657 | "metadata": {}, 658 | "outputs": [ 659 | { 660 | "data": { 661 | "text/plain": [ 662 | "({'strict': {'correct': 1,\n", 663 | " 'incorrect': 0,\n", 664 | " 'partial': 0,\n", 665 | " 'missed': 0,\n", 666 | " 'spurious': 1,\n", 667 | " 'precision': 0,\n", 668 | " 'recall': 0,\n", 669 | " 'actual': 2,\n", 670 | " 'possible': 1},\n", 671 | " 'ent_type': {'correct': 1,\n", 672 | " 'incorrect': 0,\n", 673 | " 'partial': 0,\n", 674 | " 'missed': 0,\n", 675 | " 'spurious': 1,\n", 676 | " 'precision': 0,\n", 677 | " 'recall': 0,\n", 678 | " 'actual': 2,\n", 679 | " 'possible': 1},\n", 680 | " 'partial': {'correct': 1,\n", 681 | " 'incorrect': 0,\n", 682 | " 'partial': 0,\n", 683 | " 'missed': 0,\n", 684 | " 'spurious': 1,\n", 685 | " 'precision': 0,\n", 686 | " 'recall': 0,\n", 687 | " 'actual': 2,\n", 688 | " 'possible': 1},\n", 689 | " 'exact': {'correct': 1,\n", 690 | " 'incorrect': 0,\n", 691 | " 'partial': 0,\n", 692 | " 'missed': 0,\n", 693 | " 'spurious': 1,\n", 694 | " 'precision': 0,\n", 695 | " 'recall': 0,\n", 696 | " 'actual': 2,\n", 697 | " 'possible': 1}},\n", 698 | " {'LOC': {'strict': {'correct': 1,\n", 699 | " 'incorrect': 0,\n", 700 | " 'partial': 0,\n", 701 | " 'missed': 0,\n", 702 | " 'spurious': 1,\n", 703 | " 'precision': 0,\n", 704 | " 'recall': 0,\n", 705 | " 'actual': 2,\n", 706 | " 'possible': 1},\n", 707 | " 'ent_type': {'correct': 1,\n", 708 | " 'incorrect': 0,\n", 709 | " 'partial': 0,\n", 710 | " 'missed': 0,\n", 711 | " 'spurious': 1,\n", 712 | " 'precision': 0,\n", 713 | " 'recall': 0,\n", 714 | " 'actual': 2,\n", 715 | " 'possible': 1},\n", 716 | " 'partial': {'correct': 1,\n", 717 | " 'incorrect': 0,\n", 718 | " 'partial': 0,\n", 719 | " 'missed': 0,\n", 720 | " 'spurious': 1,\n", 721 | " 'precision': 0,\n", 722 | " 'recall': 0,\n", 723 | " 'actual': 2,\n", 724 | " 'possible': 1},\n", 725 | " 'exact': {'correct': 1,\n", 726 | " 'incorrect': 0,\n", 727 | " 'partial': 0,\n", 728 | " 'missed': 0,\n", 729 | " 'spurious': 1,\n", 730 | " 'precision': 0,\n", 731 | " 'recall': 0,\n", 732 | " 'actual': 2,\n", 733 | " 'possible': 1}},\n", 734 | " 'MISC': {'strict': {'correct': 0,\n", 735 | " 'incorrect': 0,\n", 736 | " 'partial': 0,\n", 737 | " 'missed': 0,\n", 738 | " 'spurious': 1,\n", 739 | " 'precision': 0,\n", 740 | " 'recall': 0,\n", 741 | " 'actual': 1,\n", 742 | " 'possible': 0},\n", 743 | " 'ent_type': {'correct': 0,\n", 744 | " 'incorrect': 0,\n", 745 | " 'partial': 0,\n", 746 | " 'missed': 0,\n", 747 | " 'spurious': 1,\n", 748 | " 'precision': 0,\n", 749 | " 'recall': 0,\n", 750 | " 'actual': 1,\n", 751 | " 'possible': 0},\n", 752 | " 'partial': {'correct': 0,\n", 753 | " 'incorrect': 0,\n", 754 | " 'partial': 0,\n", 755 | " 'missed': 0,\n", 756 | " 'spurious': 1,\n", 757 | " 'precision': 0,\n", 758 | " 'recall': 0,\n", 759 | " 'actual': 1,\n", 760 | " 'possible': 0},\n", 761 | " 'exact': {'correct': 0,\n", 762 | " 'incorrect': 0,\n", 763 | " 'partial': 0,\n", 764 | " 'missed': 0,\n", 765 | " 'spurious': 1,\n", 766 | " 'precision': 0,\n", 767 | " 'recall': 0,\n", 768 | " 'actual': 1,\n", 769 | " 'possible': 0}},\n", 770 | " 'PER': {'strict': {'correct': 0,\n", 771 | " 'incorrect': 0,\n", 772 | " 'partial': 0,\n", 773 | " 'missed': 0,\n", 774 | " 'spurious': 1,\n", 775 | " 'precision': 0,\n", 776 | " 'recall': 0,\n", 777 | " 'actual': 1,\n", 778 | " 'possible': 0},\n", 779 | " 'ent_type': {'correct': 0,\n", 780 | " 'incorrect': 0,\n", 781 | " 'partial': 0,\n", 782 | " 'missed': 0,\n", 783 | " 'spurious': 1,\n", 784 | " 'precision': 0,\n", 785 | " 'recall': 0,\n", 786 | " 'actual': 1,\n", 787 | " 'possible': 0},\n", 788 | " 'partial': {'correct': 0,\n", 789 | " 'incorrect': 0,\n", 790 | " 'partial': 0,\n", 791 | " 'missed': 0,\n", 792 | " 'spurious': 1,\n", 793 | " 'precision': 0,\n", 794 | " 'recall': 0,\n", 795 | " 'actual': 1,\n", 796 | " 'possible': 0},\n", 797 | " 'exact': {'correct': 0,\n", 798 | " 'incorrect': 0,\n", 799 | " 'partial': 0,\n", 800 | " 'missed': 0,\n", 801 | " 'spurious': 1,\n", 802 | " 'precision': 0,\n", 803 | " 'recall': 0,\n", 804 | " 'actual': 1,\n", 805 | " 'possible': 0}},\n", 806 | " 'ORG': {'strict': {'correct': 0,\n", 807 | " 'incorrect': 0,\n", 808 | " 'partial': 0,\n", 809 | " 'missed': 0,\n", 810 | " 'spurious': 1,\n", 811 | " 'precision': 0,\n", 812 | " 'recall': 0,\n", 813 | " 'actual': 1,\n", 814 | " 'possible': 0},\n", 815 | " 'ent_type': {'correct': 0,\n", 816 | " 'incorrect': 0,\n", 817 | " 'partial': 0,\n", 818 | " 'missed': 0,\n", 819 | " 'spurious': 1,\n", 820 | " 'precision': 0,\n", 821 | " 'recall': 0,\n", 822 | " 'actual': 1,\n", 823 | " 'possible': 0},\n", 824 | " 'partial': {'correct': 0,\n", 825 | " 'incorrect': 0,\n", 826 | " 'partial': 0,\n", 827 | " 'missed': 0,\n", 828 | " 'spurious': 1,\n", 829 | " 'precision': 0,\n", 830 | " 'recall': 0,\n", 831 | " 'actual': 1,\n", 832 | " 'possible': 0},\n", 833 | " 'exact': {'correct': 0,\n", 834 | " 'incorrect': 0,\n", 835 | " 'partial': 0,\n", 836 | " 'missed': 0,\n", 837 | " 'spurious': 1,\n", 838 | " 'precision': 0,\n", 839 | " 'recall': 0,\n", 840 | " 'actual': 1,\n", 841 | " 'possible': 0}}})" 842 | ] 843 | }, 844 | "execution_count": 40, 845 | "metadata": {}, 846 | "output_type": "execute_result" 847 | } 848 | ], 849 | "source": [ 850 | "compute_metrics(true_named_entities_type['LOC'], pred_named_entities_type['LOC'], ['LOC', 'MISC', 'PER', 'ORG'])" 851 | ] 852 | }, 853 | { 854 | "cell_type": "markdown", 855 | "metadata": {}, 856 | "source": [ 857 | "## results over all messages" 858 | ] 859 | }, 860 | { 861 | "cell_type": "code", 862 | "execution_count": 42, 863 | "metadata": { 864 | "scrolled": false 865 | }, 866 | "outputs": [], 867 | "source": [ 868 | "metrics_results = {'correct': 0, 'incorrect': 0, 'partial': 0,\n", 869 | " 'missed': 0, 'spurious': 0, 'possible': 0, 'actual': 0, 'precision': 0, 'recall': 0}\n", 870 | "\n", 871 | "# overall results\n", 872 | "results = {'strict': deepcopy(metrics_results),\n", 873 | " 'ent_type': deepcopy(metrics_results),\n", 874 | " 'partial':deepcopy(metrics_results),\n", 875 | " 'exact':deepcopy(metrics_results)\n", 876 | " }\n", 877 | "\n", 878 | "\n", 879 | "# results aggregated by entity type\n", 880 | "evaluation_agg_entities_type = {e: deepcopy(results) for e in ['PER', 'LOC', 'MISC', 'ORG']}\n", 881 | "\n", 882 | "for true_ents, pred_ents in zip(test_sents_labels, y_pred):\n", 883 | " \n", 884 | " # compute results for one message\n", 885 | " tmp_results, tmp_agg_results = compute_metrics(\n", 886 | " collect_named_entities(true_ents), collect_named_entities(pred_ents), ['LOC', 'MISC', 'PER', 'ORG']\n", 887 | " )\n", 888 | " \n", 889 | " #print(tmp_results)\n", 890 | "\n", 891 | " # aggregate overall results\n", 892 | " for eval_schema in results.keys():\n", 893 | " for metric in metrics_results.keys():\n", 894 | " results[eval_schema][metric] += tmp_results[eval_schema][metric]\n", 895 | " \n", 896 | " # Calculate global precision and recall\n", 897 | " \n", 898 | " results = compute_precision_recall_wrapper(results)\n", 899 | "\n", 900 | "\n", 901 | " # aggregate results by entity type\n", 902 | " \n", 903 | " for e_type in ['PER', 'LOC', 'MISC', 'ORG']:\n", 904 | "\n", 905 | " for eval_schema in tmp_agg_results[e_type]:\n", 906 | "\n", 907 | " for metric in tmp_agg_results[e_type][eval_schema]:\n", 908 | " \n", 909 | " evaluation_agg_entities_type[e_type][eval_schema][metric] += tmp_agg_results[e_type][eval_schema][metric]\n", 910 | " \n", 911 | " # Calculate precision recall at the individual entity level\n", 912 | " \n", 913 | " evaluation_agg_entities_type[e_type] = compute_precision_recall_wrapper(evaluation_agg_entities_type[e_type])\n", 914 | " \n", 915 | " " 916 | ] 917 | }, 918 | { 919 | "cell_type": "code", 920 | "execution_count": 43, 921 | "metadata": {}, 922 | "outputs": [ 923 | { 924 | "data": { 925 | "text/plain": [ 926 | "{'ent_type': {'correct': 2860,\n", 927 | " 'incorrect': 523,\n", 928 | " 'partial': 0,\n", 929 | " 'missed': 176,\n", 930 | " 'spurious': 139,\n", 931 | " 'possible': 3559,\n", 932 | " 'actual': 3522,\n", 933 | " 'precision': 0.8120386144236229,\n", 934 | " 'recall': 0.8035965158752458},\n", 935 | " 'partial': {'correct': 3278,\n", 936 | " 'incorrect': 0,\n", 937 | " 'partial': 105,\n", 938 | " 'missed': 176,\n", 939 | " 'spurious': 139,\n", 940 | " 'possible': 3559,\n", 941 | " 'actual': 3522,\n", 942 | " 'precision': 0.9456274843838728,\n", 943 | " 'recall': 0.9357965720708064},\n", 944 | " 'strict': {'correct': 2783,\n", 945 | " 'incorrect': 600,\n", 946 | " 'partial': 0,\n", 947 | " 'missed': 176,\n", 948 | " 'spurious': 139,\n", 949 | " 'possible': 3559,\n", 950 | " 'actual': 3522,\n", 951 | " 'precision': 0.7901760363429869,\n", 952 | " 'recall': 0.78196122506322},\n", 953 | " 'exact': {'correct': 3278,\n", 954 | " 'incorrect': 105,\n", 955 | " 'partial': 0,\n", 956 | " 'missed': 176,\n", 957 | " 'spurious': 139,\n", 958 | " 'possible': 3559,\n", 959 | " 'actual': 3522,\n", 960 | " 'precision': 0.9307211811470755,\n", 961 | " 'recall': 0.9210452374262433}}" 962 | ] 963 | }, 964 | "execution_count": 43, 965 | "metadata": {}, 966 | "output_type": "execute_result" 967 | } 968 | ], 969 | "source": [ 970 | "results" 971 | ] 972 | }, 973 | { 974 | "cell_type": "code", 975 | "execution_count": 44, 976 | "metadata": { 977 | "scrolled": false 978 | }, 979 | "outputs": [ 980 | { 981 | "data": { 982 | "text/plain": [ 983 | "{'PER': {'ent_type': {'correct': 651,\n", 984 | " 'incorrect': 67,\n", 985 | " 'partial': 0,\n", 986 | " 'missed': 17,\n", 987 | " 'spurious': 139,\n", 988 | " 'possible': 735,\n", 989 | " 'actual': 857,\n", 990 | " 'precision': 0.7596266044340724,\n", 991 | " 'recall': 0.8857142857142857},\n", 992 | " 'partial': {'correct': 711,\n", 993 | " 'incorrect': 0,\n", 994 | " 'partial': 7,\n", 995 | " 'missed': 17,\n", 996 | " 'spurious': 139,\n", 997 | " 'possible': 735,\n", 998 | " 'actual': 857,\n", 999 | " 'precision': 0.8337222870478413,\n", 1000 | " 'recall': 0.972108843537415},\n", 1001 | " 'strict': {'correct': 646,\n", 1002 | " 'incorrect': 72,\n", 1003 | " 'partial': 0,\n", 1004 | " 'missed': 17,\n", 1005 | " 'spurious': 139,\n", 1006 | " 'possible': 735,\n", 1007 | " 'actual': 857,\n", 1008 | " 'precision': 0.7537922987164527,\n", 1009 | " 'recall': 0.8789115646258503},\n", 1010 | " 'exact': {'correct': 711,\n", 1011 | " 'incorrect': 7,\n", 1012 | " 'partial': 0,\n", 1013 | " 'missed': 17,\n", 1014 | " 'spurious': 139,\n", 1015 | " 'possible': 735,\n", 1016 | " 'actual': 857,\n", 1017 | " 'precision': 0.8296382730455076,\n", 1018 | " 'recall': 0.9673469387755103}},\n", 1019 | " 'LOC': {'ent_type': {'correct': 855,\n", 1020 | " 'incorrect': 180,\n", 1021 | " 'partial': 0,\n", 1022 | " 'missed': 49,\n", 1023 | " 'spurious': 139,\n", 1024 | " 'possible': 1084,\n", 1025 | " 'actual': 1174,\n", 1026 | " 'precision': 0.7282793867120954,\n", 1027 | " 'recall': 0.7887453874538746},\n", 1028 | " 'partial': {'correct': 1016,\n", 1029 | " 'incorrect': 0,\n", 1030 | " 'partial': 19,\n", 1031 | " 'missed': 49,\n", 1032 | " 'spurious': 139,\n", 1033 | " 'possible': 1084,\n", 1034 | " 'actual': 1174,\n", 1035 | " 'precision': 0.8735093696763203,\n", 1036 | " 'recall': 0.9460332103321033},\n", 1037 | " 'strict': {'correct': 844,\n", 1038 | " 'incorrect': 191,\n", 1039 | " 'partial': 0,\n", 1040 | " 'missed': 49,\n", 1041 | " 'spurious': 139,\n", 1042 | " 'possible': 1084,\n", 1043 | " 'actual': 1174,\n", 1044 | " 'precision': 0.7189097103918228,\n", 1045 | " 'recall': 0.7785977859778598},\n", 1046 | " 'exact': {'correct': 1016,\n", 1047 | " 'incorrect': 19,\n", 1048 | " 'partial': 0,\n", 1049 | " 'missed': 49,\n", 1050 | " 'spurious': 139,\n", 1051 | " 'possible': 1084,\n", 1052 | " 'actual': 1174,\n", 1053 | " 'precision': 0.8654173764906303,\n", 1054 | " 'recall': 0.9372693726937269}},\n", 1055 | " 'MISC': {'ent_type': {'correct': 200,\n", 1056 | " 'incorrect': 89,\n", 1057 | " 'partial': 0,\n", 1058 | " 'missed': 51,\n", 1059 | " 'spurious': 139,\n", 1060 | " 'possible': 340,\n", 1061 | " 'actual': 428,\n", 1062 | " 'precision': 0.4672897196261682,\n", 1063 | " 'recall': 0.5882352941176471},\n", 1064 | " 'partial': {'correct': 257,\n", 1065 | " 'incorrect': 0,\n", 1066 | " 'partial': 32,\n", 1067 | " 'missed': 51,\n", 1068 | " 'spurious': 139,\n", 1069 | " 'possible': 340,\n", 1070 | " 'actual': 428,\n", 1071 | " 'precision': 0.6378504672897196,\n", 1072 | " 'recall': 0.8029411764705883},\n", 1073 | " 'strict': {'correct': 173,\n", 1074 | " 'incorrect': 116,\n", 1075 | " 'partial': 0,\n", 1076 | " 'missed': 51,\n", 1077 | " 'spurious': 139,\n", 1078 | " 'possible': 340,\n", 1079 | " 'actual': 428,\n", 1080 | " 'precision': 0.40420560747663553,\n", 1081 | " 'recall': 0.5088235294117647},\n", 1082 | " 'exact': {'correct': 257,\n", 1083 | " 'incorrect': 32,\n", 1084 | " 'partial': 0,\n", 1085 | " 'missed': 51,\n", 1086 | " 'spurious': 139,\n", 1087 | " 'possible': 340,\n", 1088 | " 'actual': 428,\n", 1089 | " 'precision': 0.6004672897196262,\n", 1090 | " 'recall': 0.7558823529411764}},\n", 1091 | " 'ORG': {'ent_type': {'correct': 1154,\n", 1092 | " 'incorrect': 187,\n", 1093 | " 'partial': 0,\n", 1094 | " 'missed': 59,\n", 1095 | " 'spurious': 139,\n", 1096 | " 'possible': 1400,\n", 1097 | " 'actual': 1480,\n", 1098 | " 'precision': 0.7797297297297298,\n", 1099 | " 'recall': 0.8242857142857143},\n", 1100 | " 'partial': {'correct': 1294,\n", 1101 | " 'incorrect': 0,\n", 1102 | " 'partial': 47,\n", 1103 | " 'missed': 59,\n", 1104 | " 'spurious': 139,\n", 1105 | " 'possible': 1400,\n", 1106 | " 'actual': 1480,\n", 1107 | " 'precision': 0.8902027027027027,\n", 1108 | " 'recall': 0.9410714285714286},\n", 1109 | " 'strict': {'correct': 1120,\n", 1110 | " 'incorrect': 221,\n", 1111 | " 'partial': 0,\n", 1112 | " 'missed': 59,\n", 1113 | " 'spurious': 139,\n", 1114 | " 'possible': 1400,\n", 1115 | " 'actual': 1480,\n", 1116 | " 'precision': 0.7567567567567568,\n", 1117 | " 'recall': 0.8},\n", 1118 | " 'exact': {'correct': 1294,\n", 1119 | " 'incorrect': 47,\n", 1120 | " 'partial': 0,\n", 1121 | " 'missed': 59,\n", 1122 | " 'spurious': 139,\n", 1123 | " 'possible': 1400,\n", 1124 | " 'actual': 1480,\n", 1125 | " 'precision': 0.8743243243243243,\n", 1126 | " 'recall': 0.9242857142857143}}}" 1127 | ] 1128 | }, 1129 | "execution_count": 44, 1130 | "metadata": {}, 1131 | "output_type": "execute_result" 1132 | } 1133 | ], 1134 | "source": [ 1135 | "evaluation_agg_entities_type" 1136 | ] 1137 | }, 1138 | { 1139 | "cell_type": "code", 1140 | "execution_count": 45, 1141 | "metadata": {}, 1142 | "outputs": [], 1143 | "source": [ 1144 | "from ner_evaluation.ner_eval import Evaluator" 1145 | ] 1146 | }, 1147 | { 1148 | "cell_type": "code", 1149 | "execution_count": 46, 1150 | "metadata": {}, 1151 | "outputs": [], 1152 | "source": [ 1153 | "evaluator = Evaluator(test_sents_labels, y_pred, ['LOC', 'MISC', 'PER', 'ORG'])" 1154 | ] 1155 | }, 1156 | { 1157 | "cell_type": "code", 1158 | "execution_count": 47, 1159 | "metadata": {}, 1160 | "outputs": [ 1161 | { 1162 | "name": "stderr", 1163 | "output_type": "stream", 1164 | "text": [ 1165 | "2019-03-12 12:00:31 root INFO: Imported 1517 predictions for 1517 true examples\n" 1166 | ] 1167 | } 1168 | ], 1169 | "source": [ 1170 | "results, results_agg = evaluator.evaluate()" 1171 | ] 1172 | }, 1173 | { 1174 | "cell_type": "code", 1175 | "execution_count": 48, 1176 | "metadata": {}, 1177 | "outputs": [ 1178 | { 1179 | "data": { 1180 | "text/plain": [ 1181 | "{'ent_type': {'correct': 2860,\n", 1182 | " 'incorrect': 523,\n", 1183 | " 'partial': 0,\n", 1184 | " 'missed': 176,\n", 1185 | " 'spurious': 139,\n", 1186 | " 'possible': 3559,\n", 1187 | " 'actual': 3522,\n", 1188 | " 'precision': 0.8120386144236229,\n", 1189 | " 'recall': 0.8035965158752458},\n", 1190 | " 'partial': {'correct': 3278,\n", 1191 | " 'incorrect': 0,\n", 1192 | " 'partial': 105,\n", 1193 | " 'missed': 176,\n", 1194 | " 'spurious': 139,\n", 1195 | " 'possible': 3559,\n", 1196 | " 'actual': 3522,\n", 1197 | " 'precision': 0.9456274843838728,\n", 1198 | " 'recall': 0.9357965720708064},\n", 1199 | " 'strict': {'correct': 2783,\n", 1200 | " 'incorrect': 600,\n", 1201 | " 'partial': 0,\n", 1202 | " 'missed': 176,\n", 1203 | " 'spurious': 139,\n", 1204 | " 'possible': 3559,\n", 1205 | " 'actual': 3522,\n", 1206 | " 'precision': 0.7901760363429869,\n", 1207 | " 'recall': 0.78196122506322},\n", 1208 | " 'exact': {'correct': 3278,\n", 1209 | " 'incorrect': 105,\n", 1210 | " 'partial': 0,\n", 1211 | " 'missed': 176,\n", 1212 | " 'spurious': 139,\n", 1213 | " 'possible': 3559,\n", 1214 | " 'actual': 3522,\n", 1215 | " 'precision': 0.9307211811470755,\n", 1216 | " 'recall': 0.9210452374262433}}" 1217 | ] 1218 | }, 1219 | "execution_count": 48, 1220 | "metadata": {}, 1221 | "output_type": "execute_result" 1222 | } 1223 | ], 1224 | "source": [ 1225 | "results" 1226 | ] 1227 | }, 1228 | { 1229 | "cell_type": "code", 1230 | "execution_count": 49, 1231 | "metadata": {}, 1232 | "outputs": [ 1233 | { 1234 | "data": { 1235 | "text/plain": [ 1236 | "{'LOC': {'ent_type': {'correct': 855,\n", 1237 | " 'incorrect': 180,\n", 1238 | " 'partial': 0,\n", 1239 | " 'missed': 49,\n", 1240 | " 'spurious': 139,\n", 1241 | " 'possible': 1084,\n", 1242 | " 'actual': 1174,\n", 1243 | " 'precision': 0.7282793867120954,\n", 1244 | " 'recall': 0.7887453874538746},\n", 1245 | " 'partial': {'correct': 1016,\n", 1246 | " 'incorrect': 0,\n", 1247 | " 'partial': 19,\n", 1248 | " 'missed': 49,\n", 1249 | " 'spurious': 139,\n", 1250 | " 'possible': 1084,\n", 1251 | " 'actual': 1174,\n", 1252 | " 'precision': 0.8735093696763203,\n", 1253 | " 'recall': 0.9460332103321033},\n", 1254 | " 'strict': {'correct': 844,\n", 1255 | " 'incorrect': 191,\n", 1256 | " 'partial': 0,\n", 1257 | " 'missed': 49,\n", 1258 | " 'spurious': 139,\n", 1259 | " 'possible': 1084,\n", 1260 | " 'actual': 1174,\n", 1261 | " 'precision': 0.7189097103918228,\n", 1262 | " 'recall': 0.7785977859778598},\n", 1263 | " 'exact': {'correct': 1016,\n", 1264 | " 'incorrect': 19,\n", 1265 | " 'partial': 0,\n", 1266 | " 'missed': 49,\n", 1267 | " 'spurious': 139,\n", 1268 | " 'possible': 1084,\n", 1269 | " 'actual': 1174,\n", 1270 | " 'precision': 0.8654173764906303,\n", 1271 | " 'recall': 0.9372693726937269}},\n", 1272 | " 'MISC': {'ent_type': {'correct': 200,\n", 1273 | " 'incorrect': 89,\n", 1274 | " 'partial': 0,\n", 1275 | " 'missed': 51,\n", 1276 | " 'spurious': 139,\n", 1277 | " 'possible': 340,\n", 1278 | " 'actual': 428,\n", 1279 | " 'precision': 0.4672897196261682,\n", 1280 | " 'recall': 0.5882352941176471},\n", 1281 | " 'partial': {'correct': 257,\n", 1282 | " 'incorrect': 0,\n", 1283 | " 'partial': 32,\n", 1284 | " 'missed': 51,\n", 1285 | " 'spurious': 139,\n", 1286 | " 'possible': 340,\n", 1287 | " 'actual': 428,\n", 1288 | " 'precision': 0.6378504672897196,\n", 1289 | " 'recall': 0.8029411764705883},\n", 1290 | " 'strict': {'correct': 173,\n", 1291 | " 'incorrect': 116,\n", 1292 | " 'partial': 0,\n", 1293 | " 'missed': 51,\n", 1294 | " 'spurious': 139,\n", 1295 | " 'possible': 340,\n", 1296 | " 'actual': 428,\n", 1297 | " 'precision': 0.40420560747663553,\n", 1298 | " 'recall': 0.5088235294117647},\n", 1299 | " 'exact': {'correct': 257,\n", 1300 | " 'incorrect': 32,\n", 1301 | " 'partial': 0,\n", 1302 | " 'missed': 51,\n", 1303 | " 'spurious': 139,\n", 1304 | " 'possible': 340,\n", 1305 | " 'actual': 428,\n", 1306 | " 'precision': 0.6004672897196262,\n", 1307 | " 'recall': 0.7558823529411764}},\n", 1308 | " 'PER': {'ent_type': {'correct': 651,\n", 1309 | " 'incorrect': 67,\n", 1310 | " 'partial': 0,\n", 1311 | " 'missed': 17,\n", 1312 | " 'spurious': 139,\n", 1313 | " 'possible': 735,\n", 1314 | " 'actual': 857,\n", 1315 | " 'precision': 0.7596266044340724,\n", 1316 | " 'recall': 0.8857142857142857},\n", 1317 | " 'partial': {'correct': 711,\n", 1318 | " 'incorrect': 0,\n", 1319 | " 'partial': 7,\n", 1320 | " 'missed': 17,\n", 1321 | " 'spurious': 139,\n", 1322 | " 'possible': 735,\n", 1323 | " 'actual': 857,\n", 1324 | " 'precision': 0.8337222870478413,\n", 1325 | " 'recall': 0.972108843537415},\n", 1326 | " 'strict': {'correct': 646,\n", 1327 | " 'incorrect': 72,\n", 1328 | " 'partial': 0,\n", 1329 | " 'missed': 17,\n", 1330 | " 'spurious': 139,\n", 1331 | " 'possible': 735,\n", 1332 | " 'actual': 857,\n", 1333 | " 'precision': 0.7537922987164527,\n", 1334 | " 'recall': 0.8789115646258503},\n", 1335 | " 'exact': {'correct': 711,\n", 1336 | " 'incorrect': 7,\n", 1337 | " 'partial': 0,\n", 1338 | " 'missed': 17,\n", 1339 | " 'spurious': 139,\n", 1340 | " 'possible': 735,\n", 1341 | " 'actual': 857,\n", 1342 | " 'precision': 0.8296382730455076,\n", 1343 | " 'recall': 0.9673469387755103}},\n", 1344 | " 'ORG': {'ent_type': {'correct': 1154,\n", 1345 | " 'incorrect': 187,\n", 1346 | " 'partial': 0,\n", 1347 | " 'missed': 59,\n", 1348 | " 'spurious': 139,\n", 1349 | " 'possible': 1400,\n", 1350 | " 'actual': 1480,\n", 1351 | " 'precision': 0.7797297297297298,\n", 1352 | " 'recall': 0.8242857142857143},\n", 1353 | " 'partial': {'correct': 1294,\n", 1354 | " 'incorrect': 0,\n", 1355 | " 'partial': 47,\n", 1356 | " 'missed': 59,\n", 1357 | " 'spurious': 139,\n", 1358 | " 'possible': 1400,\n", 1359 | " 'actual': 1480,\n", 1360 | " 'precision': 0.8902027027027027,\n", 1361 | " 'recall': 0.9410714285714286},\n", 1362 | " 'strict': {'correct': 1120,\n", 1363 | " 'incorrect': 221,\n", 1364 | " 'partial': 0,\n", 1365 | " 'missed': 59,\n", 1366 | " 'spurious': 139,\n", 1367 | " 'possible': 1400,\n", 1368 | " 'actual': 1480,\n", 1369 | " 'precision': 0.7567567567567568,\n", 1370 | " 'recall': 0.8},\n", 1371 | " 'exact': {'correct': 1294,\n", 1372 | " 'incorrect': 47,\n", 1373 | " 'partial': 0,\n", 1374 | " 'missed': 59,\n", 1375 | " 'spurious': 139,\n", 1376 | " 'possible': 1400,\n", 1377 | " 'actual': 1480,\n", 1378 | " 'precision': 0.8743243243243243,\n", 1379 | " 'recall': 0.9242857142857143}}}" 1380 | ] 1381 | }, 1382 | "execution_count": 49, 1383 | "metadata": {}, 1384 | "output_type": "execute_result" 1385 | } 1386 | ], 1387 | "source": [ 1388 | "results_agg" 1389 | ] 1390 | }, 1391 | { 1392 | "cell_type": "code", 1393 | "execution_count": null, 1394 | "metadata": {}, 1395 | "outputs": [], 1396 | "source": [] 1397 | } 1398 | ], 1399 | "metadata": { 1400 | "kernelspec": { 1401 | "display_name": "NER-Evaluation", 1402 | "language": "python", 1403 | "name": "ner-evaluation" 1404 | }, 1405 | "language_info": { 1406 | "codemirror_mode": { 1407 | "name": "ipython", 1408 | "version": 3 1409 | }, 1410 | "file_extension": ".py", 1411 | "mimetype": "text/x-python", 1412 | "name": "python", 1413 | "nbconvert_exporter": "python", 1414 | "pygments_lexer": "ipython3", 1415 | "version": "3.6.6" 1416 | } 1417 | }, 1418 | "nbformat": 4, 1419 | "nbformat_minor": 2 1420 | } 1421 | -------------------------------------------------------------------------------- /ner_evaluation/.coverage: -------------------------------------------------------------------------------- 1 | !coverage.py: This is a private format, don't read it directly!{"lines":{}} -------------------------------------------------------------------------------- /ner_evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidsbatista/NER-Evaluation/054d52a981505a6dbe8751b2e223e6c760e812e6/ner_evaluation/__init__.py -------------------------------------------------------------------------------- /ner_evaluation/ner_eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import namedtuple 3 | from copy import deepcopy 4 | 5 | logging.basicConfig( 6 | format="%(asctime)s %(name)s %(levelname)s: %(message)s", 7 | datefmt="%Y-%m-%d %H:%M:%S", 8 | level="DEBUG", 9 | ) 10 | 11 | Entity = namedtuple("Entity", "e_type start_offset end_offset") 12 | 13 | class Evaluator(): 14 | 15 | def __init__(self, true, pred, tags): 16 | """ 17 | """ 18 | 19 | if len(true) != len(pred): 20 | raise ValueError("Number of predicted documents does not equal true") 21 | 22 | self.true = true 23 | self.pred = pred 24 | self.tags = tags 25 | 26 | # Setup dict into which metrics will be stored. 27 | 28 | self.metrics_results = { 29 | 'correct': 0, 30 | 'incorrect': 0, 31 | 'partial': 0, 32 | 'missed': 0, 33 | 'spurious': 0, 34 | 'possible': 0, 35 | 'actual': 0, 36 | 'precision': 0, 37 | 'recall': 0, 38 | } 39 | 40 | # Copy results dict to cover the four schemes. 41 | 42 | self.results = { 43 | 'strict': deepcopy(self.metrics_results), 44 | 'ent_type': deepcopy(self.metrics_results), 45 | 'partial':deepcopy(self.metrics_results), 46 | 'exact':deepcopy(self.metrics_results), 47 | } 48 | 49 | # Create an accumulator to store results 50 | 51 | self.evaluation_agg_entities_type = {e: deepcopy(self.results) for e in tags} 52 | 53 | 54 | def evaluate(self): 55 | 56 | logging.info( 57 | "Imported %s predictions for %s true examples", 58 | len(self.pred), len(self.true) 59 | ) 60 | 61 | for true_ents, pred_ents in zip(self.true, self.pred): 62 | 63 | # Check that the length of the true and predicted examples are the 64 | # same. This must be checked here, because another error may not 65 | # be thrown if the lengths do not match. 66 | 67 | if len(true_ents) != len(pred_ents): 68 | raise ValueError("Prediction length does not match true example length") 69 | 70 | # Compute results for one message 71 | 72 | tmp_results, tmp_agg_results = compute_metrics( 73 | collect_named_entities(true_ents), 74 | collect_named_entities(pred_ents), 75 | self.tags 76 | ) 77 | 78 | # Cycle through each result and accumulate 79 | 80 | # TODO: Combine these loops below: 81 | 82 | for eval_schema in self.results: 83 | 84 | for metric in self.results[eval_schema]: 85 | 86 | self.results[eval_schema][metric] += tmp_results[eval_schema][metric] 87 | 88 | # Calculate global precision and recall 89 | 90 | self.results = compute_precision_recall_wrapper(self.results) 91 | 92 | # Aggregate results by entity type 93 | 94 | for e_type in self.tags: 95 | 96 | for eval_schema in tmp_agg_results[e_type]: 97 | 98 | for metric in tmp_agg_results[e_type][eval_schema]: 99 | 100 | self.evaluation_agg_entities_type[e_type][eval_schema][metric] += tmp_agg_results[e_type][eval_schema][metric] 101 | 102 | # Calculate precision recall at the individual entity level 103 | 104 | self.evaluation_agg_entities_type[e_type] = compute_precision_recall_wrapper(self.evaluation_agg_entities_type[e_type]) 105 | 106 | return self.results, self.evaluation_agg_entities_type 107 | 108 | 109 | def collect_named_entities(tokens): 110 | """ 111 | Creates a list of Entity named-tuples, storing the entity type and the start and end 112 | offsets of the entity. 113 | 114 | :param tokens: a list of tags 115 | :return: a list of Entity named-tuples 116 | """ 117 | 118 | named_entities = [] 119 | start_offset = None 120 | end_offset = None 121 | ent_type = None 122 | 123 | for offset, token_tag in enumerate(tokens): 124 | 125 | if token_tag == 'O': 126 | if ent_type is not None and start_offset is not None: 127 | end_offset = offset - 1 128 | named_entities.append(Entity(ent_type, start_offset, end_offset)) 129 | start_offset = None 130 | end_offset = None 131 | ent_type = None 132 | 133 | elif ent_type is None: 134 | ent_type = token_tag[2:] 135 | start_offset = offset 136 | 137 | elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == 'B'): 138 | 139 | end_offset = offset - 1 140 | named_entities.append(Entity(ent_type, start_offset, end_offset)) 141 | 142 | # start of a new entity 143 | ent_type = token_tag[2:] 144 | start_offset = offset 145 | end_offset = None 146 | 147 | # catches an entity that goes up until the last token 148 | 149 | if ent_type is not None and start_offset is not None and end_offset is None: 150 | named_entities.append(Entity(ent_type, start_offset, len(tokens)-1)) 151 | 152 | return named_entities 153 | 154 | 155 | def compute_metrics(true_named_entities, pred_named_entities, tags): 156 | 157 | 158 | eval_metrics = {'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 0, 'spurious': 0, 'precision': 0, 'recall': 0} 159 | 160 | # overall results 161 | 162 | evaluation = { 163 | 'strict': deepcopy(eval_metrics), 164 | 'ent_type': deepcopy(eval_metrics), 165 | 'partial': deepcopy(eval_metrics), 166 | 'exact': deepcopy(eval_metrics) 167 | } 168 | 169 | # results by entity type 170 | 171 | evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags} 172 | 173 | # keep track of entities that overlapped 174 | 175 | true_which_overlapped_with_pred = [] 176 | 177 | # Subset into only the tags that we are interested in. 178 | # NOTE: we remove the tags we don't want from both the predicted and the 179 | # true entities. This covers the two cases where mismatches can occur: 180 | # 181 | # 1) Where the model predicts a tag that is not present in the true data 182 | # 2) Where there is a tag in the true data that the model is not capable of 183 | # predicting. 184 | 185 | true_named_entities = [ent for ent in true_named_entities if ent.e_type in tags] 186 | pred_named_entities = [ent for ent in pred_named_entities if ent.e_type in tags] 187 | 188 | # go through each predicted named-entity 189 | 190 | for pred in pred_named_entities: 191 | found_overlap = False 192 | 193 | # Check each of the potential scenarios in turn. See 194 | # http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/ 195 | # for scenario explanation. 196 | 197 | # Scenario I: Exact match between true and pred 198 | 199 | if pred in true_named_entities: 200 | true_which_overlapped_with_pred.append(pred) 201 | evaluation['strict']['correct'] += 1 202 | evaluation['ent_type']['correct'] += 1 203 | evaluation['exact']['correct'] += 1 204 | evaluation['partial']['correct'] += 1 205 | 206 | # for the agg. by e_type results 207 | evaluation_agg_entities_type[pred.e_type]['strict']['correct'] += 1 208 | evaluation_agg_entities_type[pred.e_type]['ent_type']['correct'] += 1 209 | evaluation_agg_entities_type[pred.e_type]['exact']['correct'] += 1 210 | evaluation_agg_entities_type[pred.e_type]['partial']['correct'] += 1 211 | 212 | else: 213 | 214 | # check for overlaps with any of the true entities 215 | 216 | for true in true_named_entities: 217 | 218 | pred_range = range(pred.start_offset, pred.end_offset) 219 | true_range = range(true.start_offset, true.end_offset) 220 | 221 | # Scenario IV: Offsets match, but entity type is wrong 222 | 223 | if true.start_offset == pred.start_offset and pred.end_offset == true.end_offset \ 224 | and true.e_type != pred.e_type: 225 | 226 | # overall results 227 | evaluation['strict']['incorrect'] += 1 228 | evaluation['ent_type']['incorrect'] += 1 229 | evaluation['partial']['correct'] += 1 230 | evaluation['exact']['correct'] += 1 231 | 232 | # aggregated by entity type results 233 | evaluation_agg_entities_type[true.e_type]['strict']['incorrect'] += 1 234 | evaluation_agg_entities_type[true.e_type]['ent_type']['incorrect'] += 1 235 | evaluation_agg_entities_type[true.e_type]['partial']['correct'] += 1 236 | evaluation_agg_entities_type[true.e_type]['exact']['correct'] += 1 237 | 238 | true_which_overlapped_with_pred.append(true) 239 | found_overlap = True 240 | 241 | break 242 | 243 | # check for an overlap i.e. not exact boundary match, with true entities 244 | 245 | elif find_overlap(true_range, pred_range): 246 | 247 | true_which_overlapped_with_pred.append(true) 248 | 249 | # Scenario V: There is an overlap (but offsets do not match 250 | # exactly), and the entity type is the same. 251 | # 2.1 overlaps with the same entity type 252 | 253 | if pred.e_type == true.e_type: 254 | 255 | # overall results 256 | evaluation['strict']['incorrect'] += 1 257 | evaluation['ent_type']['correct'] += 1 258 | evaluation['partial']['partial'] += 1 259 | evaluation['exact']['incorrect'] += 1 260 | 261 | # aggregated by entity type results 262 | evaluation_agg_entities_type[true.e_type]['strict']['incorrect'] += 1 263 | evaluation_agg_entities_type[true.e_type]['ent_type']['correct'] += 1 264 | evaluation_agg_entities_type[true.e_type]['partial']['partial'] += 1 265 | evaluation_agg_entities_type[true.e_type]['exact']['incorrect'] += 1 266 | 267 | found_overlap = True 268 | 269 | break 270 | 271 | # Scenario VI: Entities overlap, but the entity type is 272 | # different. 273 | 274 | else: 275 | # overall results 276 | evaluation['strict']['incorrect'] += 1 277 | evaluation['ent_type']['incorrect'] += 1 278 | evaluation['partial']['partial'] += 1 279 | evaluation['exact']['incorrect'] += 1 280 | 281 | # aggregated by entity type results 282 | # Results against the true entity 283 | 284 | evaluation_agg_entities_type[true.e_type]['strict']['incorrect'] += 1 285 | evaluation_agg_entities_type[true.e_type]['partial']['partial'] += 1 286 | evaluation_agg_entities_type[true.e_type]['ent_type']['incorrect'] += 1 287 | evaluation_agg_entities_type[true.e_type]['exact']['incorrect'] += 1 288 | 289 | # Results against the predicted entity 290 | 291 | # evaluation_agg_entities_type[pred.e_type]['strict']['spurious'] += 1 292 | 293 | found_overlap = True 294 | 295 | break 296 | 297 | # Scenario II: Entities are spurious (i.e., over-generated). 298 | 299 | if not found_overlap: 300 | 301 | # Overall results 302 | 303 | evaluation['strict']['spurious'] += 1 304 | evaluation['ent_type']['spurious'] += 1 305 | evaluation['partial']['spurious'] += 1 306 | evaluation['exact']['spurious'] += 1 307 | 308 | # Aggregated by entity type results 309 | 310 | # NOTE: when pred.e_type is not found in tags 311 | # or when it simply does not appear in the test set, then it is 312 | # spurious, but it is not clear where to assign it at the tag 313 | # level. In this case, it is applied to all target_tags 314 | # found in this example. This will mean that the sum of the 315 | # evaluation_agg_entities will not equal evaluation. 316 | 317 | for true in tags: 318 | 319 | evaluation_agg_entities_type[true]['strict']['spurious'] += 1 320 | evaluation_agg_entities_type[true]['ent_type']['spurious'] += 1 321 | evaluation_agg_entities_type[true]['partial']['spurious'] += 1 322 | evaluation_agg_entities_type[true]['exact']['spurious'] += 1 323 | 324 | # Scenario III: Entity was missed entirely. 325 | 326 | for true in true_named_entities: 327 | if true in true_which_overlapped_with_pred: 328 | continue 329 | else: 330 | # overall results 331 | evaluation['strict']['missed'] += 1 332 | evaluation['ent_type']['missed'] += 1 333 | evaluation['partial']['missed'] += 1 334 | evaluation['exact']['missed'] += 1 335 | 336 | # for the agg. by e_type 337 | evaluation_agg_entities_type[true.e_type]['strict']['missed'] += 1 338 | evaluation_agg_entities_type[true.e_type]['ent_type']['missed'] += 1 339 | evaluation_agg_entities_type[true.e_type]['partial']['missed'] += 1 340 | evaluation_agg_entities_type[true.e_type]['exact']['missed'] += 1 341 | 342 | # Compute 'possible', 'actual' according to SemEval-2013 Task 9.1 on the 343 | # overall results, and use these to calculate precision and recall. 344 | 345 | for eval_type in evaluation: 346 | evaluation[eval_type] = compute_actual_possible(evaluation[eval_type]) 347 | 348 | # Compute 'possible', 'actual', and precision and recall on entity level 349 | # results. Start by cycling through the accumulated results. 350 | 351 | for entity_type, entity_level in evaluation_agg_entities_type.items(): 352 | 353 | # Cycle through the evaluation types for each dict containing entity 354 | # level results. 355 | 356 | for eval_type in entity_level: 357 | 358 | evaluation_agg_entities_type[entity_type][eval_type] = compute_actual_possible( 359 | entity_level[eval_type] 360 | ) 361 | 362 | return evaluation, evaluation_agg_entities_type 363 | 364 | 365 | def find_overlap(true_range, pred_range): 366 | """Find the overlap between two ranges 367 | 368 | Find the overlap between two ranges. Return the overlapping values if 369 | present, else return an empty set(). 370 | 371 | Examples: 372 | 373 | >>> find_overlap((1, 2), (2, 3)) 374 | 2 375 | >>> find_overlap((1, 2), (3, 4)) 376 | set() 377 | """ 378 | 379 | true_set = set(true_range) 380 | pred_set = set(pred_range) 381 | 382 | overlaps = true_set.intersection(pred_set) 383 | 384 | return overlaps 385 | 386 | 387 | def compute_actual_possible(results): 388 | """ 389 | Takes a result dict that has been output by compute metrics. 390 | Returns the results dict with actual, possible populated. 391 | 392 | When the results dicts is from partial or ent_type metrics, then 393 | partial_or_type=True to ensure the right calculation is used for 394 | calculating precision and recall. 395 | """ 396 | 397 | correct = results['correct'] 398 | incorrect = results['incorrect'] 399 | partial = results['partial'] 400 | missed = results['missed'] 401 | spurious = results['spurious'] 402 | 403 | # Possible: number annotations in the gold-standard which contribute to the 404 | # final score 405 | 406 | possible = correct + incorrect + partial + missed 407 | 408 | # Actual: number of annotations produced by the NER system 409 | 410 | actual = correct + incorrect + partial + spurious 411 | 412 | results["actual"] = actual 413 | results["possible"] = possible 414 | 415 | return results 416 | 417 | 418 | def compute_precision_recall(results, partial_or_type=False): 419 | """ 420 | Takes a result dict that has been output by compute metrics. 421 | Returns the results dict with precison and recall populated. 422 | 423 | When the results dicts is from partial or ent_type metrics, then 424 | partial_or_type=True to ensure the right calculation is used for 425 | calculating precision and recall. 426 | """ 427 | 428 | actual = results["actual"] 429 | possible = results["possible"] 430 | partial = results['partial'] 431 | correct = results['correct'] 432 | 433 | if partial_or_type: 434 | precision = (correct + 0.5 * partial) / actual if actual > 0 else 0 435 | recall = (correct + 0.5 * partial) / possible if possible > 0 else 0 436 | 437 | else: 438 | precision = correct / actual if actual > 0 else 0 439 | recall = correct / possible if possible > 0 else 0 440 | 441 | results["precision"] = precision 442 | results["recall"] = recall 443 | 444 | return results 445 | 446 | 447 | def compute_precision_recall_wrapper(results): 448 | """ 449 | Wraps the compute_precision_recall function and runs on a dict of results 450 | """ 451 | 452 | results_a = {key: compute_precision_recall(value, True) for key, value in results.items() if 453 | key in ['partial', 'ent_type']} 454 | results_b = {key: compute_precision_recall(value) for key, value in results.items() if 455 | key in ['strict', 'exact']} 456 | 457 | results = {**results_a, **results_b} 458 | 459 | return results 460 | 461 | -------------------------------------------------------------------------------- /ner_evaluation/tests/test_evaluator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ner_evaluation.ner_eval import Evaluator 4 | 5 | 6 | def test_evaluator_simple_case(): 7 | 8 | true = [ 9 | ['O', 'O', 'B-PER', 'I-PER', 'O'], 10 | ['O', 'B-LOC', 'I-LOC', 'B-LOC', 'I-LOC', 'O'], 11 | ] 12 | 13 | pred = [ 14 | ['O', 'O', 'B-PER', 'I-PER', 'O'], 15 | ['O', 'B-LOC', 'I-LOC', 'B-LOC', 'I-LOC', 'O'], 16 | ] 17 | 18 | evaluator = Evaluator(true, pred, tags=['LOC', 'PER']) 19 | 20 | results, results_agg = evaluator.evaluate() 21 | 22 | expected = { 23 | 'strict': { 24 | 'correct': 3, 25 | 'incorrect': 0, 26 | 'partial': 0, 27 | 'missed': 0, 28 | 'spurious': 0, 29 | 'possible': 3, 30 | 'actual': 3, 31 | 'precision': 1.0, 32 | 'recall': 1.0 33 | }, 34 | 'ent_type': { 35 | 'correct': 3, 36 | 'incorrect': 0, 37 | 'partial': 0, 38 | 'missed': 0, 39 | 'spurious': 0, 40 | 'possible': 3, 41 | 'actual': 3, 42 | 'precision': 1.0, 43 | 'recall': 1.0 44 | }, 45 | 'partial': { 46 | 'correct': 3, 47 | 'incorrect': 0, 48 | 'partial': 0, 49 | 'missed': 0, 50 | 'spurious': 0, 51 | 'possible': 3, 52 | 'actual': 3, 53 | 'precision': 1.0, 54 | 'recall': 1.0 55 | }, 56 | 'exact': { 57 | 'correct': 3, 58 | 'incorrect': 0, 59 | 'partial': 0, 60 | 'missed': 0, 61 | 'spurious': 0, 62 | 'possible': 3, 63 | 'actual': 3, 64 | 'precision': 1.0, 65 | 'recall': 1.0 66 | } 67 | } 68 | 69 | assert results['strict'] == expected['strict'] 70 | assert results['ent_type'] == expected['ent_type'] 71 | assert results['partial'] == expected['partial'] 72 | assert results['exact'] == expected['exact'] 73 | 74 | def test_evaluator_simple_case_filtered_tags(): 75 | """ 76 | Check that tags can be exluded by passing the tags argument 77 | 78 | """ 79 | 80 | true = [ 81 | ['O', 'O', 'B-PER', 'I-PER', 'O'], 82 | ['O', 'B-LOC', 'I-LOC', 'B-LOC', 'I-LOC', 'O'], 83 | ['O', 'B-MISC', 'I-MISC', 'O', 'O', 'O'], 84 | ] 85 | 86 | pred = [ 87 | ['O', 'O', 'B-PER', 'I-PER', 'O'], 88 | ['O', 'B-LOC', 'I-LOC', 'B-LOC', 'I-LOC', 'O'], 89 | ['O', 'B-MISC', 'I-MISC', 'O', 'O', 'O'], 90 | ] 91 | 92 | evaluator = Evaluator(true, pred, tags=['PER', 'LOC']) 93 | 94 | results, results_agg = evaluator.evaluate() 95 | 96 | expected = { 97 | 'strict': { 98 | 'correct': 3, 99 | 'incorrect': 0, 100 | 'partial': 0, 101 | 'missed': 0, 102 | 'spurious': 0, 103 | 'possible': 3, 104 | 'actual': 3, 105 | 'precision': 1.0, 106 | 'recall': 1.0 107 | }, 108 | 'ent_type': { 109 | 'correct': 3, 110 | 'incorrect': 0, 111 | 'partial': 0, 112 | 'missed': 0, 113 | 'spurious': 0, 114 | 'possible': 3, 115 | 'actual': 3, 116 | 'precision': 1.0, 117 | 'recall': 1.0 118 | }, 119 | 'partial': { 120 | 'correct': 3, 121 | 'incorrect': 0, 122 | 'partial': 0, 123 | 'missed': 0, 124 | 'spurious': 0, 125 | 'possible': 3, 126 | 'actual': 3, 127 | 'precision': 1.0, 128 | 'recall': 1.0 129 | }, 130 | 'exact': { 131 | 'correct': 3, 132 | 'incorrect': 0, 133 | 'partial': 0, 134 | 'missed': 0, 135 | 'spurious': 0, 136 | 'possible': 3, 137 | 'actual': 3, 138 | 'precision': 1.0, 139 | 'recall': 1.0 140 | } 141 | } 142 | 143 | assert results['strict'] == expected['strict'] 144 | assert results['ent_type'] == expected['ent_type'] 145 | assert results['partial'] == expected['partial'] 146 | assert results['exact'] == expected['exact'] 147 | 148 | 149 | def test_evaluator_extra_classes(): 150 | """ 151 | Case when model predicts a class that is not in the gold (true) data 152 | """ 153 | 154 | true = [ 155 | ['O', 'B-ORG', 'I-ORG', 'I-ORG', 'O', 'O'], 156 | ] 157 | 158 | pred = [ 159 | ['O', 'B-FOO', 'I-FOO', 'I-FOO', 'O', 'O'], 160 | ] 161 | 162 | evaluator = Evaluator(true, pred, tags=['ORG', 'FOO']) 163 | 164 | results, results_agg = evaluator.evaluate() 165 | 166 | expected = { 167 | 'strict': { 168 | 'correct': 0, 169 | 'incorrect': 1, 170 | 'partial': 0, 171 | 'missed': 0, 172 | 'spurious': 0, 173 | 'possible': 1, 174 | 'actual': 1, 175 | 'precision': 0, 176 | 'recall': 0.0 177 | }, 178 | 'ent_type': { 179 | 'correct': 0, 180 | 'incorrect': 1, 181 | 'partial': 0, 182 | 'missed': 0, 183 | 'spurious': 0, 184 | 'possible': 1, 185 | 'actual': 1, 186 | 'precision': 0, 187 | 'recall': 0.0 188 | }, 189 | 'partial': { 190 | 'correct': 1, 191 | 'incorrect': 0, 192 | 'partial': 0, 193 | 'missed': 0, 194 | 'spurious': 0, 195 | 'possible': 1, 196 | 'actual': 1, 197 | 'precision': 1.0, 198 | 'recall': 1.0 199 | }, 200 | 'exact': { 201 | 'correct': 1, 202 | 'incorrect': 0, 203 | 'partial': 0, 204 | 'missed': 0, 205 | 'spurious': 0, 206 | 'possible': 1, 207 | 'actual': 1, 208 | 'precision': 1.0, 209 | 'recall': 1.0 210 | } 211 | } 212 | 213 | assert results['strict'] == expected['strict'] 214 | assert results['ent_type'] == expected['ent_type'] 215 | assert results['partial'] == expected['partial'] 216 | assert results['exact'] == expected['exact'] 217 | 218 | def test_evaluator_no_entities_in_prediction(): 219 | """ 220 | Case when model predicts a class that is not in the gold (true) data 221 | """ 222 | 223 | true = [ 224 | ['O', 'O', 'B-PER', 'I-PER', 'O', 'O'], 225 | ] 226 | 227 | pred = [ 228 | ['O', 'O', 'O', 'O', 'O', 'O'], 229 | ] 230 | 231 | evaluator = Evaluator(true, pred, tags=['PER']) 232 | 233 | results, results_agg = evaluator.evaluate() 234 | 235 | expected = { 236 | 'strict': { 237 | 'correct': 0, 238 | 'incorrect': 0, 239 | 'partial': 0, 240 | 'missed': 1, 241 | 'spurious': 0, 242 | 'possible': 1, 243 | 'actual': 0, 244 | 'precision': 0, 245 | 'recall': 0 246 | }, 247 | 'ent_type': { 248 | 'correct': 0, 249 | 'incorrect': 0, 250 | 'partial': 0, 251 | 'missed': 1, 252 | 'spurious': 0, 253 | 'possible': 1, 254 | 'actual': 0, 255 | 'precision': 0, 256 | 'recall': 0 257 | }, 258 | 'partial': { 259 | 'correct': 0, 260 | 'incorrect': 0, 261 | 'partial': 0, 262 | 'missed': 1, 263 | 'spurious': 0, 264 | 'possible': 1, 265 | 'actual': 0, 266 | 'precision': 0, 267 | 'recall': 0 268 | }, 269 | 'exact': { 270 | 'correct': 0, 271 | 'incorrect': 0, 272 | 'partial': 0, 273 | 'missed': 1, 274 | 'spurious': 0, 275 | 'possible': 1, 276 | 'actual': 0, 277 | 'precision': 0, 278 | 'recall': 0 279 | } 280 | } 281 | 282 | assert results['strict'] == expected['strict'] 283 | assert results['ent_type'] == expected['ent_type'] 284 | assert results['partial'] == expected['partial'] 285 | assert results['exact'] == expected['exact'] 286 | 287 | def test_evaluator_compare_results_and_results_agg(): 288 | """ 289 | Check that the label level results match the total results. 290 | """ 291 | 292 | true = [ 293 | ['O', 'O', 'B-PER', 'I-PER', 'O', 'O'], 294 | ] 295 | 296 | pred = [ 297 | ['O', 'O', 'B-PER', 'I-PER', 'O', 'O'], 298 | ] 299 | 300 | evaluator = Evaluator(true, pred, tags=['PER']) 301 | 302 | results, results_agg = evaluator.evaluate() 303 | 304 | expected = { 305 | 'strict': { 306 | 'correct': 1, 307 | 'incorrect': 0, 308 | 'partial': 0, 309 | 'missed': 0, 310 | 'spurious': 0, 311 | 'possible': 1, 312 | 'actual': 1, 313 | 'precision': 1, 314 | 'recall': 1 315 | }, 316 | 'ent_type': { 317 | 'correct': 1, 318 | 'incorrect': 0, 319 | 'partial': 0, 320 | 'missed': 0, 321 | 'spurious': 0, 322 | 'possible': 1, 323 | 'actual': 1, 324 | 'precision': 1, 325 | 'recall': 1 326 | }, 327 | 'partial': { 328 | 'correct': 1, 329 | 'incorrect': 0, 330 | 'partial': 0, 331 | 'missed': 0, 332 | 'spurious': 0, 333 | 'possible': 1, 334 | 'actual': 1, 335 | 'precision': 1, 336 | 'recall': 1 337 | }, 338 | 'exact': { 339 | 'correct': 1, 340 | 'incorrect': 0, 341 | 'partial': 0, 342 | 'missed': 0, 343 | 'spurious': 0, 344 | 'possible': 1, 345 | 'actual': 1, 346 | 'precision': 1, 347 | 'recall': 1 348 | } 349 | } 350 | 351 | expected_agg = { 352 | 'PER': { 353 | 'strict': { 354 | 'correct': 1, 355 | 'incorrect': 0, 356 | 'partial': 0, 357 | 'missed': 0, 358 | 'spurious': 0, 359 | 'possible': 1, 360 | 'actual': 1, 361 | 'precision': 1, 362 | 'recall': 1 363 | }, 364 | 'ent_type': { 365 | 'correct': 1, 366 | 'incorrect': 0, 367 | 'partial': 0, 368 | 'missed': 0, 369 | 'spurious': 0, 370 | 'possible': 1, 371 | 'actual': 1, 372 | 'precision': 1, 373 | 'recall': 1 374 | }, 375 | 'partial': { 376 | 'correct': 1, 377 | 'incorrect': 0, 378 | 'partial': 0, 379 | 'missed': 0, 380 | 'spurious': 0, 381 | 'possible': 1, 382 | 'actual': 1, 383 | 'precision': 1, 384 | 'recall': 1 385 | }, 386 | 'exact': { 387 | 'correct': 1, 388 | 'incorrect': 0, 389 | 'partial': 0, 390 | 'missed': 0, 391 | 'spurious': 0, 392 | 'possible': 1, 393 | 'actual': 1, 394 | 'precision': 1, 395 | 'recall': 1 396 | } 397 | } 398 | } 399 | 400 | assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] 401 | assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] 402 | assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] 403 | assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] 404 | 405 | assert results['strict'] == expected['strict'] 406 | assert results['ent_type'] == expected['ent_type'] 407 | assert results['partial'] == expected['partial'] 408 | assert results['exact'] == expected['exact'] 409 | 410 | assert results['strict'] == expected_agg['PER']['strict'] 411 | assert results['ent_type'] == expected_agg['PER']['ent_type'] 412 | assert results['partial'] == expected_agg['PER']['partial'] 413 | assert results['exact'] == expected_agg['PER']['exact'] 414 | 415 | def test_evaluator_compare_results_and_results_agg_1(): 416 | """ 417 | Test case when model predicts a label not in the test data. 418 | """ 419 | 420 | true = [ 421 | ['O', 'O', 'O', 'O', 'O', 'O'], 422 | ['O', 'O', 'B-ORG', 'I-ORG', 'O', 'O'], 423 | ['O', 'O', 'B-MISC', 'I-MISC', 'O', 'O'], 424 | ] 425 | 426 | pred = [ 427 | ['O', 'O', 'B-PER', 'I-PER', 'O', 'O'], 428 | ['O', 'O', 'B-ORG', 'I-ORG', 'O', 'O'], 429 | ['O', 'O', 'B-MISC', 'I-MISC', 'O', 'O'], 430 | ] 431 | 432 | evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'MISC']) 433 | 434 | results, results_agg = evaluator.evaluate() 435 | 436 | expected = { 437 | 'strict': { 438 | 'correct': 2, 439 | 'incorrect': 0, 440 | 'partial': 0, 441 | 'missed': 0, 442 | 'spurious': 1, 443 | 'possible': 2, 444 | 'actual': 3, 445 | 'precision': 0.6666666666666666, 446 | 'recall': 1.0, 447 | }, 448 | 'ent_type': { 449 | 'correct': 2, 450 | 'incorrect': 0, 451 | 'partial': 0, 452 | 'missed': 0, 453 | 'spurious': 1, 454 | 'possible': 2, 455 | 'actual': 3, 456 | 'precision': 0.6666666666666666, 457 | 'recall': 1.0, 458 | }, 459 | 'partial': { 460 | 'correct': 2, 461 | 'incorrect': 0, 462 | 'partial': 0, 463 | 'missed': 0, 464 | 'spurious': 1, 465 | 'possible': 2, 466 | 'actual': 3, 467 | 'precision': 0.6666666666666666, 468 | 'recall': 1.0, 469 | }, 470 | 'exact': { 471 | 'correct': 2, 472 | 'incorrect': 0, 473 | 'partial': 0, 474 | 'missed': 0, 475 | 'spurious': 1, 476 | 'possible': 2, 477 | 'actual': 3, 478 | 'precision': 0.6666666666666666, 479 | 'recall': 1.0, 480 | } 481 | } 482 | 483 | expected_agg = { 484 | 'ORG': { 485 | 'strict': { 486 | 'correct': 1, 487 | 'incorrect': 0, 488 | 'partial': 0, 489 | 'missed': 0, 490 | 'spurious': 1, 491 | 'possible': 1, 492 | 'actual': 2, 493 | 'precision': 0.5, 494 | 'recall': 1 495 | }, 496 | 'ent_type': { 497 | 'correct': 1, 498 | 'incorrect': 0, 499 | 'partial': 0, 500 | 'missed': 0, 501 | 'spurious': 1, 502 | 'possible': 1, 503 | 'actual': 2, 504 | 'precision': 0.5, 505 | 'recall': 1 506 | }, 507 | 'partial': { 508 | 'correct': 1, 509 | 'incorrect': 0, 510 | 'partial': 0, 511 | 'missed': 0, 512 | 'spurious': 1, 513 | 'possible': 1, 514 | 'actual': 2, 515 | 'precision': 0.5, 516 | 'recall': 1 517 | }, 518 | 'exact': { 519 | 'correct': 1, 520 | 'incorrect': 0, 521 | 'partial': 0, 522 | 'missed': 0, 523 | 'spurious': 1, 524 | 'possible': 1, 525 | 'actual': 2, 526 | 'precision': 0.5, 527 | 'recall': 1 528 | } 529 | }, 530 | 'MISC': { 531 | 'strict': { 532 | 'correct': 1, 533 | 'incorrect': 0, 534 | 'partial': 0, 535 | 'missed': 0, 536 | 'spurious': 1, 537 | 'possible': 1, 538 | 'actual': 2, 539 | 'precision': 0.5, 540 | 'recall': 1 541 | }, 542 | 'ent_type': { 543 | 'correct': 1, 544 | 'incorrect': 0, 545 | 'partial': 0, 546 | 'missed': 0, 547 | 'spurious': 1, 548 | 'possible': 1, 549 | 'actual': 2, 550 | 'precision': 0.5, 551 | 'recall': 1 552 | }, 553 | 'partial': { 554 | 'correct': 1, 555 | 'incorrect': 0, 556 | 'partial': 0, 557 | 'missed': 0, 558 | 'spurious': 1, 559 | 'possible': 1, 560 | 'actual': 2, 561 | 'precision': 0.5, 562 | 'recall': 1 563 | }, 564 | 'exact': { 565 | 'correct': 1, 566 | 'incorrect': 0, 567 | 'partial': 0, 568 | 'missed': 0, 569 | 'spurious': 1, 570 | 'possible': 1, 571 | 'actual': 2, 572 | 'precision': 0.5, 573 | 'recall': 1 574 | } 575 | } 576 | } 577 | 578 | assert results_agg["ORG"]["strict"] == expected_agg["ORG"]["strict"] 579 | assert results_agg["ORG"]["ent_type"] == expected_agg["ORG"]["ent_type"] 580 | assert results_agg["ORG"]["partial"] == expected_agg["ORG"]["partial"] 581 | assert results_agg["ORG"]["exact"] == expected_agg["ORG"]["exact"] 582 | 583 | assert results_agg["MISC"]["strict"] == expected_agg["MISC"]["strict"] 584 | assert results_agg["MISC"]["ent_type"] == expected_agg["MISC"]["ent_type"] 585 | assert results_agg["MISC"]["partial"] == expected_agg["MISC"]["partial"] 586 | assert results_agg["MISC"]["exact"] == expected_agg["MISC"]["exact"] 587 | 588 | assert results['strict'] == expected['strict'] 589 | assert results['ent_type'] == expected['ent_type'] 590 | assert results['partial'] == expected['partial'] 591 | assert results['exact'] == expected['exact'] 592 | 593 | def test_evaluator_wrong_prediction_length(): 594 | 595 | true = [ 596 | ['O', 'B-ORG', 'I-ORG', 'O', 'O'], 597 | ] 598 | 599 | pred = [ 600 | ['O', 'B-MISC', 'I-MISC', 'O'], 601 | ] 602 | 603 | evaluator = Evaluator(true, pred, tags=['PER', 'MISC']) 604 | 605 | with pytest.raises(ValueError): 606 | evaluator.evaluate() 607 | 608 | def test_evaluator_non_matching_corpus_length(): 609 | 610 | true = [ 611 | ['O', 'B-ORG', 'I-ORG', 'O', 'O'], 612 | ['O', 'O', 'O', 'O'] 613 | ] 614 | 615 | pred = [ 616 | ['O', 'B-MISC', 'I-MISC', 'O'], 617 | ] 618 | 619 | with pytest.raises(ValueError): 620 | evaluator = Evaluator(true, pred, tags=['PER', 'MISC']) 621 | 622 | -------------------------------------------------------------------------------- /ner_evaluation/tests/test_ner_evaluation.py: -------------------------------------------------------------------------------- 1 | from ner_evaluation.ner_eval import Entity 2 | from ner_evaluation.ner_eval import compute_metrics 3 | from ner_evaluation.ner_eval import collect_named_entities 4 | from ner_evaluation.ner_eval import find_overlap 5 | from ner_evaluation.ner_eval import compute_actual_possible 6 | from ner_evaluation.ner_eval import compute_precision_recall 7 | from ner_evaluation.ner_eval import compute_precision_recall_wrapper 8 | 9 | 10 | def test_collect_named_entities_same_type_in_sequence(): 11 | tags = ['O', 'B-LOC', 'I-LOC', 'B-LOC', 'I-LOC', 'O'] 12 | result = collect_named_entities(tags) 13 | expected = [Entity(e_type='LOC', start_offset=1, end_offset=2), 14 | Entity(e_type='LOC', start_offset=3, end_offset=4)] 15 | assert result == expected 16 | 17 | 18 | def test_collect_named_entities_entity_goes_until_last_token(): 19 | tags = ['O', 'B-LOC', 'I-LOC', 'B-LOC', 'I-LOC'] 20 | result = collect_named_entities(tags) 21 | expected = [Entity(e_type='LOC', start_offset=1, end_offset=2), 22 | Entity(e_type='LOC', start_offset=3, end_offset=4)] 23 | assert result == expected 24 | 25 | 26 | def test_collect_named_entities_sequence_has_only_one_entity(): 27 | tags = ['B-LOC', 'I-LOC', 'I-LOC'] 28 | result = collect_named_entities(tags) 29 | expected = [Entity(e_type='LOC', start_offset=0, end_offset=2)] 30 | assert result == expected 31 | 32 | 33 | def test_collect_named_entities_no_entity(): 34 | tags = ['O', 'O', 'O', 'O', 'O'] 35 | result = collect_named_entities(tags) 36 | expected = [] 37 | assert result == expected 38 | 39 | 40 | def test_compute_metrics_case_1(): 41 | true_named_entities = [ 42 | Entity('PER', 59, 69), 43 | Entity('LOC', 127, 134), 44 | Entity('LOC', 164, 174), 45 | Entity('LOC', 197, 205), 46 | Entity('LOC', 208, 219), 47 | Entity('MISC', 230, 240) 48 | ] 49 | 50 | pred_named_entities = [ 51 | Entity('PER', 24, 30), 52 | Entity('LOC', 124, 134), 53 | Entity('PER', 164, 174), 54 | Entity('LOC', 197, 205), 55 | Entity('LOC', 208, 219), 56 | Entity('LOC', 225, 243) 57 | ] 58 | 59 | results, results_agg = compute_metrics( 60 | true_named_entities, pred_named_entities, ['PER', 'LOC', 'MISC'] 61 | ) 62 | 63 | results = compute_precision_recall_wrapper(results) 64 | 65 | expected = {'strict': {'correct': 2, 66 | 'incorrect': 3, 67 | 'partial': 0, 68 | 'missed': 1, 69 | 'spurious': 1, 70 | 'possible': 6, 71 | 'actual': 6, 72 | 'precision': 0.3333333333333333, 73 | 'recall': 0.3333333333333333}, 74 | 'ent_type': {'correct': 3, 75 | 'incorrect': 2, 76 | 'partial': 0, 77 | 'missed': 1, 78 | 'spurious': 1, 79 | 'possible': 6, 80 | 'actual': 6, 81 | 'precision': 0.5, 82 | 'recall': 0.5}, 83 | 'partial': {'correct': 3, 84 | 'incorrect': 0, 85 | 'partial': 2, 86 | 'missed': 1, 87 | 'spurious': 1, 88 | 'possible': 6, 89 | 'actual': 6, 90 | 'precision': 0.6666666666666666, 91 | 'recall': 0.6666666666666666}, 92 | 'exact': {'correct': 3, 93 | 'incorrect': 2, 94 | 'partial': 0, 95 | 'missed': 1, 96 | 'spurious': 1, 97 | 'possible': 6, 98 | 'actual': 6, 99 | 'precision': 0.5, 100 | 'recall': 0.5} 101 | } 102 | 103 | assert results == expected 104 | 105 | 106 | def test_compute_metrics_agg_scenario_3(): 107 | 108 | true_named_entities = [Entity('PER', 59, 69)] 109 | 110 | pred_named_entities = [] 111 | 112 | results, results_agg = compute_metrics( 113 | true_named_entities, pred_named_entities, ['PER'] 114 | ) 115 | 116 | expected_agg = { 117 | 'PER': { 118 | 'strict': { 119 | 'correct': 0, 120 | 'incorrect': 0, 121 | 'partial': 0, 122 | 'missed': 1, 123 | 'spurious': 0, 124 | 'actual': 0, 125 | 'possible': 1, 126 | 'precision': 0, 127 | 'recall': 0, 128 | }, 129 | 'ent_type': { 130 | 'correct': 0, 131 | 'incorrect': 0, 132 | 'partial': 0, 133 | 'missed': 1, 134 | 'spurious': 0, 135 | 'actual': 0, 136 | 'possible': 1, 137 | 'precision': 0, 138 | 'recall': 0, 139 | }, 140 | 'partial': { 141 | 'correct': 0, 142 | 'incorrect': 0, 143 | 'partial': 0, 144 | 'missed': 1, 145 | 'spurious': 0, 146 | 'actual': 0, 147 | 'possible': 1, 148 | 'precision': 0, 149 | 'recall': 0, 150 | }, 151 | 'exact': { 152 | 'correct': 0, 153 | 'incorrect': 0, 154 | 'partial': 0, 155 | 'missed': 1, 156 | 'spurious': 0, 157 | 'actual': 0, 158 | 'possible': 1, 159 | 'precision': 0, 160 | 'recall': 0, 161 | } 162 | } 163 | } 164 | 165 | assert results_agg['PER']['strict'] == expected_agg['PER']['strict'] 166 | assert results_agg['PER']['ent_type'] == expected_agg['PER']['ent_type'] 167 | assert results_agg['PER']['partial'] == expected_agg['PER']['partial'] 168 | assert results_agg['PER']['exact'] == expected_agg['PER']['exact'] 169 | 170 | 171 | def test_compute_metrics_agg_scenario_2(): 172 | 173 | true_named_entities = [] 174 | 175 | pred_named_entities = [Entity('PER', 59, 69)] 176 | 177 | results, results_agg = compute_metrics( 178 | true_named_entities, pred_named_entities, ['PER'] 179 | ) 180 | 181 | expected_agg = { 182 | 'PER': { 183 | 'strict': { 184 | 'correct': 0, 185 | 'incorrect': 0, 186 | 'partial': 0, 187 | 'missed': 0, 188 | 'spurious': 1, 189 | 'actual': 1, 190 | 'possible': 0, 191 | 'precision': 0, 192 | 'recall': 0, 193 | }, 194 | 'ent_type': { 195 | 'correct': 0, 196 | 'incorrect': 0, 197 | 'partial': 0, 198 | 'missed': 0, 199 | 'spurious': 1, 200 | 'actual': 1, 201 | 'possible': 0, 202 | 'precision': 0, 203 | 'recall': 0, 204 | }, 205 | 'partial': { 206 | 'correct': 0, 207 | 'incorrect': 0, 208 | 'partial': 0, 209 | 'missed': 0, 210 | 'spurious': 1, 211 | 'actual': 1, 212 | 'possible': 0, 213 | 'precision': 0, 214 | 'recall': 0, 215 | }, 216 | 'exact': { 217 | 'correct': 0, 218 | 'incorrect': 0, 219 | 'partial': 0, 220 | 'missed': 0, 221 | 'spurious': 1, 222 | 'actual': 1, 223 | 'possible': 0, 224 | 'precision': 0, 225 | 'recall': 0, 226 | } 227 | } 228 | } 229 | 230 | assert results_agg['PER']['strict'] == expected_agg['PER']['strict'] 231 | assert results_agg['PER']['ent_type'] == expected_agg['PER']['ent_type'] 232 | assert results_agg['PER']['partial'] == expected_agg['PER']['partial'] 233 | assert results_agg['PER']['exact'] == expected_agg['PER']['exact'] 234 | 235 | 236 | def test_compute_metrics_agg_scenario_5(): 237 | 238 | true_named_entities = [Entity('PER', 59, 69)] 239 | 240 | pred_named_entities = [Entity('PER', 57, 69)] 241 | 242 | results, results_agg = compute_metrics( 243 | true_named_entities, pred_named_entities, ['PER'] 244 | ) 245 | 246 | expected_agg = { 247 | 'PER': { 248 | 'strict': { 249 | 'correct': 0, 250 | 'incorrect': 1, 251 | 'partial': 0, 252 | 'missed': 0, 253 | 'spurious': 0, 254 | 'actual': 1, 255 | 'possible': 1, 256 | 'precision': 0, 257 | 'recall': 0, 258 | }, 259 | 'ent_type': { 260 | 'correct': 1, 261 | 'incorrect': 0, 262 | 'partial': 0, 263 | 'missed': 0, 264 | 'spurious': 0, 265 | 'actual': 1, 266 | 'possible': 1, 267 | 'precision': 0, 268 | 'recall': 0, 269 | }, 270 | 'partial': { 271 | 'correct': 0, 272 | 'incorrect': 0, 273 | 'partial': 1, 274 | 'missed': 0, 275 | 'spurious': 0, 276 | 'actual': 1, 277 | 'possible': 1, 278 | 'precision': 0, 279 | 'recall': 0, 280 | }, 281 | 'exact': { 282 | 'correct': 0, 283 | 'incorrect': 1, 284 | 'partial': 0, 285 | 'missed': 0, 286 | 'spurious': 0, 287 | 'actual': 1, 288 | 'possible': 1, 289 | 'precision': 0, 290 | 'recall': 0, 291 | } 292 | } 293 | } 294 | 295 | assert results_agg['PER']['strict'] == expected_agg['PER']['strict'] 296 | assert results_agg['PER']['ent_type'] == expected_agg['PER']['ent_type'] 297 | assert results_agg['PER']['partial'] == expected_agg['PER']['partial'] 298 | assert results_agg['PER']['exact'] == expected_agg['PER']['exact'] 299 | 300 | 301 | def test_compute_metrics_agg_scenario_4(): 302 | 303 | true_named_entities = [Entity('PER', 59, 69)] 304 | 305 | pred_named_entities = [Entity('LOC', 59, 69)] 306 | 307 | results, results_agg = compute_metrics( 308 | true_named_entities, pred_named_entities, ['PER', 'LOC'] 309 | ) 310 | 311 | expected_agg = { 312 | 'PER': { 313 | 'strict': { 314 | 'correct': 0, 315 | 'incorrect': 1, 316 | 'partial': 0, 317 | 'missed': 0, 318 | 'spurious': 0, 319 | 'actual': 1, 320 | 'possible': 1, 321 | 'precision': 0, 322 | 'recall': 0, 323 | }, 324 | 'ent_type': { 325 | 'correct': 0, 326 | 'incorrect': 1, 327 | 'partial': 0, 328 | 'missed': 0, 329 | 'spurious': 0, 330 | 'actual': 1, 331 | 'possible': 1, 332 | 'precision': 0, 333 | 'recall': 0, 334 | }, 335 | 'partial': { 336 | 'correct': 1, 337 | 'incorrect': 0, 338 | 'partial': 0, 339 | 'missed': 0, 340 | 'spurious': 0, 341 | 'actual': 1, 342 | 'possible': 1, 343 | 'precision': 0, 344 | 'recall': 0, 345 | }, 346 | 'exact': { 347 | 'correct': 1, 348 | 'incorrect': 0, 349 | 'partial': 0, 350 | 'missed': 0, 351 | 'spurious': 0, 352 | 'actual': 1, 353 | 'possible': 1, 354 | 'precision': 0, 355 | 'recall': 0, 356 | } 357 | }, 358 | 'LOC': { 359 | 'strict': { 360 | 'correct': 0, 361 | 'incorrect': 0, 362 | 'partial': 0, 363 | 'missed': 0, 364 | 'spurious': 0, 365 | 'actual': 0, 366 | 'possible': 0, 367 | 'precision': 0, 368 | 'recall': 0, 369 | }, 370 | 'ent_type': { 371 | 'correct': 0, 372 | 'incorrect': 0, 373 | 'partial': 0, 374 | 'missed': 0, 375 | 'spurious': 0, 376 | 'actual': 0, 377 | 'possible': 0, 378 | 'precision': 0, 379 | 'recall': 0, 380 | }, 381 | 'partial': { 382 | 'correct': 0, 383 | 'incorrect': 0, 384 | 'partial': 0, 385 | 'missed': 0, 386 | 'spurious': 0, 387 | 'actual': 0, 388 | 'possible': 0, 389 | 'precision': 0, 390 | 'recall': 0, 391 | }, 392 | 'exact': { 393 | 'correct': 0, 394 | 'incorrect': 0, 395 | 'partial': 0, 396 | 'missed': 0, 397 | 'spurious': 0, 398 | 'actual': 0, 399 | 'possible': 0, 400 | 'precision': 0, 401 | 'recall': 0, 402 | } 403 | } 404 | } 405 | 406 | assert results_agg['PER']['strict'] == expected_agg['PER']['strict'] 407 | assert results_agg['PER']['ent_type'] == expected_agg['PER']['ent_type'] 408 | assert results_agg['PER']['partial'] == expected_agg['PER']['partial'] 409 | assert results_agg['PER']['exact'] == expected_agg['PER']['exact'] 410 | 411 | assert results_agg['LOC'] == expected_agg['LOC'] 412 | 413 | 414 | def test_compute_metrics_agg_scenario_1(): 415 | 416 | true_named_entities = [Entity('PER', 59, 69)] 417 | 418 | pred_named_entities = [Entity('PER', 59, 69)] 419 | 420 | results, results_agg = compute_metrics( 421 | true_named_entities, pred_named_entities, ['PER'] 422 | ) 423 | 424 | expected_agg = { 425 | 'PER': { 426 | 'strict': { 427 | 'correct': 1, 428 | 'incorrect': 0, 429 | 'partial': 0, 430 | 'missed': 0, 431 | 'spurious': 0, 432 | 'actual': 1, 433 | 'possible': 1, 434 | 'precision': 0, 435 | 'recall': 0, 436 | }, 437 | 'ent_type': { 438 | 'correct': 1, 439 | 'incorrect': 0, 440 | 'partial': 0, 441 | 'missed': 0, 442 | 'spurious': 0, 443 | 'actual': 1, 444 | 'possible': 1, 445 | 'precision': 0, 446 | 'recall': 0, 447 | }, 448 | 'partial': { 449 | 'correct': 1, 450 | 'incorrect': 0, 451 | 'partial': 0, 452 | 'missed': 0, 453 | 'spurious': 0, 454 | 'actual': 1, 455 | 'possible': 1, 456 | 'precision': 0, 457 | 'recall': 0, 458 | }, 459 | 'exact': { 460 | 'correct': 1, 461 | 'incorrect': 0, 462 | 'partial': 0, 463 | 'missed': 0, 464 | 'spurious': 0, 465 | 'actual': 1, 466 | 'possible': 1, 467 | 'precision': 0, 468 | 'recall': 0, 469 | } 470 | } 471 | } 472 | 473 | assert results_agg['PER']['strict'] == expected_agg['PER']['strict'] 474 | assert results_agg['PER']['ent_type'] == expected_agg['PER']['ent_type'] 475 | assert results_agg['PER']['partial'] == expected_agg['PER']['partial'] 476 | assert results_agg['PER']['exact'] == expected_agg['PER']['exact'] 477 | 478 | 479 | def test_compute_metrics_agg_scenario_6(): 480 | 481 | true_named_entities = [Entity('PER', 59, 69)] 482 | 483 | pred_named_entities = [Entity('LOC', 54, 69)] 484 | 485 | results, results_agg = compute_metrics( 486 | true_named_entities, pred_named_entities, ['PER', 'LOC'] 487 | ) 488 | 489 | expected_agg = { 490 | 'PER': { 491 | 'strict': { 492 | 'correct': 0, 493 | 'incorrect': 1, 494 | 'partial': 0, 495 | 'missed': 0, 496 | 'spurious': 0, 497 | 'actual': 1, 498 | 'possible': 1, 499 | 'precision': 0, 500 | 'recall': 0, 501 | }, 502 | 'ent_type': { 503 | 'correct': 0, 504 | 'incorrect': 1, 505 | 'partial': 0, 506 | 'missed': 0, 507 | 'spurious': 0, 508 | 'actual': 1, 509 | 'possible': 1, 510 | 'precision': 0, 511 | 'recall': 0, 512 | }, 513 | 'partial': { 514 | 'correct': 0, 515 | 'incorrect': 0, 516 | 'partial': 1, 517 | 'missed': 0, 518 | 'spurious': 0, 519 | 'actual': 1, 520 | 'possible': 1, 521 | 'precision': 0, 522 | 'recall': 0, 523 | }, 524 | 'exact': { 525 | 'correct': 0, 526 | 'incorrect': 1, 527 | 'partial': 0, 528 | 'missed': 0, 529 | 'spurious': 0, 530 | 'actual': 1, 531 | 'possible': 1, 532 | 'precision': 0, 533 | 'recall': 0, 534 | } 535 | }, 536 | 'LOC': { 537 | 'strict': { 538 | 'correct': 0, 539 | 'incorrect': 0, 540 | 'partial': 0, 541 | 'missed': 0, 542 | 'spurious': 0, 543 | 'actual': 0, 544 | 'possible': 0, 545 | 'precision': 0, 546 | 'recall': 0, 547 | }, 548 | 'ent_type': { 549 | 'correct': 0, 550 | 'incorrect': 0, 551 | 'partial': 0, 552 | 'missed': 0, 553 | 'spurious': 0, 554 | 'actual': 0, 555 | 'possible': 0, 556 | 'precision': 0, 557 | 'recall': 0, 558 | }, 559 | 'partial': { 560 | 'correct': 0, 561 | 'incorrect': 0, 562 | 'partial': 0, 563 | 'missed': 0, 564 | 'spurious': 0, 565 | 'actual': 0, 566 | 'possible': 0, 567 | 'precision': 0, 568 | 'recall': 0, 569 | }, 570 | 'exact': { 571 | 'correct': 0, 572 | 'incorrect': 0, 573 | 'partial': 0, 574 | 'missed': 0, 575 | 'spurious': 0, 576 | 'actual': 0, 577 | 'possible': 0, 578 | 'precision': 0, 579 | 'recall': 0, 580 | } 581 | } 582 | } 583 | 584 | assert results_agg['PER']['strict'] == expected_agg['PER']['strict'] 585 | assert results_agg['PER']['ent_type'] == expected_agg['PER']['ent_type'] 586 | assert results_agg['PER']['partial'] == expected_agg['PER']['partial'] 587 | assert results_agg['PER']['exact'] == expected_agg['PER']['exact'] 588 | 589 | assert results_agg["LOC"] == expected_agg["LOC"] 590 | 591 | 592 | def test_compute_metrics_extra_tags_in_prediction(): 593 | 594 | true_named_entities = [ 595 | Entity('PER', 50, 52), 596 | Entity('ORG', 59, 69), 597 | Entity('ORG', 71, 72), 598 | ] 599 | 600 | pred_named_entities = [ 601 | Entity('LOC', 50, 52), # Wrong type 602 | Entity('ORG', 59, 69), # Correct 603 | Entity('MISC', 71, 72), # Wrong type 604 | ] 605 | 606 | results, results_agg = compute_metrics( 607 | true_named_entities, pred_named_entities, ['PER', 'LOC', 'ORG'] 608 | ) 609 | 610 | expected = { 611 | 'strict': { 612 | 'correct': 1, 613 | 'incorrect': 1, 614 | 'partial': 0, 615 | 'missed': 1, 616 | 'spurious': 0, 617 | 'actual': 2, 618 | 'possible': 3, 619 | 'precision': 0, 620 | 'recall': 0, 621 | }, 622 | 'ent_type': { 623 | 'correct': 1, 624 | 'incorrect': 1, 625 | 'partial': 0, 626 | 'missed': 1, 627 | 'spurious': 0, 628 | 'actual': 2, 629 | 'possible': 3, 630 | 'precision': 0, 631 | 'recall': 0, 632 | }, 633 | 'partial': { 634 | 'correct': 2, 635 | 'incorrect': 0, 636 | 'partial': 0, 637 | 'missed': 1, 638 | 'spurious': 0, 639 | 'actual': 2, 640 | 'possible': 3, 641 | 'precision': 0, 642 | 'recall': 0, 643 | }, 644 | 'exact': { 645 | 'correct': 2, 646 | 'incorrect': 0, 647 | 'partial': 0, 648 | 'missed': 1, 649 | 'spurious': 0, 650 | 'actual': 2, 651 | 'possible': 3, 652 | 'precision': 0, 653 | 'recall': 0, 654 | } 655 | } 656 | 657 | assert results['strict'] == expected['strict'] 658 | assert results['ent_type'] == expected['ent_type'] 659 | assert results['partial'] == expected['partial'] 660 | assert results['exact'] == expected['exact'] 661 | 662 | 663 | def test_compute_metrics_extra_tags_in_true(): 664 | 665 | true_named_entities = [ 666 | Entity('PER', 50, 52), 667 | Entity('ORG', 59, 69), 668 | Entity('MISC', 71, 72), 669 | ] 670 | 671 | pred_named_entities = [ 672 | Entity('LOC', 50, 52), # Wrong type 673 | Entity('ORG', 59, 69), # Correct 674 | Entity('ORG', 71, 72), # Spurious 675 | ] 676 | 677 | results, results_agg = compute_metrics( 678 | true_named_entities, pred_named_entities, ['PER', 'LOC', 'ORG'] 679 | ) 680 | 681 | expected = { 682 | 'strict': { 683 | 'correct': 1, 684 | 'incorrect': 1, 685 | 'partial': 0, 686 | 'missed': 0, 687 | 'spurious': 1, 688 | 'actual': 3, 689 | 'possible': 2, 690 | 'precision': 0, 691 | 'recall': 0, 692 | }, 693 | 'ent_type': { 694 | 'correct': 1, 695 | 'incorrect': 1, 696 | 'partial': 0, 697 | 'missed': 0, 698 | 'spurious': 1, 699 | 'actual': 3, 700 | 'possible': 2, 701 | 'precision': 0, 702 | 'recall': 0, 703 | }, 704 | 'partial': { 705 | 'correct': 2, 706 | 'incorrect': 0, 707 | 'partial': 0, 708 | 'missed': 0, 709 | 'spurious': 1, 710 | 'actual': 3, 711 | 'possible': 2, 712 | 'precision': 0, 713 | 'recall': 0, 714 | }, 715 | 'exact': { 716 | 'correct': 2, 717 | 'incorrect': 0, 718 | 'partial': 0, 719 | 'missed': 0, 720 | 'spurious': 1, 721 | 'actual': 3, 722 | 'possible': 2, 723 | 'precision': 0, 724 | 'recall': 0, 725 | } 726 | } 727 | 728 | assert results['strict'] == expected['strict'] 729 | assert results['ent_type'] == expected['ent_type'] 730 | assert results['partial'] == expected['partial'] 731 | assert results['exact'] == expected['exact'] 732 | 733 | 734 | def test_compute_metrics_no_predictions(): 735 | 736 | true_named_entities = [ 737 | Entity('PER', 50, 52), 738 | Entity('ORG', 59, 69), 739 | Entity('MISC', 71, 72), 740 | ] 741 | 742 | pred_named_entities = [] 743 | 744 | results, results_agg = compute_metrics( 745 | true_named_entities, pred_named_entities, ['PER', 'ORG', 'MISC'] 746 | ) 747 | 748 | expected = { 749 | 'strict': { 750 | 'correct': 0, 751 | 'incorrect': 0, 752 | 'partial': 0, 753 | 'missed': 3, 754 | 'spurious': 0, 755 | 'actual': 0, 756 | 'possible': 3, 757 | 'precision': 0, 758 | 'recall': 0, 759 | }, 760 | 'ent_type': { 761 | 'correct': 0, 762 | 'incorrect': 0, 763 | 'partial': 0, 764 | 'missed': 3, 765 | 'spurious': 0, 766 | 'actual': 0, 767 | 'possible': 3, 768 | 'precision': 0, 769 | 'recall': 0, 770 | }, 771 | 'partial': { 772 | 'correct': 0, 773 | 'incorrect': 0, 774 | 'partial': 0, 775 | 'missed': 3, 776 | 'spurious': 0, 777 | 'actual': 0, 778 | 'possible': 3, 779 | 'precision': 0, 780 | 'recall': 0, 781 | }, 782 | 'exact': { 783 | 'correct': 0, 784 | 'incorrect': 0, 785 | 'partial': 0, 786 | 'missed': 3, 787 | 'spurious': 0, 788 | 'actual': 0, 789 | 'possible': 3, 790 | 'precision': 0, 791 | 'recall': 0, 792 | } 793 | } 794 | 795 | assert results['strict'] == expected['strict'] 796 | assert results['ent_type'] == expected['ent_type'] 797 | assert results['partial'] == expected['partial'] 798 | assert results['exact'] == expected['exact'] 799 | 800 | def test_find_overlap_no_overlap(): 801 | 802 | pred_entity = Entity('LOC', 1, 10) 803 | true_entity = Entity('LOC', 11, 20) 804 | 805 | pred_range = range(pred_entity.start_offset, pred_entity.end_offset) 806 | true_range = range(true_entity.start_offset, true_entity.end_offset) 807 | 808 | pred_set = set(pred_range) 809 | true_set = set(true_range) 810 | 811 | intersect = find_overlap(pred_set, true_set) 812 | 813 | assert not intersect 814 | 815 | 816 | def test_find_overlap_total_overlap(): 817 | 818 | pred_entity = Entity('LOC', 10, 22) 819 | true_entity = Entity('LOC', 11, 20) 820 | 821 | pred_range = range(pred_entity.start_offset, pred_entity.end_offset) 822 | true_range = range(true_entity.start_offset, true_entity.end_offset) 823 | 824 | pred_set = set(pred_range) 825 | true_set = set(true_range) 826 | 827 | intersect = find_overlap(pred_set, true_set) 828 | 829 | assert intersect 830 | 831 | 832 | def test_find_overlap_start_overlap(): 833 | 834 | pred_entity = Entity('LOC', 5, 12) 835 | true_entity = Entity('LOC', 11, 20) 836 | 837 | pred_range = range(pred_entity.start_offset, pred_entity.end_offset) 838 | true_range = range(true_entity.start_offset, true_entity.end_offset) 839 | 840 | pred_set = set(pred_range) 841 | true_set = set(true_range) 842 | 843 | intersect = find_overlap(pred_set, true_set) 844 | 845 | assert intersect 846 | 847 | 848 | def test_find_overlap_end_overlap(): 849 | 850 | pred_entity = Entity('LOC', 15, 25) 851 | true_entity = Entity('LOC', 11, 20) 852 | 853 | pred_range = range(pred_entity.start_offset, pred_entity.end_offset) 854 | true_range = range(true_entity.start_offset, true_entity.end_offset) 855 | 856 | pred_set = set(pred_range) 857 | true_set = set(true_range) 858 | 859 | intersect = find_overlap(pred_set, true_set) 860 | 861 | assert intersect 862 | 863 | 864 | def test_compute_actual_possible(): 865 | 866 | results = { 867 | 'correct': 6, 868 | 'incorrect': 3, 869 | 'partial': 2, 870 | 'missed': 4, 871 | 'spurious': 2, 872 | } 873 | 874 | expected = { 875 | 'correct': 6, 876 | 'incorrect': 3, 877 | 'partial': 2, 878 | 'missed': 4, 879 | 'spurious': 2, 880 | 'possible': 15, 881 | 'actual': 13, 882 | } 883 | 884 | out = compute_actual_possible(results) 885 | 886 | assert out == expected 887 | 888 | 889 | def test_compute_precision_recall(): 890 | 891 | results = { 892 | 'correct': 6, 893 | 'incorrect': 3, 894 | 'partial': 2, 895 | 'missed': 4, 896 | 'spurious': 2, 897 | 'possible': 15, 898 | 'actual': 13, 899 | } 900 | 901 | expected = { 902 | 'correct': 6, 903 | 'incorrect': 3, 904 | 'partial': 2, 905 | 'missed': 4, 906 | 'spurious': 2, 907 | 'possible': 15, 908 | 'actual': 13, 909 | 'precision': 0.46153846153846156, 910 | 'recall': 0.4 911 | } 912 | 913 | out = compute_precision_recall(results) 914 | 915 | assert out == expected 916 | 917 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==1.5.0 2 | nltk 3 | sklearn_crfsuite 4 | pytest 5 | coverage 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [coverage:run] 2 | source=ner_evaluation/ 3 | omit = *tests*, *__init__* 4 | 5 | [coverage:report] 6 | show_missing=True 7 | precision=2 8 | sort=Miss 9 | --------------------------------------------------------------------------------