├── .flake8 ├── .github └── workflows │ └── pr-checks.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── extensions.json ├── launch.json └── settings.json ├── CODE_OF_CONDUCT.md ├── Dockerfile_no_conda ├── LICENSE ├── Makefile ├── README.md ├── RadFact.png ├── SECURITY.md ├── SUPPORT.md ├── configs ├── default.yaml ├── endpoints │ ├── azure_chat_openai.yaml │ └── chat_openai.yaml ├── radfact.yaml └── report_to_phrases.yaml ├── dev_environment.yaml ├── examples ├── findings_generation_examples.csv ├── grounded_reporting_examples.json └── test_examples.json ├── getting_started.ipynb ├── mypy.ini ├── primary_deps.yaml ├── pyproject.toml ├── setup_packages.py ├── src └── radfact │ ├── __init__.py │ ├── azure_utils │ ├── __init__.py │ ├── auth.py │ └── bearer_token_provider.py │ ├── cli │ ├── run_radfact.py │ ├── run_radfact_test_examples.py │ └── run_report_to_phrases.py │ ├── data_utils │ ├── __init__.py │ └── grounded_phrase_list.py │ ├── llm_utils │ ├── __init__.py │ ├── endpoint.py │ ├── engine │ │ ├── __init__.py │ │ ├── arguments.py │ │ ├── data_subset.py │ │ ├── endpoint_utils.py │ │ ├── engine.py │ │ └── redis_cache.py │ ├── nli │ │ ├── __init__.py │ │ ├── processor.py │ │ ├── prompts │ │ │ ├── few_shot_examples.json │ │ │ └── system_message_ev_singlephrase.txt │ │ └── schema.py │ ├── processor │ │ ├── __init__.py │ │ ├── base_processor.py │ │ └── structured_processor.py │ ├── report_to_phrases │ │ ├── __init__.py │ │ ├── processor.py │ │ ├── prompts │ │ │ ├── few_shot_examples.json │ │ │ └── system_message.txt │ │ └── schema.py │ └── text_utils.py │ ├── metric │ ├── __init__.py │ ├── bootstrapping.py │ ├── box_metrics.py │ ├── print_utils.py │ ├── radfact.py │ └── schema.py │ └── paths.py └── tests ├── data_utils └── test_grounded_phrase_list.py ├── llm_utils ├── engine │ ├── test_engine.py │ └── test_redis_cache.py ├── report_to_phrases │ └── test_schema.py └── test_text_utils.py └── metric ├── test_box_metrics.py └── test_radfact.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = 3 | build, 4 | max-line-length = 120 5 | max-complexity = 25 6 | extend-select = B950 7 | extend-ignore = 8 | E731, 9 | W503, 10 | E203, 11 | E501, 12 | E402, 13 | -------------------------------------------------------------------------------- /.github/workflows/pr-checks.yaml: -------------------------------------------------------------------------------- 1 | name: PR checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: [opened, synchronize, reopened] 9 | branches: 10 | - main 11 | workflow_dispatch: 12 | 13 | # Cancel previous runs of this workflow that are still in progress. 14 | concurrency: 15 | group: ${{ github.ref }}/checks 16 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 17 | 18 | permissions: 19 | # This is required for actions/checkout 20 | contents: read 21 | 22 | jobs: 23 | run_code_quality: 24 | runs-on: ubuntu-20.04 25 | steps: 26 | - uses: actions/checkout@v3 27 | 28 | - uses: conda-incubator/setup-miniconda@v3 29 | with: 30 | activate-environment: radfact 31 | environment-file: dev_environment.yaml 32 | 33 | - name: Install repo packages 34 | run: make setup_packages 35 | shell: bash -el {0} 36 | 37 | - name: Run Flake8 38 | if: ${{ always() }} 39 | run: make flake8 40 | shell: bash -el {0} 41 | 42 | - name: Run Black 43 | if: ${{ always() }} 44 | run: make blackcheck 45 | shell: bash -el {0} 46 | 47 | - name: Run Mypy 48 | if: ${{ always() }} 49 | run: make mypy 50 | shell: bash -el {0} 51 | 52 | run_pytest: 53 | runs-on: ubuntu-20.04 54 | steps: 55 | - uses: actions/checkout@v3 56 | with: 57 | lfs: true 58 | 59 | - uses: conda-incubator/setup-miniconda@v3 60 | with: 61 | activate-environment: radfact 62 | environment-file: dev_environment.yaml 63 | 64 | - name: Install repo packages 65 | run: make setup_packages 66 | shell: bash -el {0} 67 | 68 | - name: Run pytest 69 | run: pytest 70 | shell: bash -el {0} 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | config.json 165 | outputs/* 166 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-added-large-files 6 | args: ["--maxkb", "50"] 7 | 8 | - repo: https://github.com/psf/black 9 | rev: 23.3.0 10 | hooks: 11 | - id: black 12 | args: ["--config=./pyproject.toml", "--check"] 13 | 14 | - repo: https://github.com/pycqa/flake8 15 | rev: 6.0.0 16 | hooks: 17 | - id: flake8 18 | args: ["--config=./.flake8"] 19 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=827846 to learn about workspace recommendations. 3 | // Extension identifier format: ${publisher}.${name}. Example: vscode.csharp 4 | // List of extensions which should be recommended for users of this workspace. 5 | "recommendations": [ 6 | "ms-python.black-formatter", 7 | "ms-python.isort", 8 | "ms-python.python", 9 | "ms-python.vscode-pylance", 10 | "njpwerner.autodocstring", 11 | "usernamehw.errorlens", 12 | "ms-python.mypy-type-checker", 13 | "ms-python.flake8", 14 | "ms-python.pylint" 15 | ], 16 | // List of extensions recommended by VS Code that should not be recommended for users of this workspace. 17 | "unwantedRecommendations": [] 18 | } 19 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Debug Tests", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${userHome}/miniconda3/envs/radfact/bin/pytest", 12 | "purpose": [ 13 | "debug-test" 14 | ], 15 | "console": "integratedTerminal", 16 | "justMyCode": true 17 | }, 18 | { 19 | "name": "Python: Current File", 20 | "type": "python", 21 | "request": "launch", 22 | "program": "${file}", 23 | "console": "integratedTerminal", 24 | "justMyCode": true 25 | }, 26 | { 27 | "name": "Debug LLM pipeline for report_to_phrases", 28 | "type": "python", 29 | "request": "launch", 30 | "program": "src/radfact/cli/run_report_to_phrases.py", 31 | "console": "integratedTerminal", 32 | "justMyCode": false, 33 | "args": [ 34 | "processing.start_index=0", 35 | "processing.end_index=10", 36 | "processing.batch_size=2", 37 | "dataset.csv_path=", 38 | ] 39 | }, 40 | { 41 | "name": "Debug run_radfact", 42 | "type": "python", 43 | "request": "launch", 44 | "program": "src/radfact/cli/run_radfact.py", 45 | "console": "integratedTerminal", 46 | "justMyCode": false, 47 | "args": [ 48 | "--input_path=examples/findings_generation_examples.csv", 49 | "--is_narrative_text", 50 | ] 51 | }, 52 | ] 53 | } 54 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "autoDocstring.docstringFormat": "sphinx-notypes", 3 | "files.trimTrailingWhitespace": true, 4 | "files.insertFinalNewline": true, 5 | "files.trimFinalNewlines": true, 6 | "files.watcherExclude": { 7 | "**/.git/objects/**": true, 8 | "**/.git/subtree-cache/**": true, 9 | "**/.mypy_cache/**": true, 10 | "**/.pytest_cache/**": true 11 | }, 12 | "python.testing.unittestEnabled": false, 13 | "python.testing.pytestEnabled": true, 14 | "flake8.args": [ 15 | "--config=${workspaceFolder}/.flake8", 16 | ], 17 | "isort.args": [ 18 | "--src=${workspaceFolder}", 19 | "--settings=${workspaceFolder}/pyproject.toml", 20 | "-l=120", 21 | ], 22 | "[python]": { 23 | "editor.rulers": [ 24 | 120 25 | ], 26 | "editor.formatOnSave": true, 27 | "editor.codeActionsOnSave": { 28 | "source.organizeImports": "explicit", 29 | "source.unusedImports": "explicit" 30 | }, 31 | "editor.defaultFormatter": "ms-python.black-formatter", 32 | }, 33 | "black-formatter.args": [ 34 | "--line-length=120", 35 | "--config=${workspaceFolder}/pyproject.toml", 36 | ], 37 | "black-formatter.importStrategy": "fromEnvironment", 38 | "rewrap.wrappingColumn": 120, 39 | "mypy-type-checker.args": [ 40 | "--config-file=${workspaceFolder}/mypy.ini", 41 | ], 42 | "mypy-type-checker.importStrategy": "fromEnvironment", 43 | "mypy-type-checker.preferDaemon": false, 44 | "mypy-type-checker.reportingScope": "workspace", 45 | "python.testing.pytestArgs": [ 46 | "." 47 | ], 48 | } 49 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /Dockerfile_no_conda: -------------------------------------------------------------------------------- 1 | # This Dockerfile specifies the basic image, and installs miniconda. 2 | 3 | FROM nvcr.io/nvidia/pytorch:22.10-py3 4 | 5 | # Remove that full Anaconda folder to save space 6 | RUN rm -rf /opt/conda 7 | 8 | ARG user=radfact 9 | ARG userhome=/home/${user} 10 | 11 | # Add a user so that we don't run as root in the dev container 12 | RUN useradd -m ${user} && chown -R ${user} ${userhome} 13 | USER ${user} 14 | WORKDIR ${userhome} 15 | 16 | # Add miniconda. 17 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 18 | bash Miniconda3-latest-Linux-x86_64.sh -b -p ${userhome}/miniconda && \ 19 | rm Miniconda3-latest-Linux-x86_64.sh && \ 20 | PATH=${userhome}/miniconda/bin:$PATH && \ 21 | conda init bash zsh && \ 22 | source "${userhome}/miniconda/bin/activate" && \ 23 | conda install -y -n base conda-libmamba-solver && \ 24 | conda config --set solver libmamba 25 | 26 | ENV PATH ${userhome}/miniconda/bin:$PATH 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # A temporary file for environment generation 2 | TEMP_ENV_YAML := _temp_environment.yaml 3 | DEV_ENV_YAML := dev_environment.yaml 4 | BLACK_ARGS := --config pyproject.toml . 5 | # Folder where all downloaded files are cached 6 | CACHE := $(HOME)/.cache 7 | ENV_NAME := radfact 8 | REPO_NAME := $(shell basename `pwd`) 9 | 10 | # Get latest Miniconda 11 | miniconda: 12 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 13 | bash Miniconda3-latest-Linux-x86_64.sh 14 | 15 | # Install the fast Conda mamba solver 16 | mamba: 17 | conda update -y -n base conda 18 | conda install -y -n base conda-libmamba-solver 19 | conda config --set solver libmamba 20 | 21 | # Create the Conda environment 22 | env: 23 | @echo Removing current Conda environment 24 | conda env remove -n $(ENV_NAME) 25 | @echo Re-building Conda environment from $(DEV_ENV_YAML) 26 | conda env create -n $(ENV_NAME) -f $(DEV_ENV_YAML) 27 | @echo Add all local packages in editable mode 28 | $(MAKE) setup_packages_with_deps 29 | 30 | # Remove and re-create the current environment, based on primary_deps.yaml. 31 | # Then add all packages in editable mode, without installing the dependencies. 32 | env_from_primary: 33 | @echo Removing current Conda environment 34 | conda env remove -n $(ENV_NAME) 35 | @echo Re-building Conda environment from primary_deps.yaml 36 | conda env create -f primary_deps.yaml -n $(ENV_NAME) 37 | 38 | # Run the command that adds all subpackages in editable mode, without adding the dependencies. 39 | # Dependencies are already installed in the Conda environment. 40 | setup_packages: 41 | conda run -n $(ENV_NAME) --no-capture-output python setup_packages.py --no-deps 42 | 43 | # Run the command that adds all subpackages in editable mode, including the dependencies. This should be used when 44 | # creating the locked environment file. 45 | setup_packages_with_deps: 46 | conda run -n $(ENV_NAME) --no-capture-output python setup_packages.py --add-dev-deps 47 | 48 | # Export the Conda environment to _temp_environment.yaml 49 | # Exclude all the local packages that are possibly installed. 50 | _env_export: 51 | rm -f $(TEMP_ENV_YAML) 52 | @echo "# DO NOT MANUALLY EDIT THIS FILE, USE 'make env_lock' INSTEAD" > $(TEMP_ENV_YAML) 53 | conda env export --no-builds -n $(ENV_NAME) | \ 54 | grep -v "^prefix" | \ 55 | grep -Ev "\s(azure-utils|llm-utils|radfact)==" \ 56 | >> $(TEMP_ENV_YAML); \ 57 | 58 | # Export the Conda environment to dev_environment.yaml 59 | # We need to call Make recursively here for _env_export because in environment locking this will be called twice 60 | # and would be skipped if specified as 'env_export: _env_export' 61 | env_export: 62 | $(MAKE) _env_export 63 | mv $(TEMP_ENV_YAML) $(DEV_ENV_YAML) 64 | 65 | # This make target will re-build the Conda environment based on primary_deps.yaml, install all packages, 66 | # export the locked environment 67 | env_lock_in_container: env_from_primary setup_packages_with_deps env_export 68 | 69 | # Full Docker-based solution to re-create the environments: 70 | env_lock: 71 | @echo If something goes wrong here, you may need to execute 'docker rm -f env-container' afterwards to clean up 72 | @echo Building the base docker image - this will take several minutes 73 | docker build -t devcont -f Dockerfile_no_conda . 74 | @echo Kill any container that may still be running from previous failed attempts 75 | docker rm -f env-container || true 76 | @echo Start a container with that image, with a sleep command to prevent it from terminating immediately 77 | docker run -d --name env-container devcont sleep infinity 78 | @echo Copy the repository to the container 79 | docker cp `pwd` env-container:/home/radfact 80 | @echo Change the copied files to belong to the "radfact" user that all commands are running under 81 | docker exec -u root env-container chown -R radfact:radfact /home/radfact/$(REPO_NAME) 82 | @echo Starting environment re-build - this will take several minutes 83 | docker exec -it env-container bash -c "cd $(REPO_NAME) && make env_lock_in_container" 84 | @echo Copying updated environment files out of the container 85 | docker cp env-container:/home/radfact/$(REPO_NAME)/$(DEV_ENV_YAML) . 86 | @echo Stopping container 87 | docker rm -f env-container 88 | 89 | 90 | lfs: 91 | @if [ -z "$$(which git-lfs)" ]; then \ 92 | echo "Installing git-lfs"; \ 93 | sudo apt-get install git-lfs ; \ 94 | else \ 95 | echo "git-lfs already installed" ; \ 96 | fi 97 | git lfs install 98 | git lfs pull 99 | 100 | # Run black and reformat all files as necessary 101 | black: 102 | black $(BLACK_ARGS) 103 | 104 | # Run black, but do not reformat files 105 | blackcheck: 106 | black --check $(BLACK_ARGS) 107 | 108 | # Run flake8 109 | flake8: 110 | flake8 --config .flake8 . 111 | 112 | # Run mypy 113 | mypy: 114 | mypy . 115 | 116 | # Run black to reformat all files, then flake8 to find issues beyond formatting, then mypy 117 | check: black flake8 mypy 118 | -------------------------------------------------------------------------------- /RadFact.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RadFact/72463ac34442647886f13fad7d04958ab6c7c294/RadFact.png -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing issues before filing new issues to avoid duplicates. For new issues, file your bug or feature request as a new Issue. 6 | 7 | For help and questions about using this project, please open a GitHub issue first. If necessary, you can email any of the authors of the [MAIRA-2](https://aka.ms/maira-2) paper or . 8 | 9 | ## Microsoft Support Policy 10 | 11 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 12 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - endpoints: chat_openai 3 | - _self_ 4 | 5 | llm_api: 6 | api_version: "2023-07-01-preview" 7 | max_retries: 10 8 | temperature: 0.0 9 | max_tokens: 1024 10 | n_completions: 1 11 | top_p: 1.0 12 | frequency_penalty: 0.0 13 | presence_penalty: 0.0 14 | stop: null 15 | timeout: null 16 | 17 | processing: 18 | index_col: study_id 19 | batch_size: 100 20 | start_index: 0 21 | end_index: null 22 | output_filename: "outputs.json" 23 | 24 | # The type of cache that should be set for langchain. This can be either "redis" or "sqlite". 25 | # Sqlite cache is useful for local development, it will be written to ~/.langchain.db 26 | # Redis cache is useful to share state across many evaluation runs in AzureML 27 | langchain_cache_type: null 28 | -------------------------------------------------------------------------------- /configs/endpoints/azure_chat_openai.yaml: -------------------------------------------------------------------------------- 1 | ENDPOINT_0: 2 | type: "AZURE_CHAT_OPENAI" # this specifies which type of langchain model to use 3 | url: "" # change this with your own endpoint URL 4 | deployment_name: "gpt-4" 5 | 6 | # we support different types of authentication to the endpoint. 7 | # 1. You can set the api_key as an environment variable, in the case of multiple endpoints, you can set the 8 | # api_key_env_var_name to the name of the environment variable. 9 | api_key_env_var_name: "API_KEY_AZURE_0" 10 | # 2. Alternatively, you can set the api_key as a secret in Azure KeyVault, update the keyvault_secret_name to the 11 | # name of the secret in Azure KeyVault. 12 | # keyvault_secret_name: "" 13 | # 3. If none of the above are set, we fall back to creating a token provider assuming you have the necessary azure 14 | # credentials set up. 15 | 16 | # If you have access to multiple endpoints, you can vary the speed_factor to control for the amount of data assigned 17 | # to each endpoint. This is used for efficient data sharing across multiple endpoints with different throughput. 18 | speed_factor: 1.0 19 | 20 | # Alternatively, you can set the num_parallel_processes to control the number of parallel processes that can be 21 | # spawned to handle requests to this endpoint. This is useful to parallelize requests to a single endpoint when 22 | # the endpoint has a high throughput. By default, all calls are sequential. This parameter enables parallelism. 23 | num_parallel_processes: 1 24 | 25 | # If you have access to multiple endpoints, you can add more endpoints as follows. Change the speed factors as needed 26 | # to control the amount of data assigned to each endpoint. 27 | 28 | # ENDPOINT_1: 29 | # type: "AZURE_CHAT_OPENAI" 30 | # url: "" # change this with your own endpoint URL 31 | # deployment_name: "gpt-4" 32 | # api_key_env_var_name: "API_KEY_AZURE_1" 33 | # speed_factor: 1.33 34 | # ENDPOINT_2: 35 | # type: "AZURE_CHAT_OPENAI" 36 | # # url: "" # change this with your own endpoint URL 37 | # deployment_name: "gpt-4" 38 | # api_key_env_var_name: "API_KEY_AZURE_2" 39 | # speed_factor: 3.66 40 | -------------------------------------------------------------------------------- /configs/endpoints/chat_openai.yaml: -------------------------------------------------------------------------------- 1 | ENDPOINT_0: 2 | type: "CHAT_OPENAI" # this specifies which type of langchain model to use 3 | url: "" # change this with your own endpoint URL 4 | deployment_name: "llama3-70b" 5 | 6 | # we support different types of authentication to the endpoint. 7 | # 1. You can set the api_key as an environment variable, in the case of multiple endpoints, you can set the 8 | # api_key_env_var_name to the name of the environment variable. 9 | api_key_env_var_name: "API_KEY_CHAT_0" 10 | # 2. Alternatively, you can set the api_key as a secret in Azure KeyVault, update the keyvault_secret_name to the 11 | # name of the secret in Azure KeyVault. 12 | # keyvault_secret_name: "" 13 | # 3. If none of the above are set, we fall back to creating a token provider assuming you have the necessary azure 14 | # credentials set up. 15 | 16 | # If you have access to multiple endpoints, you can vary the speed_factor to control for the amount of data assigned 17 | # to each endpoint. This is used for efficient data sharing across multiple endpoints with different throughput. 18 | speed_factor: 1.0 19 | 20 | # Alternatively, you can set the num_parallel_processes to control the number of parallel processes that can be 21 | # spawned to handle requests to this endpoint. This is useful to parallelize requests to a single endpoint when 22 | # the endpoint has a high throughput. By default, all calls are sequential. This parameter enables parallelism. 23 | num_parallel_processes: 10 24 | 25 | # If you have access to multiple endpoints, you can add more endpoints as follows. Change the speed factors as needed 26 | # to control the amount of data assigned to each endpoint. 27 | 28 | # ENDPOINT_1: 29 | # type: "CHAT_OPENAI" 30 | # url: "" # change this with your own endpoint URL 31 | # deployment_name: "llama3-70b" 32 | # api_key_env_var_name: "API_KEY_CHAT_1" 33 | # speed_factor: 1.33 34 | # ENDPOINT_2: 35 | # type: "CHAT_OPENAI" 36 | # # url: "" # change this with your own endpoint URL 37 | # deployment_name: "llama3-70b" 38 | # api_key_env_var_name: "API_KEY_CHAT_2" 39 | # speed_factor: 3.66 40 | -------------------------------------------------------------------------------- /configs/radfact.yaml: -------------------------------------------------------------------------------- 1 | #@package __global__ 2 | 3 | defaults: 4 | - default 5 | - override endpoints: chat_openai 6 | - _self_ 7 | -------------------------------------------------------------------------------- /configs/report_to_phrases.yaml: -------------------------------------------------------------------------------- 1 | #@package __global__ 2 | 3 | defaults: 4 | - default 5 | - override endpoints: azure_chat_openai 6 | - _self_ 7 | 8 | dataset: 9 | name: reports_to_phrases 10 | csv_path: "" 11 | -------------------------------------------------------------------------------- /dev_environment.yaml: -------------------------------------------------------------------------------- 1 | # DO NOT MANUALLY EDIT THIS FILE, USE 'make env_lock' INSTEAD 2 | name: radfact 3 | channels: 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1 8 | - _openmp_mutex=4.5 9 | - bzip2=1.0.8 10 | - ca-certificates=2024.8.30 11 | - ld_impl_linux-64=2.43 12 | - libffi=3.4.2 13 | - libgcc=14.2.0 14 | - libgcc-ng=14.2.0 15 | - libgomp=14.2.0 16 | - libnsl=2.0.1 17 | - libsqlite=3.47.0 18 | - libuuid=2.38.1 19 | - libxcrypt=4.4.36 20 | - libzlib=1.3.1 21 | - ncurses=6.5 22 | - openssl=3.4.0 23 | - pip=23.0.1 24 | - python=3.10.15 25 | - readline=8.2 26 | - setuptools=69.5.1 27 | - tk=8.6.13 28 | - wheel=0.45.0 29 | - xz=5.2.6 30 | - pip: 31 | - absl-py==2.1.0 32 | - adal==1.2.7 33 | - aiohappyeyeballs==2.4.3 34 | - aiohttp==3.11.7 35 | - aiosignal==1.3.1 36 | - alembic==1.14.0 37 | - antlr4-python3-runtime==4.9.3 38 | - anyio==4.6.2.post1 39 | - applicationinsights==0.11.10 40 | - argcomplete==3.1.6 41 | - asttokens==2.4.1 42 | - async-timeout==4.0.3 43 | - attrs==24.2.0 44 | - azure-ai-ml==1.22.4 45 | - azure-appconfiguration==1.1.1 46 | - azure-batch==14.0.0 47 | - azure-cli==2.53.1 48 | - azure-cli-core==2.53.1 49 | - azure-cli-telemetry==1.1.0 50 | - azure-common==1.1.28 51 | - azure-core==1.32.0 52 | - azure-cosmos==3.2.0 53 | - azure-data-tables==12.4.0 54 | - azure-datalake-store==0.0.53 55 | - azure-graphrbac==0.60.0 56 | - azure-identity==1.14.1 57 | - azure-keyvault==4.2.0 58 | - azure-keyvault-administration==4.3.0 59 | - azure-keyvault-certificates==4.7.0 60 | - azure-keyvault-keys==4.8.0b2 61 | - azure-keyvault-secrets==4.7.0 62 | - azure-loganalytics==0.1.1 63 | - azure-mgmt-advisor==9.0.0 64 | - azure-mgmt-apimanagement==4.0.0 65 | - azure-mgmt-appconfiguration==3.0.0 66 | - azure-mgmt-appcontainers==2.0.0 67 | - azure-mgmt-applicationinsights==1.0.0 68 | - azure-mgmt-authorization==4.0.0 69 | - azure-mgmt-batch==17.0.0 70 | - azure-mgmt-batchai==7.0.0b1 71 | - azure-mgmt-billing==6.0.0 72 | - azure-mgmt-botservice==2.0.0 73 | - azure-mgmt-cdn==12.0.0 74 | - azure-mgmt-cognitiveservices==13.5.0 75 | - azure-mgmt-compute==30.0.0 76 | - azure-mgmt-containerinstance==10.1.0 77 | - azure-mgmt-containerregistry==10.1.0 78 | - azure-mgmt-containerservice==26.0.0 79 | - azure-mgmt-core==1.5.0 80 | - azure-mgmt-cosmosdb==9.2.0 81 | - azure-mgmt-databoxedge==1.0.0 82 | - azure-mgmt-datalake-nspkg==3.0.1 83 | - azure-mgmt-datalake-store==0.5.0 84 | - azure-mgmt-datamigration==10.0.0 85 | - azure-mgmt-devtestlabs==4.0.0 86 | - azure-mgmt-dns==8.0.0 87 | - azure-mgmt-eventgrid==10.2.0b2 88 | - azure-mgmt-eventhub==10.1.0 89 | - azure-mgmt-extendedlocation==1.0.0b2 90 | - azure-mgmt-hdinsight==9.0.0 91 | - azure-mgmt-imagebuilder==1.2.0 92 | - azure-mgmt-iotcentral==10.0.0b2 93 | - azure-mgmt-iothub==2.3.0 94 | - azure-mgmt-iothubprovisioningservices==1.1.0 95 | - azure-mgmt-keyvault==10.2.3 96 | - azure-mgmt-kusto==0.3.0 97 | - azure-mgmt-loganalytics==13.0.0b4 98 | - azure-mgmt-managedservices==1.0.0 99 | - azure-mgmt-managementgroups==1.0.0 100 | - azure-mgmt-maps==2.0.0 101 | - azure-mgmt-marketplaceordering==1.1.0 102 | - azure-mgmt-media==9.0.0 103 | - azure-mgmt-monitor==5.0.1 104 | - azure-mgmt-msi==7.0.0 105 | - azure-mgmt-netapp==10.1.0 106 | - azure-mgmt-network==28.0.0 107 | - azure-mgmt-nspkg==3.0.2 108 | - azure-mgmt-policyinsights==1.1.0b4 109 | - azure-mgmt-privatedns==1.0.0 110 | - azure-mgmt-rdbms==10.2.0b18 111 | - azure-mgmt-recoveryservices==2.5.0 112 | - azure-mgmt-recoveryservicesbackup==7.0.0 113 | - azure-mgmt-redhatopenshift==1.3.0 114 | - azure-mgmt-redis==14.1.0 115 | - azure-mgmt-resource==23.1.0b2 116 | - azure-mgmt-search==9.1.0 117 | - azure-mgmt-security==5.0.0 118 | - azure-mgmt-servicebus==8.2.1 119 | - azure-mgmt-servicefabric==1.0.0 120 | - azure-mgmt-servicefabricmanagedclusters==1.0.0 121 | - azure-mgmt-servicelinker==1.2.0b1 122 | - azure-mgmt-signalr==1.1.0 123 | - azure-mgmt-sql==4.0.0b12 124 | - azure-mgmt-sqlvirtualmachine==1.0.0b5 125 | - azure-mgmt-storage==21.0.0 126 | - azure-mgmt-synapse==2.1.0b5 127 | - azure-mgmt-trafficmanager==1.0.0 128 | - azure-mgmt-web==7.0.0 129 | - azure-multiapi-storage==1.2.0 130 | - azure-nspkg==3.0.2 131 | - azure-storage-blob==12.19.0 132 | - azure-storage-common==1.4.2 133 | - azure-storage-file-datalake==12.14.0 134 | - azure-storage-file-share==12.20.0 135 | - azure-synapse-accesscontrol==0.5.0 136 | - azure-synapse-artifacts==0.17.0 137 | - azure-synapse-managedprivateendpoints==0.4.0 138 | - azure-synapse-spark==0.2.0 139 | - azureml-automl-core==1.58.0 140 | - azureml-core==1.58.0 141 | - azureml-dataprep==5.1.6 142 | - azureml-dataprep-native==41.0.0 143 | - azureml-dataprep-rslex==2.22.5 144 | - azureml-dataset-runtime==1.58.0 145 | - azureml-mlflow==1.58.0.post3 146 | - azureml-pipeline==1.58.0 147 | - azureml-pipeline-core==1.58.0 148 | - azureml-pipeline-steps==1.58.0 149 | - azureml-sdk==1.58.0 150 | - azureml-telemetry==1.58.0 151 | - azureml-tensorboard==1.58.0 152 | - azureml-train-automl-client==1.58.0 153 | - azureml-train-core==1.58.0 154 | - azureml-train-restclients-hyperdrive==1.58.0 155 | - backports-tempfile==1.0 156 | - backports-weakref==1.0.post1 157 | - bcrypt==4.2.1 158 | - black==24.10.0 159 | - blinker==1.9.0 160 | - cachetools==5.5.0 161 | - certifi==2024.8.30 162 | - cffi==1.17.1 163 | - cfgv==3.4.0 164 | - chardet==3.0.4 165 | - charset-normalizer==3.4.0 166 | - click==8.1.7 167 | - cloudpickle==2.2.1 168 | - colorama==0.4.6 169 | - comm==0.2.2 170 | - conda-merge==0.3.0 171 | - contextlib2==21.6.0 172 | - contourpy==1.3.1 173 | - cryptography==43.0.3 174 | - cycler==0.12.1 175 | - databricks-sdk==0.38.0 176 | - dataclasses-json==0.6.7 177 | - debugpy==1.8.8 178 | - decorator==5.1.1 179 | - deprecated==1.2.15 180 | - distlib==0.3.9 181 | - distro==1.9.0 182 | - docker==7.1.0 183 | - entrypoints==0.4 184 | - exceptiongroup==1.2.2 185 | - executing==2.1.0 186 | - fabric==2.7.1 187 | - filelock==3.16.1 188 | - flake8==7.1.1 189 | - flask==3.1.0 190 | - fonttools==4.55.0 191 | - frozenlist==1.5.0 192 | - fusepy==3.0.1 193 | - gitdb==4.0.11 194 | - gitpython==3.1.43 195 | - google-api-core==2.23.0 196 | - google-auth==2.36.0 197 | - googleapis-common-protos==1.66.0 198 | - graphene==3.4.3 199 | - graphql-core==3.2.5 200 | - graphql-relay==3.2.0 201 | - greenlet==3.1.1 202 | - grpcio==1.68.0 203 | - gunicorn==22.0.0 204 | - h11==0.14.0 205 | - hi-ml-azure==0.5.2 206 | - httpcore==1.0.7 207 | - httpx==0.27.2 208 | - humanfriendly==10.0 209 | - hydra-core==1.3.2 210 | - identify==2.6.2 211 | - idna==3.10 212 | - importlib-metadata==7.2.1 213 | - importlib-resources==6.4.0 214 | - iniconfig==2.0.0 215 | - invoke==1.7.3 216 | - ipykernel==6.29.5 217 | - ipython==8.29.0 218 | - isodate==0.7.2 219 | - itsdangerous==2.2.0 220 | - javaproperties==0.5.2 221 | - jedi==0.19.2 222 | - jeepney==0.8.0 223 | - jinja2==3.1.4 224 | - jiter==0.7.1 225 | - jmespath==1.0.1 226 | - joblib==1.4.2 227 | - jsondiff==2.0.0 228 | - jsonpatch==1.33 229 | - jsonpickle==3.4.2 230 | - jsonpointer==3.0.0 231 | - jsonschema==4.23.0 232 | - jsonschema-specifications==2024.10.1 233 | - jupyter-client==8.6.3 234 | - jupyter-core==5.7.2 235 | - kiwisolver==1.4.7 236 | - knack==0.11.0 237 | - langchain==0.2.17 238 | - langchain-community==0.2.19 239 | - langchain-core==0.2.43 240 | - langchain-openai==0.1.25 241 | - langchain-text-splitters==0.2.4 242 | - langsmith==0.1.144 243 | - mako==1.3.6 244 | - markdown==3.7 245 | - markupsafe==3.0.2 246 | - marshmallow==3.23.1 247 | - matplotlib==3.9.2 248 | - matplotlib-inline==0.1.7 249 | - mccabe==0.7.0 250 | - mlflow==2.15.1 251 | - mlflow-skinny==2.15.1 252 | - mock==5.1.0 253 | - msal==1.24.0b2 254 | - msal-extensions==1.0.0 255 | - msrest==0.7.1 256 | - msrestazure==0.6.4.post1 257 | - multidict==6.1.0 258 | - mypy==1.13.0 259 | - mypy-extensions==1.0.0 260 | - ndg-httpsclient==0.5.1 261 | - nest-asyncio==1.6.0 262 | - nodeenv==1.9.1 263 | - numpy==1.23.5 264 | - oauthlib==3.2.2 265 | - omegaconf==2.3.0 266 | - openai==1.55.0 267 | - opencensus==0.11.4 268 | - opencensus-context==0.1.3 269 | - opencensus-ext-azure==1.1.13 270 | - opencensus-ext-logging==0.1.1 271 | - opentelemetry-api==1.28.2 272 | - opentelemetry-sdk==1.28.2 273 | - opentelemetry-semantic-conventions==0.49b2 274 | - orjson==3.10.11 275 | - packaging==24.2 276 | - pandas==2.2.3 277 | - pandas-stubs==2.2.3.241009 278 | - param==1.13.0 279 | - paramiko==3.5.0 280 | - parso==0.8.4 281 | - pathlib2==2.3.7.post1 282 | - pathspec==0.12.1 283 | - pexpect==4.9.0 284 | - pillow==11.0.0 285 | - pkginfo==1.11.2 286 | - platformdirs==4.3.6 287 | - pluggy==1.5.0 288 | - portalocker==2.10.1 289 | - pre-commit==4.0.1 290 | - prompt-toolkit==3.0.48 291 | - propcache==0.2.0 292 | - proto-plus==1.25.0 293 | - protobuf==3.20.3 294 | - psutil==5.9.8 295 | - ptyprocess==0.7.0 296 | - pure-eval==0.2.3 297 | - pyarrow==15.0.2 298 | - pyasn1==0.6.1 299 | - pyasn1-modules==0.4.1 300 | - pycodestyle==2.12.1 301 | - pycomposefile==0.0.32 302 | - pycparser==2.22 303 | - pydantic==1.10.17 304 | - pydash==8.0.4 305 | - pyflakes==3.2.0 306 | - pygithub==1.59.1 307 | - pygments==2.18.0 308 | - pyjwt==2.10.0 309 | - pynacl==1.5.0 310 | - pyopenssl==24.2.1 311 | - pyparsing==3.2.0 312 | - pysocks==1.7.1 313 | - pytest==8.3.3 314 | - pytest-lazy-fixture==0.6.3 315 | - python-dateutil==2.9.0.post0 316 | - pytz==2024.2 317 | - pyyaml==6.0.2 318 | - pyzmq==26.2.0 319 | - querystring-parser==1.2.4 320 | - redis==5.2.0 321 | - referencing==0.35.1 322 | - regex==2024.11.6 323 | - requests==2.32.3 324 | - requests-oauthlib==2.0.0 325 | - requests-toolbelt==1.0.0 326 | - rpds-py==0.21.0 327 | - rsa==4.9 328 | - ruamel-yaml==0.18.6 329 | - ruamel-yaml-clib==0.2.12 330 | - scikit-learn==1.5.2 331 | - scipy==1.14.1 332 | - scp==0.13.6 333 | - secretstorage==3.3.3 334 | - semver==2.13.0 335 | - six==1.16.0 336 | - smmap==5.0.1 337 | - sniffio==1.3.1 338 | - sqlalchemy==2.0.36 339 | - sqlparse==0.5.2 340 | - sshtunnel==0.1.5 341 | - stack-data==0.6.3 342 | - strictyaml==1.7.3 343 | - tabulate==0.9.0 344 | - tenacity==8.5.0 345 | - tensorboard==2.18.0 346 | - tensorboard-data-server==0.7.2 347 | - threadpoolctl==3.5.0 348 | - tiktoken==0.8.0 349 | - tomli==2.1.0 350 | - tornado==6.4.1 351 | - tqdm==4.67.0 352 | - traitlets==5.14.3 353 | - types-pillow==10.2.0.20240822 354 | - types-pytz==2024.2.0.20241003 355 | - types-pyyaml==6.0.12.20240917 356 | - types-requests==2.32.0.20241016 357 | - types-tqdm==4.67.0.20241119 358 | - typing-extensions==4.12.2 359 | - typing-inspect==0.9.0 360 | - tzdata==2024.2 361 | - urllib3==2.2.3 362 | - virtualenv==20.27.1 363 | - wcwidth==0.2.13 364 | - websocket-client==1.3.3 365 | - werkzeug==3.1.3 366 | - wrapt==1.16.0 367 | - xmltodict==0.14.2 368 | - yarl==1.18.0 369 | - zipp==3.21.0 370 | -------------------------------------------------------------------------------- /examples/findings_generation_examples.csv: -------------------------------------------------------------------------------- 1 | example_id,prediction,target 2 | 0,The lungs are well expanded. Noted presence of the large nodular density in the right midlung field measuring 3.5 x 3.5 cm. This is a new finding. The cardiac size is mildly enlarged. With mild atherosclerotic changes of the aorta. Pulmonary vascularity is normal. The pulmonary interstitial markings are normal.,The bony structures are intact. Degenerative changes spine noted. There is indistinctness of the right hemidiaphragm suggesting subpulmonic effusion. Ovoid density superimposed lung fields in the mid and lower lung fields. These measure in the region to half to 3 cm in size and pulmonary nodule or pleural-based abnormalities. Cardiac size is enlarged. Ventricular fullness hilar vascular fullness. Unfolding aorta. There is no hilar or mediastinal adenopathy. 3 | 1,The patient is status post median sternotomy and CABG. The heart size and pulmonary vasculature are normal. No consolidations or effusions are present. The bony thorax is otherwise unremarkable.,Patient is status post median sternotomy and CABG. The lungs are well aerated without focal consolidation or discrete pulmonary mass. No pleural effusion or pneumothorax is identified. The cardiomediastinal silhouette and pulmonary vasculature appear within normal limits. No acute or suspicious osseous lesion is identified. 4 | 2,"The lungs are adequately inflated. No focal airspace opacity, pleural effusion or pneumothorax. Unchanged appearance of the small airways. The cardiac silhouette is at the upper limit of normal for size. Atherosclerotic calcifications are present at the aorta. Surgical clips project over the right upper quadrant. Degenerative changes are present at the spine.",There is mild hyperinflation with mild peribronchial cuffing and airway wall thickening without consolidation. Cardiomediastinal silhouette and pulmonary vasculature within normal limits. Osseous structures are unremarkable. 5 | 3,"The heart, lungs and mediastinal structures are within normal limits. There are no osseous abnormalities.","The heart, lungs and mediastinal structures are within normal limits. There are no osseous abnormalities." 6 | 4,Cardiac size and pulmonary vasculature are within normal limits. Lungs are clear and well expanded. Visualized osseous structures are unremarkable.,Cardiac size and pulmonary vasculature are within normal limits. Lungs are clear and well expanded. Visualized osseous structures are unremarkable. 7 | 5,"Since the previous study, there has been interval resolution of the right lower lobe infiltrate. No active infiltrate or consolidation is demonstrated. Pulmonary vascularity is normal. The hila are not enlarged. The mediastinum is not widened. There is tortuosity of the thoracic aorta with calcification. The heart size is normal. The bones appear intact.",There is mild hyperinflation with prominence of the interstitial markings in both lung fields with parenchymal scarring right lower lobe unchanged since previous study. No active infiltrate or consolidation is demonstrated. Pulmonary vascularity is normal. The hila are not enlarged. There is tortuosity aorta with calcification aortic knob. The mediastinum is not widened. The heart size is slightly enlarged. The bones appear intact. 8 | 6,The bony structures are intact. There are mild degenerative changes of the spine. The lungs show no acute infiltrate or mass. There is no effusion. The pulmonary interstitial markings are normal. The diaphragm is smooth. The cardiac size is mildly enlarged. There is uncoiling of the aorta compatible with hypertension. A solitary pacemaker is noted with the tip in the right ventricle. There is no CHF. There is no hilar or mediastinal adenopathy.,The bony structures are intact. The lungs show no acute infiltrate or mass. There is no effusion. The pulmonary interstitial markings are normal. The diaphragm is smooth. The cardiac size is mildly enlarged. There is uncoiling the aorta indicative of hypertension. Pacemakers are noted in the right atrium and ventricle. A defibrillator is also present. There is no CHF. There is no hilar or mediastinal adenopathy. 9 | 7,"The bony structures are intact. There is again noted pleural scarring changes at the left base with blunting left costophrenic angle. This appears to be a chronic process. Postoperative changes in the left hemithorax again noted. The lungs show no active infiltrate, mass or effusion. Diaphragms are sharp. Cardiac size is normal. There is no hilar or mediastinal adenopathy.",The lungs are well expanded. Previously noted infiltrative processes in both lung bases are clear. At this time there is a linear density left lung base representing platelike atelectasis. The cardiac size is normal. With mild atherosclerotic changes of the aorta. Pulmonary vascularity is normal. . The pulmonary interstitial markings are normal. No active infiltrate or consolidation is demonstrated. 10 | 8,Cardiac silhouette: The cardiac silhouette is not enlarged. Mediastinal contours: The mediastinal contours are unremarkable. Lung fields: Clear. Visualized osseous structures: No significant osseous lesion is identified. Other: None.,Cardiac silhouette: The cardiac silhouette is not enlarged. Mediastinal contours: Tortuous and calcified aorta. Lung fields: No pleural effusions or focal consolidations. Visualized osseous structures: No significant osseous lesion is identified. Other: None. 11 | 9,The bony structures are intact. There are degenerative changes of the spine. The lungs show no acute infiltrate or mass. There is no effusion. The pulmonary interstitial markings are normal. The diaphragm is smooth. The cardiac size is mildly enlarged. There is uncoiling of the aorta with calcification. There is no hilar or mediastinal adenopathy.,The bony structures are demineralized. There are extensive degenerative changes of the thoracic spine. The lungs show no acute infiltrate or mass. There is no effusion. The pulmonary interstitial markings are mildly accentuated. The diaphragm is smooth. The cardiac size is mildly enlarged. There is uncoiling of the aorta with calcification. There is no CHF. There is no hilar or mediastinal adenopathy. 12 | 10,There is linear scarring versus subsegmental atelectasis in the right lower lobe. No active infiltrate or consolidation is demonstrated. Pulmonary vascularity is normal. The hila are not enlarged. The mediastinum is not widened. The heart size is normal. The bones appear intact.,There is scarring versus subsegmental atelectasis in the right lower lobe. No active infiltrate or consolidation is demonstrated. Pulmonary vascularity is normal. The hila are not enlarged. The mediastinum is not widened. The heart size is normal. The bones appear intact. 13 | 11,Two views of the chest are compared to the previous study. The cardiomediastinal silhouette appears unremarkable. There are abnormal increased interstitial markings throughout both lungs which are stable.,"The cardiomediastinal silhouette is within normal limits. Calcified parenchymal granulomas are noted. Mildly prominent perihilar interstitial markings are noted and are nonspecific. A few scattered tiny nodular densities are noted including within the lateral aspect of the left upper lobe. No lobar consolidation, effusion, or pneumothorax is present. Bones are normal for the patient's age aside from pectus oxygen vitamin deformity and mild scoliosis." 14 | 12,"Postsurgical changes/Catheters and Support Devices: Right-sided Port-A-Cath terminates at the level of the right atrium. Left-sided cardiac pacing device. Lungs and Pleural Spaces: There is no pneumonia, edema or pneumothorax. Heart/Mediastinum: Evaluation the cardiac silhouette is limited on this projection. Other: The visualized osseous structures are unremarkable.","Postsurgical changes/Catheters and Support Devices: Right subclavian Port-A-Cath. Left chest wall pacemaker. Tendon anchor in the right shoulder. Lungs and Pleural Spaces: There is no pneumonia, edema or pneumothorax. Heart/Mediastinum: Mediastinum is not adequately evaluated on portable chest radiography. Other: No other significant findings." 15 | 13,The bony structures are intact. There are degenerative changes of the spine. The lungs show no acute infiltrate or mass. There is no effusion. The pulmonary interstitial markings are normal. The diaphragm is smooth. The cardiac size is mildly enlarged. There is uncoiling of the aorta compatible with hypertension. There is no hilar or mediastinal adenopathy.,The bony structures are intact. There are mild degenerative changes of the spine. The lungs show no acute infiltrate or mass. There is no effusion. The pulmonary interstitial markings are normal. The diaphragm is smooth. The cardiac size is mildly enlarged. Uncoiling the aorta is indicative of hypertension. There is calcification within the knob. There is no hilar or mediastinal adenopathy. 16 | 14,The bony structures are intact. There are degenerative changes of the spine. The lungs show no acute infiltrate or mass. There is no effusion. The pulmonary interstitial markings are normal. The diaphragm is smooth. The cardiac size is normal. There is uncoiling the aorta indicative of hypertension. There is no hilar or mediastinal adenopathy.,The bony structures are intact. There is a healed fracture of the right clavicle. There are degenerative changes of the thoracic spine. The lungs show no acute infiltrate or mass. There is no effusion. The pulmonary interstitial markings are mildly accentuated. The diaphragm is smooth. The cardiac size is normal. There is mild uncoiling of the aorta suggesting hypertension. There is no hilar or mediastinal adenopathy. 17 | 15,Two views of the chest are provided for interpretation. Heart size is normal. There is a nodular density overlying the right lower lung field measuring up to about 1.6 cm. Lungs are otherwise clear.,"Cardiac and mediastinal silhouettes are stable. A small nodular pulmonary opacity is present in the peripheral right lower lung not seen on the previous study. There is prominence the pulmonary interstitium. Chain staples are present in the right lower chest. No confluent infiltrates, pleural effusions, or pneumothorax seen." 18 | 16,Cardiac silhouette is normal in size. Hilar contours are unremarkable. Vascularity is normal. No infiltrate or pleural fluid is evident.,Cardiac silhouette is normal in size. Hilar contours are unremarkable. Vascularity is normal. No infiltrate or pleural fluid is evident. No pneumothorax is noted. 19 | 17,"A left-sided dual lead cardiac device is present. Surgical clips project over the lateral left chest wall. Surgical clips are seen in the right upper abdomen, likely status post cholecystectomy. The lungs are hyperlucent and hyperinflated with flattening of the diaphragms. There is no consolidation. There is no pleural effusion. No pneumothorax seen. The cardiomediastinal silhouette is enlarged. The visualized osseous structures demonstrates multilevel thoracolumbar spine degenerative disease.","Surgical clips projects over the upper, lateral chest wall. A dual lead cardiac device projects over the upper left hemithorax. Bibasilar subsegmental atelectasis is present. There is no pleural effusion. No pneumothorax seen. The cardiomediastinal silhouette is within normal limits. The visualized osseous structures demonstrates multilevel thoracolumbar spine degenerative disease." 20 | 18,"The bony structures are intact. Degenerative changes spine The lungs show no active infiltrate, mass or effusion. Diaphragms are sharp. Cardiac size is normal. There are pacemakers in the right atrium and ventricle. One of these pacemakers is new. There is no hilar or mediastinal adenopathy.","The bony structures are intact. The lungs show no active infiltrate, mass or effusion. Diaphragms are sharp. Cardiac size is normal. The previously noted cardiomegaly has resolved. Pacemaker is now present with electrodes at right atrium and right ventricle. There is no hilar or mediastinal adenopathy." 21 | 19,The bony structures are osteoporotic. There are degenerative changes of the spine. The lungs show diffuse interstitial lung disease with no acute infiltrate or mass. There is no effusion. The diaphragm is smooth. The cardiac size is moderately enlarged. There is evidence of coronary bypass graft. There is uncoiling the aorta with calcification. There is no CHF. There is no hilar or mediastinal adenopathy.,"Lung parenchyma/pleura: There is increased interstitial density with reticulonodular pattern, mostly peripheral, worse in the right upper lobe. The hemidiaphragms become less distinctive. There is no pleural effusion. Central airways are patent. Mediastinum/hilum: Mediastinal and hilar densities are within normal limits. Cardiac: There is interval development of cardiomegaly. Heart size measures as wide as 17 cm. Skeletal: There is progressive loss of vertebral body heights or mild compressions in the lower dorsal vertebral bodies." 22 | -------------------------------------------------------------------------------- /getting_started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# RadFact Example\n", 8 | "\n", 9 | "Here we show an example of running RadFact for evaluation of either findings generation or grounded reporting evaluation.\n", 10 | "\n", 11 | "## Endpoint setup\n", 12 | "\n", 13 | "RadFact scores in the MAIRA-2 paper are computed using `Llama-3-70b-Instruct` for entailment verification and GPT-4 for report to phrase conversion.\n", 14 | "\n", 15 | "* Edit [`configs/endpoints/azure_chat_openai.yaml`](configs/endpoints/azure_chat_openai.yaml) to configure the endpoints for the Azure Chat API. This will be used by default for parsing the reports into phrases.\n", 16 | "* Edit [`configs/endpoints/chat_openai.yaml`](configs/endpoints/chat_openai.yaml) to configure the endpoints for the Chat API. This will be used by default for entailement verification. \n", 17 | "* Set env variable `API_KEY` if you want to use key-based authentication for these endpoints. In case you're using multiple endpoints, use different env variables for each endpoint, e.g., `API_KEY_CHAT_OPENAI` and `API_KEY_AZURE_CHAT_OPENAI`. Make sure to update the corresponding endpoint config files to use these env variables names in `api_key_env_var_name`.\n", 18 | "* Update `endpoints` in [`configs/radfact.yaml`](configs/radfact.yaml) and [`configs/report_to_phrases.yaml`](src/report_to_phrases.yaml) to use either `ChatOpenAI` or `AzureChatOpenAI` endpoints as available.\n", 19 | "\n", 20 | "See the [README](README.md#2-endpoint-llm-setup) for more detailed setup instructions." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import json\n", 30 | "import pandas as pd\n", 31 | "from radfact.data_utils.grounded_phrase_list import GroundedPhraseList\n", 32 | "from radfact.metric.radfact import RadFactMetric\n", 33 | "from radfact.metric.bootstrapping import MetricBootstrapper\n", 34 | "from radfact.metric.print_utils import print_bootstrap_results, print_results\n", 35 | "from radfact.paths import EXAMPLES_DIR" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Findings Generation Evaluation\n", 43 | "\n", 44 | "We provide an example csv in [`findings_generation_examples.csv`](examples/findings_generation_examples.csv) with columns `example_id`, `prediction` (model generation), `target` (ground truth).\n", 45 | "\n", 46 | "RadFact expects `candidates` (generations) and `references` (ground truths) in a dictionary where keys are an identifier, typically the study id. We use `example_id` here. `candidates` and `references` are expected to be strings corresponding to the predicted and target findings sections. They will first get converted into phrases using the report to phrase conversion prompts and then undergo entailment verification to get RadFact scores." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "findings_generation_examples = pd.read_csv(EXAMPLES_DIR / 'findings_generation_examples.csv')\n", 56 | "display(findings_generation_examples.head(2))\n", 57 | "candidates_fg = findings_generation_examples.set_index(\"example_id\")[\"prediction\"].to_dict()\n", 58 | "references_fg = findings_generation_examples.set_index(\"example_id\")[\"target\"].to_dict()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "For findings generation, when we initialise the metric we set `is_narrative_text=True` to instruct it to first perfom report-to-phrase conversion." 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "radfact_metric_for_fg = RadFactMetric(is_narrative_text=True)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "`logical_f1_fg` and `radfact_scores_f` can directly be obtained using the [`compute_metric_score`](radfact/src/radfact/metric/radfact.py#L369) method as shown below.\n", 82 | "\n", 83 | "```python \n", 84 | "logical_f1_fg, radfact_scores_f = radfact_metric_for_fg.compute_metric_score(candidates, references)\n", 85 | "```\n", 86 | "This calls [`compute_results_per_sample`](radfact/src/radfact/metric/radfact.py#L284) and [`aggregate_results`](radfact/src/radfact/metric/radfact.py#L355) under the hood. However, we break it down explicitely in this example to be able to reuse the per sample results for bootstrapping." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "results_per_sample_fg = radfact_metric_for_fg.compute_results_per_sample(candidates_fg, references_fg)\n", 96 | "logical_f1_fg, radfact_scores_fg = radfact_metric_for_fg.aggregate_results(results_per_sample_fg)\n", 97 | "logical_f1_fg" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "We can now look at the results. The only relevant scores for finding generation are logical_precision, logical_recall and logical_f1 since there are no boxes associated with findings to compute the other grounding and spatial scores." 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "print(\"Findings generation RadFact scores:\")\n", 114 | "print_results(radfact_scores_fg, metrics=[\"logical_precision\", \"logical_recall\", \"logical_f1\"])" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "You can also compute the bootstrap confidence intervals for the scores as shown below.\n", 122 | "\n", 123 | "We set the number of bootstrap samples (`num_samples`) to 10 here because our example dataset is quite small." 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "bootstrapper = MetricBootstrapper(metric=radfact_metric_for_fg, num_samples=10, seed=42)\n", 133 | "radfact_scores_fg_with_cis = bootstrapper.compute_bootstrap_metrics(results_per_sample=results_per_sample_fg)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "We can now inspect the results with the confidence intervals." 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "print(\"Findings generation RadFact scores (95% CI):\")\n", 150 | "print_bootstrap_results(radfact_scores_fg_with_cis, metrics=[\"logical_precision\", \"logical_recall\", \"logical_f1\"])" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "## Grounded Reporting Evaluation\n", 158 | "\n", 159 | "For grounded reporting, it's easiest to store model generations and ground truth in JSON format to accommodate both text and boxes. Each grounded report is represented as a list of dicts representing individual sentences, each with `text` and `boxes` keys. The `boxes` are `None` for non-grounded sentences. As for findings generation, the model generations are under `prediction` and the ground truth is under `target`.\n", 160 | "\n", 161 | "Refer to the [grounded_reporting_examples.json](examples/grounded_reporting_examples.json) for examples of the expected JSON format.\n", 162 | "\n", 163 | "From this JSON we can parse examples easily into `GroundedPhraseList`, which is expected by RadFact." 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "with open(EXAMPLES_DIR / 'grounded_reporting_examples.json', \"r\", encoding=\"utf-8\") as f:\n", 173 | " grounded_reporting_examples = json.load(f)\n", 174 | "candidates_gr = {\n", 175 | " example[\"example_id\"]: GroundedPhraseList.from_list_of_dicts(example[\"prediction\"])\n", 176 | " for example in grounded_reporting_examples\n", 177 | "}\n", 178 | "references_gr = {\n", 179 | " example[\"example_id\"]: GroundedPhraseList.from_list_of_dicts(example[\"target\"])\n", 180 | " for example in grounded_reporting_examples\n", 181 | "}\n", 182 | "print(\"Loaded\", len(grounded_reporting_examples), \"grounded reporting examples\")" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "When operating on grounded reports, represented as `GroundedPhraseList`, we do not need to set `is_narrative_text=True` in the metric. With already-parsed reports, no step to convert reports into phrases is required. `is_narrative_text` is set to `False` by default." 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "radfact_metric_for_gr = RadFactMetric()" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "Similarly to findings generation, we can compute the metric scores and confidence intervals for grounded reporting.\n", 206 | "\n", 207 | "We also break down the computation to be able to reuse the per sample results for bootstrapping." 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "results_per_sample_gr = radfact_metric_for_gr.compute_results_per_sample(candidates_gr, references_gr)\n", 217 | "logical_f1_gr, radfact_scores_gr = radfact_metric_for_gr.aggregate_results(results_per_sample_gr)\n", 218 | "logical_f1_gr" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "Since this is grounded reporting, we look at all the metrics returned by RadFact including grounding and spatial scores." 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "metrics = [\n", 235 | " \"logical_precision\",\n", 236 | " \"logical_recall\",\n", 237 | " \"logical_f1\",\n", 238 | " \"spatial_precision\",\n", 239 | " \"spatial_recall\",\n", 240 | " \"spatial_f1\",\n", 241 | " \"grounding_precision\",\n", 242 | " \"grounding_recall\",\n", 243 | " \"grounding_f1\",\n", 244 | "]\n", 245 | "print(\"Grounded reporting RadFact scores:\")\n", 246 | "print_results(radfact_scores_gr, metrics=metrics)\n" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "We can compute the bootstrap confidence intervals for the scores similarly." 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "bootstrapper = MetricBootstrapper(metric=radfact_metric_for_gr, num_samples=10, seed=42)\n", 263 | "radfact_scores_gr_with_cis = bootstrapper.compute_bootstrap_metrics(results_per_sample=results_per_sample_gr)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "We can now inspect the metrics with the confidence intervals." 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "print(\"Grounded reporting RadFact scores (95% CI):\")\n", 280 | "print_bootstrap_results(radfact_scores_gr_with_cis, metrics)" 281 | ] 282 | } 283 | ], 284 | "metadata": { 285 | "kernelspec": { 286 | "display_name": "radfact", 287 | "language": "python", 288 | "name": "python3" 289 | }, 290 | "language_info": { 291 | "codemirror_mode": { 292 | "name": "ipython", 293 | "version": 3 294 | }, 295 | "file_extension": ".py", 296 | "mimetype": "text/x-python", 297 | "name": "python", 298 | "nbconvert_exporter": "python", 299 | "pygments_lexer": "ipython3", 300 | "version": "3.10.14" 301 | } 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 2 305 | } 306 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | follow_imports_for_stubs = True 3 | strict_optional = True 4 | strict = True 5 | disallow_untyped_calls=False 6 | disallow_untyped_defs=True 7 | disallow_subclassing_any = False 8 | show_traceback=True 9 | plugins = numpy.typing.mypy_plugin 10 | 11 | [mypy-azureml.*] 12 | ignore_missing_imports = True 13 | 14 | [mypy-health_azure.*] 15 | ignore_missing_imports = True 16 | 17 | [mypy-tqdm.*] 18 | ignore_missing_imports = True 19 | 20 | [mypy-mock.*] 21 | ignore_missing_imports = True 22 | 23 | [mypy-langchain.*] 24 | ignore_missing_imports = True 25 | 26 | [mypy-requests.*] 27 | ignore_missing_imports = True 28 | 29 | [mypy-azure.ai.*] 30 | ignore_errors = True 31 | 32 | -------------------------------------------------------------------------------- /primary_deps.yaml: -------------------------------------------------------------------------------- 1 | # Do not add packages here unless required for PR builds, setup scripts, etc. 2 | name: radfact 3 | channels: 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - setuptools==69.5.1 8 | - pip=23.0.1 9 | - python=3.10 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | [tool.black] 3 | line-length = 120 4 | target-version = ["py310"] 5 | skip-string-normalization = true 6 | 7 | [tool.isort] 8 | profile = "black" 9 | known_first_party = ["radfact"] 10 | sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] 11 | 12 | [tool.pylint] 13 | max-line-length = 120 14 | disable = [ 15 | "logging-fstring-interpolation", 16 | "missing-module-docstring", 17 | "missing-class-docstring", 18 | "missing-function-docstring", 19 | "no-value-for-parameter", 20 | "no-name-in-module", 21 | "no-member", 22 | ] 23 | 24 | [build-system] 25 | requires = ["setuptools==69.5.1"] 26 | build-backend = "setuptools.build_meta" 27 | 28 | [project] 29 | name = "radfact" 30 | version = "1.0.0" 31 | description = "RadFact: Code for the entailment verification metric RadFact" 32 | readme = "README.md" 33 | requires-python = ">=3.10" 34 | authors = [ 35 | { name = "Biomedical Imaging Team - Health Futures UK", email = "BioMedicalImaging@microsoft.com" }, 36 | ] 37 | maintainers = [ 38 | { name = "Biomedical Imaging Team - Health Futures UK", email = "BioMedicalImaging@microsoft.com" }, 39 | ] 40 | 41 | dependencies = [ 42 | # azure utils 43 | "azureml-sdk", 44 | "azureml-tensorboard", 45 | "azure-keyvault", 46 | "hi-ml-azure", 47 | 48 | # llm utils 49 | "hydra-core==1.3.2", 50 | "langchain==0.2.17", 51 | "langchain-community==0.2.19", 52 | "langchain-openai==0.1.25", 53 | "openai==1.55.0", 54 | "pydantic==1.10.17", 55 | "redis", 56 | 57 | # radfact other 58 | "numpy", 59 | "pandas", 60 | ] 61 | 62 | [project.optional-dependencies] 63 | dev = [ 64 | # linting, formatting, type checking 65 | "black", 66 | "flake8", 67 | "ipykernel", 68 | "mypy", 69 | "pre-commit", 70 | "pyyaml", 71 | "pandas-stubs", 72 | "types-Pillow", 73 | "types-PyYAML", 74 | "types-tqdm", 75 | ] 76 | 77 | test = [ 78 | "mock", 79 | "pandas-stubs", 80 | "pytest", 81 | "pytest-lazy-fixture", 82 | ] 83 | 84 | [project.urls] 85 | repository = "https://github.com/microsoft/RadFact" 86 | [project.scripts] 87 | run_radfact = "radfact.cli.run_radfact:main" 88 | run_report_to_phrases = "radfact.cli.run_report_to_phrases:main" 89 | -------------------------------------------------------------------------------- /setup_packages.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | from __future__ import annotations 7 | 8 | import argparse 9 | import logging 10 | import subprocess 11 | import sys 12 | from enum import Enum 13 | from pathlib import Path 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | REPOSITORY_ROOT_DIR = Path(__file__).absolute().parent 18 | 19 | 20 | class InstallMode(str, Enum): 21 | EDITABLE = "editable" 22 | SYSTEM_PATH = "system_path" 23 | 24 | @classmethod 25 | def get_members(cls) -> list[InstallMode]: 26 | return [member for member in cls] 27 | 28 | 29 | def add_package_to_path(package_src: str) -> None: 30 | """Adds the given string at the start of sys.path. 31 | 32 | Adding at the start of sys.path, rather than appending, 33 | is important when working with multiple copies of the codebase. They would otherwise 34 | pick up code not from the copy, but from the main repository that has been installed via `pip -e`. 35 | """ 36 | sys.path.insert(0, str(package_src)) 37 | logger.info(f"Added {package_src} to sys.path") 38 | 39 | 40 | def get_package_src_dir() -> str: 41 | return str(REPOSITORY_ROOT_DIR / "src") 42 | 43 | 44 | def editable_install_packages(no_deps: bool, add_dev_deps: bool) -> None: 45 | """Installs the given packages in editable mode via pip.""" 46 | pip_command = "pip install -e" 47 | if add_dev_deps: 48 | if no_deps: 49 | raise ValueError("Both no_deps and add_test_deps cannot be true") 50 | pip_command += ".[dev,test] " 51 | else: 52 | pip_command += ". " 53 | if no_deps: 54 | pip_command += "--no-deps --no-build-isolation " 55 | logger.info(f"Installing packages: {pip_command}") 56 | subprocess.run(f"{pip_command}", shell=True, check=True) 57 | 58 | 59 | def add_packages_to_path() -> None: 60 | package_src = get_package_src_dir() 61 | add_package_to_path(str(package_src)) 62 | 63 | 64 | def setup_packages(install_mode: InstallMode, no_deps: bool = False, add_dev_deps: bool = False) -> None: 65 | """Adds local packages to the Python environment, either by calling 'pip install' or by adding to sys.path. 66 | 67 | :param install_mode: Should the packages be installed via 'pip' or via sys.path 68 | :param no_deps: When installing via 'pip', should package dependencies be installed too? Defaults to False 69 | :param add_dev_deps: When installing via 'pip', should optional dev and tes dependencies be installed too? Defaults 70 | to False 71 | :raises ValueError: If the install_mode is not recognised 72 | """ 73 | if install_mode == "editable": 74 | editable_install_packages(no_deps=no_deps, add_dev_deps=add_dev_deps) 75 | elif install_mode == "system_path": 76 | add_packages_to_path() 77 | else: 78 | raise ValueError(f"Unknown install mode {install_mode}") 79 | 80 | 81 | if __name__ == "__main__": 82 | argparser = argparse.ArgumentParser(description="Install packges in editable mode") 83 | argparser.add_argument("--no-deps", action="store_true", default=False) 84 | argparser.add_argument("--add-dev-deps", action="store_true", default=False) 85 | args = argparser.parse_args() 86 | 87 | setup_packages(install_mode=InstallMode.EDITABLE, no_deps=args.no_deps, add_dev_deps=args.add_dev_deps) 88 | -------------------------------------------------------------------------------- /src/radfact/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RadFact/72463ac34442647886f13fad7d04958ab6c7c294/src/radfact/__init__.py -------------------------------------------------------------------------------- /src/radfact/azure_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RadFact/72463ac34442647886f13fad7d04958ab6c7c294/src/radfact/azure_utils/__init__.py -------------------------------------------------------------------------------- /src/radfact/azure_utils/auth.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import base64 7 | import json 8 | import logging 9 | import os 10 | from pathlib import Path 11 | from typing import Any, Callable 12 | 13 | from azure.identity import AzureCliCredential, DefaultAzureCredential 14 | from azureml._restclient.models.error_response import ErrorResponseException 15 | from health_azure import get_workspace 16 | 17 | from radfact.azure_utils.bearer_token_provider import get_bearer_token_provider 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | # The default scope for the Azure Cognitive Services. Tokens are retrieve from this page, and later used instead 22 | # of the API key. 23 | AZURE_COGNITIVE_SERVICES = "https://cognitiveservices.azure.com" 24 | 25 | 26 | def get_from_vault(secret_name: str, workspace_config_path: Path | None = None) -> str: 27 | """Reads a secret from the keyvault given the secret name. 28 | 29 | :param secret_name: The name of the secret in the keyvault. 30 | :param workspace_config_path: The path to the workspace configuration file. 31 | :return: Requested value. 32 | """ 33 | workspace = get_workspace(workspace_config_path=workspace_config_path) 34 | try: 35 | keyvault = workspace.get_default_keyvault() 36 | secret_value = str(keyvault.get_secret(name=secret_name)) 37 | return secret_value 38 | except ErrorResponseException: 39 | logger.warning("Unable to retrive secret key from keyvault.") 40 | raise 41 | 42 | 43 | def get_from_env_or_vault( 44 | env_var_name: str = "", secret_name: str = "", workspace_config_path: Path | None = None 45 | ) -> str: 46 | """Reads a value from an environment variable if possible. 47 | Otherwise, tries to read it from the keyvault given the secret name. 48 | 49 | :param env_var_name: The name of the environment variable. 50 | :param secret_name: The name of the secret in the keyvault. 51 | :param workspace_config_path: The path to the workspace configuration file. 52 | :return: Requested value from the environment variable or the keyvault. 53 | """ 54 | if not env_var_name and not secret_name: 55 | raise ValueError("Either env_var_name or secret_name must be provided.") 56 | value = os.environ.get(env_var_name, None) 57 | if value is not None: 58 | return value 59 | if not secret_name: 60 | raise ValueError("Secret name must be provided if the environment variable is not set.") 61 | value = get_from_vault(secret_name, workspace_config_path) 62 | return value 63 | 64 | 65 | def get_credential() -> AzureCliCredential | DefaultAzureCredential: 66 | """Get the appropriate Azure credential based on the environment. If the Azure CLI is installed and logged in, 67 | the Azure CLI credential is returned. Otherwise, the default Azure credential is returned.""" 68 | try: 69 | return AzureCliCredential() 70 | except Exception: 71 | logger.info("Failed to get Azure CLI credential. Trying default Azure credential.") 72 | return DefaultAzureCredential() 73 | 74 | 75 | def get_azure_token_provider() -> Callable[[], str]: 76 | """Get a token provider for Azure Cognitive Services. The bearer token provider gets authentication tokens and 77 | refreshes them automatically upon expiry. 78 | """ 79 | credential = get_credential() 80 | token = credential.get_token(AZURE_COGNITIVE_SERVICES) 81 | logger.info(f"Credentials: {print_token_details(token.token)}") 82 | return get_bearer_token_provider(credential, AZURE_COGNITIVE_SERVICES) 83 | 84 | 85 | def token_to_json(token: str) -> Any: 86 | """Converts an Azure access token to its underlying JSON structure. 87 | 88 | :param token: The access token. 89 | :return: The JSON object that is stored in the token. 90 | """ 91 | # This is code to dissect the token, taken from 92 | # https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/azure/identity/_internal/decorators.py#L38 93 | base64_meta_data = token.split(".")[1].encode("utf-8") + b"==" 94 | json_bytes = base64.decodebytes(base64_meta_data) 95 | json_string = json_bytes.decode("utf-8") 96 | return json.loads(json_string) 97 | 98 | 99 | def print_token_details(token: str) -> str: 100 | """Creates a human-readable string with details stored in the Azure token. 101 | 102 | :param token: The access token. 103 | :return: A string with information about the identity that was given the access token. 104 | """ 105 | json_dict = token_to_json(token) 106 | NOT_PRESENT = "(not available)" 107 | oid = NOT_PRESENT 108 | upn = NOT_PRESENT 109 | name = NOT_PRESENT 110 | appid = NOT_PRESENT 111 | try: 112 | oid = json_dict["oid"] 113 | except Exception: 114 | pass 115 | try: 116 | upn = json_dict["upn"] 117 | except Exception: 118 | pass 119 | try: 120 | name = json_dict["name"] 121 | except Exception: 122 | pass 123 | try: 124 | appid = json_dict["appid"] 125 | except Exception: 126 | pass 127 | return f"EntraID object ID {oid}, user principal name (upn) {upn}, name {name}, appid {appid}" 128 | 129 | 130 | def extract_object_id_from_token(token: str) -> str: 131 | """Extracts the object ID from an access token. 132 | The object ID is the unique identifier for the user or service principal in Azure Active Directory. 133 | 134 | :param token: The access token. 135 | :return: The object ID of the token. 136 | """ 137 | json_dict = token_to_json(token) 138 | return json_dict["oid"] # type: ignore 139 | -------------------------------------------------------------------------------- /src/radfact/azure_utils/bearer_token_provider.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | # This is copied from https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py because we are unable to build the environment 7 | # to use the latest version of azure-identity. This is a temporary solution until we can update the environment. 8 | # This code is licensed under the MIT License. https://github.com/Azure/azure-sdk-for-python/blob/main/LICENSE 9 | from typing import Callable 10 | 11 | from azure.core.credentials import TokenCredential 12 | from azure.core.pipeline import PipelineContext, PipelineRequest 13 | from azure.core.pipeline.policies import BearerTokenCredentialPolicy 14 | from azure.core.rest import HttpRequest 15 | 16 | 17 | def _make_request() -> PipelineRequest[HttpRequest]: 18 | return PipelineRequest(HttpRequest("CredentialWrapper", "https://fakeurl"), PipelineContext(None)) 19 | 20 | 21 | def get_bearer_token_provider(credential: TokenCredential, *scopes: str) -> Callable[[], str]: 22 | """Returns a callable that provides a bearer token. 23 | 24 | It can be used for instance to write code like: 25 | 26 | .. code-block:: python 27 | 28 | from azure.identity import DefaultAzureCredential, get_bearer_token_provider 29 | 30 | credential = DefaultAzureCredential() 31 | bearer_token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default") 32 | 33 | # Usage 34 | request.headers["Authorization"] = "Bearer " + bearer_token_provider() 35 | 36 | :param credential: The credential used to authenticate the request. 37 | :type credential: ~azure.core.credentials.TokenCredential 38 | :param str scopes: The scopes required for the bearer token. 39 | :rtype: callable 40 | :return: A callable that returns a bearer token. 41 | """ 42 | 43 | policy = BearerTokenCredentialPolicy(credential, *scopes) 44 | 45 | def wrapper() -> str: 46 | request = _make_request() 47 | policy.on_request(request) 48 | return request.http_request.headers["Authorization"][len("Bearer ") :] 49 | 50 | return wrapper 51 | -------------------------------------------------------------------------------- /src/radfact/cli/run_radfact.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import argparse 7 | import json 8 | import logging 9 | from pathlib import Path 10 | 11 | import pandas as pd 12 | 13 | from radfact.data_utils.grounded_phrase_list import GroundedPhraseList 14 | from radfact.llm_utils.report_to_phrases.processor import StudyIdType 15 | from radfact.metric.bootstrapping import MetricBootstrapper 16 | from radfact.metric.print_utils import print_bootstrap_results, print_results 17 | from radfact.metric.radfact import InputDict, RadFactMetric 18 | from radfact.paths import CONFIGS_DIR 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def validate_config_file(config_name: str | None) -> None: 24 | if config_name is not None: 25 | config_path = CONFIGS_DIR / f"{config_name}" 26 | if not config_path.exists(): 27 | message = ( 28 | f"Config file {config_name} does not exist. " 29 | "Make sure the config file is saved in the `configs` directory." 30 | ) 31 | raise FileNotFoundError(message) 32 | 33 | 34 | def get_candidates_and_references_from_csv(csv_path: Path) -> tuple[dict[StudyIdType, str], dict[StudyIdType, str]]: 35 | """Reads the csv file containing the samples to compute RadFact for and returns the candidates and references in 36 | the expected format.""" 37 | findings_generation_samples = pd.read_csv(csv_path) 38 | logger.info(f"Loaded {len(findings_generation_samples)} samples from {csv_path}") 39 | candidates = findings_generation_samples.set_index("example_id")["prediction"].to_dict() 40 | references = findings_generation_samples.set_index("example_id")["target"].to_dict() 41 | return candidates, references 42 | 43 | 44 | def get_candidates_and_references_from_json( 45 | json_path: Path, 46 | ) -> tuple[dict[StudyIdType, GroundedPhraseList], dict[StudyIdType, GroundedPhraseList]]: 47 | """Reads the json file containing the samples to compute RadFact for and returns the candidates and references in 48 | the expected format.""" 49 | with open(json_path, "r", encoding="utf-8") as f: 50 | grounded_reporting_samples = json.load(f) 51 | logger.info(f"Loaded {len(grounded_reporting_samples)} samples from {json_path}") 52 | candidates = { 53 | example["example_id"]: GroundedPhraseList.from_list_of_dicts(example["prediction"]) 54 | for example in grounded_reporting_samples 55 | } 56 | references = { 57 | example["example_id"]: GroundedPhraseList.from_list_of_dicts(example["target"]) 58 | for example in grounded_reporting_samples 59 | } 60 | return candidates, references 61 | 62 | 63 | def compute_radfact_scores( 64 | radfact_config_name: str | None, 65 | phrases_config_name: str | None, 66 | candidates: InputDict, 67 | references: InputDict, 68 | is_narrative_text: bool, 69 | bootstrap_samples: int, 70 | ) -> dict[str, float]: 71 | radfact_metric = RadFactMetric( 72 | nli_config_name=radfact_config_name, 73 | phrase_config_name=phrases_config_name, 74 | is_narrative_text=is_narrative_text, 75 | ) 76 | if bootstrap_samples == 0: 77 | _, results = radfact_metric.compute_metric_score(candidates, references) 78 | return results 79 | bootstrapper = MetricBootstrapper(metric=radfact_metric, num_samples=10, seed=42) 80 | results_per_sample = radfact_metric.compute_results_per_sample(candidates, references) 81 | return bootstrapper.compute_bootstrap_metrics(results_per_sample=results_per_sample) 82 | 83 | 84 | def main() -> None: 85 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s: %(message)s") 86 | parser = argparse.ArgumentParser( 87 | description="Compute RadFact metric for a set of samples and saves the results to a json file." 88 | ) 89 | parser.add_argument( 90 | "--input_path", 91 | type=str, 92 | help="The path to the csv or json file containing the samples to compute RadFact for. For finding generation " 93 | "samples, the csv file should have columns 'example_id', 'prediction', and 'target' similar to the example in " 94 | "`examples/findings_generation_examples.csv`. For grounded reporting samples, provide a json file in the same " 95 | "format as `examples/grounded_reporting_examples.json`.", 96 | required=True, 97 | ) 98 | parser.add_argument( 99 | "--is_narrative_text", 100 | action="store_true", 101 | help="Whether the input samples are narrative text or not. If true, the input samples are expected to be " 102 | "narrative text, otherwise they are expected to be grounded phrases.", 103 | ) 104 | parser.add_argument( 105 | "--radfact_config_name", 106 | type=str, 107 | help="The name of the config file for RadFact processing. We use the default config file but you can provide a " 108 | "custom config. Make sure the config follows the same structure as `configs/radfact.yaml` and is saved in the " 109 | "`configs` directory. This is necessary for hydra initialization from the `configs` directory.", 110 | default=None, 111 | ) 112 | parser.add_argument( 113 | "--phrases_config_name", 114 | type=str, 115 | help="The name of the config file for reports to phrases conversion. We use the default config file but you " 116 | "can provide a custom config. Make sure the config follows the same structure as " 117 | "`configs/report_to_phrases.yaml` and is saved in the `configs` directory. This is necessary for hydra " 118 | "initialization from the `configs` directory.", 119 | default=None, 120 | ) 121 | parser.add_argument( 122 | "--output_dir", 123 | type=str, 124 | help="Path to the directory where the results will be saved as a json file.", 125 | default="outputs", 126 | ) 127 | parser.add_argument( 128 | "--bootstrap_samples", 129 | type=int, 130 | help="Number of bootstrap samples to use for computing the confidence intervals. Set to 0 to disable " 131 | "bootstrapping.", 132 | default=500, 133 | ) 134 | 135 | args = parser.parse_args() 136 | input_path = Path(args.input_path) 137 | output_dir = Path(args.output_dir) 138 | is_narrative_text = args.is_narrative_text 139 | radfact_config_name = args.radfact_config_name 140 | phrases_config_name = args.phrases_config_name 141 | bootstrap_samples = args.bootstrap_samples 142 | 143 | assert input_path.suffix in [".csv", ".json"], "Input file must be a csv or json file." 144 | assert input_path.suffix == ".csv" or not is_narrative_text, ( 145 | "Input file must be a json file for grounded phrases and is_narrative_text must be False. For narrative text, " 146 | "input file must be a csv file and is_narrative_text must be True." 147 | ) 148 | validate_config_file(radfact_config_name) 149 | validate_config_file(phrases_config_name) 150 | 151 | candidates: InputDict 152 | references: InputDict 153 | 154 | if is_narrative_text: 155 | candidates, references = get_candidates_and_references_from_csv(input_path) 156 | else: 157 | candidates, references = get_candidates_and_references_from_json(input_path) 158 | 159 | results = compute_radfact_scores( 160 | radfact_config_name=radfact_config_name, 161 | phrases_config_name=phrases_config_name, 162 | candidates=candidates, 163 | references=references, 164 | is_narrative_text=is_narrative_text, 165 | bootstrap_samples=bootstrap_samples, 166 | ) 167 | 168 | print_fn = print_results if bootstrap_samples == 0 else print_bootstrap_results 169 | if is_narrative_text: 170 | print("RadFact scores for narrative text samples") 171 | print_fn(results=results, metrics=["logical_precision", "logical_recall", "logical_f1", "num_llm_failures"]) 172 | else: 173 | print("RadFact scores for grounded phrases samples") 174 | print_fn( 175 | results=results, 176 | metrics=[ 177 | "logical_precision", 178 | "logical_recall", 179 | "logical_f1", 180 | "spatial_precision", 181 | "spatial_recall", 182 | "spatial_f1", 183 | "grounding_precision", 184 | "grounding_recall", 185 | "grounding_f1", 186 | "num_llm_failures", 187 | ], 188 | ) 189 | 190 | output_path = output_dir / f"radfact_scores_{input_path.stem}.json" 191 | with open(output_path, "w", encoding="utf-8") as f: 192 | json.dump(results, f, indent=2) 193 | logger.info(f"Results saved to {output_path}") 194 | 195 | 196 | if __name__ == "__main__": 197 | main() 198 | -------------------------------------------------------------------------------- /src/radfact/cli/run_radfact_test_examples.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import json 7 | import logging 8 | 9 | from radfact.data_utils.grounded_phrase_list import GroundedPhrase, GroundedPhraseList 10 | from radfact.metric.bootstrapping import MetricBootstrapper 11 | from radfact.metric.radfact import InputDict, RadFactMetric 12 | from radfact.paths import EXAMPLES_DIR 13 | from radfact.metric.print_utils import print_bootstrap_results 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def read_examples() -> tuple[InputDict, InputDict]: 19 | json_path = EXAMPLES_DIR / "test_examples.json" 20 | with open(json_path, "r", encoding="utf-8") as f: 21 | examples_json = json.load(f) 22 | candidates = { 23 | example["example_id"]: GroundedPhraseList([GroundedPhrase(phrase) for phrase in example["input"]["phrases_A"]]) 24 | for example in examples_json 25 | } 26 | references = { 27 | example["example_id"]: GroundedPhraseList([GroundedPhrase(phrase) for phrase in example["input"]["phrases_B"]]) 28 | for example in examples_json 29 | } 30 | return candidates, references 31 | 32 | 33 | def run_radfact() -> None: 34 | candidates, references = read_examples() 35 | metric = RadFactMetric() 36 | logger.info(f"Computing RadFact metric for {len(candidates)} examples") 37 | results_per_sample = metric.compute_results_per_sample(candidates, references) 38 | metric.is_narrative_text = True # to avoid computing box metrics that are not relevant for this test 39 | bootstrapper = MetricBootstrapper(metric=metric, num_samples=500, seed=42) 40 | results_with_error_bars = bootstrapper.compute_bootstrap_metrics(results_per_sample=results_per_sample) 41 | 42 | metrics = ["logical_precision", "logical_recall", "logical_f1", "num_llm_failures"] 43 | print("RadFact results using Llama-3-70b-Instruct model") 44 | print_bootstrap_results(results_with_error_bars, metrics) 45 | 46 | expected_results = { 47 | "logical_precision/median": 0.3554, 48 | "logical_precision/p2.5th": 0.2745, 49 | "logical_precision/p97.5th": 0.4327, 50 | "logical_recall/median": 0.3211, 51 | "logical_recall/p2.5th": 0.2328, 52 | "logical_recall/p97.5th": 0.4093, 53 | "logical_f1/median": 0.3385, 54 | "logical_f1/p2.5th": 0.2607, 55 | "logical_f1/p97.5th": 0.4140, 56 | "num_llm_failures/median": 0.0, 57 | "num_llm_failures/p2.5th": 0.0, 58 | "num_llm_failures/p97.5th": 0.0, 59 | } 60 | print("Expected results range") 61 | # You should expect the results to be within the range printed here - doube check num_llm_failures if you notice 62 | # major discrepancies. The results may vary if you encounter many LLM failures or if you're using a different model. 63 | print_bootstrap_results(expected_results, metrics) 64 | 65 | 66 | if __name__ == "__main__": 67 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s: %(message)s") 68 | run_radfact() 69 | -------------------------------------------------------------------------------- /src/radfact/cli/run_report_to_phrases.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | 8 | import hydra 9 | import pandas as pd 10 | from omegaconf import DictConfig 11 | 12 | from radfact.llm_utils.report_to_phrases.processor import FINDINGS_SECTION, get_report_to_phrases_engine 13 | from radfact.paths import CONFIGS_DIR 14 | 15 | logger = logging.getLogger(__name__) 16 | http_logger = logging.getLogger("httpx") 17 | http_logger.setLevel(logging.WARNING) 18 | 19 | 20 | def _validate_column(df: pd.DataFrame, col_name: str) -> None: 21 | if col_name not in df.columns: 22 | raise ValueError(f"Column {col_name} not found in dataset. Available columns: {df.columns}") 23 | 24 | 25 | def get_dataset_dataframe(cfg: DictConfig) -> pd.DataFrame: 26 | """Read the dataset dataframe and drop duplicates based on the index column and findings column.""" 27 | dataset_dataframe_path = cfg.dataset.csv_path 28 | dataset_name = cfg.dataset.name 29 | df = pd.read_csv(dataset_dataframe_path) 30 | findings_col = FINDINGS_SECTION 31 | _validate_column(df, findings_col) 32 | id_col = cfg.processing.index_col 33 | _validate_column(df, id_col) 34 | df.drop_duplicates(subset=[id_col, findings_col], inplace=True) 35 | logger.info(f"Loaded {len(df)} rows from {dataset_name} dataset") 36 | df.dropna(subset=[findings_col], inplace=True) 37 | logger.info(f"Processing {len(df)} rows with non-null findings") 38 | return df 39 | 40 | 41 | @hydra.main(version_base=None, config_path=str(CONFIGS_DIR), config_name="report_to_phrases") 42 | def main(cfg: DictConfig) -> None: 43 | dataset_df = get_dataset_dataframe(cfg) 44 | engine = get_report_to_phrases_engine(cfg=cfg, dataset_df=dataset_df) 45 | engine.run() 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /src/radfact/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RadFact/72463ac34442647886f13fad7d04958ab6c7c294/src/radfact/data_utils/__init__.py -------------------------------------------------------------------------------- /src/radfact/data_utils/grounded_phrase_list.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | from __future__ import annotations 7 | 8 | import dataclasses 9 | from dataclasses import dataclass 10 | from typing import Generic, Iterator, Mapping, TypeVar 11 | 12 | TypeT = TypeVar("TypeT", float, int) 13 | 14 | 15 | @dataclass(frozen=True) 16 | class GenericBox(Generic[TypeT]): 17 | """Bounding box class with coordinates of type TypeT. Allows for looping and unpacking of coordinates.""" 18 | 19 | x_min: TypeT 20 | y_min: TypeT 21 | x_max: TypeT 22 | y_max: TypeT 23 | 24 | def __post_init__(self) -> None: 25 | if not 0 <= self.x_min <= self.x_max: 26 | raise ValueError(f"Invalid x coordinates: {self}") 27 | if not 0 <= self.y_min <= self.y_max: 28 | raise ValueError(f"Invalid y coordinates: {self}") 29 | 30 | def __iter__(self) -> Iterator[TypeT]: 31 | yield from (self.x_min, self.y_min, self.x_max, self.y_max) 32 | 33 | def has_zero_area(self) -> bool: 34 | return self.x_min == self.x_max or self.y_min == self.y_max 35 | 36 | 37 | @dataclass(frozen=True) 38 | class NormalizedBox(GenericBox[float]): 39 | """Bounding box normalized to the image size, with coordinates in the range [0, 1].""" 40 | 41 | def __post_init__(self) -> None: 42 | super().__post_init__() 43 | if not self.x_max <= 1: 44 | raise ValueError(f"Invalid x coordinates: {self}") 45 | if not self.y_max <= 1: 46 | raise ValueError(f"Invalid y coordinates: {self}") 47 | 48 | 49 | BoxDictType = dict[str, float] 50 | GroundedPhraseDictType = Mapping[str, str | list[BoxDictType]] 51 | 52 | 53 | @dataclass(frozen=True) 54 | class GroundedPhrase: 55 | """A grounded phrase consists of a string with an (optional) list of normalized bounding boxes.""" 56 | 57 | text: str 58 | boxes: list[NormalizedBox] | None = None 59 | 60 | def __post_init__(self) -> None: 61 | if self.boxes is not None and len(self.boxes) == 0: 62 | raise ValueError(f"Empty boxes for grounded text: {self}, this should be set to None") 63 | 64 | @classmethod 65 | def from_dict(cls, grounded_phrase_dict: GroundedPhraseDictType) -> GroundedPhrase: 66 | text = grounded_phrase_dict["text"] 67 | if not isinstance(text, str): 68 | raise ValueError(f"text is not a string: {text}") 69 | box_list = grounded_phrase_dict["boxes"] 70 | if box_list is None: 71 | return cls(text=text, boxes=None) 72 | if isinstance(box_list, list): 73 | return cls(text=text, boxes=[NormalizedBox(**box) for box in box_list]) 74 | else: 75 | raise ValueError(f"boxes is not a list: {box_list}") 76 | 77 | 78 | GroundedPhraseListType = list[str | NormalizedBox | GroundedPhrase] 79 | 80 | GroundedPhraseListDictType = list[Mapping[str, str | BoxDictType | list[BoxDictType]]] 81 | 82 | 83 | class GroundedPhraseList(GroundedPhraseListType): 84 | def get_all_text(self, sep: str = " ") -> str: 85 | """Extract all text segments from the sequence as a continuous string. 86 | 87 | :param sep: Separator between joined substrings, defaults to " ". 88 | :return: Single string containing whitespace-stripped text segments joined by `sep`. 89 | """ 90 | text_parts = [] 91 | for part in self: 92 | if isinstance(part, str): 93 | text_parts.append(part.strip()) 94 | elif isinstance(part, GroundedPhrase): 95 | text_parts.append(part.text.strip()) 96 | return sep.join(text_parts) 97 | 98 | def get_all_boxes(self, fail_if_non_box: bool = False) -> list[NormalizedBox]: 99 | """Extract all bounding boxes from the sequence as a single list.""" 100 | box_list = [] 101 | for part in self: 102 | if isinstance(part, NormalizedBox): 103 | box_list.append(part) 104 | elif fail_if_non_box: 105 | raise ValueError(f"Encountered a non-box while extracting boxes: {part}") 106 | if isinstance(part, GroundedPhrase) and part.boxes is not None: 107 | box_list.extend(part.boxes) 108 | 109 | return box_list 110 | 111 | def get_all_grounded_phrases(self, fail_if_non_grounded_phrase: bool = False) -> list[GroundedPhrase]: 112 | """Extract all GroundedPhrase from the sequence as a single list. 113 | 114 | If there are any non-GroundedPhrase in the sequence, it will raise a ValueError if fail_if_non_phrase is True. 115 | This can occur if we expect a sequence to contain only GroundedPhrase. 116 | """ 117 | phrases = [part for part in self if isinstance(part, GroundedPhrase)] 118 | if fail_if_non_grounded_phrase and len(phrases) != len(self): 119 | raise ValueError(f"Encountered a non-GroundedPhrase while extracting phrases: {self}") 120 | return phrases 121 | 122 | def to_list_of_dicts(self) -> GroundedPhraseListDictType: 123 | """Convert the sequence to a list of dictionaries.""" 124 | 125 | list_of_dicts: GroundedPhraseListDictType = [] 126 | for part in self: 127 | if isinstance(part, str): 128 | list_of_dicts.append({"text": part}) 129 | elif isinstance(part, NormalizedBox): 130 | box_as_dict: dict[str, float] = dataclasses.asdict(part) 131 | list_of_dicts.append({"box": box_as_dict}) 132 | elif isinstance(part, GroundedPhrase): 133 | list_of_dicts.append(dataclasses.asdict(part)) 134 | else: 135 | raise ValueError(f"Unknown member of grounded phrase list: {part}") 136 | return list_of_dicts 137 | 138 | @classmethod 139 | def from_list_of_dicts(cls, list_of_dicts: GroundedPhraseListDictType) -> GroundedPhraseList: 140 | """Convert a list of dictionaries to a grounded phrase list. 141 | 142 | :param list_of_dicts: List of dictionaries. 143 | """ 144 | if not isinstance(list_of_dicts, list): 145 | raise ValueError(f"Expected list of dictionaries, got: {list_of_dicts}") 146 | grounded_phrase_list: GroundedPhraseListType = [] 147 | for part in list_of_dicts: 148 | if not isinstance(part, dict): 149 | raise ValueError(f"Expected dictionary, got: {part}") 150 | part_keys = part.keys() 151 | if part_keys == {"text"}: 152 | assert isinstance(part["text"], str), f"Expected string, got: {part['text']}" 153 | grounded_phrase_list.append(part["text"]) 154 | elif part_keys == {"box"}: 155 | box = part["box"] 156 | grounded_phrase_list.append(NormalizedBox(**box)) 157 | elif part_keys == {"text", "boxes"}: 158 | grounded_phrase_list.append(GroundedPhrase.from_dict(part)) 159 | else: 160 | raise ValueError(f"Unknown member of grounded phrase list: {part}") 161 | return cls(grounded_phrase_list) 162 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RadFact/72463ac34442647886f13fad7d04958ab6c7c294/src/radfact/llm_utils/__init__.py -------------------------------------------------------------------------------- /src/radfact/llm_utils/endpoint.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | from dataclasses import dataclass, field 7 | from enum import Enum 8 | from typing import Callable 9 | 10 | from radfact.azure_utils.auth import get_azure_token_provider, get_from_env_or_vault 11 | from radfact.paths import WORKSPACE_CONFIG_PATH 12 | 13 | # The default name under which an endpoint API Key is stored in environment variables. 14 | ENV_API_KEY = "API_KEY" 15 | 16 | 17 | class EndpointType(Enum): 18 | AZURE_CHAT_OPENAI = "azure_chat_openai" 19 | CHAT_OPENAI = "chat_openai" 20 | 21 | 22 | @dataclass(frozen=False) 23 | class Endpoint: 24 | url: str 25 | deployment_name: str 26 | type: EndpointType = EndpointType.AZURE_CHAT_OPENAI 27 | speed_factor: float = 1.0 28 | num_parallel_processes: int = 1 29 | api_key_env_var_name: str = ENV_API_KEY 30 | keyvault_secret_name: str = "" 31 | # The name of the Redis cache for this endpoint. If empty, no cache is used. Make sure to update the cache 32 | # location if the model type changes significantly, and we expect different responses. 33 | redis_cache: str = "" 34 | _api_key: str | None = field(default=None, init=False) 35 | _token_provider: Callable[[], str] | None = field(default=None, init=False) 36 | 37 | @property 38 | def api_key(self) -> str: 39 | if self._api_key is None: 40 | self._api_key = get_from_env_or_vault( 41 | env_var_name=self.api_key_env_var_name, 42 | secret_name=self.keyvault_secret_name, 43 | workspace_config_path=WORKSPACE_CONFIG_PATH, 44 | ) 45 | assert self._api_key is not None # for mypy 46 | return self._api_key 47 | 48 | @property 49 | def token_provider(self) -> Callable[[], str]: 50 | if self._token_provider is None: 51 | self._token_provider = get_azure_token_provider() 52 | assert self._token_provider is not None # for mypy 53 | return self._token_provider 54 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RadFact/72463ac34442647886f13fad7d04958ab6c7c294/src/radfact/llm_utils/engine/__init__.py -------------------------------------------------------------------------------- /src/radfact/llm_utils/engine/arguments.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | from abc import ABCMeta, abstractmethod 8 | from dataclasses import dataclass, field 9 | from typing import Any 10 | 11 | from azureml._restclient.models.error_response import ErrorResponseException 12 | from langchain_openai import AzureChatOpenAI, ChatOpenAI 13 | 14 | from radfact.llm_utils.endpoint import Endpoint, EndpointType 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @dataclass 20 | class OpenaiAPIArguments(metaclass=ABCMeta): 21 | """Base class for OpenAI API models.""" 22 | 23 | endpoint: Endpoint | None = field(default=None) 24 | api_version: str = field(default="2023-06-01-preview") 25 | max_retries: int = field(default=10) 26 | timeout: float | None = field(default=None) 27 | 28 | def set_endpoint(self, endpoint: Endpoint) -> None: 29 | """Set the endpoint for the API.""" 30 | self.endpoint = endpoint 31 | 32 | def get_params(self) -> dict[str, Any]: 33 | """Get LLM params as a dict. The dict keys match the expected arguments of the API. Check the OpenAI API 34 | documentation for more details. 35 | 36 | :raises ValueError: If the endpoint type is not supported. 37 | :return: The LLM params as a dict. 38 | """ 39 | if self.endpoint is None: 40 | raise ValueError("Endpoint must be set before getting the params.") 41 | match self.endpoint.type: 42 | case EndpointType.AZURE_CHAT_OPENAI: 43 | params = dict( 44 | deployment_name=self.endpoint.deployment_name, 45 | azure_endpoint=self.endpoint.url, 46 | openai_api_version=self.api_version, 47 | max_retries=self.max_retries, 48 | request_timeout=self.timeout, 49 | ) 50 | try: 51 | params["openai_api_key"] = self.endpoint.api_key 52 | except (ValueError, ErrorResponseException): 53 | logger.info( 54 | "Could not find API key in environment variables nor in the keyvault... Trying token provider." 55 | ) 56 | params["azure_ad_token_provider"] = self.endpoint.token_provider 57 | return params 58 | case EndpointType.CHAT_OPENAI: 59 | return dict( 60 | model=self.endpoint.deployment_name, 61 | base_url=self.endpoint.url, 62 | openai_api_key=self.endpoint.api_key, 63 | max_retries=self.max_retries, 64 | request_timeout=self.timeout, 65 | ) 66 | case _: 67 | raise ValueError(f"Unsupported endpoint type {self.endpoint.type}") 68 | 69 | @abstractmethod 70 | def get_model(self) -> ChatOpenAI | AzureChatOpenAI: 71 | """Returns the chat model.""" 72 | raise NotImplementedError(f"get_model() must be implemented in a subclass {self.__class__.__name__}") 73 | 74 | 75 | @dataclass 76 | class LLMAPIArguments(OpenaiAPIArguments): 77 | """Chat API for an LLM expects arguments to match ChatOpenAI or AzureChatOpenAI.""" 78 | 79 | temperature: float = field(default=0.0) 80 | max_tokens: int = field(default=1024) 81 | top_p: float = field(default=0.95) 82 | frequency_penalty: float = field(default=0.0) 83 | presence_penalty: float = field(default=0.0) 84 | stop: list[str] | None = field(default=None) 85 | n_completions: int = field(default=1) 86 | 87 | def get_chat_completion_params(self) -> dict[str, Any]: 88 | """Get the chat completion params. Note that these params are specific to the chat completion API, the dict 89 | keys match the expected arguments of the API. Check the OpenAI API documentation for more details. 90 | https://api.python.langchain.com/en/stable/chat_models/langchain_openai.chat_models.azure.AzureChatOpenAI.html#langchain_openai.chat_models.azure.AzureChatOpenAI 91 | """ 92 | return dict( 93 | temperature=self.temperature, 94 | max_tokens=self.max_tokens, 95 | n=self.n_completions, 96 | top_p=self.top_p, 97 | frequency_penalty=self.frequency_penalty, 98 | presence_penalty=self.presence_penalty, 99 | stop=self.stop, 100 | ) 101 | 102 | def get_params(self) -> dict[str, Any]: 103 | """Get LLM params as a dict.""" 104 | params = super().get_params() 105 | params.update(self.get_chat_completion_params()) 106 | return params 107 | 108 | def get_model(self) -> ChatOpenAI | AzureChatOpenAI: 109 | assert self.endpoint is not None # for mypy 110 | match self.endpoint.type: 111 | case EndpointType.AZURE_CHAT_OPENAI: 112 | return AzureChatOpenAI(**self.get_params()) 113 | case EndpointType.CHAT_OPENAI: 114 | return ChatOpenAI(**self.get_params()) 115 | case _: 116 | raise ValueError(f"Unsupported endpoint type {self.endpoint.type}") 117 | 118 | 119 | @dataclass 120 | class LLMEngineArguments: 121 | """Arguments for the LLM engine wrapper around a processor. 122 | 123 | :param index_col: The name of the index column in the dataset. 124 | :param batch_size: The batch size for processing the dataset. 125 | :param start_index: The start index for processing the dataset. 126 | :param end_index: The end index for processing the dataset. 127 | :param output_filename: The name of the output file. 128 | """ 129 | 130 | index_col: str 131 | batch_size: int = 100 132 | start_index: int = 0 133 | dataset_name: str | None = field(default=None) 134 | end_index: int | None = field(default=None) 135 | output_filename: str = "output.json" 136 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/engine/data_subset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | from dataclasses import dataclass, field 8 | from pathlib import Path 9 | 10 | import pandas as pd 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | @dataclass 16 | class DataSubset: 17 | """Class to represent a subset of a dataset. Includes methods to save progress and skipped IDs.""" 18 | 19 | df: pd.DataFrame 20 | start_index: int 21 | end_index: int 22 | index_col: str 23 | output_folder: Path 24 | processed_ids: set[str] = field(default_factory=set) 25 | skipped_ids: set[str] = field(default_factory=set) 26 | 27 | def __post_init__(self) -> None: 28 | assert ( 29 | len(self.df) == self.end_index - self.start_index 30 | ), f"Dataframe {len(self.df)=} does not match {(self.end_index - self.start_index)=}" 31 | self.progress_folder = self.output_folder / "progress" 32 | self.progress_folder.mkdir(parents=True, exist_ok=True) 33 | self.skip_folder = self.output_folder / "skipped" 34 | self.skip_folder.mkdir(parents=True, exist_ok=True) 35 | 36 | @property 37 | def filename(self) -> str: 38 | """Return the file stem for this subset of the dataset.""" 39 | return f"subset_{self.start_index}_{self.end_index}.csv" 40 | 41 | @property 42 | def progress_file(self) -> Path: 43 | """Status file for this subset of the dataset saving the progress so far.""" 44 | return self.progress_folder / self.filename 45 | 46 | @property 47 | def skipped_file(self) -> Path: 48 | """File for IDs that have been skipped.""" 49 | return self.skip_folder / self.filename 50 | 51 | @property 52 | def relative_progress(self) -> float: 53 | """Return the relative progress of this subset of the dataset.""" 54 | return len(self.processed_ids) / len(self.df) 55 | 56 | @property 57 | def progress_stats(self) -> dict[str, str | float]: 58 | """Return the name of the progress metric for this subset of the dataset.""" 59 | return {"name": f"progress_{self.start_index}_{self.end_index}", "value": self.relative_progress} 60 | 61 | @property 62 | def skipped_stats(self) -> dict[str, str | int]: 63 | """Return the name of the skipped metric for this subset of the dataset.""" 64 | return {"name": f"skipped_{self.start_index}_{self.end_index}", "value": len(self.skipped_ids)} 65 | 66 | @property 67 | def indices(self) -> set[str]: 68 | """Return the indices of this subset of the dataset.""" 69 | return set(self.df[self.index_col]) 70 | 71 | def save_progress(self) -> None: 72 | """Save progress to progress file""" 73 | if len(self.processed_ids) > 0: 74 | processed_df = pd.DataFrame({self.index_col: list(self.processed_ids)}) 75 | processed_df.to_csv(self.progress_file, index=False) 76 | 77 | def save_skipped(self) -> None: 78 | """Save skipped IDs to skipped file""" 79 | if len(self.skipped_ids) > 0: 80 | skipped_df = pd.DataFrame({self.index_col: list(self.skipped_ids)}) 81 | skipped_df.to_csv(self.skipped_file, index=False) 82 | 83 | def __len__(self) -> int: 84 | return len(self.df) 85 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/engine/endpoint_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | 8 | from omegaconf import DictConfig 9 | 10 | from radfact.llm_utils.endpoint import ENV_API_KEY, Endpoint, EndpointType 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def get_endpoints_dict_sorted_by_speed( 16 | cfg: DictConfig, 17 | descending: bool = False, 18 | default_num_parallel_processes: int = 1, 19 | ) -> dict[str, Endpoint]: 20 | """Return a dictionary of endpoints sorted by speed factor in the order specified by the descending parameter. 21 | (Default: ascending order, slowest first, fastest last) 22 | 23 | :param cfg: The OmegaConf configuration object for the whole engine 24 | :param descending: If True, return the endpoints in descending speed (Fastest first, slowest last). 25 | :param default_num_parallel_processes: The default number of parallel processes to run for each endpoint. If the 26 | endpoint can handle multiple parallel processes, it will be replicated n times in the dictionary. 27 | :return: A dictionary of endpoint objects sorted by speed factor 28 | """ 29 | endpoint_objs: list[Endpoint] = [] 30 | for endpoint_name, endpoint in cfg.endpoints.items(): 31 | assert isinstance(endpoint, DictConfig) 32 | logger.info(f"Creating Endpoint object for {endpoint_name}") 33 | endpoint_obj = Endpoint( 34 | url=endpoint.get("url"), 35 | type=EndpointType[endpoint.get("type")], 36 | api_key_env_var_name=endpoint.get("api_key_env_var_name", ENV_API_KEY), 37 | keyvault_secret_name=endpoint.get("keyvault_secret_name", ""), 38 | deployment_name=endpoint.get("deployment_name"), 39 | speed_factor=endpoint.get("speed_factor", 1.0), 40 | num_parallel_processes=endpoint.get("num_parallel_processes", default_num_parallel_processes), 41 | ) 42 | endpoint_objs.append(endpoint_obj) 43 | endpoint_objs.sort(key=lambda x: x.speed_factor, reverse=descending) 44 | if not all(endpoint.deployment_name == endpoint_objs[0].deployment_name for endpoint in endpoint_objs): 45 | raise ValueError( 46 | f"All endpoints must be of the same type but got {[endpoint.deployment_name for endpoint in endpoint_objs]}" 47 | ) 48 | return {endpoint.url: endpoint for endpoint in endpoint_objs} 49 | 50 | 51 | def replicate_same_endpoint_n_times(endpoints: dict[str, Endpoint]) -> dict[str, Endpoint]: 52 | """Replicate each endpoint n times in the dictionary if num_parallel_processes > 1. 53 | 54 | :param endpoints: A dictionary of endpoint objects. 55 | :return: A dictionary of endpoint objects with replicated endpoints. 56 | """ 57 | replicated_endpoints = {} 58 | for url, endpoint in endpoints.items(): 59 | if endpoint.num_parallel_processes > 1: 60 | logger.info(f"Replicating endpoint {url} {endpoint.num_parallel_processes} times") 61 | for i in range(endpoint.num_parallel_processes): 62 | # this gives the illusion that it's a different endpoint for dataset sharing but endpoint.url is 63 | # unchanged so we can spawn/fork n processes and send parallel requests to the same endpoint 64 | replicated_endpoints[f"{url}_{i}"] = endpoint 65 | else: 66 | replicated_endpoints[url] = endpoint 67 | return replicated_endpoints 68 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/engine/redis_cache.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | import re 8 | 9 | import redis 10 | from langchain_community.cache import RedisCache 11 | 12 | from radfact.azure_utils.auth import extract_object_id_from_token, get_credential 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def remove_endpoint_from_json_string(llm_string: str) -> str: 18 | """Remove the endpoint URL from the llm_string, using either the format used by AzureOpenAI or ChatOpenAI.""" 19 | # azure_endpoint is used by AzureOpenAI, api_endpoint is used by ChatOpenAI 20 | pattern = r'"(azure|api)_endpoint": "https://[^"]+"' 21 | return re.sub(pattern, "", llm_string) 22 | 23 | 24 | class RedisCacheWithoutEndpoint(RedisCache): 25 | """A RedisCache object that ignores the 'azure_endpoint' field when querying the cache.""" 26 | 27 | @staticmethod 28 | def _key(prompt: str, llm_string: str) -> str: 29 | """Compute key from prompt and llm_string, while ignoring the 'azure_endpoint' field in the llm_string.""" 30 | # We simply remove the 'azure_endpoint' entry in the LLM string (JSON serialized dictionary) before 31 | # computing the hash. The LLM string is then no longer a valid JSON, but that does not matter for cache lookup 32 | return RedisCache._key(prompt, remove_endpoint_from_json_string(llm_string)) 33 | 34 | 35 | def get_redis_cache(redis_cache_name: str) -> RedisCache: 36 | """Gets a RedisCache object that points to the given Redis cache object in Azure. 37 | When running in AzureML, the cache is accessed using the cluster managed identity, 38 | otherwise it uses the current users default Azure credentials. 39 | 40 | :param redis_cache_name: The name of the Redis cache in Azure, without the .redis.cache.windows.net suffix. 41 | :return: A RedisCache object that points to the given Redis cache object in Azure. 42 | """ 43 | credential = get_credential() 44 | token = credential.get_token("https://redis.azure.com/.default").token 45 | # The Redis username is the object id of the managed identity, which can be read out from the OID field of the token 46 | redis_username = extract_object_id_from_token(token) 47 | redis_url = f"{redis_cache_name}.redis.cache.windows.net" 48 | logger.info(f"Connecting to Redis cache {redis_url} with AAD object id {redis_username}") 49 | redis_client = redis.Redis( 50 | host=redis_url, 51 | port=6380, 52 | password=token, 53 | username=redis_username, 54 | ssl=True, 55 | ) 56 | # Set a simple test key to check authentication as early as possible 57 | logger.info("Testing Redis connection") 58 | redis_client.set("testkey", "testvalue") 59 | logger.info("Redis connection successful") 60 | return RedisCacheWithoutEndpoint(redis_client) 61 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/nli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RadFact/72463ac34442647886f13fad7d04958ab6c7c294/src/radfact/llm_utils/nli/__init__.py -------------------------------------------------------------------------------- /src/radfact/llm_utils/nli/processor.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | from enum import Enum 8 | from functools import partial 9 | from pathlib import Path 10 | from typing import Any, Callable 11 | 12 | import pandas as pd 13 | from langchain_core.language_models import BaseLanguageModel 14 | from langchain_core.messages import BaseMessage 15 | from omegaconf import DictConfig 16 | from pydantic import ValidationError 17 | 18 | from radfact.data_utils.grounded_phrase_list import GroundedPhraseList 19 | from radfact.llm_utils.engine.engine import LLMEngine, get_subfolder 20 | from radfact.llm_utils.nli.schema import ( 21 | ComparisonQuerySinglePhrase, 22 | DirectionOptions, 23 | EvidencedPhrase, 24 | EVState, 25 | NLIQuerySample, 26 | NLIQuerySampleSinglePhrase, 27 | NLISample, 28 | NLISampleSinglePhrase, 29 | load_examples_from_json, 30 | ) 31 | from radfact.llm_utils.processor.base_processor import BaseProcessor 32 | from radfact.llm_utils.processor.structured_processor import ( 33 | FormatStyleOptions, 34 | ProcessorStats, 35 | StructuredProcessor, 36 | simple_formatter, 37 | ) 38 | from radfact.paths import OUTPUT_DIR, get_prompts_dir 39 | 40 | logger = logging.getLogger(__name__) 41 | PARSING_TASK = "nli" 42 | PROMPTS_DIR = get_prompts_dir(task=PARSING_TASK) 43 | RADFACT_SUBFOLDER = "radfact" 44 | 45 | 46 | class MetricDataframeKeys(str, Enum): 47 | CANDIDATE = "candidate" 48 | REFERENCE = "reference" 49 | STUDY_ID = "study_id" 50 | 51 | 52 | def get_ev_processor_singlephrase(log_dir: Path) -> StructuredProcessor[ComparisonQuerySinglePhrase, EvidencedPhrase]: 53 | """ 54 | Helper function to load the NLI processor with the correct system prompt and few-shot examples. 55 | 56 | The setting here is to classify a SINGLE PHRASE at a time given the reference report. 57 | Further, we do entailment verification, aka the binary version of NLI. 58 | 59 | :param api_arguments: API arguments for the LLM. 60 | :param log_dir: Directory to save logs. 61 | :return: Processor for entailment verification. 62 | """ 63 | 64 | system_prompt_path = PROMPTS_DIR / "system_message_ev_singlephrase.txt" 65 | few_shot_examples_path = PROMPTS_DIR / "few_shot_examples.json" 66 | system_prompt = system_prompt_path.read_text() 67 | few_shot_examples = load_examples_from_json(json_path=few_shot_examples_path, binary=True) 68 | # The few-shots are in the bidirectional format, we need to convert them to single-phrase. 69 | few_shot_examples_single_phrase: list[NLISampleSinglePhrase] = [] 70 | for few_shot_example in few_shot_examples: 71 | one_way_dict = NLISampleSinglePhrase.from_nli_sample(few_shot_example) 72 | for single_phrase_sample in one_way_dict.values(): 73 | few_shot_examples_single_phrase.extend(single_phrase_sample) 74 | 75 | formatter = partial(simple_formatter, style=FormatStyleOptions.YAML) 76 | 77 | processor = StructuredProcessor( 78 | query_type=ComparisonQuerySinglePhrase, 79 | result_type=EvidencedPhrase, 80 | system_prompt=system_prompt, 81 | few_shot_examples=few_shot_examples_single_phrase, 82 | format_query_fn=formatter, 83 | format_output_fn=formatter, 84 | log_dir=log_dir, 85 | use_schema_in_system_prompt=False, 86 | output_format=FormatStyleOptions.YAML, 87 | ) 88 | logger.info("Initialized the processor for single phrase entailment verification.") 89 | return processor 90 | 91 | 92 | class ReportGroundingNLIProcessor(BaseProcessor[NLIQuerySample, NLISample]): 93 | NUM_LLM_FAILURES = "num_llm_failures" 94 | NUM_LLM_SUCCESS = "num_llm_success" 95 | NUM_LLM_PHRASE_REWRITES = "num_llm_phrase_rewrites" 96 | 97 | def __init__(self, format_query_fn: Callable[..., Any] | None = None) -> None: 98 | super().__init__() 99 | self.format_query_fn = format_query_fn 100 | self.phrase_processor = get_ev_processor_singlephrase(log_dir=OUTPUT_DIR / "ev_processor_logs") 101 | # Logging errors 102 | self.num_llm_failures = 0 103 | self.num_llm_success = 0 104 | self.num_llm_phrase_rewrites = 0 # Sometimes it rewrites the phrase, which is not ideal 105 | self.response_dict: dict[DirectionOptions, list[EvidencedPhrase]] = {} 106 | 107 | def run_processor_on_single_phrase( 108 | self, single_phrase: NLIQuerySampleSinglePhrase, query_id: str 109 | ) -> EvidencedPhrase: 110 | """ 111 | Run the processor on a single phrase. 112 | 113 | If LLM fails to respond, we return a default NOT_ENTAILMENT with no evidence. 114 | If LLM tries to rephrase the input, we log a warning and correct it. 115 | """ 116 | single_response = self.phrase_processor.run(query=single_phrase.input, query_id=query_id) 117 | 118 | if single_response is None: 119 | logger.warning(f"WARNING: No response for example {query_id}. Setting as NOT ENTAILED.") 120 | single_response = EvidencedPhrase( 121 | phrase=single_phrase.input.hypothesis, status=EVState.NOT_ENTAILMENT, evidence=[] 122 | ) 123 | self.num_llm_failures += 1 124 | else: 125 | self.num_llm_success += 1 126 | # There is a chance that the LLM rewrites the original input phrase somehow, so we need to check. 127 | # If it does rewrite it, we log a warning and correct it. 128 | phrase_from_llm = single_response.phrase 129 | if phrase_from_llm != single_phrase.input.hypothesis: 130 | self.num_llm_phrase_rewrites += 1 131 | logger.warning( 132 | "WARNING: LLM has rewritten the input phrase. " 133 | f"Original: '{single_phrase.input.hypothesis}' Rewritten: '{phrase_from_llm}'" 134 | ) 135 | single_response = single_response.copy(update={"phrase": single_phrase.input.hypothesis}) 136 | return single_response 137 | 138 | def set_model(self, model: BaseLanguageModel[str] | BaseLanguageModel[BaseMessage]) -> None: 139 | self.phrase_processor.set_model(model) 140 | 141 | def run(self, query: NLIQuerySample | Any, query_id: str) -> NLISample | None: 142 | if self.format_query_fn is not None: 143 | query = self.format_query_fn(query) 144 | assert isinstance( 145 | query, NLIQuerySample 146 | ), f"Query must be an NLIQuerySample, got {type(query)}. Provide a format_query_fn to convert it." 147 | phrase_level_examples = NLIQuerySampleSinglePhrase.from_nli_query_sample(query) 148 | for direction, phrase_list in phrase_level_examples.items(): 149 | processed_list: list[EvidencedPhrase] = [] 150 | for single_phrase in phrase_list: 151 | single_response = self.run_processor_on_single_phrase(single_phrase, query_id=query_id) 152 | processed_list.append(single_response) 153 | self.response_dict[direction] = processed_list 154 | try: 155 | output = NLISample.from_pair_of_unidirectional_lists( 156 | example_id=query_id, 157 | A_to_B=self.response_dict[DirectionOptions.A_TO_B], 158 | B_to_A=self.response_dict[DirectionOptions.B_TO_A], 159 | ) 160 | return output 161 | except ValidationError as e: 162 | logger.warning(f"WARNING: Validation error for example {query_id}. Skipping.") 163 | logger.warning(e) 164 | return None 165 | 166 | def get_processor_stats(self) -> ProcessorStats: 167 | return { 168 | self.NUM_LLM_FAILURES: self.num_llm_failures, 169 | self.NUM_LLM_SUCCESS: self.num_llm_success, 170 | self.NUM_LLM_PHRASE_REWRITES: self.num_llm_phrase_rewrites, 171 | } 172 | 173 | def aggregate_processor_stats(self, stats_per_processor: dict[str, ProcessorStats]) -> ProcessorStats: 174 | result: ProcessorStats = {} 175 | for _, stats in stats_per_processor.items(): 176 | for key, value in stats.items(): 177 | result[key] = result.get(key, 0) + value 178 | return result 179 | 180 | 181 | def format_row_to_nli_query_sample(row: "pd.Series[Any]") -> NLIQuerySample: 182 | return NLIQuerySample.from_grounded_phrases_list_pair( 183 | example_id=row[MetricDataframeKeys.STUDY_ID], 184 | candidate=row[MetricDataframeKeys.CANDIDATE], 185 | reference=row[MetricDataframeKeys.REFERENCE], 186 | ) 187 | 188 | 189 | def get_report_nli_engine( 190 | cfg: DictConfig, candidates: dict[str, GroundedPhraseList], references: dict[str, GroundedPhraseList] 191 | ) -> LLMEngine: 192 | output_folder = get_subfolder(root=OUTPUT_DIR, subfolder=RADFACT_SUBFOLDER) 193 | nli_report_processor = ReportGroundingNLIProcessor(format_query_fn=format_row_to_nli_query_sample) 194 | dataset_df = pd.DataFrame( 195 | { 196 | MetricDataframeKeys.STUDY_ID: study_id, 197 | MetricDataframeKeys.CANDIDATE: candidates[study_id], 198 | MetricDataframeKeys.REFERENCE: references[study_id], 199 | } 200 | for study_id in candidates.keys() 201 | ) 202 | engine = LLMEngine( 203 | cfg=cfg, 204 | processor=nli_report_processor, 205 | dataset_df=dataset_df, 206 | progress_output_folder=output_folder, 207 | final_output_folder=output_folder, 208 | ) 209 | return engine 210 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/nli/prompts/system_message_ev_singlephrase.txt: -------------------------------------------------------------------------------- 1 | You are an AI radiology assistant. Your task is to assess whether a statement about a chest X-ray (the "hypothesis") is true or not, given a reference report about the chest X-ray. This task is known as entailment verification. If the statement is true ("entailed") according to the reference, provide the evidence to support it. 2 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/nli/schema.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import json 7 | from enum import Enum 8 | from pathlib import Path 9 | from typing import Any 10 | import logging 11 | from pydantic import BaseModel, Field, root_validator 12 | 13 | from radfact.data_utils.grounded_phrase_list import GroundedPhraseList 14 | from radfact.llm_utils.processor.base_processor import BaseModelWithId 15 | from radfact.llm_utils.text_utils import normalise_text_for_comparison 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class DirectionOptions(str, Enum): 21 | """ 22 | Keys for the direction of the comparison. 23 | """ 24 | 25 | B_TO_A = "B_to_A" # Given B, is A true? 26 | A_TO_B = "A_to_B" # Given A, is B true? 27 | 28 | 29 | class NLIState(str, Enum): 30 | """State of phrase from report.""" 31 | 32 | ENTAILMENT = "entailment" 33 | CONTRADICTION = "contradiction" 34 | NEUTRAL = "neutral" 35 | 36 | 37 | class EVState(str, Enum): 38 | """Entailment verification state.""" 39 | 40 | ENTAILMENT = "entailment" 41 | NOT_ENTAILMENT = "not_entailment" 42 | 43 | @classmethod 44 | def from_nli_state(cls, nli_state: NLIState) -> "EVState": 45 | """ 46 | Convert an NLIState to an EVState. 47 | ENTAILMENT --> ENTAILMENT, else NOT_ENTAILMENT. 48 | 49 | Returns a ValueError if the NLIState is not recognized. 50 | """ 51 | if nli_state == NLIState.ENTAILMENT: 52 | return EVState.ENTAILMENT 53 | elif nli_state in [NLIState.CONTRADICTION, NLIState.NEUTRAL]: 54 | return EVState.NOT_ENTAILMENT 55 | else: 56 | raise ValueError(f"Unrecognized NLIState: {nli_state}.") 57 | 58 | 59 | class ComparisonQuery(BaseModel): 60 | phrases_A: list[str] = Field(description="Phrases from report A.") 61 | phrases_B: list[str] = Field(description="Phrases from report B.") 62 | 63 | 64 | class ComparisonQuerySinglePhrase(BaseModel): 65 | reference: list[str] = Field(description="Reference report.") 66 | hypothesis: str = Field(description="Phrase to assess.") 67 | 68 | 69 | class EvidencedPhrase(BaseModel): 70 | phrase: str = Field(description="Phrase from report.") 71 | evidence: list[str] = Field(description="Phrase(s) from reference report supporting the logical state.") 72 | # Note that the status could either be NLIState or EVState 73 | status: str = Field(description="Logical state of phrase given reference report.") 74 | 75 | def convert_to_binary(self) -> "EvidencedPhrase": 76 | """Convert the status to binary.""" 77 | try: 78 | _ = EVState(self.status) 79 | # If this succeeds, we are already binary. 80 | return self 81 | except ValueError: 82 | new_status = EVState.from_nli_state(NLIState(self.status)) 83 | return self.copy(update={"status": new_status.value}) 84 | 85 | @root_validator 86 | @classmethod 87 | def evidence_exists_or_not(cls, values: dict[str, Any]) -> dict[str, Any]: 88 | """ 89 | Entailed phrases always need evidence. 90 | Neutral phrases should have no evidence. 91 | Contradicted phrases should have evidence. 92 | Not-entailed phrases can either have or not have evidence (since they are contradiction OR neutral). 93 | """ 94 | status = values["status"] 95 | evidence = values["evidence"] 96 | # Entailment --> evidence 97 | if status == NLIState.ENTAILMENT or status == EVState.ENTAILMENT: 98 | if len(evidence) == 0: 99 | raise ValueError(f"Entailed phrases should have evidence. {values['phrase']=}") 100 | # Neutral --> no evidence 101 | elif status == NLIState.NEUTRAL: 102 | if len(evidence) > 0: 103 | raise ValueError(f"Neutral phrases should not have evidence. {values['phrase']=}") 104 | # Contradiction --> evidence 105 | elif status == NLIState.CONTRADICTION: 106 | if len(evidence) == 0: 107 | raise ValueError(f"Contradicted phrases should have evidence. {values['phrase']=}") 108 | elif status == EVState.NOT_ENTAILMENT: 109 | # Not-entailed phrases can either have or not have evidence (since they are contradiction OR neutral). 110 | pass 111 | else: 112 | raise ValueError(f"Unrecognized status: {status}.") 113 | return values 114 | 115 | def pretty_format(self) -> str: 116 | return f"{self.status.ljust(15)}|{self.phrase}|{self.evidence}" 117 | 118 | 119 | class BidirectionalEvidence(BaseModel): 120 | phrases_A_evidenced: list[EvidencedPhrase] = Field( 121 | description="Phrases from report A with logical state and supporting evidence." 122 | ) 123 | phrases_B_evidenced: list[EvidencedPhrase] = Field( 124 | description="Phrases from report B with logical state and supporting evidence." 125 | ) 126 | 127 | def pretty_format(self) -> str: 128 | def _format_single_direction(evidenced_phrases: list[EvidencedPhrase]) -> str: 129 | return "\n".join(phrase.pretty_format() for phrase in evidenced_phrases) 130 | 131 | overall_output = "=== Phrases A ===\n" 132 | overall_output += _format_single_direction(self.phrases_A_evidenced) 133 | overall_output += "\n=== Phrases B ===\n" 134 | overall_output += _format_single_direction(self.phrases_B_evidenced) 135 | return overall_output 136 | 137 | 138 | class NLIQuerySample(BaseModelWithId): 139 | """ 140 | A single sample for the NLI task. 141 | Does not include an output as it is not necessarily for demonstration/testing. 142 | """ 143 | 144 | example_id: str 145 | input: ComparisonQuery 146 | 147 | @classmethod 148 | def from_grounded_phrases_list_pair( 149 | cls, example_id: str, candidate: GroundedPhraseList, reference: GroundedPhraseList 150 | ) -> "NLIQuerySample": 151 | """ 152 | Create an NLIQuerySample from a pair of `GroundedPhraseList` instances (candidate and reference). 153 | """ 154 | candidate_grounded_phrase = [ 155 | phrase.text for phrase in candidate.get_all_grounded_phrases(fail_if_non_grounded_phrase=True) 156 | ] 157 | reference_grounded_phrase = [ 158 | phrase.text for phrase in reference.get_all_grounded_phrases(fail_if_non_grounded_phrase=True) 159 | ] 160 | return NLIQuerySample( 161 | example_id=example_id, 162 | input=ComparisonQuery(phrases_A=candidate_grounded_phrase, phrases_B=reference_grounded_phrase), 163 | ) 164 | 165 | 166 | class NLISample(NLIQuerySample): 167 | """ 168 | A single sample with output for the NLI task. 169 | Enables validation of the output. 170 | """ 171 | 172 | output: BidirectionalEvidence 173 | 174 | @classmethod 175 | def from_pair_of_unidirectional_lists( 176 | cls, example_id: str, A_to_B: list[EvidencedPhrase], B_to_A: list[EvidencedPhrase] 177 | ) -> "NLISample": 178 | """ 179 | Create an NLISample from a pair of one-way samples. 180 | """ 181 | return NLISample( 182 | example_id=example_id, 183 | input=ComparisonQuery(phrases_A=[x.phrase for x in B_to_A], phrases_B=[x.phrase for x in A_to_B]), 184 | output=BidirectionalEvidence(phrases_A_evidenced=B_to_A, phrases_B_evidenced=A_to_B), 185 | ) 186 | 187 | @root_validator(skip_on_failure=True) 188 | @classmethod 189 | def input_and_output_phrases_match(cls, values: dict[str, Any]) -> dict[str, Any]: 190 | """Every original phrase must appear in the output. Every evidenced phrase must appear in the input.""" 191 | output: BidirectionalEvidence = values["output"] 192 | input: ComparisonQuery = values["input"] 193 | phrase_pairings = { 194 | "phrases_A": (input.phrases_A, output.phrases_A_evidenced), 195 | "phrases_B": (input.phrases_B, output.phrases_B_evidenced), 196 | } 197 | missing_phrases: dict[str, list[str]] = {} 198 | added_phrases: dict[str, list[str]] = {} 199 | for phrase_list_name, (phrases, evidenced_phrases) in phrase_pairings.items(): 200 | missing_phrases[phrase_list_name] = [x for x in phrases if x not in [y.phrase for y in evidenced_phrases]] 201 | added_phrases[phrase_list_name] = [x.phrase for x in evidenced_phrases if x.phrase not in phrases] 202 | if any(missing_phrases.values()): 203 | raise ValueError(f"Phrases should have been classified but are not in output: {missing_phrases}.") 204 | if any(added_phrases.values()): 205 | raise ValueError(f"Phrases should not have been classified but are in output: {added_phrases}.") 206 | return values 207 | 208 | @root_validator(skip_on_failure=True) 209 | @classmethod 210 | def evidence_from_correct_report(cls, values: dict[str, Any]) -> dict[str, Any]: 211 | """Check that all evidence comes from the appropriate report.""" 212 | output: BidirectionalEvidence = values["output"] 213 | input: ComparisonQuery = values["input"] 214 | cleaned_phrases_A = [normalise_text_for_comparison(phrase) for phrase in input.phrases_A] 215 | cleaned_phrases_B = [normalise_text_for_comparison(phrase) for phrase in input.phrases_B] 216 | 217 | def _confirm_evidence_is_in_expected_report( 218 | phrase_with_evidence: EvidencedPhrase, report_phrases: list[str] 219 | ) -> None: 220 | """All the evidence from a phrase should be in the expected report. 221 | If phrase is from A, evidence is from B. 222 | """ 223 | for supporting_phrase in phrase_with_evidence.evidence: 224 | if normalise_text_for_comparison(supporting_phrase) not in report_phrases: 225 | raise ValueError( 226 | f"Evidence for {phrase_with_evidence.phrase} comes from {supporting_phrase}, " 227 | f"which is not in the expected report. Expected it to be one of {report_phrases}." 228 | ) 229 | 230 | for phrase_with_evidence in output.phrases_A_evidenced: 231 | _confirm_evidence_is_in_expected_report(phrase_with_evidence, cleaned_phrases_B) 232 | for phrase_with_evidence in output.phrases_B_evidenced: 233 | _confirm_evidence_is_in_expected_report(phrase_with_evidence, cleaned_phrases_A) 234 | return values 235 | 236 | @root_validator 237 | @classmethod 238 | def evidenced_phrases_exist_in_original_list(cls, values: dict[str, Any]) -> dict[str, Any]: 239 | """All evidenced phrases must be found in the original phrase list.""" 240 | if "output" not in values: 241 | raise ValueError("No BidirectionalEvidence? Must have failed validation.") 242 | output: BidirectionalEvidence = values["output"] 243 | input: ComparisonQuery = values["input"] 244 | for phrase_with_evidence in output.phrases_A_evidenced: 245 | if phrase_with_evidence.phrase not in input.phrases_A: 246 | raise ValueError( 247 | f"Evidenced phrase {phrase_with_evidence.phrase} not found in phrase A list ({input.phrases_A})." 248 | ) 249 | for phrase_with_evidence in output.phrases_B_evidenced: 250 | if phrase_with_evidence.phrase not in input.phrases_B: 251 | raise ValueError( 252 | f"Evidenced phrase {phrase_with_evidence.phrase} not found in phrase B list ({input.phrases_B})." 253 | ) 254 | return values 255 | 256 | 257 | class NLIQuerySampleSinglePhrase(BaseModel): 258 | """A single sample for the NLI task, assuming we only go one phrase at a time.""" 259 | 260 | example_id: str 261 | input: ComparisonQuerySinglePhrase 262 | 263 | @classmethod 264 | def from_nli_query_sample( 265 | cls, nli_query_sample: NLIQuerySample 266 | ) -> dict[DirectionOptions, list["NLIQuerySampleSinglePhrase"]]: 267 | """ 268 | Create a dict list of NLIQuerySampleSinglePhrase instances from an NLIQuerySample. 269 | We return a dict so we can do A-to-B and B-to-A separately. 270 | """ 271 | B_to_A = [ 272 | NLIQuerySampleSinglePhrase( 273 | example_id=nli_query_sample.example_id, 274 | input=ComparisonQuerySinglePhrase(reference=nli_query_sample.input.phrases_B, hypothesis=phrase_A), 275 | ) 276 | for phrase_A in nli_query_sample.input.phrases_A 277 | ] 278 | A_to_B = [ 279 | NLIQuerySampleSinglePhrase( 280 | example_id=nli_query_sample.example_id, 281 | input=ComparisonQuerySinglePhrase(reference=nli_query_sample.input.phrases_A, hypothesis=phrase_B), 282 | ) 283 | for phrase_B in nli_query_sample.input.phrases_B 284 | ] 285 | return {DirectionOptions.B_TO_A: B_to_A, DirectionOptions.A_TO_B: A_to_B} 286 | 287 | 288 | class NLISampleSinglePhrase(BaseModel): 289 | """A single sample with output for the NLI task, assuming we only go one phrase at a time.""" 290 | 291 | example_id: str 292 | input: ComparisonQuerySinglePhrase 293 | output: EvidencedPhrase 294 | 295 | @classmethod 296 | def from_nli_sample(cls, nli_sample: NLISample) -> dict[DirectionOptions, list["NLISampleSinglePhrase"]]: 297 | """ 298 | Create a dict list of NLISampleSinglePhrase instances from an NLISample. 299 | We return a dict so we can do A-to-B and B-to-A separately. 300 | """ 301 | B_to_A = [ 302 | NLISampleSinglePhrase( 303 | example_id=nli_sample.example_id, 304 | input=ComparisonQuerySinglePhrase( 305 | reference=nli_sample.input.phrases_B, hypothesis=evidenced_phrase_A.phrase 306 | ), 307 | output=evidenced_phrase_A, 308 | ) 309 | for evidenced_phrase_A in nli_sample.output.phrases_A_evidenced 310 | ] 311 | A_to_B = [ 312 | NLISampleSinglePhrase( 313 | example_id=nli_sample.example_id, 314 | input=ComparisonQuerySinglePhrase( 315 | reference=nli_sample.input.phrases_A, hypothesis=evidenced_phrase_B.phrase 316 | ), 317 | output=evidenced_phrase_B, 318 | ) 319 | for evidenced_phrase_B in nli_sample.output.phrases_B_evidenced 320 | ] 321 | return {DirectionOptions.B_TO_A: B_to_A, DirectionOptions.A_TO_B: A_to_B} 322 | 323 | 324 | def load_examples_from_json(json_path: Path, binary: bool = True) -> list[NLISample]: 325 | """ 326 | Helper function to load NLISamples from a json. 327 | 328 | :param json_path: Path to the json we wish to load from. 329 | :param binary: Whether to convert NLIState to EVState (ENTAILMENT v. NOT_ENTAILMENT). 330 | """ 331 | with open(json_path, "r", encoding="utf-8") as f: 332 | examples_json = json.load(f) 333 | if binary: 334 | # Convert NLIState to EVState (ENTAILMENT v. NOT_ENTAILMENT) 335 | samples = [] 336 | for example_json in examples_json: 337 | nli_sample = NLISample.parse_obj(example_json) 338 | nli_sample.output.phrases_A_evidenced = [ 339 | phrase.convert_to_binary() for phrase in nli_sample.output.phrases_A_evidenced 340 | ] 341 | nli_sample.output.phrases_B_evidenced = [ 342 | phrase.convert_to_binary() for phrase in nli_sample.output.phrases_B_evidenced 343 | ] 344 | samples.append(nli_sample) 345 | for evidenced_phrase in nli_sample.output.phrases_A_evidenced: 346 | _ = EVState(evidenced_phrase.status) 347 | for evidenced_phrase in nli_sample.output.phrases_B_evidenced: 348 | _ = EVState(evidenced_phrase.status) 349 | else: 350 | samples = [NLISample.parse_obj(example_json) for example_json in examples_json] 351 | logger.info(f"Loaded {len(samples)} examples.") 352 | return samples 353 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/processor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RadFact/72463ac34442647886f13fad7d04958ab6c7c294/src/radfact/llm_utils/processor/__init__.py -------------------------------------------------------------------------------- /src/radfact/llm_utils/processor/base_processor.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | from abc import ABCMeta, abstractmethod 7 | from typing import Any, Generic, TypeVar 8 | 9 | from langchain_core.language_models import BaseLanguageModel 10 | from langchain_core.messages import BaseMessage 11 | from pydantic import BaseModel 12 | 13 | QueryT = TypeVar("QueryT") 14 | BaseResultT = TypeVar("BaseResultT") 15 | 16 | 17 | class BaseModelWithId(BaseModel): 18 | """Base class for models that have an ID.""" 19 | 20 | id: int | str | None = None 21 | 22 | 23 | class BaseProcessor(Generic[QueryT, BaseResultT], metaclass=ABCMeta): 24 | """Base class for processors that interact with language models.""" 25 | 26 | @abstractmethod 27 | def set_model(self, model: BaseLanguageModel[str] | BaseLanguageModel[BaseMessage]) -> None: 28 | raise NotImplementedError 29 | 30 | @abstractmethod 31 | def run(self, query: QueryT, query_id: str) -> BaseResultT | None: 32 | raise NotImplementedError 33 | 34 | def get_processor_stats(self) -> Any: 35 | """Return statistics that the processor collects.""" 36 | return None 37 | 38 | def aggregate_processor_stats(self, stats_per_processor: dict[str, Any]) -> Any: 39 | """Aggregate statistics from multiple processors. 40 | 41 | :param stats_per_processor: A dictionary of statistics from multiple processors. The dictionary key is the 42 | processor ID (usually the endpoint name), the dictionary value is the statistics returned by 43 | `get_processor_stats()`. 44 | :return: The aggregated statistics, which should be of the same type as returned by `get_processor_stats()` 45 | """ 46 | return None 47 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/processor/structured_processor.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | from enum import Enum 8 | from functools import partial 9 | from pathlib import Path 10 | from typing import Any, Callable, Generic, Iterable, Protocol, TypeVar 11 | 12 | import yaml 13 | from langchain.output_parsers import PydanticOutputParser, YamlOutputParser 14 | from langchain_core.language_models import BaseLanguageModel 15 | from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage 16 | from langchain_core.prompts import BaseChatPromptTemplate 17 | from pydantic import BaseModel 18 | 19 | from radfact.llm_utils.processor.base_processor import BaseProcessor, QueryT 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | _QUERY_KEY = "query" 24 | ResultT = TypeVar("ResultT", bound=BaseModel) 25 | ProcessorStats = dict[str, int] 26 | 27 | 28 | class FormatStyleOptions(str, Enum): 29 | """Options for formatting the output of a class.""" 30 | 31 | JSON = "json" 32 | YAML = "yaml" 33 | 34 | 35 | def simple_formatter(obj: BaseModel, style: FormatStyleOptions) -> str: 36 | """Format a BaseModel instance for output. 37 | 38 | :param obj: The Pydantic object to format. 39 | :param style: The desired output format. 40 | :return: The formatted string. 41 | """ 42 | match style: 43 | case FormatStyleOptions.JSON: 44 | return str(obj.json()) 45 | case FormatStyleOptions.YAML: 46 | return yaml.dump(obj.dict(), sort_keys=False) 47 | case _: 48 | raise ValueError(f"Unrecognized format style: {style}.") 49 | 50 | 51 | class Example(Protocol, Generic[QueryT, ResultT]): 52 | """Interface for any object with `input` and `output` attributes representing a processed example.""" 53 | 54 | input: QueryT 55 | output: ResultT 56 | 57 | 58 | class QueryTemplate(BaseChatPromptTemplate, Generic[QueryT, ResultT]): 59 | """Query template for a structured processor.""" 60 | 61 | system_prompt: str 62 | format_query_fn: Callable[[QueryT], str] 63 | format_output_fn: Callable[[ResultT], str] 64 | examples: Iterable[Example[QueryT, ResultT]] | None 65 | 66 | def __init__( 67 | self, 68 | system_prompt: str, 69 | query_type: type[QueryT], 70 | format_query_fn: Callable[[QueryT], str], 71 | format_output_fn: Callable[[ResultT], str], 72 | examples: Iterable[Example[QueryT, ResultT]] | None = None, 73 | ) -> None: 74 | super().__init__( # type: ignore 75 | input_variables=[_QUERY_KEY], 76 | input_types={_QUERY_KEY: query_type}, 77 | system_prompt=system_prompt, 78 | format_query_fn=format_query_fn, 79 | format_output_fn=format_output_fn, 80 | ) 81 | self.examples = examples 82 | 83 | def prepare_few_shot_examples(self, examples: Iterable[Example[QueryT, ResultT]] | None) -> list[BaseMessage]: 84 | """Prepare few shot examples as human-assistant message pairs. 85 | 86 | :param examples: List of few shot examples. 87 | :return: List of messages (HumanMessage, AIMessage) for the chat prompt. 88 | """ 89 | few_shot_messages: list[BaseMessage] = [] 90 | if not examples: 91 | return few_shot_messages 92 | for example in examples: 93 | human_query = self.format_query_fn(example.input) 94 | ai_response = self.format_output_fn(example.output) 95 | few_shot_messages.append(HumanMessage(content=human_query)) 96 | few_shot_messages.append(AIMessage(content=ai_response)) 97 | return few_shot_messages 98 | 99 | def format_messages(self, **kwargs: Any) -> list[BaseMessage]: 100 | """Format the final chat prompt messages including the system prompt, few-shot examples, and the query.""" 101 | query: QueryT = kwargs[_QUERY_KEY] 102 | formatted_query = self.format_query_fn(query) 103 | 104 | few_shot_messages = self.prepare_few_shot_examples( 105 | examples=self.examples 106 | ) # Not yet sure how to precompute this 107 | return ( 108 | [SystemMessage(content=self.system_prompt)] 109 | + few_shot_messages # Note this is empty if none are provided 110 | + [HumanMessage(content=formatted_query)] 111 | ) 112 | 113 | 114 | class StructuredProcessor(BaseProcessor[QueryT, ResultT]): 115 | """Processor for structured queries and results.""" 116 | 117 | NUM_FAILURES = "num_failures" 118 | NUM_SUCCESS = "num_success" 119 | 120 | def __init__( 121 | self, 122 | query_type: type[QueryT], 123 | result_type: type[ResultT], 124 | system_prompt: str | Path, 125 | format_query_fn: Callable[[QueryT], str], 126 | format_output_fn: Callable[[ResultT], str] | None = None, 127 | model: BaseLanguageModel[str] | BaseLanguageModel[BaseMessage] | None = None, 128 | log_dir: Path | None = None, 129 | few_shot_examples: Iterable[Example[QueryT, ResultT]] | None = None, 130 | validate_result_fn: Callable[[QueryT, ResultT], None] | None = None, 131 | use_schema_in_system_prompt: bool = True, 132 | output_format: FormatStyleOptions = FormatStyleOptions.JSON, 133 | ) -> None: 134 | """ 135 | :param query_type: Type of the input query, e.g. `str`, `pd.Series`, or a Pydantic `BaseModel`. 136 | :param result_type: Type of the structured Pydantic output. 137 | :param system_prompt: The part of system message describing the desired model behaviour. This will be 138 | complemented by a description of the output JSON schema for `result_type`. 139 | :param format_query_fn: A function to format the query into a "human" chat message for the model. 140 | :param format_output_fn: A function to format the (expected) output into an "AI" chat message for the model. 141 | If not provided, the output will be formatted as JSON. 142 | :param model: Langchain model to use for processing. 143 | :param log_dir: If given, directory where to save error log files containing details of each failed query, named 144 | `error_{query_id}.txt`. Otherwise, errors will only be printed to stdout. 145 | :param few_shot_examples: Optional list of few-shot examples to be included in the prompt as human-assistant 146 | message pairs. Each example can be any Python object with `input` and `output` attributes of `query_type` 147 | and `result_type`, respectively. 148 | :param validate_result_fn: Optional function to validate the result of each query. It should take the query and 149 | the result and raise an exception if the result is invalid. For example, this may leverage Pydantic's 150 | validation mechanics. 151 | :param use_schema_in_system_prompt: If True, the system prompt will also include a description of the 152 | expected output schema for `result_type`, as generated by the output parser (either YAML or JSON). 153 | :param output_format: The format of the output. Either "json" or "yaml". 154 | """ 155 | self.query_type = query_type 156 | self.result_type = result_type 157 | 158 | self.system_prompt = system_prompt if isinstance(system_prompt, str) else system_prompt.read_text() 159 | self.format_query_fn = format_query_fn 160 | self.output_format = output_format 161 | if format_output_fn is None: 162 | # Use the generic formatter based on the output format style 163 | self.format_output_fn: Callable[[ResultT], str] = partial(simple_formatter, style=self.output_format) 164 | else: 165 | self.format_output_fn = format_output_fn 166 | 167 | self.validate_result_fn = validate_result_fn 168 | 169 | self.parser: PydanticOutputParser[ResultT] | YamlOutputParser[ResultT] 170 | match self.output_format: 171 | case FormatStyleOptions.YAML: 172 | self.parser = YamlOutputParser(pydantic_object=result_type) 173 | case FormatStyleOptions.JSON: 174 | self.parser = PydanticOutputParser(pydantic_object=result_type) 175 | case _: 176 | raise ValueError( 177 | f"Unrecognized output format: {self.output_format}. Should be one of {FormatStyleOptions}." 178 | ) 179 | 180 | if use_schema_in_system_prompt: 181 | self.system_prompt += "\n\n" + self.parser.get_format_instructions() 182 | 183 | self.query_template = QueryTemplate( 184 | system_prompt=self.system_prompt, 185 | query_type=query_type, 186 | format_query_fn=self.format_query_fn, 187 | format_output_fn=self.format_output_fn, 188 | examples=few_shot_examples, 189 | ) 190 | self.log_dir = log_dir 191 | self.model = model 192 | # For logging 193 | self.num_failures: int = 0 194 | self.num_success: int = 0 195 | 196 | def set_model(self, model: BaseLanguageModel[str] | BaseLanguageModel[BaseMessage]) -> None: 197 | self.model = model 198 | 199 | def _write_error(self, ex: Exception, query: QueryT, query_id: str) -> None: 200 | formatted_query = self.format_query_fn(query) 201 | error_message = f"{ex}\n----\nQuery {query_id=}:\n{formatted_query=}" 202 | if self.log_dir: 203 | self.log_dir.mkdir(exist_ok=True, parents=True) 204 | error_log_path = self.log_dir / f"error_{query_id}.txt" 205 | error_log_path.write_text(error_message) 206 | logger.info(f"Error details saved to {error_log_path}.") 207 | 208 | def run(self, query: QueryT, query_id: str) -> ResultT | None: 209 | assert self.model, "Model not set. Call `set_model` first." 210 | chain = self.query_template | self.model | self.parser 211 | try: 212 | response: ResultT = chain.invoke({_QUERY_KEY: query}) 213 | if self.validate_result_fn: 214 | self.validate_result_fn(query, response) 215 | self.num_success += 1 216 | return response 217 | except Exception as ex: 218 | self._write_error(ex, query, query_id) 219 | self.num_failures += 1 220 | return None 221 | 222 | def get_processor_stats(self) -> ProcessorStats: 223 | return { 224 | self.NUM_FAILURES: self.num_failures, 225 | self.NUM_SUCCESS: self.num_success, 226 | } 227 | 228 | def aggregate_processor_stats(self, stats_per_processor: dict[str, ProcessorStats]) -> ProcessorStats: 229 | result: ProcessorStats = {} 230 | for _, stats in stats_per_processor.items(): 231 | for key, value in stats.items(): 232 | result[key] = result.get(key, 0) + value 233 | return result 234 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/report_to_phrases/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RadFact/72463ac34442647886f13fad7d04958ab6c7c294/src/radfact/llm_utils/report_to_phrases/__init__.py -------------------------------------------------------------------------------- /src/radfact/llm_utils/report_to_phrases/processor.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | from pathlib import Path 7 | from typing import Any 8 | 9 | import pandas as pd 10 | from omegaconf import DictConfig 11 | 12 | from radfact.llm_utils.engine.engine import LLMEngine, get_subfolder 13 | from radfact.llm_utils.processor.structured_processor import StructuredProcessor 14 | from radfact.llm_utils.report_to_phrases.schema import ParsedReport, load_examples_from_json 15 | from radfact.paths import OUTPUT_DIR, get_prompts_dir 16 | 17 | FINDINGS_SECTION = "FINDINGS" 18 | PARSING_TASK = "report_to_phrases" 19 | PROMPTS_DIR = get_prompts_dir(task=PARSING_TASK) 20 | StudyIdType = str | int 21 | 22 | 23 | def get_report_to_phrases_processor(log_dir: Path | None = None) -> StructuredProcessor[str, ParsedReport]: 24 | """Return a processor for converting reports to phrases. 25 | 26 | :param log_dir: The directory to save logs. 27 | :return: The processor for report to phrase conversion. 28 | """ 29 | system_message_path = PROMPTS_DIR / "system_message.txt" 30 | few_shot_examples_path = PROMPTS_DIR / "few_shot_examples.json" 31 | system_prompt = system_message_path.read_text() 32 | few_shot_examples = load_examples_from_json(few_shot_examples_path) 33 | processor = StructuredProcessor( 34 | query_type=str, 35 | result_type=ParsedReport, 36 | system_prompt=system_prompt, 37 | format_query_fn=lambda x: x, # Our query is simply the findings text 38 | few_shot_examples=few_shot_examples, # type: ignore[arg-type] 39 | log_dir=log_dir, 40 | ) 41 | return processor 42 | 43 | 44 | def get_findings_from_row(row: "pd.Series[Any]") -> str: 45 | """Get the findings from a row in a DataFrame.""" 46 | findings = row[FINDINGS_SECTION] 47 | assert isinstance(findings, str), f"Findings should be a string, got {findings}" 48 | return findings 49 | 50 | 51 | def get_report_to_phrases_engine(cfg: DictConfig, dataset_df: pd.DataFrame) -> LLMEngine: 52 | """ 53 | Create the processing engine for converting reports to phrases. 54 | 55 | :param cfg: The configuration for the processing engine. 56 | :param dataset_df: The dataset DataFrame. 57 | :param subfolder: The subfolder to save the processing output. 58 | :return: The processing engine. 59 | """ 60 | subfolder = cfg.dataset.name 61 | root = OUTPUT_DIR / PARSING_TASK 62 | output_folder = get_subfolder(root, subfolder) 63 | final_output_folder = get_subfolder(root, subfolder) 64 | log_dir = get_subfolder(root, "logs") 65 | 66 | report_to_phrases_processor = get_report_to_phrases_processor(log_dir=log_dir) 67 | id_col = cfg.processing.index_col 68 | dataset_df = dataset_df[[id_col, FINDINGS_SECTION]] 69 | engine = LLMEngine( 70 | cfg=cfg, 71 | processor=report_to_phrases_processor, 72 | dataset_df=dataset_df, 73 | progress_output_folder=output_folder, 74 | final_output_folder=final_output_folder, 75 | row_to_query_fn=get_findings_from_row, 76 | ) 77 | return engine 78 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/report_to_phrases/prompts/system_message.txt: -------------------------------------------------------------------------------- 1 | You are an AI radiology assistant. You are helping process reports from chest X-rays. 2 | 3 | Please extract phrases from the radiology report which refer to objects, findings, or anatomies visible in a chest X-ray, or the absence of such. 4 | 5 | Rules: 6 | - If a sentence describes multiple findings, split them up into separate sentences. 7 | - Exclude clinical speculation or interpretation (e.g. "... highly suggestive of pneumonia"). 8 | - Exclude recommendations (e.g. "Recommend a CT"). 9 | - Exclude comments on the technical quality of the X-ray (e.g. "there are low lung volumes"). 10 | - Include mentions of change (e.g. "Pleural effusion has increased") because change is visible when we compare two X-rays. 11 | - If consecutive sentences are closely linked such that one sentence can't be understood without the other one, process them together. 12 | 13 | The objective is to extract phrases which refer to things which can be located on a chest X-ray, or confirmed not to be present. 14 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/report_to_phrases/schema.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import json 7 | from pathlib import Path 8 | from typing import Any, Callable, Dict, List 9 | 10 | from pydantic import BaseModel, root_validator 11 | 12 | from radfact.data_utils.grounded_phrase_list import GroundedPhrase, GroundedPhraseList 13 | from radfact.llm_utils.processor.base_processor import BaseModelWithId 14 | 15 | 16 | class SentenceWithRephrases(BaseModel): 17 | """Dataclass for a sentence with rephrases. The source sentence is 'orig' and the rephrased sentences are 'new'.""" 18 | 19 | orig: str 20 | new: list[str] 21 | 22 | 23 | class ParsedReport(BaseModelWithId): 24 | """ 25 | How we represent a parsed report. This applies to the model output and the examples. 26 | The `ParsedReport` is a list of `SentenceWithRephrases`. 27 | Each `SentenceWithRephrases` has an 'orig' sentence and a list of 'new' rephrased ('phrasified') sentences. 28 | """ 29 | 30 | sentence_list: list[SentenceWithRephrases] 31 | 32 | def phrases_as_list(self) -> List[str]: 33 | """Collect all the rephrased sentences from a model output. 34 | 35 | :return: List of rephrased sentences. 36 | """ 37 | rephrases = [] 38 | for sentence in self.sentence_list: 39 | for new_sentence in sentence.new: 40 | if len(new_sentence) > 0: 41 | rephrases.append(new_sentence) 42 | return rephrases 43 | 44 | def pretty_print_rephrased(self, print_fn: Callable[[str], None] = print) -> None: 45 | """Pretty print the rephrased sentences.""" 46 | for sentence in self.sentence_list: 47 | print_fn(f"{sentence.orig}") 48 | for new_sentence in sentence.new: 49 | print_fn(f" -->\t{new_sentence}") 50 | 51 | def get_sentence_mappings(self) -> Dict[str, List[str]]: 52 | """Return a dictionary mapping original sentences to rephrased sentences. 53 | 54 | :return: Dictionary with keys being original sentences and values corresponding to rephrased sentences. 55 | """ 56 | sentence_mappings = {x.orig: x.new for x in self.sentence_list} 57 | return sentence_mappings 58 | 59 | def to_grounded_phrases_list(self, rephrased: bool = True) -> GroundedPhraseList: 60 | """Convert the parsed report to a `GroundedPhraseList`. Specifically a list of `GroundedPhrase` objects. 61 | If rephrased (default), we use the 'new' phrases. Otherwise we use 'orig'. 62 | """ 63 | sequence = GroundedPhraseList() 64 | if rephrased: 65 | rephrased_sentences = self.phrases_as_list() 66 | for sentence in rephrased_sentences: 67 | sequence.append(GroundedPhrase(text=sentence)) 68 | else: 69 | for sentence_with_rephrases in self.sentence_list: 70 | sequence.append(GroundedPhrase(text=sentence_with_rephrases.orig)) 71 | return sequence 72 | 73 | 74 | class PhraseParsingExample(BaseModel): 75 | """Dataclass for a single example.""" 76 | 77 | example_id: int | str 78 | findings_text: str 79 | parsed_report: ParsedReport | None = None 80 | study_id: str | None = None 81 | example_rationale: str | None = None 82 | 83 | @property 84 | def input(self) -> str: 85 | return self.findings_text 86 | 87 | @property 88 | def output(self) -> ParsedReport | None: 89 | return self.parsed_report 90 | 91 | @root_validator 92 | @classmethod 93 | def no_unnecessarily_duplicated_sentences(cls, values: dict[str, Any]) -> dict[str, Any]: 94 | """Make sure the same 'orig' sentence doesn't appear twice.""" 95 | if values["parsed_report"] is None: 96 | # Nothing to do here 97 | return values 98 | sentence_list = values["parsed_report"].sentence_list 99 | findings_text = values["findings_text"] 100 | duplicated_sentences = [sentence.orig for sentence in sentence_list if sentence_list.count(sentence) > 1] 101 | # A duplicated sentence is acceptable if it also appears in the original report twice. 102 | real_duplications = [] 103 | for duplication_candidate in duplicated_sentences: 104 | if findings_text.count(duplication_candidate) == 1: 105 | real_duplications.append(duplication_candidate) 106 | if len(real_duplications) > 0: 107 | raise ValueError( 108 | "Duplicate sentences found in ParsedReport." 109 | f"Duplicated sentences: {real_duplications}. Original report: {findings_text}." 110 | ) 111 | return values 112 | 113 | 114 | def load_examples_from_json(file_path: Path) -> list[PhraseParsingExample]: 115 | """ 116 | Given a path to a json file, load the examples into a list of PhraseParsingExample objects. 117 | 118 | This is implemented to be "backwards compatible" with the old json format, where the parsed_report was a list of 119 | strings. 120 | """ 121 | if file_path.suffix != ".json": 122 | file_path = file_path.with_suffix(".json") 123 | examples = json.load(open(file_path, "r", encoding="utf-8")) 124 | examples_list: list[PhraseParsingExample] = [] 125 | for example in examples: 126 | parsed_example = PhraseParsingExample.parse_obj(example) 127 | examples_list.append(parsed_example) 128 | return examples_list 129 | -------------------------------------------------------------------------------- /src/radfact/llm_utils/text_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | import re 8 | from difflib import SequenceMatcher 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def normalise_text_for_comparison(text: str) -> str: 15 | """ 16 | Normalise a string for comparison by removing whitespace, dashes, and punctuation. 17 | 18 | :param text: The text to normalise. 19 | :return: The normalised text. 20 | """ 21 | text = text.lower().strip() # Normalise case and trim whitespace 22 | text = re.sub(r"\s+", " ", text) # Remove duplicated spaces 23 | text = re.sub(r"[\-–—]", "", text) # Remove dashes 24 | text = re.sub(r"\s*(?<=[\.\:\!\?])", "", text) # Remove spaces before punctuation 25 | text = re.sub(r"[\.\:\!\?]$", "", text) # Remove final punctuation 26 | return text 27 | 28 | 29 | def find_best_match(text: str, candidate_texts: list[str]) -> tuple[int, str]: 30 | """ 31 | Given "text" and a list of possible matches ("candidate_texts"), return the 32 | index of the best match and the match itself. We assume there is a match. 33 | 34 | :param text: The text to match. 35 | :param candidate_texts: The list of candidate texts to match against. 36 | :return: A tuple of the index of the best match and the match itself. 37 | """ 38 | candidates_normalised = [normalise_text_for_comparison(candidate) for candidate in candidate_texts] 39 | text_normalised = normalise_text_for_comparison(text) 40 | 41 | # Case 1: Exact match, with or without normalisation 42 | for i, candidate in enumerate(candidates_normalised): 43 | if candidate == text_normalised: 44 | return i, candidate_texts[i] 45 | 46 | # Case 2: Not exact match. Use the longest common substring to select. 47 | logger.info("No good match! Using substring matching...") 48 | best_match = "" 49 | match_index = -1 50 | best_substring_length = 0 51 | for i, candidate in enumerate(candidates_normalised): 52 | substring_match = SequenceMatcher(None, text_normalised, candidate).find_longest_match( 53 | 0, len(text_normalised), 0, len(candidate) 54 | ) 55 | substring = text[substring_match.a : substring_match.a + substring_match.size] 56 | if len(substring) > best_substring_length: 57 | best_match = candidate_texts[i] 58 | match_index = i 59 | best_substring_length = len(substring) 60 | assert match_index != -1, f"Failed to find a match for {text} in {candidate_texts}." 61 | return match_index, best_match 62 | -------------------------------------------------------------------------------- /src/radfact/metric/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | -------------------------------------------------------------------------------- /src/radfact/metric/bootstrapping.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | from typing import Generator 8 | 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from radfact.metric.radfact import InputDict, PerSampleResultType, RadFactMetric, ReturnType 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def _collate_list_of_dicts(dicts: list[dict[str, float]]) -> dict[str, list[float]]: 18 | """Convert a list of dictionaries into a dictionary of lists, preserving order.""" 19 | keys = dicts[0].keys() 20 | if not all(elem.keys() == keys for elem in dicts[1:]): 21 | raise ValueError("Every dict in the list should have the same keys") 22 | return {key: [elem[key] for elem in dicts] for key in keys} 23 | 24 | 25 | class MetricBootstrapper: 26 | """Utility to report bootstrapping statistics for RadFact metric.""" 27 | 28 | def __init__(self, metric: RadFactMetric, num_samples: int, seed: int | None = None) -> None: 29 | """ 30 | :param metric: The metric for which to compute bootstrap statistics (e.g. RadFactMetric). 31 | :param num_samples: Number of bootstrap samples to generate, ideally in the hundreds. 32 | :param seed: RNG seed for reproducibility. By default (`None`), will give different results every time. 33 | """ 34 | self.metric = metric 35 | self.num_samples = num_samples 36 | self.seed = seed 37 | 38 | def _generate_bootstrap_results( 39 | self, 40 | candidates: InputDict | None = None, 41 | references: InputDict | None = None, 42 | results_per_sample: PerSampleResultType | None = None, 43 | ) -> Generator[ReturnType, None, None]: 44 | """Compute the bootstrap results for a radfact metric by drawing `num_samples` samples with replacement, 45 | and re-computing the metric. 46 | 47 | 48 | :param candidates: The set of candidate reports. 49 | :param references: The set of reference reports. 50 | :param results_per_sample: Intermediate per-sample results. This can be used to pass in pre-computed per-sample 51 | results. If provided, the arguments `candidates` and `references` will be ignored for sample wise metrics. 52 | :yield: A generator of bootstrap results, where each result of type `MetricReturnType`. 53 | """ 54 | if results_per_sample is None: 55 | assert candidates is not None and references is not None 56 | results_per_sample = self.metric.compute_results_per_sample(candidates=candidates, references=references) 57 | assert results_per_sample is not None # for mypy 58 | num_records = len(results_per_sample) 59 | if num_records == 0: 60 | logger.warning("No samples to bootstrap. Metrics will all be NaN.") 61 | rng = np.random.default_rng(seed=self.seed) 62 | boot_indices_generator = ( 63 | # Draw with replacement from the range of indices. 64 | (None, rng.choice(num_records, size=num_records, replace=True)) 65 | for _ in tqdm(range(self.num_samples), total=self.num_samples) 66 | ) 67 | for _, boot_indices in boot_indices_generator: 68 | boot_results_per_sample = self.metric.reindex_results_per_sample(results_per_sample, boot_indices) 69 | boot_results = self.metric.aggregate_results(boot_results_per_sample) 70 | yield boot_results 71 | 72 | @staticmethod 73 | def _compute_bootstrap_stats(values: list[float]) -> dict[str, float]: 74 | q025, q25, median, q75, q975 = np.nanquantile(np.asarray(values), [0.025, 0.25, 0.5, 0.75, 0.975], axis=0) 75 | numpy_stats = { 76 | "mean": np.nanmean(values, axis=0), 77 | "stderr": np.nanstd(values, axis=0), 78 | "p2.5th": q025, 79 | "p25th": q25, 80 | "median": median, 81 | "p75th": q75, 82 | "p97.5th": q975, 83 | } 84 | return {stat_name: value.tolist() for stat_name, value in numpy_stats.items()} 85 | 86 | def compute_bootstrap_metrics( 87 | self, 88 | candidates: InputDict | None = None, 89 | references: InputDict | None = None, 90 | results_per_sample: PerSampleResultType | None = None, 91 | ) -> dict[str, float]: 92 | """Calculate bootstrap statistics for RadFact metric that has intermediate per-sample results. 93 | 94 | :param candidates: The set of candidate reports to bootstrap. 95 | :param references: The set of reference reports to bootstrap. 96 | :param results_per_sample: Intermediate per-sample results. This can be used to pass in pre-computed per-sample 97 | results. If provided, the arguments `candidates` and `references` will be ignored for sample wise metrics. 98 | :return: A dictionary of bootstrap statistics, containing the mean (`mean`), standard error (`stderr`), 95% 99 | confidence interval (`p2.5th` and `p97.5th`), quartiles (`p25th` and `p75th`), and median (`median`) of the 100 | bootstrap distribution. If `metric` returns detailed submetrics, bootstrap statistics will also be included, 101 | e.g. `submetric/mean`, `submetric/stderr`, etc. 102 | """ 103 | boot_results_generator = self._generate_bootstrap_results( 104 | candidates=candidates, references=references, results_per_sample=results_per_sample 105 | ) 106 | boot_main_scores: list[float] = [] 107 | boot_detailed_scores_dicts: list[dict[str, float]] = [] 108 | for boot_results in boot_results_generator: 109 | assert isinstance(boot_results, tuple) 110 | main_score, detailed_scores_dict = boot_results 111 | boot_main_scores.append(main_score) 112 | boot_detailed_scores_dicts.append(detailed_scores_dict) 113 | 114 | stats_dict = self._compute_bootstrap_stats(boot_main_scores) 115 | collated_detailed_scores_dict = _collate_list_of_dicts(boot_detailed_scores_dicts) 116 | for submetric_name, values in collated_detailed_scores_dict.items(): 117 | submetric_stats_dict = self._compute_bootstrap_stats(values) 118 | for stat_name, stat_value in submetric_stats_dict.items(): 119 | stats_dict[f"{submetric_name}/{stat_name}"] = stat_value 120 | return stats_dict 121 | -------------------------------------------------------------------------------- /src/radfact/metric/box_metrics.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | 11 | from radfact.data_utils.grounded_phrase_list import NormalizedBox 12 | 13 | 14 | IOU = "iou" 15 | PRECISION = "precision" 16 | RECALL = "recall" 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def get_mask_from_boxes(boxes: list[NormalizedBox], mask_size: int = 224) -> npt.NDArray[np.bool_]: 22 | """Gets a pixel mask from a list of boxes. 23 | 24 | It creates a numpy array of zeros with the shape of (mask_size, mask_size). It then iterates over 25 | the boxes and sets the corresponding pixels to True in the mask. It returns the mask as a numpy array. 26 | 27 | :param boxes: A list of Box objects to convert into a mask. 28 | :param mask_size: The image size of the mask. Defaults to 224. 29 | :returns: A numpy array of boolean values representing the pixel mask. 30 | """ 31 | mask = np.zeros((mask_size, mask_size), dtype=np.bool_) 32 | for box in boxes: 33 | box_coord = (box.x_min, box.y_min, box.x_max, box.y_max) 34 | x1, y1, x2, y2 = (np.array(box_coord) * mask_size).astype(int) 35 | mask[x1:x2, y1:y2] = True 36 | return mask 37 | 38 | 39 | def compute_box_metrics( 40 | pred_boxes: list[NormalizedBox], true_boxes: list[NormalizedBox], mask_size: int 41 | ) -> dict[str, float]: 42 | """Computes the IOU, precision, and recall scores for a pair of box lists. 43 | It converts the boxes into pixel masks and calculates the prediction area, ground truth area, intersection area 44 | and union area. It then returns a dictionary of IOU, precision, and recall scores based on these areas. 45 | 46 | :param pred_boxes: A list of Box objects for the prediction boxes. 47 | :param true_boxes: A list of Box objects for the ground truth boxes. 48 | :param mask_size: The image size of the masks. 49 | 50 | :returns: A dictionary of IOU, precision, and recall scores. 51 | """ 52 | pred_mask = get_mask_from_boxes(pred_boxes, mask_size) 53 | true_mask = get_mask_from_boxes(true_boxes, mask_size) 54 | 55 | pred_area = pred_mask.sum() 56 | true_area = true_mask.sum() 57 | if true_area <= 0: 58 | logger.warning(f"WARNING: True area is not positive, {true_area=}. {true_boxes=}, {pred_boxes=}") 59 | intersection_area = (pred_mask & true_mask).sum() 60 | union_area = (pred_mask | true_mask).sum() 61 | iou = intersection_area / union_area 62 | if pred_area > 0: 63 | precision = intersection_area / pred_area 64 | else: 65 | precision = np.nan 66 | if true_area > 0: 67 | recall = intersection_area / true_area 68 | else: 69 | recall = np.nan 70 | return { 71 | IOU: iou, 72 | PRECISION: precision, 73 | RECALL: recall, 74 | } 75 | -------------------------------------------------------------------------------- /src/radfact/metric/print_utils.py: -------------------------------------------------------------------------------- 1 | def print_bootstrap_results(results: dict[str, float], metrics: list[str]) -> None: 2 | for metric_name in metrics: 3 | median = results[f"{metric_name}/median"] * 100 4 | p025 = results[f"{metric_name}/p2.5th"] * 100 5 | p975 = results[f"{metric_name}/p97.5th"] * 100 6 | print(f"{metric_name}: {median:0.2f} (95% CI: [{p025:0.2f}, {p975:0.2f}])") 7 | 8 | 9 | def print_results(results: dict[str, float], metrics: list[str]) -> None: 10 | for metric_name in metrics: 11 | metric = results[f"{metric_name}"] * 100 12 | print(f"{metric_name}: {metric:0.2f}") 13 | -------------------------------------------------------------------------------- /src/radfact/metric/schema.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import logging 7 | from dataclasses import dataclass, field 8 | from enum import Enum 9 | 10 | import numpy as np 11 | 12 | from radfact.data_utils.grounded_phrase_list import GroundedPhrase, NormalizedBox 13 | from radfact.llm_utils.nli.schema import EvidencedPhrase 14 | from radfact.llm_utils.text_utils import find_best_match, normalise_text_for_comparison 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @dataclass(frozen=True) 20 | class OneWayNLIFractions: 21 | entailed_fraction: float 22 | full_box_fraction: float 23 | entailed_box_fraction: float 24 | num_phrases: int 25 | num_phrases_with_boxes: int 26 | 27 | 28 | @dataclass(frozen=True) 29 | class RadFactScore: 30 | logical_precision: float 31 | logical_recall: float 32 | spatial_precision: float 33 | spatial_recall: float 34 | grounding_precision: float 35 | grounding_recall: float 36 | num_candidate_phrases: int | float 37 | num_reference_phrases: int | float 38 | num_candidate_phrases_with_boxes: int | float 39 | num_reference_phrases_with_boxes: int | float 40 | 41 | @staticmethod 42 | def _compute_f1_score(precision: float, recall: float) -> float: 43 | return 2 * (precision * recall) / (precision + recall) 44 | 45 | @property 46 | def logical_f1(self) -> float: 47 | return self._compute_f1_score(self.logical_precision, self.logical_recall) 48 | 49 | @property 50 | def spatial_f1(self) -> float: 51 | return self._compute_f1_score(self.spatial_precision, self.spatial_recall) 52 | 53 | @property 54 | def grounding_f1(self) -> float: 55 | return self._compute_f1_score(self.grounding_precision, self.grounding_recall) 56 | 57 | @classmethod 58 | def from_candidate_and_reference_fractions( 59 | cls, candidate: OneWayNLIFractions, reference: OneWayNLIFractions 60 | ) -> "RadFactScore": 61 | """Create a score from the candidate and reference fractions.""" 62 | return cls( 63 | logical_precision=candidate.entailed_fraction, 64 | logical_recall=reference.entailed_fraction, 65 | spatial_precision=candidate.full_box_fraction, 66 | spatial_recall=reference.full_box_fraction, 67 | grounding_precision=candidate.entailed_box_fraction, 68 | grounding_recall=reference.entailed_box_fraction, 69 | num_candidate_phrases=candidate.num_phrases, 70 | num_reference_phrases=reference.num_phrases, 71 | num_candidate_phrases_with_boxes=candidate.num_phrases_with_boxes, 72 | num_reference_phrases_with_boxes=reference.num_phrases_with_boxes, 73 | ) 74 | 75 | @classmethod 76 | def from_aggregate(cls, scores: list["RadFactScore"], only_factual_scores: bool = False) -> "RadFactScore": 77 | """Aggregate the scores from a list of samples. If only_factual_scores is True, we only aggregate the logical 78 | scores. The spatial and grounding scores are set to 0.0. 79 | """ 80 | 81 | def _nanmean(values: list[float | int]) -> float: 82 | """ 83 | Compute the mean of the values, ignoring NaNs. 84 | This is mostly for mypy convenience. 85 | """ 86 | return float(np.nanmean(values)) 87 | 88 | n = len(scores) 89 | if n == 0: 90 | return cls( 91 | logical_precision=0.0, 92 | logical_recall=0.0, 93 | spatial_precision=0.0, 94 | spatial_recall=0.0, 95 | grounding_precision=0.0, 96 | grounding_recall=0.0, 97 | num_candidate_phrases=0.0, 98 | num_reference_phrases=0.0, 99 | num_candidate_phrases_with_boxes=0.0, 100 | num_reference_phrases_with_boxes=0.0, 101 | ) 102 | return cls( 103 | # If no predicted or reference phrases, these can be NaN 104 | logical_precision=_nanmean([x.logical_precision for x in scores]), 105 | logical_recall=_nanmean([x.logical_recall for x in scores]), 106 | # Box metrics can be NaN if there are no boxes, either direction 107 | spatial_precision=0.0 if only_factual_scores else _nanmean([x.spatial_precision for x in scores]), 108 | spatial_recall=0.0 if only_factual_scores else _nanmean([x.spatial_recall for x in scores]), 109 | grounding_precision=0.0 if only_factual_scores else _nanmean([x.grounding_precision for x in scores]), 110 | grounding_recall=0.0 if only_factual_scores else _nanmean([x.grounding_recall for x in scores]), 111 | # Numbers of phrases etc. should never have NaN 112 | num_candidate_phrases=sum(x.num_candidate_phrases for x in scores) / n, 113 | num_reference_phrases=sum(x.num_reference_phrases for x in scores) / n, 114 | # These can be nan if we are running the metric on data without boxes so we set it to 0.0 when 115 | # only_factual_scores is True 116 | num_candidate_phrases_with_boxes=( 117 | 0.0 if only_factual_scores else _nanmean([x.num_candidate_phrases_with_boxes for x in scores]) 118 | ), 119 | num_reference_phrases_with_boxes=( 120 | 0.0 if only_factual_scores else _nanmean([x.num_reference_phrases_with_boxes for x in scores]) 121 | ), 122 | ) 123 | 124 | 125 | class SpatialEntailmentStatus(str, Enum): 126 | NO_BOXES = "no_boxes" 127 | SPATIAL_ENTAILMENT = "spatial_entailment" 128 | NO_SPATIAL_ENTAILMENT = "no_spatial_entailment" 129 | 130 | 131 | @dataclass(frozen=True, kw_only=True) 132 | class GroundedPhraseEvidenced(GroundedPhrase): 133 | status: str 134 | spatial_entailment_status: SpatialEntailmentStatus | None = None 135 | evidence: list[GroundedPhrase] 136 | evidence_indices: list[int] | None = None 137 | 138 | def __post_init__(self) -> None: 139 | super().__post_init__() 140 | if self.evidence_indices is not None: 141 | assert len(self.evidence_indices) == len(self.evidence) 142 | 143 | def get_all_evidence_boxes(self) -> list[NormalizedBox]: 144 | all_evidence_boxes = [] 145 | for premise in self.evidence: 146 | if premise.boxes is not None: 147 | all_evidence_boxes.extend(premise.boxes) 148 | return all_evidence_boxes 149 | 150 | @staticmethod 151 | def attach_evidence_to_hypothesis( 152 | *, 153 | evidenced_phrase: EvidencedPhrase, 154 | hypothesis_grounded_phrase: GroundedPhrase, 155 | premise_grounded_phrases: list[GroundedPhrase], 156 | ) -> 'GroundedPhraseEvidenced': 157 | """ 158 | Attach the evidence to the hypothesis phrase based on NLI output. 159 | 160 | We need to do this because `GroundedPhrase` includes boxes, whereas the NLI processor operates only on strings. 161 | 162 | :param evidenced_phrase: `EvidencedPhrase` for the given hypothesis, as generated by the NLI processor. 163 | :param hypothesis_grounded_phrase: `GroundedPhrase` corresponding to the hypothesis. 164 | :param premise_grounded_phrases: List of `GroundedPhrase` premises, containing at least the evidence phrases. 165 | :return: `GroundedPhraseEvidenced` corresponding to the hypothesis with NLI status and evidence. 166 | """ 167 | if normalise_text_for_comparison(hypothesis_grounded_phrase.text) != normalise_text_for_comparison( 168 | evidenced_phrase.phrase 169 | ): 170 | raise ValueError( 171 | f"Evidenced phrase ({evidenced_phrase.phrase}) does not match " 172 | f"hypothesis ({hypothesis_grounded_phrase.text})." 173 | ) 174 | evidence_indices = [ 175 | find_best_match(premise, [x.text for x in premise_grounded_phrases])[0] 176 | for premise in evidenced_phrase.evidence 177 | ] 178 | evidence_grounded_phrases = [premise_grounded_phrases[i] for i in evidence_indices] 179 | return GroundedPhraseEvidenced( 180 | text=hypothesis_grounded_phrase.text, 181 | boxes=hypothesis_grounded_phrase.boxes, 182 | status=evidenced_phrase.status, 183 | evidence=evidence_grounded_phrases, 184 | evidence_indices=evidence_indices, 185 | ) 186 | 187 | @staticmethod 188 | def attach_evidence_to_all_hypotheses( 189 | *, 190 | evidenced_phrases: list[EvidencedPhrase], 191 | hypothesis_grounded_phrases: list[GroundedPhrase], 192 | premise_grounded_phrases: list[GroundedPhrase], 193 | ) -> list['GroundedPhraseEvidenced']: 194 | """ 195 | Attach evidence to all hypothesis phrase based on NLI output. 196 | 197 | We need to do this because `GroundedPhrase` includes boxes, whereas the NLI processor operates only on strings. 198 | 199 | :param evidenced_phrases: List of `EvidencedPhrase` as generated by the NLI processor. 200 | All phrases and evidence must be contained in the given hypothesis and premise lists, respectively. 201 | :param hypothesis_grounded_phrases: List of `GroundedPhrase` hypotheses. 202 | :param premise_grounded_phrases: List of `GroundedPhrase` premises. 203 | :return: List of `GroundedPhraseEvidenced` corresponding to the hypotheses with NLI status and evidence. 204 | """ 205 | 206 | def retrieve_evidenced_phrase(phrase: str) -> EvidencedPhrase: 207 | """Given the phrase (text), retrieve the EvidencedPhrase object from the list.""" 208 | phrase_idx, _ = find_best_match(phrase, [x.phrase for x in evidenced_phrases]) 209 | return evidenced_phrases[phrase_idx] 210 | 211 | return [ 212 | GroundedPhraseEvidenced.attach_evidence_to_hypothesis( 213 | hypothesis_grounded_phrase=hypothesis_grounded_phrase, 214 | evidenced_phrase=retrieve_evidenced_phrase(hypothesis_grounded_phrase.text), 215 | premise_grounded_phrases=premise_grounded_phrases, 216 | ) 217 | for hypothesis_grounded_phrase in hypothesis_grounded_phrases 218 | ] 219 | 220 | 221 | @dataclass(frozen=True) 222 | class PerSampleNLIResult: 223 | study_id: str 224 | scores: RadFactScore | None = None 225 | candidate_phrases: list[GroundedPhraseEvidenced] = field(default_factory=list) 226 | reference_phrases: list[GroundedPhraseEvidenced] = field(default_factory=list) 227 | -------------------------------------------------------------------------------- /src/radfact/paths.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | from pathlib import Path 7 | 8 | REPOSITORY_ROOT_DIR = Path(__file__).absolute().parents[2] 9 | RADFACT_ROOT_DIR = REPOSITORY_ROOT_DIR / "src" / "radfact" 10 | OUTPUT_DIR = REPOSITORY_ROOT_DIR / "outputs" 11 | LLM_UTILS_DIR = RADFACT_ROOT_DIR / "llm_utils" 12 | CONFIGS_DIR = REPOSITORY_ROOT_DIR / "configs" 13 | EXAMPLES_DIR = REPOSITORY_ROOT_DIR / "examples" 14 | WORKSPACE_CONFIG_PATH = REPOSITORY_ROOT_DIR / "config.json" 15 | 16 | 17 | def get_prompts_dir(task: str) -> Path: 18 | return LLM_UTILS_DIR / task / "prompts" 19 | -------------------------------------------------------------------------------- /tests/data_utils/test_grounded_phrase_list.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import pytest 7 | 8 | from radfact.data_utils.grounded_phrase_list import GroundedPhrase, GroundedPhraseList, NormalizedBox 9 | 10 | 11 | def test_box() -> None: 12 | box = NormalizedBox(0.1, 0.2, 0.3, 0.4) 13 | assert box.x_min == 0.1 14 | assert box.y_min == 0.2 15 | assert box.x_max == 0.3 16 | assert box.y_max == 0.4 17 | assert tuple(box) == (0.1, 0.2, 0.3, 0.4) 18 | 19 | 20 | def test_box_invalid_coords() -> None: 21 | for invalid_box_coords in [ 22 | (-0.1, 0.2, 0.3, 0.4), # x_min < 0 23 | (0.6, 0.2, 0.3, 0.4), # x_min > x_max 24 | (0.1, 0.2, 1.3, 0.4), # x_max > 1 25 | (0.1, -0.2, 0.3, 0.4), # y_min < 0 26 | (0.1, 0.6, 0.8, 0.4), # y_min > y_max 27 | (0.1, 0.2, 0.3, 1.4), # y_max > 1 28 | ]: 29 | with pytest.raises(ValueError, match="Invalid . coordinates"): 30 | NormalizedBox(*invalid_box_coords) 31 | 32 | 33 | def test_box_has_zero_area() -> None: 34 | assert NormalizedBox(0, 0, 0, 0).has_zero_area() 35 | assert NormalizedBox(0, 0, 0, 0.5).has_zero_area() 36 | assert NormalizedBox(0, 0, 0.5, 0).has_zero_area() 37 | assert not NormalizedBox(0, 0, 0.5, 0.5).has_zero_area() 38 | 39 | 40 | def test_get_all_text() -> None: 41 | grounded_phrase_list = GroundedPhraseList( 42 | [ 43 | "Plain str ", # Note there is whitespace at the end 44 | GroundedPhrase( 45 | text="Grounded str", 46 | boxes=[NormalizedBox(0.1, 0.2, 0.3, 0.4)], 47 | ), 48 | ] 49 | ) 50 | assert grounded_phrase_list.get_all_text() == "Plain str Grounded str" 51 | assert grounded_phrase_list.get_all_text(sep="|") == "Plain str|Grounded str" 52 | 53 | 54 | def test_get_all_boxes() -> None: 55 | grounded_phrase_list = GroundedPhraseList( 56 | [ 57 | "Plain str", 58 | GroundedPhrase( 59 | text="Grounded str", 60 | boxes=[NormalizedBox(0.1, 0.2, 0.3, 0.4)], 61 | ), 62 | ] 63 | ) 64 | assert grounded_phrase_list.get_all_boxes() == [NormalizedBox(0.1, 0.2, 0.3, 0.4)] 65 | with pytest.raises(ValueError, match="Encountered a non-box while extracting boxes"): 66 | grounded_phrase_list.get_all_boxes(fail_if_non_box=True) 67 | 68 | grounded_phrase_list_no_boxes = GroundedPhraseList([GroundedPhrase(text="Phrase str with no box", boxes=None)]) 69 | assert grounded_phrase_list_no_boxes.get_all_boxes() == [] 70 | with pytest.raises(ValueError, match="Encountered a non-box while extracting boxes"): 71 | assert grounded_phrase_list_no_boxes.get_all_boxes(fail_if_non_box=True) == [] 72 | 73 | box_only_sequence = GroundedPhraseList([NormalizedBox(0.1, 0.2, 0.3, 0.4)]) 74 | assert box_only_sequence.get_all_boxes() == [NormalizedBox(0.1, 0.2, 0.3, 0.4)] 75 | assert box_only_sequence.get_all_boxes(fail_if_non_box=True) == [NormalizedBox(0.1, 0.2, 0.3, 0.4)] 76 | 77 | 78 | def test_get_all_grounded_phrases() -> None: 79 | grounded_phrase_list = GroundedPhraseList( 80 | [ 81 | "Plain str", 82 | GroundedPhrase( 83 | text="Grounded str", 84 | boxes=[NormalizedBox(0.1, 0.2, 0.3, 0.4)], 85 | ), 86 | ] 87 | ) 88 | assert grounded_phrase_list.get_all_grounded_phrases() == [ 89 | GroundedPhrase(text="Grounded str", boxes=[NormalizedBox(0.1, 0.2, 0.3, 0.4)]) 90 | ] 91 | with pytest.raises(ValueError, match="Encountered a non-GroundedPhrase while extracting phrases"): 92 | grounded_phrase_list.get_all_grounded_phrases(fail_if_non_grounded_phrase=True) 93 | 94 | phrase_only_sequence = GroundedPhraseList([GroundedPhrase(text="Phrase str with no box", boxes=None)]) 95 | assert phrase_only_sequence.get_all_grounded_phrases() == [ 96 | GroundedPhrase(text="Phrase str with no box", boxes=None) 97 | ] 98 | assert phrase_only_sequence.get_all_grounded_phrases(fail_if_non_grounded_phrase=True) == [ 99 | GroundedPhrase(text="Phrase str with no box", boxes=None) 100 | ] 101 | 102 | 103 | def test_grounded_phrase_from_dict() -> None: 104 | grounded_phrase_dict = { 105 | "text": "Grounded str", 106 | "boxes": [ 107 | {"x_min": 0.1, "y_min": 0.2, "x_max": 0.3, "y_max": 0.4}, 108 | ], 109 | } 110 | grounded_phrase = GroundedPhrase.from_dict(grounded_phrase_dict) # type: ignore 111 | assert grounded_phrase.text == "Grounded str" 112 | assert grounded_phrase.boxes == [NormalizedBox(0.1, 0.2, 0.3, 0.4)] 113 | 114 | grounded_phrase_dict = { 115 | "text": "Grounded str", 116 | "boxes": [ 117 | {"x_min": 0.1, "y_min": 0.2, "x_max": 0.3, "y_max": 0.4}, 118 | {"x_min": 0.5, "y_min": 0.6, "x_max": 0.7, "y_max": 0.8}, 119 | ], 120 | } 121 | grounded_phrase = GroundedPhrase.from_dict(grounded_phrase_dict) # type: ignore 122 | assert grounded_phrase.text == "Grounded str" 123 | assert grounded_phrase.boxes == [NormalizedBox(0.1, 0.2, 0.3, 0.4), NormalizedBox(0.5, 0.6, 0.7, 0.8)] 124 | 125 | # malformed dicts 126 | with pytest.raises(KeyError, match="text"): 127 | GroundedPhrase.from_dict({"boxes": []}) 128 | with pytest.raises(KeyError, match="boxes"): 129 | GroundedPhrase.from_dict({"text": ""}) 130 | with pytest.raises(ValueError, match="boxes is not a list:"): 131 | GroundedPhrase.from_dict({"text": "", "boxes": {}}) # type: ignore 132 | 133 | 134 | def test_grounded_phrase_list_to_dict() -> None: 135 | grounded_phrase_list = GroundedPhraseList( 136 | [ 137 | "Plain str", 138 | GroundedPhrase( 139 | text="Ungrounded str", 140 | boxes=None, 141 | ), 142 | ] 143 | ) 144 | assert grounded_phrase_list.to_list_of_dicts() == [ 145 | { 146 | "text": "Plain str", 147 | }, 148 | { 149 | "text": "Ungrounded str", 150 | "boxes": None, 151 | }, 152 | ] 153 | 154 | grounded_phrase_list = GroundedPhraseList( 155 | [ 156 | "Plain str", 157 | GroundedPhrase( 158 | text="Grounded str", 159 | boxes=[NormalizedBox(0.1, 0.2, 0.3, 0.4)], 160 | ), 161 | ] 162 | ) 163 | assert grounded_phrase_list.to_list_of_dicts() == [ 164 | { 165 | "text": "Plain str", 166 | }, 167 | { 168 | "text": "Grounded str", 169 | "boxes": [ 170 | {"x_min": 0.1, "y_min": 0.2, "x_max": 0.3, "y_max": 0.4}, 171 | ], 172 | }, 173 | ] 174 | 175 | grounded_phrase_list = GroundedPhraseList( 176 | [ 177 | "Plain str", 178 | GroundedPhrase( 179 | text="Grounded str", 180 | boxes=[NormalizedBox(0.1, 0.2, 0.3, 0.4), NormalizedBox(0.5, 0.6, 0.7, 0.8)], 181 | ), 182 | ] 183 | ) 184 | 185 | assert grounded_phrase_list.to_list_of_dicts() == [ 186 | { 187 | "text": "Plain str", 188 | }, 189 | { 190 | "text": "Grounded str", 191 | "boxes": [ 192 | {"x_min": 0.1, "y_min": 0.2, "x_max": 0.3, "y_max": 0.4}, 193 | {"x_min": 0.5, "y_min": 0.6, "x_max": 0.7, "y_max": 0.8}, 194 | ], 195 | }, 196 | ] 197 | 198 | 199 | def test_grounded_phrase_list_from_dict() -> None: 200 | list_of_dicts = [ 201 | { 202 | "text": "Plain str", 203 | }, 204 | { 205 | "text": "Ungrounded str", 206 | "boxes": None, 207 | }, 208 | ] 209 | grounded_phrase_list = GroundedPhraseList.from_list_of_dicts(list_of_dicts) # type: ignore 210 | assert grounded_phrase_list == GroundedPhraseList( 211 | [ 212 | "Plain str", 213 | GroundedPhrase( 214 | text="Ungrounded str", 215 | boxes=None, 216 | ), 217 | ] 218 | ) 219 | 220 | list_of_dicts = [ 221 | { 222 | "text": "Plain str", 223 | }, 224 | { 225 | "text": "Grounded str", 226 | "boxes": [ 227 | {"x_min": 0.1, "y_min": 0.2, "x_max": 0.3, "y_max": 0.4}, 228 | ], 229 | }, 230 | ] 231 | grounded_phrase_list = GroundedPhraseList.from_list_of_dicts(list_of_dicts) # type: ignore 232 | assert grounded_phrase_list == GroundedPhraseList( 233 | [ 234 | "Plain str", 235 | GroundedPhrase( 236 | text="Grounded str", 237 | boxes=[NormalizedBox(0.1, 0.2, 0.3, 0.4)], 238 | ), 239 | ] 240 | ) 241 | 242 | list_of_dicts = [ 243 | { 244 | "text": "Plain str", 245 | }, 246 | { 247 | "text": "Grounded str", 248 | "boxes": [ 249 | {"x_min": 0.1, "y_min": 0.2, "x_max": 0.3, "y_max": 0.4}, 250 | {"x_min": 0.5, "y_min": 0.6, "x_max": 0.7, "y_max": 0.8}, 251 | ], 252 | }, 253 | ] 254 | grounded_phrase_list = GroundedPhraseList.from_list_of_dicts(list_of_dicts) # type: ignore 255 | assert grounded_phrase_list == GroundedPhraseList( 256 | [ 257 | "Plain str", 258 | GroundedPhrase( 259 | text="Grounded str", 260 | boxes=[NormalizedBox(0.1, 0.2, 0.3, 0.4), NormalizedBox(0.5, 0.6, 0.7, 0.8)], 261 | ), 262 | ] 263 | ) 264 | 265 | list_of_dicts = [ 266 | { 267 | "text": "Plain str", 268 | }, 269 | { 270 | "text": "Grounded str", 271 | "boxes": [ 272 | {"x_min": 0.1, "y_min": 0.2, "x_max": 0.3, "y_max": 0.4}, 273 | {"x_min": 0.5, "y_min": 0.6, "x_max": 0.7, "y_max": 0.8}, 274 | ], 275 | }, 276 | ] 277 | 278 | grounded_phrase_list = GroundedPhraseList.from_list_of_dicts(list_of_dicts) # type: ignore 279 | assert grounded_phrase_list == GroundedPhraseList( 280 | [ 281 | "Plain str", 282 | GroundedPhrase( 283 | text="Grounded str", 284 | boxes=[NormalizedBox(0.1, 0.2, 0.3, 0.4), NormalizedBox(0.5, 0.6, 0.7, 0.8)], 285 | ), 286 | ] 287 | ) 288 | 289 | # Malformed sequences 290 | list_of_dicts = [ 291 | "Plain str", 292 | GroundedPhrase( 293 | text="Grounded str", 294 | boxes=[NormalizedBox(0.1, 0.2, 0.3, 0.4)], 295 | ), 296 | ] 297 | 298 | with pytest.raises(ValueError, match="Expected dictionary, got:"): 299 | GroundedPhraseList.from_list_of_dicts(list_of_dicts) # type: ignore 300 | list_of_dicts = [ 301 | { 302 | "text": "Grounded str", 303 | "boxes": [ 304 | {"x_min": 0.1, "y_min": 0.2, "x_max": 0.3, "y_max": 0.4}, 305 | ], 306 | "other_key": True, 307 | }, 308 | ] 309 | with pytest.raises(ValueError, match="Unknown member of grounded phrase list"): 310 | GroundedPhraseList.from_list_of_dicts(list_of_dicts) # type: ignore 311 | -------------------------------------------------------------------------------- /tests/llm_utils/engine/test_engine.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | from unittest.mock import Mock 7 | 8 | import pandas as pd 9 | 10 | from radfact.llm_utils.engine.engine import LLMEngine 11 | 12 | 13 | def mock_dataset_sharding(n_data: int, speed_factors: list[float]) -> list[int]: 14 | class DummyEndpoint: 15 | def __init__(self, speed_factor: float) -> None: 16 | self.speed_factor = speed_factor 17 | 18 | mock_engine = Mock(spec=LLMEngine) 19 | mock_engine.dataset_df = pd.DataFrame([{"dummy": i} for i in range(n_data)]) 20 | mock_engine.endpoints = { 21 | str(i): DummyEndpoint(speed_factor=speed_factor) for i, speed_factor in enumerate(speed_factors) 22 | } 23 | 24 | weighted_splits = LLMEngine.get_weighted_splits(mock_engine) 25 | return list(weighted_splits.values()) 26 | 27 | 28 | def test_dataset_sharding() -> None: 29 | assert mock_dataset_sharding(10, [1.0]) == [10] 30 | assert mock_dataset_sharding(10, [1.0, 1.0]) == [5, 5] 31 | assert mock_dataset_sharding(10, [1.0, 1.0, 1.0]) == [4, 3, 3] 32 | assert mock_dataset_sharding(10, [1.0, 2.0]) == [3, 7] 33 | assert mock_dataset_sharding(10, [1.0, 1.0, 2.0]) == [3, 2, 5] 34 | assert mock_dataset_sharding(10, [1.0 for _ in range(11)]) == [1 for _ in range(10)] + [0] 35 | assert mock_dataset_sharding(20, [1.0 for _ in range(11)]) == [2 for _ in range(9)] + [1, 1] 36 | -------------------------------------------------------------------------------- /tests/llm_utils/engine/test_redis_cache.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | from radfact.llm_utils.engine.redis_cache import remove_endpoint_from_json_string 7 | 8 | 9 | def test_remove_endpoint() -> None: 10 | assert remove_endpoint_from_json_string("foo") == "foo" 11 | # Full string matches 12 | assert remove_endpoint_from_json_string('"azure_endpoint": "https://foo"') == "" 13 | # Quotes missing on URL, should not match 14 | assert remove_endpoint_from_json_string('"azure_endpoint": https://foo') == '"azure_endpoint": https://foo' 15 | # Should match 16 | assert remove_endpoint_from_json_string('1, "azure_endpoint": "https://foo", 2') == "1, , 2" 17 | # Should match - api_endpoint is the format used by ChatOpenAI 18 | assert remove_endpoint_from_json_string('1, "api_endpoint": "https://foo", 2') == "1, , 2" 19 | -------------------------------------------------------------------------------- /tests/llm_utils/report_to_phrases/test_schema.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | from radfact.data_utils.grounded_phrase_list import GroundedPhrase, GroundedPhraseList 7 | from radfact.llm_utils.report_to_phrases.schema import ParsedReport, SentenceWithRephrases 8 | 9 | 10 | def test_to_grounded_phrases_list() -> None: 11 | # Create a ParsedReport instance 12 | sentence_with_rephrases = ParsedReport( 13 | sentence_list=[ 14 | SentenceWithRephrases(orig="Original sentence", new=["Rephrased sentence 1", "Rephrased sentence 2"]), 15 | SentenceWithRephrases(orig="Another original sentence", new=["Another rephrased sentence"]), 16 | ] 17 | ) 18 | 19 | # Convert ParsedReport to GroundedPhraseList 20 | sequence = sentence_with_rephrases.to_grounded_phrases_list() 21 | 22 | expected_sequence = GroundedPhraseList( 23 | [ 24 | GroundedPhrase(text="Rephrased sentence 1"), 25 | GroundedPhrase(text="Rephrased sentence 2"), 26 | GroundedPhrase(text="Another rephrased sentence"), 27 | ] 28 | ) 29 | 30 | assert sequence == expected_sequence 31 | -------------------------------------------------------------------------------- /tests/llm_utils/test_text_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | from radfact.llm_utils.text_utils import find_best_match 7 | 8 | 9 | def test_find_best_match() -> None: 10 | text_options = ["There is a small pleural effusion", "No acute cardiopulmonary process", "Widespread consolidation"] 11 | 12 | # Case 1: Exact match, with or without normalisation 13 | text = "There is a small pleural effusion" 14 | index, match = find_best_match(text, text_options) 15 | assert index == 0 16 | assert match == "There is a small pleural effusion" 17 | 18 | # Case 2: Normalisation 19 | text = "there is a small pleural effusion." 20 | index, match = find_best_match(text, text_options) 21 | assert index == 0 22 | assert match == "There is a small pleural effusion" 23 | 24 | # Case 3: Not exact match. Use the longest common substring to select. 25 | text = "No acute cardio" 26 | index, match = find_best_match(text, text_options) 27 | assert index == 1 28 | assert match == "No acute cardiopulmonary process" 29 | 30 | # Case 4: Not exact match. Use the longest common substring to select, but text is longer. 31 | text = "There is consolidation seen throughout the lungs." 32 | index, match = find_best_match(text, text_options) 33 | assert index == 2 34 | assert match == "Widespread consolidation" 35 | -------------------------------------------------------------------------------- /tests/metric/test_box_metrics.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | 9 | from radfact.metric.box_metrics import IOU, PRECISION, RECALL, compute_box_metrics, get_mask_from_boxes 10 | from radfact.data_utils.grounded_phrase_list import NormalizedBox 11 | 12 | 13 | def assert_score(metric_name: str, expected_score: float, actual_score: float) -> None: 14 | assert isinstance(actual_score, float) 15 | is_close = np.isclose(expected_score, actual_score, equal_nan=True) 16 | message = "metric: {}, expected: {}, actual: {}".format(metric_name, expected_score, actual_score) 17 | assert is_close, message 18 | 19 | 20 | def _get_box_names() -> list[str]: 21 | return ["box1", "box2", "box3", "box4", "box5", "box6", "box7", "box8", "box9", "box10", "box10"] 22 | 23 | 24 | def _get_box_pairs() -> list[tuple[str, str]]: 25 | return [ 26 | ("box5", "box6"), # we assume (gt_box, pred_box) 27 | ("box2", "box7"), 28 | ("box3", "box8"), 29 | ("box4", "box4"), 30 | ("box5", "box7"), 31 | ("box9", "box10"), 32 | ("box10", "box11"), 33 | ] 34 | 35 | 36 | def _get_box(key: str) -> NormalizedBox: 37 | box1 = NormalizedBox(0.1, 0.1, 0.3, 0.3) 38 | box2 = NormalizedBox(0.2, 0.2, 0.4, 0.4) 39 | box3 = NormalizedBox(0.5, 0.5, 0.7, 0.7) 40 | box4 = NormalizedBox(0.6, 0.6, 0.8, 0.8) 41 | box5 = NormalizedBox(0.3, 0.3, 0.5, 0.5) 42 | box6 = NormalizedBox(0.4, 0.4, 0.6, 0.6) 43 | box7 = NormalizedBox(0.1, 0.1, 0.4, 0.4) 44 | box8 = NormalizedBox(0.7, 0.7, 0.9, 0.9) 45 | box9 = NormalizedBox(0.3, 0.6, 0.5, 0.8) 46 | box10 = NormalizedBox(0.3, 0.65, 0.5, 0.85) 47 | box11 = NormalizedBox(0.0, 0.0, 0.0, 0.0) 48 | return { 49 | "box1": box1, 50 | "box2": box2, 51 | "box3": box3, 52 | "box4": box4, 53 | "box5": box5, 54 | "box6": box6, 55 | "box7": box7, 56 | "box8": box8, 57 | "box9": box9, 58 | "box10": box10, 59 | "box11": box11, 60 | }[key] 61 | 62 | 63 | def _get_mask_for_box(key: str) -> npt.NDArray[np.bool_]: 64 | mask_size = 224 65 | mask = np.zeros((224, 224), dtype=np.bool_) 66 | box = _get_box(key) 67 | box_coord = (box.x_min, box.y_min, box.x_max, box.y_max) 68 | x1, y1, x2, y2 = (np.array(box_coord) * mask_size).astype(int) 69 | mask[x1:x2, y1:y2] = True 70 | return mask 71 | 72 | 73 | def test_get_mask_from_boxes() -> None: 74 | boxes: list[str] = _get_box_names() 75 | for box in boxes: 76 | expected = _get_mask_for_box(box) 77 | actual = get_mask_from_boxes([_get_box(box)]) 78 | assert np.array_equal(actual, expected) 79 | 80 | 81 | def get_iou_precision_recall_for_box(key1: str, key2: str) -> dict[str, float]: #: key1 is pred box and key2 is gt box 82 | box1_mask = _get_mask_for_box(key1) 83 | box2_mask = _get_mask_for_box(key2) 84 | pred_area = box1_mask.sum() 85 | true_area = box2_mask.sum() 86 | inter_area = (box1_mask & box2_mask).sum() 87 | union_area = (box1_mask | box2_mask).sum() 88 | iou = inter_area / union_area 89 | if pred_area > 0: 90 | precision = inter_area / pred_area 91 | else: 92 | precision = np.nan 93 | recall = inter_area / true_area 94 | return { 95 | IOU: iou, 96 | PRECISION: precision, 97 | RECALL: recall, 98 | } 99 | 100 | 101 | def test_compute_box_metrics() -> None: 102 | box_pairs_list: list[tuple[str, str]] = _get_box_pairs() 103 | # Test for cases where the pred. and gt. has one box. Tests for perfect overlap/zero overlap & intermediate cases 104 | for box_pairs in box_pairs_list: 105 | expected = get_iou_precision_recall_for_box(box_pairs[1], box_pairs[0]) 106 | actual = compute_box_metrics([_get_box(box_pairs[1])], [_get_box(box_pairs[0])], mask_size=224) 107 | assert_score(IOU, expected_score=expected[IOU], actual_score=round(actual[IOU], 2)) 108 | assert_score(PRECISION, expected_score=expected[PRECISION], actual_score=round(actual[PRECISION], 2)) 109 | assert_score(RECALL, expected_score=expected[RECALL], actual_score=round(actual[RECALL], 2)) 110 | 111 | # Test for cases where the pred. and gt. has two boxes and perfect overlap 112 | actual = compute_box_metrics( 113 | [_get_box(box_pairs_list[0][0]), _get_box(box_pairs_list[0][1])], 114 | [_get_box(box_pairs_list[0][0]), _get_box(box_pairs_list[0][1])], 115 | mask_size=224, 116 | ) 117 | assert_score(IOU, expected_score=1.0, actual_score=round(actual[IOU], 2)) 118 | assert_score(PRECISION, expected_score=1.0, actual_score=round(actual[PRECISION], 2)) 119 | assert_score(RECALL, expected_score=1.0, actual_score=round(actual[RECALL], 2)) 120 | 121 | # Test for cases where the pred. and gt. has two boxes and zero overlap 122 | actual = compute_box_metrics( 123 | [_get_box(box_pairs_list[2][0]), _get_box(box_pairs_list[2][0])], 124 | [_get_box(box_pairs_list[2][1]), _get_box(box_pairs_list[2][1])], 125 | mask_size=224, 126 | ) 127 | assert_score(IOU, expected_score=0, actual_score=round(actual[IOU], 2)) 128 | assert_score(PRECISION, expected_score=0, actual_score=round(actual[PRECISION], 2)) 129 | assert_score(RECALL, expected_score=0, actual_score=round(actual[RECALL], 2)) 130 | 131 | # Test for cases where the pred. and gt. has two boxes and intermediate overlap 132 | actual = compute_box_metrics( 133 | [_get_box(box_pairs_list[0][0]), _get_box(box_pairs_list[0][1])], 134 | [_get_box(box_pairs_list[1][0]), _get_box(box_pairs_list[1][1])], 135 | mask_size=224, 136 | ) 137 | assert_score(IOU, expected_score=0.06, actual_score=round(actual[IOU], 2)) 138 | assert_score(PRECISION, expected_score=0.14, actual_score=round(actual[PRECISION], 2)) 139 | assert_score(RECALL, expected_score=0.11, actual_score=round(actual[RECALL], 2)) 140 | 141 | 142 | def test_compute_box_metrics_zero_area_box() -> None: 143 | # Make sure the metrics produce what we expect when there is a zero area box in both gt and prediction 144 | pred_has_zero_area = compute_box_metrics( 145 | pred_boxes=[_get_box("box11")], true_boxes=[_get_box("box1")], mask_size=224 146 | ) 147 | assert_score(IOU, expected_score=0, actual_score=round(pred_has_zero_area[IOU], 2)) 148 | assert_score(PRECISION, expected_score=np.nan, actual_score=round(pred_has_zero_area[PRECISION], 2)) 149 | assert_score(RECALL, expected_score=0, actual_score=round(pred_has_zero_area[RECALL], 2)) 150 | 151 | gt_has_zero_area = compute_box_metrics(pred_boxes=[_get_box("box1")], true_boxes=[_get_box("box11")], mask_size=224) 152 | assert_score(IOU, expected_score=0, actual_score=round(gt_has_zero_area[IOU], 2)) 153 | assert_score(PRECISION, expected_score=0, actual_score=round(gt_has_zero_area[PRECISION], 2)) 154 | assert_score(RECALL, expected_score=np.nan, actual_score=round(gt_has_zero_area[RECALL], 2)) 155 | --------------------------------------------------------------------------------