├── .github
├── actions
│ └── setup-poetry-environment
│ │ └── action.yml
└── workflows
│ └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── dev-setup.sh
├── images
├── embeddings.gif
├── encord.png
├── eval-usage.svg
├── logo_v1.webp
└── tti-eval-banner.png
├── notebooks
├── tti_eval_Bring_Your_Own_Dataset_From_Encord_Quickstart.ipynb
├── tti_eval_Bring_Your_Own_Model_From_Hugging_Face_Quickstart.ipynb
├── tti_eval_CLI_Quickstart.ipynb
└── tti_eval_Python_Quickstart.ipynb
├── poetry.lock
├── pyproject.toml
├── sources
├── dataset-definition-schema.json
├── datasets
│ ├── Alzheimer-MRI.json
│ ├── LungCancer4Types.json
│ ├── chest-xray-classification.json
│ ├── plants.json
│ ├── rsicd.json
│ ├── skin-cancer.json
│ └── sports-classification.json
├── model-definition-schema.json
└── models
│ ├── apple.json
│ ├── bioclip.json
│ ├── clip.json
│ ├── eva-clip.json
│ ├── fashion.json
│ ├── plip.json
│ ├── pubmed.json
│ ├── rsicd-encord.json
│ ├── rsicd.json
│ ├── siglip_large.json
│ ├── siglip_small.json
│ ├── street.json
│ ├── tinyclip.json
│ └── vit-b-32-laion2b.json
├── tests
├── common
│ └── test_embedding_definition.py
└── evaluation
│ └── test_knn.py
└── tti_eval
├── cli
├── __init__.py
├── main.py
└── utils.py
├── common
├── __init__.py
├── base.py
├── numpy_types.py
└── string_utils.py
├── compute.py
├── constants.py
├── dataset
├── __init__.py
├── base.py
├── provider.py
├── types
│ ├── __init__.py
│ ├── encord_ds.py
│ └── hugging_face.py
└── utils.py
├── evaluation
├── __init__.py
├── base.py
├── evaluator.py
├── image_retrieval.py
├── knn.py
├── linear_probe.py
├── utils.py
└── zero_shot.py
├── model
├── __init__.py
├── base.py
├── provider.py
└── types
│ ├── __init__.py
│ ├── hugging_face.py
│ ├── local_clip_model.py
│ └── open_clip_model.py
├── plotting
├── __init__.py
├── animation.py
└── reduction.py
└── utils.py
/.github/actions/setup-poetry-environment/action.yml:
--------------------------------------------------------------------------------
1 | name: "Setup test environment"
2 | description: "Sets up Python, Poetry and dependencies"
3 |
4 | inputs:
5 | python:
6 | description: "Python version to use"
7 | default: "3.11"
8 | required: false
9 | poetry:
10 | description: "Poetry version to use"
11 | default: 1.7.1
12 | required: false
13 |
14 | runs:
15 | using: "composite"
16 |
17 | steps:
18 | - uses: actions/setup-python@v5
19 | with:
20 | python-version: ${{ inputs.python }}
21 |
22 | - uses: snok/install-poetry@v1
23 | with:
24 | version: ${{ inputs.poetry }}
25 | virtualenvs-create: true
26 | virtualenvs-in-project: true
27 |
28 | - name: Load cached venv
29 | id: cached-poetry-dependencies
30 | uses: actions/cache@v4
31 | with:
32 | path: .venv
33 | key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
34 |
35 | - if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
36 | run: |
37 | poetry lock --no-update
38 | poetry install --no-interaction
39 | shell: bash
40 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | push:
8 | branches:
9 | - main
10 | workflow_dispatch:
11 |
12 | env:
13 | PYTHON_VERSION: 3.11
14 |
15 | jobs:
16 | pre-commit:
17 | name: Linting and type checking
18 | runs-on: ubuntu-latest
19 | steps:
20 | - uses: actions/checkout@v4
21 | - name: Setup poetry environment
22 | uses: ./.github/actions/setup-poetry-environment
23 | - name: Run linting, type checking and testing
24 | uses: pre-commit/action@v3.0.0
25 | with:
26 | extra_args: "--all-files --hook-stage=push"
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/vim,python
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=vim,python
4 |
5 | ### Python ###
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/#use-with-ide
115 | .pdm.toml
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | .idea/
166 |
167 | #Vscode
168 | .vscode
169 |
170 | ### Python Patch ###
171 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
172 | poetry.toml
173 |
174 | # ruff
175 | .ruff_cache/
176 |
177 | # LSP config files
178 | pyrightconfig.json
179 |
180 | ### Vim ###
181 | # Swap
182 | [._]*.s[a-v][a-z]
183 | !*.svg # comment out if you don't need vector files
184 | [._]*.sw[a-p]
185 | [._]s[a-rt-v][a-z]
186 | [._]ss[a-gi-z]
187 | [._]sw[a-p]
188 |
189 | # Session
190 | Session.vim
191 | Sessionx.vim
192 |
193 | # Temporary
194 | .netrwhist
195 | *~
196 | # Auto-generated tag files
197 | tags
198 | # Persistent undo
199 | [._]*.un~
200 |
201 | # End of https://www.toptal.com/developers/gitignore/api/vim,python
202 | output
203 |
204 | # Mac's specifics
205 | .DS_Store
206 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3.11
3 |
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: f71fa2c1f9cf5cb705f73dffe4b21f7c61470ba9 # 4.4.0
7 | hooks:
8 | - id: check-yaml
9 | - id: check-json
10 | - id: check-toml
11 | - id: end-of-file-fixer
12 | - id: trailing-whitespace
13 | - repo: https://github.com/pappasam/toml-sort
14 | rev: b9b6210da457c38122995e434b314f4c4a4a923e # 0.23.1
15 | hooks:
16 | - id: toml-sort-fix
17 | files: ^.+.toml$
18 | - repo: https://github.com/astral-sh/ruff-pre-commit
19 | rev: v0.2.1
20 | hooks:
21 | - id: ruff
22 | args:
23 | - --fix
24 | - --exit-non-zero-on-fix
25 | # - --ignore=E501 # line-too-long
26 | # - --ignore=F631 # assert-tuple
27 | # - --ignore=E741 # ambiguous-variable-name
28 | - id: ruff-format
29 | files: ^src\/.+\.py$
30 | # - repo: https://github.com/pre-commit/mirrors-mypy
31 | # rev: v1.8.0
32 | # hooks:
33 | # - id: mypy
34 | default_stages: [push]
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
15 |
16 |
17 |
18 |
19 |
20 | Welcome to `tti-eval`, a repository for benchmarking text-to-image models **on your own data**!
21 |
22 | > Evaluate your (or HF) text-to-image embedding models like [CLIP][openai/clip-vit-large-patch14-336] from OpenAI against your (or HF) datasets to estimate how well the model will perform on your classification dataset.
23 |
24 | ## TLDR
25 |
26 | With this library, you can take an embedding model intended for jointly embedding images and text (like [CLIP][openai/clip-vit-large-patch14-336]) and compute metrics for how well such model performs on classifying your custom dataset.
27 | What you will do is
28 |
29 | 0. [Install](#installation) `tti-eval`
30 | 1. [Compute embeddings](#embeddings-generation) of a dataset with a model
31 | 2. Do an [evaluation](#model-evaluation) of the model against the dataset
32 |
33 | You can easily benchmark different models and datasets against each other. Here is an example:
34 |
35 |
36 |

37 |
38 |
39 | ## Installation
40 |
41 | > `tti-eval` requires [Python 3.10+](https://www.python.org/downloads/release/python-31014/) and [Poetry](https://python-poetry.org/docs/#installation).
42 |
43 | 1. Clone the repository:
44 | ```
45 | git clone https://github.com/encord-team/text-to-image-eval.git
46 | ```
47 | 2. Navigate to the project directory:
48 | ```
49 | cd text-to-image-eval
50 | ```
51 | 3. Install the required dependencies:
52 | ```
53 | poetry shell
54 | poetry install
55 | ```
56 | 4. Add environment variables:
57 | ```
58 | export TTI_EVAL_CACHE_PATH=$PWD/.cache
59 | export TTI_EVAL_OUTPUT_PATH=$PWD/output
60 | export ENCORD_SSH_KEY_PATH=
61 | ```
62 |
63 | ## CLI Quickstart
64 |
65 |
66 |
67 |
68 |
69 | ### Embeddings Generation
70 |
71 | To build embeddings, run the CLI command `tti-eval build`.
72 | This command allows to interactively select the model and dataset combinations on which to build the embeddings.
73 |
74 | Alternatively, you can choose known (model, dataset) pairs using the `--model-dataset` option. For example:
75 |
76 | ```
77 | tti-eval build --model-dataset clip/Alzheimer-MRI --model-dataset bioclip/Alzheimer-MRI
78 | ```
79 |
80 | ### Model Evaluation
81 |
82 | To evaluate models, use the CLI command `tti-eval evaluate`.
83 | This command enables interactive selection of model and dataset combinations for evaluation.
84 |
85 | Alternatively, you can specify known (model, dataset) pairs using the `--model-dataset` option. For example:
86 |
87 | ```
88 | tti-eval evaluate --model-dataset clip/Alzheimer-MRI --model-dataset bioclip/Alzheimer-MRI
89 | ```
90 |
91 | The evaluation results can be exported to a CSV file using the `--save` option.
92 | They will be saved on a folder at the location specified by the environment variable `TTI_EVAL_OUTPUT_PATH`.
93 | By default, exported evaluation results are stored to the `output/evaluations` folder within the repository.
94 |
95 | ### Embeddings Animation
96 |
97 | To create 2D animations of the embeddings, use the CLI command `tti-eval animate`.
98 | This command allows to visualise the reduction of embeddings from two models on the same dataset.
99 |
100 | You have the option to interactively select two models and a dataset for visualization.
101 | Alternatively, you can specify the models and dataset as arguments. For example:
102 |
103 | ```
104 | tti-eval animate clip bioclip Alzheimer-MRI
105 | ```
106 |
107 | The animations will be saved on a folder at the location specified by the environment variable `TTI_EVAL_OUTPUT_PATH`.
108 | By default, animations are stored in the `output/animations` folder within the repository.
109 | To interactively explore the animation in a temporary session, use the `--interactive` flag.
110 |
111 |
112 |

113 |
114 |
115 | > ℹ️ You can also carry out these operations using Python. Explore our Python Quickstart guide for more details.
116 | >
117 | >
118 | >
119 | >
120 |
121 | ## Some Example Results
122 |
123 | One example of where this `tti-eval` is useful is to test different open-source models against different open-source datasets within a specific domain.
124 | Below, we focused on the medical domain. We evaluate nine different models of which three of them are domain specific.
125 | The models are evaluated against four different medical datasets. Note, Further down this page, you will find links to all models and datasets.
126 |
127 |
128 |

129 |
130 |
Figure 1: Linear probe accuracy across four different medical datasets. General purpose models are colored green while models trained for the medical domain are colored red.
131 |
132 |
133 |
134 |
135 | The raw numbers from the experiment
136 |
137 | ### Weighted KNN Accuracy
138 |
139 | | Model/Dataset | Alzheimer-MRI | LungCancer4Types | chest-xray-classification | skin-cancer |
140 | | :--------------- | :-----------: | :--------------: | :-----------------------: | :---------: |
141 | | apple | 0.6777 | 0.6633 | 0.9687 | 0.7985 |
142 | | bioclip | 0.8952 | 0.7800 | 0.9771 | 0.7961 |
143 | | clip | 0.6986 | 0.6867 | 0.9727 | 0.7891 |
144 | | plip | 0.8021 | 0.6767 | 0.9599 | 0.7860 |
145 | | pubmed | 0.8503 | 0.5767 | 0.9725 | 0.7637 |
146 | | siglip_large | 0.6908 | 0.6533 | 0.9695 | 0.7947 |
147 | | siglip_small | 0.6992 | 0.6267 | 0.9646 | 0.7780 |
148 | | tinyclip | 0.7389 | 0.5900 | 0.9673 | 0.7589 |
149 | | vit-b-32-laion2b | 0.7559 | 0.5967 | 0.9654 | 0.7738 |
150 |
151 | ### Zero-shot Accuracy
152 |
153 | | Model/Dataset | Alzheimer-MRI | LungCancer4Types | chest-xray-classification | skin-cancer |
154 | | :--------------- | :-----------: | :--------------: | :-----------------------: | :---------: |
155 | | apple | 0.4460 | 0.2367 | 0.7381 | 0.3594 |
156 | | bioclip | 0.3092 | 0.2200 | 0.7356 | 0.0431 |
157 | | clip | 0.4857 | 0.2267 | 0.7381 | 0.1955 |
158 | | plip | 0.0104 | 0.2267 | 0.3873 | 0.0797 |
159 | | pubmed | 0.3099 | 0.2867 | 0.7501 | 0.1127 |
160 | | siglip_large | 0.4876 | 0.3000 | 0.5950 | 0.0421 |
161 | | siglip_small | 0.4102 | 0.0767 | 0.7381 | 0.1541 |
162 | | tinyclip | 0.2526 | 0.2533 | 0.7313 | 0.1113 |
163 | | vit-b-32-laion2b | 0.3594 | 0.1533 | 0.7378 | 0.1228 |
164 |
165 | ---
166 |
167 | ### Image-to-image Retrieval
168 |
169 | | Model/Dataset | Alzheimer-MRI | LungCancer4Types | chest-xray-classification | skin-cancer |
170 | | :--------------- | :-----------: | :--------------: | :-----------------------: | :---------: |
171 | | apple | 0.4281 | 0.2786 | 0.8835 | 0.6437 |
172 | | bioclip | 0.4535 | 0.3496 | 0.8786 | 0.6278 |
173 | | clip | 0.4247 | 0.2812 | 0.8602 | 0.6347 |
174 | | plip | 0.4406 | 0.3174 | 0.8372 | 0.6289 |
175 | | pubmed | 0.4445 | 0.3022 | 0.8621 | 0.6228 |
176 | | siglip_large | 0.4232 | 0.2743 | 0.8797 | 0.6466 |
177 | | siglip_small | 0.4303 | 0.2613 | 0.8660 | 0.6348 |
178 | | tinyclip | 0.4361 | 0.2833 | 0.8379 | 0.6098 |
179 | | vit-b-32-laion2b | 0.4378 | 0.2934 | 0.8551 | 0.6189 |
180 |
181 | ---
182 |
183 | ### Linear Probe Accuracy
184 |
185 | | Model/Dataset | Alzheimer-MRI | LungCancer4Types | chest-xray-classification | skin-cancer |
186 | | :--------------- | :-----------: | :--------------: | :-----------------------: | :---------: |
187 | | apple | 0.5482 | 0.5433 | 0.9362 | 0.7662 |
188 | | bioclip | 0.6139 | 0.6600 | 0.9433 | 0.7933 |
189 | | clip | 0.5547 | 0.5700 | 0.9362 | 0.7704 |
190 | | plip | 0.5469 | 0.5267 | 0.9261 | 0.7630 |
191 | | pubmed | 0.5482 | 0.5400 | 0.9278 | 0.7269 |
192 | | siglip_large | 0.5286 | 0.5200 | 0.9496 | 0.7697 |
193 | | siglip_small | 0.5449 | 0.4967 | 0.9327 | 0.7606 |
194 | | tinyclip | 0.5651 | 0.5733 | 0.9280 | 0.7484 |
195 | | vit-b-32-laion2b | 0.5684 | 0.5933 | 0.9302 | 0.7578 |
196 |
197 | ---
198 |
199 |
200 |
201 | ## Datasets
202 |
203 |
204 |
205 |
206 |
207 | This repository contains classification datasets sourced from [Hugging Face](https://huggingface.co/datasets) and [Encord](https://app.encord.com/projects).
208 |
209 | > ⚠️ Currently, only image and image groups datasets are supported, with potential for future expansion to include video datasets.
210 |
211 | | Dataset Title | Implementation | HF Dataset |
212 | | :------------------------ | :------------------------------ | :----------------------------------------------------------------------------------- |
213 | | Alzheimer-MRI | [Hugging Face][hf-dataset-impl] | [Falah/Alzheimer_MRI][Falah/Alzheimer_MRI] |
214 | | chest-xray-classification | [Hugging Face][hf-dataset-impl] | [trpakov/chest-xray-classification][trpakov/chest-xray-classification] |
215 | | LungCancer4Types | [Hugging Face][hf-dataset-impl] | [Kabil007/LungCancer4Types][Kabil007/LungCancer4Types] |
216 | | plants | [Hugging Face][hf-dataset-impl] | [sampath017/plants][sampath017/plants] |
217 | | skin-cancer | [Hugging Face][hf-dataset-impl] | [marmal88/skin_cancer][marmal88/skin_cancer] |
218 | | sports-classification | [Hugging Face][hf-dataset-impl] | [HES-XPLAIN/SportsImageClassification][HES-XPLAIN/SportsImageClassification] |
219 | | rsicd | [Encord][encord-dataset-impl] | \* Requires ssh key and access to the Encord project |
220 |
221 | ### Add a Dataset from a Known Source
222 |
223 | To register a dataset from a known source, you can include the dataset definition as a JSON file in the `sources/datasets` folder.
224 | The definition will be validated against the schema defined by the `tti_eval.dataset.base.DatasetDefinitionSpec` Pydantic class to ensure that it adheres to the required structure.
225 | You can find the explicit schema in `sources/dataset-definition-schema.json`.
226 |
227 | Check out the declarations of known sources at `tti_eval.dataset.types` and refer to the existing dataset definitions in the `sources/datasets` folder for guidance.
228 | Below is an example of a dataset definition for the [plants](https://huggingface.co/datasets/sampath017/plants) dataset sourced from Hugging Face:
229 |
230 | ```json
231 | {
232 | "dataset_type": "HFDataset",
233 | "title": "plants",
234 | "title_in_source": "sampath017/plants"
235 | }
236 | ```
237 |
238 | In each dataset definition, the `dataset_type` and `title` fields are required.
239 | The `dataset_type` indicates the name of the class that represents the source, while `title` serves as a reference for the dataset on this platform.
240 |
241 | For Hugging Face datasets, the `title_in_source` field should store the title of the dataset as it appears on the Hugging Face website.
242 |
243 | For datasets sourced from Encord, other set of fields are required. These include `project_hash`, which contains the hash of the project, and `classification_hash`, which contains the hash of the radio-button (multiclass) classification used in the labels.
244 |
245 | ### Add a Dataset Source
246 |
247 | Expanding the dataset sources involves two key steps:
248 |
249 | 1. Create a dataset class that inherits from `tti_eval.dataset.Dataset` and specifies the input requirements for extracting data from the new source.
250 | This class should encapsulate the necessary logic for fetching and processing dataset elements.
251 | 2. Generate a dataset definition in JSON format and save it in the `sources/datasets` folder, following the guidelines outlined in the previous section.
252 | Ensure that the definition includes essential fields such as `dataset_type`, `title`, and `module_path`, which points to the file containing the dataset class implementation.
253 |
254 | > It's recommended to store the file containing the dataset class implementation in the `tti_eval/dataset/types` folder and add a reference to the class in the `__init__.py` file in the same folder.
255 | > This ensures that the new dataset type is accessible by default for all dataset definitions, eliminating the need to explicitly state the `module_path` field for datasets from such source.
256 |
257 | ### Programmatically Add a Dataset
258 |
259 | Alternatively, you can programmatically add a dataset, which will be available only for the current session, using the `register_dataset()` method of the `tti_eval.dataset.DatasetProvider` class.
260 |
261 | Here is an example of how to register a dataset from Hugging Face using Python code:
262 |
263 | ```python
264 | from tti_eval.dataset import DatasetProvider, Split
265 | from tti_eval.dataset.types import HFDataset
266 |
267 | DatasetProvider.register_dataset(HFDataset, "plants", title_in_source="sampath017/plants")
268 | ds = DatasetProvider.get_dataset("plants", split=Split.ALL)
269 | print(len(ds)) # Returns: 219
270 | ```
271 |
272 | ### Remove a Dataset
273 |
274 | To permanently remove a dataset, simply delete the corresponding JSON file stores in the `sources/datasets` folder.
275 | This action removes the dataset from the list of available datasets in the CLI, disabling the option to create any further embedding using its data.
276 | However, all embeddings previously built on that dataset will remain intact and available for other tasks such as evaluation and animation.
277 |
278 | ## Models
279 |
280 |
281 |
282 |
283 |
284 | This repository contains models sourced from [Hugging Face](https://huggingface.co/models), [OpenCLIP](https://github.com/mlfoundations/open_clip) and local implementations based on OpenCLIP models.
285 |
286 | _TODO_: Some more prose about what's the difference between implementations.
287 |
288 | ### Hugging Face Models
289 |
290 | | Model Title | Implementation | HF Model |
291 | | :--------------- | :---------------------------- | :--------------------------------------------------------------------------------------------- |
292 | | apple | [OpenCLIP][open-model-impl] | [apple/DFN5B-CLIP-ViT-H-14][apple/DFN5B-CLIP-ViT-H-14] |
293 | | bioclip | [OpenCLIP][open-model-impl] | [imageomics/bioclip][imageomics/bioclip] |
294 | | eva-clip | [OpenCLIP][open-model-impl] | [BAAI/EVA-CLIP-8B-448][BAAI/EVA-CLIP-8B-448] |
295 | | vit-b-32-laion2b | [OpenCLIP][local-model-impl] | [laion/CLIP-ViT-B-32-laion2B-s34B-b79K][laion/CLIP-ViT-B-32-laion2B-s34B-b79K] |
296 | | clip | [Hugging Face][hf-model-impl] | [openai/clip-vit-large-patch14-336][openai/clip-vit-large-patch14-336] |
297 | | fashion | [Hugging Face][hf-model-impl] | [patrickjohncyh/fashion-clip][patrickjohncyh/fashion-clip] |
298 | | plip | [Hugging Face][hf-model-impl] | [vinid/plip][vinid/plip] |
299 | | pubmed | [Hugging Face][hf-model-impl] | [flaviagiammarino/pubmed-clip-vit-base-patch32][flaviagiammarino/pubmed-clip-vit-base-patch32] |
300 | | rsicd | [Hugging Face][hf-model-impl] | [flax-community/clip-rsicd][flax-community/clip-rsicd] |
301 | | siglip_large | [Hugging Face][hf-model-impl] | [google/siglip-large-patch16-256][google/siglip-large-patch16-256] |
302 | | siglip_small | [Hugging Face][hf-model-impl] | [google/siglip-base-patch16-224][google/siglip-base-patch16-224] |
303 | | street | [Hugging Face][hf-model-impl] | [geolocal/StreetCLIP][geolocal/StreetCLIP] |
304 | | tinyclip | [Hugging Face][hf-model-impl] | [wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M][wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M] |
305 |
306 | ### Locally Trained Models
307 |
308 | | Model Title | Implementation | Weights |
309 | | :----------- | :-------------------------------- | :------ |
310 | | rsicd-encord | [LocalOpenCLIP][local-model-impl] | - |
311 |
312 | ### Add a Model from a Known Source
313 |
314 | To register a model from a known source, you can include the model definition as a JSON file in the `sources/models` folder.
315 | The definition will be validated against the schema defined by the `tti_eval.model.base.ModelDefinitionSpec` Pydantic class to ensure that it adheres to the required structure.
316 | You can find the explicit schema in `sources/model-definition-schema.json`.
317 |
318 | Check out the declarations of known sources at `tti_eval.model.types` and refer to the existing model definitions in the `sources/models` folder for guidance.
319 | Below is an example of a model definition for the [clip](https://huggingface.co/openai/clip-vit-large-patch14-336) model sourced from Hugging Face:
320 |
321 | ```json
322 | {
323 | "model_type": "HFModel",
324 | "title": "clip",
325 | "title_in_source": "openai/clip-vit-large-patch14-336"
326 | }
327 | ```
328 |
329 | In each model definition, the `model_type` and `title` fields are required.
330 | The `model_type` indicates the name of the class that represents the source, while `title` serves as a reference for the model on this platform.
331 |
332 | For non-local models, the `title_in_source` field should store the title of the model as it appears in the source.
333 | For model checkpoints in local storage, the `title_in_source` field should store the title of the model used to train it.
334 | Additionally, on models sourced from OpenCLIP the optional `pretrained` field may be needed. See the list of OpenCLIP models [here](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md).
335 |
336 | ### Add a Model Source
337 |
338 | Expanding the model sources involves two key steps:
339 |
340 | 1. Create a model class that inherits from `tti_eval.model.Model` and specifies the input requirements for loading models from the new source.
341 | This class should encapsulate the necessary logic for processing model elements and generating embeddings.
342 | 2. Generate a model definition in JSON format and save it in the `sources/models` folder, following the guidelines outlined in the previous section.
343 | Ensure that the definition includes essential fields such as `model_type`, `title`, and `module_path`, which points to the file containing the model class implementation.
344 |
345 | > It's recommended to store the file containing the model class implementation in the `tti_eval/model/types` folder and add a reference to the class in the `__init__.py` file in the same folder.
346 | > This ensures that the new model type is accessible by default for all model definitions, eliminating the need to explicitly state the `module_path` field for models from such source.
347 |
348 | ### Programmatically Add a Model
349 |
350 | Alternatively, you can programmatically add a model, which will be available only for the current session, using the `register_model()` method of the `tti_eval.model.ModelProvider` class.
351 |
352 | Here is an example of how to register a model from Hugging Face using Python code:
353 |
354 | ```python
355 | from tti_eval.model import ModelProvider
356 | from tti_eval.model.types import HFModel
357 |
358 | ModelProvider.register_model(HFModel, "clip", title_in_source="openai/clip-vit-large-patch14-336")
359 | model = ModelProvider.get_model("clip")
360 | print(model.title, model.title_in_source) # Returns: clip openai/clip-vit-large-patch14-336
361 | ```
362 |
363 | ### Remove a Model
364 |
365 | To permanently remove a model, simply delete the corresponding JSON file stores in the `sources/models` folder.
366 | This action removes the model from the list of available models in the CLI, disabling the option to create any further embedding with it.
367 | However, all embeddings previously built with that model will remain intact and available for other tasks such as evaluation and animation.
368 |
369 | ## Set Up the Development Environment
370 |
371 | 1. Create the virtual environment, add dev dependencies and set up pre-commit hooks.
372 | ```
373 | ./dev-setup.sh
374 | ```
375 | 2. Add environment variables:
376 | ```
377 | export TTI_EVAL_CACHE_PATH=$PWD/.cache
378 | export TTI_EVAL_OUTPUT_PATH=$PWD/output
379 | export ENCORD_SSH_KEY_PATH=
380 | ```
381 |
382 | ## Contributing
383 |
384 | Contributions are welcome!
385 | Please feel free to open an issue or submit a pull request with your suggestions, bug fixes, or new features.
386 |
387 | ### Adding Dataset Sources
388 |
389 | To contribute by adding dataset sources, follow these steps:
390 |
391 | 1. Store the file containing the new dataset class implementation in the `tti_eval/dataset/types` folder.
392 | Don't forget to add a reference to the class in the `__init__.py` file in the same folder.
393 | This ensures that the new dataset type is accessible by default for all dataset definitions, eliminating the need to explicitly state the `module_path` field for datasets from such source.
394 | 2. Open a pull request with the necessary changes. Make sure to include tests validating that data retrieval, processing and usage are working as expected.
395 | 3. Document the addition of the dataset source, providing details on its structure, usage, and any specific considerations or instructions for integration.
396 | This ensures that users have clear guidance on how to leverage the new dataset source effectively.
397 |
398 | ### Adding Model Sources
399 |
400 | To contribute by adding model sources, follow these steps:
401 |
402 | 1. Store the file containing the new model class implementation in the `tti_eval/model/types` folder.
403 | Don't forget to add a reference to the class in the `__init__.py` file in the same folder.
404 | This ensures that the new model type is accessible by default for all model definitions, eliminating the need to explicitly state the `module_path` field for models from such source.
405 | 2. Open a pull request with the necessary changes. Make sure to include tests validating that model loading, processing and embedding generation are working as expected.
406 | 3. Document the addition of the model source, providing details on its structure, usage, and any specific considerations or instructions for integration.
407 | This ensures that users have clear guidance on how to leverage the new model source effectively.
408 |
409 | ## Known Issues
410 |
411 | 1. `autofaiss`: The project depends on the [autofaiss][autofaiss] library which can give some trouble on Windows. Please reach out or raise an issue with as many system and version details as possible if you encounter it.
412 |
413 | [Falah/Alzheimer_MRI]: https://huggingface.co/datasets/Falah/Alzheimer_MRI
414 | [trpakov/chest-xray-classification]: https://huggingface.co/datasets/trpakov/chest-xray-classification
415 | [Kabil007/LungCancer4Types]: https://huggingface.co/datasets/Kabil007/LungCancer4Types
416 | [sampath017/plants]: https://huggingface.co/datasets/sampath017/plants
417 | [marmal88/skin_cancer]: https://huggingface.co/datasets/marmal88/skin_cancer
418 | [HES-XPLAIN/SportsImageClassification]: https://huggingface.co/datasets/HES-XPLAIN/SportsImageClassification
419 | [apple/DFN5B-CLIP-ViT-H-14]: https://huggingface.co/apple/DFN5B-CLIP-ViT-H-14
420 | [imageomics/bioclip]: https://huggingface.co/imageomics/bioclip
421 | [openai/clip-vit-large-patch14-336]: https://huggingface.co/openai/clip-vit-large-patch14-336
422 | [BAAI/EVA-CLIP-8B-448]: https://huggingface.co/BAAI/EVA-CLIP-8B-448
423 | [patrickjohncyh/fashion-clip]: https://huggingface.co/patrickjohncyh/fashion-clip
424 | [vinid/plip]: https://huggingface.co/vinid/plip
425 | [flaviagiammarino/pubmed-clip-vit-base-patch32]: https://huggingface.co/flaviagiammarino/pubmed-clip-vit-base-patch32
426 | [flax-community/clip-rsicd]: https://huggingface.co/flax-community/clip-rsicd
427 | [google/siglip-large-patch16-256]: https://huggingface.co/google/siglip-large-patch16-256
428 | [google/siglip-base-patch16-224]: https://huggingface.co/google/siglip-base-patch16-224
429 | [geolocal/StreetCLIP]: https://huggingface.co/geolocal/StreetCLIP
430 | [wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M]: https://huggingface.co/wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M
431 | [laion/CLIP-ViT-B-32-laion2B-s34B-b79K]: https://huggingface.co/laion/CLIP-ViT-B-32-laion2B-s34B-b79K
432 | [open-model-impl]: https://github.com/encord-team/text-to-image-eval/blob/main/tti_eval/model/types/open_clip_model.py
433 | [hf-model-impl]: https://github.com/encord-team/text-to-image-eval/blob/main/tti_eval/model/types/hugging_face_clip.py
434 | [local-model-impl]: https://github.com/encord-team/text-to-image-eval/blob/main/tti_eval/model/types/local_clip_model.py
435 | [hf-dataset-impl]: https://github.com/encord-team/text-to-image-eval/blob/main/tti_eval/dataset/types/hugging_face.py
436 | [encord-dataset-impl]: https://github.com/encord-team/text-to-image-eval/blob/main/tti_eval/dataset/types/encord_ds.py
437 | [autofaiss]: https://pypi.org/project/autofaiss/
438 |
--------------------------------------------------------------------------------
/dev-setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/zsh
2 | poetry env use 3.11 # Create the virtual environment if it does not exist
3 | source $(poetry env info --path)/bin/activate # Activate and enter the virtual environment
4 | poetry install --with=dev # Install dev dependencies
5 | pre-commit install --install-hooks --overwrite -t pre-push # Set up pre-commit hooks
6 |
--------------------------------------------------------------------------------
/images/embeddings.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/images/embeddings.gif
--------------------------------------------------------------------------------
/images/encord.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/images/encord.png
--------------------------------------------------------------------------------
/images/logo_v1.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/images/logo_v1.webp
--------------------------------------------------------------------------------
/images/tti-eval-banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/images/tti-eval-banner.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "poetry.core.masonry.api"
3 | requires = ["poetry-core"]
4 |
5 | [tool.poetry]
6 | authors = ["Cord Technologies Limited "]
7 | description = "Evaluating text-to-image and image-to-text retrieval models."
8 | name = "tti-eval"
9 | readme = "README.md"
10 | version = "0.1.0"
11 | packages = [{include = "tti_eval", from = "."}]
12 |
13 | [tool.poetry.dependencies]
14 | autofaiss = "^2.17.0"
15 | datasets = "^2.17.0"
16 | encord = "^0.1.108"
17 | pydantic = "^2.6.1"
18 | python = "^3.10"
19 | python-dotenv = "^1.0.1"
20 | scikit-learn = "^1.4"
21 | torch = "^2.2.0"
22 | torchvision = "^0.17.0"
23 | transformers = "^4.37.2"
24 | umap-learn = "^0.5.5"
25 | matplotlib = "^3.8.2"
26 | typer = "^0.9.0"
27 | inquirerpy = "^0.3.4"
28 | tabulate = "^0.9.0"
29 | open-clip-torch = "^2.24.0"
30 | natsort = "^8.4.0"
31 |
32 | [tool.poetry.group.dev.dependencies]
33 | mypy = "^1.8.0"
34 | ruff = "^0.2.1"
35 | toml-sort = "^0.23.1"
36 | pre-commit = "^3.6.1"
37 | pyfzf = "^0.3.1"
38 | ipython = "^8.22.1"
39 | ipdb = "^0.13.13"
40 | pytest = "^8.1.1"
41 |
42 | [tool.poetry.scripts]
43 | tti-eval = "tti_eval.cli.main:cli"
44 |
45 | [tool.ruff]
46 | line-length = 120
47 | target-version = "py311"
48 |
49 | [tool.ruff.lint]
50 | # ignore = [
51 | # "B007", # Loop control variable {name} not used within loop body
52 | # "E501", # Checks for lines that exceed the specified maximum character length
53 | # "E741" # Ambiguous variable name: {name} (e.g. allow short names in one-line list comprehensions)
54 | # ]
55 | select = ["B", "E", "F", "I", "UP"]
56 | ignore = ["UP007"]
57 |
58 | [tool.ruff.lint.isort]
59 | known-first-party = ["src"]
60 |
61 | [tool.ruff.lint.per-file-ignores]
62 | "__init__.py" = ["F401"]
63 |
--------------------------------------------------------------------------------
/sources/dataset-definition-schema.json:
--------------------------------------------------------------------------------
1 | {
2 | "$defs": {
3 | "Split": {
4 | "enum": [
5 | "train",
6 | "validation",
7 | "test",
8 | "all"
9 | ],
10 | "title": "Split",
11 | "type": "string"
12 | }
13 | },
14 | "additionalProperties": true,
15 | "properties": {
16 | "dataset_type": {
17 | "title": "Dataset Type",
18 | "type": "string"
19 | },
20 | "title": {
21 | "title": "Title",
22 | "type": "string"
23 | },
24 | "module_path": {
25 | "default": "../../tti_eval/dataset/types/__init__.py",
26 | "format": "path",
27 | "title": "Module Path",
28 | "type": "string"
29 | },
30 | "split": {
31 | "anyOf": [
32 | {
33 | "$ref": "#/$defs/Split"
34 | },
35 | {
36 | "type": "null"
37 | }
38 | ],
39 | "default": null
40 | },
41 | "title_in_source": {
42 | "anyOf": [
43 | {
44 | "type": "string"
45 | },
46 | {
47 | "type": "null"
48 | }
49 | ],
50 | "default": null,
51 | "title": "Title In Source"
52 | },
53 | "cache_dir": {
54 | "anyOf": [
55 | {
56 | "format": "path",
57 | "type": "string"
58 | },
59 | {
60 | "type": "null"
61 | }
62 | ],
63 | "default": null,
64 | "title": "Cache Dir"
65 | }
66 | },
67 | "required": [
68 | "dataset_type",
69 | "title"
70 | ],
71 | "title": "DatasetDefinitionSpec",
72 | "type": "object"
73 | }
74 |
--------------------------------------------------------------------------------
/sources/datasets/Alzheimer-MRI.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_type": "HFDataset",
3 | "title": "Alzheimer-MRI",
4 | "title_in_source": "Falah/Alzheimer_MRI"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/datasets/LungCancer4Types.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_type": "HFDataset",
3 | "title": "LungCancer4Types",
4 | "title_in_source": "Kabil007/LungCancer4Types",
5 | "revision": "a1aab924c6bed6b080fc85552fd7b39724931605"
6 | }
7 |
--------------------------------------------------------------------------------
/sources/datasets/chest-xray-classification.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_type": "HFDataset",
3 | "title": "chest-xray-classification",
4 | "title_in_source": "trpakov/chest-xray-classification",
5 | "name": "full",
6 | "target_feature": "labels"
7 | }
8 |
--------------------------------------------------------------------------------
/sources/datasets/plants.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_type": "HFDataset",
3 | "title": "plants",
4 | "title_in_source": "sampath017/plants"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/datasets/rsicd.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_type": "EncordDataset",
3 | "title": "rsicd",
4 | "project_hash": "46ba913e-1428-48ef-be7f-2553e69bc1e6",
5 | "classification_hash": "4f6cf0c8"
6 | }
7 |
--------------------------------------------------------------------------------
/sources/datasets/skin-cancer.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_type": "HFDataset",
3 | "title": "skin-cancer",
4 | "title_in_source": "marmal88/skin_cancer",
5 | "target_feature": "dx"
6 | }
7 |
--------------------------------------------------------------------------------
/sources/datasets/sports-classification.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_type": "HFDataset",
3 | "title": "sports-classification",
4 | "title_in_source": "HES-XPLAIN/SportsImageClassification"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/model-definition-schema.json:
--------------------------------------------------------------------------------
1 | {
2 | "additionalProperties": true,
3 | "properties": {
4 | "model_type": {
5 | "title": "Model Type",
6 | "type": "string"
7 | },
8 | "title": {
9 | "title": "Title",
10 | "type": "string"
11 | },
12 | "module_path": {
13 | "default": "../../tti_eval/model/types/__init__.py",
14 | "format": "path",
15 | "title": "Module Path",
16 | "type": "string"
17 | },
18 | "device": {
19 | "anyOf": [
20 | {
21 | "type": "string"
22 | },
23 | {
24 | "type": "null"
25 | }
26 | ],
27 | "default": null,
28 | "title": "Device"
29 | },
30 | "title_in_source": {
31 | "anyOf": [
32 | {
33 | "type": "string"
34 | },
35 | {
36 | "type": "null"
37 | }
38 | ],
39 | "default": null,
40 | "title": "Title In Source"
41 | },
42 | "cache_dir": {
43 | "anyOf": [
44 | {
45 | "format": "path",
46 | "type": "string"
47 | },
48 | {
49 | "type": "null"
50 | }
51 | ],
52 | "default": null,
53 | "title": "Cache Dir"
54 | }
55 | },
56 | "required": [
57 | "model_type",
58 | "title"
59 | ],
60 | "title": "ModelDefinitionSpec",
61 | "type": "object"
62 | }
63 |
--------------------------------------------------------------------------------
/sources/models/apple.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "OpenCLIPModel",
3 | "title": "apple",
4 | "title_in_source": "hf-hub:apple/DFN5B-CLIP-ViT-H-14"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/bioclip.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "OpenCLIPModel",
3 | "title": "bioclip",
4 | "title_in_source": "hf-hub:imageomics/bioclip"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/clip.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "HFModel",
3 | "title": "clip",
4 | "title_in_source": "openai/clip-vit-large-patch14-336"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/eva-clip.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "OpenCLIPModel",
3 | "title": "eva-clip",
4 | "title_in_source": "BAAI/EVA-CLIP-8B-448"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/fashion.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "HFModel",
3 | "title": "fashion",
4 | "title_in_source": "patrickjohncyh/fashion-clip"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/plip.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "HFModel",
3 | "title": "plip",
4 | "title_in_source": "vinid/plip"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/pubmed.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "HFModel",
3 | "title": "pubmed",
4 | "title_in_source": "flaviagiammarino/pubmed-clip-vit-base-patch32"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/rsicd-encord.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "LocalCLIPModel",
3 | "title": "rsicd-encord",
4 | "title_in_source": "ViT-B-32"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/rsicd.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "HFModel",
3 | "title": "rsicd",
4 | "title_in_source": "flax-community/clip-rsicd"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/siglip_large.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "HFModel",
3 | "title": "siglip_large",
4 | "title_in_source": "google/siglip-large-patch16-256"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/siglip_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "HFModel",
3 | "title": "siglip_small",
4 | "title_in_source": "google/siglip-base-patch16-224"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/street.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "HFModel",
3 | "title": "street",
4 | "title_in_source": "geolocal/StreetCLIP"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/tinyclip.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "HFModel",
3 | "title": "tinyclip",
4 | "title_in_source": "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M"
5 | }
6 |
--------------------------------------------------------------------------------
/sources/models/vit-b-32-laion2b.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "OpenCLIPModel",
3 | "title": "vit-b-32-laion2b",
4 | "title_in_source": "ViT-B-32",
5 | "pretrained": "laion2b_e16"
6 | }
7 |
--------------------------------------------------------------------------------
/tests/common/test_embedding_definition.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from tti_eval.common import EmbeddingDefinition, Embeddings, Split
4 | from tti_eval.compute import compute_embeddings_from_definition
5 |
6 | if __name__ == "__main__":
7 | def_ = EmbeddingDefinition(
8 | model="weird_this with / stuff \\whatever",
9 | dataset="hello there dataset",
10 | )
11 | def_.embedding_path(Split.TRAIN).parent.mkdir(exist_ok=True, parents=True)
12 |
13 | images = np.random.randn(100, 20).astype(np.float32)
14 | labels = np.random.randint(0, 10, size=(100,))
15 | classes = np.random.randn(10, 20).astype(np.float32)
16 | emb = Embeddings(images=images, labels=labels, classes=classes)
17 | # emb.to_file(def_.embedding_path(Split.TRAIN))
18 | def_.save_embeddings(emb, split=Split.TRAIN, overwrite=True)
19 | new_emb = def_.load_embeddings(Split.TRAIN)
20 |
21 | assert new_emb is not None
22 | assert np.allclose(new_emb.images, images)
23 | assert np.allclose(new_emb.labels, labels)
24 |
25 | from pydantic import ValidationError
26 |
27 | try:
28 | Embeddings(
29 | images=np.random.randn(100, 20).astype(np.float32),
30 | labels=np.random.randint(0, 10, size=(100,)),
31 | classes=np.random.randn(10, 30).astype(np.float32),
32 | )
33 | raise AssertionError()
34 | except ValidationError:
35 | pass
36 |
37 | def_ = EmbeddingDefinition(
38 | model="clip",
39 | dataset="LungCancer4Types",
40 | )
41 | embeddings = compute_embeddings_from_definition(def_, Split.VALIDATION)
42 | def_.save_embeddings(embeddings, split=Split.VALIDATION, overwrite=True)
43 |
--------------------------------------------------------------------------------
/tests/evaluation/test_knn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
4 | from tti_eval.evaluation.knn import WeightedKNNClassifier
5 | from tti_eval.evaluation.utils import normalize, softmax
6 |
7 |
8 | def slow_knn_predict(
9 | train_embeddings: Embeddings,
10 | val_embeddings: Embeddings,
11 | num_classes: int,
12 | k: int,
13 | ) -> tuple[ProbabilityArray, ClassArray]:
14 | train_image_embeddings = normalize(train_embeddings.images)
15 | val_image_embeddings = normalize(val_embeddings.images)
16 | n = val_image_embeddings.shape[0]
17 |
18 | # Retrieve the classes and distances of the `k` nearest training embeddings to each validation embedding
19 | all_dists = np.linalg.norm(val_image_embeddings[:, np.newaxis] - train_image_embeddings[np.newaxis, :], axis=-1)
20 | nearest_indices = np.argsort(all_dists, axis=1)[:, :k]
21 | dists = all_dists[np.arange(n)[:, np.newaxis], nearest_indices]
22 | nearest_classes = np.take(train_embeddings.labels, nearest_indices)
23 |
24 | # Calculate class votes from the distances (avoiding division by zero)
25 | max_value = np.finfo(np.float32).max
26 | scores = np.divide(1, np.square(dists), out=np.full_like(dists, max_value), where=dists != 0)
27 | weighted_count = np.zeros((n, num_classes), dtype=np.float32)
28 | for cls in range(num_classes):
29 | mask = nearest_classes == cls
30 | weighted_count[:, cls] = np.ma.masked_array(scores, mask=~mask).sum(axis=1)
31 | probabilities = softmax(weighted_count)
32 | return probabilities, np.argmax(probabilities, axis=1)
33 |
34 |
35 | def test_weighted_knn_classifier():
36 | np.random.seed(42)
37 |
38 | train_embeddings = Embeddings(
39 | images=np.random.randn(100, 10).astype(np.float32),
40 | labels=np.random.randint(0, 10, size=(100,)),
41 | )
42 | val_embeddings = Embeddings(
43 | images=np.random.randn(20, 10).astype(np.float32),
44 | labels=np.random.randint(0, 10, size=(20,)),
45 | )
46 | knn = WeightedKNNClassifier(
47 | train_embeddings,
48 | val_embeddings,
49 | num_classes=10,
50 | )
51 | probs, pred_classes = knn.predict()
52 |
53 | test_probs, test_pred_classes = slow_knn_predict(
54 | train_embeddings,
55 | val_embeddings,
56 | num_classes=knn.num_classes,
57 | k=knn.k,
58 | )
59 |
60 | assert (pred_classes == test_pred_classes).all()
61 | assert np.isclose(probs, test_probs).all()
62 |
--------------------------------------------------------------------------------
/tti_eval/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/tti_eval/cli/__init__.py
--------------------------------------------------------------------------------
/tti_eval/cli/main.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, Optional
2 |
3 | import matplotlib.pyplot as plt
4 | import typer
5 | from typer import Argument, Option, Typer
6 |
7 | from tti_eval.common import Split
8 | from tti_eval.compute import compute_embeddings_from_definition
9 | from tti_eval.utils import read_all_cached_embeddings
10 |
11 | from .utils import (
12 | parse_raw_embedding_definitions,
13 | select_existing_embedding_definitions,
14 | select_from_all_embedding_definitions,
15 | )
16 |
17 | cli = Typer(name="tti-eval", no_args_is_help=True, rich_markup_mode="markdown")
18 |
19 |
20 | @cli.command(
21 | "build",
22 | help="""Build embeddings.
23 | If no arguments are given, you will be prompted to select the model and dataset combinations for generating embeddings.
24 | You can use [TAB] to select multiple combinations and execute them sequentially.
25 | """,
26 | )
27 | def build_command(
28 | model_datasets: Annotated[
29 | Optional[list[str]],
30 | Option(
31 | "--model-dataset",
32 | help="Specify a model and dataset combination. Can be used multiple times. "
33 | "(model, dataset) pairs must be presented as 'MODEL/DATASET'.",
34 | ),
35 | ] = None,
36 | include_existing: Annotated[
37 | bool,
38 | Option(help="Show combinations for which the embeddings have already been computed."),
39 | ] = False,
40 | by_dataset: Annotated[
41 | bool,
42 | Option(help="Select dataset first, then model. Will only work if `model_dataset` is not specified."),
43 | ] = False,
44 | ):
45 | if len(model_datasets) > 0:
46 | definitions = parse_raw_embedding_definitions(model_datasets)
47 | else:
48 | definitions = select_from_all_embedding_definitions(
49 | include_existing=include_existing,
50 | by_dataset=by_dataset,
51 | )
52 |
53 | splits = [Split.TRAIN, Split.VALIDATION]
54 | for embd_defn in definitions:
55 | for split in splits:
56 | try:
57 | embeddings = compute_embeddings_from_definition(embd_defn, split)
58 | embd_defn.save_embeddings(embeddings=embeddings, split=split, overwrite=True)
59 | print(f"Embeddings saved successfully to file at `{embd_defn.embedding_path(split)}`")
60 | except Exception as e:
61 | print(f"Failed to build embeddings for {embd_defn} on the specified split {split}")
62 | print(e)
63 | import traceback
64 |
65 | traceback.print_exc()
66 |
67 |
68 | @cli.command(
69 | "evaluate",
70 | help="""Evaluate embeddings performance.
71 | If no arguments are given, you will be prompted to select the model and dataset combinations to evaluate.
72 | Only (model, dataset) pairs whose embeddings have been built will be available for evaluation.
73 | You can use [TAB] to select multiple combinations and execute them sequentially.
74 | """,
75 | )
76 | def evaluate_embeddings(
77 | model_datasets: Annotated[
78 | Optional[list[str]],
79 | Option(
80 | "--model-dataset",
81 | help="Specify a model and dataset combination. Can be used multiple times. "
82 | "(model, dataset) pairs must be presented as 'MODEL/DATASET'.",
83 | ),
84 | ] = None,
85 | all_: Annotated[bool, Option("--all", "-a", help="Evaluate all models.")] = False,
86 | save: Annotated[bool, Option("--save", "-s", help="Save evaluation results to a CSV file.")] = False,
87 | ):
88 | from tti_eval.evaluation import (
89 | I2IRetrievalEvaluator,
90 | LinearProbeClassifier,
91 | WeightedKNNClassifier,
92 | ZeroShotClassifier,
93 | )
94 | from tti_eval.evaluation.evaluator import export_evaluation_to_csv, run_evaluation
95 |
96 | model_datasets = model_datasets or []
97 |
98 | if all_:
99 | definitions = read_all_cached_embeddings(as_list=True)
100 | elif len(model_datasets) > 0:
101 | definitions = parse_raw_embedding_definitions(model_datasets)
102 | else:
103 | definitions = select_existing_embedding_definitions()
104 |
105 | models = [ZeroShotClassifier, LinearProbeClassifier, WeightedKNNClassifier, I2IRetrievalEvaluator]
106 | performances = run_evaluation(models, definitions)
107 | if save:
108 | export_evaluation_to_csv(performances)
109 |
110 |
111 | @cli.command(
112 | "animate",
113 | help="""Animate 2D embeddings from two different models on the same dataset.
114 | The interface will prompt you to choose which embeddings you want to use.
115 | """,
116 | )
117 | def animate_embeddings(
118 | from_model: Annotated[Optional[str], Argument(help="Title of the model in the left side of the animation.")] = None,
119 | to_model: Annotated[Optional[str], Argument(help="Title of the model in the right side of the animation.")] = None,
120 | dataset: Annotated[Optional[str], Argument(help="Title of the dataset where the embeddings were computed.")] = None,
121 | interactive: Annotated[bool, Option(help="Interactive plot instead of animation.")] = False,
122 | reduction: Annotated[str, Option(help="Reduction type [pca, tsne, umap (default)].")] = "umap",
123 | ):
124 | from tti_eval.plotting.animation import EmbeddingDefinition, build_animation, save_animation_to_file
125 |
126 | all_none_input_args = from_model is None and to_model is None and dataset is None
127 | all_str_input_args = from_model is not None and to_model is not None and dataset is not None
128 |
129 | if not all_str_input_args and not all_none_input_args:
130 | typer.echo("Some arguments were provided. Please either provide all arguments or ignore them entirely.")
131 | raise typer.Abort()
132 |
133 | if all_none_input_args:
134 | defs = select_existing_embedding_definitions(by_dataset=True, count=2)
135 | from_def, to_def = defs[0], defs[1]
136 | else:
137 | from_def = EmbeddingDefinition(model=from_model, dataset=dataset)
138 | to_def = EmbeddingDefinition(model=to_model, dataset=dataset)
139 |
140 | res = build_animation(from_def, to_def, interactive=interactive, reduction=reduction)
141 | if res is None:
142 | plt.show()
143 | else:
144 | save_animation_to_file(res, from_def, to_def)
145 |
146 |
147 | @cli.command("list", help="List models and datasets. By default, only cached pairs are listed.")
148 | def list_models_datasets(
149 | all_: Annotated[
150 | bool,
151 | Option("--all", "-a", help="List all models and datasets that are available via the tool."),
152 | ] = False,
153 | ):
154 | from tti_eval.dataset import DatasetProvider
155 | from tti_eval.model import ModelProvider
156 |
157 | if all_:
158 | datasets = DatasetProvider.list_dataset_titles()
159 | models = ModelProvider.list_model_titles()
160 | print(f"Available datasets are: {', '.join(datasets)}")
161 | print(f"Available models are: {', '.join(models)}")
162 | return
163 |
164 | defns = read_all_cached_embeddings(as_list=True)
165 | print(f"Available model_datasets pairs: {', '.join([str(defn) for defn in defns])}")
166 |
167 |
168 | if __name__ == "__main__":
169 | cli()
170 |
--------------------------------------------------------------------------------
/tti_eval/cli/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 | from typing import Literal, overload
3 |
4 | from InquirerPy import inquirer as inq
5 | from InquirerPy.base.control import Choice
6 | from natsort import natsorted, ns
7 |
8 | from tti_eval.common import EmbeddingDefinition
9 | from tti_eval.dataset import DatasetProvider
10 | from tti_eval.model import ModelProvider
11 | from tti_eval.utils import read_all_cached_embeddings
12 |
13 |
14 | @overload
15 | def _do_embedding_definition_selection(
16 | defs: list[EmbeddingDefinition], allow_multiple: Literal[True] = True
17 | ) -> list[EmbeddingDefinition]:
18 | ...
19 |
20 |
21 | @overload
22 | def _do_embedding_definition_selection(
23 | defs: list[EmbeddingDefinition], allow_multiple: Literal[False]
24 | ) -> EmbeddingDefinition:
25 | ...
26 |
27 |
28 | def _do_embedding_definition_selection(
29 | defs: list[EmbeddingDefinition],
30 | allow_multiple: bool = True,
31 | ) -> list[EmbeddingDefinition] | EmbeddingDefinition:
32 | sorted_defs = natsorted(defs, key=lambda x: (x.dataset, x.model), alg=ns.IGNORECASE)
33 | choices = [Choice(d, f"D: {d.dataset[:15]:18s} M: {d.model}") for d in sorted_defs]
34 | message = "Please select the desired pairs" if allow_multiple else "Please select a pair"
35 | definitions = inq.fuzzy(
36 | message,
37 | choices=choices,
38 | multiselect=allow_multiple,
39 | vi_mode=True,
40 | ).execute() # type: ignore
41 | return definitions
42 |
43 |
44 | def _by_dataset(defs: list[EmbeddingDefinition] | dict[str, list[EmbeddingDefinition]]) -> list[EmbeddingDefinition]:
45 | if isinstance(defs, list):
46 | defs_list = defs
47 | defs = {}
48 | for d in defs_list:
49 | defs.setdefault(d.dataset, []).append(d)
50 |
51 | choices = sorted(
52 | [Choice(v, f"D: {k[:15]:18s} M: {', '.join([d.model for d in v])}") for k, v in defs.items() if len(v)],
53 | key=lambda c: len(c.value),
54 | )
55 | message = "Please select a dataset"
56 | definitions: list[EmbeddingDefinition] = inq.fuzzy(
57 | message, choices=choices, multiselect=False, vi_mode=True
58 | ).execute() # type: ignore
59 | return definitions
60 |
61 |
62 | def parse_raw_embedding_definitions(raw_embedding_definitions: list[str]) -> list[EmbeddingDefinition]:
63 | if not all([model_dataset.count("/") == 1 for model_dataset in raw_embedding_definitions]):
64 | raise ValueError("All (model, dataset) pairs must be presented as MODEL/DATASET")
65 | model_dataset_pairs = [model_dataset.split("/") for model_dataset in raw_embedding_definitions]
66 | return [
67 | EmbeddingDefinition(model=model_dataset[0], dataset=model_dataset[1]) for model_dataset in model_dataset_pairs
68 | ]
69 |
70 |
71 | def select_existing_embedding_definitions(
72 | by_dataset: bool = False,
73 | count: int | None = None,
74 | ) -> list[EmbeddingDefinition]:
75 | defs = read_all_cached_embeddings(as_list=True)
76 |
77 | if by_dataset:
78 | # Subset definitions to specific dataset
79 | defs = _by_dataset(defs)
80 |
81 | if count is None:
82 | return _do_embedding_definition_selection(defs)
83 | else:
84 | return [_do_embedding_definition_selection(defs, allow_multiple=False) for _ in range(count)]
85 |
86 |
87 | def select_from_all_embedding_definitions(
88 | include_existing: bool = False, by_dataset: bool = False
89 | ) -> list[EmbeddingDefinition]:
90 | existing = set(read_all_cached_embeddings(as_list=True))
91 |
92 | models = ModelProvider.list_model_titles()
93 | datasets = DatasetProvider.list_dataset_titles()
94 |
95 | defs = [EmbeddingDefinition(dataset=d, model=m) for d, m in product(datasets, models)]
96 | if not include_existing:
97 | defs = [d for d in defs if d not in existing]
98 |
99 | if by_dataset:
100 | defs = _by_dataset(defs)
101 |
102 | return _do_embedding_definition_selection(defs)
103 |
--------------------------------------------------------------------------------
/tti_eval/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import EmbeddingDefinition, Embeddings, Split
2 | from .numpy_types import ClassArray, EmbeddingArray, ProbabilityArray, ReductionArray
3 |
--------------------------------------------------------------------------------
/tti_eval/common/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from enum import Enum
3 | from pathlib import Path
4 | from typing import Annotated, Any
5 |
6 | import numpy as np
7 | from pydantic import BaseModel, model_validator
8 | from pydantic.functional_validators import AfterValidator
9 |
10 | from tti_eval.constants import NPZ_KEYS, PROJECT_PATHS
11 |
12 | from .numpy_types import ClassArray, EmbeddingArray
13 | from .string_utils import safe_str
14 |
15 | SafeName = Annotated[str, AfterValidator(safe_str)]
16 | logger = logging.getLogger("multiclips")
17 |
18 |
19 | class Split(str, Enum):
20 | TRAIN = "train"
21 | VALIDATION = "validation"
22 | TEST = "test"
23 | ALL = "all"
24 |
25 |
26 | class Embeddings(BaseModel):
27 | images: EmbeddingArray
28 | classes: EmbeddingArray | None = None
29 | labels: ClassArray
30 |
31 | @model_validator(mode="after")
32 | def validate_shapes(self) -> "Embeddings":
33 | N_images = self.images.shape[0]
34 | N_labels = self.labels.shape[0]
35 | if N_images != N_labels:
36 | raise ValueError(f"Differing number of images: {N_images} and labels: {N_labels}")
37 |
38 | if self.classes is None:
39 | return self
40 |
41 | *_, d1 = self.images.shape
42 | *_, d2 = self.classes.shape
43 |
44 | if d1 != d2:
45 | raise ValueError(
46 | f"Image {self.images.shape} and class {self.classes.shape} embeddings should have same dimensionality"
47 | )
48 | return self
49 |
50 | @staticmethod
51 | def from_file(path: Path) -> "Embeddings":
52 | if not path.suffix == ".npz":
53 | raise ValueError(f"Embedding files should be `.npz` files not {path.suffix}")
54 |
55 | loaded = np.load(path)
56 | if NPZ_KEYS.IMAGE_EMBEDDINGS not in loaded and NPZ_KEYS.LABELS not in loaded:
57 | raise ValueError(f"Require both {NPZ_KEYS.IMAGE_EMBEDDINGS}, {NPZ_KEYS.LABELS} should be present in {path}")
58 |
59 | image_embeddings: EmbeddingArray = loaded[NPZ_KEYS.IMAGE_EMBEDDINGS]
60 | labels: ClassArray = loaded[NPZ_KEYS.LABELS]
61 | label_embeddings: EmbeddingArray | None = (
62 | loaded[NPZ_KEYS.CLASS_EMBEDDINGS] if NPZ_KEYS.CLASS_EMBEDDINGS in loaded else None
63 | )
64 | return Embeddings(images=image_embeddings, labels=labels, classes=label_embeddings)
65 |
66 | def to_file(self, path: Path) -> Path:
67 | if not path.suffix == ".npz":
68 | raise ValueError(f"Embedding files should be `.npz` files not {path.suffix}")
69 | to_store: dict[str, np.ndarray] = {
70 | NPZ_KEYS.IMAGE_EMBEDDINGS: self.images,
71 | NPZ_KEYS.LABELS: self.labels,
72 | }
73 |
74 | if self.classes is not None:
75 | to_store[NPZ_KEYS.CLASS_EMBEDDINGS] = self.classes
76 |
77 | np.savez_compressed(
78 | path,
79 | **to_store,
80 | )
81 | return path
82 |
83 | class Config:
84 | arbitrary_types_allowed = True
85 |
86 |
87 | class EmbeddingDefinition(BaseModel):
88 | model: SafeName
89 | dataset: SafeName
90 |
91 | def _get_embedding_path(self, split: Split, suffix: str) -> Path:
92 | return Path(self.dataset) / f"{self.model}_{split.value}{suffix}"
93 |
94 | def embedding_path(self, split: Split) -> Path:
95 | return PROJECT_PATHS.EMBEDDINGS / self._get_embedding_path(split, ".npz")
96 |
97 | def get_reduction_path(self, reduction_name: str, split: Split):
98 | return PROJECT_PATHS.REDUCTIONS / self._get_embedding_path(split, f".{reduction_name}.2d.npy")
99 |
100 | def load_embeddings(self, split: Split) -> Embeddings | None:
101 | """
102 | Load embeddings for embedding configuration or return None
103 | """
104 | try:
105 | return Embeddings.from_file(self.embedding_path(split))
106 | except ValueError:
107 | return None
108 |
109 | def save_embeddings(self, embeddings: Embeddings, split: Split, overwrite: bool = False) -> bool:
110 | """
111 | Save embeddings associated to the embedding definition.
112 | Args:
113 | embeddings: The embeddings to store
114 | split: The dataset split that corresponds to the embeddings
115 | overwrite: If false, won't overwrite and will return False
116 |
117 | Returns:
118 | True iff file stored successfully
119 |
120 | """
121 | if self.embedding_path(split).is_file() and not overwrite:
122 | logger.warning(
123 | f"Not saving embeddings to file `{self.embedding_path(split)}` as overwrite is False and file exists"
124 | )
125 | return False
126 | self.embedding_path(split).parent.mkdir(exist_ok=True, parents=True)
127 | embeddings.to_file(self.embedding_path(split))
128 | return True
129 |
130 | def __str__(self):
131 | return self.model + "_" + self.dataset
132 |
133 | def __eq__(self, other: Any) -> bool:
134 | return isinstance(other, EmbeddingDefinition) and self.model == other.model and self.dataset == other.dataset
135 |
136 | def __hash__(self):
137 | return hash((self.model, self.dataset))
138 |
--------------------------------------------------------------------------------
/tti_eval/common/numpy_types.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, Literal, TypeVar
2 |
3 | import numpy as np
4 | import numpy.typing as npt
5 |
6 | DType = TypeVar("DType", bound=np.generic)
7 |
8 | NArray = Annotated[npt.NDArray[DType], Literal["N"]]
9 | NCArray = Annotated[npt.NDArray[DType], Literal["N", "C"]]
10 | NDArray = Annotated[npt.NDArray[DType], Literal["N", "D"]]
11 | N2Array = Annotated[npt.NDArray[DType], Literal["N", 2]]
12 |
13 | # Classes
14 | ProbabilityArray = NCArray[np.float32]
15 | ClassArray = NArray[np.int32]
16 | """
17 | Numpy array of shape [N,C] probabilities where N is number of samples and C is number of classes.
18 | """
19 |
20 | # Embeddings
21 | EmbeddingArray = NDArray[np.float32]
22 | """
23 | Numpy array of shape [N,D] embeddings where N is number of samples and D is the embedding size.
24 | """
25 |
26 | # Reductions
27 | ReductionArray = N2Array[np.float32]
28 | """
29 | Numpy array of shape [N,2] reductions of embeddings.
30 | """
31 |
--------------------------------------------------------------------------------
/tti_eval/common/string_utils.py:
--------------------------------------------------------------------------------
1 | KEEP_CHARACTERS = {".", "_", " ", "-"}
2 | REPLACE_CHARACTERS = {" ": "_"}
3 |
4 |
5 | def safe_str(unsafe: str) -> str:
6 | if not isinstance(unsafe, str):
7 | raise ValueError(f"{unsafe} ({type(unsafe)}) not a string")
8 | return "".join(REPLACE_CHARACTERS.get(c, c) for c in unsafe if c.isalnum() or c in KEEP_CHARACTERS).rstrip()
9 |
--------------------------------------------------------------------------------
/tti_eval/compute.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 |
3 | from tti_eval.common import EmbeddingDefinition, Embeddings, Split
4 | from tti_eval.dataset import Dataset, DatasetProvider
5 | from tti_eval.model import Model, ModelProvider
6 |
7 |
8 | def compute_embeddings(model: Model, dataset: Dataset, batch_size: int = 50) -> Embeddings:
9 | dataset.set_transform(model.get_transform())
10 | dataloader = DataLoader(dataset, collate_fn=model.get_collate_fn(), batch_size=batch_size)
11 |
12 | image_embeddings, class_embeddings, labels = model.build_embedding(dataloader)
13 | embeddings = Embeddings(images=image_embeddings, classes=class_embeddings, labels=labels)
14 | return embeddings
15 |
16 |
17 | def compute_embeddings_from_definition(definition: EmbeddingDefinition, split: Split) -> Embeddings:
18 | model = ModelProvider.get_model(definition.model)
19 | dataset = DatasetProvider.get_dataset(definition.dataset, split)
20 | return compute_embeddings(model, dataset)
21 |
--------------------------------------------------------------------------------
/tti_eval/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | from dotenv import load_dotenv
5 |
6 | load_dotenv()
7 |
8 | # If the cache directory is not explicitly specified, use the `.cache` directory located in the project's root.
9 | _TTI_EVAL_ROOT_DIR = Path(__file__).parent.parent
10 | CACHE_PATH = Path(os.environ.get("TTI_EVAL_CACHE_PATH", _TTI_EVAL_ROOT_DIR / ".cache"))
11 | _OUTPUT_PATH = Path(os.environ.get("TTI_EVAL_OUTPUT_PATH", _TTI_EVAL_ROOT_DIR / "output"))
12 | _SOURCES_PATH = _TTI_EVAL_ROOT_DIR / "sources"
13 |
14 |
15 | class PROJECT_PATHS:
16 | EMBEDDINGS = CACHE_PATH / "embeddings"
17 | MODELS = CACHE_PATH / "models"
18 | REDUCTIONS = CACHE_PATH / "reductions"
19 |
20 |
21 | class NPZ_KEYS:
22 | IMAGE_EMBEDDINGS = "image_embeddings"
23 | CLASS_EMBEDDINGS = "class_embeddings"
24 | LABELS = "labels"
25 |
26 |
27 | class OUTPUT_PATH:
28 | ANIMATIONS = _OUTPUT_PATH / "animations"
29 | EVALUATIONS = _OUTPUT_PATH / "evaluations"
30 |
31 |
32 | class SOURCES_PATH:
33 | DATASET_INSTANCE_DEFINITIONS = _SOURCES_PATH / "datasets"
34 | MODEL_INSTANCE_DEFINITIONS = _SOURCES_PATH / "models"
35 |
--------------------------------------------------------------------------------
/tti_eval/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Dataset, Split
2 | from .provider import DatasetProvider
3 |
--------------------------------------------------------------------------------
/tti_eval/dataset/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from os.path import relpath
3 | from pathlib import Path
4 |
5 | from pydantic import BaseModel, ConfigDict
6 | from torch.utils.data import Dataset as TorchDataset
7 |
8 | from tti_eval.common import Split
9 | from tti_eval.constants import CACHE_PATH, SOURCES_PATH
10 |
11 | DEFAULT_DATASET_TYPES_LOCATION = (
12 | Path(relpath(str(__file__), SOURCES_PATH.DATASET_INSTANCE_DEFINITIONS)).parent / "types" / "__init__.py"
13 | )
14 |
15 |
16 | class DatasetDefinitionSpec(BaseModel):
17 | dataset_type: str
18 | title: str
19 | module_path: Path = DEFAULT_DATASET_TYPES_LOCATION
20 | split: Split | None = None
21 | title_in_source: str | None = None
22 | cache_dir: Path | None = None
23 |
24 | # Allow additional dataset configuration fields
25 | model_config = ConfigDict(extra="allow")
26 |
27 |
28 | class Dataset(TorchDataset, ABC):
29 | def __init__(
30 | self,
31 | title: str,
32 | *,
33 | split: Split = Split.ALL,
34 | title_in_source: str | None = None,
35 | transform=None,
36 | cache_dir: str | None = None,
37 | **kwargs,
38 | ):
39 | self.transform = transform
40 | self.__title = title
41 | self.__title_in_source = title if title_in_source is None else title_in_source
42 | self.__split = split
43 | self.__class_names = []
44 | if cache_dir is None:
45 | cache_dir = CACHE_PATH
46 | self._cache_dir: Path = Path(cache_dir).expanduser().resolve() / "datasets" / title
47 |
48 | @abstractmethod
49 | def __getitem__(self, idx):
50 | pass
51 |
52 | @abstractmethod
53 | def __len__(self):
54 | pass
55 |
56 | @property
57 | def split(self) -> Split:
58 | return self.__split
59 |
60 | @property
61 | def title(self) -> str:
62 | return self.__title
63 |
64 | @property
65 | def title_in_source(self) -> str:
66 | return self.__title_in_source
67 |
68 | @property
69 | def class_names(self) -> list[str]:
70 | return self.__class_names
71 |
72 | @class_names.setter
73 | def class_names(self, class_names: list[str]) -> None:
74 | self.__class_names = class_names
75 |
76 | def set_transform(self, transform):
77 | self.transform = transform
78 |
79 | @property
80 | def text_queries(self) -> list[str]:
81 | return [f"An image of a {class_name}" for class_name in self.class_names]
82 |
83 | @abstractmethod
84 | def _setup(self, **kwargs):
85 | pass
86 |
--------------------------------------------------------------------------------
/tti_eval/dataset/provider.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from pathlib import Path
3 | from typing import Any
4 |
5 | from natsort import natsorted, ns
6 |
7 | from tti_eval.constants import CACHE_PATH, SOURCES_PATH
8 |
9 | from .base import Dataset, DatasetDefinitionSpec, Split
10 | from .utils import load_class_from_path
11 |
12 |
13 | class DatasetProvider:
14 | __instance = None
15 | __global_settings: dict[str, Any] = dict()
16 | __known_dataset_types: dict[tuple[Path, str], Any] = dict()
17 |
18 | def __init__(self):
19 | self._datasets = {}
20 |
21 | @classmethod
22 | def prepare(cls):
23 | if cls.__instance is None:
24 | cls.__instance = cls()
25 | cls.register_datasets_from_sources_dir(SOURCES_PATH.DATASET_INSTANCE_DEFINITIONS)
26 | # Global settings
27 | cls.__instance.add_global_setting("cache_dir", CACHE_PATH)
28 | return cls.__instance
29 |
30 | @classmethod
31 | def global_settings(cls) -> dict:
32 | return deepcopy(cls.__global_settings)
33 |
34 | @classmethod
35 | def add_global_setting(cls, name: str, value: Any) -> None:
36 | cls.__global_settings[name] = value
37 |
38 | @classmethod
39 | def remove_global_setting(cls, name: str) -> None:
40 | cls.__global_settings.pop(name, None)
41 |
42 | @classmethod
43 | def register_dataset(cls, source: type[Dataset], title: str, split: Split | None = None, **kwargs) -> None:
44 | instance = cls.prepare()
45 | if split is None:
46 | # One dataset with all the split definitions
47 | instance._datasets[title] = (source, kwargs)
48 | else:
49 | # One dataset is defined per split
50 | kwargs.update(split=split)
51 | instance._datasets[(title, split)] = (source, kwargs)
52 |
53 | @classmethod
54 | def register_dataset_from_json_definition(cls, json_definition: Path) -> None:
55 | spec = DatasetDefinitionSpec.model_validate_json(json_definition.read_text(encoding="utf-8"))
56 | if not spec.module_path.is_absolute(): # Handle relative module paths
57 | spec.module_path = (json_definition.parent / spec.module_path).resolve()
58 | if not spec.module_path.is_file():
59 | raise FileNotFoundError(
60 | f"Could not find the specified module path at `{spec.module_path.as_posix()}` "
61 | f"when registering dataset `{spec.title}`"
62 | )
63 |
64 | # Fetch the class of the dataset type stated in the definition
65 | dataset_type = cls.__known_dataset_types.get((spec.module_path, spec.dataset_type))
66 | if dataset_type is None:
67 | dataset_type = load_class_from_path(spec.module_path.as_posix(), spec.dataset_type)
68 | if not issubclass(dataset_type, Dataset):
69 | raise ValueError(
70 | f"Dataset type specified in the JSON definition file `{json_definition.as_posix()}` "
71 | f"does not inherit from the base class `Dataset`"
72 | )
73 | cls.__known_dataset_types[(spec.module_path, spec.dataset_type)] = dataset_type
74 | cls.register_dataset(dataset_type, **spec.model_dump(exclude={"module_path", "dataset_type"}))
75 |
76 | @classmethod
77 | def register_datasets_from_sources_dir(cls, source_dir: Path) -> None:
78 | for f in source_dir.glob("*.json"):
79 | cls.register_dataset_from_json_definition(f)
80 |
81 | @classmethod
82 | def get_dataset(cls, title: str, split: Split) -> Dataset:
83 | instance = cls.prepare()
84 | if (title, split) in instance._datasets:
85 | # The split corresponds to fetching a whole dataset (one-to-one relationship)
86 | dict_key = (title, split)
87 | split = Split.ALL # Ensure to read the whole dataset
88 | elif title in instance._datasets:
89 | # The dataset internally knows how to determine the split
90 | dict_key = title
91 | else:
92 | raise ValueError(f"Unrecognized dataset: {title}")
93 |
94 | source, kwargs = instance._datasets[dict_key]
95 | # Apply global settings. Values of local settings take priority when local and global settings share keys.
96 | kwargs_with_global_settings = cls.global_settings() | kwargs
97 | return source(title, split=split, **kwargs_with_global_settings)
98 |
99 | @classmethod
100 | def list_dataset_titles(cls) -> list[str]:
101 | dataset_titles = [
102 | dict_key[0] if isinstance(dict_key, tuple) else dict_key for dict_key in cls.prepare()._datasets.keys()
103 | ]
104 | return natsorted(set(dataset_titles), alg=ns.IGNORECASE)
105 |
--------------------------------------------------------------------------------
/tti_eval/dataset/types/__init__.py:
--------------------------------------------------------------------------------
1 | from tti_eval.dataset.types.encord_ds import EncordDataset
2 | from tti_eval.dataset.types.hugging_face import HFDataset
3 |
--------------------------------------------------------------------------------
/tti_eval/dataset/types/encord_ds.py:
--------------------------------------------------------------------------------
1 | import json
2 | import multiprocessing
3 | import os
4 | from collections.abc import Callable
5 | from concurrent.futures import ThreadPoolExecutor
6 | from dataclasses import dataclass
7 | from functools import partial
8 | from pathlib import Path
9 | from typing import Any
10 |
11 | from encord import EncordUserClient, Project
12 | from encord.common.constants import DATETIME_STRING_FORMAT
13 | from encord.objects import Classification, LabelRowV2
14 | from encord.objects.common import PropertyType
15 | from encord.orm.dataset import DataType, Image, Video
16 | from PIL import Image as PILImage
17 | from tqdm.auto import tqdm
18 |
19 | from tti_eval.dataset import Dataset, Split
20 | from tti_eval.dataset.utils import collect_async, download_file, simple_random_split
21 |
22 |
23 | class EncordDataset(Dataset):
24 | def __init__(
25 | self,
26 | title: str,
27 | project_hash: str,
28 | classification_hash: str,
29 | *,
30 | split: Split = Split.ALL,
31 | title_in_source: str | None = None,
32 | transform=None,
33 | cache_dir: str | None = None,
34 | ssh_key_path: str | None = None,
35 | **kwargs,
36 | ):
37 | super().__init__(
38 | title,
39 | split=split,
40 | title_in_source=title_in_source,
41 | transform=transform,
42 | cache_dir=cache_dir,
43 | )
44 | self._setup(project_hash, classification_hash, ssh_key_path, **kwargs)
45 |
46 | def __getitem__(self, idx):
47 | frame_path = self._dataset_indices_info[idx].image_file
48 | img = PILImage.open(frame_path)
49 | label = self._dataset_indices_info[idx].label
50 |
51 | if self.transform is not None:
52 | _d = self.transform(dict(image=[img], label=[label]))
53 | res_item = dict(image=_d["image"][0], label=_d["label"][0])
54 | else:
55 | res_item = dict(image=img, label=label)
56 | return res_item
57 |
58 | def __len__(self):
59 | return len(self._dataset_indices_info)
60 |
61 | def _get_frame_file(self, label_row: LabelRowV2, frame: int) -> Path:
62 | return get_frame_file(
63 | data_dir=self._cache_dir,
64 | project_hash=self._project.project_hash,
65 | label_row=label_row,
66 | frame=frame,
67 | )
68 |
69 | def _get_label_row_annotations_file(self, label_row: LabelRowV2) -> Path:
70 | return get_label_row_annotations_file(
71 | data_dir=self._cache_dir,
72 | project_hash=self._project.project_hash,
73 | label_row_hash=label_row.label_hash,
74 | )
75 |
76 | def _ensure_answers_availability(self) -> dict:
77 | lrs_info_file = get_label_rows_info_file(self._cache_dir, self._project.project_hash)
78 | label_rows_info: dict = json.loads(lrs_info_file.read_text(encoding="utf-8"))
79 | should_update_info = False
80 | class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices
81 | for label_row in self._label_rows:
82 | if "answers" not in label_rows_info[label_row.label_hash]:
83 | if not label_row.is_labelling_initialised:
84 | # Retrieve label row content from local storage
85 | anns_path = self._get_label_row_annotations_file(label_row)
86 | label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8")))
87 |
88 | answers = dict()
89 | for frame_view in label_row.get_frame_views():
90 | clf_instances = frame_view.get_classification_instances(self._classification)
91 | # Skip frames where the input classification is missing
92 | if len(clf_instances) == 0:
93 | continue
94 |
95 | clf_instance = clf_instances[0]
96 | clf_answer = clf_instance.get_answer(self._attribute)
97 | # Skip frames where the input classification has no answer (probable annotation error)
98 | if clf_answer is None:
99 | continue
100 |
101 | answers[frame_view.frame] = {
102 | "image_file": self._get_frame_file(label_row, frame_view.frame).as_posix(),
103 | "label": class_name_to_idx[clf_answer.title],
104 | }
105 | label_rows_info[label_row.label_hash]["answers"] = answers
106 | should_update_info = True
107 | if should_update_info:
108 | lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8")
109 | return label_rows_info
110 |
111 | def _setup(
112 | self,
113 | project_hash: str,
114 | classification_hash: str,
115 | ssh_key_path: str | None = None,
116 | **kwargs,
117 | ):
118 | ssh_key_path = ssh_key_path or os.getenv("ENCORD_SSH_KEY_PATH")
119 | if ssh_key_path is None:
120 | raise ValueError(
121 | "The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing. "
122 | "Please set one of them to proceed."
123 | )
124 | client = EncordUserClient.create_with_ssh_private_key(ssh_private_key_path=ssh_key_path)
125 | self._project = client.get_project(project_hash)
126 |
127 | self._classification = self._project.ontology_structure.get_child_by_hash(
128 | classification_hash, type_=Classification
129 | )
130 | radio_attribute = self._classification.attributes[0]
131 | if radio_attribute.get_property_type() != PropertyType.RADIO:
132 | raise ValueError("Expected a classification hash with an attribute of type `Radio`")
133 | self._attribute = radio_attribute
134 | self.class_names = [o.title for o in self._attribute.options]
135 |
136 | # Fetch the label rows of the selected split
137 | splits_file = self._cache_dir / "splits.json"
138 | split_to_lr_hashes: dict[str, list[str]]
139 | if splits_file.exists():
140 | split_to_lr_hashes = json.loads(splits_file.read_text(encoding="utf-8"))
141 | else:
142 | split_to_lr_hashes = simple_project_split(self._project)
143 | splits_file.parent.mkdir(parents=True, exist_ok=True)
144 | splits_file.write_text(json.dumps(split_to_lr_hashes), encoding="utf-8")
145 | self._label_rows = self._project.list_label_rows_v2(label_hashes=split_to_lr_hashes[self.split])
146 |
147 | # Get data from source. Users may supply the `overwrite_annotations` keyword in the init to download everything
148 | download_data_from_project(
149 | self._project,
150 | self._cache_dir,
151 | self._label_rows,
152 | tqdm_desc=f"Downloading {self.split} data from Encord project `{self._project.title}`",
153 | **kwargs,
154 | )
155 |
156 | # Prepare data for the __getitem__ method
157 | self._dataset_indices_info: list[EncordDataset.DatasetIndexInfo] = []
158 | label_rows_info = self._ensure_answers_availability()
159 | for label_row in self._label_rows:
160 | answers: dict[int, Any] = label_rows_info[label_row.label_hash]["answers"]
161 | for frame_num in sorted(answers.keys()):
162 | self._dataset_indices_info.append(EncordDataset.DatasetIndexInfo(**answers[frame_num]))
163 |
164 | @dataclass
165 | class DatasetIndexInfo:
166 | image_file: Path | str
167 | label: int
168 |
169 |
170 | # -----------------------------------------------------------------------
171 | # UTILITY FUNCTIONS
172 | # -----------------------------------------------------------------------
173 |
174 |
175 | def _download_image(image_data: Image | Video, destination_dir: Path) -> Path:
176 | # TODO The type of `image_data` is also Video because of a SDK bug explained in `_download_label_row_image_data`.
177 | file_name = get_frame_name(image_data.image_hash, image_data.title)
178 | destination_path = destination_dir / file_name
179 | if not destination_path.exists():
180 | download_file(image_data.file_link, destination_path)
181 | return destination_path
182 |
183 |
184 | def _download_label_row_image_data(data_dir: Path, project: Project, label_row: LabelRowV2) -> list[Path]:
185 | label_row.initialise_labels()
186 | label_row_dir = get_label_row_dir(data_dir, project.project_hash, label_row.label_hash)
187 | label_row_dir.mkdir(parents=True, exist_ok=True)
188 |
189 | if label_row.data_type == DataType.IMAGE:
190 | # TODO This `if` is here because of a SDK bug, remove it when IMAGE data is stored in the proper image field [1]
191 | images_data = [project.get_data(label_row.data_hash, get_signed_url=True)[0]]
192 | # Missing field caused by the SDK bug
193 | images_data[0]["image_hash"] = label_row.data_hash
194 | else:
195 | images_data = project.get_data(label_row.data_hash, get_signed_url=True)[1]
196 | return collect_async(
197 | lambda image_data: _download_image(image_data, label_row_dir),
198 | images_data,
199 | max_workers=4,
200 | disable=True,
201 | )
202 |
203 |
204 | def _download_label_row(
205 | label_row: LabelRowV2,
206 | project: Project,
207 | data_dir: Path,
208 | overwrite_annotations: bool,
209 | label_rows_info: dict[str, Any],
210 | update_pbar: Callable[[], Any],
211 | ):
212 | if label_row.data_type not in {DataType.IMAGE, DataType.IMG_GROUP}:
213 | return
214 | save_annotations = False
215 | # Trigger the images download if the label hash is not found or is None (never downloaded).
216 | if label_row.label_hash not in label_rows_info.keys():
217 | _download_label_row_image_data(data_dir, project, label_row)
218 | save_annotations = True
219 | # Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations.
220 | elif (
221 | overwrite_annotations
222 | and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT)
223 | != label_rows_info[label_row.label_hash]["last_edited_at"]
224 | ):
225 | label_row.initialise_labels()
226 | save_annotations = True
227 |
228 | if save_annotations:
229 | annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash)
230 | annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8")
231 | label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at}
232 | update_pbar()
233 |
234 |
235 | def _download_label_rows(
236 | project: Project,
237 | data_dir: Path,
238 | label_rows: list[LabelRowV2],
239 | overwrite_annotations: bool,
240 | label_rows_info: dict[str, Any],
241 | tqdm_desc: str | None,
242 | ):
243 | if tqdm_desc is None:
244 | tqdm_desc = f"Downloading data from Encord project `{project.title}`"
245 |
246 | pbar = tqdm(total=len(label_rows), desc=tqdm_desc)
247 | _do_download = partial(
248 | _download_label_row,
249 | project=project,
250 | data_dir=data_dir,
251 | overwrite_annotations=overwrite_annotations,
252 | label_rows_info=label_rows_info,
253 | update_pbar=lambda: pbar.update(1),
254 | )
255 |
256 | with ThreadPoolExecutor(min(multiprocessing.cpu_count(), 24)) as exe:
257 | exe.map(_do_download, label_rows)
258 |
259 |
260 | def download_data_from_project(
261 | project: Project,
262 | data_dir: Path,
263 | label_rows: list[LabelRowV2] | None = None,
264 | overwrite_annotations: bool = False,
265 | tqdm_desc: str | None = None,
266 | ) -> None:
267 | """
268 | Iterates through the images of the project and downloads their content, adhering to the following file structure:
269 | data_dir/
270 | ├── project-hash/
271 | │ ├── image-group-hash1/
272 | │ │ ├── image1.jpg
273 | │ │ ├── image2.jpg
274 | │ │ └── ...
275 | │ │ └── annotations.json
276 | │ └── image-hash1/
277 | │ │ ├── image1.jpeg
278 | │ │ ├── annotations.json
279 | │ └── ...
280 | └── ...
281 | :param project: The project containing the images with their annotations.
282 | :param data_dir: The directory where the project data will be downloaded.
283 | :param label_rows: The label rows that will be downloaded. If None, all label rows will be downloaded.
284 | :param overwrite_annotations: Flag that indicates whether to overwrite existing annotations if they exist.
285 | :param tqdm_desc: Optional description for tqdm progress bar.
286 | Defaults to 'Downloading data from Encord project `{project.title}`'
287 | """
288 | # Read file that tracks the downloaded data progress
289 | lrs_info_file = get_label_rows_info_file(data_dir, project.project_hash)
290 | lrs_info_file.parent.mkdir(parents=True, exist_ok=True)
291 | label_rows_info = json.loads(lrs_info_file.read_text(encoding="utf-8")) if lrs_info_file.is_file() else dict()
292 |
293 | # Retrieve only the unseen data if there is no explicit annotation update
294 | filtered_label_rows = (
295 | label_rows
296 | if overwrite_annotations
297 | else [lr for lr in label_rows if lr.label_hash not in label_rows_info.keys()]
298 | )
299 | if len(filtered_label_rows) == 0:
300 | return
301 |
302 | try:
303 | _download_label_rows(
304 | project,
305 | data_dir,
306 | filtered_label_rows,
307 | overwrite_annotations,
308 | label_rows_info,
309 | tqdm_desc=tqdm_desc,
310 | )
311 | finally:
312 | # Save the current download progress in case of failure
313 | lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8")
314 |
315 |
316 | def get_frame_name(frame_hash: str, frame_title: str) -> str:
317 | file_extension = frame_title.rsplit(sep=".", maxsplit=1)[-1]
318 | return f"{frame_hash}.{file_extension}"
319 |
320 |
321 | def get_frame_file(data_dir: Path, project_hash: str, label_row: LabelRowV2, frame: int) -> Path:
322 | label_row_dir = get_label_row_dir(data_dir, project_hash, label_row.label_hash)
323 | frame_view = label_row.get_frame_view(frame)
324 | return label_row_dir / get_frame_name(frame_view.image_hash, frame_view.image_title)
325 |
326 |
327 | def get_frame_file_raw(
328 | data_dir: Path,
329 | project_hash: str,
330 | label_row_hash: str,
331 | frame_hash: str,
332 | frame_title: str,
333 | ) -> Path:
334 | return get_label_row_dir(data_dir, project_hash, label_row_hash) / get_frame_name(frame_hash, frame_title)
335 |
336 |
337 | def get_label_row_annotations_file(data_dir: Path, project_hash: str, label_row_hash: str) -> Path:
338 | return get_label_row_dir(data_dir, project_hash, label_row_hash) / "annotations.json"
339 |
340 |
341 | def get_label_row_dir(data_dir: Path, project_hash: str, label_row_hash: str) -> Path:
342 | return data_dir / project_hash / label_row_hash
343 |
344 |
345 | def get_label_rows_info_file(data_dir: Path, project_hash: str) -> Path:
346 | return data_dir / project_hash / "label_rows_info.json"
347 |
348 |
349 | def simple_project_split(
350 | project: Project,
351 | seed: int = 42,
352 | train_split: float = 0.7,
353 | validation_split: float = 0.15,
354 | ) -> dict[Split, list[str]]:
355 | """
356 | Split the label rows of a project into training, validation, and test sets using simple random splitting.
357 |
358 | :param project: The project containing the label rows to split.
359 | :param seed: Random seed for reproducibility. Defaults to 42.
360 | :param train_split: Percentage of the dataset to allocate to the training set. Defaults to 0.7.
361 | :param validation_split: Percentage of the dataset to allocate to the validation set. Defaults to 0.15.
362 | :return: A dictionary containing lists with the label hashes of the data represented in the training,
363 | validation, and test sets.
364 |
365 | :raises ValueError: If the sum of `train_split` and `validation_split` is greater than 1,
366 | or if `train_split` or `validation_split` are less than 0.
367 | """
368 | label_rows = project.list_label_rows_v2()
369 | split_to_indices = simple_random_split(len(label_rows), seed, train_split, validation_split)
370 | enforce_label_rows_initialization(label_rows) # Ensure that all label rows have a label hash
371 | return {split: [label_rows[i].label_hash for i in indices] for split, indices in split_to_indices.items()}
372 |
373 |
374 | def enforce_label_rows_initialization(label_rows: list[LabelRowV2]):
375 | for lr in label_rows:
376 | if lr.label_hash is None:
377 | lr.initialise_labels()
378 |
--------------------------------------------------------------------------------
/tti_eval/dataset/types/hugging_face.py:
--------------------------------------------------------------------------------
1 | from datasets import ClassLabel, DatasetDict, Sequence, Value, load_dataset
2 | from datasets import Dataset as _RemoteHFDataset
3 |
4 | from tti_eval.dataset import Dataset, Split
5 |
6 |
7 | class HFDataset(Dataset):
8 | def __init__(
9 | self,
10 | title: str,
11 | *,
12 | split: Split = Split.ALL,
13 | title_in_source: str | None = None,
14 | transform=None,
15 | cache_dir: str | None = None,
16 | target_feature: str = "label",
17 | **kwargs,
18 | ):
19 | super().__init__(title, split=split, title_in_source=title_in_source, transform=transform, cache_dir=cache_dir)
20 | self._target_feature = target_feature
21 | self._setup(**kwargs)
22 |
23 | def __getitem__(self, idx):
24 | return self._dataset[idx]
25 |
26 | def __len__(self):
27 | return len(self._dataset)
28 |
29 | def set_transform(self, transform):
30 | super().set_transform(transform)
31 | self._dataset.set_transform(transform)
32 |
33 | @staticmethod
34 | def _get_available_splits(dataset_dict: DatasetDict) -> list[Split]:
35 | return [Split(s) for s in dataset_dict.keys() if s in [_ for _ in Split]] + [Split.ALL]
36 |
37 | def _get_hf_dataset_split(self, **kwargs) -> _RemoteHFDataset:
38 | try:
39 | if self.split == Split.ALL: # Retrieve all the dataset data if the split is ALL
40 | return load_dataset(self.title_in_source, split="all", cache_dir=self._cache_dir.as_posix(), **kwargs)
41 | dataset_dict: DatasetDict = load_dataset(
42 | self.title_in_source,
43 | cache_dir=self._cache_dir.as_posix(),
44 | **kwargs,
45 | )
46 | except Exception as e:
47 | raise ValueError(f"Failed to load dataset from Hugging Face: `{self.title_in_source}`") from e
48 |
49 | available_splits = HFDataset._get_available_splits(dataset_dict)
50 | missing_splits = [s for s in Split if s not in available_splits]
51 |
52 | # Return target dataset if it already exists and won't be modified
53 | if self.split in [Split.VALIDATION, Split.TEST] and self.split in available_splits:
54 | return dataset_dict[self.split]
55 | if self.split == Split.TRAIN:
56 | if self.split in missing_splits:
57 | # Train split must always exist
58 | raise AttributeError(f"Missing train split in Hugging Face dataset: `{self.title_in_source}`")
59 | if not missing_splits:
60 | # No need to split the train dataset, can be returned as a whole
61 | return dataset_dict[self.split]
62 |
63 | # Get a 15% of the train dataset for each missing split (VALIDATION, TEST or both)
64 | # This operation includes data shuffling to prevent splits with skewed class counts because of the input order
65 | split_percent = 0.15 * len(missing_splits)
66 | split_seed = 42
67 | # Split the original train dataset into two, the final train dataset and the missing splits dataset
68 | split_to_dataset = dataset_dict["train"].train_test_split(test_size=split_percent, seed=split_seed)
69 | if self.split == Split.TRAIN:
70 | return split_to_dataset[self.split]
71 |
72 | if len(missing_splits) == 1:
73 | # One missing split (either VALIDATION or TEST), so we return the 15% stored in "test"
74 | return split_to_dataset["test"]
75 | else:
76 | # Both VALIDATION and TEST splits are missing
77 | # Each one will take a half of the 30% stored in "test"
78 | if self.split == Split.VALIDATION:
79 | return split_to_dataset["test"].train_test_split(test_size=0.5, seed=split_seed)["train"]
80 | else:
81 | return split_to_dataset["test"].train_test_split(test_size=0.5, seed=split_seed)["test"]
82 |
83 | def _setup(self, **kwargs):
84 | self._dataset = self._get_hf_dataset_split(**kwargs)
85 |
86 | if self._target_feature not in self._dataset.features:
87 | raise ValueError(f"The dataset `{self.title}` does not have the target feature `{self._target_feature}`")
88 | # Encode the target feature if necessary (e.g. use integer instead of string in the label values)
89 | if isinstance(self._dataset.features[self._target_feature], Value):
90 | self._dataset = self._dataset.class_encode_column(self._target_feature)
91 |
92 | # Rename the target feature to `label`
93 | # TODO do not rename the target feature but use it in the embeddings computations instead of the `label` tag
94 | if self._target_feature != "label":
95 | self._dataset = self._dataset.rename_column(self._target_feature, "label")
96 |
97 | label_feature = self._dataset.features["label"]
98 | if isinstance(label_feature, Sequence): # Drop potential wrapper
99 | label_feature = label_feature.feature
100 |
101 | if isinstance(label_feature, ClassLabel):
102 | self.class_names = label_feature.names
103 | else:
104 | raise TypeError(f"Expected target feature of type `ClassLabel`, found `{type(label_feature).__name__}`")
105 |
106 | def __del__(self):
107 | for f in self._cache_dir.glob("**/*.lock"):
108 | f.unlink(missing_ok=True)
109 |
--------------------------------------------------------------------------------
/tti_eval/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 | import os
3 | from collections.abc import Callable
4 | from concurrent.futures import ThreadPoolExecutor as Executor
5 | from concurrent.futures import as_completed
6 | from pathlib import Path
7 | from typing import Any, TypeVar
8 |
9 | import numpy as np
10 | import requests
11 | from tqdm.auto import tqdm
12 |
13 | from tti_eval.common import Split
14 |
15 | T = TypeVar("T")
16 | G = TypeVar("G")
17 |
18 | _default_max_workers = min(10, (os.cpu_count() or 1) + 4)
19 |
20 |
21 | def collect_async(
22 | fn: Callable[[T], G],
23 | job_args: list[T],
24 | max_workers=_default_max_workers,
25 | **kwargs,
26 | ) -> list[G]:
27 | """
28 | Distribute work across multiple workers. Good for, e.g., downloading data.
29 | Will return results in dictionary.
30 | :param fn: The function to be applied
31 | :param job_args: Arguments to `fn`.
32 | :param max_workers: Number of workers to distribute work over.
33 | :param kwargs: Arguments passed on to tqdm.
34 | :return: List [fn(*job_args)]
35 | """
36 | if len(job_args) == 0:
37 | tmp: list[G] = []
38 | return tmp
39 | if not isinstance(job_args[0], tuple):
40 | _job_args: list[tuple[Any]] = [(j,) for j in job_args]
41 | else:
42 | _job_args = job_args # type: ignore
43 |
44 | results: list[G] = []
45 | with tqdm(total=len(job_args), **kwargs) as pbar:
46 | with Executor(max_workers=max_workers) as exe:
47 | jobs = [exe.submit(fn, *args) for args in _job_args]
48 | for job in as_completed(jobs):
49 | result = job.result()
50 | if result is not None:
51 | results.append(result)
52 | pbar.update(1)
53 | return results
54 |
55 |
56 | def download_file(
57 | url: str,
58 | destination: Path,
59 | ) -> None:
60 | destination.parent.mkdir(parents=True, exist_ok=True)
61 | with open(destination, "wb") as f:
62 | r = requests.get(url, stream=True)
63 | if r.status_code != 200:
64 | raise ConnectionError(f"Something happened, couldn't download file from: {url}")
65 |
66 | for chunk in r.iter_content(chunk_size=1024):
67 | if chunk:
68 | f.write(chunk)
69 | f.flush()
70 |
71 |
72 | def load_class_from_path(module_path: str, class_name: str):
73 | spec = importlib.util.spec_from_file_location(module_path, module_path)
74 | module = importlib.util.module_from_spec(spec)
75 | spec.loader.exec_module(module)
76 | return getattr(module, class_name)
77 |
78 |
79 | def simple_random_split(
80 | dataset_size: int,
81 | seed: int = 42,
82 | train_split: float = 0.7,
83 | validation_split: float = 0.15,
84 | ) -> dict[Split, np.ndarray]:
85 | """
86 | Split the dataset into training, validation, and test sets using simple random splitting.
87 |
88 | :param dataset_size: The total size of the dataset.
89 | :param seed: Random seed for reproducibility. Defaults to 42.
90 | :param train_split: Percentage of the dataset to allocate to the training set. Defaults to 0.7.
91 | :param validation_split: Percentage of the dataset to allocate to the validation set. Defaults to 0.15.
92 | :return: A dictionary containing arrays with the indices of the data represented in the training,
93 | validation, and test sets.
94 |
95 | :raises ValueError: If the sum of `train_split` and `validation_split` is greater than 1,
96 | or if `train_split` or `validation_split` are less than 0.
97 | """
98 | if dataset_size < 3:
99 | raise ValueError(f"Expected a dataset with size at least 3, got {dataset_size}")
100 |
101 | if train_split < 0 or validation_split < 0:
102 | raise ValueError(f"Expected positive splits, got ({train_split=}, {validation_split=})")
103 | if train_split + validation_split >= 1:
104 | raise ValueError(
105 | f"Expected `train_split` and `validation_split` sum between 0 and 1, got {train_split + validation_split}"
106 | )
107 | rng = np.random.default_rng(seed)
108 | selection = rng.permutation(dataset_size)
109 | train_size = max(1, int(dataset_size * train_split))
110 | validation_size = max(1, int(dataset_size * validation_split))
111 | # Ensure that the TEST split has at least an element
112 | if train_size + validation_size == dataset_size:
113 | if train_size > 1:
114 | train_size -= 1
115 | if validation_size > 1:
116 | validation_size -= 1
117 | return {
118 | Split.TRAIN: selection[:train_size],
119 | Split.VALIDATION: selection[train_size : train_size + validation_size],
120 | Split.TEST: selection[train_size + validation_size :],
121 | Split.ALL: selection,
122 | }
123 |
--------------------------------------------------------------------------------
/tti_eval/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import ClassificationModel, EvaluationModel
2 | from .image_retrieval import I2IRetrievalEvaluator
3 | from .knn import WeightedKNNClassifier
4 | from .linear_probe import LinearProbeClassifier
5 | from .zero_shot import ZeroShotClassifier
6 |
--------------------------------------------------------------------------------
/tti_eval/evaluation/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any
3 |
4 | from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
5 |
6 | from .utils import normalize
7 |
8 |
9 | class EvaluationModel(ABC):
10 | def __init__(
11 | self,
12 | train_embeddings: Embeddings,
13 | validation_embeddings: Embeddings,
14 | num_classes: int | None = None,
15 | **kwargs,
16 | ) -> None:
17 | super().__init__(**kwargs)
18 | # Preprocessing the embeddings
19 | train_embeddings.images = normalize(train_embeddings.images)
20 | validation_embeddings.images = normalize(validation_embeddings.images)
21 | if train_embeddings.classes is not None:
22 | train_embeddings.classes = normalize(train_embeddings.classes)
23 | if validation_embeddings.classes is not None:
24 | validation_embeddings.classes = normalize(validation_embeddings.classes)
25 |
26 | self._train_embeddings = train_embeddings
27 | self._val_embeddings = validation_embeddings
28 | self._check_dims()
29 |
30 | # Ensure that number of classes is always valid
31 | self._num_classes: int = max(train_embeddings.labels.max(), validation_embeddings.labels.max()).item() + 1
32 | if num_classes is not None:
33 | if self._num_classes > num_classes:
34 | raise ValueError("`num_classes` is lower than the number of classes found in the embeddings")
35 | self._num_classes = num_classes
36 |
37 | def _check_dims(self):
38 | """
39 | Raises error if dimensions doesn't match.
40 | """
41 | train_images_s, train_labels_s = self._train_embeddings.images.shape, self._train_embeddings.labels.shape
42 | if train_labels_s[0] != train_images_s[0]:
43 | raise ValueError(
44 | f"Expected `{train_images_s[0]}` samples in train's label embeddings, got `{train_labels_s[0]}`"
45 | )
46 | if self._train_embeddings.classes is not None and self._train_embeddings.classes.shape[-1] != self.dim:
47 | raise ValueError(
48 | f"Expected train's class embeddings with dimensions `{self.dim}`, "
49 | f"got `{self._train_embeddings.classes.shape[-1]}`"
50 | )
51 |
52 | val_images_s, val_labels_s = self._val_embeddings.images.shape, self._val_embeddings.labels.shape
53 | if val_labels_s[0] != val_images_s[0]:
54 | raise ValueError(
55 | f"Expected `{val_images_s[0]}` samples in validation's label embeddings, got `{val_labels_s[0]}`"
56 | )
57 | if val_images_s[-1] != self.dim:
58 | raise ValueError(
59 | f"Expected validation's image embeddings with dimensions `{self.dim}`, "
60 | f"got `{self._val_embeddings.images.shape[-1]}`"
61 | )
62 | if self._val_embeddings.classes is not None and self._val_embeddings.classes.shape[-1] != self.dim:
63 | raise ValueError(
64 | f"Expected validation's class embeddings with dimensions `{self.dim}`, "
65 | f"got `{self._val_embeddings.classes.shape[-1]}`"
66 | )
67 |
68 | @property
69 | def dim(self) -> int:
70 | """
71 | Expected embedding dimensionality.
72 | """
73 | return self._train_embeddings.images.shape[-1]
74 |
75 | @staticmethod
76 | def get_default_params() -> dict[str, Any]:
77 | return {}
78 |
79 | @property
80 | def num_classes(self) -> int:
81 | return self._num_classes
82 |
83 | @abstractmethod
84 | def evaluate(self) -> float:
85 | ...
86 |
87 | @classmethod
88 | @abstractmethod
89 | def title(cls) -> str:
90 | ...
91 |
92 |
93 | class ClassificationModel(EvaluationModel):
94 | @abstractmethod
95 | def predict(self) -> tuple[ProbabilityArray, ClassArray]:
96 | ...
97 |
98 | def evaluate(self) -> float:
99 | _, y_hat = self.predict()
100 | acc: float = (y_hat == self._val_embeddings.labels).astype(float).mean().item()
101 | return acc
102 |
--------------------------------------------------------------------------------
/tti_eval/evaluation/evaluator.py:
--------------------------------------------------------------------------------
1 | import csv
2 | from datetime import datetime
3 |
4 | from natsort import natsorted, ns
5 | from tabulate import tabulate
6 | from tqdm.auto import tqdm
7 |
8 | from tti_eval.common import EmbeddingDefinition, Split
9 | from tti_eval.constants import OUTPUT_PATH
10 | from tti_eval.evaluation import (
11 | EvaluationModel,
12 | I2IRetrievalEvaluator,
13 | LinearProbeClassifier,
14 | WeightedKNNClassifier,
15 | ZeroShotClassifier,
16 | )
17 | from tti_eval.utils import read_all_cached_embeddings
18 |
19 |
20 | def print_evaluation_results(
21 | results: dict[EmbeddingDefinition, dict[str, float]],
22 | evaluation_model_title: str,
23 | ):
24 | defs = list(results.keys())
25 | model_names = natsorted(set(map(lambda d: d.model, defs)), alg=ns.IGNORECASE)
26 | dataset_names = natsorted(set(map(lambda d: d.dataset, defs)), alg=ns.IGNORECASE)
27 |
28 | table: list[list[float | str]] = [
29 | ["Model/Dataset"] + dataset_names,
30 | ] + [[m] + ["-"] * len(dataset_names) for m in model_names]
31 |
32 | def set_score(d: EmbeddingDefinition, s: float):
33 | row = model_names.index(d.model) + 1
34 | col = dataset_names.index(d.dataset) + 1
35 | table[row][col] = f"{s:.4f}"
36 |
37 | for d, res in results.items():
38 | s = res.get(evaluation_model_title)
39 | if s:
40 | set_score(d, res[evaluation_model_title])
41 |
42 | print(f"{'='*5} {evaluation_model_title} {'=' * 5}")
43 | print(tabulate(table))
44 |
45 |
46 | def run_evaluation(
47 | evaluators: list[type[EvaluationModel]],
48 | embedding_definitions: list[EmbeddingDefinition],
49 | ) -> dict[EmbeddingDefinition, dict[str, float]]:
50 | embeddings_performance: dict[EmbeddingDefinition, dict[str, float]] = {}
51 | used_evaluators: set[str] = set()
52 |
53 | for def_ in tqdm(embedding_definitions, desc="Evaluating models", leave=False):
54 | train_embeddings = def_.load_embeddings(Split.TRAIN)
55 | validation_embeddings = def_.load_embeddings(Split.VALIDATION)
56 |
57 | if train_embeddings is None:
58 | print(f"No train embeddings were found for {def_}")
59 | continue
60 | if validation_embeddings is None:
61 | print(f"No validation embeddings were found for {def_}")
62 | continue
63 |
64 | evaluator_performance: dict[str, float] = embeddings_performance.setdefault(def_, {})
65 | for evaluator_type in evaluators:
66 | if evaluator_type == ZeroShotClassifier and train_embeddings.classes is None:
67 | continue
68 | evaluator = evaluator_type(
69 | train_embeddings=train_embeddings,
70 | validation_embeddings=validation_embeddings,
71 | )
72 | evaluator_performance[evaluator.title()] = evaluator.evaluate()
73 | used_evaluators.add(evaluator.title())
74 |
75 | for evaluator_type in evaluators:
76 | evaluator_title = evaluator_type.title()
77 | if evaluator_title in used_evaluators:
78 | print_evaluation_results(embeddings_performance, evaluator_title)
79 | return embeddings_performance
80 |
81 |
82 | def export_evaluation_to_csv(embeddings_performance: dict[EmbeddingDefinition, dict[str, float]]) -> None:
83 | ts = datetime.now()
84 | results_file = OUTPUT_PATH.EVALUATIONS / f"eval_{ts.isoformat()}.csv"
85 | results_file.parent.mkdir(parents=True, exist_ok=True) # Ensure that parent folder exists
86 |
87 | headers = ["Model", "Dataset", "Classifier", "Accuracy"]
88 | with open(results_file.as_posix(), "w", newline="") as csvfile:
89 | writer = csv.writer(csvfile)
90 | writer.writerow(headers)
91 |
92 | for def_, perf in embeddings_performance.items():
93 | def_: EmbeddingDefinition
94 | for classifier_title, accuracy in perf.items():
95 | writer.writerow([def_.model, def_.dataset, classifier_title, accuracy])
96 | print(f"Evaluation results exported to `{results_file.as_posix()}`")
97 |
98 |
99 | if __name__ == "__main__":
100 | models = [ZeroShotClassifier, LinearProbeClassifier, WeightedKNNClassifier, I2IRetrievalEvaluator]
101 | defs = read_all_cached_embeddings(as_list=True)
102 | print(defs)
103 | performances = run_evaluation(models, defs)
104 | export_evaluation_to_csv(performances)
105 | print(performances)
106 |
--------------------------------------------------------------------------------
/tti_eval/evaluation/image_retrieval.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any
3 |
4 | import numpy as np
5 | from faiss import IndexFlatL2
6 |
7 | from tti_eval.common import Embeddings
8 | from tti_eval.utils import disable_tqdm
9 |
10 | from .base import EvaluationModel
11 |
12 | logger = logging.getLogger("multiclips")
13 |
14 |
15 | class I2IRetrievalEvaluator(EvaluationModel):
16 | @classmethod
17 | def title(cls) -> str:
18 | return "I2IR"
19 |
20 | def __init__(
21 | self,
22 | train_embeddings: Embeddings,
23 | validation_embeddings: Embeddings,
24 | num_classes: int | None = None,
25 | k: int = 100,
26 | ) -> None:
27 | """
28 | Image-to-Image retrieval evaluator.
29 |
30 | For each training embedding (e_i), this evaluator computes the percentage (accuracy) of its k nearest neighbors
31 | from the validation embeddings that share the same class. It returns the mean percentage of correct nearest
32 | neighbors across all training embeddings.
33 |
34 | :param train_embeddings: Training embeddings used for evaluation.
35 | :param validation_embeddings: Validation embeddings used for similarity search setup.
36 | :param num_classes: Number of classes. If not specified, it will be inferred from the training embeddings.
37 | :param k: Number of nearest neighbors.
38 |
39 | :raises ValueError: If the build of the faiss index for similarity search fails.
40 | """
41 | super().__init__(train_embeddings, validation_embeddings, num_classes)
42 | self.k = min(k, len(validation_embeddings.images))
43 |
44 | class_ids, counts = np.unique(self._val_embeddings.labels, return_counts=True)
45 | self._class_counts = np.zeros(self.num_classes, dtype=np.int32)
46 | self._class_counts[class_ids] = counts
47 |
48 | disable_tqdm() # Disable tqdm progress bar when building the index
49 | d = self._val_embeddings.images.shape[-1]
50 | self._index = IndexFlatL2(d)
51 | self._index.add(self._val_embeddings.images)
52 |
53 | def evaluate(self) -> float:
54 | _, nearest_indices = self._index.search(self._train_embeddings.images, self.k) # type: ignore
55 | nearest_classes = self._val_embeddings.labels[nearest_indices]
56 |
57 | # To compute retrieval accuracy, we ensure that a maximum of Q elements per sample are retrieved,
58 | # where Q represents the size of the respective class in the validation embeddings
59 | top_nearest_per_class = np.where(self._class_counts < self.k, self._class_counts, self.k)
60 | top_nearest_per_sample = top_nearest_per_class[self._train_embeddings.labels]
61 |
62 | # Add a placeholder value for indices outside the retrieval scope
63 | nearest_classes[np.arange(self.k) >= top_nearest_per_sample[:, np.newaxis]] = -1
64 |
65 | # Count the number of neighbours that match the class of the sample and compute the mean accuracy
66 | matches_per_sample = np.sum(
67 | nearest_classes == np.array(self._train_embeddings.labels)[:, np.newaxis],
68 | axis=1,
69 | )
70 | accuracies = np.divide(
71 | matches_per_sample,
72 | top_nearest_per_sample,
73 | out=np.zeros_like(matches_per_sample, dtype=np.float64),
74 | where=top_nearest_per_sample != 0,
75 | )
76 | return accuracies.mean().item()
77 |
78 | @staticmethod
79 | def get_default_params() -> dict[str, Any]:
80 | return {"k": 100}
81 |
82 |
83 | if __name__ == "__main__":
84 | np.random.seed(42)
85 | train_embeddings = Embeddings(
86 | images=np.random.randn(80, 10).astype(np.float32),
87 | labels=np.random.randint(0, 10, size=(80,)),
88 | )
89 | val_embeddings = Embeddings(
90 | images=np.random.randn(20, 10).astype(np.float32),
91 | labels=np.random.randint(0, 10, size=(20,)),
92 | )
93 | mean_accuracy = I2IRetrievalEvaluator(
94 | train_embeddings,
95 | val_embeddings,
96 | num_classes=10,
97 | ).evaluate()
98 | print(mean_accuracy)
99 |
--------------------------------------------------------------------------------
/tti_eval/evaluation/knn.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any
3 |
4 | import numpy as np
5 | from faiss import IndexFlatL2
6 |
7 | from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
8 | from tti_eval.utils import disable_tqdm
9 |
10 | from .base import ClassificationModel
11 | from .utils import softmax
12 |
13 | logger = logging.getLogger("multiclips")
14 |
15 |
16 | class WeightedKNNClassifier(ClassificationModel):
17 | @classmethod
18 | def title(cls) -> str:
19 | return "wKNN"
20 |
21 | def __init__(
22 | self,
23 | train_embeddings: Embeddings,
24 | validation_embeddings: Embeddings,
25 | num_classes: int | None = None,
26 | k: int = 11,
27 | ) -> None:
28 | """
29 | Weighted KNN classifier based on the provided embeddings and labels.
30 | Output "probabilities" is a softmax over weighted votes of class.
31 |
32 | Given q, a sample to predict, this function identifies the k nearest neighbors (e_i)
33 | with corresponding classes (y_i) from `train_embeddings.images` and `train_embeddings.labels`, respectively,
34 | and assigns a class vote of 1/||e_i - q||_2^2 for class y_i.
35 | The class with the highest vote count will be chosen.
36 |
37 | :param train_embeddings: Embeddings and their labels used for setting up the search space.
38 | :param validation_embeddings: Embeddings and their labels used for evaluating the search space.
39 | :param num_classes: Number of classes. If not specified, it will be inferred from the train labels.
40 | :param k: Number of nearest neighbors.
41 |
42 | :raises ValueError: If the build of the faiss index for KNN fails.
43 | """
44 | super().__init__(train_embeddings, validation_embeddings, num_classes)
45 | self.k = k
46 | disable_tqdm() # Disable tqdm progress bar when building the index
47 | d = train_embeddings.images.shape[-1]
48 | self._index = IndexFlatL2(d)
49 | self._index.add(train_embeddings.images)
50 |
51 | @staticmethod
52 | def get_default_params() -> dict[str, Any]:
53 | return {"k": 11}
54 |
55 | def predict(self) -> tuple[ProbabilityArray, ClassArray]:
56 | dists, nearest_indices = self._index.search(self._val_embeddings.images, self.k) # type: ignore
57 | nearest_classes = np.take(self._train_embeddings.labels, nearest_indices)
58 |
59 | # Calculate class votes from the distances (avoiding division by zero)
60 | # Note: Values stored in `dists` are the squared 2-norm values of the respective distance vectors
61 | max_value = np.finfo(np.float32).max
62 | scores = np.divide(1, dists, out=np.full_like(dists, max_value), where=dists != 0)
63 | # NOTE: if self.k and self.num_classes are both large, this might become a big one.
64 | # We can shape of a factor self.k if we count differently here.
65 | n = len(self._val_embeddings.images)
66 | weighted_count = np.zeros((n, self.num_classes, self.k), dtype=np.float32)
67 | weighted_count[
68 | np.repeat(np.arange(n), self.k), # [0, 0, .., 0_k, 1, 1, .., 1_k, ..]
69 | nearest_classes.reshape(-1), # [class numbers]
70 | np.tile(np.arange(self.k), n), # [0, 1, .., k-1, 0, 1, .., k-1, ..]
71 | ] = scores.reshape(-1)
72 | probabilities = softmax(weighted_count.sum(-1))
73 | return probabilities, np.argmax(probabilities, axis=1)
74 |
--------------------------------------------------------------------------------
/tti_eval/evaluation/linear_probe.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any
3 |
4 | import numpy as np
5 | from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
6 |
7 | from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
8 | from tti_eval.evaluation.base import ClassificationModel
9 |
10 | logger = logging.getLogger("multiclips")
11 |
12 |
13 | class LinearProbeClassifier(ClassificationModel):
14 | @classmethod
15 | def title(cls) -> str:
16 | return "linear_probe"
17 |
18 | def __init__(
19 | self,
20 | train_embeddings: Embeddings,
21 | validation_embeddings: Embeddings,
22 | num_classes: int | None = None,
23 | log_reg_params: dict[str, Any] | None = None,
24 | use_cross_validation: bool = False,
25 | ) -> None:
26 | """
27 | Logistic Regression model based on the provided embeddings and labels.
28 |
29 | :param train_embeddings: Embeddings and their labels used for setting up the model.
30 | :param validation_embeddings: Embeddings and their labels used for evaluating the model.
31 | :param num_classes: Number of classes. If not specified, it will be inferred from the train labels.
32 | :param log_reg_params: Parameters for the Logistic Regression model.
33 | :param use_cross_validation: Flag that indicated whether to use cross-validation when training the model.
34 | """
35 | super().__init__(train_embeddings, validation_embeddings, num_classes)
36 |
37 | params = log_reg_params or {}
38 | self.classifier: LogisticRegressionCV | LogisticRegression
39 | if use_cross_validation:
40 | self.classifier = LogisticRegressionCV(
41 | Cs=[1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3], random_state=0, **params
42 | ).fit(self._train_embeddings.images, self._train_embeddings.labels) # type: ignore
43 | else:
44 | self.classifier = LogisticRegression(random_state=0, **params).fit(
45 | self._train_embeddings.images, self._train_embeddings.labels
46 | )
47 |
48 | def predict(self) -> tuple[ProbabilityArray, ClassArray]:
49 | probabilities: ProbabilityArray = self.classifier.predict_proba(self._val_embeddings.images) # type: ignore
50 | return probabilities, np.argmax(probabilities, axis=1)
51 |
52 |
53 | if __name__ == "__main__":
54 | train_embeddings = Embeddings(
55 | images=np.random.randn(100, 10).astype(np.float32),
56 | labels=np.random.randint(0, 10, size=(100,)),
57 | )
58 | val_embeddings = Embeddings(
59 | images=np.random.randn(2, 10).astype(np.float32),
60 | labels=np.random.randint(0, 20, size=(2,)).astype(np.float32),
61 | )
62 | linear_probe = LinearProbeClassifier(
63 | train_embeddings,
64 | val_embeddings,
65 | num_classes=20,
66 | )
67 | probs, pred_classes = linear_probe.predict()
68 | print(probs)
69 | print(pred_classes)
70 |
--------------------------------------------------------------------------------
/tti_eval/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.typing as npt
3 |
4 | from tti_eval.common.numpy_types import DType
5 |
6 |
7 | def normalize(x: npt.NDArray[DType]) -> npt.NDArray[DType]:
8 | return x / np.linalg.norm(x, ord=2, axis=1, keepdims=True)
9 |
10 |
11 | def softmax(x: npt.NDArray[DType]) -> npt.NDArray[DType]:
12 | z = x - x.max(axis=1, keepdims=True)
13 | numerator = np.exp(z)
14 | return np.exp(z) / np.sum(numerator, axis=1, keepdims=True)
15 |
--------------------------------------------------------------------------------
/tti_eval/evaluation/zero_shot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
4 | from tti_eval.evaluation.base import ClassificationModel
5 | from tti_eval.evaluation.utils import softmax
6 |
7 |
8 | class ZeroShotClassifier(ClassificationModel):
9 | @classmethod
10 | def title(cls) -> str:
11 | return "zero_shot"
12 |
13 | def __init__(
14 | self,
15 | train_embeddings: Embeddings,
16 | validation_embeddings: Embeddings,
17 | num_classes: int | None = None,
18 | ) -> None:
19 | """
20 | Zero-Shot classifier based on the provided embeddings and labels.
21 |
22 | :param train_embeddings: Embeddings and their labels used for setting up the search space.
23 | :param validation_embeddings: Embeddings and their labels used for evaluating the search space.
24 | :param num_classes: Number of classes. If not specified, it will be inferred from the train labels.
25 | """
26 | super().__init__(train_embeddings, validation_embeddings, num_classes)
27 | if self._train_embeddings.classes is None:
28 | raise ValueError("Expected class embeddings in `train_embeddings`, got `None`")
29 |
30 | @property
31 | def dim(self) -> int:
32 | return self._train_embeddings.classes.shape[-1]
33 |
34 | def predict(self) -> tuple[ProbabilityArray, ClassArray]:
35 | inner_products = self._val_embeddings.images @ self._train_embeddings.classes.T
36 | probabilities = softmax(inner_products)
37 | return probabilities, np.argmax(probabilities, axis=1)
38 |
39 |
40 | if __name__ == "__main__":
41 | train_embeddings = Embeddings(
42 | images=np.random.randn(100, 10).astype(np.float32),
43 | labels=np.random.randint(0, 10, size=(100,)),
44 | classes=np.random.randn(20, 10).astype(np.float32),
45 | )
46 | val_embeddings = Embeddings(
47 | images=np.random.randn(2, 10).astype(np.float32),
48 | labels=np.random.randint(0, 10, size=(2,)),
49 | )
50 | zeroshot = ZeroShotClassifier(
51 | train_embeddings,
52 | val_embeddings,
53 | num_classes=20,
54 | )
55 | probs, pred_classes = zeroshot.predict()
56 | print(probs)
57 | print(pred_classes)
58 |
--------------------------------------------------------------------------------
/tti_eval/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Model
2 | from .provider import ModelProvider
3 |
--------------------------------------------------------------------------------
/tti_eval/model/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from collections.abc import Callable
3 | from os.path import relpath
4 | from pathlib import Path
5 | from typing import Any
6 |
7 | import torch
8 | from pydantic import BaseModel, ConfigDict
9 | from torch.utils.data import DataLoader
10 |
11 | from tti_eval.common.numpy_types import ClassArray, EmbeddingArray
12 | from tti_eval.constants import CACHE_PATH, SOURCES_PATH
13 |
14 | DEFAULT_MODEL_TYPES_LOCATION = (
15 | Path(relpath(str(__file__), SOURCES_PATH.MODEL_INSTANCE_DEFINITIONS)).parent / "types" / "__init__.py"
16 | )
17 |
18 |
19 | class ModelDefinitionSpec(BaseModel):
20 | model_type: str
21 | title: str
22 | module_path: Path = DEFAULT_MODEL_TYPES_LOCATION
23 | device: str | None = None
24 | title_in_source: str | None = None
25 | cache_dir: Path | None = None
26 |
27 | # Allow additional model configuration fields
28 | # Also, silence spurious Pydantic UserWarning: Field "model_type" has conflict with protected namespace "model_"
29 | model_config = ConfigDict(protected_namespaces=(), extra="allow")
30 |
31 |
32 | class Model(ABC):
33 | def __init__(
34 | self,
35 | title: str,
36 | device: str | None = None,
37 | *,
38 | title_in_source: str | None = None,
39 | cache_dir: str | None = None,
40 | **kwargs,
41 | ) -> None:
42 | self.__title = title
43 | self.__title_in_source = title_in_source or title
44 | device = device or ("cuda" if torch.cuda.is_available() else "cpu")
45 | self._check_device(device)
46 | self.__device = torch.device(device)
47 | if cache_dir is None:
48 | cache_dir = CACHE_PATH
49 | self._cache_dir = Path(cache_dir).expanduser().resolve() / "models" / title
50 |
51 | @property
52 | def title(self) -> str:
53 | return self.__title
54 |
55 | @property
56 | def title_in_source(self) -> str:
57 | return self.__title_in_source
58 |
59 | @property
60 | def device(self) -> torch.device:
61 | return self.__device
62 |
63 | @abstractmethod
64 | def _setup(self, **kwargs) -> None:
65 | pass
66 |
67 | @abstractmethod
68 | def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]:
69 | ...
70 |
71 | @abstractmethod
72 | def get_collate_fn(self) -> Callable[[Any], Any]:
73 | ...
74 |
75 | @abstractmethod
76 | def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]:
77 | ...
78 |
79 | @staticmethod
80 | def _check_device(device: str):
81 | # Check if the input device exists and is available
82 | if device not in {"cuda", "cpu"}:
83 | raise ValueError(f"Unrecognized device: {device}")
84 | if not getattr(torch, device).is_available():
85 | raise ValueError(f"Unavailable device: {device}")
86 |
--------------------------------------------------------------------------------
/tti_eval/model/provider.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any
3 |
4 | from natsort import natsorted, ns
5 |
6 | from tti_eval.constants import SOURCES_PATH
7 | from tti_eval.dataset.utils import load_class_from_path
8 |
9 | from .base import Model, ModelDefinitionSpec
10 |
11 |
12 | class ModelProvider:
13 | __instance = None
14 | __known_model_types: dict[tuple[Path, str], Any] = dict()
15 |
16 | def __init__(self) -> None:
17 | self._models = {}
18 |
19 | @classmethod
20 | def prepare(cls):
21 | if cls.__instance is None:
22 | cls.__instance = cls()
23 | cls.register_models_from_sources_dir(SOURCES_PATH.MODEL_INSTANCE_DEFINITIONS)
24 | return cls.__instance
25 |
26 | @classmethod
27 | def register_model(cls, source: type[Model], title: str, **kwargs):
28 | cls.prepare()._models[title] = (source, kwargs)
29 |
30 | @classmethod
31 | def register_model_from_json_definition(cls, json_definition: Path) -> None:
32 | spec = ModelDefinitionSpec.model_validate_json(json_definition.read_text(encoding="utf-8"))
33 | if not spec.module_path.is_absolute(): # Handle relative module paths
34 | spec.module_path = (json_definition.parent / spec.module_path).resolve()
35 | if not spec.module_path.is_file():
36 | raise FileNotFoundError(
37 | f"Could not find the specified module path at `{spec.module_path.as_posix()}` "
38 | f"when registering model `{spec.title}`"
39 | )
40 |
41 | # Fetch the class of the model type stated in the definition
42 | model_type = cls.__known_model_types.get((spec.module_path, spec.model_type))
43 | if model_type is None:
44 | model_type = load_class_from_path(spec.module_path.as_posix(), spec.model_type)
45 | if not issubclass(model_type, Model):
46 | raise ValueError(
47 | f"Model type specified in the JSON definition file `{json_definition.as_posix()}` "
48 | f"does not inherit from the base class `Model`"
49 | )
50 | cls.__known_model_types[(spec.module_path, spec.model_type)] = model_type
51 | cls.register_model(model_type, **spec.model_dump(exclude={"module_path", "model_type"}))
52 |
53 | @classmethod
54 | def register_models_from_sources_dir(cls, source_dir: Path) -> None:
55 | for f in source_dir.glob("*.json"):
56 | cls.register_model_from_json_definition(f)
57 |
58 | @classmethod
59 | def get_model(cls, title: str) -> Model:
60 | instance = cls.prepare()
61 | if title not in instance._models:
62 | raise ValueError(f"Unrecognized model: {title}")
63 | source, kwargs = instance._models[title]
64 | return source(title, **kwargs)
65 |
66 | @classmethod
67 | def list_model_titles(cls) -> list[str]:
68 | return natsorted(cls.prepare()._models.keys(), alg=ns.IGNORECASE)
69 |
--------------------------------------------------------------------------------
/tti_eval/model/types/__init__.py:
--------------------------------------------------------------------------------
1 | from tti_eval.model.types.hugging_face import HFModel
2 | from tti_eval.model.types.local_clip_model import LocalCLIPModel
3 | from tti_eval.model.types.open_clip_model import OpenCLIPModel
4 |
--------------------------------------------------------------------------------
/tti_eval/model/types/hugging_face.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from typing import Any
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from tqdm import tqdm
8 | from transformers import AutoModel as HF_AutoModel
9 | from transformers import AutoProcessor as HF_AutoProcessor
10 | from transformers import AutoTokenizer as HF_AutoTokenizer
11 |
12 | from tti_eval.common import ClassArray, EmbeddingArray
13 | from tti_eval.dataset import Dataset
14 | from tti_eval.model import Model
15 |
16 |
17 | class HFModel(Model):
18 | def __init__(
19 | self,
20 | title: str,
21 | device: str | None = None,
22 | *,
23 | title_in_source: str | None = None,
24 | cache_dir: str | None = None,
25 | **kwargs,
26 | ) -> None:
27 | super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir)
28 | self._setup(**kwargs)
29 |
30 | def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]:
31 | def process_fn(batch) -> dict[str, list[Any]]:
32 | images = [i.convert("RGB") for i in batch["image"]]
33 | batch["image"] = [
34 | self.processor(images=[i], return_tensors="pt").to(self.device).pixel_values.squeeze() for i in images
35 | ]
36 | return batch
37 |
38 | return process_fn
39 |
40 | def get_collate_fn(self) -> Callable[[Any], Any]:
41 | def collate_fn(examples) -> dict[str, torch.Tensor]:
42 | images = []
43 | labels = []
44 | for example in examples:
45 | images.append(example["image"])
46 | labels.append(example["label"])
47 |
48 | pixel_values = torch.stack(images)
49 | labels = torch.tensor(labels)
50 | return {"pixel_values": pixel_values, "labels": labels}
51 |
52 | return collate_fn
53 |
54 | def _setup(self, **kwargs) -> None:
55 | self.model = HF_AutoModel.from_pretrained(self.title_in_source, cache_dir=self._cache_dir).to(self.device)
56 | load_result = HF_AutoProcessor.from_pretrained(self.title_in_source, cache_dir=self._cache_dir)
57 | self.processor = load_result[0] if isinstance(load_result, tuple) else load_result
58 | self.tokenizer = HF_AutoTokenizer.from_pretrained(self.title_in_source, cache_dir=self._cache_dir)
59 |
60 | def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]:
61 | all_image_embeddings = []
62 | all_labels = []
63 | with torch.inference_mode():
64 | _dataset: Dataset = dataloader.dataset
65 | inputs = self.tokenizer(_dataset.text_queries, padding=True, return_tensors="pt").to(self.device)
66 | class_features = self.model.get_text_features(**inputs)
67 | normalized_class_features = class_features / class_features.norm(p=2, dim=-1, keepdim=True)
68 | class_embeddings = normalized_class_features.numpy(force=True)
69 | for batch in tqdm(
70 | dataloader,
71 | desc=f"Embedding ({_dataset.split}) {_dataset.title} dataset with {self.title}",
72 | ):
73 | image_features = self.model.get_image_features(pixel_values=batch["pixel_values"].to(self.device))
74 | normalized_image_features = (image_features / image_features.norm(p=2, dim=-1, keepdim=True)).squeeze()
75 | all_image_embeddings.append(normalized_image_features)
76 | all_labels.append(batch["labels"])
77 | image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True)
78 | labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32)
79 | return image_embeddings, class_embeddings, labels
80 |
--------------------------------------------------------------------------------
/tti_eval/model/types/local_clip_model.py:
--------------------------------------------------------------------------------
1 | from tti_eval.model.types.open_clip_model import OpenCLIPModel
2 |
3 |
4 | class LocalCLIPModel(OpenCLIPModel):
5 | def __init__(
6 | self,
7 | title: str,
8 | device: str | None = None,
9 | *,
10 | title_in_source: str,
11 | cache_dir: str | None = None,
12 | **kwargs,
13 | ) -> None:
14 | super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir, **kwargs)
15 |
16 | def _setup(self, **kwargs) -> None:
17 | self.pretrained = (self._cache_dir / "checkpoint.pt").as_posix()
18 | super()._setup(**kwargs)
19 |
--------------------------------------------------------------------------------
/tti_eval/model/types/open_clip_model.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from typing import Any
3 |
4 | import numpy as np
5 | import open_clip
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 |
10 | from tti_eval.common import ClassArray, EmbeddingArray
11 | from tti_eval.dataset import Dataset
12 | from tti_eval.model import Model
13 |
14 |
15 | class OpenCLIPModel(Model):
16 | def __init__(
17 | self,
18 | title: str,
19 | device: str | None = None,
20 | *,
21 | title_in_source: str,
22 | pretrained: str | None = None,
23 | cache_dir: str | None = None,
24 | **kwargs,
25 | ) -> None:
26 | self.pretrained = pretrained
27 | super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir, **kwargs)
28 | self._setup(**kwargs)
29 |
30 | def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]:
31 | def process_fn(batch) -> dict[str, list[Any]]:
32 | images = [i.convert("RGB") for i in batch["image"]]
33 | batch["image"] = [self.processor(i) for i in images]
34 | return batch
35 |
36 | return process_fn
37 |
38 | def get_collate_fn(self) -> Callable[[Any], Any]:
39 | def collate_fn(examples) -> dict[str, torch.Tensor]:
40 | images = []
41 | labels = []
42 | for example in examples:
43 | images.append(example["image"])
44 | labels.append(example["label"])
45 |
46 | torch_images = torch.stack(images)
47 | labels = torch.tensor(labels)
48 | return {"image": torch_images, "labels": labels}
49 |
50 | return collate_fn
51 |
52 | def _setup(self, **kwargs) -> None:
53 | self.model, _, self.processor = open_clip.create_model_and_transforms(
54 | model_name=self.title_in_source,
55 | pretrained=self.pretrained,
56 | cache_dir=self._cache_dir.as_posix(),
57 | device=self.device,
58 | **kwargs,
59 | )
60 | self.tokenizer = open_clip.get_tokenizer(model_name=self.title_in_source)
61 |
62 | def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]:
63 | all_image_embeddings = []
64 | all_labels = []
65 | with torch.inference_mode():
66 | _dataset: Dataset = dataloader.dataset
67 | text = self.tokenizer(_dataset.text_queries).to(self.device)
68 | class_embeddings = self.model.encode_text(text, normalize=True).numpy(force=True)
69 | for batch in tqdm(
70 | dataloader,
71 | desc=f"Embedding ({_dataset.split}) {_dataset.title} dataset with {self.title}",
72 | ):
73 | image_features = self.model.encode_image(batch["image"].to(self.device), normalize=True)
74 | all_image_embeddings.append(image_features)
75 | all_labels.append(batch["labels"])
76 | image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True)
77 | labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32)
78 | return image_embeddings, class_embeddings, labels
79 |
--------------------------------------------------------------------------------
/tti_eval/plotting/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/encord-team/text-to-image-eval/92ff9da010e4c00de76475a57bd52ed4d95b3716/tti_eval/plotting/__init__.py
--------------------------------------------------------------------------------
/tti_eval/plotting/animation.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from pathlib import Path
3 | from typing import Literal, overload
4 |
5 | import matplotlib.animation as animation
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | from matplotlib.widgets import Slider
9 | from PIL import Image
10 |
11 | from tti_eval.common import EmbeddingDefinition
12 | from tti_eval.common.base import SafeName
13 | from tti_eval.common.numpy_types import ClassArray, N2Array
14 | from tti_eval.constants import OUTPUT_PATH
15 | from tti_eval.dataset import DatasetProvider, Split
16 |
17 | from .reduction import REDUCTIONS, reduction_from_string
18 |
19 |
20 | @overload
21 | def create_embedding_chart(
22 | x1: N2Array,
23 | x2: N2Array,
24 | labels: ClassArray,
25 | title1: str | SafeName,
26 | title2: str | SafeName,
27 | suptitle: str,
28 | label_names: list[str],
29 | *,
30 | interactive: Literal[False],
31 | ) -> animation.FuncAnimation:
32 | ...
33 |
34 |
35 | @overload
36 | def create_embedding_chart(
37 | x1: N2Array,
38 | x2: N2Array,
39 | labels: ClassArray,
40 | title1: str | SafeName,
41 | title2: str | SafeName,
42 | suptitle: str,
43 | label_names: list[str],
44 | *,
45 | interactive: Literal[True],
46 | ) -> None:
47 | ...
48 |
49 |
50 | @overload
51 | def create_embedding_chart(
52 | x1: N2Array,
53 | x2: N2Array,
54 | labels: ClassArray,
55 | title1: str | SafeName,
56 | title2: str | SafeName,
57 | suptitle: str,
58 | label_names: list[str],
59 | *,
60 | interactive: bool,
61 | ) -> animation.FuncAnimation | None:
62 | ...
63 |
64 |
65 | def create_embedding_chart(
66 | x1: N2Array,
67 | x2: N2Array,
68 | labels: ClassArray,
69 | title1: str | SafeName,
70 | title2: str | SafeName,
71 | suptitle: str,
72 | label_names: list[str],
73 | *,
74 | interactive: bool = False,
75 | ) -> animation.FuncAnimation | None:
76 | fig, ax = plt.subplots(figsize=(12, 7))
77 | fig.suptitle(suptitle)
78 |
79 | # Compute chart bounds
80 | all = np.concatenate([x1, x2], axis=0)
81 | xmin, ymin = all.min(0)
82 | xmax, ymax = all.max(0)
83 | dx = xmax - xmin
84 | dy = ymax - ymin
85 | xmin -= dx * 0.1
86 | xmax += dx * 0.1
87 | ymin -= dy * 0.1
88 | ymax += dy * 0.1
89 |
90 | # Initial plot
91 | points = plt.scatter(*x1.T, c=labels)
92 |
93 | ax.axis("off")
94 | ax.set_xlim([xmin, xmax])
95 | ax.set_ylim([ymin, ymax])
96 |
97 | handles, _ = points.legend_elements()
98 | ax.legend(handles, label_names, loc="upper right")
99 |
100 | # Position scatter and slider
101 | s_width, s_height = (0.4, 0.1)
102 | x_padding = (1 - s_width) / 2
103 |
104 | fig.subplots_adjust(bottom=s_height * 1.5)
105 | # Insert Logo
106 | logo = Image.open(Path(__file__).parents[2] / "images" / "encord.png")
107 | w, h = logo.size
108 | logo_padding_x = 0.02
109 | ax_w = x_padding - logo_padding_x * 2
110 | ax_h = ax_w * h / w
111 | ax_logo = fig.add_axes((0.02, 0.02, ax_w, ax_h))
112 | ax_logo.axis("off")
113 | ax_logo.imshow(logo)
114 |
115 | # Build slider
116 | ax_interpolation = fig.add_axes((x_padding, s_height, s_width, s_height / 2))
117 | init_time = 0
118 |
119 | def interpolate(t):
120 | t = max(min(t, 1), 0)
121 | return ((1 - t) * x1 + t * x2).T
122 |
123 | interpolate_slider = Slider(
124 | ax=ax_interpolation,
125 | label="",
126 | valmin=0.0,
127 | valmax=1.0,
128 | valinit=init_time,
129 | )
130 |
131 | # Add model_titles
132 | fig.text(x_padding - 0.02, s_height * 1.25, title1, ha="right", va="center")
133 | fig.text(1 - x_padding + 0.06, 0.1 * 1.25, title2, ha="left", va="center")
134 |
135 | def update_from_slider(val):
136 | points.set_offsets(interpolate(interpolate_slider.val).T)
137 |
138 | interpolate_slider.on_changed(update_from_slider)
139 |
140 | if interactive:
141 | plt.show()
142 | return
143 |
144 | # Animation bit
145 | frames_left = np.linspace(0, 1, 20) ** 4 / 2
146 | frames = np.concatenate([frames_left, 1 - frames_left[::-1][1:]], axis=0)
147 | frames = np.concatenate([frames, frames[::-1]], axis=0)
148 |
149 | def update_from_animation(val):
150 | interpolate_slider.set_val(val)
151 |
152 | interpolate_slider.set_active(False)
153 |
154 | return animation.FuncAnimation(fig, update_from_animation, frames=frames) # type: ignore
155 |
156 |
157 | def standardize(a: N2Array) -> N2Array:
158 | mins = a.min(0, keepdims=True)
159 | maxs = a.max(0, keepdims=True)
160 | return (a - mins) / (maxs - mins)
161 |
162 |
163 | def rotate_to_target(source: N2Array, destination: N2Array) -> N2Array:
164 | from scipy.spatial.transform import Rotation as R
165 |
166 | source = np.pad(source, [(0, 0), (0, 1)], mode="constant", constant_values=0.0)
167 | destination = np.pad(destination, [(0, 0), (0, 1)], mode="constant", constant_values=0.0)
168 |
169 | rot, *_ = R.align_vectors(destination, source, return_sensitivity=True)
170 | out = rot.apply(source)
171 | return out[:, :2]
172 |
173 |
174 | @overload
175 | def build_animation(
176 | defn_1: EmbeddingDefinition,
177 | defn_2: EmbeddingDefinition,
178 | *,
179 | interactive: Literal[True],
180 | reduction: REDUCTIONS = "umap",
181 | ) -> None:
182 | ...
183 |
184 |
185 | @overload
186 | def build_animation(
187 | defn_1: EmbeddingDefinition,
188 | defn_2: EmbeddingDefinition,
189 | *,
190 | reduction: REDUCTIONS = "umap",
191 | interactive: Literal[False],
192 | ) -> animation.FuncAnimation:
193 | ...
194 |
195 |
196 | @overload
197 | def build_animation(
198 | defn_1: EmbeddingDefinition,
199 | defn_2: EmbeddingDefinition,
200 | *,
201 | reduction: REDUCTIONS = "umap",
202 | interactive: bool,
203 | ) -> animation.FuncAnimation | None:
204 | ...
205 |
206 |
207 | def build_animation(
208 | defn_1: EmbeddingDefinition,
209 | defn_2: EmbeddingDefinition,
210 | *,
211 | split: Split = Split.VALIDATION,
212 | reduction: REDUCTIONS = "umap",
213 | interactive: bool = False,
214 | ) -> animation.FuncAnimation | None:
215 | dataset = DatasetProvider.get_dataset(defn_1.dataset, split)
216 |
217 | embeds = defn_1.load_embeddings(split) # FIXME: This is expensive to get just labels
218 | if embeds is None:
219 | raise ValueError("Empty embeddings")
220 |
221 | reducer = reduction_from_string(reduction)
222 | reduced_1 = standardize(reducer.get_reduction(defn_1, split))
223 | reduced_2 = rotate_to_target(standardize(reducer.get_reduction(defn_2, split)), reduced_1)
224 | labels = embeds.labels
225 |
226 | if reduced_1.shape[0] > 2_000:
227 | selection = np.random.permutation(reduced_1.shape[0])[:2_000]
228 | reduced_1 = reduced_1[selection]
229 | reduced_2 = reduced_2[selection]
230 | labels = labels[selection]
231 |
232 | return create_embedding_chart(
233 | reduced_1,
234 | reduced_2,
235 | labels,
236 | defn_1.model,
237 | defn_2.model,
238 | suptitle=dataset.title,
239 | label_names=dataset.class_names,
240 | interactive=interactive,
241 | )
242 |
243 |
244 | def save_animation_to_file(
245 | anim: animation.FuncAnimation,
246 | def1: EmbeddingDefinition,
247 | def2: EmbeddingDefinition,
248 | split: Split = Split.VALIDATION,
249 | ):
250 | date_code = datetime.now().strftime("%Y%m%d-%H%M%S")
251 | file_name = f"transition_{def1.dataset}_{def1.model}_{def2.model}_{split}_{date_code}.gif"
252 | animation_file = OUTPUT_PATH.ANIMATIONS / file_name
253 | animation_file.parent.mkdir(parents=True, exist_ok=True) # Ensure that parent folder exists
254 | anim.save(animation_file)
255 | print(f"Animation stored at `{animation_file.resolve().as_posix()}`")
256 |
--------------------------------------------------------------------------------
/tti_eval/plotting/reduction.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Literal
3 |
4 | import numpy as np
5 | import umap
6 | from sklearn.decomposition import PCA
7 | from sklearn.manifold import TSNE
8 |
9 | from tti_eval.common import EmbeddingArray, EmbeddingDefinition, ReductionArray, Split
10 |
11 |
12 | class Reducer:
13 | @classmethod
14 | @abstractmethod
15 | def title(cls) -> str:
16 | raise NotImplementedError("This abstract method returns the title of the reducer implemented in this class.")
17 |
18 | @classmethod
19 | @abstractmethod
20 | def reduce(cls, embeddings: EmbeddingArray, **kwargs) -> ReductionArray:
21 | raise NotImplementedError("This abstract method contains the implementation for reducing embeddings.")
22 |
23 | @classmethod
24 | def get_reduction(
25 | cls,
26 | embedding_def: EmbeddingDefinition,
27 | split: Split,
28 | force_recompute: bool = False,
29 | save: bool = True,
30 | **kwargs,
31 | ) -> ReductionArray:
32 | reduction_file = embedding_def.get_reduction_path(cls.title(), split=split)
33 | if reduction_file.is_file() and not force_recompute:
34 | reduction: ReductionArray = np.load(reduction_file)
35 | return reduction
36 |
37 | elif not embedding_def.embedding_path(split).is_file():
38 | raise ValueError(
39 | f"{repr(embedding_def)} does not have embeddings stored ({embedding_def.embedding_path(split)})"
40 | )
41 |
42 | image_embeddings: EmbeddingArray = np.load(embedding_def.embedding_path(split))["image_embeddings"]
43 | reduction = cls.reduce(image_embeddings)
44 | if save:
45 | reduction_file.parent.mkdir(parents=True, exist_ok=True)
46 | np.save(reduction_file, reduction)
47 | return reduction
48 |
49 |
50 | class UMAPReducer(Reducer):
51 | @classmethod
52 | def title(cls) -> str:
53 | return "umap"
54 |
55 | @classmethod
56 | def reduce(cls, embeddings: EmbeddingArray, umap_seed: int | None = None, **kwargs) -> ReductionArray:
57 | reducer = umap.UMAP()
58 | return reducer.fit_transform(embeddings)
59 |
60 |
61 | class TSNEReducer(Reducer):
62 | @classmethod
63 | def title(cls) -> str:
64 | return "tsne"
65 |
66 | @classmethod
67 | def reduce(cls, embeddings: EmbeddingArray, **kwargs) -> ReductionArray:
68 | reducer = TSNE()
69 | return reducer.fit_transform(embeddings)
70 |
71 |
72 | class PCAReducer(Reducer):
73 | @classmethod
74 | def title(cls) -> str:
75 | return "pca"
76 |
77 | @classmethod
78 | def reduce(cls, embeddings: EmbeddingArray, **kwargs) -> ReductionArray:
79 | reducer = PCA(n_components=2)
80 | return reducer.fit_transform(embeddings)
81 |
82 |
83 | __REDUCTIONS = {
84 | UMAPReducer.title(): UMAPReducer,
85 | TSNEReducer.title(): TSNEReducer,
86 | PCAReducer.title(): PCAReducer,
87 | }
88 | REDUCTIONS = Literal["umap"] | Literal["tsne"] | Literal["pca"]
89 |
90 |
91 | def reduction_from_string(name: str) -> UMAPReducer | TSNEReducer | PCAReducer:
92 | if name not in __REDUCTIONS:
93 | raise KeyError(f"{name} not in set {set(__REDUCTIONS.keys())}")
94 | return __REDUCTIONS[name]
95 |
96 |
97 | if __name__ == "__main__":
98 | embeddings = np.random.randn(100, 20)
99 |
100 | for cls in [UMAPReducer, TSNEReducer, PCAReducer]:
101 | print(cls.reduce(embeddings).shape)
102 |
--------------------------------------------------------------------------------
/tti_eval/utils.py:
--------------------------------------------------------------------------------
1 | from functools import partialmethod
2 | from itertools import chain
3 | from typing import Literal, overload
4 |
5 | from tqdm import tqdm
6 |
7 | from tti_eval.common import EmbeddingDefinition
8 | from tti_eval.constants import PROJECT_PATHS
9 |
10 |
11 | def disable_tqdm():
12 | tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
13 |
14 |
15 | def enable_tqdm():
16 | tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)
17 |
18 |
19 | @overload
20 | def read_all_cached_embeddings(as_list: Literal[True]) -> list[EmbeddingDefinition]:
21 | ...
22 |
23 |
24 | @overload
25 | def read_all_cached_embeddings(as_list: Literal[False] = False) -> dict[str, list[EmbeddingDefinition]]:
26 | ...
27 |
28 |
29 | def read_all_cached_embeddings(
30 | as_list: bool = False,
31 | ) -> dict[str, list[EmbeddingDefinition]] | list[EmbeddingDefinition]:
32 | """
33 | Reads existing embedding definitions from the cache directory.
34 | Returns: a dictionary of where the list is over models.
35 | """
36 | if not PROJECT_PATHS.EMBEDDINGS.exists():
37 | return dict()
38 |
39 | defs_dict = {
40 | d.name: list(
41 | {
42 | EmbeddingDefinition(dataset=d.name, model=m.stem.rsplit("_", maxsplit=1)[0])
43 | for m in d.iterdir()
44 | if m.is_file() and m.suffix == ".npz"
45 | }
46 | )
47 | for d in PROJECT_PATHS.EMBEDDINGS.iterdir()
48 | if d.is_dir()
49 | }
50 | if as_list:
51 | return list(chain(*defs_dict.values()))
52 | return defs_dict
53 |
54 |
55 | if __name__ == "__main__":
56 | print(read_all_cached_embeddings())
57 |
--------------------------------------------------------------------------------