├── .github ├── actions │ └── setup-poetry-environment │ │ └── action.yml └── workflows │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── dev-setup.sh ├── images ├── embeddings.gif ├── encord.png ├── eval-usage.svg ├── logo_v1.webp └── tti-eval-banner.png ├── notebooks ├── tti_eval_Bring_Your_Own_Dataset_From_Encord_Quickstart.ipynb ├── tti_eval_Bring_Your_Own_Model_From_Hugging_Face_Quickstart.ipynb ├── tti_eval_CLI_Quickstart.ipynb └── tti_eval_Python_Quickstart.ipynb ├── poetry.lock ├── pyproject.toml ├── sources ├── dataset-definition-schema.json ├── datasets │ ├── Alzheimer-MRI.json │ ├── LungCancer4Types.json │ ├── chest-xray-classification.json │ ├── plants.json │ ├── rsicd.json │ ├── skin-cancer.json │ └── sports-classification.json ├── model-definition-schema.json └── models │ ├── apple.json │ ├── bioclip.json │ ├── clip.json │ ├── eva-clip.json │ ├── fashion.json │ ├── plip.json │ ├── pubmed.json │ ├── rsicd-encord.json │ ├── rsicd.json │ ├── siglip_large.json │ ├── siglip_small.json │ ├── street.json │ ├── tinyclip.json │ └── vit-b-32-laion2b.json ├── tests ├── common │ └── test_embedding_definition.py └── evaluation │ └── test_knn.py └── tti_eval ├── cli ├── __init__.py ├── main.py └── utils.py ├── common ├── __init__.py ├── base.py ├── numpy_types.py └── string_utils.py ├── compute.py ├── constants.py ├── dataset ├── __init__.py ├── base.py ├── provider.py ├── types │ ├── __init__.py │ ├── encord_ds.py │ └── hugging_face.py └── utils.py ├── evaluation ├── __init__.py ├── base.py ├── evaluator.py ├── image_retrieval.py ├── knn.py ├── linear_probe.py ├── utils.py └── zero_shot.py ├── model ├── __init__.py ├── base.py ├── provider.py └── types │ ├── __init__.py │ ├── hugging_face.py │ ├── local_clip_model.py │ └── open_clip_model.py ├── plotting ├── __init__.py ├── animation.py └── reduction.py └── utils.py /.github/actions/setup-poetry-environment/action.yml: -------------------------------------------------------------------------------- 1 | name: "Setup test environment" 2 | description: "Sets up Python, Poetry and dependencies" 3 | 4 | inputs: 5 | python: 6 | description: "Python version to use" 7 | default: "3.11" 8 | required: false 9 | poetry: 10 | description: "Poetry version to use" 11 | default: 1.7.1 12 | required: false 13 | 14 | runs: 15 | using: "composite" 16 | 17 | steps: 18 | - uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ inputs.python }} 21 | 22 | - uses: snok/install-poetry@v1 23 | with: 24 | version: ${{ inputs.poetry }} 25 | virtualenvs-create: true 26 | virtualenvs-in-project: true 27 | 28 | - name: Load cached venv 29 | id: cached-poetry-dependencies 30 | uses: actions/cache@v4 31 | with: 32 | path: .venv 33 | key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 34 | 35 | - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 36 | run: | 37 | poetry lock --no-update 38 | poetry install --no-interaction 39 | shell: bash 40 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: 9 | - main 10 | workflow_dispatch: 11 | 12 | env: 13 | PYTHON_VERSION: 3.11 14 | 15 | jobs: 16 | pre-commit: 17 | name: Linting and type checking 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Setup poetry environment 22 | uses: ./.github/actions/setup-poetry-environment 23 | - name: Run linting, type checking and testing 24 | uses: pre-commit/action@v3.0.0 25 | with: 26 | extra_args: "--all-files --hook-stage=push" 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/vim,python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=vim,python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | .idea/ 166 | 167 | #Vscode 168 | .vscode 169 | 170 | ### Python Patch ### 171 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 172 | poetry.toml 173 | 174 | # ruff 175 | .ruff_cache/ 176 | 177 | # LSP config files 178 | pyrightconfig.json 179 | 180 | ### Vim ### 181 | # Swap 182 | [._]*.s[a-v][a-z] 183 | !*.svg # comment out if you don't need vector files 184 | [._]*.sw[a-p] 185 | [._]s[a-rt-v][a-z] 186 | [._]ss[a-gi-z] 187 | [._]sw[a-p] 188 | 189 | # Session 190 | Session.vim 191 | Sessionx.vim 192 | 193 | # Temporary 194 | .netrwhist 195 | *~ 196 | # Auto-generated tag files 197 | tags 198 | # Persistent undo 199 | [._]*.un~ 200 | 201 | # End of https://www.toptal.com/developers/gitignore/api/vim,python 202 | output 203 | 204 | # Mac's specifics 205 | .DS_Store 206 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.11 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: f71fa2c1f9cf5cb705f73dffe4b21f7c61470ba9 # 4.4.0 7 | hooks: 8 | - id: check-yaml 9 | - id: check-json 10 | - id: check-toml 11 | - id: end-of-file-fixer 12 | - id: trailing-whitespace 13 | - repo: https://github.com/pappasam/toml-sort 14 | rev: b9b6210da457c38122995e434b314f4c4a4a923e # 0.23.1 15 | hooks: 16 | - id: toml-sort-fix 17 | files: ^.+.toml$ 18 | - repo: https://github.com/astral-sh/ruff-pre-commit 19 | rev: v0.2.1 20 | hooks: 21 | - id: ruff 22 | args: 23 | - --fix 24 | - --exit-non-zero-on-fix 25 | # - --ignore=E501 # line-too-long 26 | # - --ignore=F631 # assert-tuple 27 | # - --ignore=E741 # ambiguous-variable-name 28 | - id: ruff-format 29 | files: ^src\/.+\.py$ 30 | # - repo: https://github.com/pre-commit/mirrors-mypy 31 | # rev: v1.8.0 32 | # hooks: 33 | # - id: mypy 34 | default_stages: [push] 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 | CLI Quickstart In Colab 4 | Python Versions 5 | License 6 | 7 | PRs Welcome 8 |
9 |
10 | Encord Notebooks 11 | Join us on Slack 12 | Twitter Follow 13 |
14 |
15 | 16 |

17 | tti-eval logo banner 18 |

19 | 20 | Welcome to `tti-eval`, a repository for benchmarking text-to-image models **on your own data**! 21 | 22 | > Evaluate your (or HF) text-to-image embedding models like [CLIP][openai/clip-vit-large-patch14-336] from OpenAI against your (or HF) datasets to estimate how well the model will perform on your classification dataset. 23 | 24 | ## TLDR 25 | 26 | With this library, you can take an embedding model intended for jointly embedding images and text (like [CLIP][openai/clip-vit-large-patch14-336]) and compute metrics for how well such model performs on classifying your custom dataset. 27 | What you will do is 28 | 29 | 0. [Install](#installation) `tti-eval` 30 | 1. [Compute embeddings](#embeddings-generation) of a dataset with a model 31 | 2. Do an [evaluation](#model-evaluation) of the model against the dataset 32 | 33 | You can easily benchmark different models and datasets against each other. Here is an example: 34 | 35 |
36 | An animation showing how to use the CLI to evaluate embedding models 37 |
38 | 39 | ## Installation 40 | 41 | > `tti-eval` requires [Python 3.10+](https://www.python.org/downloads/release/python-31014/) and [Poetry](https://python-poetry.org/docs/#installation). 42 | 43 | 1. Clone the repository: 44 | ``` 45 | git clone https://github.com/encord-team/text-to-image-eval.git 46 | ``` 47 | 2. Navigate to the project directory: 48 | ``` 49 | cd text-to-image-eval 50 | ``` 51 | 3. Install the required dependencies: 52 | ``` 53 | poetry shell 54 | poetry install 55 | ``` 56 | 4. Add environment variables: 57 | ``` 58 | export TTI_EVAL_CACHE_PATH=$PWD/.cache 59 | export TTI_EVAL_OUTPUT_PATH=$PWD/output 60 | export ENCORD_SSH_KEY_PATH= 61 | ``` 62 | 63 | ## CLI Quickstart 64 | 65 | 66 | CLI Quickstart In Colab 67 | 68 | 69 | ### Embeddings Generation 70 | 71 | To build embeddings, run the CLI command `tti-eval build`. 72 | This command allows to interactively select the model and dataset combinations on which to build the embeddings. 73 | 74 | Alternatively, you can choose known (model, dataset) pairs using the `--model-dataset` option. For example: 75 | 76 | ``` 77 | tti-eval build --model-dataset clip/Alzheimer-MRI --model-dataset bioclip/Alzheimer-MRI 78 | ``` 79 | 80 | ### Model Evaluation 81 | 82 | To evaluate models, use the CLI command `tti-eval evaluate`. 83 | This command enables interactive selection of model and dataset combinations for evaluation. 84 | 85 | Alternatively, you can specify known (model, dataset) pairs using the `--model-dataset` option. For example: 86 | 87 | ``` 88 | tti-eval evaluate --model-dataset clip/Alzheimer-MRI --model-dataset bioclip/Alzheimer-MRI 89 | ``` 90 | 91 | The evaluation results can be exported to a CSV file using the `--save` option. 92 | They will be saved on a folder at the location specified by the environment variable `TTI_EVAL_OUTPUT_PATH`. 93 | By default, exported evaluation results are stored to the `output/evaluations` folder within the repository. 94 | 95 | ### Embeddings Animation 96 | 97 | To create 2D animations of the embeddings, use the CLI command `tti-eval animate`. 98 | This command allows to visualise the reduction of embeddings from two models on the same dataset. 99 | 100 | You have the option to interactively select two models and a dataset for visualization. 101 | Alternatively, you can specify the models and dataset as arguments. For example: 102 | 103 | ``` 104 | tti-eval animate clip bioclip Alzheimer-MRI 105 | ``` 106 | 107 | The animations will be saved on a folder at the location specified by the environment variable `TTI_EVAL_OUTPUT_PATH`. 108 | By default, animations are stored in the `output/animations` folder within the repository. 109 | To interactively explore the animation in a temporary session, use the `--interactive` flag. 110 | 111 |
112 | Transition between embedding plots 113 |
114 | 115 | > ℹ️ You can also carry out these operations using Python. Explore our Python Quickstart guide for more details. 116 | > 117 | > 118 | > Python Quickstart In Colab 119 | > 120 | 121 | ## Some Example Results 122 | 123 | One example of where this `tti-eval` is useful is to test different open-source models against different open-source datasets within a specific domain. 124 | Below, we focused on the medical domain. We evaluate nine different models of which three of them are domain specific. 125 | The models are evaluated against four different medical datasets. Note, Further down this page, you will find links to all models and datasets. 126 | 127 |
128 | An animation showing how to use the CLI to evaluate embedding models 129 |
130 |

Figure 1: Linear probe accuracy across four different medical datasets. General purpose models are colored green while models trained for the medical domain are colored red. 131 |

132 |
133 | 134 |
135 | The raw numbers from the experiment 136 | 137 | ### Weighted KNN Accuracy 138 | 139 | | Model/Dataset | Alzheimer-MRI | LungCancer4Types | chest-xray-classification | skin-cancer | 140 | | :--------------- | :-----------: | :--------------: | :-----------------------: | :---------: | 141 | | apple | 0.6777 | 0.6633 | 0.9687 | 0.7985 | 142 | | bioclip | 0.8952 | 0.7800 | 0.9771 | 0.7961 | 143 | | clip | 0.6986 | 0.6867 | 0.9727 | 0.7891 | 144 | | plip | 0.8021 | 0.6767 | 0.9599 | 0.7860 | 145 | | pubmed | 0.8503 | 0.5767 | 0.9725 | 0.7637 | 146 | | siglip_large | 0.6908 | 0.6533 | 0.9695 | 0.7947 | 147 | | siglip_small | 0.6992 | 0.6267 | 0.9646 | 0.7780 | 148 | | tinyclip | 0.7389 | 0.5900 | 0.9673 | 0.7589 | 149 | | vit-b-32-laion2b | 0.7559 | 0.5967 | 0.9654 | 0.7738 | 150 | 151 | ### Zero-shot Accuracy 152 | 153 | | Model/Dataset | Alzheimer-MRI | LungCancer4Types | chest-xray-classification | skin-cancer | 154 | | :--------------- | :-----------: | :--------------: | :-----------------------: | :---------: | 155 | | apple | 0.4460 | 0.2367 | 0.7381 | 0.3594 | 156 | | bioclip | 0.3092 | 0.2200 | 0.7356 | 0.0431 | 157 | | clip | 0.4857 | 0.2267 | 0.7381 | 0.1955 | 158 | | plip | 0.0104 | 0.2267 | 0.3873 | 0.0797 | 159 | | pubmed | 0.3099 | 0.2867 | 0.7501 | 0.1127 | 160 | | siglip_large | 0.4876 | 0.3000 | 0.5950 | 0.0421 | 161 | | siglip_small | 0.4102 | 0.0767 | 0.7381 | 0.1541 | 162 | | tinyclip | 0.2526 | 0.2533 | 0.7313 | 0.1113 | 163 | | vit-b-32-laion2b | 0.3594 | 0.1533 | 0.7378 | 0.1228 | 164 | 165 | --- 166 | 167 | ### Image-to-image Retrieval 168 | 169 | | Model/Dataset | Alzheimer-MRI | LungCancer4Types | chest-xray-classification | skin-cancer | 170 | | :--------------- | :-----------: | :--------------: | :-----------------------: | :---------: | 171 | | apple | 0.4281 | 0.2786 | 0.8835 | 0.6437 | 172 | | bioclip | 0.4535 | 0.3496 | 0.8786 | 0.6278 | 173 | | clip | 0.4247 | 0.2812 | 0.8602 | 0.6347 | 174 | | plip | 0.4406 | 0.3174 | 0.8372 | 0.6289 | 175 | | pubmed | 0.4445 | 0.3022 | 0.8621 | 0.6228 | 176 | | siglip_large | 0.4232 | 0.2743 | 0.8797 | 0.6466 | 177 | | siglip_small | 0.4303 | 0.2613 | 0.8660 | 0.6348 | 178 | | tinyclip | 0.4361 | 0.2833 | 0.8379 | 0.6098 | 179 | | vit-b-32-laion2b | 0.4378 | 0.2934 | 0.8551 | 0.6189 | 180 | 181 | --- 182 | 183 | ### Linear Probe Accuracy 184 | 185 | | Model/Dataset | Alzheimer-MRI | LungCancer4Types | chest-xray-classification | skin-cancer | 186 | | :--------------- | :-----------: | :--------------: | :-----------------------: | :---------: | 187 | | apple | 0.5482 | 0.5433 | 0.9362 | 0.7662 | 188 | | bioclip | 0.6139 | 0.6600 | 0.9433 | 0.7933 | 189 | | clip | 0.5547 | 0.5700 | 0.9362 | 0.7704 | 190 | | plip | 0.5469 | 0.5267 | 0.9261 | 0.7630 | 191 | | pubmed | 0.5482 | 0.5400 | 0.9278 | 0.7269 | 192 | | siglip_large | 0.5286 | 0.5200 | 0.9496 | 0.7697 | 193 | | siglip_small | 0.5449 | 0.4967 | 0.9327 | 0.7606 | 194 | | tinyclip | 0.5651 | 0.5733 | 0.9280 | 0.7484 | 195 | | vit-b-32-laion2b | 0.5684 | 0.5933 | 0.9302 | 0.7578 | 196 | 197 | --- 198 | 199 |
200 | 201 | ## Datasets 202 | 203 | 204 | Datasets Quickstart In Colab 205 | 206 | 207 | This repository contains classification datasets sourced from [Hugging Face](https://huggingface.co/datasets) and [Encord](https://app.encord.com/projects). 208 | 209 | > ⚠️ Currently, only image and image groups datasets are supported, with potential for future expansion to include video datasets. 210 | 211 | | Dataset Title | Implementation | HF Dataset | 212 | | :------------------------ | :------------------------------ | :----------------------------------------------------------------------------------- | 213 | | Alzheimer-MRI | [Hugging Face][hf-dataset-impl] | [Falah/Alzheimer_MRI][Falah/Alzheimer_MRI] | 214 | | chest-xray-classification | [Hugging Face][hf-dataset-impl] | [trpakov/chest-xray-classification][trpakov/chest-xray-classification] | 215 | | LungCancer4Types | [Hugging Face][hf-dataset-impl] | [Kabil007/LungCancer4Types][Kabil007/LungCancer4Types] | 216 | | plants | [Hugging Face][hf-dataset-impl] | [sampath017/plants][sampath017/plants] | 217 | | skin-cancer | [Hugging Face][hf-dataset-impl] | [marmal88/skin_cancer][marmal88/skin_cancer] | 218 | | sports-classification | [Hugging Face][hf-dataset-impl] | [HES-XPLAIN/SportsImageClassification][HES-XPLAIN/SportsImageClassification] | 219 | | rsicd | [Encord][encord-dataset-impl] | \* Requires ssh key and access to the Encord project | 220 | 221 | ### Add a Dataset from a Known Source 222 | 223 | To register a dataset from a known source, you can include the dataset definition as a JSON file in the `sources/datasets` folder. 224 | The definition will be validated against the schema defined by the `tti_eval.dataset.base.DatasetDefinitionSpec` Pydantic class to ensure that it adheres to the required structure. 225 | You can find the explicit schema in `sources/dataset-definition-schema.json`. 226 | 227 | Check out the declarations of known sources at `tti_eval.dataset.types` and refer to the existing dataset definitions in the `sources/datasets` folder for guidance. 228 | Below is an example of a dataset definition for the [plants](https://huggingface.co/datasets/sampath017/plants) dataset sourced from Hugging Face: 229 | 230 | ```json 231 | { 232 | "dataset_type": "HFDataset", 233 | "title": "plants", 234 | "title_in_source": "sampath017/plants" 235 | } 236 | ``` 237 | 238 | In each dataset definition, the `dataset_type` and `title` fields are required. 239 | The `dataset_type` indicates the name of the class that represents the source, while `title` serves as a reference for the dataset on this platform. 240 | 241 | For Hugging Face datasets, the `title_in_source` field should store the title of the dataset as it appears on the Hugging Face website. 242 | 243 | For datasets sourced from Encord, other set of fields are required. These include `project_hash`, which contains the hash of the project, and `classification_hash`, which contains the hash of the radio-button (multiclass) classification used in the labels. 244 | 245 | ### Add a Dataset Source 246 | 247 | Expanding the dataset sources involves two key steps: 248 | 249 | 1. Create a dataset class that inherits from `tti_eval.dataset.Dataset` and specifies the input requirements for extracting data from the new source. 250 | This class should encapsulate the necessary logic for fetching and processing dataset elements. 251 | 2. Generate a dataset definition in JSON format and save it in the `sources/datasets` folder, following the guidelines outlined in the previous section. 252 | Ensure that the definition includes essential fields such as `dataset_type`, `title`, and `module_path`, which points to the file containing the dataset class implementation. 253 | 254 | > It's recommended to store the file containing the dataset class implementation in the `tti_eval/dataset/types` folder and add a reference to the class in the `__init__.py` file in the same folder. 255 | > This ensures that the new dataset type is accessible by default for all dataset definitions, eliminating the need to explicitly state the `module_path` field for datasets from such source. 256 | 257 | ### Programmatically Add a Dataset 258 | 259 | Alternatively, you can programmatically add a dataset, which will be available only for the current session, using the `register_dataset()` method of the `tti_eval.dataset.DatasetProvider` class. 260 | 261 | Here is an example of how to register a dataset from Hugging Face using Python code: 262 | 263 | ```python 264 | from tti_eval.dataset import DatasetProvider, Split 265 | from tti_eval.dataset.types import HFDataset 266 | 267 | DatasetProvider.register_dataset(HFDataset, "plants", title_in_source="sampath017/plants") 268 | ds = DatasetProvider.get_dataset("plants", split=Split.ALL) 269 | print(len(ds)) # Returns: 219 270 | ``` 271 | 272 | ### Remove a Dataset 273 | 274 | To permanently remove a dataset, simply delete the corresponding JSON file stores in the `sources/datasets` folder. 275 | This action removes the dataset from the list of available datasets in the CLI, disabling the option to create any further embedding using its data. 276 | However, all embeddings previously built on that dataset will remain intact and available for other tasks such as evaluation and animation. 277 | 278 | ## Models 279 | 280 | 281 | Models Quickstart In Colab 282 | 283 | 284 | This repository contains models sourced from [Hugging Face](https://huggingface.co/models), [OpenCLIP](https://github.com/mlfoundations/open_clip) and local implementations based on OpenCLIP models. 285 | 286 | _TODO_: Some more prose about what's the difference between implementations. 287 | 288 | ### Hugging Face Models 289 | 290 | | Model Title | Implementation | HF Model | 291 | | :--------------- | :---------------------------- | :--------------------------------------------------------------------------------------------- | 292 | | apple | [OpenCLIP][open-model-impl] | [apple/DFN5B-CLIP-ViT-H-14][apple/DFN5B-CLIP-ViT-H-14] | 293 | | bioclip | [OpenCLIP][open-model-impl] | [imageomics/bioclip][imageomics/bioclip] | 294 | | eva-clip | [OpenCLIP][open-model-impl] | [BAAI/EVA-CLIP-8B-448][BAAI/EVA-CLIP-8B-448] | 295 | | vit-b-32-laion2b | [OpenCLIP][local-model-impl] | [laion/CLIP-ViT-B-32-laion2B-s34B-b79K][laion/CLIP-ViT-B-32-laion2B-s34B-b79K] | 296 | | clip | [Hugging Face][hf-model-impl] | [openai/clip-vit-large-patch14-336][openai/clip-vit-large-patch14-336] | 297 | | fashion | [Hugging Face][hf-model-impl] | [patrickjohncyh/fashion-clip][patrickjohncyh/fashion-clip] | 298 | | plip | [Hugging Face][hf-model-impl] | [vinid/plip][vinid/plip] | 299 | | pubmed | [Hugging Face][hf-model-impl] | [flaviagiammarino/pubmed-clip-vit-base-patch32][flaviagiammarino/pubmed-clip-vit-base-patch32] | 300 | | rsicd | [Hugging Face][hf-model-impl] | [flax-community/clip-rsicd][flax-community/clip-rsicd] | 301 | | siglip_large | [Hugging Face][hf-model-impl] | [google/siglip-large-patch16-256][google/siglip-large-patch16-256] | 302 | | siglip_small | [Hugging Face][hf-model-impl] | [google/siglip-base-patch16-224][google/siglip-base-patch16-224] | 303 | | street | [Hugging Face][hf-model-impl] | [geolocal/StreetCLIP][geolocal/StreetCLIP] | 304 | | tinyclip | [Hugging Face][hf-model-impl] | [wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M][wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M] | 305 | 306 | ### Locally Trained Models 307 | 308 | | Model Title | Implementation | Weights | 309 | | :----------- | :-------------------------------- | :------ | 310 | | rsicd-encord | [LocalOpenCLIP][local-model-impl] | - | 311 | 312 | ### Add a Model from a Known Source 313 | 314 | To register a model from a known source, you can include the model definition as a JSON file in the `sources/models` folder. 315 | The definition will be validated against the schema defined by the `tti_eval.model.base.ModelDefinitionSpec` Pydantic class to ensure that it adheres to the required structure. 316 | You can find the explicit schema in `sources/model-definition-schema.json`. 317 | 318 | Check out the declarations of known sources at `tti_eval.model.types` and refer to the existing model definitions in the `sources/models` folder for guidance. 319 | Below is an example of a model definition for the [clip](https://huggingface.co/openai/clip-vit-large-patch14-336) model sourced from Hugging Face: 320 | 321 | ```json 322 | { 323 | "model_type": "HFModel", 324 | "title": "clip", 325 | "title_in_source": "openai/clip-vit-large-patch14-336" 326 | } 327 | ``` 328 | 329 | In each model definition, the `model_type` and `title` fields are required. 330 | The `model_type` indicates the name of the class that represents the source, while `title` serves as a reference for the model on this platform. 331 | 332 | For non-local models, the `title_in_source` field should store the title of the model as it appears in the source. 333 | For model checkpoints in local storage, the `title_in_source` field should store the title of the model used to train it. 334 | Additionally, on models sourced from OpenCLIP the optional `pretrained` field may be needed. See the list of OpenCLIP models [here](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md). 335 | 336 | ### Add a Model Source 337 | 338 | Expanding the model sources involves two key steps: 339 | 340 | 1. Create a model class that inherits from `tti_eval.model.Model` and specifies the input requirements for loading models from the new source. 341 | This class should encapsulate the necessary logic for processing model elements and generating embeddings. 342 | 2. Generate a model definition in JSON format and save it in the `sources/models` folder, following the guidelines outlined in the previous section. 343 | Ensure that the definition includes essential fields such as `model_type`, `title`, and `module_path`, which points to the file containing the model class implementation. 344 | 345 | > It's recommended to store the file containing the model class implementation in the `tti_eval/model/types` folder and add a reference to the class in the `__init__.py` file in the same folder. 346 | > This ensures that the new model type is accessible by default for all model definitions, eliminating the need to explicitly state the `module_path` field for models from such source. 347 | 348 | ### Programmatically Add a Model 349 | 350 | Alternatively, you can programmatically add a model, which will be available only for the current session, using the `register_model()` method of the `tti_eval.model.ModelProvider` class. 351 | 352 | Here is an example of how to register a model from Hugging Face using Python code: 353 | 354 | ```python 355 | from tti_eval.model import ModelProvider 356 | from tti_eval.model.types import HFModel 357 | 358 | ModelProvider.register_model(HFModel, "clip", title_in_source="openai/clip-vit-large-patch14-336") 359 | model = ModelProvider.get_model("clip") 360 | print(model.title, model.title_in_source) # Returns: clip openai/clip-vit-large-patch14-336 361 | ``` 362 | 363 | ### Remove a Model 364 | 365 | To permanently remove a model, simply delete the corresponding JSON file stores in the `sources/models` folder. 366 | This action removes the model from the list of available models in the CLI, disabling the option to create any further embedding with it. 367 | However, all embeddings previously built with that model will remain intact and available for other tasks such as evaluation and animation. 368 | 369 | ## Set Up the Development Environment 370 | 371 | 1. Create the virtual environment, add dev dependencies and set up pre-commit hooks. 372 | ``` 373 | ./dev-setup.sh 374 | ``` 375 | 2. Add environment variables: 376 | ``` 377 | export TTI_EVAL_CACHE_PATH=$PWD/.cache 378 | export TTI_EVAL_OUTPUT_PATH=$PWD/output 379 | export ENCORD_SSH_KEY_PATH= 380 | ``` 381 | 382 | ## Contributing 383 | 384 | Contributions are welcome! 385 | Please feel free to open an issue or submit a pull request with your suggestions, bug fixes, or new features. 386 | 387 | ### Adding Dataset Sources 388 | 389 | To contribute by adding dataset sources, follow these steps: 390 | 391 | 1. Store the file containing the new dataset class implementation in the `tti_eval/dataset/types` folder. 392 | Don't forget to add a reference to the class in the `__init__.py` file in the same folder. 393 | This ensures that the new dataset type is accessible by default for all dataset definitions, eliminating the need to explicitly state the `module_path` field for datasets from such source. 394 | 2. Open a pull request with the necessary changes. Make sure to include tests validating that data retrieval, processing and usage are working as expected. 395 | 3. Document the addition of the dataset source, providing details on its structure, usage, and any specific considerations or instructions for integration. 396 | This ensures that users have clear guidance on how to leverage the new dataset source effectively. 397 | 398 | ### Adding Model Sources 399 | 400 | To contribute by adding model sources, follow these steps: 401 | 402 | 1. Store the file containing the new model class implementation in the `tti_eval/model/types` folder. 403 | Don't forget to add a reference to the class in the `__init__.py` file in the same folder. 404 | This ensures that the new model type is accessible by default for all model definitions, eliminating the need to explicitly state the `module_path` field for models from such source. 405 | 2. Open a pull request with the necessary changes. Make sure to include tests validating that model loading, processing and embedding generation are working as expected. 406 | 3. Document the addition of the model source, providing details on its structure, usage, and any specific considerations or instructions for integration. 407 | This ensures that users have clear guidance on how to leverage the new model source effectively. 408 | 409 | ## Known Issues 410 | 411 | 1. `autofaiss`: The project depends on the [autofaiss][autofaiss] library which can give some trouble on Windows. Please reach out or raise an issue with as many system and version details as possible if you encounter it. 412 | 413 | [Falah/Alzheimer_MRI]: https://huggingface.co/datasets/Falah/Alzheimer_MRI 414 | [trpakov/chest-xray-classification]: https://huggingface.co/datasets/trpakov/chest-xray-classification 415 | [Kabil007/LungCancer4Types]: https://huggingface.co/datasets/Kabil007/LungCancer4Types 416 | [sampath017/plants]: https://huggingface.co/datasets/sampath017/plants 417 | [marmal88/skin_cancer]: https://huggingface.co/datasets/marmal88/skin_cancer 418 | [HES-XPLAIN/SportsImageClassification]: https://huggingface.co/datasets/HES-XPLAIN/SportsImageClassification 419 | [apple/DFN5B-CLIP-ViT-H-14]: https://huggingface.co/apple/DFN5B-CLIP-ViT-H-14 420 | [imageomics/bioclip]: https://huggingface.co/imageomics/bioclip 421 | [openai/clip-vit-large-patch14-336]: https://huggingface.co/openai/clip-vit-large-patch14-336 422 | [BAAI/EVA-CLIP-8B-448]: https://huggingface.co/BAAI/EVA-CLIP-8B-448 423 | [patrickjohncyh/fashion-clip]: https://huggingface.co/patrickjohncyh/fashion-clip 424 | [vinid/plip]: https://huggingface.co/vinid/plip 425 | [flaviagiammarino/pubmed-clip-vit-base-patch32]: https://huggingface.co/flaviagiammarino/pubmed-clip-vit-base-patch32 426 | [flax-community/clip-rsicd]: https://huggingface.co/flax-community/clip-rsicd 427 | [google/siglip-large-patch16-256]: https://huggingface.co/google/siglip-large-patch16-256 428 | [google/siglip-base-patch16-224]: https://huggingface.co/google/siglip-base-patch16-224 429 | [geolocal/StreetCLIP]: https://huggingface.co/geolocal/StreetCLIP 430 | [wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M]: https://huggingface.co/wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M 431 | [laion/CLIP-ViT-B-32-laion2B-s34B-b79K]: https://huggingface.co/laion/CLIP-ViT-B-32-laion2B-s34B-b79K 432 | [open-model-impl]: https://github.com/encord-team/text-to-image-eval/blob/main/tti_eval/model/types/open_clip_model.py 433 | [hf-model-impl]: https://github.com/encord-team/text-to-image-eval/blob/main/tti_eval/model/types/hugging_face_clip.py 434 | [local-model-impl]: https://github.com/encord-team/text-to-image-eval/blob/main/tti_eval/model/types/local_clip_model.py 435 | [hf-dataset-impl]: https://github.com/encord-team/text-to-image-eval/blob/main/tti_eval/dataset/types/hugging_face.py 436 | [encord-dataset-impl]: https://github.com/encord-team/text-to-image-eval/blob/main/tti_eval/dataset/types/encord_ds.py 437 | [autofaiss]: https://pypi.org/project/autofaiss/ 438 | -------------------------------------------------------------------------------- /dev-setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | poetry env use 3.11 # Create the virtual environment if it does not exist 3 | source $(poetry env info --path)/bin/activate # Activate and enter the virtual environment 4 | poetry install --with=dev # Install dev dependencies 5 | pre-commit install --install-hooks --overwrite -t pre-push # Set up pre-commit hooks 6 | -------------------------------------------------------------------------------- /images/embeddings.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/images/embeddings.gif -------------------------------------------------------------------------------- /images/encord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/images/encord.png -------------------------------------------------------------------------------- /images/logo_v1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/images/logo_v1.webp -------------------------------------------------------------------------------- /images/tti-eval-banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/images/tti-eval-banner.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "poetry.core.masonry.api" 3 | requires = ["poetry-core"] 4 | 5 | [tool.poetry] 6 | authors = ["Cord Technologies Limited "] 7 | description = "Evaluating text-to-image and image-to-text retrieval models." 8 | name = "tti-eval" 9 | readme = "README.md" 10 | version = "0.1.0" 11 | packages = [{include = "tti_eval", from = "."}] 12 | 13 | [tool.poetry.dependencies] 14 | autofaiss = "^2.17.0" 15 | datasets = "^2.17.0" 16 | encord = "^0.1.108" 17 | pydantic = "^2.6.1" 18 | python = "^3.10" 19 | python-dotenv = "^1.0.1" 20 | scikit-learn = "^1.4" 21 | torch = "^2.2.0" 22 | torchvision = "^0.17.0" 23 | transformers = "^4.37.2" 24 | umap-learn = "^0.5.5" 25 | matplotlib = "^3.8.2" 26 | typer = "^0.9.0" 27 | inquirerpy = "^0.3.4" 28 | tabulate = "^0.9.0" 29 | open-clip-torch = "^2.24.0" 30 | natsort = "^8.4.0" 31 | 32 | [tool.poetry.group.dev.dependencies] 33 | mypy = "^1.8.0" 34 | ruff = "^0.2.1" 35 | toml-sort = "^0.23.1" 36 | pre-commit = "^3.6.1" 37 | pyfzf = "^0.3.1" 38 | ipython = "^8.22.1" 39 | ipdb = "^0.13.13" 40 | pytest = "^8.1.1" 41 | 42 | [tool.poetry.scripts] 43 | tti-eval = "tti_eval.cli.main:cli" 44 | 45 | [tool.ruff] 46 | line-length = 120 47 | target-version = "py311" 48 | 49 | [tool.ruff.lint] 50 | # ignore = [ 51 | # "B007", # Loop control variable {name} not used within loop body 52 | # "E501", # Checks for lines that exceed the specified maximum character length 53 | # "E741" # Ambiguous variable name: {name} (e.g. allow short names in one-line list comprehensions) 54 | # ] 55 | select = ["B", "E", "F", "I", "UP"] 56 | ignore = ["UP007"] 57 | 58 | [tool.ruff.lint.isort] 59 | known-first-party = ["src"] 60 | 61 | [tool.ruff.lint.per-file-ignores] 62 | "__init__.py" = ["F401"] 63 | -------------------------------------------------------------------------------- /sources/dataset-definition-schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$defs": { 3 | "Split": { 4 | "enum": [ 5 | "train", 6 | "validation", 7 | "test", 8 | "all" 9 | ], 10 | "title": "Split", 11 | "type": "string" 12 | } 13 | }, 14 | "additionalProperties": true, 15 | "properties": { 16 | "dataset_type": { 17 | "title": "Dataset Type", 18 | "type": "string" 19 | }, 20 | "title": { 21 | "title": "Title", 22 | "type": "string" 23 | }, 24 | "module_path": { 25 | "default": "../../tti_eval/dataset/types/__init__.py", 26 | "format": "path", 27 | "title": "Module Path", 28 | "type": "string" 29 | }, 30 | "split": { 31 | "anyOf": [ 32 | { 33 | "$ref": "#/$defs/Split" 34 | }, 35 | { 36 | "type": "null" 37 | } 38 | ], 39 | "default": null 40 | }, 41 | "title_in_source": { 42 | "anyOf": [ 43 | { 44 | "type": "string" 45 | }, 46 | { 47 | "type": "null" 48 | } 49 | ], 50 | "default": null, 51 | "title": "Title In Source" 52 | }, 53 | "cache_dir": { 54 | "anyOf": [ 55 | { 56 | "format": "path", 57 | "type": "string" 58 | }, 59 | { 60 | "type": "null" 61 | } 62 | ], 63 | "default": null, 64 | "title": "Cache Dir" 65 | } 66 | }, 67 | "required": [ 68 | "dataset_type", 69 | "title" 70 | ], 71 | "title": "DatasetDefinitionSpec", 72 | "type": "object" 73 | } 74 | -------------------------------------------------------------------------------- /sources/datasets/Alzheimer-MRI.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "HFDataset", 3 | "title": "Alzheimer-MRI", 4 | "title_in_source": "Falah/Alzheimer_MRI" 5 | } 6 | -------------------------------------------------------------------------------- /sources/datasets/LungCancer4Types.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "HFDataset", 3 | "title": "LungCancer4Types", 4 | "title_in_source": "Kabil007/LungCancer4Types", 5 | "revision": "a1aab924c6bed6b080fc85552fd7b39724931605" 6 | } 7 | -------------------------------------------------------------------------------- /sources/datasets/chest-xray-classification.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "HFDataset", 3 | "title": "chest-xray-classification", 4 | "title_in_source": "trpakov/chest-xray-classification", 5 | "name": "full", 6 | "target_feature": "labels" 7 | } 8 | -------------------------------------------------------------------------------- /sources/datasets/plants.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "HFDataset", 3 | "title": "plants", 4 | "title_in_source": "sampath017/plants" 5 | } 6 | -------------------------------------------------------------------------------- /sources/datasets/rsicd.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "EncordDataset", 3 | "title": "rsicd", 4 | "project_hash": "46ba913e-1428-48ef-be7f-2553e69bc1e6", 5 | "classification_hash": "4f6cf0c8" 6 | } 7 | -------------------------------------------------------------------------------- /sources/datasets/skin-cancer.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "HFDataset", 3 | "title": "skin-cancer", 4 | "title_in_source": "marmal88/skin_cancer", 5 | "target_feature": "dx" 6 | } 7 | -------------------------------------------------------------------------------- /sources/datasets/sports-classification.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_type": "HFDataset", 3 | "title": "sports-classification", 4 | "title_in_source": "HES-XPLAIN/SportsImageClassification" 5 | } 6 | -------------------------------------------------------------------------------- /sources/model-definition-schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "additionalProperties": true, 3 | "properties": { 4 | "model_type": { 5 | "title": "Model Type", 6 | "type": "string" 7 | }, 8 | "title": { 9 | "title": "Title", 10 | "type": "string" 11 | }, 12 | "module_path": { 13 | "default": "../../tti_eval/model/types/__init__.py", 14 | "format": "path", 15 | "title": "Module Path", 16 | "type": "string" 17 | }, 18 | "device": { 19 | "anyOf": [ 20 | { 21 | "type": "string" 22 | }, 23 | { 24 | "type": "null" 25 | } 26 | ], 27 | "default": null, 28 | "title": "Device" 29 | }, 30 | "title_in_source": { 31 | "anyOf": [ 32 | { 33 | "type": "string" 34 | }, 35 | { 36 | "type": "null" 37 | } 38 | ], 39 | "default": null, 40 | "title": "Title In Source" 41 | }, 42 | "cache_dir": { 43 | "anyOf": [ 44 | { 45 | "format": "path", 46 | "type": "string" 47 | }, 48 | { 49 | "type": "null" 50 | } 51 | ], 52 | "default": null, 53 | "title": "Cache Dir" 54 | } 55 | }, 56 | "required": [ 57 | "model_type", 58 | "title" 59 | ], 60 | "title": "ModelDefinitionSpec", 61 | "type": "object" 62 | } 63 | -------------------------------------------------------------------------------- /sources/models/apple.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "OpenCLIPModel", 3 | "title": "apple", 4 | "title_in_source": "hf-hub:apple/DFN5B-CLIP-ViT-H-14" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/bioclip.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "OpenCLIPModel", 3 | "title": "bioclip", 4 | "title_in_source": "hf-hub:imageomics/bioclip" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/clip.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "HFModel", 3 | "title": "clip", 4 | "title_in_source": "openai/clip-vit-large-patch14-336" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/eva-clip.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "OpenCLIPModel", 3 | "title": "eva-clip", 4 | "title_in_source": "BAAI/EVA-CLIP-8B-448" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/fashion.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "HFModel", 3 | "title": "fashion", 4 | "title_in_source": "patrickjohncyh/fashion-clip" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/plip.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "HFModel", 3 | "title": "plip", 4 | "title_in_source": "vinid/plip" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/pubmed.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "HFModel", 3 | "title": "pubmed", 4 | "title_in_source": "flaviagiammarino/pubmed-clip-vit-base-patch32" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/rsicd-encord.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "LocalCLIPModel", 3 | "title": "rsicd-encord", 4 | "title_in_source": "ViT-B-32" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/rsicd.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "HFModel", 3 | "title": "rsicd", 4 | "title_in_source": "flax-community/clip-rsicd" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/siglip_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "HFModel", 3 | "title": "siglip_large", 4 | "title_in_source": "google/siglip-large-patch16-256" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/siglip_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "HFModel", 3 | "title": "siglip_small", 4 | "title_in_source": "google/siglip-base-patch16-224" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/street.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "HFModel", 3 | "title": "street", 4 | "title_in_source": "geolocal/StreetCLIP" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/tinyclip.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "HFModel", 3 | "title": "tinyclip", 4 | "title_in_source": "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M" 5 | } 6 | -------------------------------------------------------------------------------- /sources/models/vit-b-32-laion2b.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "OpenCLIPModel", 3 | "title": "vit-b-32-laion2b", 4 | "title_in_source": "ViT-B-32", 5 | "pretrained": "laion2b_e16" 6 | } 7 | -------------------------------------------------------------------------------- /tests/common/test_embedding_definition.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tti_eval.common import EmbeddingDefinition, Embeddings, Split 4 | from tti_eval.compute import compute_embeddings_from_definition 5 | 6 | if __name__ == "__main__": 7 | def_ = EmbeddingDefinition( 8 | model="weird_this with / stuff \\whatever", 9 | dataset="hello there dataset", 10 | ) 11 | def_.embedding_path(Split.TRAIN).parent.mkdir(exist_ok=True, parents=True) 12 | 13 | images = np.random.randn(100, 20).astype(np.float32) 14 | labels = np.random.randint(0, 10, size=(100,)) 15 | classes = np.random.randn(10, 20).astype(np.float32) 16 | emb = Embeddings(images=images, labels=labels, classes=classes) 17 | # emb.to_file(def_.embedding_path(Split.TRAIN)) 18 | def_.save_embeddings(emb, split=Split.TRAIN, overwrite=True) 19 | new_emb = def_.load_embeddings(Split.TRAIN) 20 | 21 | assert new_emb is not None 22 | assert np.allclose(new_emb.images, images) 23 | assert np.allclose(new_emb.labels, labels) 24 | 25 | from pydantic import ValidationError 26 | 27 | try: 28 | Embeddings( 29 | images=np.random.randn(100, 20).astype(np.float32), 30 | labels=np.random.randint(0, 10, size=(100,)), 31 | classes=np.random.randn(10, 30).astype(np.float32), 32 | ) 33 | raise AssertionError() 34 | except ValidationError: 35 | pass 36 | 37 | def_ = EmbeddingDefinition( 38 | model="clip", 39 | dataset="LungCancer4Types", 40 | ) 41 | embeddings = compute_embeddings_from_definition(def_, Split.VALIDATION) 42 | def_.save_embeddings(embeddings, split=Split.VALIDATION, overwrite=True) 43 | -------------------------------------------------------------------------------- /tests/evaluation/test_knn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tti_eval.common import ClassArray, Embeddings, ProbabilityArray 4 | from tti_eval.evaluation.knn import WeightedKNNClassifier 5 | from tti_eval.evaluation.utils import normalize, softmax 6 | 7 | 8 | def slow_knn_predict( 9 | train_embeddings: Embeddings, 10 | val_embeddings: Embeddings, 11 | num_classes: int, 12 | k: int, 13 | ) -> tuple[ProbabilityArray, ClassArray]: 14 | train_image_embeddings = normalize(train_embeddings.images) 15 | val_image_embeddings = normalize(val_embeddings.images) 16 | n = val_image_embeddings.shape[0] 17 | 18 | # Retrieve the classes and distances of the `k` nearest training embeddings to each validation embedding 19 | all_dists = np.linalg.norm(val_image_embeddings[:, np.newaxis] - train_image_embeddings[np.newaxis, :], axis=-1) 20 | nearest_indices = np.argsort(all_dists, axis=1)[:, :k] 21 | dists = all_dists[np.arange(n)[:, np.newaxis], nearest_indices] 22 | nearest_classes = np.take(train_embeddings.labels, nearest_indices) 23 | 24 | # Calculate class votes from the distances (avoiding division by zero) 25 | max_value = np.finfo(np.float32).max 26 | scores = np.divide(1, np.square(dists), out=np.full_like(dists, max_value), where=dists != 0) 27 | weighted_count = np.zeros((n, num_classes), dtype=np.float32) 28 | for cls in range(num_classes): 29 | mask = nearest_classes == cls 30 | weighted_count[:, cls] = np.ma.masked_array(scores, mask=~mask).sum(axis=1) 31 | probabilities = softmax(weighted_count) 32 | return probabilities, np.argmax(probabilities, axis=1) 33 | 34 | 35 | def test_weighted_knn_classifier(): 36 | np.random.seed(42) 37 | 38 | train_embeddings = Embeddings( 39 | images=np.random.randn(100, 10).astype(np.float32), 40 | labels=np.random.randint(0, 10, size=(100,)), 41 | ) 42 | val_embeddings = Embeddings( 43 | images=np.random.randn(20, 10).astype(np.float32), 44 | labels=np.random.randint(0, 10, size=(20,)), 45 | ) 46 | knn = WeightedKNNClassifier( 47 | train_embeddings, 48 | val_embeddings, 49 | num_classes=10, 50 | ) 51 | probs, pred_classes = knn.predict() 52 | 53 | test_probs, test_pred_classes = slow_knn_predict( 54 | train_embeddings, 55 | val_embeddings, 56 | num_classes=knn.num_classes, 57 | k=knn.k, 58 | ) 59 | 60 | assert (pred_classes == test_pred_classes).all() 61 | assert np.isclose(probs, test_probs).all() 62 | -------------------------------------------------------------------------------- /tti_eval/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/tti_eval/cli/__init__.py -------------------------------------------------------------------------------- /tti_eval/cli/main.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Optional 2 | 3 | import matplotlib.pyplot as plt 4 | import typer 5 | from typer import Argument, Option, Typer 6 | 7 | from tti_eval.common import Split 8 | from tti_eval.compute import compute_embeddings_from_definition 9 | from tti_eval.utils import read_all_cached_embeddings 10 | 11 | from .utils import ( 12 | parse_raw_embedding_definitions, 13 | select_existing_embedding_definitions, 14 | select_from_all_embedding_definitions, 15 | ) 16 | 17 | cli = Typer(name="tti-eval", no_args_is_help=True, rich_markup_mode="markdown") 18 | 19 | 20 | @cli.command( 21 | "build", 22 | help="""Build embeddings. 23 | If no arguments are given, you will be prompted to select the model and dataset combinations for generating embeddings. 24 | You can use [TAB] to select multiple combinations and execute them sequentially. 25 | """, 26 | ) 27 | def build_command( 28 | model_datasets: Annotated[ 29 | Optional[list[str]], 30 | Option( 31 | "--model-dataset", 32 | help="Specify a model and dataset combination. Can be used multiple times. " 33 | "(model, dataset) pairs must be presented as 'MODEL/DATASET'.", 34 | ), 35 | ] = None, 36 | include_existing: Annotated[ 37 | bool, 38 | Option(help="Show combinations for which the embeddings have already been computed."), 39 | ] = False, 40 | by_dataset: Annotated[ 41 | bool, 42 | Option(help="Select dataset first, then model. Will only work if `model_dataset` is not specified."), 43 | ] = False, 44 | ): 45 | if len(model_datasets) > 0: 46 | definitions = parse_raw_embedding_definitions(model_datasets) 47 | else: 48 | definitions = select_from_all_embedding_definitions( 49 | include_existing=include_existing, 50 | by_dataset=by_dataset, 51 | ) 52 | 53 | splits = [Split.TRAIN, Split.VALIDATION] 54 | for embd_defn in definitions: 55 | for split in splits: 56 | try: 57 | embeddings = compute_embeddings_from_definition(embd_defn, split) 58 | embd_defn.save_embeddings(embeddings=embeddings, split=split, overwrite=True) 59 | print(f"Embeddings saved successfully to file at `{embd_defn.embedding_path(split)}`") 60 | except Exception as e: 61 | print(f"Failed to build embeddings for {embd_defn} on the specified split {split}") 62 | print(e) 63 | import traceback 64 | 65 | traceback.print_exc() 66 | 67 | 68 | @cli.command( 69 | "evaluate", 70 | help="""Evaluate embeddings performance. 71 | If no arguments are given, you will be prompted to select the model and dataset combinations to evaluate. 72 | Only (model, dataset) pairs whose embeddings have been built will be available for evaluation. 73 | You can use [TAB] to select multiple combinations and execute them sequentially. 74 | """, 75 | ) 76 | def evaluate_embeddings( 77 | model_datasets: Annotated[ 78 | Optional[list[str]], 79 | Option( 80 | "--model-dataset", 81 | help="Specify a model and dataset combination. Can be used multiple times. " 82 | "(model, dataset) pairs must be presented as 'MODEL/DATASET'.", 83 | ), 84 | ] = None, 85 | all_: Annotated[bool, Option("--all", "-a", help="Evaluate all models.")] = False, 86 | save: Annotated[bool, Option("--save", "-s", help="Save evaluation results to a CSV file.")] = False, 87 | ): 88 | from tti_eval.evaluation import ( 89 | I2IRetrievalEvaluator, 90 | LinearProbeClassifier, 91 | WeightedKNNClassifier, 92 | ZeroShotClassifier, 93 | ) 94 | from tti_eval.evaluation.evaluator import export_evaluation_to_csv, run_evaluation 95 | 96 | model_datasets = model_datasets or [] 97 | 98 | if all_: 99 | definitions = read_all_cached_embeddings(as_list=True) 100 | elif len(model_datasets) > 0: 101 | definitions = parse_raw_embedding_definitions(model_datasets) 102 | else: 103 | definitions = select_existing_embedding_definitions() 104 | 105 | models = [ZeroShotClassifier, LinearProbeClassifier, WeightedKNNClassifier, I2IRetrievalEvaluator] 106 | performances = run_evaluation(models, definitions) 107 | if save: 108 | export_evaluation_to_csv(performances) 109 | 110 | 111 | @cli.command( 112 | "animate", 113 | help="""Animate 2D embeddings from two different models on the same dataset. 114 | The interface will prompt you to choose which embeddings you want to use. 115 | """, 116 | ) 117 | def animate_embeddings( 118 | from_model: Annotated[Optional[str], Argument(help="Title of the model in the left side of the animation.")] = None, 119 | to_model: Annotated[Optional[str], Argument(help="Title of the model in the right side of the animation.")] = None, 120 | dataset: Annotated[Optional[str], Argument(help="Title of the dataset where the embeddings were computed.")] = None, 121 | interactive: Annotated[bool, Option(help="Interactive plot instead of animation.")] = False, 122 | reduction: Annotated[str, Option(help="Reduction type [pca, tsne, umap (default)].")] = "umap", 123 | ): 124 | from tti_eval.plotting.animation import EmbeddingDefinition, build_animation, save_animation_to_file 125 | 126 | all_none_input_args = from_model is None and to_model is None and dataset is None 127 | all_str_input_args = from_model is not None and to_model is not None and dataset is not None 128 | 129 | if not all_str_input_args and not all_none_input_args: 130 | typer.echo("Some arguments were provided. Please either provide all arguments or ignore them entirely.") 131 | raise typer.Abort() 132 | 133 | if all_none_input_args: 134 | defs = select_existing_embedding_definitions(by_dataset=True, count=2) 135 | from_def, to_def = defs[0], defs[1] 136 | else: 137 | from_def = EmbeddingDefinition(model=from_model, dataset=dataset) 138 | to_def = EmbeddingDefinition(model=to_model, dataset=dataset) 139 | 140 | res = build_animation(from_def, to_def, interactive=interactive, reduction=reduction) 141 | if res is None: 142 | plt.show() 143 | else: 144 | save_animation_to_file(res, from_def, to_def) 145 | 146 | 147 | @cli.command("list", help="List models and datasets. By default, only cached pairs are listed.") 148 | def list_models_datasets( 149 | all_: Annotated[ 150 | bool, 151 | Option("--all", "-a", help="List all models and datasets that are available via the tool."), 152 | ] = False, 153 | ): 154 | from tti_eval.dataset import DatasetProvider 155 | from tti_eval.model import ModelProvider 156 | 157 | if all_: 158 | datasets = DatasetProvider.list_dataset_titles() 159 | models = ModelProvider.list_model_titles() 160 | print(f"Available datasets are: {', '.join(datasets)}") 161 | print(f"Available models are: {', '.join(models)}") 162 | return 163 | 164 | defns = read_all_cached_embeddings(as_list=True) 165 | print(f"Available model_datasets pairs: {', '.join([str(defn) for defn in defns])}") 166 | 167 | 168 | if __name__ == "__main__": 169 | cli() 170 | -------------------------------------------------------------------------------- /tti_eval/cli/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | from typing import Literal, overload 3 | 4 | from InquirerPy import inquirer as inq 5 | from InquirerPy.base.control import Choice 6 | from natsort import natsorted, ns 7 | 8 | from tti_eval.common import EmbeddingDefinition 9 | from tti_eval.dataset import DatasetProvider 10 | from tti_eval.model import ModelProvider 11 | from tti_eval.utils import read_all_cached_embeddings 12 | 13 | 14 | @overload 15 | def _do_embedding_definition_selection( 16 | defs: list[EmbeddingDefinition], allow_multiple: Literal[True] = True 17 | ) -> list[EmbeddingDefinition]: 18 | ... 19 | 20 | 21 | @overload 22 | def _do_embedding_definition_selection( 23 | defs: list[EmbeddingDefinition], allow_multiple: Literal[False] 24 | ) -> EmbeddingDefinition: 25 | ... 26 | 27 | 28 | def _do_embedding_definition_selection( 29 | defs: list[EmbeddingDefinition], 30 | allow_multiple: bool = True, 31 | ) -> list[EmbeddingDefinition] | EmbeddingDefinition: 32 | sorted_defs = natsorted(defs, key=lambda x: (x.dataset, x.model), alg=ns.IGNORECASE) 33 | choices = [Choice(d, f"D: {d.dataset[:15]:18s} M: {d.model}") for d in sorted_defs] 34 | message = "Please select the desired pairs" if allow_multiple else "Please select a pair" 35 | definitions = inq.fuzzy( 36 | message, 37 | choices=choices, 38 | multiselect=allow_multiple, 39 | vi_mode=True, 40 | ).execute() # type: ignore 41 | return definitions 42 | 43 | 44 | def _by_dataset(defs: list[EmbeddingDefinition] | dict[str, list[EmbeddingDefinition]]) -> list[EmbeddingDefinition]: 45 | if isinstance(defs, list): 46 | defs_list = defs 47 | defs = {} 48 | for d in defs_list: 49 | defs.setdefault(d.dataset, []).append(d) 50 | 51 | choices = sorted( 52 | [Choice(v, f"D: {k[:15]:18s} M: {', '.join([d.model for d in v])}") for k, v in defs.items() if len(v)], 53 | key=lambda c: len(c.value), 54 | ) 55 | message = "Please select a dataset" 56 | definitions: list[EmbeddingDefinition] = inq.fuzzy( 57 | message, choices=choices, multiselect=False, vi_mode=True 58 | ).execute() # type: ignore 59 | return definitions 60 | 61 | 62 | def parse_raw_embedding_definitions(raw_embedding_definitions: list[str]) -> list[EmbeddingDefinition]: 63 | if not all([model_dataset.count("/") == 1 for model_dataset in raw_embedding_definitions]): 64 | raise ValueError("All (model, dataset) pairs must be presented as MODEL/DATASET") 65 | model_dataset_pairs = [model_dataset.split("/") for model_dataset in raw_embedding_definitions] 66 | return [ 67 | EmbeddingDefinition(model=model_dataset[0], dataset=model_dataset[1]) for model_dataset in model_dataset_pairs 68 | ] 69 | 70 | 71 | def select_existing_embedding_definitions( 72 | by_dataset: bool = False, 73 | count: int | None = None, 74 | ) -> list[EmbeddingDefinition]: 75 | defs = read_all_cached_embeddings(as_list=True) 76 | 77 | if by_dataset: 78 | # Subset definitions to specific dataset 79 | defs = _by_dataset(defs) 80 | 81 | if count is None: 82 | return _do_embedding_definition_selection(defs) 83 | else: 84 | return [_do_embedding_definition_selection(defs, allow_multiple=False) for _ in range(count)] 85 | 86 | 87 | def select_from_all_embedding_definitions( 88 | include_existing: bool = False, by_dataset: bool = False 89 | ) -> list[EmbeddingDefinition]: 90 | existing = set(read_all_cached_embeddings(as_list=True)) 91 | 92 | models = ModelProvider.list_model_titles() 93 | datasets = DatasetProvider.list_dataset_titles() 94 | 95 | defs = [EmbeddingDefinition(dataset=d, model=m) for d, m in product(datasets, models)] 96 | if not include_existing: 97 | defs = [d for d in defs if d not in existing] 98 | 99 | if by_dataset: 100 | defs = _by_dataset(defs) 101 | 102 | return _do_embedding_definition_selection(defs) 103 | -------------------------------------------------------------------------------- /tti_eval/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import EmbeddingDefinition, Embeddings, Split 2 | from .numpy_types import ClassArray, EmbeddingArray, ProbabilityArray, ReductionArray 3 | -------------------------------------------------------------------------------- /tti_eval/common/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from enum import Enum 3 | from pathlib import Path 4 | from typing import Annotated, Any 5 | 6 | import numpy as np 7 | from pydantic import BaseModel, model_validator 8 | from pydantic.functional_validators import AfterValidator 9 | 10 | from tti_eval.constants import NPZ_KEYS, PROJECT_PATHS 11 | 12 | from .numpy_types import ClassArray, EmbeddingArray 13 | from .string_utils import safe_str 14 | 15 | SafeName = Annotated[str, AfterValidator(safe_str)] 16 | logger = logging.getLogger("multiclips") 17 | 18 | 19 | class Split(str, Enum): 20 | TRAIN = "train" 21 | VALIDATION = "validation" 22 | TEST = "test" 23 | ALL = "all" 24 | 25 | 26 | class Embeddings(BaseModel): 27 | images: EmbeddingArray 28 | classes: EmbeddingArray | None = None 29 | labels: ClassArray 30 | 31 | @model_validator(mode="after") 32 | def validate_shapes(self) -> "Embeddings": 33 | N_images = self.images.shape[0] 34 | N_labels = self.labels.shape[0] 35 | if N_images != N_labels: 36 | raise ValueError(f"Differing number of images: {N_images} and labels: {N_labels}") 37 | 38 | if self.classes is None: 39 | return self 40 | 41 | *_, d1 = self.images.shape 42 | *_, d2 = self.classes.shape 43 | 44 | if d1 != d2: 45 | raise ValueError( 46 | f"Image {self.images.shape} and class {self.classes.shape} embeddings should have same dimensionality" 47 | ) 48 | return self 49 | 50 | @staticmethod 51 | def from_file(path: Path) -> "Embeddings": 52 | if not path.suffix == ".npz": 53 | raise ValueError(f"Embedding files should be `.npz` files not {path.suffix}") 54 | 55 | loaded = np.load(path) 56 | if NPZ_KEYS.IMAGE_EMBEDDINGS not in loaded and NPZ_KEYS.LABELS not in loaded: 57 | raise ValueError(f"Require both {NPZ_KEYS.IMAGE_EMBEDDINGS}, {NPZ_KEYS.LABELS} should be present in {path}") 58 | 59 | image_embeddings: EmbeddingArray = loaded[NPZ_KEYS.IMAGE_EMBEDDINGS] 60 | labels: ClassArray = loaded[NPZ_KEYS.LABELS] 61 | label_embeddings: EmbeddingArray | None = ( 62 | loaded[NPZ_KEYS.CLASS_EMBEDDINGS] if NPZ_KEYS.CLASS_EMBEDDINGS in loaded else None 63 | ) 64 | return Embeddings(images=image_embeddings, labels=labels, classes=label_embeddings) 65 | 66 | def to_file(self, path: Path) -> Path: 67 | if not path.suffix == ".npz": 68 | raise ValueError(f"Embedding files should be `.npz` files not {path.suffix}") 69 | to_store: dict[str, np.ndarray] = { 70 | NPZ_KEYS.IMAGE_EMBEDDINGS: self.images, 71 | NPZ_KEYS.LABELS: self.labels, 72 | } 73 | 74 | if self.classes is not None: 75 | to_store[NPZ_KEYS.CLASS_EMBEDDINGS] = self.classes 76 | 77 | np.savez_compressed( 78 | path, 79 | **to_store, 80 | ) 81 | return path 82 | 83 | class Config: 84 | arbitrary_types_allowed = True 85 | 86 | 87 | class EmbeddingDefinition(BaseModel): 88 | model: SafeName 89 | dataset: SafeName 90 | 91 | def _get_embedding_path(self, split: Split, suffix: str) -> Path: 92 | return Path(self.dataset) / f"{self.model}_{split.value}{suffix}" 93 | 94 | def embedding_path(self, split: Split) -> Path: 95 | return PROJECT_PATHS.EMBEDDINGS / self._get_embedding_path(split, ".npz") 96 | 97 | def get_reduction_path(self, reduction_name: str, split: Split): 98 | return PROJECT_PATHS.REDUCTIONS / self._get_embedding_path(split, f".{reduction_name}.2d.npy") 99 | 100 | def load_embeddings(self, split: Split) -> Embeddings | None: 101 | """ 102 | Load embeddings for embedding configuration or return None 103 | """ 104 | try: 105 | return Embeddings.from_file(self.embedding_path(split)) 106 | except ValueError: 107 | return None 108 | 109 | def save_embeddings(self, embeddings: Embeddings, split: Split, overwrite: bool = False) -> bool: 110 | """ 111 | Save embeddings associated to the embedding definition. 112 | Args: 113 | embeddings: The embeddings to store 114 | split: The dataset split that corresponds to the embeddings 115 | overwrite: If false, won't overwrite and will return False 116 | 117 | Returns: 118 | True iff file stored successfully 119 | 120 | """ 121 | if self.embedding_path(split).is_file() and not overwrite: 122 | logger.warning( 123 | f"Not saving embeddings to file `{self.embedding_path(split)}` as overwrite is False and file exists" 124 | ) 125 | return False 126 | self.embedding_path(split).parent.mkdir(exist_ok=True, parents=True) 127 | embeddings.to_file(self.embedding_path(split)) 128 | return True 129 | 130 | def __str__(self): 131 | return self.model + "_" + self.dataset 132 | 133 | def __eq__(self, other: Any) -> bool: 134 | return isinstance(other, EmbeddingDefinition) and self.model == other.model and self.dataset == other.dataset 135 | 136 | def __hash__(self): 137 | return hash((self.model, self.dataset)) 138 | -------------------------------------------------------------------------------- /tti_eval/common/numpy_types.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Literal, TypeVar 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | 6 | DType = TypeVar("DType", bound=np.generic) 7 | 8 | NArray = Annotated[npt.NDArray[DType], Literal["N"]] 9 | NCArray = Annotated[npt.NDArray[DType], Literal["N", "C"]] 10 | NDArray = Annotated[npt.NDArray[DType], Literal["N", "D"]] 11 | N2Array = Annotated[npt.NDArray[DType], Literal["N", 2]] 12 | 13 | # Classes 14 | ProbabilityArray = NCArray[np.float32] 15 | ClassArray = NArray[np.int32] 16 | """ 17 | Numpy array of shape [N,C] probabilities where N is number of samples and C is number of classes. 18 | """ 19 | 20 | # Embeddings 21 | EmbeddingArray = NDArray[np.float32] 22 | """ 23 | Numpy array of shape [N,D] embeddings where N is number of samples and D is the embedding size. 24 | """ 25 | 26 | # Reductions 27 | ReductionArray = N2Array[np.float32] 28 | """ 29 | Numpy array of shape [N,2] reductions of embeddings. 30 | """ 31 | -------------------------------------------------------------------------------- /tti_eval/common/string_utils.py: -------------------------------------------------------------------------------- 1 | KEEP_CHARACTERS = {".", "_", " ", "-"} 2 | REPLACE_CHARACTERS = {" ": "_"} 3 | 4 | 5 | def safe_str(unsafe: str) -> str: 6 | if not isinstance(unsafe, str): 7 | raise ValueError(f"{unsafe} ({type(unsafe)}) not a string") 8 | return "".join(REPLACE_CHARACTERS.get(c, c) for c in unsafe if c.isalnum() or c in KEEP_CHARACTERS).rstrip() 9 | -------------------------------------------------------------------------------- /tti_eval/compute.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from tti_eval.common import EmbeddingDefinition, Embeddings, Split 4 | from tti_eval.dataset import Dataset, DatasetProvider 5 | from tti_eval.model import Model, ModelProvider 6 | 7 | 8 | def compute_embeddings(model: Model, dataset: Dataset, batch_size: int = 50) -> Embeddings: 9 | dataset.set_transform(model.get_transform()) 10 | dataloader = DataLoader(dataset, collate_fn=model.get_collate_fn(), batch_size=batch_size) 11 | 12 | image_embeddings, class_embeddings, labels = model.build_embedding(dataloader) 13 | embeddings = Embeddings(images=image_embeddings, classes=class_embeddings, labels=labels) 14 | return embeddings 15 | 16 | 17 | def compute_embeddings_from_definition(definition: EmbeddingDefinition, split: Split) -> Embeddings: 18 | model = ModelProvider.get_model(definition.model) 19 | dataset = DatasetProvider.get_dataset(definition.dataset, split) 20 | return compute_embeddings(model, dataset) 21 | -------------------------------------------------------------------------------- /tti_eval/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from dotenv import load_dotenv 5 | 6 | load_dotenv() 7 | 8 | # If the cache directory is not explicitly specified, use the `.cache` directory located in the project's root. 9 | _TTI_EVAL_ROOT_DIR = Path(__file__).parent.parent 10 | CACHE_PATH = Path(os.environ.get("TTI_EVAL_CACHE_PATH", _TTI_EVAL_ROOT_DIR / ".cache")) 11 | _OUTPUT_PATH = Path(os.environ.get("TTI_EVAL_OUTPUT_PATH", _TTI_EVAL_ROOT_DIR / "output")) 12 | _SOURCES_PATH = _TTI_EVAL_ROOT_DIR / "sources" 13 | 14 | 15 | class PROJECT_PATHS: 16 | EMBEDDINGS = CACHE_PATH / "embeddings" 17 | MODELS = CACHE_PATH / "models" 18 | REDUCTIONS = CACHE_PATH / "reductions" 19 | 20 | 21 | class NPZ_KEYS: 22 | IMAGE_EMBEDDINGS = "image_embeddings" 23 | CLASS_EMBEDDINGS = "class_embeddings" 24 | LABELS = "labels" 25 | 26 | 27 | class OUTPUT_PATH: 28 | ANIMATIONS = _OUTPUT_PATH / "animations" 29 | EVALUATIONS = _OUTPUT_PATH / "evaluations" 30 | 31 | 32 | class SOURCES_PATH: 33 | DATASET_INSTANCE_DEFINITIONS = _SOURCES_PATH / "datasets" 34 | MODEL_INSTANCE_DEFINITIONS = _SOURCES_PATH / "models" 35 | -------------------------------------------------------------------------------- /tti_eval/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Dataset, Split 2 | from .provider import DatasetProvider 3 | -------------------------------------------------------------------------------- /tti_eval/dataset/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from os.path import relpath 3 | from pathlib import Path 4 | 5 | from pydantic import BaseModel, ConfigDict 6 | from torch.utils.data import Dataset as TorchDataset 7 | 8 | from tti_eval.common import Split 9 | from tti_eval.constants import CACHE_PATH, SOURCES_PATH 10 | 11 | DEFAULT_DATASET_TYPES_LOCATION = ( 12 | Path(relpath(str(__file__), SOURCES_PATH.DATASET_INSTANCE_DEFINITIONS)).parent / "types" / "__init__.py" 13 | ) 14 | 15 | 16 | class DatasetDefinitionSpec(BaseModel): 17 | dataset_type: str 18 | title: str 19 | module_path: Path = DEFAULT_DATASET_TYPES_LOCATION 20 | split: Split | None = None 21 | title_in_source: str | None = None 22 | cache_dir: Path | None = None 23 | 24 | # Allow additional dataset configuration fields 25 | model_config = ConfigDict(extra="allow") 26 | 27 | 28 | class Dataset(TorchDataset, ABC): 29 | def __init__( 30 | self, 31 | title: str, 32 | *, 33 | split: Split = Split.ALL, 34 | title_in_source: str | None = None, 35 | transform=None, 36 | cache_dir: str | None = None, 37 | **kwargs, 38 | ): 39 | self.transform = transform 40 | self.__title = title 41 | self.__title_in_source = title if title_in_source is None else title_in_source 42 | self.__split = split 43 | self.__class_names = [] 44 | if cache_dir is None: 45 | cache_dir = CACHE_PATH 46 | self._cache_dir: Path = Path(cache_dir).expanduser().resolve() / "datasets" / title 47 | 48 | @abstractmethod 49 | def __getitem__(self, idx): 50 | pass 51 | 52 | @abstractmethod 53 | def __len__(self): 54 | pass 55 | 56 | @property 57 | def split(self) -> Split: 58 | return self.__split 59 | 60 | @property 61 | def title(self) -> str: 62 | return self.__title 63 | 64 | @property 65 | def title_in_source(self) -> str: 66 | return self.__title_in_source 67 | 68 | @property 69 | def class_names(self) -> list[str]: 70 | return self.__class_names 71 | 72 | @class_names.setter 73 | def class_names(self, class_names: list[str]) -> None: 74 | self.__class_names = class_names 75 | 76 | def set_transform(self, transform): 77 | self.transform = transform 78 | 79 | @property 80 | def text_queries(self) -> list[str]: 81 | return [f"An image of a {class_name}" for class_name in self.class_names] 82 | 83 | @abstractmethod 84 | def _setup(self, **kwargs): 85 | pass 86 | -------------------------------------------------------------------------------- /tti_eval/dataset/provider.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | from natsort import natsorted, ns 6 | 7 | from tti_eval.constants import CACHE_PATH, SOURCES_PATH 8 | 9 | from .base import Dataset, DatasetDefinitionSpec, Split 10 | from .utils import load_class_from_path 11 | 12 | 13 | class DatasetProvider: 14 | __instance = None 15 | __global_settings: dict[str, Any] = dict() 16 | __known_dataset_types: dict[tuple[Path, str], Any] = dict() 17 | 18 | def __init__(self): 19 | self._datasets = {} 20 | 21 | @classmethod 22 | def prepare(cls): 23 | if cls.__instance is None: 24 | cls.__instance = cls() 25 | cls.register_datasets_from_sources_dir(SOURCES_PATH.DATASET_INSTANCE_DEFINITIONS) 26 | # Global settings 27 | cls.__instance.add_global_setting("cache_dir", CACHE_PATH) 28 | return cls.__instance 29 | 30 | @classmethod 31 | def global_settings(cls) -> dict: 32 | return deepcopy(cls.__global_settings) 33 | 34 | @classmethod 35 | def add_global_setting(cls, name: str, value: Any) -> None: 36 | cls.__global_settings[name] = value 37 | 38 | @classmethod 39 | def remove_global_setting(cls, name: str) -> None: 40 | cls.__global_settings.pop(name, None) 41 | 42 | @classmethod 43 | def register_dataset(cls, source: type[Dataset], title: str, split: Split | None = None, **kwargs) -> None: 44 | instance = cls.prepare() 45 | if split is None: 46 | # One dataset with all the split definitions 47 | instance._datasets[title] = (source, kwargs) 48 | else: 49 | # One dataset is defined per split 50 | kwargs.update(split=split) 51 | instance._datasets[(title, split)] = (source, kwargs) 52 | 53 | @classmethod 54 | def register_dataset_from_json_definition(cls, json_definition: Path) -> None: 55 | spec = DatasetDefinitionSpec.model_validate_json(json_definition.read_text(encoding="utf-8")) 56 | if not spec.module_path.is_absolute(): # Handle relative module paths 57 | spec.module_path = (json_definition.parent / spec.module_path).resolve() 58 | if not spec.module_path.is_file(): 59 | raise FileNotFoundError( 60 | f"Could not find the specified module path at `{spec.module_path.as_posix()}` " 61 | f"when registering dataset `{spec.title}`" 62 | ) 63 | 64 | # Fetch the class of the dataset type stated in the definition 65 | dataset_type = cls.__known_dataset_types.get((spec.module_path, spec.dataset_type)) 66 | if dataset_type is None: 67 | dataset_type = load_class_from_path(spec.module_path.as_posix(), spec.dataset_type) 68 | if not issubclass(dataset_type, Dataset): 69 | raise ValueError( 70 | f"Dataset type specified in the JSON definition file `{json_definition.as_posix()}` " 71 | f"does not inherit from the base class `Dataset`" 72 | ) 73 | cls.__known_dataset_types[(spec.module_path, spec.dataset_type)] = dataset_type 74 | cls.register_dataset(dataset_type, **spec.model_dump(exclude={"module_path", "dataset_type"})) 75 | 76 | @classmethod 77 | def register_datasets_from_sources_dir(cls, source_dir: Path) -> None: 78 | for f in source_dir.glob("*.json"): 79 | cls.register_dataset_from_json_definition(f) 80 | 81 | @classmethod 82 | def get_dataset(cls, title: str, split: Split) -> Dataset: 83 | instance = cls.prepare() 84 | if (title, split) in instance._datasets: 85 | # The split corresponds to fetching a whole dataset (one-to-one relationship) 86 | dict_key = (title, split) 87 | split = Split.ALL # Ensure to read the whole dataset 88 | elif title in instance._datasets: 89 | # The dataset internally knows how to determine the split 90 | dict_key = title 91 | else: 92 | raise ValueError(f"Unrecognized dataset: {title}") 93 | 94 | source, kwargs = instance._datasets[dict_key] 95 | # Apply global settings. Values of local settings take priority when local and global settings share keys. 96 | kwargs_with_global_settings = cls.global_settings() | kwargs 97 | return source(title, split=split, **kwargs_with_global_settings) 98 | 99 | @classmethod 100 | def list_dataset_titles(cls) -> list[str]: 101 | dataset_titles = [ 102 | dict_key[0] if isinstance(dict_key, tuple) else dict_key for dict_key in cls.prepare()._datasets.keys() 103 | ] 104 | return natsorted(set(dataset_titles), alg=ns.IGNORECASE) 105 | -------------------------------------------------------------------------------- /tti_eval/dataset/types/__init__.py: -------------------------------------------------------------------------------- 1 | from tti_eval.dataset.types.encord_ds import EncordDataset 2 | from tti_eval.dataset.types.hugging_face import HFDataset 3 | -------------------------------------------------------------------------------- /tti_eval/dataset/types/encord_ds.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import os 4 | from collections.abc import Callable 5 | from concurrent.futures import ThreadPoolExecutor 6 | from dataclasses import dataclass 7 | from functools import partial 8 | from pathlib import Path 9 | from typing import Any 10 | 11 | from encord import EncordUserClient, Project 12 | from encord.common.constants import DATETIME_STRING_FORMAT 13 | from encord.objects import Classification, LabelRowV2 14 | from encord.objects.common import PropertyType 15 | from encord.orm.dataset import DataType, Image, Video 16 | from PIL import Image as PILImage 17 | from tqdm.auto import tqdm 18 | 19 | from tti_eval.dataset import Dataset, Split 20 | from tti_eval.dataset.utils import collect_async, download_file, simple_random_split 21 | 22 | 23 | class EncordDataset(Dataset): 24 | def __init__( 25 | self, 26 | title: str, 27 | project_hash: str, 28 | classification_hash: str, 29 | *, 30 | split: Split = Split.ALL, 31 | title_in_source: str | None = None, 32 | transform=None, 33 | cache_dir: str | None = None, 34 | ssh_key_path: str | None = None, 35 | **kwargs, 36 | ): 37 | super().__init__( 38 | title, 39 | split=split, 40 | title_in_source=title_in_source, 41 | transform=transform, 42 | cache_dir=cache_dir, 43 | ) 44 | self._setup(project_hash, classification_hash, ssh_key_path, **kwargs) 45 | 46 | def __getitem__(self, idx): 47 | frame_path = self._dataset_indices_info[idx].image_file 48 | img = PILImage.open(frame_path) 49 | label = self._dataset_indices_info[idx].label 50 | 51 | if self.transform is not None: 52 | _d = self.transform(dict(image=[img], label=[label])) 53 | res_item = dict(image=_d["image"][0], label=_d["label"][0]) 54 | else: 55 | res_item = dict(image=img, label=label) 56 | return res_item 57 | 58 | def __len__(self): 59 | return len(self._dataset_indices_info) 60 | 61 | def _get_frame_file(self, label_row: LabelRowV2, frame: int) -> Path: 62 | return get_frame_file( 63 | data_dir=self._cache_dir, 64 | project_hash=self._project.project_hash, 65 | label_row=label_row, 66 | frame=frame, 67 | ) 68 | 69 | def _get_label_row_annotations_file(self, label_row: LabelRowV2) -> Path: 70 | return get_label_row_annotations_file( 71 | data_dir=self._cache_dir, 72 | project_hash=self._project.project_hash, 73 | label_row_hash=label_row.label_hash, 74 | ) 75 | 76 | def _ensure_answers_availability(self) -> dict: 77 | lrs_info_file = get_label_rows_info_file(self._cache_dir, self._project.project_hash) 78 | label_rows_info: dict = json.loads(lrs_info_file.read_text(encoding="utf-8")) 79 | should_update_info = False 80 | class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices 81 | for label_row in self._label_rows: 82 | if "answers" not in label_rows_info[label_row.label_hash]: 83 | if not label_row.is_labelling_initialised: 84 | # Retrieve label row content from local storage 85 | anns_path = self._get_label_row_annotations_file(label_row) 86 | label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8"))) 87 | 88 | answers = dict() 89 | for frame_view in label_row.get_frame_views(): 90 | clf_instances = frame_view.get_classification_instances(self._classification) 91 | # Skip frames where the input classification is missing 92 | if len(clf_instances) == 0: 93 | continue 94 | 95 | clf_instance = clf_instances[0] 96 | clf_answer = clf_instance.get_answer(self._attribute) 97 | # Skip frames where the input classification has no answer (probable annotation error) 98 | if clf_answer is None: 99 | continue 100 | 101 | answers[frame_view.frame] = { 102 | "image_file": self._get_frame_file(label_row, frame_view.frame).as_posix(), 103 | "label": class_name_to_idx[clf_answer.title], 104 | } 105 | label_rows_info[label_row.label_hash]["answers"] = answers 106 | should_update_info = True 107 | if should_update_info: 108 | lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8") 109 | return label_rows_info 110 | 111 | def _setup( 112 | self, 113 | project_hash: str, 114 | classification_hash: str, 115 | ssh_key_path: str | None = None, 116 | **kwargs, 117 | ): 118 | ssh_key_path = ssh_key_path or os.getenv("ENCORD_SSH_KEY_PATH") 119 | if ssh_key_path is None: 120 | raise ValueError( 121 | "The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing. " 122 | "Please set one of them to proceed." 123 | ) 124 | client = EncordUserClient.create_with_ssh_private_key(ssh_private_key_path=ssh_key_path) 125 | self._project = client.get_project(project_hash) 126 | 127 | self._classification = self._project.ontology_structure.get_child_by_hash( 128 | classification_hash, type_=Classification 129 | ) 130 | radio_attribute = self._classification.attributes[0] 131 | if radio_attribute.get_property_type() != PropertyType.RADIO: 132 | raise ValueError("Expected a classification hash with an attribute of type `Radio`") 133 | self._attribute = radio_attribute 134 | self.class_names = [o.title for o in self._attribute.options] 135 | 136 | # Fetch the label rows of the selected split 137 | splits_file = self._cache_dir / "splits.json" 138 | split_to_lr_hashes: dict[str, list[str]] 139 | if splits_file.exists(): 140 | split_to_lr_hashes = json.loads(splits_file.read_text(encoding="utf-8")) 141 | else: 142 | split_to_lr_hashes = simple_project_split(self._project) 143 | splits_file.parent.mkdir(parents=True, exist_ok=True) 144 | splits_file.write_text(json.dumps(split_to_lr_hashes), encoding="utf-8") 145 | self._label_rows = self._project.list_label_rows_v2(label_hashes=split_to_lr_hashes[self.split]) 146 | 147 | # Get data from source. Users may supply the `overwrite_annotations` keyword in the init to download everything 148 | download_data_from_project( 149 | self._project, 150 | self._cache_dir, 151 | self._label_rows, 152 | tqdm_desc=f"Downloading {self.split} data from Encord project `{self._project.title}`", 153 | **kwargs, 154 | ) 155 | 156 | # Prepare data for the __getitem__ method 157 | self._dataset_indices_info: list[EncordDataset.DatasetIndexInfo] = [] 158 | label_rows_info = self._ensure_answers_availability() 159 | for label_row in self._label_rows: 160 | answers: dict[int, Any] = label_rows_info[label_row.label_hash]["answers"] 161 | for frame_num in sorted(answers.keys()): 162 | self._dataset_indices_info.append(EncordDataset.DatasetIndexInfo(**answers[frame_num])) 163 | 164 | @dataclass 165 | class DatasetIndexInfo: 166 | image_file: Path | str 167 | label: int 168 | 169 | 170 | # ----------------------------------------------------------------------- 171 | # UTILITY FUNCTIONS 172 | # ----------------------------------------------------------------------- 173 | 174 | 175 | def _download_image(image_data: Image | Video, destination_dir: Path) -> Path: 176 | # TODO The type of `image_data` is also Video because of a SDK bug explained in `_download_label_row_image_data`. 177 | file_name = get_frame_name(image_data.image_hash, image_data.title) 178 | destination_path = destination_dir / file_name 179 | if not destination_path.exists(): 180 | download_file(image_data.file_link, destination_path) 181 | return destination_path 182 | 183 | 184 | def _download_label_row_image_data(data_dir: Path, project: Project, label_row: LabelRowV2) -> list[Path]: 185 | label_row.initialise_labels() 186 | label_row_dir = get_label_row_dir(data_dir, project.project_hash, label_row.label_hash) 187 | label_row_dir.mkdir(parents=True, exist_ok=True) 188 | 189 | if label_row.data_type == DataType.IMAGE: 190 | # TODO This `if` is here because of a SDK bug, remove it when IMAGE data is stored in the proper image field [1] 191 | images_data = [project.get_data(label_row.data_hash, get_signed_url=True)[0]] 192 | # Missing field caused by the SDK bug 193 | images_data[0]["image_hash"] = label_row.data_hash 194 | else: 195 | images_data = project.get_data(label_row.data_hash, get_signed_url=True)[1] 196 | return collect_async( 197 | lambda image_data: _download_image(image_data, label_row_dir), 198 | images_data, 199 | max_workers=4, 200 | disable=True, 201 | ) 202 | 203 | 204 | def _download_label_row( 205 | label_row: LabelRowV2, 206 | project: Project, 207 | data_dir: Path, 208 | overwrite_annotations: bool, 209 | label_rows_info: dict[str, Any], 210 | update_pbar: Callable[[], Any], 211 | ): 212 | if label_row.data_type not in {DataType.IMAGE, DataType.IMG_GROUP}: 213 | return 214 | save_annotations = False 215 | # Trigger the images download if the label hash is not found or is None (never downloaded). 216 | if label_row.label_hash not in label_rows_info.keys(): 217 | _download_label_row_image_data(data_dir, project, label_row) 218 | save_annotations = True 219 | # Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations. 220 | elif ( 221 | overwrite_annotations 222 | and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT) 223 | != label_rows_info[label_row.label_hash]["last_edited_at"] 224 | ): 225 | label_row.initialise_labels() 226 | save_annotations = True 227 | 228 | if save_annotations: 229 | annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash) 230 | annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8") 231 | label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at} 232 | update_pbar() 233 | 234 | 235 | def _download_label_rows( 236 | project: Project, 237 | data_dir: Path, 238 | label_rows: list[LabelRowV2], 239 | overwrite_annotations: bool, 240 | label_rows_info: dict[str, Any], 241 | tqdm_desc: str | None, 242 | ): 243 | if tqdm_desc is None: 244 | tqdm_desc = f"Downloading data from Encord project `{project.title}`" 245 | 246 | pbar = tqdm(total=len(label_rows), desc=tqdm_desc) 247 | _do_download = partial( 248 | _download_label_row, 249 | project=project, 250 | data_dir=data_dir, 251 | overwrite_annotations=overwrite_annotations, 252 | label_rows_info=label_rows_info, 253 | update_pbar=lambda: pbar.update(1), 254 | ) 255 | 256 | with ThreadPoolExecutor(min(multiprocessing.cpu_count(), 24)) as exe: 257 | exe.map(_do_download, label_rows) 258 | 259 | 260 | def download_data_from_project( 261 | project: Project, 262 | data_dir: Path, 263 | label_rows: list[LabelRowV2] | None = None, 264 | overwrite_annotations: bool = False, 265 | tqdm_desc: str | None = None, 266 | ) -> None: 267 | """ 268 | Iterates through the images of the project and downloads their content, adhering to the following file structure: 269 | data_dir/ 270 | ├── project-hash/ 271 | │ ├── image-group-hash1/ 272 | │ │ ├── image1.jpg 273 | │ │ ├── image2.jpg 274 | │ │ └── ... 275 | │ │ └── annotations.json 276 | │ └── image-hash1/ 277 | │ │ ├── image1.jpeg 278 | │ │ ├── annotations.json 279 | │ └── ... 280 | └── ... 281 | :param project: The project containing the images with their annotations. 282 | :param data_dir: The directory where the project data will be downloaded. 283 | :param label_rows: The label rows that will be downloaded. If None, all label rows will be downloaded. 284 | :param overwrite_annotations: Flag that indicates whether to overwrite existing annotations if they exist. 285 | :param tqdm_desc: Optional description for tqdm progress bar. 286 | Defaults to 'Downloading data from Encord project `{project.title}`' 287 | """ 288 | # Read file that tracks the downloaded data progress 289 | lrs_info_file = get_label_rows_info_file(data_dir, project.project_hash) 290 | lrs_info_file.parent.mkdir(parents=True, exist_ok=True) 291 | label_rows_info = json.loads(lrs_info_file.read_text(encoding="utf-8")) if lrs_info_file.is_file() else dict() 292 | 293 | # Retrieve only the unseen data if there is no explicit annotation update 294 | filtered_label_rows = ( 295 | label_rows 296 | if overwrite_annotations 297 | else [lr for lr in label_rows if lr.label_hash not in label_rows_info.keys()] 298 | ) 299 | if len(filtered_label_rows) == 0: 300 | return 301 | 302 | try: 303 | _download_label_rows( 304 | project, 305 | data_dir, 306 | filtered_label_rows, 307 | overwrite_annotations, 308 | label_rows_info, 309 | tqdm_desc=tqdm_desc, 310 | ) 311 | finally: 312 | # Save the current download progress in case of failure 313 | lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8") 314 | 315 | 316 | def get_frame_name(frame_hash: str, frame_title: str) -> str: 317 | file_extension = frame_title.rsplit(sep=".", maxsplit=1)[-1] 318 | return f"{frame_hash}.{file_extension}" 319 | 320 | 321 | def get_frame_file(data_dir: Path, project_hash: str, label_row: LabelRowV2, frame: int) -> Path: 322 | label_row_dir = get_label_row_dir(data_dir, project_hash, label_row.label_hash) 323 | frame_view = label_row.get_frame_view(frame) 324 | return label_row_dir / get_frame_name(frame_view.image_hash, frame_view.image_title) 325 | 326 | 327 | def get_frame_file_raw( 328 | data_dir: Path, 329 | project_hash: str, 330 | label_row_hash: str, 331 | frame_hash: str, 332 | frame_title: str, 333 | ) -> Path: 334 | return get_label_row_dir(data_dir, project_hash, label_row_hash) / get_frame_name(frame_hash, frame_title) 335 | 336 | 337 | def get_label_row_annotations_file(data_dir: Path, project_hash: str, label_row_hash: str) -> Path: 338 | return get_label_row_dir(data_dir, project_hash, label_row_hash) / "annotations.json" 339 | 340 | 341 | def get_label_row_dir(data_dir: Path, project_hash: str, label_row_hash: str) -> Path: 342 | return data_dir / project_hash / label_row_hash 343 | 344 | 345 | def get_label_rows_info_file(data_dir: Path, project_hash: str) -> Path: 346 | return data_dir / project_hash / "label_rows_info.json" 347 | 348 | 349 | def simple_project_split( 350 | project: Project, 351 | seed: int = 42, 352 | train_split: float = 0.7, 353 | validation_split: float = 0.15, 354 | ) -> dict[Split, list[str]]: 355 | """ 356 | Split the label rows of a project into training, validation, and test sets using simple random splitting. 357 | 358 | :param project: The project containing the label rows to split. 359 | :param seed: Random seed for reproducibility. Defaults to 42. 360 | :param train_split: Percentage of the dataset to allocate to the training set. Defaults to 0.7. 361 | :param validation_split: Percentage of the dataset to allocate to the validation set. Defaults to 0.15. 362 | :return: A dictionary containing lists with the label hashes of the data represented in the training, 363 | validation, and test sets. 364 | 365 | :raises ValueError: If the sum of `train_split` and `validation_split` is greater than 1, 366 | or if `train_split` or `validation_split` are less than 0. 367 | """ 368 | label_rows = project.list_label_rows_v2() 369 | split_to_indices = simple_random_split(len(label_rows), seed, train_split, validation_split) 370 | enforce_label_rows_initialization(label_rows) # Ensure that all label rows have a label hash 371 | return {split: [label_rows[i].label_hash for i in indices] for split, indices in split_to_indices.items()} 372 | 373 | 374 | def enforce_label_rows_initialization(label_rows: list[LabelRowV2]): 375 | for lr in label_rows: 376 | if lr.label_hash is None: 377 | lr.initialise_labels() 378 | -------------------------------------------------------------------------------- /tti_eval/dataset/types/hugging_face.py: -------------------------------------------------------------------------------- 1 | from datasets import ClassLabel, DatasetDict, Sequence, Value, load_dataset 2 | from datasets import Dataset as _RemoteHFDataset 3 | 4 | from tti_eval.dataset import Dataset, Split 5 | 6 | 7 | class HFDataset(Dataset): 8 | def __init__( 9 | self, 10 | title: str, 11 | *, 12 | split: Split = Split.ALL, 13 | title_in_source: str | None = None, 14 | transform=None, 15 | cache_dir: str | None = None, 16 | target_feature: str = "label", 17 | **kwargs, 18 | ): 19 | super().__init__(title, split=split, title_in_source=title_in_source, transform=transform, cache_dir=cache_dir) 20 | self._target_feature = target_feature 21 | self._setup(**kwargs) 22 | 23 | def __getitem__(self, idx): 24 | return self._dataset[idx] 25 | 26 | def __len__(self): 27 | return len(self._dataset) 28 | 29 | def set_transform(self, transform): 30 | super().set_transform(transform) 31 | self._dataset.set_transform(transform) 32 | 33 | @staticmethod 34 | def _get_available_splits(dataset_dict: DatasetDict) -> list[Split]: 35 | return [Split(s) for s in dataset_dict.keys() if s in [_ for _ in Split]] + [Split.ALL] 36 | 37 | def _get_hf_dataset_split(self, **kwargs) -> _RemoteHFDataset: 38 | try: 39 | if self.split == Split.ALL: # Retrieve all the dataset data if the split is ALL 40 | return load_dataset(self.title_in_source, split="all", cache_dir=self._cache_dir.as_posix(), **kwargs) 41 | dataset_dict: DatasetDict = load_dataset( 42 | self.title_in_source, 43 | cache_dir=self._cache_dir.as_posix(), 44 | **kwargs, 45 | ) 46 | except Exception as e: 47 | raise ValueError(f"Failed to load dataset from Hugging Face: `{self.title_in_source}`") from e 48 | 49 | available_splits = HFDataset._get_available_splits(dataset_dict) 50 | missing_splits = [s for s in Split if s not in available_splits] 51 | 52 | # Return target dataset if it already exists and won't be modified 53 | if self.split in [Split.VALIDATION, Split.TEST] and self.split in available_splits: 54 | return dataset_dict[self.split] 55 | if self.split == Split.TRAIN: 56 | if self.split in missing_splits: 57 | # Train split must always exist 58 | raise AttributeError(f"Missing train split in Hugging Face dataset: `{self.title_in_source}`") 59 | if not missing_splits: 60 | # No need to split the train dataset, can be returned as a whole 61 | return dataset_dict[self.split] 62 | 63 | # Get a 15% of the train dataset for each missing split (VALIDATION, TEST or both) 64 | # This operation includes data shuffling to prevent splits with skewed class counts because of the input order 65 | split_percent = 0.15 * len(missing_splits) 66 | split_seed = 42 67 | # Split the original train dataset into two, the final train dataset and the missing splits dataset 68 | split_to_dataset = dataset_dict["train"].train_test_split(test_size=split_percent, seed=split_seed) 69 | if self.split == Split.TRAIN: 70 | return split_to_dataset[self.split] 71 | 72 | if len(missing_splits) == 1: 73 | # One missing split (either VALIDATION or TEST), so we return the 15% stored in "test" 74 | return split_to_dataset["test"] 75 | else: 76 | # Both VALIDATION and TEST splits are missing 77 | # Each one will take a half of the 30% stored in "test" 78 | if self.split == Split.VALIDATION: 79 | return split_to_dataset["test"].train_test_split(test_size=0.5, seed=split_seed)["train"] 80 | else: 81 | return split_to_dataset["test"].train_test_split(test_size=0.5, seed=split_seed)["test"] 82 | 83 | def _setup(self, **kwargs): 84 | self._dataset = self._get_hf_dataset_split(**kwargs) 85 | 86 | if self._target_feature not in self._dataset.features: 87 | raise ValueError(f"The dataset `{self.title}` does not have the target feature `{self._target_feature}`") 88 | # Encode the target feature if necessary (e.g. use integer instead of string in the label values) 89 | if isinstance(self._dataset.features[self._target_feature], Value): 90 | self._dataset = self._dataset.class_encode_column(self._target_feature) 91 | 92 | # Rename the target feature to `label` 93 | # TODO do not rename the target feature but use it in the embeddings computations instead of the `label` tag 94 | if self._target_feature != "label": 95 | self._dataset = self._dataset.rename_column(self._target_feature, "label") 96 | 97 | label_feature = self._dataset.features["label"] 98 | if isinstance(label_feature, Sequence): # Drop potential wrapper 99 | label_feature = label_feature.feature 100 | 101 | if isinstance(label_feature, ClassLabel): 102 | self.class_names = label_feature.names 103 | else: 104 | raise TypeError(f"Expected target feature of type `ClassLabel`, found `{type(label_feature).__name__}`") 105 | 106 | def __del__(self): 107 | for f in self._cache_dir.glob("**/*.lock"): 108 | f.unlink(missing_ok=True) 109 | -------------------------------------------------------------------------------- /tti_eval/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import os 3 | from collections.abc import Callable 4 | from concurrent.futures import ThreadPoolExecutor as Executor 5 | from concurrent.futures import as_completed 6 | from pathlib import Path 7 | from typing import Any, TypeVar 8 | 9 | import numpy as np 10 | import requests 11 | from tqdm.auto import tqdm 12 | 13 | from tti_eval.common import Split 14 | 15 | T = TypeVar("T") 16 | G = TypeVar("G") 17 | 18 | _default_max_workers = min(10, (os.cpu_count() or 1) + 4) 19 | 20 | 21 | def collect_async( 22 | fn: Callable[[T], G], 23 | job_args: list[T], 24 | max_workers=_default_max_workers, 25 | **kwargs, 26 | ) -> list[G]: 27 | """ 28 | Distribute work across multiple workers. Good for, e.g., downloading data. 29 | Will return results in dictionary. 30 | :param fn: The function to be applied 31 | :param job_args: Arguments to `fn`. 32 | :param max_workers: Number of workers to distribute work over. 33 | :param kwargs: Arguments passed on to tqdm. 34 | :return: List [fn(*job_args)] 35 | """ 36 | if len(job_args) == 0: 37 | tmp: list[G] = [] 38 | return tmp 39 | if not isinstance(job_args[0], tuple): 40 | _job_args: list[tuple[Any]] = [(j,) for j in job_args] 41 | else: 42 | _job_args = job_args # type: ignore 43 | 44 | results: list[G] = [] 45 | with tqdm(total=len(job_args), **kwargs) as pbar: 46 | with Executor(max_workers=max_workers) as exe: 47 | jobs = [exe.submit(fn, *args) for args in _job_args] 48 | for job in as_completed(jobs): 49 | result = job.result() 50 | if result is not None: 51 | results.append(result) 52 | pbar.update(1) 53 | return results 54 | 55 | 56 | def download_file( 57 | url: str, 58 | destination: Path, 59 | ) -> None: 60 | destination.parent.mkdir(parents=True, exist_ok=True) 61 | with open(destination, "wb") as f: 62 | r = requests.get(url, stream=True) 63 | if r.status_code != 200: 64 | raise ConnectionError(f"Something happened, couldn't download file from: {url}") 65 | 66 | for chunk in r.iter_content(chunk_size=1024): 67 | if chunk: 68 | f.write(chunk) 69 | f.flush() 70 | 71 | 72 | def load_class_from_path(module_path: str, class_name: str): 73 | spec = importlib.util.spec_from_file_location(module_path, module_path) 74 | module = importlib.util.module_from_spec(spec) 75 | spec.loader.exec_module(module) 76 | return getattr(module, class_name) 77 | 78 | 79 | def simple_random_split( 80 | dataset_size: int, 81 | seed: int = 42, 82 | train_split: float = 0.7, 83 | validation_split: float = 0.15, 84 | ) -> dict[Split, np.ndarray]: 85 | """ 86 | Split the dataset into training, validation, and test sets using simple random splitting. 87 | 88 | :param dataset_size: The total size of the dataset. 89 | :param seed: Random seed for reproducibility. Defaults to 42. 90 | :param train_split: Percentage of the dataset to allocate to the training set. Defaults to 0.7. 91 | :param validation_split: Percentage of the dataset to allocate to the validation set. Defaults to 0.15. 92 | :return: A dictionary containing arrays with the indices of the data represented in the training, 93 | validation, and test sets. 94 | 95 | :raises ValueError: If the sum of `train_split` and `validation_split` is greater than 1, 96 | or if `train_split` or `validation_split` are less than 0. 97 | """ 98 | if dataset_size < 3: 99 | raise ValueError(f"Expected a dataset with size at least 3, got {dataset_size}") 100 | 101 | if train_split < 0 or validation_split < 0: 102 | raise ValueError(f"Expected positive splits, got ({train_split=}, {validation_split=})") 103 | if train_split + validation_split >= 1: 104 | raise ValueError( 105 | f"Expected `train_split` and `validation_split` sum between 0 and 1, got {train_split + validation_split}" 106 | ) 107 | rng = np.random.default_rng(seed) 108 | selection = rng.permutation(dataset_size) 109 | train_size = max(1, int(dataset_size * train_split)) 110 | validation_size = max(1, int(dataset_size * validation_split)) 111 | # Ensure that the TEST split has at least an element 112 | if train_size + validation_size == dataset_size: 113 | if train_size > 1: 114 | train_size -= 1 115 | if validation_size > 1: 116 | validation_size -= 1 117 | return { 118 | Split.TRAIN: selection[:train_size], 119 | Split.VALIDATION: selection[train_size : train_size + validation_size], 120 | Split.TEST: selection[train_size + validation_size :], 121 | Split.ALL: selection, 122 | } 123 | -------------------------------------------------------------------------------- /tti_eval/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ClassificationModel, EvaluationModel 2 | from .image_retrieval import I2IRetrievalEvaluator 3 | from .knn import WeightedKNNClassifier 4 | from .linear_probe import LinearProbeClassifier 5 | from .zero_shot import ZeroShotClassifier 6 | -------------------------------------------------------------------------------- /tti_eval/evaluation/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | from tti_eval.common import ClassArray, Embeddings, ProbabilityArray 5 | 6 | from .utils import normalize 7 | 8 | 9 | class EvaluationModel(ABC): 10 | def __init__( 11 | self, 12 | train_embeddings: Embeddings, 13 | validation_embeddings: Embeddings, 14 | num_classes: int | None = None, 15 | **kwargs, 16 | ) -> None: 17 | super().__init__(**kwargs) 18 | # Preprocessing the embeddings 19 | train_embeddings.images = normalize(train_embeddings.images) 20 | validation_embeddings.images = normalize(validation_embeddings.images) 21 | if train_embeddings.classes is not None: 22 | train_embeddings.classes = normalize(train_embeddings.classes) 23 | if validation_embeddings.classes is not None: 24 | validation_embeddings.classes = normalize(validation_embeddings.classes) 25 | 26 | self._train_embeddings = train_embeddings 27 | self._val_embeddings = validation_embeddings 28 | self._check_dims() 29 | 30 | # Ensure that number of classes is always valid 31 | self._num_classes: int = max(train_embeddings.labels.max(), validation_embeddings.labels.max()).item() + 1 32 | if num_classes is not None: 33 | if self._num_classes > num_classes: 34 | raise ValueError("`num_classes` is lower than the number of classes found in the embeddings") 35 | self._num_classes = num_classes 36 | 37 | def _check_dims(self): 38 | """ 39 | Raises error if dimensions doesn't match. 40 | """ 41 | train_images_s, train_labels_s = self._train_embeddings.images.shape, self._train_embeddings.labels.shape 42 | if train_labels_s[0] != train_images_s[0]: 43 | raise ValueError( 44 | f"Expected `{train_images_s[0]}` samples in train's label embeddings, got `{train_labels_s[0]}`" 45 | ) 46 | if self._train_embeddings.classes is not None and self._train_embeddings.classes.shape[-1] != self.dim: 47 | raise ValueError( 48 | f"Expected train's class embeddings with dimensions `{self.dim}`, " 49 | f"got `{self._train_embeddings.classes.shape[-1]}`" 50 | ) 51 | 52 | val_images_s, val_labels_s = self._val_embeddings.images.shape, self._val_embeddings.labels.shape 53 | if val_labels_s[0] != val_images_s[0]: 54 | raise ValueError( 55 | f"Expected `{val_images_s[0]}` samples in validation's label embeddings, got `{val_labels_s[0]}`" 56 | ) 57 | if val_images_s[-1] != self.dim: 58 | raise ValueError( 59 | f"Expected validation's image embeddings with dimensions `{self.dim}`, " 60 | f"got `{self._val_embeddings.images.shape[-1]}`" 61 | ) 62 | if self._val_embeddings.classes is not None and self._val_embeddings.classes.shape[-1] != self.dim: 63 | raise ValueError( 64 | f"Expected validation's class embeddings with dimensions `{self.dim}`, " 65 | f"got `{self._val_embeddings.classes.shape[-1]}`" 66 | ) 67 | 68 | @property 69 | def dim(self) -> int: 70 | """ 71 | Expected embedding dimensionality. 72 | """ 73 | return self._train_embeddings.images.shape[-1] 74 | 75 | @staticmethod 76 | def get_default_params() -> dict[str, Any]: 77 | return {} 78 | 79 | @property 80 | def num_classes(self) -> int: 81 | return self._num_classes 82 | 83 | @abstractmethod 84 | def evaluate(self) -> float: 85 | ... 86 | 87 | @classmethod 88 | @abstractmethod 89 | def title(cls) -> str: 90 | ... 91 | 92 | 93 | class ClassificationModel(EvaluationModel): 94 | @abstractmethod 95 | def predict(self) -> tuple[ProbabilityArray, ClassArray]: 96 | ... 97 | 98 | def evaluate(self) -> float: 99 | _, y_hat = self.predict() 100 | acc: float = (y_hat == self._val_embeddings.labels).astype(float).mean().item() 101 | return acc 102 | -------------------------------------------------------------------------------- /tti_eval/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from datetime import datetime 3 | 4 | from natsort import natsorted, ns 5 | from tabulate import tabulate 6 | from tqdm.auto import tqdm 7 | 8 | from tti_eval.common import EmbeddingDefinition, Split 9 | from tti_eval.constants import OUTPUT_PATH 10 | from tti_eval.evaluation import ( 11 | EvaluationModel, 12 | I2IRetrievalEvaluator, 13 | LinearProbeClassifier, 14 | WeightedKNNClassifier, 15 | ZeroShotClassifier, 16 | ) 17 | from tti_eval.utils import read_all_cached_embeddings 18 | 19 | 20 | def print_evaluation_results( 21 | results: dict[EmbeddingDefinition, dict[str, float]], 22 | evaluation_model_title: str, 23 | ): 24 | defs = list(results.keys()) 25 | model_names = natsorted(set(map(lambda d: d.model, defs)), alg=ns.IGNORECASE) 26 | dataset_names = natsorted(set(map(lambda d: d.dataset, defs)), alg=ns.IGNORECASE) 27 | 28 | table: list[list[float | str]] = [ 29 | ["Model/Dataset"] + dataset_names, 30 | ] + [[m] + ["-"] * len(dataset_names) for m in model_names] 31 | 32 | def set_score(d: EmbeddingDefinition, s: float): 33 | row = model_names.index(d.model) + 1 34 | col = dataset_names.index(d.dataset) + 1 35 | table[row][col] = f"{s:.4f}" 36 | 37 | for d, res in results.items(): 38 | s = res.get(evaluation_model_title) 39 | if s: 40 | set_score(d, res[evaluation_model_title]) 41 | 42 | print(f"{'='*5} {evaluation_model_title} {'=' * 5}") 43 | print(tabulate(table)) 44 | 45 | 46 | def run_evaluation( 47 | evaluators: list[type[EvaluationModel]], 48 | embedding_definitions: list[EmbeddingDefinition], 49 | ) -> dict[EmbeddingDefinition, dict[str, float]]: 50 | embeddings_performance: dict[EmbeddingDefinition, dict[str, float]] = {} 51 | used_evaluators: set[str] = set() 52 | 53 | for def_ in tqdm(embedding_definitions, desc="Evaluating models", leave=False): 54 | train_embeddings = def_.load_embeddings(Split.TRAIN) 55 | validation_embeddings = def_.load_embeddings(Split.VALIDATION) 56 | 57 | if train_embeddings is None: 58 | print(f"No train embeddings were found for {def_}") 59 | continue 60 | if validation_embeddings is None: 61 | print(f"No validation embeddings were found for {def_}") 62 | continue 63 | 64 | evaluator_performance: dict[str, float] = embeddings_performance.setdefault(def_, {}) 65 | for evaluator_type in evaluators: 66 | if evaluator_type == ZeroShotClassifier and train_embeddings.classes is None: 67 | continue 68 | evaluator = evaluator_type( 69 | train_embeddings=train_embeddings, 70 | validation_embeddings=validation_embeddings, 71 | ) 72 | evaluator_performance[evaluator.title()] = evaluator.evaluate() 73 | used_evaluators.add(evaluator.title()) 74 | 75 | for evaluator_type in evaluators: 76 | evaluator_title = evaluator_type.title() 77 | if evaluator_title in used_evaluators: 78 | print_evaluation_results(embeddings_performance, evaluator_title) 79 | return embeddings_performance 80 | 81 | 82 | def export_evaluation_to_csv(embeddings_performance: dict[EmbeddingDefinition, dict[str, float]]) -> None: 83 | ts = datetime.now() 84 | results_file = OUTPUT_PATH.EVALUATIONS / f"eval_{ts.isoformat()}.csv" 85 | results_file.parent.mkdir(parents=True, exist_ok=True) # Ensure that parent folder exists 86 | 87 | headers = ["Model", "Dataset", "Classifier", "Accuracy"] 88 | with open(results_file.as_posix(), "w", newline="") as csvfile: 89 | writer = csv.writer(csvfile) 90 | writer.writerow(headers) 91 | 92 | for def_, perf in embeddings_performance.items(): 93 | def_: EmbeddingDefinition 94 | for classifier_title, accuracy in perf.items(): 95 | writer.writerow([def_.model, def_.dataset, classifier_title, accuracy]) 96 | print(f"Evaluation results exported to `{results_file.as_posix()}`") 97 | 98 | 99 | if __name__ == "__main__": 100 | models = [ZeroShotClassifier, LinearProbeClassifier, WeightedKNNClassifier, I2IRetrievalEvaluator] 101 | defs = read_all_cached_embeddings(as_list=True) 102 | print(defs) 103 | performances = run_evaluation(models, defs) 104 | export_evaluation_to_csv(performances) 105 | print(performances) 106 | -------------------------------------------------------------------------------- /tti_eval/evaluation/image_retrieval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | import numpy as np 5 | from faiss import IndexFlatL2 6 | 7 | from tti_eval.common import Embeddings 8 | from tti_eval.utils import disable_tqdm 9 | 10 | from .base import EvaluationModel 11 | 12 | logger = logging.getLogger("multiclips") 13 | 14 | 15 | class I2IRetrievalEvaluator(EvaluationModel): 16 | @classmethod 17 | def title(cls) -> str: 18 | return "I2IR" 19 | 20 | def __init__( 21 | self, 22 | train_embeddings: Embeddings, 23 | validation_embeddings: Embeddings, 24 | num_classes: int | None = None, 25 | k: int = 100, 26 | ) -> None: 27 | """ 28 | Image-to-Image retrieval evaluator. 29 | 30 | For each training embedding (e_i), this evaluator computes the percentage (accuracy) of its k nearest neighbors 31 | from the validation embeddings that share the same class. It returns the mean percentage of correct nearest 32 | neighbors across all training embeddings. 33 | 34 | :param train_embeddings: Training embeddings used for evaluation. 35 | :param validation_embeddings: Validation embeddings used for similarity search setup. 36 | :param num_classes: Number of classes. If not specified, it will be inferred from the training embeddings. 37 | :param k: Number of nearest neighbors. 38 | 39 | :raises ValueError: If the build of the faiss index for similarity search fails. 40 | """ 41 | super().__init__(train_embeddings, validation_embeddings, num_classes) 42 | self.k = min(k, len(validation_embeddings.images)) 43 | 44 | class_ids, counts = np.unique(self._val_embeddings.labels, return_counts=True) 45 | self._class_counts = np.zeros(self.num_classes, dtype=np.int32) 46 | self._class_counts[class_ids] = counts 47 | 48 | disable_tqdm() # Disable tqdm progress bar when building the index 49 | d = self._val_embeddings.images.shape[-1] 50 | self._index = IndexFlatL2(d) 51 | self._index.add(self._val_embeddings.images) 52 | 53 | def evaluate(self) -> float: 54 | _, nearest_indices = self._index.search(self._train_embeddings.images, self.k) # type: ignore 55 | nearest_classes = self._val_embeddings.labels[nearest_indices] 56 | 57 | # To compute retrieval accuracy, we ensure that a maximum of Q elements per sample are retrieved, 58 | # where Q represents the size of the respective class in the validation embeddings 59 | top_nearest_per_class = np.where(self._class_counts < self.k, self._class_counts, self.k) 60 | top_nearest_per_sample = top_nearest_per_class[self._train_embeddings.labels] 61 | 62 | # Add a placeholder value for indices outside the retrieval scope 63 | nearest_classes[np.arange(self.k) >= top_nearest_per_sample[:, np.newaxis]] = -1 64 | 65 | # Count the number of neighbours that match the class of the sample and compute the mean accuracy 66 | matches_per_sample = np.sum( 67 | nearest_classes == np.array(self._train_embeddings.labels)[:, np.newaxis], 68 | axis=1, 69 | ) 70 | accuracies = np.divide( 71 | matches_per_sample, 72 | top_nearest_per_sample, 73 | out=np.zeros_like(matches_per_sample, dtype=np.float64), 74 | where=top_nearest_per_sample != 0, 75 | ) 76 | return accuracies.mean().item() 77 | 78 | @staticmethod 79 | def get_default_params() -> dict[str, Any]: 80 | return {"k": 100} 81 | 82 | 83 | if __name__ == "__main__": 84 | np.random.seed(42) 85 | train_embeddings = Embeddings( 86 | images=np.random.randn(80, 10).astype(np.float32), 87 | labels=np.random.randint(0, 10, size=(80,)), 88 | ) 89 | val_embeddings = Embeddings( 90 | images=np.random.randn(20, 10).astype(np.float32), 91 | labels=np.random.randint(0, 10, size=(20,)), 92 | ) 93 | mean_accuracy = I2IRetrievalEvaluator( 94 | train_embeddings, 95 | val_embeddings, 96 | num_classes=10, 97 | ).evaluate() 98 | print(mean_accuracy) 99 | -------------------------------------------------------------------------------- /tti_eval/evaluation/knn.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | import numpy as np 5 | from faiss import IndexFlatL2 6 | 7 | from tti_eval.common import ClassArray, Embeddings, ProbabilityArray 8 | from tti_eval.utils import disable_tqdm 9 | 10 | from .base import ClassificationModel 11 | from .utils import softmax 12 | 13 | logger = logging.getLogger("multiclips") 14 | 15 | 16 | class WeightedKNNClassifier(ClassificationModel): 17 | @classmethod 18 | def title(cls) -> str: 19 | return "wKNN" 20 | 21 | def __init__( 22 | self, 23 | train_embeddings: Embeddings, 24 | validation_embeddings: Embeddings, 25 | num_classes: int | None = None, 26 | k: int = 11, 27 | ) -> None: 28 | """ 29 | Weighted KNN classifier based on the provided embeddings and labels. 30 | Output "probabilities" is a softmax over weighted votes of class. 31 | 32 | Given q, a sample to predict, this function identifies the k nearest neighbors (e_i) 33 | with corresponding classes (y_i) from `train_embeddings.images` and `train_embeddings.labels`, respectively, 34 | and assigns a class vote of 1/||e_i - q||_2^2 for class y_i. 35 | The class with the highest vote count will be chosen. 36 | 37 | :param train_embeddings: Embeddings and their labels used for setting up the search space. 38 | :param validation_embeddings: Embeddings and their labels used for evaluating the search space. 39 | :param num_classes: Number of classes. If not specified, it will be inferred from the train labels. 40 | :param k: Number of nearest neighbors. 41 | 42 | :raises ValueError: If the build of the faiss index for KNN fails. 43 | """ 44 | super().__init__(train_embeddings, validation_embeddings, num_classes) 45 | self.k = k 46 | disable_tqdm() # Disable tqdm progress bar when building the index 47 | d = train_embeddings.images.shape[-1] 48 | self._index = IndexFlatL2(d) 49 | self._index.add(train_embeddings.images) 50 | 51 | @staticmethod 52 | def get_default_params() -> dict[str, Any]: 53 | return {"k": 11} 54 | 55 | def predict(self) -> tuple[ProbabilityArray, ClassArray]: 56 | dists, nearest_indices = self._index.search(self._val_embeddings.images, self.k) # type: ignore 57 | nearest_classes = np.take(self._train_embeddings.labels, nearest_indices) 58 | 59 | # Calculate class votes from the distances (avoiding division by zero) 60 | # Note: Values stored in `dists` are the squared 2-norm values of the respective distance vectors 61 | max_value = np.finfo(np.float32).max 62 | scores = np.divide(1, dists, out=np.full_like(dists, max_value), where=dists != 0) 63 | # NOTE: if self.k and self.num_classes are both large, this might become a big one. 64 | # We can shape of a factor self.k if we count differently here. 65 | n = len(self._val_embeddings.images) 66 | weighted_count = np.zeros((n, self.num_classes, self.k), dtype=np.float32) 67 | weighted_count[ 68 | np.repeat(np.arange(n), self.k), # [0, 0, .., 0_k, 1, 1, .., 1_k, ..] 69 | nearest_classes.reshape(-1), # [class numbers] 70 | np.tile(np.arange(self.k), n), # [0, 1, .., k-1, 0, 1, .., k-1, ..] 71 | ] = scores.reshape(-1) 72 | probabilities = softmax(weighted_count.sum(-1)) 73 | return probabilities, np.argmax(probabilities, axis=1) 74 | -------------------------------------------------------------------------------- /tti_eval/evaluation/linear_probe.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | import numpy as np 5 | from sklearn.linear_model import LogisticRegression, LogisticRegressionCV 6 | 7 | from tti_eval.common import ClassArray, Embeddings, ProbabilityArray 8 | from tti_eval.evaluation.base import ClassificationModel 9 | 10 | logger = logging.getLogger("multiclips") 11 | 12 | 13 | class LinearProbeClassifier(ClassificationModel): 14 | @classmethod 15 | def title(cls) -> str: 16 | return "linear_probe" 17 | 18 | def __init__( 19 | self, 20 | train_embeddings: Embeddings, 21 | validation_embeddings: Embeddings, 22 | num_classes: int | None = None, 23 | log_reg_params: dict[str, Any] | None = None, 24 | use_cross_validation: bool = False, 25 | ) -> None: 26 | """ 27 | Logistic Regression model based on the provided embeddings and labels. 28 | 29 | :param train_embeddings: Embeddings and their labels used for setting up the model. 30 | :param validation_embeddings: Embeddings and their labels used for evaluating the model. 31 | :param num_classes: Number of classes. If not specified, it will be inferred from the train labels. 32 | :param log_reg_params: Parameters for the Logistic Regression model. 33 | :param use_cross_validation: Flag that indicated whether to use cross-validation when training the model. 34 | """ 35 | super().__init__(train_embeddings, validation_embeddings, num_classes) 36 | 37 | params = log_reg_params or {} 38 | self.classifier: LogisticRegressionCV | LogisticRegression 39 | if use_cross_validation: 40 | self.classifier = LogisticRegressionCV( 41 | Cs=[1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3], random_state=0, **params 42 | ).fit(self._train_embeddings.images, self._train_embeddings.labels) # type: ignore 43 | else: 44 | self.classifier = LogisticRegression(random_state=0, **params).fit( 45 | self._train_embeddings.images, self._train_embeddings.labels 46 | ) 47 | 48 | def predict(self) -> tuple[ProbabilityArray, ClassArray]: 49 | probabilities: ProbabilityArray = self.classifier.predict_proba(self._val_embeddings.images) # type: ignore 50 | return probabilities, np.argmax(probabilities, axis=1) 51 | 52 | 53 | if __name__ == "__main__": 54 | train_embeddings = Embeddings( 55 | images=np.random.randn(100, 10).astype(np.float32), 56 | labels=np.random.randint(0, 10, size=(100,)), 57 | ) 58 | val_embeddings = Embeddings( 59 | images=np.random.randn(2, 10).astype(np.float32), 60 | labels=np.random.randint(0, 20, size=(2,)).astype(np.float32), 61 | ) 62 | linear_probe = LinearProbeClassifier( 63 | train_embeddings, 64 | val_embeddings, 65 | num_classes=20, 66 | ) 67 | probs, pred_classes = linear_probe.predict() 68 | print(probs) 69 | print(pred_classes) 70 | -------------------------------------------------------------------------------- /tti_eval/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.typing as npt 3 | 4 | from tti_eval.common.numpy_types import DType 5 | 6 | 7 | def normalize(x: npt.NDArray[DType]) -> npt.NDArray[DType]: 8 | return x / np.linalg.norm(x, ord=2, axis=1, keepdims=True) 9 | 10 | 11 | def softmax(x: npt.NDArray[DType]) -> npt.NDArray[DType]: 12 | z = x - x.max(axis=1, keepdims=True) 13 | numerator = np.exp(z) 14 | return np.exp(z) / np.sum(numerator, axis=1, keepdims=True) 15 | -------------------------------------------------------------------------------- /tti_eval/evaluation/zero_shot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tti_eval.common import ClassArray, Embeddings, ProbabilityArray 4 | from tti_eval.evaluation.base import ClassificationModel 5 | from tti_eval.evaluation.utils import softmax 6 | 7 | 8 | class ZeroShotClassifier(ClassificationModel): 9 | @classmethod 10 | def title(cls) -> str: 11 | return "zero_shot" 12 | 13 | def __init__( 14 | self, 15 | train_embeddings: Embeddings, 16 | validation_embeddings: Embeddings, 17 | num_classes: int | None = None, 18 | ) -> None: 19 | """ 20 | Zero-Shot classifier based on the provided embeddings and labels. 21 | 22 | :param train_embeddings: Embeddings and their labels used for setting up the search space. 23 | :param validation_embeddings: Embeddings and their labels used for evaluating the search space. 24 | :param num_classes: Number of classes. If not specified, it will be inferred from the train labels. 25 | """ 26 | super().__init__(train_embeddings, validation_embeddings, num_classes) 27 | if self._train_embeddings.classes is None: 28 | raise ValueError("Expected class embeddings in `train_embeddings`, got `None`") 29 | 30 | @property 31 | def dim(self) -> int: 32 | return self._train_embeddings.classes.shape[-1] 33 | 34 | def predict(self) -> tuple[ProbabilityArray, ClassArray]: 35 | inner_products = self._val_embeddings.images @ self._train_embeddings.classes.T 36 | probabilities = softmax(inner_products) 37 | return probabilities, np.argmax(probabilities, axis=1) 38 | 39 | 40 | if __name__ == "__main__": 41 | train_embeddings = Embeddings( 42 | images=np.random.randn(100, 10).astype(np.float32), 43 | labels=np.random.randint(0, 10, size=(100,)), 44 | classes=np.random.randn(20, 10).astype(np.float32), 45 | ) 46 | val_embeddings = Embeddings( 47 | images=np.random.randn(2, 10).astype(np.float32), 48 | labels=np.random.randint(0, 10, size=(2,)), 49 | ) 50 | zeroshot = ZeroShotClassifier( 51 | train_embeddings, 52 | val_embeddings, 53 | num_classes=20, 54 | ) 55 | probs, pred_classes = zeroshot.predict() 56 | print(probs) 57 | print(pred_classes) 58 | -------------------------------------------------------------------------------- /tti_eval/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Model 2 | from .provider import ModelProvider 3 | -------------------------------------------------------------------------------- /tti_eval/model/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections.abc import Callable 3 | from os.path import relpath 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | import torch 8 | from pydantic import BaseModel, ConfigDict 9 | from torch.utils.data import DataLoader 10 | 11 | from tti_eval.common.numpy_types import ClassArray, EmbeddingArray 12 | from tti_eval.constants import CACHE_PATH, SOURCES_PATH 13 | 14 | DEFAULT_MODEL_TYPES_LOCATION = ( 15 | Path(relpath(str(__file__), SOURCES_PATH.MODEL_INSTANCE_DEFINITIONS)).parent / "types" / "__init__.py" 16 | ) 17 | 18 | 19 | class ModelDefinitionSpec(BaseModel): 20 | model_type: str 21 | title: str 22 | module_path: Path = DEFAULT_MODEL_TYPES_LOCATION 23 | device: str | None = None 24 | title_in_source: str | None = None 25 | cache_dir: Path | None = None 26 | 27 | # Allow additional model configuration fields 28 | # Also, silence spurious Pydantic UserWarning: Field "model_type" has conflict with protected namespace "model_" 29 | model_config = ConfigDict(protected_namespaces=(), extra="allow") 30 | 31 | 32 | class Model(ABC): 33 | def __init__( 34 | self, 35 | title: str, 36 | device: str | None = None, 37 | *, 38 | title_in_source: str | None = None, 39 | cache_dir: str | None = None, 40 | **kwargs, 41 | ) -> None: 42 | self.__title = title 43 | self.__title_in_source = title_in_source or title 44 | device = device or ("cuda" if torch.cuda.is_available() else "cpu") 45 | self._check_device(device) 46 | self.__device = torch.device(device) 47 | if cache_dir is None: 48 | cache_dir = CACHE_PATH 49 | self._cache_dir = Path(cache_dir).expanduser().resolve() / "models" / title 50 | 51 | @property 52 | def title(self) -> str: 53 | return self.__title 54 | 55 | @property 56 | def title_in_source(self) -> str: 57 | return self.__title_in_source 58 | 59 | @property 60 | def device(self) -> torch.device: 61 | return self.__device 62 | 63 | @abstractmethod 64 | def _setup(self, **kwargs) -> None: 65 | pass 66 | 67 | @abstractmethod 68 | def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: 69 | ... 70 | 71 | @abstractmethod 72 | def get_collate_fn(self) -> Callable[[Any], Any]: 73 | ... 74 | 75 | @abstractmethod 76 | def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: 77 | ... 78 | 79 | @staticmethod 80 | def _check_device(device: str): 81 | # Check if the input device exists and is available 82 | if device not in {"cuda", "cpu"}: 83 | raise ValueError(f"Unrecognized device: {device}") 84 | if not getattr(torch, device).is_available(): 85 | raise ValueError(f"Unavailable device: {device}") 86 | -------------------------------------------------------------------------------- /tti_eval/model/provider.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | from natsort import natsorted, ns 5 | 6 | from tti_eval.constants import SOURCES_PATH 7 | from tti_eval.dataset.utils import load_class_from_path 8 | 9 | from .base import Model, ModelDefinitionSpec 10 | 11 | 12 | class ModelProvider: 13 | __instance = None 14 | __known_model_types: dict[tuple[Path, str], Any] = dict() 15 | 16 | def __init__(self) -> None: 17 | self._models = {} 18 | 19 | @classmethod 20 | def prepare(cls): 21 | if cls.__instance is None: 22 | cls.__instance = cls() 23 | cls.register_models_from_sources_dir(SOURCES_PATH.MODEL_INSTANCE_DEFINITIONS) 24 | return cls.__instance 25 | 26 | @classmethod 27 | def register_model(cls, source: type[Model], title: str, **kwargs): 28 | cls.prepare()._models[title] = (source, kwargs) 29 | 30 | @classmethod 31 | def register_model_from_json_definition(cls, json_definition: Path) -> None: 32 | spec = ModelDefinitionSpec.model_validate_json(json_definition.read_text(encoding="utf-8")) 33 | if not spec.module_path.is_absolute(): # Handle relative module paths 34 | spec.module_path = (json_definition.parent / spec.module_path).resolve() 35 | if not spec.module_path.is_file(): 36 | raise FileNotFoundError( 37 | f"Could not find the specified module path at `{spec.module_path.as_posix()}` " 38 | f"when registering model `{spec.title}`" 39 | ) 40 | 41 | # Fetch the class of the model type stated in the definition 42 | model_type = cls.__known_model_types.get((spec.module_path, spec.model_type)) 43 | if model_type is None: 44 | model_type = load_class_from_path(spec.module_path.as_posix(), spec.model_type) 45 | if not issubclass(model_type, Model): 46 | raise ValueError( 47 | f"Model type specified in the JSON definition file `{json_definition.as_posix()}` " 48 | f"does not inherit from the base class `Model`" 49 | ) 50 | cls.__known_model_types[(spec.module_path, spec.model_type)] = model_type 51 | cls.register_model(model_type, **spec.model_dump(exclude={"module_path", "model_type"})) 52 | 53 | @classmethod 54 | def register_models_from_sources_dir(cls, source_dir: Path) -> None: 55 | for f in source_dir.glob("*.json"): 56 | cls.register_model_from_json_definition(f) 57 | 58 | @classmethod 59 | def get_model(cls, title: str) -> Model: 60 | instance = cls.prepare() 61 | if title not in instance._models: 62 | raise ValueError(f"Unrecognized model: {title}") 63 | source, kwargs = instance._models[title] 64 | return source(title, **kwargs) 65 | 66 | @classmethod 67 | def list_model_titles(cls) -> list[str]: 68 | return natsorted(cls.prepare()._models.keys(), alg=ns.IGNORECASE) 69 | -------------------------------------------------------------------------------- /tti_eval/model/types/__init__.py: -------------------------------------------------------------------------------- 1 | from tti_eval.model.types.hugging_face import HFModel 2 | from tti_eval.model.types.local_clip_model import LocalCLIPModel 3 | from tti_eval.model.types.open_clip_model import OpenCLIPModel 4 | -------------------------------------------------------------------------------- /tti_eval/model/types/hugging_face.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Any 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from transformers import AutoModel as HF_AutoModel 9 | from transformers import AutoProcessor as HF_AutoProcessor 10 | from transformers import AutoTokenizer as HF_AutoTokenizer 11 | 12 | from tti_eval.common import ClassArray, EmbeddingArray 13 | from tti_eval.dataset import Dataset 14 | from tti_eval.model import Model 15 | 16 | 17 | class HFModel(Model): 18 | def __init__( 19 | self, 20 | title: str, 21 | device: str | None = None, 22 | *, 23 | title_in_source: str | None = None, 24 | cache_dir: str | None = None, 25 | **kwargs, 26 | ) -> None: 27 | super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir) 28 | self._setup(**kwargs) 29 | 30 | def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: 31 | def process_fn(batch) -> dict[str, list[Any]]: 32 | images = [i.convert("RGB") for i in batch["image"]] 33 | batch["image"] = [ 34 | self.processor(images=[i], return_tensors="pt").to(self.device).pixel_values.squeeze() for i in images 35 | ] 36 | return batch 37 | 38 | return process_fn 39 | 40 | def get_collate_fn(self) -> Callable[[Any], Any]: 41 | def collate_fn(examples) -> dict[str, torch.Tensor]: 42 | images = [] 43 | labels = [] 44 | for example in examples: 45 | images.append(example["image"]) 46 | labels.append(example["label"]) 47 | 48 | pixel_values = torch.stack(images) 49 | labels = torch.tensor(labels) 50 | return {"pixel_values": pixel_values, "labels": labels} 51 | 52 | return collate_fn 53 | 54 | def _setup(self, **kwargs) -> None: 55 | self.model = HF_AutoModel.from_pretrained(self.title_in_source, cache_dir=self._cache_dir).to(self.device) 56 | load_result = HF_AutoProcessor.from_pretrained(self.title_in_source, cache_dir=self._cache_dir) 57 | self.processor = load_result[0] if isinstance(load_result, tuple) else load_result 58 | self.tokenizer = HF_AutoTokenizer.from_pretrained(self.title_in_source, cache_dir=self._cache_dir) 59 | 60 | def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: 61 | all_image_embeddings = [] 62 | all_labels = [] 63 | with torch.inference_mode(): 64 | _dataset: Dataset = dataloader.dataset 65 | inputs = self.tokenizer(_dataset.text_queries, padding=True, return_tensors="pt").to(self.device) 66 | class_features = self.model.get_text_features(**inputs) 67 | normalized_class_features = class_features / class_features.norm(p=2, dim=-1, keepdim=True) 68 | class_embeddings = normalized_class_features.numpy(force=True) 69 | for batch in tqdm( 70 | dataloader, 71 | desc=f"Embedding ({_dataset.split}) {_dataset.title} dataset with {self.title}", 72 | ): 73 | image_features = self.model.get_image_features(pixel_values=batch["pixel_values"].to(self.device)) 74 | normalized_image_features = (image_features / image_features.norm(p=2, dim=-1, keepdim=True)).squeeze() 75 | all_image_embeddings.append(normalized_image_features) 76 | all_labels.append(batch["labels"]) 77 | image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True) 78 | labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32) 79 | return image_embeddings, class_embeddings, labels 80 | -------------------------------------------------------------------------------- /tti_eval/model/types/local_clip_model.py: -------------------------------------------------------------------------------- 1 | from tti_eval.model.types.open_clip_model import OpenCLIPModel 2 | 3 | 4 | class LocalCLIPModel(OpenCLIPModel): 5 | def __init__( 6 | self, 7 | title: str, 8 | device: str | None = None, 9 | *, 10 | title_in_source: str, 11 | cache_dir: str | None = None, 12 | **kwargs, 13 | ) -> None: 14 | super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir, **kwargs) 15 | 16 | def _setup(self, **kwargs) -> None: 17 | self.pretrained = (self._cache_dir / "checkpoint.pt").as_posix() 18 | super()._setup(**kwargs) 19 | -------------------------------------------------------------------------------- /tti_eval/model/types/open_clip_model.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Any 3 | 4 | import numpy as np 5 | import open_clip 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from tti_eval.common import ClassArray, EmbeddingArray 11 | from tti_eval.dataset import Dataset 12 | from tti_eval.model import Model 13 | 14 | 15 | class OpenCLIPModel(Model): 16 | def __init__( 17 | self, 18 | title: str, 19 | device: str | None = None, 20 | *, 21 | title_in_source: str, 22 | pretrained: str | None = None, 23 | cache_dir: str | None = None, 24 | **kwargs, 25 | ) -> None: 26 | self.pretrained = pretrained 27 | super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir, **kwargs) 28 | self._setup(**kwargs) 29 | 30 | def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: 31 | def process_fn(batch) -> dict[str, list[Any]]: 32 | images = [i.convert("RGB") for i in batch["image"]] 33 | batch["image"] = [self.processor(i) for i in images] 34 | return batch 35 | 36 | return process_fn 37 | 38 | def get_collate_fn(self) -> Callable[[Any], Any]: 39 | def collate_fn(examples) -> dict[str, torch.Tensor]: 40 | images = [] 41 | labels = [] 42 | for example in examples: 43 | images.append(example["image"]) 44 | labels.append(example["label"]) 45 | 46 | torch_images = torch.stack(images) 47 | labels = torch.tensor(labels) 48 | return {"image": torch_images, "labels": labels} 49 | 50 | return collate_fn 51 | 52 | def _setup(self, **kwargs) -> None: 53 | self.model, _, self.processor = open_clip.create_model_and_transforms( 54 | model_name=self.title_in_source, 55 | pretrained=self.pretrained, 56 | cache_dir=self._cache_dir.as_posix(), 57 | device=self.device, 58 | **kwargs, 59 | ) 60 | self.tokenizer = open_clip.get_tokenizer(model_name=self.title_in_source) 61 | 62 | def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: 63 | all_image_embeddings = [] 64 | all_labels = [] 65 | with torch.inference_mode(): 66 | _dataset: Dataset = dataloader.dataset 67 | text = self.tokenizer(_dataset.text_queries).to(self.device) 68 | class_embeddings = self.model.encode_text(text, normalize=True).numpy(force=True) 69 | for batch in tqdm( 70 | dataloader, 71 | desc=f"Embedding ({_dataset.split}) {_dataset.title} dataset with {self.title}", 72 | ): 73 | image_features = self.model.encode_image(batch["image"].to(self.device), normalize=True) 74 | all_image_embeddings.append(image_features) 75 | all_labels.append(batch["labels"]) 76 | image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True) 77 | labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32) 78 | return image_embeddings, class_embeddings, labels 79 | -------------------------------------------------------------------------------- /tti_eval/plotting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/tti_eval/plotting/__init__.py -------------------------------------------------------------------------------- /tti_eval/plotting/animation.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | from typing import Literal, overload 4 | 5 | import matplotlib.animation as animation 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from matplotlib.widgets import Slider 9 | from PIL import Image 10 | 11 | from tti_eval.common import EmbeddingDefinition 12 | from tti_eval.common.base import SafeName 13 | from tti_eval.common.numpy_types import ClassArray, N2Array 14 | from tti_eval.constants import OUTPUT_PATH 15 | from tti_eval.dataset import DatasetProvider, Split 16 | 17 | from .reduction import REDUCTIONS, reduction_from_string 18 | 19 | 20 | @overload 21 | def create_embedding_chart( 22 | x1: N2Array, 23 | x2: N2Array, 24 | labels: ClassArray, 25 | title1: str | SafeName, 26 | title2: str | SafeName, 27 | suptitle: str, 28 | label_names: list[str], 29 | *, 30 | interactive: Literal[False], 31 | ) -> animation.FuncAnimation: 32 | ... 33 | 34 | 35 | @overload 36 | def create_embedding_chart( 37 | x1: N2Array, 38 | x2: N2Array, 39 | labels: ClassArray, 40 | title1: str | SafeName, 41 | title2: str | SafeName, 42 | suptitle: str, 43 | label_names: list[str], 44 | *, 45 | interactive: Literal[True], 46 | ) -> None: 47 | ... 48 | 49 | 50 | @overload 51 | def create_embedding_chart( 52 | x1: N2Array, 53 | x2: N2Array, 54 | labels: ClassArray, 55 | title1: str | SafeName, 56 | title2: str | SafeName, 57 | suptitle: str, 58 | label_names: list[str], 59 | *, 60 | interactive: bool, 61 | ) -> animation.FuncAnimation | None: 62 | ... 63 | 64 | 65 | def create_embedding_chart( 66 | x1: N2Array, 67 | x2: N2Array, 68 | labels: ClassArray, 69 | title1: str | SafeName, 70 | title2: str | SafeName, 71 | suptitle: str, 72 | label_names: list[str], 73 | *, 74 | interactive: bool = False, 75 | ) -> animation.FuncAnimation | None: 76 | fig, ax = plt.subplots(figsize=(12, 7)) 77 | fig.suptitle(suptitle) 78 | 79 | # Compute chart bounds 80 | all = np.concatenate([x1, x2], axis=0) 81 | xmin, ymin = all.min(0) 82 | xmax, ymax = all.max(0) 83 | dx = xmax - xmin 84 | dy = ymax - ymin 85 | xmin -= dx * 0.1 86 | xmax += dx * 0.1 87 | ymin -= dy * 0.1 88 | ymax += dy * 0.1 89 | 90 | # Initial plot 91 | points = plt.scatter(*x1.T, c=labels) 92 | 93 | ax.axis("off") 94 | ax.set_xlim([xmin, xmax]) 95 | ax.set_ylim([ymin, ymax]) 96 | 97 | handles, _ = points.legend_elements() 98 | ax.legend(handles, label_names, loc="upper right") 99 | 100 | # Position scatter and slider 101 | s_width, s_height = (0.4, 0.1) 102 | x_padding = (1 - s_width) / 2 103 | 104 | fig.subplots_adjust(bottom=s_height * 1.5) 105 | # Insert Logo 106 | logo = Image.open(Path(__file__).parents[2] / "images" / "encord.png") 107 | w, h = logo.size 108 | logo_padding_x = 0.02 109 | ax_w = x_padding - logo_padding_x * 2 110 | ax_h = ax_w * h / w 111 | ax_logo = fig.add_axes((0.02, 0.02, ax_w, ax_h)) 112 | ax_logo.axis("off") 113 | ax_logo.imshow(logo) 114 | 115 | # Build slider 116 | ax_interpolation = fig.add_axes((x_padding, s_height, s_width, s_height / 2)) 117 | init_time = 0 118 | 119 | def interpolate(t): 120 | t = max(min(t, 1), 0) 121 | return ((1 - t) * x1 + t * x2).T 122 | 123 | interpolate_slider = Slider( 124 | ax=ax_interpolation, 125 | label="", 126 | valmin=0.0, 127 | valmax=1.0, 128 | valinit=init_time, 129 | ) 130 | 131 | # Add model_titles 132 | fig.text(x_padding - 0.02, s_height * 1.25, title1, ha="right", va="center") 133 | fig.text(1 - x_padding + 0.06, 0.1 * 1.25, title2, ha="left", va="center") 134 | 135 | def update_from_slider(val): 136 | points.set_offsets(interpolate(interpolate_slider.val).T) 137 | 138 | interpolate_slider.on_changed(update_from_slider) 139 | 140 | if interactive: 141 | plt.show() 142 | return 143 | 144 | # Animation bit 145 | frames_left = np.linspace(0, 1, 20) ** 4 / 2 146 | frames = np.concatenate([frames_left, 1 - frames_left[::-1][1:]], axis=0) 147 | frames = np.concatenate([frames, frames[::-1]], axis=0) 148 | 149 | def update_from_animation(val): 150 | interpolate_slider.set_val(val) 151 | 152 | interpolate_slider.set_active(False) 153 | 154 | return animation.FuncAnimation(fig, update_from_animation, frames=frames) # type: ignore 155 | 156 | 157 | def standardize(a: N2Array) -> N2Array: 158 | mins = a.min(0, keepdims=True) 159 | maxs = a.max(0, keepdims=True) 160 | return (a - mins) / (maxs - mins) 161 | 162 | 163 | def rotate_to_target(source: N2Array, destination: N2Array) -> N2Array: 164 | from scipy.spatial.transform import Rotation as R 165 | 166 | source = np.pad(source, [(0, 0), (0, 1)], mode="constant", constant_values=0.0) 167 | destination = np.pad(destination, [(0, 0), (0, 1)], mode="constant", constant_values=0.0) 168 | 169 | rot, *_ = R.align_vectors(destination, source, return_sensitivity=True) 170 | out = rot.apply(source) 171 | return out[:, :2] 172 | 173 | 174 | @overload 175 | def build_animation( 176 | defn_1: EmbeddingDefinition, 177 | defn_2: EmbeddingDefinition, 178 | *, 179 | interactive: Literal[True], 180 | reduction: REDUCTIONS = "umap", 181 | ) -> None: 182 | ... 183 | 184 | 185 | @overload 186 | def build_animation( 187 | defn_1: EmbeddingDefinition, 188 | defn_2: EmbeddingDefinition, 189 | *, 190 | reduction: REDUCTIONS = "umap", 191 | interactive: Literal[False], 192 | ) -> animation.FuncAnimation: 193 | ... 194 | 195 | 196 | @overload 197 | def build_animation( 198 | defn_1: EmbeddingDefinition, 199 | defn_2: EmbeddingDefinition, 200 | *, 201 | reduction: REDUCTIONS = "umap", 202 | interactive: bool, 203 | ) -> animation.FuncAnimation | None: 204 | ... 205 | 206 | 207 | def build_animation( 208 | defn_1: EmbeddingDefinition, 209 | defn_2: EmbeddingDefinition, 210 | *, 211 | split: Split = Split.VALIDATION, 212 | reduction: REDUCTIONS = "umap", 213 | interactive: bool = False, 214 | ) -> animation.FuncAnimation | None: 215 | dataset = DatasetProvider.get_dataset(defn_1.dataset, split) 216 | 217 | embeds = defn_1.load_embeddings(split) # FIXME: This is expensive to get just labels 218 | if embeds is None: 219 | raise ValueError("Empty embeddings") 220 | 221 | reducer = reduction_from_string(reduction) 222 | reduced_1 = standardize(reducer.get_reduction(defn_1, split)) 223 | reduced_2 = rotate_to_target(standardize(reducer.get_reduction(defn_2, split)), reduced_1) 224 | labels = embeds.labels 225 | 226 | if reduced_1.shape[0] > 2_000: 227 | selection = np.random.permutation(reduced_1.shape[0])[:2_000] 228 | reduced_1 = reduced_1[selection] 229 | reduced_2 = reduced_2[selection] 230 | labels = labels[selection] 231 | 232 | return create_embedding_chart( 233 | reduced_1, 234 | reduced_2, 235 | labels, 236 | defn_1.model, 237 | defn_2.model, 238 | suptitle=dataset.title, 239 | label_names=dataset.class_names, 240 | interactive=interactive, 241 | ) 242 | 243 | 244 | def save_animation_to_file( 245 | anim: animation.FuncAnimation, 246 | def1: EmbeddingDefinition, 247 | def2: EmbeddingDefinition, 248 | split: Split = Split.VALIDATION, 249 | ): 250 | date_code = datetime.now().strftime("%Y%m%d-%H%M%S") 251 | file_name = f"transition_{def1.dataset}_{def1.model}_{def2.model}_{split}_{date_code}.gif" 252 | animation_file = OUTPUT_PATH.ANIMATIONS / file_name 253 | animation_file.parent.mkdir(parents=True, exist_ok=True) # Ensure that parent folder exists 254 | anim.save(animation_file) 255 | print(f"Animation stored at `{animation_file.resolve().as_posix()}`") 256 | -------------------------------------------------------------------------------- /tti_eval/plotting/reduction.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Literal 3 | 4 | import numpy as np 5 | import umap 6 | from sklearn.decomposition import PCA 7 | from sklearn.manifold import TSNE 8 | 9 | from tti_eval.common import EmbeddingArray, EmbeddingDefinition, ReductionArray, Split 10 | 11 | 12 | class Reducer: 13 | @classmethod 14 | @abstractmethod 15 | def title(cls) -> str: 16 | raise NotImplementedError("This abstract method returns the title of the reducer implemented in this class.") 17 | 18 | @classmethod 19 | @abstractmethod 20 | def reduce(cls, embeddings: EmbeddingArray, **kwargs) -> ReductionArray: 21 | raise NotImplementedError("This abstract method contains the implementation for reducing embeddings.") 22 | 23 | @classmethod 24 | def get_reduction( 25 | cls, 26 | embedding_def: EmbeddingDefinition, 27 | split: Split, 28 | force_recompute: bool = False, 29 | save: bool = True, 30 | **kwargs, 31 | ) -> ReductionArray: 32 | reduction_file = embedding_def.get_reduction_path(cls.title(), split=split) 33 | if reduction_file.is_file() and not force_recompute: 34 | reduction: ReductionArray = np.load(reduction_file) 35 | return reduction 36 | 37 | elif not embedding_def.embedding_path(split).is_file(): 38 | raise ValueError( 39 | f"{repr(embedding_def)} does not have embeddings stored ({embedding_def.embedding_path(split)})" 40 | ) 41 | 42 | image_embeddings: EmbeddingArray = np.load(embedding_def.embedding_path(split))["image_embeddings"] 43 | reduction = cls.reduce(image_embeddings) 44 | if save: 45 | reduction_file.parent.mkdir(parents=True, exist_ok=True) 46 | np.save(reduction_file, reduction) 47 | return reduction 48 | 49 | 50 | class UMAPReducer(Reducer): 51 | @classmethod 52 | def title(cls) -> str: 53 | return "umap" 54 | 55 | @classmethod 56 | def reduce(cls, embeddings: EmbeddingArray, umap_seed: int | None = None, **kwargs) -> ReductionArray: 57 | reducer = umap.UMAP() 58 | return reducer.fit_transform(embeddings) 59 | 60 | 61 | class TSNEReducer(Reducer): 62 | @classmethod 63 | def title(cls) -> str: 64 | return "tsne" 65 | 66 | @classmethod 67 | def reduce(cls, embeddings: EmbeddingArray, **kwargs) -> ReductionArray: 68 | reducer = TSNE() 69 | return reducer.fit_transform(embeddings) 70 | 71 | 72 | class PCAReducer(Reducer): 73 | @classmethod 74 | def title(cls) -> str: 75 | return "pca" 76 | 77 | @classmethod 78 | def reduce(cls, embeddings: EmbeddingArray, **kwargs) -> ReductionArray: 79 | reducer = PCA(n_components=2) 80 | return reducer.fit_transform(embeddings) 81 | 82 | 83 | __REDUCTIONS = { 84 | UMAPReducer.title(): UMAPReducer, 85 | TSNEReducer.title(): TSNEReducer, 86 | PCAReducer.title(): PCAReducer, 87 | } 88 | REDUCTIONS = Literal["umap"] | Literal["tsne"] | Literal["pca"] 89 | 90 | 91 | def reduction_from_string(name: str) -> UMAPReducer | TSNEReducer | PCAReducer: 92 | if name not in __REDUCTIONS: 93 | raise KeyError(f"{name} not in set {set(__REDUCTIONS.keys())}") 94 | return __REDUCTIONS[name] 95 | 96 | 97 | if __name__ == "__main__": 98 | embeddings = np.random.randn(100, 20) 99 | 100 | for cls in [UMAPReducer, TSNEReducer, PCAReducer]: 101 | print(cls.reduce(embeddings).shape) 102 | -------------------------------------------------------------------------------- /tti_eval/utils.py: -------------------------------------------------------------------------------- 1 | from functools import partialmethod 2 | from itertools import chain 3 | from typing import Literal, overload 4 | 5 | from tqdm import tqdm 6 | 7 | from tti_eval.common import EmbeddingDefinition 8 | from tti_eval.constants import PROJECT_PATHS 9 | 10 | 11 | def disable_tqdm(): 12 | tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) 13 | 14 | 15 | def enable_tqdm(): 16 | tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) 17 | 18 | 19 | @overload 20 | def read_all_cached_embeddings(as_list: Literal[True]) -> list[EmbeddingDefinition]: 21 | ... 22 | 23 | 24 | @overload 25 | def read_all_cached_embeddings(as_list: Literal[False] = False) -> dict[str, list[EmbeddingDefinition]]: 26 | ... 27 | 28 | 29 | def read_all_cached_embeddings( 30 | as_list: bool = False, 31 | ) -> dict[str, list[EmbeddingDefinition]] | list[EmbeddingDefinition]: 32 | """ 33 | Reads existing embedding definitions from the cache directory. 34 | Returns: a dictionary of where the list is over models. 35 | """ 36 | if not PROJECT_PATHS.EMBEDDINGS.exists(): 37 | return dict() 38 | 39 | defs_dict = { 40 | d.name: list( 41 | { 42 | EmbeddingDefinition(dataset=d.name, model=m.stem.rsplit("_", maxsplit=1)[0]) 43 | for m in d.iterdir() 44 | if m.is_file() and m.suffix == ".npz" 45 | } 46 | ) 47 | for d in PROJECT_PATHS.EMBEDDINGS.iterdir() 48 | if d.is_dir() 49 | } 50 | if as_list: 51 | return list(chain(*defs_dict.values())) 52 | return defs_dict 53 | 54 | 55 | if __name__ == "__main__": 56 | print(read_all_cached_embeddings()) 57 | --------------------------------------------------------------------------------