├── .gitchangelog.rc ├── .github ├── pull_request_template.md └── workflows │ └── CI-checks.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.rst ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── compare_versions.py ├── examples ├── example_no_loader.py └── run_example.sh ├── pyproject.toml ├── src └── nervaluate │ ├── __init__.py │ ├── entities.py │ ├── evaluator.py │ ├── loaders.py │ ├── strategies.py │ └── utils.py └── tests ├── __init__.py ├── test_entities.py ├── test_evaluator.py ├── test_loaders.py ├── test_strategies.py └── test_utils.py /.gitchangelog.rc: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8; mode: python -*- 2 | ## 3 | ## Format 4 | ## 5 | ## ACTION: [AUDIENCE:] COMMIT_MSG [!TAG ...] 6 | ## 7 | ## Description 8 | ## 9 | ## ACTION is one of 'chg', 'fix', 'new' 10 | ## 11 | ## Is WHAT the change is about. 12 | ## 13 | ## 'chg' is for refactor, small improvement, cosmetic changes... 14 | ## 'fix' is for bug fixes 15 | ## 'new' is for new features, big improvement 16 | ## 17 | ## AUDIENCE is optional and one of 'dev', 'usr', 'pkg', 'test', 'doc' 18 | ## 19 | ## Is WHO is concerned by the change. 20 | ## 21 | ## 'dev' is for developpers (API changes, refactors...) 22 | ## 'usr' is for final users (UI changes) 23 | ## 'pkg' is for packagers (packaging changes) 24 | ## 'test' is for testers (test only related changes) 25 | ## 'doc' is for doc guys (doc only changes) 26 | ## 27 | ## COMMIT_MSG is ... well ... the commit message itself. 28 | ## 29 | ## TAGs are additionnal adjective as 'refactor' 'minor' 'cosmetic' 30 | ## 31 | ## They are preceded with a '!' or a '@' (prefer the former, as the 32 | ## latter is wrongly interpreted in github.) Commonly used tags are: 33 | ## 34 | ## 'refactor' is obviously for refactoring code only 35 | ## 'minor' is for a very meaningless change (a typo, adding a comment) 36 | ## 'cosmetic' is for cosmetic driven change (re-indentation, 80-col...) 37 | ## 'wip' is for partial functionality but complete subfunctionality. 38 | ## 39 | ## Example: 40 | ## 41 | ## new: usr: support of bazaar implemented 42 | ## chg: re-indentend some lines !cosmetic 43 | ## new: dev: updated code to be compatible with last version of killer lib. 44 | ## fix: pkg: updated year of licence coverage. 45 | ## new: test: added a bunch of test around user usability of feature X. 46 | ## fix: typo in spelling my name in comment. !minor 47 | ## 48 | ## Please note that multi-line commit message are supported, and only the 49 | ## first line will be considered as the "summary" of the commit message. So 50 | ## tags, and other rules only applies to the summary. The body of the commit 51 | ## message will be displayed in the changelog without reformatting. 52 | 53 | 54 | ## 55 | ## ``ignore_regexps`` is a line of regexps 56 | ## 57 | ## Any commit having its full commit message matching any regexp listed here 58 | ## will be ignored and won't be reported in the changelog. 59 | ## 60 | ignore_regexps = [ 61 | r'@minor', r'!minor', 62 | r'@cosmetic', r'!cosmetic', 63 | r'@refactor', r'!refactor', 64 | r'@wip', r'!wip', 65 | r'^([cC]hg|[fF]ix|[nN]ew)\s*:\s*[p|P]kg:', 66 | r'^([cC]hg|[fF]ix|[nN]ew)\s*:\s*[d|D]ev:', 67 | r'^(.{3,3}\s*:)?\s*[fF]irst commit.?\s*$', 68 | r'^$', ## ignore commits with empty messages 69 | ] 70 | 71 | 72 | ## ``section_regexps`` is a list of 2-tuples associating a string label and a 73 | ## list of regexp 74 | ## 75 | ## Commit messages will be classified in sections thanks to this. Section 76 | ## titles are the label, and a commit is classified under this section if any 77 | ## of the regexps associated is matching. 78 | ## 79 | ## Please note that ``section_regexps`` will only classify commits and won't 80 | ## make any changes to the contents. So you'll probably want to go check 81 | ## ``subject_process`` (or ``body_process``) to do some changes to the subject, 82 | ## whenever you are tweaking this variable. 83 | ## 84 | section_regexps = [ 85 | ('New', [ 86 | r'^[nN]ew\s*:\s*((dev|use?r|pkg|test|doc)\s*:\s*)?([^\n]*)$', 87 | ]), 88 | ('Changes', [ 89 | r'^[cC]hg\s*:\s*((dev|use?r|pkg|test|doc)\s*:\s*)?([^\n]*)$', 90 | ]), 91 | ('Fix', [ 92 | r'^[fF]ix\s*:\s*((dev|use?r|pkg|test|doc)\s*:\s*)?([^\n]*)$', 93 | ]), 94 | 95 | ('Other', None ## Match all lines 96 | ), 97 | 98 | ] 99 | 100 | 101 | ## ``body_process`` is a callable 102 | ## 103 | ## This callable will be given the original body and result will 104 | ## be used in the changelog. 105 | ## 106 | ## Available constructs are: 107 | ## 108 | ## - any python callable that take one txt argument and return txt argument. 109 | ## 110 | ## - ReSub(pattern, replacement): will apply regexp substitution. 111 | ## 112 | ## - Indent(chars=" "): will indent the text with the prefix 113 | ## Please remember that template engines gets also to modify the text and 114 | ## will usually indent themselves the text if needed. 115 | ## 116 | ## - Wrap(regexp=r"\n\n"): re-wrap text in separate paragraph to fill 80-Columns 117 | ## 118 | ## - noop: do nothing 119 | ## 120 | ## - ucfirst: ensure the first letter is uppercase. 121 | ## (usually used in the ``subject_process`` pipeline) 122 | ## 123 | ## - final_dot: ensure text finishes with a dot 124 | ## (usually used in the ``subject_process`` pipeline) 125 | ## 126 | ## - strip: remove any spaces before or after the content of the string 127 | ## 128 | ## - SetIfEmpty(msg="No commit message."): will set the text to 129 | ## whatever given ``msg`` if the current text is empty. 130 | ## 131 | ## Additionally, you can `pipe` the provided filters, for instance: 132 | #body_process = Wrap(regexp=r'\n(?=\w+\s*:)') | Indent(chars=" ") 133 | #body_process = Wrap(regexp=r'\n(?=\w+\s*:)') 134 | #body_process = noop 135 | body_process = ReSub(r'((^|\n)[A-Z]\w+(-\w+)*: .*(\n\s+.*)*)+$', r'') | strip 136 | 137 | 138 | ## ``subject_process`` is a callable 139 | ## 140 | ## This callable will be given the original subject and result will 141 | ## be used in the changelog. 142 | ## 143 | ## Available constructs are those listed in ``body_process`` doc. 144 | subject_process = (strip | 145 | ReSub(r'^([cC]hg|[fF]ix|[nN]ew)\s*:\s*((dev|use?r|pkg|test|doc)\s*:\s*)?([^\n@]*)(@[a-z]+\s+)*$', r'\4') | 146 | SetIfEmpty("No commit message.") | ucfirst | final_dot) 147 | 148 | 149 | ## ``tag_filter_regexp`` is a regexp 150 | ## 151 | ## Tags that will be used for the changelog must match this regexp. 152 | ## 153 | tag_filter_regexp = r'^[0-9]+\.[0-9]+(\.[0-9]+)?$' 154 | 155 | 156 | ## ``unreleased_version_label`` is a string or a callable that outputs a string 157 | ## 158 | ## This label will be used as the changelog Title of the last set of changes 159 | ## between last valid tag and HEAD if any. 160 | unreleased_version_label = "(unreleased)" 161 | 162 | 163 | ## ``output_engine`` is a callable 164 | ## 165 | ## This will change the output format of the generated changelog file 166 | ## 167 | ## Available choices are: 168 | ## 169 | ## - rest_py 170 | ## 171 | ## Legacy pure python engine, outputs ReSTructured text. 172 | ## This is the default. 173 | ## 174 | ## - mustache() 175 | ## 176 | ## Template name could be any of the available templates in 177 | ## ``templates/mustache/*.tpl``. 178 | ## Requires python package ``pystache``. 179 | ## Examples: 180 | ## - mustache("markdown") 181 | ## - mustache("restructuredtext") 182 | ## 183 | ## - makotemplate() 184 | ## 185 | ## Template name could be any of the available templates in 186 | ## ``templates/mako/*.tpl``. 187 | ## Requires python package ``mako``. 188 | ## Examples: 189 | ## - makotemplate("restructuredtext") 190 | ## 191 | output_engine = rest_py 192 | #output_engine = mustache("restructuredtext") 193 | #output_engine = mustache("markdown") 194 | #output_engine = makotemplate("restructuredtext") 195 | 196 | 197 | ## ``include_merge`` is a boolean 198 | ## 199 | ## This option tells git-log whether to include merge commits in the log. 200 | ## The default is to include them. 201 | include_merge = True 202 | 203 | 204 | ## ``log_encoding`` is a string identifier 205 | ## 206 | ## This option tells gitchangelog what encoding is outputed by ``git log``. 207 | ## The default is to be clever about it: it checks ``git config`` for 208 | ## ``i18n.logOutputEncoding``, and if not found will default to git's own 209 | ## default: ``utf-8``. 210 | #log_encoding = 'utf-8' 211 | 212 | 213 | ## ``publish`` is a callable 214 | ## 215 | ## Sets what ``gitchangelog`` should do with the output generated by 216 | ## the output engine. ``publish`` is a callable taking one argument 217 | ## that is an interator on lines from the output engine. 218 | ## 219 | ## Some helper callable are provided: 220 | ## 221 | ## Available choices are: 222 | ## 223 | ## - stdout 224 | ## 225 | ## Outputs directly to standard output 226 | ## (This is the default) 227 | ## 228 | ## - FileInsertAtFirstRegexMatch(file, pattern, idx=lamda m: m.start()) 229 | ## 230 | ## Creates a callable that will parse given file for the given 231 | ## regex pattern and will insert the output in the file. 232 | ## ``idx`` is a callable that receive the matching object and 233 | ## must return a integer index point where to insert the 234 | ## the output in the file. Default is to return the position of 235 | ## the start of the matched string. 236 | ## 237 | ## - FileRegexSubst(file, pattern, replace, flags) 238 | ## 239 | ## Apply a replace inplace in the given file. Your regex pattern must 240 | ## take care of everything and might be more complex. Check the README 241 | ## for a complete copy-pastable example. 242 | ## 243 | # publish = FileInsertIntoFirstRegexMatch( 244 | # "CHANGELOG.rst", 245 | # r'/(?P[0-9]+\.[0-9]+(\.[0-9]+)?)\s+\([0-9]+-[0-9]{2}-[0-9]{2}\)\n--+\n/', 246 | # idx=lambda m: m.start(1) 247 | # ) 248 | #publish = stdout 249 | 250 | 251 | ## ``revs`` is a list of callable or a list of string 252 | ## 253 | ## callable will be called to resolve as strings and allow dynamical 254 | ## computation of these. The result will be used as revisions for 255 | ## gitchangelog (as if directly stated on the command line). This allows 256 | ## to filter exaclty which commits will be read by gitchangelog. 257 | ## 258 | ## To get a full documentation on the format of these strings, please 259 | ## refer to the ``git rev-list`` arguments. There are many examples. 260 | ## 261 | ## Using callables is especially useful, for instance, if you 262 | ## are using gitchangelog to generate incrementally your changelog. 263 | ## 264 | ## Some helpers are provided, you can use them:: 265 | ## 266 | ## - FileFirstRegexMatch(file, pattern): will return a callable that will 267 | ## return the first string match for the given pattern in the given file. 268 | ## If you use named sub-patterns in your regex pattern, it'll output only 269 | ## the string matching the regex pattern named "rev". 270 | ## 271 | ## - Caret(rev): will return the rev prefixed by a "^", which is a 272 | ## way to remove the given revision and all its ancestor. 273 | ## 274 | ## Please note that if you provide a rev-list on the command line, it'll 275 | ## replace this value (which will then be ignored). 276 | ## 277 | ## If empty, then ``gitchangelog`` will act as it had to generate a full 278 | ## changelog. 279 | ## 280 | ## The default is to use all commits to make the changelog. 281 | #revs = ["^1.0.3", ] 282 | #revs = [ 283 | # Caret( 284 | # FileFirstRegexMatch( 285 | # "CHANGELOG.rst", 286 | # r"(?P[0-9]+\.[0-9]+(\.[0-9]+)?)\s+\([0-9]+-[0-9]{2}-[0-9]{2}\)\n--+\n")), 287 | # "HEAD" 288 | #] 289 | revs = [] 290 | 291 | include_merge = False 292 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ### Related Issues 2 | 3 | - fixes #issue-number 4 | 5 | ### Proposed Changes: 6 | 7 | 8 | 9 | 10 | ### How did you test it? 11 | 12 | 13 | 14 | ### Notes for the reviewer 15 | 16 | 17 | 18 | ### Checklist 19 | 20 | - I have read the [contributors guidelines](https://github.com/deepset-ai/haystack/blob/main/CONTRIBUTING.md) and the [code of conduct](https://github.com/deepset-ai/haystack/blob/main/code_of_conduct.txt) 21 | - I have updated the related issue with new insights and changes 22 | - I added unit tests and updated the docstrings 23 | - I've used one of the [conventional commit types](https://www.conventionalcommits.org/en/v1.0.0/) for my PR title: `fix:`, `feat:`, `build:`, `chore:`, `ci:`, `docs:`, `style:`, `refactor:`, `perf:`, `test:` and added `!` in case the PR includes breaking changes. 24 | - I documented my code 25 | - I ran [pre-commit hooks](https://github.com/deepset-ai/haystack/blob/main/CONTRIBUTING.md#installation) and fixed any issue 26 | -------------------------------------------------------------------------------- /.github/workflows/CI-checks.yml: -------------------------------------------------------------------------------- 1 | name: Linting, Type Checking, and Testing 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest, windows-latest, macos-latest] 15 | python-version: ["3.11"] 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install Hatch 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install hatch 29 | 30 | - name: Running linters 31 | run: | 32 | hatch -e dev run lint 33 | 34 | - name: Type checking with mypy 35 | run: | 36 | hatch -e dev run typing 37 | 38 | - name: Running tests 39 | run: | 40 | hatch -e dev run test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **.coverage 2 | **.ipynb_checkpoints/ 3 | **.mypy_cache/ 4 | **/.python-version 5 | **__pycache__/ 6 | .tox/ 7 | .venv/ 8 | build/ 9 | coverage.xml 10 | dist/ 11 | nervaluate.egg-info/ 12 | **/.DS_Store 13 | .idea 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.1.0 5 | hooks: 6 | - id: check-yaml 7 | 8 | - repo: https://github.com/psf/black 9 | rev: 22.3.0 10 | hooks: 11 | - id: black 12 | args: [-t, py38, -l 120] 13 | 14 | - repo: local 15 | hooks: 16 | - id: pylint 17 | name: pylint 18 | entry: pylint 19 | language: system 20 | types: [ python ] 21 | args: [--rcfile=pylint.cfg] 22 | 23 | - repo: local 24 | hooks: 25 | - id: flake8 26 | name: flake8 27 | entry: flake8 28 | language: system 29 | types: [ python ] 30 | args: [--config=setup.cfg] 31 | 32 | - repo: local 33 | hooks: 34 | - id: mypy 35 | name: mypy 36 | entry: mypy 37 | language: python 38 | language_version: python3.8 39 | types: [python] 40 | exclude: examples|tests 41 | require_serial: true # use require_serial so that script is only called once per commit 42 | verbose: true # print the number of files as a sanity-check 43 | args: [--config, setup.cfg] -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | 5 | (unreleased) 6 | ------------ 7 | - Fixing pandas dependency. [David S. Batista] 8 | 9 | 10 | 0.3.0 (2025-06-05) 11 | ------------------ 12 | 13 | Changes 14 | ~~~~~~~ 15 | - Update changelog for 0.2.0 release. [Matthew Upson] 16 | 17 | Fix 18 | ~~~ 19 | - Mypy configuration error. [angelo-digian] 20 | - Typo in type annotation. [angelo-digian] 21 | - Switched order of imports. [angelo-digian] 22 | 23 | Other 24 | ~~~~~ 25 | - 0.3.0 release. [David S. Batista] 26 | - Adding deprecation warnings. [David S. Batista] 27 | - Create pull_request_template.md. [David S. Batista] 28 | - Upgrading dev tools versions. [David S. Batista] 29 | - Initial import. [David S. Batista] 30 | - Adding scenario type for summary report. [David S. Batista] 31 | - Update README.md. [David S. Batista] 32 | - Updating README.MD. [David S. Batista] 33 | - Removing unused variable. [David S. Batista] 34 | - Update src/nervaluate/reporting.py. [Copilot, David S. Batista] 35 | - Update src/nervaluate/reporting.py. [Copilot, David S. Batista] 36 | - Removing Makefile. [David S. Batista] 37 | - Drafting CONTRIBUTE.md. [David S. Batista] 38 | - Drafting CONTRIBUTE.md. [David S. Batista] 39 | - Removing flake8. [David S. Batista] 40 | - Removing old config files. [David S. Batista] 41 | - Running on ubuntu, windows and macos. [David S. Batista] 42 | - Reverting to ubuntu only. [David S. Batista] 43 | - Adding new file. [David S. Batista] 44 | - Removing old workflow file. [David S. Batista] 45 | - Adding windows and macos to CI. [David S. Batista] 46 | - Streamlining CI checks. [David S. Batista] 47 | - Disabling old github workflow and triggering new one. [David S. 48 | Batista] 49 | - Changing github workflow. [David S. Batista] 50 | - Fixing linting and typing issues. [David S. Batista] 51 | - Adding pytest-cov as dependency. [David S. Batista] 52 | - Adding hatch as project manager; linting and typing. [David S. 53 | Batista] 54 | - Fixing type hints. [David S. Batista] 55 | - Wip. [David S. Batista] 56 | - Adding docstrings. [David S. Batista] 57 | - Adding more tests. [David S. Batista] 58 | - Adding more tests. [David S. Batista] 59 | - Adding docstrings and increasing test coverage. [David S. Batista] 60 | - Removing requirements_dev.txt. [David S. Batista] 61 | - Blackening for py311. [David S. Batista] 62 | - Fixing pyprojec.toml dependencies. [David S. Batista] 63 | - Fixing pyprojec.toml dependencies. [David S. Batista] 64 | - Fixing pyprojec.toml dependencies. [David S. Batista] 65 | - Fixing pyprojec.toml dependencies. [David S. Batista] 66 | - Fixing pyprojec.toml dependencies. [David S. Batista] 67 | - Refactor: move dev dependencies to pyproject.toml and update CI 68 | workflow. [David S. Batista] 69 | - Adding wrongly removed pre-commit. [David S. Batista] 70 | - Fixing type hints. [David S. Batista] 71 | - Removing unused imports and mutuable default arguments. [David S. 72 | Batista] 73 | - Update README.md. [Tim Miller] 74 | - Update README.md. [adgianv] 75 | - Update README.md - change the pdf link. [adgianv] 76 | - Added type annotations to functions. [angelo-digian] 77 | - Pandas version downgraded to 2.0.1 because incompatible with python 78 | version. [angelo-digian] 79 | - Fixed pandas version to 2.2.1. [angelo-digian] 80 | - Add pandas as a dependency in pyproject.toml. [angelo-digian] 81 | - Adding pandas in the requirements file. [angelo-digian] 82 | - Update tests/test_evaluator.py. [David S. Batista] 83 | - Modified results_to_df method and added test. [angelo-digian] 84 | - Expanded evaluator class: added method to return results of the nested 85 | dictionary as a dataframe. [angelo-digian] 86 | 87 | 88 | 0.2.0 (2024-04-10) 89 | ------------------ 90 | 91 | New 92 | ~~~ 93 | - Add pre-commit. [Matthew Upson] 94 | - Add CITATION.cff file. [Matthew Upson] 95 | - Upload artefacts to codecov. [Matthew Upson] 96 | - Run tests on windows instance. [Matthew Upson] 97 | 98 | Changes 99 | ~~~~~~~ 100 | - Add codecov config. [Matthew Upson] 101 | - Remove .travis.yml. [Matthew Upson] 102 | - Update tox.ini. [Matthew Upson] 103 | - Update versions to test. [Matthew Upson] 104 | - Add tox tests as github action. [Matthew Upson] 105 | 106 | Fix 107 | ~~~ 108 | - Grant write permission to CICD workflow. [Matthew Upson] 109 | - Run on windows and linux matrix. [Matthew Upson] 110 | 111 | Other 112 | ~~~~~ 113 | - Updates README to reflect new functionality. [Jack Boylan] 114 | - Removes extra 'indices' printed. [Jack Boylan] 115 | - Bump black from 23.3.0 to 24.3.0. [dependabot[bot]] 116 | 117 | Bumps [black](https://github.com/psf/black) from 23.3.0 to 24.3.0. 118 | - [Release notes](https://github.com/psf/black/releases) 119 | - [Changelog](https://github.com/psf/black/blob/main/CHANGES.md) 120 | - [Commits](https://github.com/psf/black/compare/23.3.0...24.3.0) 121 | 122 | --- 123 | updated-dependencies: 124 | - dependency-name: black 125 | dependency-type: direct:development 126 | ... 127 | - Fixed Typo in README. [Giovanni Casari] 128 | - Reformats quotes in `test_nervaluate.py` [Jack Boylan] 129 | - Initial import. [David S. Batista] 130 | - Handles case when `predictions` is empty. [Jack Boylan] 131 | - Adds unit tests for evaluation indices output. [Jack Boylan] 132 | - Adds summary print functions for overall indices and per-entity 133 | indices results. [Jack Boylan] 134 | - Adds `within_instance_index` to evaluation indices outputs. [Jack 135 | Boylan] 136 | - Ensures compatibility with existing unit tests. [Jack Boylan] 137 | - Adheres to code quality checks. [Jack Boylan] 138 | - Adds more descriptive variable names. [Jack Boylan] 139 | - Adds correct indices to result indices output. [Jack Boylan] 140 | - Moves evaluation indices to separate data structures. [Jack Boylan] 141 | - Adds index lists to output for examples with incorrect, partial, 142 | spurious, and missed entities. [Jack Boylan] 143 | - Docs: fix typo "spurius" > "spurious" [DanShatford] 144 | - Added test for issue #40. [g.casari] 145 | - Solved issue #40. [g.casari] 146 | - Update README.md. [David S. Batista] 147 | - Cleaning README.MD. [David S. Batista] 148 | - Attending PR comments. [David S. Batista] 149 | - Fixing links on README.MD. [David S. Batista] 150 | - Updating pyproject.toml. [David S. Batista] 151 | - Updating pyproject.toml. [David S. Batista] 152 | - Updating README.MD and bumping version to 0.2.0. [David S. Batista] 153 | - Updating README.MD. [David S. Batista] 154 | - Reverting to Python 3.8. [David S. Batista] 155 | - Adding some badges to the README. [David S. Batista] 156 | - Initial commit. [David S. Batista] 157 | - Wip: adding poetry. [David S. Batista] 158 | - Full working example. [David S. Batista] 159 | - Nit. [David S. Batista] 160 | - Wip: adding summary report and examples. [David S. Batista] 161 | - Wip: adding summary report and examples. [David S. Batista] 162 | - Wip: adding summary report and examples. [David S. Batista] 163 | - Wip: adding summary report and examples. [David S. Batista] 164 | - Wip: adding summary report and examples. [David S. Batista] 165 | - Wip: adding summary report. [David S. Batista] 166 | - Wip: adding summary report. [David S. Batista] 167 | - Removed codecov from requirements.txt. [David S. Batista] 168 | - Removing duplicated code and fixing type hit. [David S. Batista] 169 | - Updated Makefile: install package in editable mode. [David S. Batista] 170 | - Updated name. [David S. Batista] 171 | - Minimum version Python 3.8. [David S. Batista] 172 | - Fixing Makefile and pre-commit. [David S. Batista] 173 | - Adding DS_Store and .idea to gitignore. [David S. Batista] 174 | - Updating Makefile. [David S. Batista] 175 | - WIP: pre-commit. [David S. Batista] 176 | - WIP: pre-commit. [David S. Batista] 177 | - WIP: pre-commit. [David S. Batista] 178 | - WIP: pre-commit. [David S. Batista] 179 | - WIP: pre-commit. [David S. Batista] 180 | - WIP: pre-commit. [David S. Batista] 181 | - WIP: pre-commit. [David S. Batista] 182 | - WIP: pre-commit. [David S. Batista] 183 | - Fixing types. [David S. Batista] 184 | - Finished adding type hints, some were skipped, code needs refactoring. 185 | [David S. Batista] 186 | - WIP: adding type hints. [David S. Batista] 187 | - WIP: adding type hints. [David S. Batista] 188 | - WIP: adding type hints. [David S. Batista] 189 | - WIP: adding type hints. [David S. Batista] 190 | - Adding some execptions, code needs refactoring. [David S. Batista] 191 | - Fixing pyling and flake8 issues. [David S. Batista] 192 | - Replaced setup.py with pyproject.toml. [David S. Batista] 193 | - Reverting utils import. [David S. Batista] 194 | - Fixing types and wrappint at 120 characters. [David S. Batista] 195 | - Update CITATION.cff. [David S. Batista] 196 | 197 | updating orcid 198 | - Fix recall formula readme. [fgh95] 199 | - Update LICENSE. [ivyleavedtoadflax] 200 | - Update LICENSE. [ivyleavedtoadflax] 201 | - Delete .python-version. [ivyleavedtoadflax] 202 | 203 | 204 | 0.1.8 (2020-10-16) 205 | ------------------ 206 | 207 | New 208 | ~~~ 209 | - Add test for whole span length entities (see #32) [Matthew Upson] 210 | - Summarise blog post in README. [Matthew Upson] 211 | 212 | Changes 213 | ~~~~~~~ 214 | - Bump version in setup.py. [Matthew Upson] 215 | - Update CHANGELOG (#36) [ivyleavedtoadflax] 216 | - Fix tests to match #32. [Matthew Upson] 217 | 218 | Fix 219 | ~~~ 220 | - Correct catch sequence of just one entity. [Matthew Upson] 221 | 222 | Incorporate edits in #28 but includes tests. 223 | 224 | Other 225 | ~~~~~ 226 | - Add code coverage. [ivyleavedtoadflax] 227 | - Crucial fixes for evaluation. [Alex Flückiger] 228 | - Update utils.py. [ivyleavedtoadflax] 229 | 230 | Tiny change to kick off CI 231 | - Fix to catch last entites Small change to catch entities that go up 232 | until last character when there is no tag. [pim] 233 | 234 | 235 | 0.1.7 (2019-12-07) 236 | ------------------ 237 | 238 | New 239 | ~~~ 240 | - Add tests. [Matthew Upson] 241 | 242 | * Linting 243 | * Rename existing tests to disambiguate 244 | - Add loaders to nervaluate. [Matthew Upson] 245 | 246 | * Add list and conll formats 247 | 248 | Changes 249 | ~~~~~~~ 250 | - Update README. [Matthew Upson] 251 | 252 | Fix 253 | ~~~ 254 | - Issue with setup.py. [Matthew Upson] 255 | 256 | * Add docstring to __version__.py 257 | 258 | 259 | 0.1.6 (2019-12-07) 260 | ------------------ 261 | 262 | New 263 | ~~~ 264 | - Add gitchangelog and Makefile recipe. [Matthew Upson] 265 | 266 | Changes 267 | ~~~~~~~ 268 | - Bump version to 0.1.6. [Matthew Upson] 269 | - Remove examples. [Matthew Upson] 270 | 271 | These are not accessible from the package in any case. 272 | - Add dev requirements. [Matthew Upson] 273 | 274 | 275 | 0.1.5 (2019-12-06) 276 | ------------------ 277 | 278 | Changes 279 | ~~~~~~~ 280 | - Bump version to 0.1.5. [Matthew Upson] 281 | - Update setup.py. [Matthew Upson] 282 | - Update package url to point at pypi. [Matthew Upson] 283 | 284 | 285 | 0.1.4 (2019-12-06) 286 | ------------------ 287 | 288 | New 289 | ~~~ 290 | - Add dist to .gitignore. [Matthew Upson] 291 | - Create pypi friendly README/long description. [Matthew Upson] 292 | - Clean entity dicts of extraneous keys. [Matthew Upson] 293 | 294 | * Failing to do this can cause problems in evaluations 295 | * Add tests 296 | 297 | Changes 298 | ~~~~~~~ 299 | - Bump version to 0.1.4. [Matthew Upson] 300 | - Make setup.py pypi compliant. [Matthew Upson] 301 | 302 | 303 | 0.1.2 (2019-12-04) 304 | ------------------ 305 | 306 | New 307 | ~~~ 308 | - Add missing prodigy format tests. [Matthew Upson] 309 | - Pass argument when using list. [Matthew Upson] 310 | - Setup module structure. [Matthew Upson] 311 | - Add get_tags() and tests. [Matthew Upson] 312 | 313 | Adds function to extract all the NER tags from a list of sentences. 314 | - Add Evaluator class. [Matthew Upson] 315 | 316 | * Add some logging statements 317 | * Add input checks on number of documents and tokens per document 318 | * Allow target labels to be passed as argument to compute_metrics. Note 319 | that if a label is predicted and it is not in this list, then it 320 | will be classed as spurious for the aggregated scores, and on each 321 | entity level result (because it is unclear where the spurious value 322 | should be applied, it is applied to all) 323 | * linting 324 | * Add many new tests 325 | - Don't evaluate precision and recall for each sentence. [Matthew Upson] 326 | 327 | Rather than automatically calculate precision and recall at the sentence 328 | level, this change adds a new function compute_precision_recall_wrapper 329 | which can be run after all the metrics whether for 1 document, or 1000, 330 | have been calculated. This has the benefit that we can reuse the same 331 | code for calculating precision/recall, and allows us to calculate entity 332 | level precision/recall if required. 333 | - Calculate entity level score. [Matthew Upson] 334 | - Add compute_actual_possible function. [Matthew Upson] 335 | - Record results for each entity type. [Matthew Upson] 336 | - Add scenario comments matching blog table. [Matthew Upson] 337 | - Test results at individual entity level. [Matthew Upson] 338 | - Add .gitinore file. [Matthew Upson] 339 | - Add requirements.txt. [Matthew Upson] 340 | 341 | Changes 342 | ~~~~~~~ 343 | - Bump version to 0.1.2. [Matthew Upson] 344 | - Bump version number to 0.1.1. [Matthew Upson] 345 | - Reduce logging verbosity. [ivyleavedtoadflax] 346 | - Add example to README.md. [Matthew Upson] 347 | - Create virtualenv recipe. [Matthew Upson] 348 | 349 | * Move example dependencies to requirements_example.txt 350 | * Add virtualenv recipe to Makefile 351 | * Update .gitignore 352 | - Remove unused dependencies. [Matthew Upson] 353 | 354 | * Dependencies for the examples should not be included in setup.py, instead 355 | move them to requirements_examples.txt 356 | - Update example notebook. [Matthew Upson] 357 | - Remove unwanted tags from pred_named_entities. [Matthew Upson] 358 | - Remove superfluous get_tags() function. [Matthew Upson] 359 | - Update notebook. [Matthew Upson] 360 | - Update notebook. [Matthew Upson] 361 | - Update tests. [Matthew Upson] 362 | - Update .gitignore. [Matthew Upson] 363 | - Replace spurius with spurious. [Matthew Upson] 364 | - Update README with requirements and test info. [Matthew Upson] 365 | - Update setup.cfg with source and omit paths. [Matthew Upson] 366 | - Use pytest instead of unittest. [Matthew Upson] 367 | 368 | Other 369 | ~~~~~ 370 | - Revert "Remove tox and use pytest" [Matthew Upson] 371 | 372 | * Better to keep tox for local testing in the Makefile and resolve 373 | issues running tox on the developers machine. 374 | 375 | This reverts commit 8578795e62ca384adf054c1b85a1c1d7f0d089d5. 376 | - Remove tox and use pytest. [Elizabeth Gallagher] 377 | - Add f1 output to nervaluate and update all tests. [Elizabeth 378 | Gallagher] 379 | - Update .travis.yml. [ivyleavedtoadflax] 380 | - Update README.md. [Matt Upson] 381 | - Build(deps): bump nltk from 3.4.4 to 3.4.5. [dependabot[bot]] 382 | 383 | Bumps [nltk](https://github.com/nltk/nltk) from 3.4.4 to 3.4.5. 384 | - [Release notes](https://github.com/nltk/nltk/releases) 385 | - [Changelog](https://github.com/nltk/nltk/blob/develop/ChangeLog) 386 | - [Commits](https://github.com/nltk/nltk/compare/3.4.4...3.4.5) 387 | - Update __version__.py. [Matt Upson] 388 | - PEPed8 things a bit. [David Soares Batista] 389 | - Update README.md. [David S. Batista] 390 | - Update README.md. [David S. Batista] 391 | - Notebook. [David Soares Batista] 392 | - Updated notebook. [David Soares Batista] 393 | - Update README.md. [David S. Batista] 394 | - Update README.md. [David S. Batista] 395 | - Renamed notebook. [David Soares Batista] 396 | - Bug fixing. [David Soares Batista] 397 | - Test. [David Soares Batista] 398 | - Typo in comment. [David Soares Batista] 399 | - Use find_overlap to find all overlap cases. [Matthew Upson] 400 | 401 | Adds the find_overlap function which captures the three possible overlap 402 | scenarios (Total, Start, and End). This is examplained in graph below. 403 | 404 | Character Offset: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 405 | True: | | | |LOC|LOC|LOC|LOC|LOC| | | 406 | Total Overlap: | | |LOC|LOC|LOC|LOC|LOC|LOC|LOC| | 407 | Start Overlap: | | |LOC|LOC|LOC| | | | | | 408 | End Overlap: | | | | | | |LOC|LOC|LOC| | 409 | - Removed debug stamt. [David Soares Batista] 410 | - Added partial and exact evaluation and tests. [David Soares Batista] 411 | - Update. [David Soares Batista] 412 | - Updated README. [David Soares Batista] 413 | - - fixed bugs and added tests - added pytest. [David Soares Batista] 414 | - Update ner_evaluation.py. [David S. Batista] 415 | - Redefined evaluation according to discussion here: 416 | https://github.com/davidsbatista/NER-Evaluation/issues/2. [David 417 | Soares Batista] 418 | - Fixed a BUG in collect_named_entites() issued by 419 | rjlotok.dblma@gmail.com. [David Soares Batista] 420 | - Update README.md. [David S. Batista] 421 | - Update README.md. [David S. Batista] 422 | - Major refactoring. [David Soares Batista] 423 | - Create README.md. [David S. Batista] 424 | - Initial import. [David Soares Batista] 425 | - Initial commit. [David S. Batista] 426 | 427 | 428 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | title: "nervaluate" 4 | date-released: 2025-06-08 5 | url: "https://github.com/mantisnlp/nervaluate" 6 | version: 1.0.0 7 | authors: 8 | - family-names: "Batista" 9 | given-names: "David" 10 | orcid: "https://orcid.org/0000-0002-9324-5773" 11 | - family-names: "Upson" 12 | given-names: "Matthew Antony" 13 | orcid: "https://orcid.org/0000-0002-1040-8048" 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to `nervaluate` 2 | 3 | Thank you for your interest in contributing to `nervaluate`! This document provides guidelines and instructions for contributing to the project. 4 | 5 | ## Development Setup 6 | 7 | 1. Fork the repository 8 | 2. Clone your fork: 9 | ```bash 10 | git clone https://github.com/your-username/nervaluate.git 11 | cd nervaluate 12 | ``` 13 | 3. Make sure you have hatch installed, then create a virtual environment: 14 | # ToDo 15 | 16 | ## Adding Tests 17 | 18 | `nervaluate` uses pytest for testing. Here are the guidelines for adding tests: 19 | 20 | 1. All new features and bug fixes should include tests 21 | 2. Tests should be placed in the `tests/` directory 22 | 3. Test files should be named `test_*.py` 23 | 4. Test functions should be named `test_*` 24 | 5. Use pytest fixtures when appropriate for test setup and teardown 25 | 6. Run tests locally before submitting a pull request: 26 | ```bash 27 | hatch -e 28 | ``` 29 | 30 | 31 | ## Changelog Management 32 | 33 | `nervaluate` uses gitchangelog to maintain the CHANGELOG.rst file. Here's how to use it: 34 | 35 | 1. Make your changes in a new branch 36 | 2. Write your commit messages following these conventions: 37 | - Use present tense ("Add feature" not "Added feature") 38 | - Use imperative mood ("Move cursor to..." not "Moves cursor to...") 39 | - Limit the first line to 72 characters or less 40 | - Reference issues and pull requests liberally after the first line 41 | 42 | 3. The commit message format should be: 43 | ``` 44 | type(scope): subject 45 | 46 | body 47 | ``` 48 | 49 | Where type can be: 50 | - feat: A new feature 51 | - fix: A bug fix 52 | - docs: Documentation changes 53 | - style: Changes that do not affect the meaning of the code 54 | - refactor: A code change that neither fixes a bug nor adds a feature 55 | - perf: A code change that improves performance 56 | - test: Adding missing tests or correcting existing tests 57 | - chore: Changes to the build process or auxiliary tools 58 | 59 | 4. After committing your changes, you can generate the changelog: 60 | ```bash 61 | gitchangelog > CHANGELOG.rst 62 | ``` 63 | 64 | ## Pull Request Process 65 | 66 | 1. Update the README.md with details of changes if needed 67 | 2. Update the CHANGELOG.rst using gitchangelog 68 | 3. The PR will be merged once you have the sign-off of at least one other developer 69 | 4. Make sure all tests pass and there are no linting errors 70 | 71 | ## Code Style 72 | 73 | - Follow PEP 8 guidelines 74 | - Use type hints 75 | 76 | ## Questions? 77 | 78 | Feel free to open an issue if you have any questions about contributing to `nervaluate`. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 David S. Batista and Matthew A. Upson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![python](https://img.shields.io/badge/Python-3.11-3776AB.svg?style=flat&logo=python&logoColor=white)](https://www.python.org) 2 |   3 | [![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/) 4 |   5 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 |   7 | ![GitHub](https://img.shields.io/github/license/ivyleavedtoadflax/nervaluate) 8 |   9 | ![Pull Requests Welcome](https://img.shields.io/badge/pull%20requests-welcome-brightgreen.svg) 10 |   11 | ![PyPI](https://img.shields.io/pypi/v/nervaluate) 12 | 13 | # nervaluate 14 | 15 | `nervaluate` is a module for evaluating Named Entity Recognition (NER) models as defined in the SemEval 2013 - 9.1 task. 16 | 17 | The evaluation metrics output by nervaluate go beyond a simple token/tag based schema, and consider different scenarios 18 | based on whether all the tokens that belong to a named entity were classified or not, and also whether the correct 19 | entity type was assigned. 20 | 21 | This full problem is described in detail in the [original blog](http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/) 22 | post by [David Batista](https://github.com/davidsbatista), and this package extends the code in the [original repository](https://github.com/davidsbatista/NER-Evaluation) 23 | which accompanied the blog post. 24 | 25 | The code draws heavily on the papers: 26 | 27 | * [SemEval-2013 Task 9 : Extraction of Drug-Drug Interactions from Biomedical Texts (DDIExtraction 2013)](https://www.aclweb.org/anthology/S13-2056) 28 | 29 | * [SemEval-2013 Task 9.1 - Evaluation Metrics](https://davidsbatista.net/assets/documents/others/semeval_2013-task-9_1-evaluation-metrics.pdf) 30 | 31 | # Usage example 32 | 33 | ``` 34 | pip install nervaluate 35 | ``` 36 | 37 | A possible input format are lists of NER labels, where each list corresponds to a sentence and each label is a token label. 38 | Initialize the `Evaluator` class with the true labels and predicted labels, and specify the entity types we want to evaluate. 39 | 40 | ```python 41 | from nervaluate.evaluator import Evaluator 42 | 43 | true = [ 44 | ['O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'I-ORG'], # "The John Smith who works at Google Inc" 45 | ['O', 'B-LOC', 'B-PER', 'I-PER', 'O', 'O', 'B-DATE'], # "In Paris Marie Curie lived in 1895" 46 | ] 47 | 48 | pred = [ 49 | ['O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-ORG', 'I-ORG'], 50 | ['O', 'B-LOC', 'I-LOC', 'B-PER', 'O', 'O', 'B-DATE'], 51 | ] 52 | 53 | evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") 54 | ``` 55 | 56 | Print the summary report for the evaluation, which will show the metrics for each entity type and evaluation scenario: 57 | 58 | ```python 59 | 60 | print(evaluator.summary_report()) 61 | 62 | Scenario: all 63 | 64 | correct incorrect partial missed spurious precision recall f1-score 65 | 66 | ent_type 5 0 0 0 0 1.00 1.00 1.00 67 | exact 2 3 0 0 0 0.40 0.40 0.40 68 | partial 2 0 3 0 0 0.40 0.40 0.40 69 | strict 2 3 0 0 0 0.40 0.40 0.40 70 | ``` 71 | 72 | or aggregated by entity type under a specific evaluation scenario: 73 | 74 | ```python 75 | print(evaluator.summary_report(mode='entities')) 76 | 77 | Scenario: strict 78 | 79 | correct incorrect partial missed spurious precision recall f1-score 80 | 81 | DATE 1 0 0 0 0 1.00 1.00 1.00 82 | LOC 0 1 0 0 0 0.00 0.00 0.00 83 | ORG 1 0 0 0 0 1.00 1.00 1.00 84 | PER 0 2 0 0 0 0.00 0.00 0.00 85 | ``` 86 | 87 | # Evaluation Scenarios 88 | 89 | ## Token level evaluation for NER is too simplistic 90 | 91 | When running machine learning models for NER, it is common to report metrics at the individual token level. This may 92 | not be the best approach, as a named entity can be made up of multiple tokens, so a full-entity accuracy would be 93 | desirable. 94 | 95 | When comparing the golden standard annotations with the output of a NER system different scenarios might occur: 96 | 97 | __I. Surface string and entity type match__ 98 | 99 | | Token | Gold | Prediction | 100 | |-------|-------|------------| 101 | | in | O | O | 102 | | New | B-LOC | B-LOC | 103 | | York | I-LOC | I-LOC | 104 | | . | O | O | 105 | 106 | __II. System hypothesized an incorrect entity__ 107 | 108 | | Token | Gold | Prediction | 109 | |----------|------|------------| 110 | | an | O | O | 111 | | Awful | O | B-ORG | 112 | | Headache | O | I-ORG | 113 | | in | O | O | 114 | 115 | __III. System misses an entity__ 116 | 117 | | Token | Gold | Prediction | 118 | |-------|-------|------------| 119 | | in | O | O | 120 | | Palo | B-LOC | O | 121 | | Alto | I-LOC | O | 122 | | , | O | O | 123 | 124 | Based on these three scenarios we have a simple classification evaluation that can be measured in terms of false 125 | positives, true positives, false negatives and false positives, and subsequently compute precision, recall and 126 | F1-score for each named-entity type. 127 | 128 | However, this simple schema ignores the possibility of partial matches or other scenarios when the NER system gets 129 | the named-entity surface string correct but the type wrong. We might also want to evaluate these scenarios 130 | again at a full-entity level. 131 | 132 | For example: 133 | 134 | __IV. System identifies the surface string but assigns the wrong entity type__ 135 | 136 | | Token | Gold | Prediction | 137 | |-------|-------|------------| 138 | | I | O | O | 139 | | live | O | O | 140 | | in | O | O | 141 | | Palo | B-LOC | B-ORG | 142 | | Alto | I-LOC | I-ORG | 143 | | , | O | O | 144 | 145 | __V. System gets the boundaries of the surface string wrong__ 146 | 147 | | Token | Gold | Prediction | 148 | |---------|-------|------------| 149 | | Unless | O | B-PER | 150 | | Karl | B-PER | I-PER | 151 | | Smith | I-PER | I-PER | 152 | | resigns | O | O | 153 | 154 | __VI. System gets the boundaries and entity type wrong__ 155 | 156 | | Token | Gold | Prediction | 157 | |---------|-------|------------| 158 | | Unless | O | B-ORG | 159 | | Karl | B-PER | I-ORG | 160 | | Smith | I-PER | I-ORG | 161 | | resigns | O | O | 162 | 163 | 164 | ## Defining evaluation metrics 165 | 166 | How can we incorporate these described scenarios into evaluation metrics? See the [original blog](http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/) 167 | for a great explanation, a summary is included here. 168 | 169 | We can define the following five metrics to consider different categories of errors: 170 | 171 | | Error type | Explanation | 172 | |-----------------|--------------------------------------------------------------------------| 173 | | Correct (COR) | both are the same | 174 | | Incorrect (INC) | the output of a system and the golden annotation don’t match | 175 | | Partial (PAR) | system and the golden annotation are somewhat “similar” but not the same | 176 | | Missing (MIS) | a golden annotation is not captured by a system | 177 | | Spurious (SPU) | system produces a response which doesn’t exist in the golden annotation | 178 | 179 | These five metrics can be measured in four different ways: 180 | 181 | | Evaluation schema | Explanation | 182 | |-------------------|-----------------------------------------------------------------------------------| 183 | | Strict | exact boundary surface string match and entity type | 184 | | Exact | exact boundary match over the surface string, regardless of the type | 185 | | Partial | partial boundary match over the surface string, regardless of the type | 186 | | Type | some overlap between the system tagged entity and the gold annotation is required | 187 | 188 | These five errors and four evaluation schema interact in the following ways: 189 | 190 | | Scenario | Gold entity | Gold string | Pred entity | Pred string | Type | Partial | Exact | Strict | 191 | |----------|-------------|----------------|-------------|---------------------|------|---------|-------|--------| 192 | | III | BRAND | tikosyn | | | MIS | MIS | MIS | MIS | 193 | | II | | | BRAND | healthy | SPU | SPU | SPU | SPU | 194 | | V | DRUG | warfarin | DRUG | of warfarin | COR | PAR | INC | INC | 195 | | IV | DRUG | propranolol | BRAND | propranolol | INC | COR | COR | INC | 196 | | I | DRUG | phenytoin | DRUG | phenytoin | COR | COR | COR | COR | 197 | | VI | GROUP | contraceptives | DRUG | oral contraceptives | INC | PAR | INC | INC | 198 | 199 | Then precision, recall and f1-score are calculated for each different evaluation schema. In order to achieve data, 200 | two more quantities need to be calculated: 201 | 202 | ``` 203 | POSSIBLE (POS) = COR + INC + PAR + MIS = TP + FN 204 | ACTUAL (ACT) = COR + INC + PAR + SPU = TP + FP 205 | ``` 206 | 207 | Then we can compute precision, recall, f1-score, where roughly describing precision is the percentage of correct 208 | named-entities found by the NER system. Recall as the percentage of the named-entities in the golden annotations 209 | that are retrieved by the NER system. 210 | 211 | This is computed in two different ways depending on whether we want an exact match (i.e., strict and exact ) or a 212 | partial match (i.e., partial and type) scenario: 213 | 214 | __Exact Match (i.e., strict and exact )__ 215 | ``` 216 | Precision = (COR / ACT) = TP / (TP + FP) 217 | Recall = (COR / POS) = TP / (TP+FN) 218 | ``` 219 | 220 | __Partial Match (i.e., partial and type)__ 221 | ``` 222 | Precision = (COR + 0.5 × PAR) / ACT = TP / (TP + FP) 223 | Recall = (COR + 0.5 × PAR)/POS = COR / ACT = TP / (TP + FN) 224 | ``` 225 | 226 | __Putting all together:__ 227 | 228 | | Measure | Type | Partial | Exact | Strict | 229 | |-----------|------|---------|-------|--------| 230 | | Correct | 3 | 3 | 3 | 2 | 231 | | Incorrect | 2 | 0 | 2 | 3 | 232 | | Partial | 0 | 2 | 0 | 0 | 233 | | Missed | 1 | 1 | 1 | 1 | 234 | | Spurious | 1 | 1 | 1 | 1 | 235 | | Precision | 0.5 | 0.66 | 0.5 | 0.33 | 236 | | Recall | 0.5 | 0.66 | 0.5 | 0.33 | 237 | | F1 | 0.5 | 0.66 | 0.5 | 0.33 | 238 | 239 | 240 | ## Notes: 241 | 242 | In scenarios IV and VI the entity type of the `true` and `pred` does not match, in both cases we only scored against 243 | the true entity, not the predicted one. You can argue that the predicted entity could also be scored as spurious, 244 | but according to the definition of `spurious`: 245 | 246 | * Spurious (SPU) : system produces a response which does not exist in the golden annotation; 247 | 248 | In this case there exists an annotation, but with a different entity type, so we assume it's only incorrect. 249 | 250 | 251 | ## Contributing to the `nervaluate` package 252 | 253 | ### Extending the package to accept more formats 254 | 255 | The `Evaluator` accepts the following formats: 256 | 257 | * Nested lists containing NER labels 258 | * CoNLL style tab delimited strings 259 | * [prodi.gy](https://prodi.gy) style lists of spans 260 | 261 | Additional formats can easily be added by creating a new loader class in `nervaluate/loaders.py`. The loader class 262 | should inherit from the `DataLoader` base class and implement the `load` method. 263 | 264 | The `load` method should return a list of entity lists, where each entity is represented as a dictionary 265 | with `label`, `start`, and `end` keys. 266 | 267 | The new loader can then be added to the `_setup_loaders` method in the `Evaluator` class, and can be selected with the 268 | `loader` argument when instantiating the `Evaluator` class. 269 | 270 | Here is list of formats we intend to [include](https://github.com/MantisAI/nervaluate/issues/3). 271 | 272 | ### General Contributing 273 | 274 | Improvements, adding new features and bug fixes are welcome. If you wish to participate in the development of `nervaluate` 275 | please read the guidelines in the [CONTRIBUTING.md](CONTRIBUTING.md) file. 276 | 277 | --- 278 | 279 | Give a ⭐️ if this project helped you! 280 | -------------------------------------------------------------------------------- /compare_versions.py: -------------------------------------------------------------------------------- 1 | from nervaluate.evaluator import Evaluator as NewEvaluator 2 | from nervaluate import Evaluator as OldEvaluator 3 | from nervaluate.reporting import summary_report_overall_indices, summary_report_ents_indices, summary_report 4 | 5 | def list_to_dict_format(data): 6 | """ 7 | Convert list format data to dictionary format. 8 | 9 | Args: 10 | data: List of lists containing BIO tags 11 | 12 | Returns: 13 | List of lists containing dictionaries with label, start, and end keys 14 | """ 15 | result = [] 16 | for doc in data: 17 | doc_entities = [] 18 | current_entity = None 19 | 20 | for i, tag in enumerate(doc): 21 | if tag.startswith('B-'): 22 | # If we were tracking an entity, add it to the list 23 | if current_entity is not None: 24 | doc_entities.append(current_entity) 25 | # Start tracking a new entity 26 | current_entity = { 27 | 'label': tag[2:], # Remove 'B-' prefix 28 | 'start': i, 29 | 'end': i 30 | } 31 | elif tag.startswith('I-'): 32 | # Continue tracking the current entity 33 | if current_entity is not None: 34 | current_entity['end'] = i 35 | else: # 'O' tag 36 | # If we were tracking an entity, add it to the list 37 | if current_entity is not None: 38 | doc_entities.append(current_entity) 39 | current_entity = None 40 | 41 | # Don't forget to add the last entity if there was one 42 | if current_entity is not None: 43 | doc_entities.append(current_entity) 44 | 45 | result.append(doc_entities) 46 | 47 | return result 48 | 49 | 50 | def generate_synthetic_data(tags, num_samples, min_length=5, max_length=15): 51 | """ 52 | Generate synthetic NER data with ground truth and predictions. 53 | 54 | Args: 55 | tags (list): List of entity tags to use (e.g., ['PER', 'ORG', 'LOC', 'DATE']) 56 | num_samples (int): Number of samples to generate 57 | min_length (int): Minimum sequence length 58 | max_length (int): Maximum sequence length 59 | 60 | Returns: 61 | tuple: (true_sequences, pred_sequences) 62 | """ 63 | import random 64 | 65 | def generate_sequence(length): 66 | sequence = ['O'] * length 67 | # Randomly decide if we'll add an entity 68 | if random.random() < 0.7: # 70% chance to add an entity 69 | # Choose random tag 70 | tag = random.choice(tags) 71 | # Choose random start position 72 | start = random.randint(0, length - 2) 73 | # Choose random length (1 or 2 tokens) 74 | entity_length = random.randint(1, 2) 75 | if start + entity_length <= length: 76 | sequence[start] = f'B-{tag}' 77 | for i in range(1, entity_length): 78 | sequence[start + i] = f'I-{tag}' 79 | return sequence 80 | 81 | def generate_prediction(true_sequence): 82 | pred_sequence = true_sequence.copy() 83 | # Randomly modify some predictions 84 | for i in range(len(pred_sequence)): 85 | if random.random() < 0.2: # 20% chance to modify each token 86 | if pred_sequence[i] == 'O': 87 | # Sometimes predict an entity where there isn't one 88 | if random.random() < 0.3: 89 | tag = random.choice(tags) 90 | pred_sequence[i] = f'B-{tag}' 91 | else: 92 | # Sometimes change the entity type or boundary 93 | if random.random() < 0.3: 94 | tag = random.choice(tags) 95 | if pred_sequence[i].startswith('B-'): 96 | pred_sequence[i] = f'B-{tag}' 97 | elif pred_sequence[i].startswith('I-'): 98 | pred_sequence[i] = f'I-{tag}' 99 | elif random.random() < 0.3: 100 | # Sometimes predict O instead of an entity 101 | pred_sequence[i] = 'O' 102 | return pred_sequence 103 | 104 | true_sequences = [] 105 | pred_sequences = [] 106 | 107 | for _ in range(num_samples): 108 | length = random.randint(min_length, max_length) 109 | true_sequence = generate_sequence(length) 110 | pred_sequence = generate_prediction(true_sequence) 111 | true_sequences.append(true_sequence) 112 | pred_sequences.append(pred_sequence) 113 | 114 | return true_sequences, pred_sequences 115 | 116 | 117 | def overall_report(true, pred): 118 | 119 | new_evaluator = NewEvaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") 120 | print(new_evaluator.summary_report()) 121 | 122 | print("-"*100) 123 | 124 | old_evaluator = OldEvaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") 125 | results = old_evaluator.evaluate()[0] # Get the first element which contains the overall results 126 | print(summary_report(results)) 127 | 128 | 129 | def entities_report(true, pred): 130 | 131 | new_evaluator = NewEvaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") 132 | 133 | # entities - strict, exact, partial, ent_type 134 | print(new_evaluator.summary_report(mode="entities", scenario="strict")) 135 | print(new_evaluator.summary_report(mode="entities", scenario="exact")) 136 | print(new_evaluator.summary_report(mode="entities", scenario="partial")) 137 | print(new_evaluator.summary_report(mode="entities", scenario="ent_type")) 138 | 139 | print("-"*100) 140 | 141 | old_evaluator = OldEvaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") 142 | _, results_agg_entities_type, _, _ = old_evaluator.evaluate() # Get the second element which contains the entity-specific results 143 | print(summary_report(results_agg_entities_type, mode="entities", scenario="strict")) 144 | print(summary_report(results_agg_entities_type, mode="entities", scenario="exact")) 145 | print(summary_report(results_agg_entities_type, mode="entities", scenario="partial")) 146 | print(summary_report(results_agg_entities_type, mode="entities", scenario="ent_type")) 147 | 148 | 149 | def indices_report_overall(true, pred): 150 | 151 | new_evaluator = NewEvaluator(true, pred, tags=['PER', 'LOC', 'DATE'], loader="list") 152 | print(new_evaluator.summary_report_indices(colors=True, mode="overall", scenario="strict")) 153 | print(new_evaluator.summary_report_indices(colors=True, mode="overall", scenario="exact")) 154 | print(new_evaluator.summary_report_indices(colors=True, mode="overall", scenario="partial")) 155 | print(new_evaluator.summary_report_indices(colors=True, mode="overall", scenario="ent_type")) 156 | 157 | old_evaluator = OldEvaluator(true, pred, tags=['LOC', 'PER', 'DATE'], loader="list") 158 | _, _, result_indices, _ = old_evaluator.evaluate() 159 | pred_dict = list_to_dict_format(pred) # convert predictions to dictionary format for reporting 160 | print(summary_report_overall_indices(evaluation_indices=result_indices, error_schema='strict', preds=pred_dict)) 161 | print(summary_report_overall_indices(evaluation_indices=result_indices, error_schema='exact', preds=pred_dict)) 162 | print(summary_report_overall_indices(evaluation_indices=result_indices, error_schema='partial', preds=pred_dict)) 163 | print(summary_report_overall_indices(evaluation_indices=result_indices, error_schema='ent_type', preds=pred_dict)) 164 | 165 | 166 | def indices_report_entities(true, pred): 167 | 168 | new_evaluator = NewEvaluator(true, pred, tags=['PER', 'LOC', 'DATE'], loader="list") 169 | print(new_evaluator.summary_report_indices(colors=True, mode="entities", scenario="strict")) 170 | print(new_evaluator.summary_report_indices(colors=True, mode="entities", scenario="exact")) 171 | print(new_evaluator.summary_report_indices(colors=True, mode="entities", scenario="partial")) 172 | print(new_evaluator.summary_report_indices(colors=True, mode="entities", scenario="ent_type")) 173 | 174 | old_evaluator = OldEvaluator(true, pred, tags=['LOC', 'PER', 'DATE'], loader="list") 175 | _, _, _, result_indices_by_tag = old_evaluator.evaluate() 176 | pred_dict = list_to_dict_format(pred) # convert predictions to dictionary format for reporting 177 | print(summary_report_ents_indices(evaluation_agg_indices=result_indices_by_tag, error_schema='strict', preds=pred_dict)) 178 | print(summary_report_ents_indices(evaluation_agg_indices=result_indices_by_tag, error_schema='exact', preds=pred_dict)) 179 | print(summary_report_ents_indices(evaluation_agg_indices=result_indices_by_tag, error_schema='partial', preds=pred_dict)) 180 | print(summary_report_ents_indices(evaluation_agg_indices=result_indices_by_tag, error_schema='ent_type', preds=pred_dict)) 181 | 182 | 183 | if __name__ == "__main__": 184 | tags = ['PER', 'ORG', 'LOC', 'DATE'] 185 | true, pred = generate_synthetic_data(tags, num_samples=10) 186 | 187 | overall_report(true, pred) 188 | print("\n\n" + "="*100 + "\n\n") 189 | entities_report(true, pred) 190 | print("\n\n" + "="*100 + "\n\n") 191 | indices_report_overall(true, pred) 192 | print("\n\n" + "="*100 + "\n\n") 193 | indices_report_entities(true, pred) 194 | print("\n\n" + "="*100 + "\n\n") -------------------------------------------------------------------------------- /examples/example_no_loader.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import sklearn_crfsuite 3 | from sklearn.metrics import classification_report 4 | 5 | from nervaluate import Evaluator, collect_named_entities, summary_report_ent, summary_report_overall 6 | 7 | 8 | def word2features(sent, i): 9 | word = sent[i][0] 10 | postag = sent[i][1] 11 | 12 | features = { 13 | "bias": 1.0, 14 | "word.lower()": word.lower(), 15 | "word[-3:]": word[-3:], 16 | "word[-2:]": word[-2:], 17 | "word.isupper()": word.isupper(), 18 | "word.istitle()": word.istitle(), 19 | "word.isdigit()": word.isdigit(), 20 | "postag": postag, 21 | "postag[:2]": postag[:2], 22 | } 23 | if i > 0: 24 | word1 = sent[i - 1][0] 25 | postag1 = sent[i - 1][1] 26 | features.update( 27 | { 28 | "-1:word.lower()": word1.lower(), 29 | "-1:word.istitle()": word1.istitle(), 30 | "-1:word.isupper()": word1.isupper(), 31 | "-1:postag": postag1, 32 | "-1:postag[:2]": postag1[:2], 33 | } 34 | ) 35 | else: 36 | features["BOS"] = True 37 | 38 | if i < len(sent) - 1: 39 | word1 = sent[i + 1][0] 40 | postag1 = sent[i + 1][1] 41 | features.update( 42 | { 43 | "+1:word.lower()": word1.lower(), 44 | "+1:word.istitle()": word1.istitle(), 45 | "+1:word.isupper()": word1.isupper(), 46 | "+1:postag": postag1, 47 | "+1:postag[:2]": postag1[:2], 48 | } 49 | ) 50 | else: 51 | features["EOS"] = True 52 | 53 | return features 54 | 55 | 56 | def sent2features(sent): 57 | return [word2features(sent, i) for i in range(len(sent))] 58 | 59 | 60 | def sent2labels(sent): 61 | return [label for token, postag, label in sent] 62 | 63 | 64 | def sent2tokens(sent): 65 | return [token for token, postag, label in sent] 66 | 67 | 68 | def main(): 69 | print("Loading CoNLL 2002 NER Spanish data") 70 | nltk.corpus.conll2002.fileids() 71 | train_sents = list(nltk.corpus.conll2002.iob_sents("esp.train")) 72 | test_sents = list(nltk.corpus.conll2002.iob_sents("esp.testb")) 73 | 74 | x_train = [sent2features(s) for s in train_sents] 75 | y_train = [sent2labels(s) for s in train_sents] 76 | 77 | x_test = [sent2features(s) for s in test_sents] 78 | y_test = [sent2labels(s) for s in test_sents] 79 | 80 | print("Train a CRF on the CoNLL 2002 NER Spanish data") 81 | crf = sklearn_crfsuite.CRF(algorithm="lbfgs", c1=0.1, c2=0.1, max_iterations=10, all_possible_transitions=True) 82 | try: 83 | crf.fit(x_train, y_train) 84 | except AttributeError: 85 | pass 86 | 87 | y_pred = crf.predict(x_test) 88 | labels = list(crf.classes_) 89 | labels.remove("O") # remove 'O' label from evaluation 90 | sorted_labels = sorted(labels, key=lambda name: (name[1:], name[0])) # group B- and I- results 91 | y_test_flat = [y for msg in y_test for y in msg] 92 | y_pred_flat = [y for msg in y_pred for y in msg] 93 | print(classification_report(y_test_flat, y_pred_flat, labels=sorted_labels)) 94 | 95 | test_sents_labels = [] 96 | for sentence in test_sents: 97 | sentence = [token[2] for token in sentence] 98 | test_sents_labels.append(sentence) 99 | 100 | pred_collected = [collect_named_entities(msg) for msg in y_pred] 101 | test_collected = [collect_named_entities(msg) for msg in y_test] 102 | 103 | evaluator = Evaluator(test_collected, pred_collected, ["LOC", "MISC", "PER", "ORG"]) 104 | results, results_agg = evaluator.evaluate() 105 | 106 | print("\n\nOverall") 107 | print(summary_report_overall(results)) 108 | print("\n\n'Strict'") 109 | print(summary_report_ent(results_agg, scenario="strict")) 110 | print("\n\n'Ent_Type'") 111 | print(summary_report_ent(results_agg, scenario="ent_type")) 112 | print("\n\n'Partial'") 113 | print(summary_report_ent(results_agg, scenario="partial")) 114 | print("\n\n'Exact'") 115 | print(summary_report_ent(results_agg, scenario="exact")) 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /examples/run_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install nltk 4 | pip install sklearn 5 | pip install sklearn_crfsuite 6 | python -m nltk.downloader conll2002 7 | python example_no_loader.py 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "nervaluate" 7 | version = "0.3.1" 8 | authors = [ 9 | { name="David S. Batista"}, 10 | { name="Matthew Upson"} 11 | ] 12 | description = "NER evaluation considering partial match scoring" 13 | readme = "README.md" 14 | requires-python = ">=3.11" 15 | keywords = ["named-entity-recognition", "ner", "evaluation-metrics", "partial-match-scoring", "nlp"] 16 | license = {text = "MIT License"} 17 | classifiers = [ 18 | "Programming Language :: Python :: 3", 19 | "Operating System :: OS Independent" 20 | ] 21 | 22 | dependencies = [ 23 | "pandas>=2.3.0" 24 | ] 25 | 26 | [project.optional-dependencies] 27 | dev = [ 28 | "black>=25.1.0", 29 | "coverage>=7.8.0", 30 | "gitchangelog", 31 | "mypy>=1.15.0", 32 | "pre-commit==3.3.1", 33 | "pylint>=3.3.7", 34 | "pytest>=8.3.5", 35 | "pytest-cov>=6.1.1", 36 | ] 37 | 38 | [project.urls] 39 | "Homepage" = "https://github.com/MantisAI/nervaluate" 40 | "Bug Tracker" = "https://github.com/MantisAI/nervaluate/issues" 41 | 42 | [tool.pytest.ini_options] 43 | testpaths = ["tests"] 44 | python_files = ["test_*.py"] 45 | addopts = "--cov=nervaluate --cov-report=term-missing" 46 | 47 | [tool.coverage.run] 48 | source = ["nervaluate"] 49 | omit = ["*__init__*"] 50 | 51 | [tool.coverage.report] 52 | show_missing = true 53 | precision = 2 54 | sort = "Miss" 55 | 56 | [tool.black] 57 | line-length = 120 58 | target-version = ["py311"] 59 | 60 | [tool.pylint.messages_control] 61 | disable = [ 62 | "C0111", # missing-docstring 63 | "C0103", # invalid-name 64 | "W0511", # fixme 65 | "W0603", # global-statement 66 | "W1202", # logging-format-interpolation 67 | "W1203", # logging-fstring-interpolation 68 | "E1126", # invalid-sequence-index 69 | "E1137", # invalid-slice-index 70 | "I0011", # bad-option-value 71 | "I0020", # bad-option-value 72 | "R0801", # duplicate-code 73 | "W9020", # bad-option-value 74 | "W0621", # redefined-outer-name 75 | ] 76 | 77 | [tool.pylint.'DESIGN'] 78 | max-args = 38 # Default is 5 79 | max-attributes = 28 # Default is 7 80 | max-branches = 14 # Default is 12 81 | max-locals = 45 # Default is 15 82 | max-module-lines = 2468 # Default is 1000 83 | max-nested-blocks = 9 # Default is 5 84 | max-statements = 206 # Default is 50 85 | min-public-methods = 1 # Allow classes with just one public method 86 | 87 | [tool.pylint.format] 88 | max-line-length = 120 89 | 90 | [tool.pylint.basic] 91 | accept-no-param-doc = true 92 | accept-no-raise-doc = true 93 | accept-no-return-doc = true 94 | accept-no-yields-doc = true 95 | default-docstring-type = "numpy" 96 | 97 | [tool.pylint.master] 98 | load-plugins = ["pylint.extensions.docparams"] 99 | ignore-paths = ["./examples/.*"] 100 | 101 | [tool.mypy] 102 | python_version = "3.11" 103 | ignore_missing_imports = true 104 | disallow_any_unimported = true 105 | disallow_untyped_defs = true 106 | warn_redundant_casts = true 107 | warn_unused_ignores = true 108 | warn_unused_configs = true 109 | 110 | [[tool.mypy.overrides]] 111 | module = "examples.*" 112 | follow_imports = "skip" 113 | 114 | [tool.hatch.envs.dev] 115 | dependencies = [ 116 | "black==24.3.0", 117 | "coverage==7.2.5", 118 | "gitchangelog", 119 | "mypy==1.3.0", 120 | "pre-commit==3.3.1", 121 | "pylint==2.17.4", 122 | "pytest==7.3.1", 123 | "pytest-cov==4.1.0", 124 | ] 125 | 126 | [tool.hatch.envs.dev.scripts] 127 | lint = [ 128 | "black -t py311 -l 120 src tests", 129 | "pylint src tests" 130 | ] 131 | typing = "mypy src" 132 | test = "pytest" 133 | clean = "rm -rf dist src/nervaluate.egg-info .coverage .mypy_cache .pytest_cache" 134 | changelog = "gitchangelog > CHANGELOG.rst" 135 | all = [ 136 | "clean", 137 | "lint", 138 | "typing", 139 | "test" 140 | ] 141 | -------------------------------------------------------------------------------- /src/nervaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import Evaluator 2 | from .utils import collect_named_entities, conll_to_spans, list_to_spans, split_list 3 | -------------------------------------------------------------------------------- /src/nervaluate/entities.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Tuple 3 | 4 | 5 | @dataclass 6 | class Entity: 7 | """Represents a named entity with its position and label.""" 8 | 9 | label: str 10 | start: int 11 | end: int 12 | 13 | def __eq__(self, other: object) -> bool: 14 | if not isinstance(other, Entity): 15 | return NotImplemented 16 | return self.label == other.label and self.start == other.start and self.end == other.end 17 | 18 | def __hash__(self) -> int: 19 | return hash((self.label, self.start, self.end)) 20 | 21 | 22 | @dataclass 23 | class EvaluationResult: 24 | """Represents the evaluation metrics for a single entity type or overall.""" 25 | 26 | correct: int = 0 27 | incorrect: int = 0 28 | partial: int = 0 29 | missed: int = 0 30 | spurious: int = 0 31 | precision: float = 0.0 32 | recall: float = 0.0 33 | f1: float = 0.0 34 | actual: int = 0 35 | possible: int = 0 36 | 37 | def compute_metrics(self, partial_or_type: bool = False) -> None: 38 | """Compute precision, recall and F1 score.""" 39 | self.actual = self.correct + self.incorrect + self.partial + self.spurious 40 | self.possible = self.correct + self.incorrect + self.partial + self.missed 41 | 42 | if partial_or_type: 43 | precision = (self.correct + 0.5 * self.partial) / self.actual if self.actual > 0 else 0 44 | recall = (self.correct + 0.5 * self.partial) / self.possible if self.possible > 0 else 0 45 | else: 46 | precision = self.correct / self.actual if self.actual > 0 else 0 47 | recall = self.correct / self.possible if self.possible > 0 else 0 48 | 49 | self.precision = precision 50 | self.recall = recall 51 | self.f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 52 | 53 | 54 | @dataclass 55 | class EvaluationIndices: 56 | """Represents the indices of entities in different evaluation categories.""" 57 | 58 | correct_indices: List[Tuple[int, int]] = None # type: ignore 59 | incorrect_indices: List[Tuple[int, int]] = None # type: ignore 60 | partial_indices: List[Tuple[int, int]] = None # type: ignore 61 | missed_indices: List[Tuple[int, int]] = None # type: ignore 62 | spurious_indices: List[Tuple[int, int]] = None # type: ignore 63 | 64 | def __post_init__(self) -> None: 65 | if self.correct_indices is None: 66 | self.correct_indices = [] 67 | if self.incorrect_indices is None: 68 | self.incorrect_indices = [] 69 | if self.partial_indices is None: 70 | self.partial_indices = [] 71 | if self.missed_indices is None: 72 | self.missed_indices = [] 73 | if self.spurious_indices is None: 74 | self.spurious_indices = [] 75 | -------------------------------------------------------------------------------- /src/nervaluate/evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Union 2 | import pandas as pd 3 | 4 | from .entities import EvaluationResult, EvaluationIndices 5 | from .strategies import ( 6 | EvaluationStrategy, 7 | StrictEvaluation, 8 | PartialEvaluation, 9 | EntityTypeEvaluation, 10 | ExactEvaluation, 11 | ) 12 | from .loaders import DataLoader, ConllLoader, ListLoader, DictLoader 13 | from .entities import Entity 14 | 15 | 16 | class Evaluator: 17 | """Main evaluator class for NER evaluation.""" 18 | 19 | def __init__(self, true: Any, pred: Any, tags: List[str], loader: str = "default") -> None: 20 | """ 21 | Initialize the evaluator. 22 | 23 | Args: 24 | true: True entities in any supported format 25 | pred: Predicted entities in any supported format 26 | tags: List of valid entity tags 27 | loader: Name of the loader to use 28 | """ 29 | self.tags = tags 30 | self._setup_loaders() 31 | self._load_data(true, pred, loader) 32 | self._setup_evaluation_strategies() 33 | 34 | def _setup_loaders(self) -> None: 35 | """Setup available data loaders.""" 36 | self.loaders: Dict[str, DataLoader] = {"conll": ConllLoader(), "list": ListLoader(), "dict": DictLoader()} 37 | 38 | def _setup_evaluation_strategies(self) -> None: 39 | """Setup evaluation strategies.""" 40 | self.strategies: Dict[str, EvaluationStrategy] = { 41 | "strict": StrictEvaluation(), 42 | "partial": PartialEvaluation(), 43 | "ent_type": EntityTypeEvaluation(), 44 | "exact": ExactEvaluation(), 45 | } 46 | 47 | def _load_data(self, true: Any, pred: Any, loader: str) -> None: 48 | """Load the true and predicted data.""" 49 | if loader == "default": 50 | # Try to infer the loader based on input type 51 | if isinstance(true, str): 52 | loader = "conll" 53 | elif isinstance(true, list) and true and isinstance(true[0], list): 54 | if isinstance(true[0][0], dict): 55 | loader = "dict" 56 | else: 57 | loader = "list" 58 | else: 59 | raise ValueError("Could not infer loader from input type") 60 | 61 | if loader not in self.loaders: 62 | raise ValueError(f"Unknown loader: {loader}") 63 | 64 | # For list loader, check document lengths before loading 65 | if loader == "list": 66 | if len(true) != len(pred): 67 | raise ValueError("Number of predicted documents does not equal true") 68 | 69 | # Check that each document has the same length 70 | for i, (true_doc, pred_doc) in enumerate(zip(true, pred)): 71 | if len(true_doc) != len(pred_doc): 72 | raise ValueError(f"Document {i} has different lengths: true={len(true_doc)}, pred={len(pred_doc)}") 73 | 74 | self.true = self.loaders[loader].load(true) 75 | self.pred = self.loaders[loader].load(pred) 76 | 77 | if len(self.true) != len(self.pred): 78 | raise ValueError("Number of predicted documents does not equal true") 79 | 80 | def evaluate(self) -> Dict[str, Any]: 81 | """ 82 | Run the evaluation. 83 | 84 | Returns: 85 | Dictionary containing evaluation results for each strategy and entity type 86 | """ 87 | results = {} 88 | # Get unique tags that appear in either true or predicted data 89 | used_tags = set() # type: ignore 90 | for doc in self.true: 91 | used_tags.update(e.label for e in doc) 92 | for doc in self.pred: 93 | used_tags.update(e.label for e in doc) 94 | # Only keep tags that are both used and in the allowed tags list 95 | used_tags = used_tags.intersection(set(self.tags)) 96 | 97 | entity_results: Dict[str, Dict[str, EvaluationResult]] = {tag: {} for tag in used_tags} 98 | indices = {} 99 | entity_indices: Dict[str, Dict[str, EvaluationIndices]] = {tag: {} for tag in used_tags} 100 | 101 | # Evaluate each document 102 | for doc_idx, (true_doc, pred_doc) in enumerate(zip(self.true, self.pred)): 103 | # Filter entities by valid tags 104 | true_doc = [e for e in true_doc if e.label in self.tags] 105 | pred_doc = [e for e in pred_doc if e.label in self.tags] 106 | 107 | # Evaluate with each strategy 108 | for strategy_name, strategy in self.strategies.items(): 109 | result, doc_indices = strategy.evaluate(true_doc, pred_doc, self.tags, doc_idx) 110 | 111 | # Update overall results 112 | if strategy_name not in results: 113 | results[strategy_name] = result 114 | indices[strategy_name] = doc_indices 115 | else: 116 | self._merge_results(results[strategy_name], result) 117 | self._merge_indices(indices[strategy_name], doc_indices) 118 | 119 | # Update entity-specific results 120 | for tag in used_tags: 121 | # Filter entities for this specific tag 122 | true_tag_doc = [e for e in true_doc if e.label == tag] 123 | pred_tag_doc = [e for e in pred_doc if e.label == tag] 124 | 125 | # Evaluate only entities of this tag 126 | tag_result, tag_indices = strategy.evaluate(true_tag_doc, pred_tag_doc, [tag], doc_idx) 127 | 128 | if tag not in entity_results: 129 | entity_results[tag] = {} 130 | entity_indices[tag] = {} 131 | if strategy_name not in entity_results[tag]: 132 | entity_results[tag][strategy_name] = tag_result 133 | entity_indices[tag][strategy_name] = tag_indices 134 | else: 135 | self._merge_results(entity_results[tag][strategy_name], tag_result) 136 | self._merge_indices(entity_indices[tag][strategy_name], tag_indices) 137 | 138 | return { 139 | "overall": results, 140 | "entities": entity_results, 141 | "overall_indices": indices, 142 | "entity_indices": entity_indices, 143 | } 144 | 145 | @staticmethod 146 | def _merge_results(target: EvaluationResult, source: EvaluationResult) -> None: 147 | """Merge two evaluation results.""" 148 | target.correct += source.correct 149 | target.incorrect += source.incorrect 150 | target.partial += source.partial 151 | target.missed += source.missed 152 | target.spurious += source.spurious 153 | target.compute_metrics() 154 | 155 | @staticmethod 156 | def _merge_indices(target: EvaluationIndices, source: EvaluationIndices) -> None: 157 | """Merge two evaluation indices.""" 158 | target.correct_indices.extend(source.correct_indices) 159 | target.incorrect_indices.extend(source.incorrect_indices) 160 | target.partial_indices.extend(source.partial_indices) 161 | target.missed_indices.extend(source.missed_indices) 162 | target.spurious_indices.extend(source.spurious_indices) 163 | 164 | def results_to_dataframe(self) -> Any: 165 | """Convert results to a pandas DataFrame.""" 166 | results = self.evaluate() 167 | 168 | # Flatten the results structure 169 | flat_results = {} 170 | for category, category_results in results.items(): 171 | for strategy, strategy_results in category_results.items(): 172 | for metric, value in strategy_results.__dict__.items(): 173 | key = f"{category}.{strategy}.{metric}" 174 | flat_results[key] = value 175 | 176 | return pd.DataFrame([flat_results]) 177 | 178 | def summary_report(self, mode: str = "overall", scenario: str = "strict", digits: int = 2) -> str: 179 | """ 180 | Generate a summary report of the evaluation results. 181 | 182 | Args: 183 | mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. 184 | scenario: The scenario to report on. Only used when mode is 'entities'. 185 | Must be one of: 186 | - 'strict' exact boundary surface string match and entity type; 187 | - 'exact': exact boundary match over the surface string and entity type; 188 | - 'partial': partial boundary match over the surface string, regardless of the type; 189 | - 'ent_type': exact boundary match over the surface string, regardless of the type; 190 | digits: The number of digits to round the results to. 191 | 192 | Returns: 193 | A string containing the summary report. 194 | 195 | Raises: 196 | ValueError: If the scenario or mode is invalid. 197 | """ 198 | valid_scenarios = {"strict", "ent_type", "partial", "exact"} 199 | valid_modes = {"overall", "entities"} 200 | 201 | if mode not in valid_modes: 202 | raise ValueError(f"Invalid mode: must be one of {valid_modes}") 203 | 204 | if mode == "entities" and scenario not in valid_scenarios: 205 | raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") 206 | 207 | headers = ["correct", "incorrect", "partial", "missed", "spurious", "precision", "recall", "f1-score"] 208 | rows = [headers] 209 | 210 | results = self.evaluate() 211 | if mode == "overall": 212 | # Process overall results - show all scenarios 213 | results_data = results["overall"] 214 | for eval_schema in sorted(valid_scenarios): # Sort to ensure consistent order 215 | if eval_schema not in results_data: 216 | continue 217 | results_schema = results_data[eval_schema] 218 | rows.append( 219 | [ 220 | eval_schema, 221 | results_schema.correct, 222 | results_schema.incorrect, 223 | results_schema.partial, 224 | results_schema.missed, 225 | results_schema.spurious, 226 | results_schema.precision, 227 | results_schema.recall, 228 | results_schema.f1, 229 | ] 230 | ) 231 | else: 232 | # Process entity-specific results for the specified scenario only 233 | results_data = results["entities"] 234 | target_names = sorted(results_data.keys()) 235 | for ent_type in target_names: 236 | if scenario not in results_data[ent_type]: 237 | continue # Skip if scenario not available for this entity type 238 | 239 | results_ent = results_data[ent_type][scenario] 240 | rows.append( 241 | [ 242 | ent_type, 243 | results_ent.correct, 244 | results_ent.incorrect, 245 | results_ent.partial, 246 | results_ent.missed, 247 | results_ent.spurious, 248 | results_ent.precision, 249 | results_ent.recall, 250 | results_ent.f1, 251 | ] 252 | ) 253 | 254 | # Format the report 255 | name_width = max(len(str(row[0])) for row in rows) 256 | width = max(name_width, digits) 257 | head_fmt = "{:>{width}s} " + " {:>11}" * len(headers) 258 | report = f"Scenario: {scenario if mode == 'entities' else 'all'}\n\n" + head_fmt.format( 259 | "", *headers, width=width 260 | ) 261 | report += "\n\n" 262 | row_fmt = "{:>{width}s} " + " {:>11}" * 5 + " {:>11.{digits}f}" * 3 + "\n" 263 | 264 | for row in rows[1:]: 265 | report += row_fmt.format(*row, width=width, digits=digits) 266 | 267 | return report 268 | 269 | def summary_report_indices( # pylint: disable=too-many-branches 270 | self, mode: str = "overall", scenario: str = "strict", colors: bool = False 271 | ) -> str: 272 | """ 273 | Generate a summary report of the evaluation indices. 274 | 275 | Args: 276 | mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. 277 | scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. 278 | Only used when mode is 'entities'. Defaults to 'strict'. 279 | colors: Whether to use colors in the output. Defaults to False. 280 | 281 | Returns: 282 | A string containing the summary report of indices. 283 | 284 | Raises: 285 | ValueError: If the scenario or mode is invalid. 286 | """ 287 | valid_scenarios = {"strict", "ent_type", "partial", "exact"} 288 | valid_modes = {"overall", "entities"} 289 | 290 | if mode not in valid_modes: 291 | raise ValueError(f"Invalid mode: must be one of {valid_modes}") 292 | 293 | if mode == "entities" and scenario not in valid_scenarios: 294 | raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") 295 | 296 | # ANSI color codes 297 | COLORS = { 298 | "reset": "\033[0m", 299 | "bold": "\033[1m", 300 | "red": "\033[91m", 301 | "green": "\033[92m", 302 | "yellow": "\033[93m", 303 | "blue": "\033[94m", 304 | "magenta": "\033[95m", 305 | "cyan": "\033[96m", 306 | "white": "\033[97m", 307 | } 308 | 309 | def colorize(text: str, color: str) -> str: 310 | """Helper function to colorize text if colors are enabled.""" 311 | if colors: 312 | return f"{COLORS[color]}{text}{COLORS['reset']}" 313 | return text 314 | 315 | def get_prediction_info(pred: Union[Entity, str]) -> str: 316 | """Helper function to get prediction info based on pred type.""" 317 | if isinstance(pred, Entity): 318 | return f"Label={pred.label}, Start={pred.start}, End={pred.end}" 319 | # String (BIO tag) 320 | return f"Tag={pred}" 321 | 322 | results = self.evaluate() 323 | report = "" 324 | 325 | # Create headers for the table 326 | headers = ["Category", "Instance", "Entity", "Details"] 327 | header_fmt = "{:<20} {:<10} {:<8} {:<25}" 328 | row_fmt = "{:<20} {:<10} {:<8} {:<10}" 329 | 330 | if mode == "overall": 331 | # Get the indices from the overall results 332 | indices_data = results["overall_indices"][scenario] 333 | report += f"\n{colorize('Indices for error schema', 'bold')} '{colorize(scenario, 'cyan')}':\n\n" 334 | report += colorize(header_fmt.format(*headers), "bold") + "\n" 335 | report += colorize("-" * 78, "white") + "\n" 336 | 337 | for category, indices in indices_data.__dict__.items(): 338 | if not category.endswith("_indices"): 339 | continue 340 | category_name = category.replace("_indices", "").replace("_", " ").capitalize() 341 | 342 | # Color mapping for categories 343 | category_colors = { 344 | "Correct": "green", 345 | "Incorrect": "red", 346 | "Partial": "yellow", 347 | "Missed": "magenta", 348 | "Spurious": "blue", 349 | } 350 | 351 | if indices: 352 | for instance_index, entity_index in indices: 353 | if self.pred != [[]]: 354 | pred = self.pred[instance_index][entity_index] 355 | prediction_info = get_prediction_info(pred) 356 | report += ( 357 | row_fmt.format( 358 | colorize(category_name, category_colors.get(category_name, "white")), 359 | f"{instance_index}", 360 | f"{entity_index}", 361 | prediction_info, 362 | ) 363 | + "\n" 364 | ) 365 | else: 366 | report += ( 367 | row_fmt.format( 368 | colorize(category_name, category_colors.get(category_name, "white")), 369 | f"{instance_index}", 370 | f"{entity_index}", 371 | "No prediction info", 372 | ) 373 | + "\n" 374 | ) 375 | else: 376 | report += ( 377 | row_fmt.format( 378 | colorize(category_name, category_colors.get(category_name, "white")), "-", "-", "None" 379 | ) 380 | + "\n" 381 | ) 382 | else: 383 | # Get the indices from the entity-specific results 384 | for entity_type, entity_results in results["entity_indices"].items(): 385 | report += f"\n{colorize('Entity Type', 'bold')}: {colorize(entity_type, 'cyan')}\n" 386 | report += f"{colorize('Error Schema', 'bold')}: '{colorize(scenario, 'cyan')}'\n\n" 387 | report += colorize(header_fmt.format(*headers), "bold") + "\n" 388 | report += colorize("-" * 78, "white") + "\n" 389 | 390 | error_data = entity_results[scenario] 391 | for category, indices in error_data.__dict__.items(): 392 | if not category.endswith("_indices"): 393 | continue 394 | category_name = category.replace("_indices", "").replace("_", " ").capitalize() 395 | 396 | # Color mapping for categories 397 | category_colors = { 398 | "Correct": "green", 399 | "Incorrect": "red", 400 | "Partial": "yellow", 401 | "Missed": "magenta", 402 | "Spurious": "blue", 403 | } 404 | 405 | if indices: 406 | for instance_index, entity_index in indices: 407 | if self.pred != [[]]: 408 | pred = self.pred[instance_index][entity_index] 409 | prediction_info = get_prediction_info(pred) 410 | report += ( 411 | row_fmt.format( 412 | colorize(category_name, category_colors.get(category_name, "white")), 413 | f"{instance_index}", 414 | f"{entity_index}", 415 | prediction_info, 416 | ) 417 | + "\n" 418 | ) 419 | else: 420 | report += ( 421 | row_fmt.format( 422 | colorize(category_name, category_colors.get(category_name, "white")), 423 | f"{instance_index}", 424 | f"{entity_index}", 425 | "No prediction info", 426 | ) 427 | + "\n" 428 | ) 429 | else: 430 | report += ( 431 | row_fmt.format( 432 | colorize(category_name, category_colors.get(category_name, "white")), "-", "-", "None" 433 | ) 434 | + "\n" 435 | ) 436 | 437 | return report 438 | -------------------------------------------------------------------------------- /src/nervaluate/loaders.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Dict, Any 3 | 4 | from .entities import Entity 5 | 6 | 7 | class DataLoader(ABC): 8 | """Abstract base class for data loaders.""" 9 | 10 | @abstractmethod 11 | def load(self, data: Any) -> List[List[Entity]]: 12 | """Load data into a list of entity lists.""" 13 | 14 | 15 | class ConllLoader(DataLoader): 16 | """Loader for CoNLL format data.""" 17 | 18 | def load(self, data: str) -> List[List[Entity]]: # pylint: disable=too-many-branches 19 | """Load CoNLL format data into a list of Entity lists.""" 20 | if not isinstance(data, str): 21 | raise ValueError("ConllLoader expects string input") 22 | 23 | if not data: 24 | return [] 25 | 26 | result: List[List[Entity]] = [] 27 | # Strip trailing whitespace and newlines to avoid empty documents 28 | documents = data.rstrip().split("\n\n") 29 | 30 | for doc in documents: 31 | if not doc.strip(): 32 | result.append([]) 33 | continue 34 | 35 | current_doc = [] 36 | start_offset = None 37 | end_offset = None 38 | ent_type = None 39 | has_entities = False 40 | 41 | for offset, line in enumerate(doc.split("\n")): 42 | if not line.strip(): 43 | continue 44 | 45 | parts = line.split("\t") 46 | if len(parts) < 2: 47 | raise ValueError(f"Invalid CoNLL format: line '{line}' does not contain a tab separator") 48 | 49 | token_tag = parts[1] 50 | 51 | if token_tag == "O": 52 | if ent_type is not None and start_offset is not None: 53 | end_offset = offset - 1 54 | if isinstance(start_offset, int) and isinstance(end_offset, int): 55 | current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) 56 | start_offset = None 57 | end_offset = None 58 | ent_type = None 59 | 60 | elif ent_type is None: 61 | if not (token_tag.startswith("B-") or token_tag.startswith("I-")): 62 | raise ValueError(f"Invalid tag format: {token_tag}") 63 | ent_type = token_tag[2:] # Remove B- or I- prefix 64 | start_offset = offset 65 | has_entities = True 66 | 67 | elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): 68 | end_offset = offset - 1 69 | if isinstance(start_offset, int) and isinstance(end_offset, int): 70 | current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) 71 | 72 | # start of a new entity 73 | if not (token_tag.startswith("B-") or token_tag.startswith("I-")): 74 | raise ValueError(f"Invalid tag format: {token_tag}") 75 | ent_type = token_tag[2:] 76 | start_offset = offset 77 | end_offset = None 78 | has_entities = True 79 | 80 | # Catches an entity that goes up until the last token 81 | if ent_type is not None and start_offset is not None and end_offset is None: 82 | if isinstance(start_offset, int): 83 | current_doc.append(Entity(label=ent_type, start=start_offset, end=len(doc.split("\n")) - 1)) 84 | has_entities = True 85 | 86 | result.append(current_doc if has_entities else []) 87 | 88 | return result 89 | 90 | 91 | class ListLoader(DataLoader): 92 | """Loader for list format data.""" 93 | 94 | def load(self, data: List[List[str]]) -> List[List[Entity]]: # pylint: disable=too-many-branches 95 | """Load list format data into a list of entity lists.""" 96 | if not isinstance(data, list): 97 | raise ValueError("ListLoader expects list input") 98 | 99 | if not data: 100 | return [] 101 | 102 | result = [] 103 | 104 | for doc in data: 105 | if not isinstance(doc, list): 106 | raise ValueError("Each document must be a list of tags") 107 | 108 | current_doc = [] 109 | start_offset = None 110 | end_offset = None 111 | ent_type = None 112 | 113 | for offset, token_tag in enumerate(doc): 114 | if not isinstance(token_tag, str): 115 | raise ValueError(f"Invalid tag type: {type(token_tag)}") 116 | 117 | if token_tag == "O": 118 | if ent_type is not None and start_offset is not None: 119 | end_offset = offset - 1 120 | if isinstance(start_offset, int) and isinstance(end_offset, int): 121 | current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) 122 | start_offset = None 123 | end_offset = None 124 | ent_type = None 125 | 126 | elif ent_type is None: 127 | if not (token_tag.startswith("B-") or token_tag.startswith("I-")): 128 | raise ValueError(f"Invalid tag format: {token_tag}") 129 | ent_type = token_tag[2:] # Remove B- or I- prefix 130 | start_offset = offset 131 | 132 | elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): 133 | end_offset = offset - 1 134 | if isinstance(start_offset, int) and isinstance(end_offset, int): 135 | current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) 136 | 137 | # start of a new entity 138 | if not (token_tag.startswith("B-") or token_tag.startswith("I-")): 139 | raise ValueError(f"Invalid tag format: {token_tag}") 140 | ent_type = token_tag[2:] 141 | start_offset = offset 142 | end_offset = None 143 | 144 | # Catches an entity that goes up until the last token 145 | if ent_type is not None and start_offset is not None and end_offset is None: 146 | if isinstance(start_offset, int): 147 | current_doc.append(Entity(label=ent_type, start=start_offset, end=len(doc) - 1)) 148 | 149 | result.append(current_doc) 150 | 151 | return result 152 | 153 | 154 | class DictLoader(DataLoader): 155 | """Loader for dictionary format data.""" 156 | 157 | def load(self, data: List[List[Dict[str, Any]]]) -> List[List[Entity]]: 158 | """Load dictionary format data into a list of entity lists.""" 159 | if not isinstance(data, list): 160 | raise ValueError("DictLoader expects list input") 161 | 162 | if not data: 163 | return [] 164 | 165 | result = [] 166 | 167 | for doc in data: 168 | if not isinstance(doc, list): 169 | raise ValueError("Each document must be a list of entity dictionaries") 170 | 171 | current_doc = [] 172 | for entity in doc: 173 | if not isinstance(entity, dict): 174 | raise ValueError(f"Invalid entity type: {type(entity)}") 175 | 176 | required_keys = {"label", "start", "end"} 177 | if not all(key in entity for key in required_keys): 178 | raise ValueError(f"Entity missing required keys: {required_keys}") 179 | 180 | if not isinstance(entity["label"], str): 181 | raise ValueError("Entity label must be a string") 182 | 183 | if not isinstance(entity["start"], int) or not isinstance(entity["end"], int): 184 | raise ValueError("Entity start and end must be integers") 185 | 186 | current_doc.append(Entity(label=entity["label"], start=entity["start"], end=entity["end"])) 187 | result.append(current_doc) 188 | 189 | return result 190 | -------------------------------------------------------------------------------- /src/nervaluate/strategies.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Tuple 3 | 4 | from .entities import Entity, EvaluationResult, EvaluationIndices 5 | 6 | 7 | class EvaluationStrategy(ABC): 8 | """Abstract base class for evaluation strategies.""" 9 | 10 | @abstractmethod 11 | def evaluate( 12 | self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 13 | ) -> Tuple[EvaluationResult, EvaluationIndices]: 14 | """Evaluate the predicted entities against the true entities.""" 15 | 16 | 17 | class StrictEvaluation(EvaluationStrategy): 18 | """ 19 | Strict evaluation strategy - entities must match exactly. 20 | 21 | If there's a predicted entity that perfectly matches a true entity and they have the same label 22 | we mark it as correct. 23 | If there's a predicted entity that doesn't perfectly match any true entity, we mark it as spurious. 24 | If there's a true entity that doesn't perfecly match any predicted entity, we mark it as missed. 25 | All other cases are marked as incorrect. 26 | """ 27 | 28 | def evaluate( 29 | self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 30 | ) -> Tuple[EvaluationResult, EvaluationIndices]: 31 | """ 32 | Evaluate the predicted entities against the true entities using strict matching. 33 | """ 34 | result = EvaluationResult() 35 | indices = EvaluationIndices() 36 | matched_true = set() 37 | 38 | for pred_idx, pred in enumerate(pred_entities): 39 | found_match = False 40 | found_incorrect = False 41 | 42 | for true_idx, true in enumerate(true_entities): 43 | if true_idx in matched_true: 44 | continue 45 | 46 | # Check for perfect match (same boundaries and label) 47 | if pred.label == true.label and pred.start == true.start and pred.end == true.end: 48 | result.correct += 1 49 | indices.correct_indices.append((instance_index, pred_idx)) 50 | matched_true.add(true_idx) 51 | found_match = True 52 | break 53 | # Check for any overlap 54 | if pred.start <= true.end and pred.end >= true.start: 55 | result.incorrect += 1 56 | indices.incorrect_indices.append((instance_index, pred_idx)) 57 | matched_true.add(true_idx) 58 | found_incorrect = True 59 | break 60 | 61 | if not found_match and not found_incorrect: 62 | result.spurious += 1 63 | indices.spurious_indices.append((instance_index, pred_idx)) 64 | 65 | for true_idx, true in enumerate(true_entities): 66 | if true_idx not in matched_true: 67 | result.missed += 1 68 | indices.missed_indices.append((instance_index, true_idx)) 69 | 70 | result.compute_metrics() 71 | return result, indices 72 | 73 | 74 | class PartialEvaluation(EvaluationStrategy): 75 | """ 76 | Partial evaluation strategy - allows for partial matches. 77 | 78 | If there's a predicted entity that perfectly matches a true entity, we mark it as correct. 79 | If there's a predicted entity that has some minimum overlap with a true entity we mark it as partial. 80 | If there's a predicted entity that doesn't match any true entity, we mark it as spurious. 81 | If there's a true entity that doesn't match any predicted entity, we mark it as missed. 82 | 83 | There's never entity type/label checking in this strategy, and there's never an entity marked as incorrect. 84 | """ 85 | 86 | def evaluate( 87 | self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 88 | ) -> Tuple[EvaluationResult, EvaluationIndices]: 89 | result = EvaluationResult() 90 | indices = EvaluationIndices() 91 | matched_true = set() 92 | 93 | for pred_idx, pred in enumerate(pred_entities): 94 | found_match = False 95 | 96 | for true_idx, true in enumerate(true_entities): 97 | if true_idx in matched_true: 98 | continue 99 | 100 | # Check for overlap 101 | if pred.start <= true.end and pred.end >= true.start: 102 | if pred.start == true.start and pred.end == true.end: 103 | result.correct += 1 104 | indices.correct_indices.append((instance_index, pred_idx)) 105 | else: 106 | result.partial += 1 107 | indices.partial_indices.append((instance_index, pred_idx)) 108 | matched_true.add(true_idx) 109 | found_match = True 110 | break 111 | 112 | if not found_match: 113 | result.spurious += 1 114 | indices.spurious_indices.append((instance_index, pred_idx)) 115 | 116 | for true_idx, true in enumerate(true_entities): 117 | if true_idx not in matched_true: 118 | result.missed += 1 119 | indices.missed_indices.append((instance_index, true_idx)) 120 | 121 | result.compute_metrics(partial_or_type=True) 122 | return result, indices 123 | 124 | 125 | class EntityTypeEvaluation(EvaluationStrategy): 126 | """ 127 | Entity type evaluation strategy - only checks entity types. 128 | 129 | In in strategy, we check for overlap between the predicted entity and the true entity. 130 | 131 | If there's a predicted entity that perfectly matches or only some minimum overlap with a 132 | true entity, and the same label, we mark it as correct. 133 | If there's a predicted entity that has some minimum overlap or perfectly matches but has 134 | the wrong label we mark it as inccorrect. 135 | If there's a predicted entity that doesn't match any true entity, we mark it as spurious. 136 | If there's a true entity that doesn't match any predicted entity, we mark it as missed. 137 | 138 | # ToDo: define a minimum overlap threshold - see: https://github.com/MantisAI/nervaluate/pull/83 139 | """ 140 | 141 | def evaluate( 142 | self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 143 | ) -> Tuple[EvaluationResult, EvaluationIndices]: 144 | result = EvaluationResult() 145 | indices = EvaluationIndices() 146 | matched_true = set() 147 | 148 | for pred_idx, pred in enumerate(pred_entities): 149 | found_match = False 150 | found_incorrect = False 151 | 152 | for true_idx, true in enumerate(true_entities): 153 | if true_idx in matched_true: 154 | continue 155 | 156 | # Check for any overlap (perfect or minimum) 157 | if pred.start <= true.end and pred.end >= true.start: 158 | if pred.label == true.label: 159 | result.correct += 1 160 | indices.correct_indices.append((instance_index, pred_idx)) 161 | matched_true.add(true_idx) 162 | found_match = True 163 | else: 164 | result.incorrect += 1 165 | indices.incorrect_indices.append((instance_index, pred_idx)) 166 | matched_true.add(true_idx) 167 | found_incorrect = True 168 | break 169 | 170 | if not found_match and not found_incorrect: 171 | result.spurious += 1 172 | indices.spurious_indices.append((instance_index, pred_idx)) 173 | 174 | for true_idx, true in enumerate(true_entities): 175 | if true_idx not in matched_true: 176 | result.missed += 1 177 | indices.missed_indices.append((instance_index, true_idx)) 178 | 179 | result.compute_metrics(partial_or_type=True) 180 | return result, indices 181 | 182 | 183 | class ExactEvaluation(EvaluationStrategy): 184 | """ 185 | Exact evaluation strategy - exact boundary match over the surface string, regardless of the type. 186 | 187 | If there's a predicted entity that perfectly matches a true entity, regardless of the label, we mark it as correct. 188 | If there's a predicted entity that has only some minimum overlap with a true entity, we mark it as incorrect. 189 | If there's a predicted entity that doesn't match any true entity, we mark it as spurious. 190 | If there's a true entity that doesn't match any predicted entity, we mark it as missed. 191 | """ 192 | 193 | def evaluate( 194 | self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 195 | ) -> Tuple[EvaluationResult, EvaluationIndices]: 196 | """ 197 | Evaluate the predicted entities against the true entities using exact boundary matching. 198 | Entity type is not considered in the matching. 199 | """ 200 | result = EvaluationResult() 201 | indices = EvaluationIndices() 202 | matched_true = set() 203 | 204 | for pred_idx, pred in enumerate(pred_entities): 205 | found_match = False 206 | found_incorrect = False 207 | 208 | for true_idx, true in enumerate(true_entities): 209 | if true_idx in matched_true: 210 | continue 211 | 212 | # Check for exact boundary match (regardless of label) 213 | if pred.start == true.start and pred.end == true.end: 214 | result.correct += 1 215 | indices.correct_indices.append((instance_index, pred_idx)) 216 | matched_true.add(true_idx) 217 | found_match = True 218 | break 219 | # Check for any overlap 220 | if pred.start <= true.end and pred.end >= true.start: 221 | result.incorrect += 1 222 | indices.incorrect_indices.append((instance_index, pred_idx)) 223 | matched_true.add(true_idx) 224 | found_incorrect = True 225 | break 226 | 227 | if not found_match and not found_incorrect: 228 | result.spurious += 1 229 | indices.spurious_indices.append((instance_index, pred_idx)) 230 | 231 | for true_idx, true in enumerate(true_entities): 232 | if true_idx not in matched_true: 233 | result.missed += 1 234 | indices.missed_indices.append((instance_index, true_idx)) 235 | 236 | result.compute_metrics() 237 | return result, indices 238 | -------------------------------------------------------------------------------- /src/nervaluate/utils.py: -------------------------------------------------------------------------------- 1 | def split_list(token: list[str], split_chars: list[str] | None = None) -> list[list[str]]: 2 | """ 3 | Split a list into sublists based on a list of split characters. 4 | 5 | If split_chars is None, the list is split on empty strings. 6 | 7 | :param token: The list to split. 8 | :param split_chars: The characters to split on. 9 | 10 | :returns: 11 | A list of lists. 12 | """ 13 | if split_chars is None: 14 | split_chars = [""] 15 | out = [] 16 | chunk = [] 17 | for i, item in enumerate(token): 18 | if item not in split_chars: 19 | chunk.append(item) 20 | if i + 1 == len(token): 21 | out.append(chunk) 22 | else: 23 | out.append(chunk) 24 | chunk = [] 25 | return out 26 | 27 | 28 | def conll_to_spans(doc: str) -> list[list[dict]]: 29 | """ 30 | Convert a CoNLL-formatted string to a list of spans. 31 | 32 | :param doc: The CoNLL-formatted string. 33 | 34 | :returns: 35 | A list of spans. 36 | """ 37 | out = [] 38 | doc_parts = split_list(doc.split("\n"), split_chars=None) 39 | 40 | for example in doc_parts: 41 | labels = [] 42 | for token in example: 43 | token_parts = token.split("\t") 44 | label = token_parts[1] 45 | labels.append(label) 46 | out.append(labels) 47 | 48 | spans = list_to_spans(out) 49 | 50 | return spans 51 | 52 | 53 | def list_to_spans(doc: list[list[str]]) -> list[list[dict]]: 54 | """ 55 | Convert a list of tags to a list of spans. 56 | 57 | :param doc: The list of tags. 58 | 59 | :returns: 60 | A list of spans. 61 | """ 62 | spans = [collect_named_entities(tokens) for tokens in doc] 63 | return spans 64 | 65 | 66 | def collect_named_entities(tokens: list[str]) -> list[dict]: 67 | """ 68 | Creates a list of Entity named-tuples, storing the entity type and the start and end offsets of the entity. 69 | 70 | :param tokens: a list of tags 71 | 72 | :returns: 73 | A list of Entity named-tuples. 74 | """ 75 | 76 | named_entities = [] 77 | start_offset = None 78 | end_offset = None 79 | ent_type = None 80 | 81 | for offset, token_tag in enumerate(tokens): 82 | if token_tag == "O": 83 | if ent_type is not None and start_offset is not None: 84 | end_offset = offset - 1 85 | named_entities.append({"label": ent_type, "start": start_offset, "end": end_offset}) 86 | start_offset = None 87 | end_offset = None 88 | ent_type = None 89 | 90 | elif ent_type is None: 91 | ent_type = token_tag[2:] 92 | start_offset = offset 93 | 94 | elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): 95 | end_offset = offset - 1 96 | named_entities.append({"label": ent_type, "start": start_offset, "end": end_offset}) 97 | 98 | # start of a new entity 99 | ent_type = token_tag[2:] 100 | start_offset = offset 101 | end_offset = None 102 | 103 | # Catches an entity that goes up until the last token 104 | if ent_type is not None and start_offset is not None and end_offset is None: 105 | named_entities.append({"label": ent_type, "start": start_offset, "end": len(tokens) - 1}) 106 | 107 | return named_entities 108 | 109 | 110 | def find_overlap(true_range: range, pred_range: range) -> set: 111 | """ 112 | Find the overlap between two ranges. 113 | 114 | :param true_range: The true range. 115 | :param pred_range: The predicted range. 116 | 117 | :returns: 118 | A set of overlapping values. 119 | 120 | Examples: 121 | >>> find_overlap(range(1, 3), range(2, 4)) 122 | {2} 123 | >>> find_overlap(range(1, 3), range(3, 5)) 124 | set() 125 | """ 126 | 127 | true_set = set(true_range) 128 | pred_set = set(pred_range) 129 | overlaps = true_set.intersection(pred_set) 130 | 131 | return overlaps 132 | 133 | 134 | def clean_entities(ent: dict) -> dict: 135 | """ 136 | Returns just the useful keys if additional keys are present in the entity 137 | dict. 138 | 139 | This may happen if passing a list of spans directly from prodigy, which 140 | typically may include 'token_start' and 'token_end'. 141 | """ 142 | return {"start": ent["start"], "end": ent["end"], "label": ent["label"]} 143 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("../src/nervaluate") 4 | -------------------------------------------------------------------------------- /tests/test_entities.py: -------------------------------------------------------------------------------- 1 | from nervaluate.entities import Entity, EvaluationResult 2 | 3 | 4 | def test_entity_equality(): 5 | """Test Entity equality comparison.""" 6 | entity1 = Entity(label="PER", start=0, end=1) 7 | entity2 = Entity(label="PER", start=0, end=1) 8 | entity3 = Entity(label="ORG", start=0, end=1) 9 | 10 | assert entity1 == entity2 11 | assert entity1 != entity3 12 | assert entity1 != "not an entity" 13 | 14 | 15 | def test_entity_hash(): 16 | """Test Entity hashing.""" 17 | entity1 = Entity(label="PER", start=0, end=1) 18 | entity2 = Entity(label="PER", start=0, end=1) 19 | entity3 = Entity(label="ORG", start=0, end=1) 20 | 21 | assert hash(entity1) == hash(entity2) 22 | assert hash(entity1) != hash(entity3) 23 | 24 | 25 | def test_evaluation_result_compute_metrics(): 26 | """Test computation of evaluation metrics.""" 27 | result = EvaluationResult(correct=5, incorrect=2, partial=1, missed=1, spurious=1) 28 | 29 | # Test strict metrics 30 | result.compute_metrics(partial_or_type=False) 31 | assert result.precision == 5 / 9 # 5/(5+2+1+1) 32 | assert result.recall == 5 / (5 + 2 + 1 + 1) 33 | 34 | # Test partial metrics 35 | result.compute_metrics(partial_or_type=True) 36 | assert result.precision == 5.5 / 9 # (5+0.5*1)/(5+2+1+1) 37 | assert result.recall == (5 + 0.5 * 1) / (5 + 2 + 1 + 1) 38 | 39 | 40 | def test_evaluation_result_zero_cases(): 41 | """Test evaluation metrics with zero values.""" 42 | result = EvaluationResult() 43 | result.compute_metrics() 44 | assert result.precision == 0 45 | assert result.recall == 0 46 | assert result.f1 == 0 47 | -------------------------------------------------------------------------------- /tests/test_evaluator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from nervaluate.evaluator import Evaluator 3 | 4 | 5 | @pytest.fixture 6 | def sample_data(): 7 | true = [ 8 | ["O", "B-PER", "O", "B-ORG", "I-ORG", "B-LOC"], 9 | ["O", "B-PER", "O", "B-ORG"], 10 | ] 11 | 12 | pred = [ 13 | ["O", "B-PER", "O", "B-ORG", "O", "B-PER"], 14 | ["O", "B-PER", "O", "B-LOC"], 15 | ] 16 | 17 | return true, pred 18 | 19 | 20 | def test_evaluator_initialization(sample_data): 21 | """Test evaluator initialization.""" 22 | true, pred = sample_data 23 | evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") 24 | 25 | assert len(evaluator.true) == 2 26 | assert len(evaluator.pred) == 2 27 | assert evaluator.tags == ["PER", "ORG", "LOC"] 28 | 29 | 30 | def test_evaluator_evaluation(sample_data): 31 | """Test evaluation process.""" 32 | true, pred = sample_data 33 | evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") 34 | results = evaluator.evaluate() 35 | 36 | # Check that we have results for all strategies 37 | assert "overall" in results 38 | assert "entities" in results 39 | assert "strict" in results["overall"] 40 | assert "partial" in results["overall"] 41 | assert "ent_type" in results["overall"] 42 | 43 | # Check that we have results for each entity type 44 | for entity in ["PER", "ORG", "LOC"]: 45 | assert entity in results["entities"] 46 | assert "strict" in results["entities"][entity] 47 | assert "partial" in results["entities"][entity] 48 | assert "ent_type" in results["entities"][entity] 49 | 50 | 51 | def test_evaluator_with_invalid_tags(sample_data): 52 | """Test evaluator with invalid tags.""" 53 | true, pred = sample_data 54 | evaluator = Evaluator(true, pred, ["INVALID"], loader="list") 55 | results = evaluator.evaluate() 56 | 57 | for strategy in ["strict", "partial", "ent_type"]: 58 | assert results["overall"][strategy].correct == 0 59 | assert results["overall"][strategy].incorrect == 0 60 | assert results["overall"][strategy].partial == 0 61 | assert results["overall"][strategy].missed == 0 62 | assert results["overall"][strategy].spurious == 0 63 | 64 | 65 | def test_evaluator_different_document_lengths(): 66 | """Test that Evaluator raises ValueError when documents have different lengths.""" 67 | true = [ 68 | ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], # 8 tokens 69 | ["O", "B-LOC", "B-PER", "I-PER", "O", "O", "B-DATE"], # 7 tokens 70 | ] 71 | pred = [ 72 | ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], # 8 tokens 73 | ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-DATE", "I-DATE", "O"], # 10 tokens 74 | ] 75 | tags = ["PER", "ORG", "LOC", "DATE"] 76 | 77 | # Test that ValueError is raised 78 | with pytest.raises(ValueError, match="Document 1 has different lengths: true=7, pred=10"): 79 | evaluator = Evaluator(true=true, pred=pred, tags=tags, loader="list") 80 | evaluator.evaluate() 81 | -------------------------------------------------------------------------------- /tests/test_loaders.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from nervaluate.loaders import ConllLoader, ListLoader, DictLoader 4 | 5 | 6 | def test_conll_loader(): 7 | """Test CoNLL format loader.""" 8 | true_conll = ( 9 | "word\tO\nword\tO\nword\tO\nword\tO\nword\tO\nword\tO\n\n" 10 | "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" 11 | "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n\n" 12 | "word\tB-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\n" 13 | ) 14 | 15 | pred_conll = ( 16 | "word\tO\nword\tO\nword\tB-PER\nword\tI-PER\nword\tO\nword\tO\n\n" 17 | "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" 18 | "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n\n" 19 | "word\tB-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\n" 20 | ) 21 | 22 | loader = ConllLoader() 23 | true_entities = loader.load(true_conll) 24 | pred_entities = loader.load(pred_conll) 25 | 26 | # Test true entities 27 | assert len(true_entities) == 4 # Four documents 28 | assert len(true_entities[0]) == 0 # First document has no entities (all O tags) 29 | assert len(true_entities[1]) == 1 # Second document has 1 entity (ORG) 30 | assert len(true_entities[2]) == 1 # Third document has 1 entity (MISC) 31 | assert len(true_entities[3]) == 1 # Fourth document has 1 entity (MISC) 32 | 33 | # Check first entity in second document 34 | assert true_entities[1][0].label == "ORG" 35 | assert true_entities[1][0].start == 2 36 | assert true_entities[1][0].end == 3 37 | 38 | # Test pred entities 39 | assert len(pred_entities) == 4 # Four documents 40 | assert len(pred_entities[0]) == 1 # First document has 1 entity (PER) 41 | assert len(pred_entities[1]) == 1 # Second document has 1 entity (ORG) 42 | assert len(pred_entities[2]) == 1 # Third document has 1 entity (MISC) 43 | assert len(pred_entities[3]) == 1 # Fourth document has 1 entity (MISC) 44 | 45 | # Check first entity in first document 46 | assert pred_entities[0][0].label == "PER" 47 | assert pred_entities[0][0].start == 2 48 | assert pred_entities[0][0].end == 3 49 | 50 | # Test empty document handling 51 | empty_doc = "word\tO\nword\tO\nword\tO\n\n" 52 | empty_entities = loader.load(empty_doc) 53 | assert len(empty_entities) == 1 # One document 54 | assert len(empty_entities[0]) == 0 # Empty list for document with only O tags 55 | 56 | 57 | def test_list_loader(): 58 | """Test list format loader.""" 59 | true_list = [ 60 | ["O", "O", "O", "O", "O", "O"], 61 | ["O", "O", "B-ORG", "I-ORG", "O", "O"], 62 | ["O", "O", "B-MISC", "I-MISC", "O", "O"], 63 | ["B-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC"], 64 | ] 65 | 66 | pred_list = [ 67 | ["O", "O", "B-PER", "I-PER", "O", "O"], 68 | ["O", "O", "B-ORG", "I-ORG", "O", "O"], 69 | ["O", "O", "B-MISC", "I-MISC", "O", "O"], 70 | ["B-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC"], 71 | ] 72 | 73 | loader = ListLoader() 74 | true_entities = loader.load(true_list) 75 | pred_entities = loader.load(pred_list) 76 | 77 | # Test true entities 78 | assert len(true_entities) == 4 # Four documents 79 | assert len(true_entities[0]) == 0 # First document has no entities (all O tags) 80 | assert len(true_entities[1]) == 1 # Second document has 1 entity (ORG) 81 | assert len(true_entities[2]) == 1 # Third document has 1 entity (MISC) 82 | assert len(true_entities[3]) == 1 # Fourth document has 1 entity (MISC) 83 | 84 | # Check no entities in the first document 85 | assert len(true_entities[0]) == 0 86 | 87 | # Check first entity in second document 88 | assert true_entities[1][0].label == "ORG" 89 | assert true_entities[1][0].start == 2 90 | assert true_entities[1][0].end == 3 91 | 92 | # Check only entity in the last document 93 | assert true_entities[3][0].label == "MISC" 94 | assert true_entities[3][0].start == 0 95 | assert true_entities[3][0].end == 5 96 | 97 | # Test pred entities 98 | assert len(pred_entities) == 4 # Four documents 99 | assert len(pred_entities[0]) == 1 # First document has 1 entity (PER) 100 | assert len(pred_entities[1]) == 1 # Second document has 1 entity (ORG) 101 | assert len(pred_entities[2]) == 1 # Third document has 1 entity (MISC) 102 | assert len(pred_entities[3]) == 1 # Fourth document has 1 entity (MISC) 103 | 104 | # Check first entity in first document 105 | assert pred_entities[0][0].label == "PER" 106 | assert pred_entities[0][0].start == 2 107 | assert pred_entities[0][0].end == 3 108 | 109 | # Test empty document handling 110 | empty_doc = [["O", "O", "O"]] 111 | empty_entities = loader.load(empty_doc) 112 | assert len(empty_entities) == 1 # One document 113 | assert len(empty_entities[0]) == 0 # Empty list for document with only O tags 114 | 115 | 116 | def test_dict_loader(): 117 | """Test dictionary format loader.""" 118 | true_prod = [ 119 | [], 120 | [{"label": "ORG", "start": 2, "end": 3}], 121 | [{"label": "MISC", "start": 2, "end": 3}], 122 | [{"label": "MISC", "start": 0, "end": 5}], 123 | ] 124 | 125 | pred_prod = [ 126 | [{"label": "PER", "start": 2, "end": 3}], 127 | [{"label": "ORG", "start": 2, "end": 3}], 128 | [{"label": "MISC", "start": 2, "end": 3}], 129 | [{"label": "MISC", "start": 0, "end": 5}], 130 | ] 131 | 132 | loader = DictLoader() 133 | true_entities = loader.load(true_prod) 134 | pred_entities = loader.load(pred_prod) 135 | 136 | # Test true entities 137 | assert len(true_entities) == 4 # Four documents 138 | assert len(true_entities[0]) == 0 # First document has no entities 139 | assert len(true_entities[1]) == 1 # Second document has 1 entity (ORG) 140 | assert len(true_entities[2]) == 1 # Third document has 1 entity (MISC) 141 | assert len(true_entities[3]) == 1 # Fourth document has 1 entity (MISC) 142 | 143 | # Check first entity in second document 144 | assert true_entities[1][0].label == "ORG" 145 | assert true_entities[1][0].start == 2 146 | assert true_entities[1][0].end == 3 147 | 148 | # Check only entity in the last document 149 | assert true_entities[3][0].label == "MISC" 150 | assert true_entities[3][0].start == 0 151 | assert true_entities[3][0].end == 5 152 | 153 | # Test pred entities 154 | assert len(pred_entities) == 4 # Four documents 155 | assert len(pred_entities[0]) == 1 # First document has 1 entity (PER) 156 | assert len(pred_entities[1]) == 1 # Second document has 1 entity (ORG) 157 | assert len(pred_entities[2]) == 1 # Third document has 1 entity (MISC) 158 | assert len(pred_entities[3]) == 1 # Fourth document has 1 entity (MISC) 159 | 160 | # Check first entity in first document 161 | assert pred_entities[0][0].label == "PER" 162 | assert pred_entities[0][0].start == 2 163 | assert pred_entities[0][0].end == 3 164 | 165 | # Test empty document handling 166 | empty_doc = [[]] 167 | empty_entities = loader.load(empty_doc) 168 | assert len(empty_entities) == 1 # One document 169 | assert len(empty_entities[0]) == 0 # Empty list for empty document 170 | 171 | 172 | def test_loader_with_empty_input(): 173 | """Test loaders with empty input.""" 174 | # Test ConllLoader with empty string 175 | conll_loader = ConllLoader() 176 | entities = conll_loader.load("") 177 | assert len(entities) == 0 178 | 179 | # Test ListLoader with empty list 180 | list_loader = ListLoader() 181 | entities = list_loader.load([]) 182 | assert len(entities) == 0 183 | 184 | # Test DictLoader with empty list 185 | dict_loader = DictLoader() 186 | entities = dict_loader.load([]) 187 | assert len(entities) == 0 188 | 189 | 190 | def test_loader_with_invalid_data(): 191 | """Test loaders with invalid data.""" 192 | with pytest.raises(Exception): 193 | ConllLoader().load("invalid\tdata") 194 | 195 | with pytest.raises(Exception): 196 | ListLoader().load([["invalid"]]) 197 | 198 | with pytest.raises(Exception): 199 | DictLoader().load([[{"invalid": "data"}]]) 200 | -------------------------------------------------------------------------------- /tests/test_strategies.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from nervaluate.entities import Entity 3 | from nervaluate.strategies import EntityTypeEvaluation, ExactEvaluation, PartialEvaluation, StrictEvaluation 4 | 5 | 6 | def create_entities_from_bio(bio_tags): 7 | """Helper function to create entities from BIO tags.""" 8 | entities = [] 9 | current_entity = None 10 | 11 | for i, tag in enumerate(bio_tags): 12 | if tag == "O": 13 | continue 14 | 15 | if tag.startswith("B-"): 16 | if current_entity: 17 | entities.append(current_entity) 18 | current_entity = Entity(tag[2:], i, i + 1) 19 | elif tag.startswith("I-"): 20 | if current_entity: 21 | current_entity.end = i + 1 22 | else: 23 | # Handle case where I- tag appears without B- 24 | current_entity = Entity(tag[2:], i, i + 1) 25 | 26 | if current_entity: 27 | entities.append(current_entity) 28 | 29 | return entities 30 | 31 | 32 | @pytest.fixture 33 | def base_sequence(): 34 | """Base sequence: 'The John Smith who works at Google Inc'""" 35 | return ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"] 36 | 37 | 38 | class TestStrictEvaluation: 39 | """Test cases for strict evaluation strategy.""" 40 | 41 | def test_perfect_match(self, base_sequence): 42 | """Test case: Perfect match of all entities.""" 43 | true = create_entities_from_bio(base_sequence) 44 | pred = create_entities_from_bio(base_sequence) 45 | 46 | evaluator = StrictEvaluation() 47 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG"]) 48 | 49 | assert result.correct == 2 50 | assert result.incorrect == 0 51 | assert result.partial == 0 52 | assert result.missed == 0 53 | assert result.spurious == 0 54 | assert result_indices.correct_indices == [(0, 0), (0, 1)] 55 | assert result_indices.incorrect_indices == [] 56 | assert result_indices.partial_indices == [] 57 | assert result_indices.missed_indices == [] 58 | assert result_indices.spurious_indices == [] 59 | 60 | def test_missed_entity(self, base_sequence): 61 | """Test case: One entity is missed in prediction.""" 62 | true = create_entities_from_bio(base_sequence) 63 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "O"]) 64 | 65 | evaluator = StrictEvaluation() 66 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG"]) 67 | 68 | assert result.correct == 1 69 | assert result.incorrect == 0 70 | assert result.partial == 0 71 | assert result.missed == 1 72 | assert result.spurious == 0 73 | assert result_indices.correct_indices == [(0, 0)] 74 | assert result_indices.incorrect_indices == [] 75 | assert result_indices.partial_indices == [] 76 | assert result_indices.missed_indices == [(0, 1)] 77 | assert result_indices.spurious_indices == [] 78 | 79 | def test_wrong_label(self, base_sequence): 80 | """Test case: Entity with wrong label.""" 81 | true = create_entities_from_bio(base_sequence) 82 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "I-LOC"]) 83 | 84 | evaluator = StrictEvaluation() 85 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 86 | 87 | assert result.correct == 1 88 | assert result.incorrect == 1 89 | assert result.partial == 0 90 | assert result.missed == 0 91 | assert result.spurious == 0 92 | assert result_indices.correct_indices == [(0, 0)] 93 | assert result_indices.incorrect_indices == [(0, 1)] 94 | assert result_indices.partial_indices == [] 95 | assert result_indices.missed_indices == [] 96 | assert result_indices.spurious_indices == [] 97 | 98 | def test_wrong_boundary(self, base_sequence): 99 | """Test case: Entity with wrong boundary.""" 100 | true = create_entities_from_bio(base_sequence) 101 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "O"]) 102 | 103 | evaluator = StrictEvaluation() 104 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 105 | 106 | assert result.correct == 1 107 | assert result.incorrect == 1 108 | assert result.partial == 0 109 | assert result.missed == 0 110 | assert result.spurious == 0 111 | assert result_indices.correct_indices == [(0, 0)] 112 | assert result_indices.incorrect_indices == [(0, 1)] 113 | assert result_indices.partial_indices == [] 114 | assert result_indices.missed_indices == [] 115 | assert result_indices.spurious_indices == [] 116 | 117 | def test_shifted_boundary(self, base_sequence): 118 | """Test case: Entity with shifted boundary.""" 119 | true = create_entities_from_bio(base_sequence) 120 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "B-LOC"]) 121 | 122 | evaluator = StrictEvaluation() 123 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 124 | 125 | assert result.correct == 1 126 | assert result.incorrect == 1 127 | assert result.partial == 0 128 | assert result.missed == 0 129 | assert result.spurious == 0 130 | assert result_indices.correct_indices == [(0, 0)] 131 | assert result_indices.incorrect_indices == [(0, 1)] 132 | assert result_indices.partial_indices == [] 133 | assert result_indices.missed_indices == [] 134 | assert result_indices.spurious_indices == [] 135 | 136 | def test_extra_entity(self, base_sequence): 137 | """Test case: Extra entity in prediction.""" 138 | true = create_entities_from_bio(base_sequence) 139 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "B-PER", "O", "B-LOC", "I-LOC"]) 140 | 141 | evaluator = StrictEvaluation() 142 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 143 | 144 | assert result.correct == 1 145 | assert result.incorrect == 1 146 | assert result.partial == 0 147 | assert result.missed == 0 148 | assert result.spurious == 1 149 | assert result_indices.correct_indices == [(0, 0)] 150 | assert result_indices.incorrect_indices == [(0, 2)] 151 | assert result_indices.partial_indices == [] 152 | assert result_indices.missed_indices == [] 153 | assert result_indices.spurious_indices == [(0, 1)] 154 | 155 | 156 | class TestEntityTypeEvaluation: 157 | """Test cases for entity type evaluation strategy.""" 158 | 159 | def test_perfect_match(self, base_sequence): 160 | """Test case: Perfect match of all entities.""" 161 | true = create_entities_from_bio(base_sequence) 162 | pred = create_entities_from_bio(base_sequence) 163 | 164 | evaluator = EntityTypeEvaluation() 165 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG"]) 166 | 167 | assert result.correct == 2 168 | assert result.incorrect == 0 169 | assert result.partial == 0 170 | assert result.missed == 0 171 | assert result.spurious == 0 172 | assert result_indices.correct_indices == [(0, 0), (0, 1)] 173 | assert result_indices.incorrect_indices == [] 174 | assert result_indices.partial_indices == [] 175 | assert result_indices.missed_indices == [] 176 | assert result_indices.spurious_indices == [] 177 | 178 | def test_missed_entity(self, base_sequence): 179 | """Test case: One entity is missed in prediction.""" 180 | true = create_entities_from_bio(base_sequence) 181 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "O"]) 182 | 183 | evaluator = EntityTypeEvaluation() 184 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG"]) 185 | 186 | assert result.correct == 1 187 | assert result.incorrect == 0 188 | assert result.partial == 0 189 | assert result.missed == 1 190 | assert result.spurious == 0 191 | assert result_indices.correct_indices == [(0, 0)] 192 | assert result_indices.incorrect_indices == [] 193 | assert result_indices.partial_indices == [] 194 | assert result_indices.missed_indices == [(0, 1)] 195 | assert result_indices.spurious_indices == [] 196 | 197 | def test_wrong_label(self, base_sequence): 198 | """Test case: Entity with wrong label.""" 199 | true = create_entities_from_bio(base_sequence) 200 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "I-LOC"]) 201 | 202 | evaluator = EntityTypeEvaluation() 203 | result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 204 | 205 | assert result.correct == 1 206 | assert result.incorrect == 1 207 | assert result.partial == 0 208 | assert result.missed == 0 209 | assert result.spurious == 0 210 | 211 | def test_wrong_boundary(self, base_sequence): 212 | """Test case: Entity with wrong boundary.""" 213 | true = create_entities_from_bio(base_sequence) 214 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "O"]) 215 | 216 | evaluator = EntityTypeEvaluation() 217 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 218 | 219 | assert result.correct == 1 220 | assert result.incorrect == 1 221 | assert result.partial == 0 222 | assert result.missed == 0 223 | assert result.spurious == 0 224 | assert result_indices.correct_indices == [(0, 0)] 225 | assert result_indices.incorrect_indices == [(0, 1)] 226 | assert result_indices.partial_indices == [] 227 | assert result_indices.missed_indices == [] 228 | assert result_indices.spurious_indices == [] 229 | 230 | def test_shifted_boundary(self, base_sequence): 231 | """Test case: Entity with shifted boundary.""" 232 | true = create_entities_from_bio(base_sequence) 233 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "B-LOC"]) 234 | 235 | evaluator = EntityTypeEvaluation() 236 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 237 | 238 | assert result.correct == 1 239 | assert result.incorrect == 1 240 | assert result.partial == 0 241 | assert result.missed == 0 242 | assert result.spurious == 0 243 | assert result_indices.correct_indices == [(0, 0)] 244 | assert result_indices.incorrect_indices == [(0, 1)] 245 | assert result_indices.partial_indices == [] 246 | assert result_indices.missed_indices == [] 247 | assert result_indices.spurious_indices == [] 248 | 249 | def test_extra_entity(self, base_sequence): 250 | """Test case: Extra entity in prediction.""" 251 | true = create_entities_from_bio(base_sequence) 252 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "B-PER", "O", "B-LOC", "I-LOC"]) 253 | 254 | evaluator = EntityTypeEvaluation() 255 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 256 | 257 | assert result.correct == 1 258 | assert result.incorrect == 1 259 | assert result.partial == 0 260 | assert result.missed == 0 261 | assert result.spurious == 1 262 | assert result_indices.correct_indices == [(0,0)] 263 | assert result_indices.incorrect_indices == [(0, 2)] 264 | assert result_indices.spurious_indices == [(0, 1)] 265 | assert result_indices.missed_indices == [] 266 | assert result_indices.partial_indices == [] 267 | 268 | class TestExactEvaluation: 269 | """Test cases for exact evaluation strategy.""" 270 | 271 | def test_perfect_match(self, base_sequence): 272 | """Test case: Perfect match of all entities.""" 273 | true = create_entities_from_bio(base_sequence) 274 | pred = create_entities_from_bio(base_sequence) 275 | 276 | evaluator = ExactEvaluation() 277 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG"]) 278 | 279 | assert result.correct == 2 280 | assert result.incorrect == 0 281 | assert result.partial == 0 282 | assert result.missed == 0 283 | assert result.spurious == 0 284 | assert result_indices.correct_indices == [(0, 0), (0, 1)] 285 | assert result_indices.incorrect_indices == [] 286 | assert result_indices.partial_indices == [] 287 | assert result_indices.missed_indices == [] 288 | assert result_indices.spurious_indices == [] 289 | 290 | def test_missed_entity(self, base_sequence): 291 | """Test case: One entity is missed in prediction.""" 292 | true = create_entities_from_bio(base_sequence) 293 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "O"]) 294 | 295 | evaluator = ExactEvaluation() 296 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG"]) 297 | 298 | assert result.correct == 1 299 | assert result.incorrect == 0 300 | assert result.partial == 0 301 | assert result.missed == 1 302 | assert result.spurious == 0 303 | assert result_indices.correct_indices == [(0, 0)] 304 | assert result_indices.incorrect_indices == [] 305 | assert result_indices.partial_indices == [] 306 | assert result_indices.missed_indices == [(0, 1)] 307 | assert result_indices.spurious_indices == [] 308 | 309 | def test_wrong_label(self, base_sequence): 310 | """Test case: Entity with wrong label.""" 311 | true = create_entities_from_bio(base_sequence) 312 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "I-LOC"]) 313 | 314 | evaluator = ExactEvaluation() 315 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 316 | 317 | assert result.correct == 2 318 | assert result.incorrect == 0 319 | assert result.partial == 0 320 | assert result.missed == 0 321 | assert result.spurious == 0 322 | assert result_indices.correct_indices == [(0, 0), (0, 1)] 323 | assert result_indices.incorrect_indices == [] 324 | assert result_indices.partial_indices == [] 325 | assert result_indices.missed_indices == [] 326 | assert result_indices.spurious_indices == [] 327 | 328 | def test_wrong_boundary(self, base_sequence): 329 | """Test case: Entity with wrong boundary.""" 330 | true = create_entities_from_bio(base_sequence) 331 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "O"]) 332 | 333 | evaluator = ExactEvaluation() 334 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 335 | 336 | assert result.correct == 1 337 | assert result.incorrect == 1 338 | assert result.partial == 0 339 | assert result.missed == 0 340 | assert result.spurious == 0 341 | assert result_indices.correct_indices == [(0, 0)] 342 | assert result_indices.incorrect_indices == [(0, 1)] 343 | assert result_indices.partial_indices == [] 344 | assert result_indices.missed_indices == [] 345 | assert result_indices.spurious_indices == [] 346 | 347 | def test_shifted_boundary(self, base_sequence): 348 | """Test case: Entity with shifted boundary.""" 349 | true = create_entities_from_bio(base_sequence) 350 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "B-LOC"]) 351 | 352 | evaluator = ExactEvaluation() 353 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 354 | 355 | assert result.correct == 1 356 | assert result.incorrect == 1 357 | assert result.partial == 0 358 | assert result.missed == 0 359 | assert result.spurious == 0 360 | assert result_indices.correct_indices == [(0, 0)] 361 | assert result_indices.incorrect_indices == [(0, 1)] 362 | assert result_indices.partial_indices == [] 363 | assert result_indices.missed_indices == [] 364 | assert result_indices.spurious_indices == [] 365 | 366 | def test_extra_entity(self, base_sequence): 367 | """Test case: Extra entity in prediction.""" 368 | true = create_entities_from_bio(base_sequence) 369 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "B-PER", "O", "B-LOC", "I-LOC"]) 370 | 371 | evaluator = ExactEvaluation() 372 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 373 | 374 | assert result.correct == 2 375 | assert result.incorrect == 0 376 | assert result.partial == 0 377 | assert result.missed == 0 378 | assert result.spurious == 1 379 | assert result_indices.correct_indices == [(0, 0), (0, 2)] 380 | assert result_indices.incorrect_indices == [] 381 | assert result_indices.partial_indices == [] 382 | assert result_indices.missed_indices == [] 383 | assert result_indices.spurious_indices == [(0, 1)] 384 | 385 | 386 | class TestPartialEvaluation: 387 | """Test cases for partial evaluation strategy.""" 388 | 389 | def test_perfect_match(self, base_sequence): 390 | """Test case: Perfect match of all entities.""" 391 | true = create_entities_from_bio(base_sequence) 392 | pred = create_entities_from_bio(base_sequence) 393 | 394 | evaluator = PartialEvaluation() 395 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG"]) 396 | 397 | assert result.correct == 2 398 | assert result.incorrect == 0 399 | assert result.partial == 0 400 | assert result.missed == 0 401 | assert result.spurious == 0 402 | assert result_indices.correct_indices == [(0, 0), (0, 1)] 403 | assert result_indices.incorrect_indices == [] 404 | assert result_indices.partial_indices == [] 405 | assert result_indices.missed_indices == [] 406 | assert result_indices.spurious_indices == [] 407 | 408 | def test_missed_entity(self, base_sequence): 409 | """Test case: One entity is missed in prediction.""" 410 | true = create_entities_from_bio(base_sequence) 411 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "O"]) 412 | 413 | evaluator = PartialEvaluation() 414 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG"]) 415 | 416 | assert result.correct == 1 417 | assert result.incorrect == 0 418 | assert result.partial == 0 419 | assert result.missed == 1 420 | assert result.spurious == 0 421 | assert result_indices.correct_indices == [(0, 0)] 422 | assert result_indices.incorrect_indices == [] 423 | assert result_indices.partial_indices == [] 424 | assert result_indices.missed_indices == [(0, 1)] 425 | assert result_indices.spurious_indices == [] 426 | 427 | def test_wrong_label(self, base_sequence): 428 | """Test case: Entity with wrong label.""" 429 | true = create_entities_from_bio(base_sequence) 430 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "I-LOC"]) 431 | 432 | evaluator = PartialEvaluation() 433 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 434 | 435 | assert result.correct == 2 436 | assert result.incorrect == 0 437 | assert result.partial == 0 438 | assert result.missed == 0 439 | assert result.spurious == 0 440 | assert result_indices.correct_indices == [(0, 0), (0, 1)] 441 | assert result_indices.incorrect_indices == [] 442 | assert result_indices.partial_indices == [] 443 | assert result_indices.missed_indices == [] 444 | assert result_indices.spurious_indices == [] 445 | 446 | def test_wrong_boundary(self, base_sequence): 447 | """Test case: Entity with wrong boundary.""" 448 | true = create_entities_from_bio(base_sequence) 449 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "O"]) 450 | 451 | evaluator = PartialEvaluation() 452 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 453 | 454 | assert result.correct == 1 455 | assert result.incorrect == 0 456 | assert result.partial == 1 457 | assert result.missed == 0 458 | assert result.spurious == 0 459 | assert result_indices.correct_indices == [(0, 0)] 460 | assert result_indices.incorrect_indices == [] 461 | assert result_indices.partial_indices == [(0, 1)] 462 | assert result_indices.missed_indices == [] 463 | assert result_indices.spurious_indices == [] 464 | 465 | def test_shifted_boundary(self, base_sequence): 466 | """Test case: Entity with shifted boundary.""" 467 | true = create_entities_from_bio(base_sequence) 468 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "B-LOC"]) 469 | 470 | evaluator = PartialEvaluation() 471 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 472 | 473 | assert result.correct == 1 474 | assert result.incorrect == 0 475 | assert result.partial == 1 476 | assert result.missed == 0 477 | assert result.spurious == 0 478 | assert result_indices.correct_indices == [(0, 0)] 479 | assert result_indices.incorrect_indices == [] 480 | assert result_indices.partial_indices == [(0, 1)] 481 | assert result_indices.missed_indices == [] 482 | assert result_indices.spurious_indices == [] 483 | 484 | def test_extra_entity(self, base_sequence): 485 | """Test case: Extra entity in prediction.""" 486 | true = create_entities_from_bio(base_sequence) 487 | pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "B-PER", "O", "B-LOC", "I-LOC"]) 488 | 489 | evaluator = PartialEvaluation() 490 | result, result_indices = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) 491 | 492 | assert result.correct == 2 493 | assert result.incorrect == 0 494 | assert result.partial == 0 495 | assert result.missed == 0 496 | assert result.spurious == 1 497 | assert result_indices.correct_indices == [(0, 0), (0, 2)] 498 | assert result_indices.incorrect_indices == [] 499 | assert result_indices.partial_indices == [] 500 | assert result_indices.missed_indices == [] 501 | assert result_indices.spurious_indices == [(0, 1)] 502 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from nervaluate import ( 2 | collect_named_entities, 3 | conll_to_spans, 4 | list_to_spans, 5 | split_list, 6 | ) 7 | 8 | 9 | def test_list_to_spans(): 10 | before = [ 11 | ["O", "B-LOC", "I-LOC", "B-LOC", "I-LOC", "O"], 12 | ["O", "B-GPE", "I-GPE", "B-GPE", "I-GPE", "O"], 13 | ] 14 | 15 | expected = [ 16 | [ 17 | {"label": "LOC", "start": 1, "end": 2}, 18 | {"label": "LOC", "start": 3, "end": 4}, 19 | ], 20 | [ 21 | {"label": "GPE", "start": 1, "end": 2}, 22 | {"label": "GPE", "start": 3, "end": 4}, 23 | ], 24 | ] 25 | 26 | result = list_to_spans(before) 27 | 28 | assert result == expected 29 | 30 | 31 | def test_list_to_spans_1(): 32 | before = [ 33 | ["O", "O", "O", "O", "O", "O"], 34 | ["O", "O", "B-ORG", "I-ORG", "O", "O"], 35 | ["O", "O", "B-MISC", "I-MISC", "O", "O"], 36 | ] 37 | 38 | expected = [ 39 | [], 40 | [{"label": "ORG", "start": 2, "end": 3}], 41 | [{"label": "MISC", "start": 2, "end": 3}], 42 | ] 43 | 44 | actual = list_to_spans(before) 45 | 46 | assert actual == expected 47 | 48 | 49 | def test_conll_to_spans(): 50 | before = ( 51 | ",\tO\n" 52 | "Davos\tB-PER\n" 53 | "2018\tO\n" 54 | ":\tO\n" 55 | "Soros\tB-PER\n" 56 | "accuses\tO\n" 57 | "Trump\tB-PER\n" 58 | "of\tO\n" 59 | "wanting\tO\n" 60 | "\n" 61 | "foo\tO\n" 62 | ) 63 | 64 | after = [ 65 | [ 66 | {"label": "PER", "start": 1, "end": 1}, 67 | {"label": "PER", "start": 4, "end": 4}, 68 | {"label": "PER", "start": 6, "end": 6}, 69 | ], 70 | [], 71 | ] 72 | 73 | out = conll_to_spans(before) 74 | 75 | assert after == out 76 | 77 | 78 | def test_conll_to_spans_1(): 79 | before = ( 80 | "word\tO\nword\tO\nword\tO\nword\tO\nword\tO\nword\tO\n\n" 81 | "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" 82 | "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n" 83 | ) 84 | 85 | expected = [ 86 | [], 87 | [{"label": "ORG", "start": 2, "end": 3}], 88 | [{"label": "MISC", "start": 2, "end": 3}], 89 | ] 90 | 91 | actual = conll_to_spans(before) 92 | 93 | assert actual == expected 94 | 95 | 96 | def test_split_list(): 97 | before = ["aa", "bb", "cc", "", "dd", "ee", "ff"] 98 | expected = [["aa", "bb", "cc"], ["dd", "ee", "ff"]] 99 | out = split_list(before) 100 | 101 | assert expected == out 102 | 103 | 104 | def test_collect_named_entities_same_type_in_sequence(): 105 | tags = ["O", "B-LOC", "I-LOC", "B-LOC", "I-LOC", "O"] 106 | result = collect_named_entities(tags) 107 | expected = [ 108 | {"label": "LOC", "start": 1, "end": 2}, 109 | {"label": "LOC", "start": 3, "end": 4}, 110 | ] 111 | assert result == expected 112 | 113 | 114 | def test_collect_named_entities_sequence_has_only_one_entity(): 115 | tags = ["B-LOC", "I-LOC"] 116 | result = collect_named_entities(tags) 117 | expected = [{"label": "LOC", "start": 0, "end": 1}] 118 | assert result == expected 119 | 120 | 121 | def test_collect_named_entities_entity_goes_until_last_token(): 122 | tags = ["O", "B-LOC", "I-LOC", "B-LOC", "I-LOC"] 123 | result = collect_named_entities(tags) 124 | expected = [ 125 | {"label": "LOC", "start": 1, "end": 2}, 126 | {"label": "LOC", "start": 3, "end": 4}, 127 | ] 128 | assert result == expected 129 | 130 | 131 | def test_collect_named_entities_no_entity(): 132 | tags = ["O", "O", "O", "O", "O"] 133 | result = collect_named_entities(tags) 134 | expected = [] 135 | assert result == expected 136 | --------------------------------------------------------------------------------