├── .gitignore ├── README.md ├── demo.py ├── mmc_model_RAM_profiling.ipynb ├── pyproject.toml.INSTALL-ALL ├── setup.py ├── src └── mmc │ ├── __init__.py │ ├── ez │ ├── CLIP.py │ └── __init__.py │ ├── loaders │ ├── __init__.py │ ├── basemmcloader.py │ ├── clipfaloader.py │ ├── cloobloader.py │ ├── fairsliploader.py │ ├── keliploader.py │ ├── mlfcliploader.py │ ├── openaicliploader.py │ └── sbertclibloader.py │ ├── mock │ ├── __init__.py │ └── openai.py │ ├── modalities.py │ ├── multimmc.py │ ├── multimodalcomparator.py │ ├── napm_installs │ └── __init__.py │ └── registry.py └── tests ├── __init__.py ├── assets ├── dummy.txt └── marley_birthday.jpg ├── test_api_mock.py ├── test_ezmode_clip.py ├── test_mmc.py ├── test_mmc_clipfa.py ├── test_mmc_fairslip.py ├── test_mmc_fairslip_cc12m.py ├── test_mmc_fairslip_cc3m.py ├── test_mmc_katcloob.py ├── test_mmc_loaders.py ├── test_mmc_mlf.py ├── test_mmc_sbert.py ├── test_modalities.py ├── test_module_mock.py ├── test_multimmc.py └── test_registry.py /.gitignore: -------------------------------------------------------------------------------- 1 | # for vs code 2 | settings.json 3 | 4 | Pipfile.lock 5 | poetry.lock 6 | 7 | _venv 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mmc 2 | 3 | # installation 4 | 5 | ``` 6 | git clone https://github.com/dmarx/Multi-Modal-Comparators 7 | cd 'Multi-Modal-Comparators' 8 | cp pyproject.toml.INSTALL-ALL pyproject.toml 9 | pip install poetry 10 | poetry build 11 | pip install dist/mmc*.whl 12 | 13 | # optional final step: 14 | #poe napm_installs 15 | python src/mmc/napm_installs/__init__.py 16 | ``` 17 | 18 | To see which models are immediately available, run: 19 | 20 | ``` 21 | python -m mmc.loaders 22 | ``` 23 | 24 | ### That optional `poe napm_installs` step 25 | 26 | For the most convenient experience, it is recommended that you perform the final `poe napm_installs` step. 27 | Omitting this step will make your one-time setup faster, but will make certain use cases more complex. 28 | 29 | If you did not perform the optional `poe napm_installs` step, you likely received several warnings about 30 | models whose loaders could not be registered. These are models whose codebases depend on python code which 31 | is not trivially installable. You will still have access to all of the models supported by the library as if 32 | you had run the last step, but their loaders will not be queryable from the registry (see below) and will need 33 | to be loaded via the appropriate mmc.loader directly, which may be non-trivial to identify without the ability to 34 | query it from mmc's registry. 35 | 36 | As a concrete example, if the napm step is skipped, the model `[cloob - corwsonkb - cloob_laion_400m_vit_b_16_32_epochs]` 37 | will not appear in the list of registered loaders, but can still be loaded like this: 38 | 39 | ``` 40 | from mmc.loaders import KatCloobLoader 41 | 42 | model = KatCloobLoader(id='cloob_laion_400m_vit_b_16_32_epochs').load() 43 | ``` 44 | 45 | Invoking the `load()` method on an unregistered loader will invoke [napm](https://github.com/dmarx/not-a-package-manager) 46 | to prepare any uninstallable dependencies required to load the model. Next time you run `python -m mmc.loaders`, 47 | the CLOOB loader will show as registered and spinning up the registry will longer emit a warning for that model. 48 | 49 | 50 | # Usage 51 | 52 | **TLDR** 53 | 54 | ``` 55 | # spin up the registry 56 | from mmc import loaders 57 | 58 | ## Using the 'mocked' openai API to load supported CLIP models 59 | 60 | from mmc.ez.CLIP import clip 61 | clip.available_models() 62 | 63 | # requesting a tokenizer before loading the model 64 | # returns the openai clip SimpleTokenizer 65 | #tokenize = clip.tokenize 66 | 67 | # either of these works 68 | model, preprocessor = clip.load('RN50') 69 | model, preprocessor = clip.load('[clip - openai - RN50]') 70 | 71 | # if we request the tokenizer *after* a model has been loaded, 72 | # the tokenizer appropriate to the loaded model is returned 73 | tokenize = clip.tokenize 74 | 75 | ############################### 76 | 77 | ## Slightly "closer to the metal" usage 78 | 79 | from mmc.mock.openai import MockOpenaiClip 80 | from mmc.registry import REGISTRY 81 | 82 | cloob_query = {architecture='cloob'} 83 | cloob_loaders = REGISTRY.find(**cloob_query) 84 | 85 | # loader repl prints attributes for uniquely querying 86 | print(cloob_loaders) 87 | 88 | # loader returns a perceptor whose API is standardized across mmc 89 | cloob_model = cloob_loaders[0].load() 90 | 91 | # wrapper classes are provided for mocking popular implementations 92 | # to facilitate drop-in compatibility with existing code 93 | drop_in_replacement__cloob_model = MockOpenaiClip(cloob_model) 94 | ``` 95 | 96 | ## Querying the Model Registry 97 | 98 | Spin up the model registry by importing the loaders module: 99 | 100 | ```from mmc import loaders``` 101 | 102 | To see which models are available: 103 | 104 | ``` 105 | from mmc.registry import REGISTRY 106 | 107 | for loader in REGISTRY.find(): 108 | print(loader) 109 | ``` 110 | 111 | You can constrain the result set by querying the registry for specific metadata attributes 112 | 113 | ``` 114 | # all CLIP models 115 | clip_loaders = REGISTRY.find(architecture='clip') 116 | 117 | # CLIP models published by openai 118 | openai_clip_loaders = REGISTRY.find(architecture='clip', publisher='openai') 119 | 120 | # All models published by MLFoundations (openCLIP) 121 | mlf_loaders = REGISTRY.find(publisher='mlfoundations)' 122 | 123 | # A specific model 124 | rn50_loader = REGISTRY.find(architecture='clip', publisher='openai', id='RN50') 125 | # NB: there may be multiple models matching a particular "id". the 'id' field 126 | # only needs to be unique for a given architecture-publisher pair. 127 | ``` 128 | 129 | All pretrained checkpoints are uniquely identifiable by a combination of `architecture`, `publisher`, and `id`. 130 | 131 | The above queries return lists of **loader** objects. If model artifacts (checkpoints, config) need to be downloaded, they will only be downloaded after the `load()` method on the loader is invoked. 132 | 133 | ``` 134 | loaders = REGISTRY.find(...) 135 | loader = loaders[0] # just picking an arbitrary return value here, remember: loaders is a *list* of loaders 136 | model = loader.load() 137 | ``` 138 | 139 | The `load()` method returns an instance of an `mmc.MultiModalComparator`. The `MultiModalComparator` class 140 | is a modality-agnostic abstraction. I'll get to the ins and outs of that another time. 141 | 142 | ## API Mocking 143 | 144 | You want something you can just drop into your code and it'll work. We got you. This library provides wrapper 145 | classes to mock the APIs of commonly used CLIP implementations (at present, OpenAI's CLIP is the only API which can be mocked). 146 | Individual loaders can be wrapped after instantiation (see below), but we also provide an "easy mode" API for best user experience. 147 | 148 | ### Using the 'Easy Mode' CLIP API 149 | 150 | Let's consider a codebase that already has openai/CLIP installed, via e.g. `pip install git+https://github.com/openai/CLIP` or `pip install clip-anytorch`. 151 | 152 | All we have to do to integrate MMC is change 153 | 154 | **Step 1:** 155 | 156 | ``` 157 | # let's call this the "normal clip object" 158 | import clip 159 | ``` 160 | 161 | to 162 | 163 | ``` 164 | #from mmc import loaders ## optional, populates the mmc registry with all supported loaders 165 | from mmc.ez.CLIP import clip 166 | ``` 167 | 168 | **Step 2.** (optional but strongly advised) 169 | 170 | Make sure an references to `clip.tokenize` appear *after* `clip.load()` has already been invoked 171 | 172 | 173 | And that's it! 174 | 175 | 176 | Here's what this change gives us: 177 | 178 | * `clip.available_models()` 179 | - In addition to all of the values normally returned when this method is invoked on the normal `clip` object, also returns all currently available models known to the MMC registry. 180 | 181 | * `clip.load()` 182 | - Supports any model aliases returned by `clip.available_models()` 183 | - Invoking `clip.load()` with arguments like `'RN50'` or `'ViT-B/16'` will return the expected OopenAI clip model. 184 | - Additionally supports loading models using the mmc alias convention, i.e. `clip.load('RN50')` is equivalent to `clip.load('[clip - openai - RN50]')` 185 | 186 | * `clip.tokenize` 187 | - Returns the OpenAI tokenizer by default, exactly as if we had invoked `clip.tokenize` on the original `clip` object rather than `mmc.ez.clip`. 188 | - if a model has already been loaded using the `clip.load()` method above, then `clip.tokenize` returns the text preprocessor appropriate to that model. At present, most CLIP implementations are published with Openai's tokenizer so you might not experience any issues if you don't reorganize this part of your code. 189 | 190 | ### Mocking Individual Loaders 191 | 192 | To wrap a `MultiModalComparator` so it can 193 | be used as a drop-in replacement with code compatible with OpenAI's CLIP: 194 | 195 | ``` 196 | from mmc.mock.openai import MockOpenaiClip 197 | 198 | my_model = my_model_loader.load() 199 | model = MockOpenaiClip(my_model) 200 | ``` 201 | 202 | ## MultiMMC: Multi-Perceptor Implementation 203 | 204 | *(WIP, behavior likely to change in near future)* 205 | 206 | The `MultiMMC` class can be used to run inference against multiple mmc models in parallel. This form of 207 | ensemble is sometimes referred to as a "multi-perceptor". 208 | 209 | To ensure that all models loaded into the MultiMMC are compatible, the MultiMMC instance is initialized 210 | by specifying the modalities it supports. We'll discuss modality objects in a bit. 211 | 212 | ``` 213 | from mmc.multimmc import MultiMMC 214 | from mmc.modalities import TEXT, IMAGE 215 | 216 | perceptor = MultiMMC(TEXT, IMAGE) 217 | ``` 218 | 219 | To load and use a model: 220 | 221 | ``` 222 | perceptor.load_model( 223 | architecture='clip', 224 | publisher='openai', 225 | id='RN50', 226 | ) 227 | 228 | score = perceptor.compare( 229 | image=PIL.Image.open(...), 230 | text=text_pos), 231 | ) 232 | ``` 233 | 234 | Additional models can be added to the ensemble via the `load_model()` method. 235 | 236 | The MultiMMC does not support API mocking because of its reliance on the `compare` method. 237 | 238 | 239 | # Available Pre-trained Models 240 | 241 | Some model comparisons [here](https://t.co/iShJpm5GjL) 242 | 243 | ``` 244 | # [ - - ] 245 | [clip - openai - RN50] 246 | [clip - openai - RN101] 247 | [clip - openai - RN50x4] 248 | [clip - openai - RN50x16] 249 | [clip - openai - RN50x64] 250 | [clip - openai - ViT-B/32] 251 | [clip - openai - ViT-B/16] 252 | [clip - openai - ViT-L/14] 253 | [clip - openai - ViT-L/14@336px] 254 | [clip - mlfoundations - RN50--openai] 255 | [clip - mlfoundations - RN50--yfcc15m] 256 | [clip - mlfoundations - RN50--cc12m] 257 | [clip - mlfoundations - RN50-quickgelu--openai] 258 | [clip - mlfoundations - RN50-quickgelu--yfcc15m] 259 | [clip - mlfoundations - RN50-quickgelu--cc12m] 260 | [clip - mlfoundations - RN101--openai] 261 | [clip - mlfoundations - RN101--yfcc15m] 262 | [clip - mlfoundations - RN101-quickgelu--openai] 263 | [clip - mlfoundations - RN101-quickgelu--yfcc15m] 264 | [clip - mlfoundations - RN50x4--openai] 265 | [clip - mlfoundations - RN50x16--openai] 266 | [clip - mlfoundations - ViT-B-32--openai] 267 | [clip - mlfoundations - ViT-B-32--laion400m_e31] 268 | [clip - mlfoundations - ViT-B-32--laion400m_e32] 269 | [clip - mlfoundations - ViT-B-32--laion400m_avg] 270 | [clip - mlfoundations - ViT-B-32-quickgelu--openai] 271 | [clip - mlfoundations - ViT-B-32-quickgelu--laion400m_e31] 272 | [clip - mlfoundations - ViT-B-32-quickgelu--laion400m_e32] 273 | [clip - mlfoundations - ViT-B-32-quickgelu--laion400m_avg] 274 | [clip - mlfoundations - ViT-B-16--openai] 275 | [clip - mlfoundations - ViT-L-14--openai] 276 | [clip - sbert - ViT-B-32-multilingual-v1] 277 | [clip - sajjjadayobi - clipfa] 278 | 279 | # The following models depend on napm for setup 280 | [clip - navervision - kelip_ViT-B/32] 281 | [cloob - crowsonkb - cloob_laion_400m_vit_b_16_16_epochs] 282 | [cloob - crowsonkb - cloob_laion_400m_vit_b_16_32_epochs] 283 | [clip - facebookresearch - clip_small_25ep] 284 | [clip - facebookresearch - clip_base_25ep] 285 | [clip - facebookresearch - clip_large_25ep] 286 | [slip - facebookresearch - slip_small_25ep] 287 | [slip - facebookresearch - slip_small_50ep] 288 | [slip - facebookresearch - slip_small_100ep] 289 | [slip - facebookresearch - slip_base_25ep] 290 | [slip - facebookresearch - slip_base_50ep] 291 | [slip - facebookresearch - slip_base_100ep] 292 | [slip - facebookresearch - slip_large_25ep] 293 | [slip - facebookresearch - slip_large_50ep] 294 | [slip - facebookresearch - slip_large_100ep] 295 | [simclr - facebookresearch - simclr_small_25ep] 296 | [simclr - facebookresearch - simclr_base_25ep] 297 | [simclr - facebookresearch - simclr_large_25ep] 298 | [clip - facebookresearch - clip_base_cc3m_40ep] 299 | [clip - facebookresearch - clip_base_cc12m_35ep] 300 | [slip - facebookresearch - slip_base_cc3m_40ep] 301 | [slip - facebookresearch - slip_base_cc12m_35ep] 302 | ``` 303 | 304 | # VRAM Cost 305 | 306 | The following is an estimate of the amount of space the loaded model occupies in memory: 307 | 308 | | | publisher | architecture | model_name | vram_mb | 309 | |---:|:-----------------|:---------------|:------------------------------------|----------:| 310 | | 0 | openai | clip | RN50 | 358 | 311 | | 1 | openai | clip | RN101 | 294 | 312 | | 2 | openai | clip | RN50x4 | 424 | 313 | | 3 | openai | clip | RN50x16 | 660 | 314 | | 4 | openai | clip | RN50x64 | 1350 | 315 | | 5 | openai | clip | ViT-B/32 | 368 | 316 | | 6 | openai | clip | ViT-B/16 | 348 | 317 | | 7 | openai | clip | ViT-L/14 | 908 | 318 | | 8 | openai | clip | ViT-L/14@336px | 908 | 319 | | 9 | mlfoundations | clip | RN50--openai | 402 | 320 | | 10 | mlfoundations | clip | RN50--yfcc15m | 402 | 321 | | 11 | mlfoundations | clip | RN50--cc12m | 402 | 322 | | 12 | mlfoundations | clip | RN50-quickgelu--openai | 402 | 323 | | 13 | mlfoundations | clip | RN50-quickgelu--yfcc15m | 402 | 324 | | 14 | mlfoundations | clip | RN50-quickgelu--cc12m | 402 | 325 | | 15 | mlfoundations | clip | RN101--openai | 476 | 326 | | 16 | mlfoundations | clip | RN101--yfcc15m | 476 | 327 | | 17 | mlfoundations | clip | RN101-quickgelu--openai | 476 | 328 | | 18 | mlfoundations | clip | RN101-quickgelu--yfcc15m | 476 | 329 | | 19 | mlfoundations | clip | RN50x4--openai | 732 | 330 | | 20 | mlfoundations | clip | RN50x16--openai | 1200 | 331 | | 21 | mlfoundations | clip | ViT-B-32--openai | 634 | 332 | | 22 | mlfoundations | clip | ViT-B-32--laion400m_e31 | 634 | 333 | | 23 | mlfoundations | clip | ViT-B-32--laion400m_e32 | 634 | 334 | | 24 | mlfoundations | clip | ViT-B-32--laion400m_avg | 634 | 335 | | 25 | mlfoundations | clip | ViT-B-32-quickgelu--openai | 634 | 336 | | 26 | mlfoundations | clip | ViT-B-32-quickgelu--laion400m_e31 | 634 | 337 | | 27 | mlfoundations | clip | ViT-B-32-quickgelu--laion400m_e32 | 634 | 338 | | 28 | mlfoundations | clip | ViT-B-32-quickgelu--laion400m_avg | 634 | 339 | | 29 | mlfoundations | clip | ViT-B-16--openai | 634 | 340 | | 30 | mlfoundations | clip | ViT-L-14--openai | 1688 | 341 | | 32 | sajjjadayobi | clip | clipfa | 866 | 342 | | 33 | crowsonkb | cloob | cloob_laion_400m_vit_b_16_16_epochs | 610 | 343 | | 34 | crowsonkb | cloob | cloob_laion_400m_vit_b_16_32_epochs | 610 | 344 | | 36 | facebookresearch | slip | slip_small_25ep | 728 | 345 | | 37 | facebookresearch | slip | slip_small_50ep | 650 | 346 | | 38 | facebookresearch | slip | slip_small_100ep | 650 | 347 | | 39 | facebookresearch | slip | slip_base_25ep | 714 | 348 | | 40 | facebookresearch | slip | slip_base_50ep | 714 | 349 | | 41 | facebookresearch | slip | slip_base_100ep | 714 | 350 | | 42 | facebookresearch | slip | slip_large_25ep | 1534 | 351 | | 43 | facebookresearch | slip | slip_large_50ep | 1522 | 352 | | 44 | facebookresearch | slip | slip_large_100ep | 1522 | 353 | | 45 | facebookresearch | slip | slip_base_cc3m_40ep | 714 | 354 | | 46 | facebookresearch | slip | slip_base_cc12m_35ep | 714 | 355 | 356 | # Contributing 357 | 358 | ## Suggest a pre-trained model 359 | 360 | If you would like to suggest a pre-trained model for future addition, you can add a comment to [this issue](https://github.com/dmarx/Multi-Modal-Comparators/issues/2) 361 | 362 | ## Add a pre-trained model 363 | 364 | 1. Create a loader class that encapsulates the logic for importing the model, loading weights, preprocessing inputs, and performing projections. 365 | 2. At the bottom of the file defining the loader class should be a code snippet that adds each respective checkpoint's loader to the registry. 366 | 3. Add an import for the new file to `mmc/loaders/__init__.py`. The imports in this file are the reason `import mmc.loaders` "spins up" the registry. 367 | 4. If the codebase on which the model depends can be installed, update `pytproject.toml` to install it. 368 | 5. Otherwise, add napm preparation at the top of the loaders `load` method (see cloob or kelip for examples), and also add napm setup to `mmc/napm_installs/__init__.py` 369 | 6. Add a test case to tests/test_mmc_loaders.py 370 | 7. Add a test script for the loader (see `test_mmc_katcloob` as an example) 371 | 372 | 373 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mmc.multimmc import MultiMMC 4 | from mmc.modalities import TEXT, IMAGE 5 | 6 | 7 | # for now at least, I'm referring to models like CLIP, CLOOB, SLIP etc. as 8 | # "multi-modal comparators" (MMCs). The MultiMMC class is a generic wrapper 9 | # that serves the same function as like a "MultiCLIPPerceptor", but is 10 | # intended to be suficiently generic to be able to wrap collections of models 11 | # that aren't all from the same family. The conly constraint is that 12 | # the individual MMCs attached to the MultiMMC must each be compatible with 13 | # the modalities the MultiMMC supports. 14 | 15 | perceptor = MultiMMC(TEXT, IMAGE)#, shared_latent=True) 16 | 17 | oa_clip_modelnames = [ 18 | 'RN50', 19 | 'RN101', 20 | 'ViTL64', 21 | ... 22 | ] 23 | 24 | #perceptor.load_model(architecture='slip', id='some-clip-model') 25 | #perceptor.load_model(architecture='blip', id='that-one-blip-model') 26 | 27 | # Individual MMCs can be ascribed weights. Potentially ways this could be used: 28 | # * weighted ensemble of perceptors 29 | # * compensate for perceptors that produce outputs at different scales 30 | for model_name in oa_clip_modelnames: 31 | perceptor.load_model( 32 | architecture='clip', 33 | publisher='openai', 34 | id=model_name, 35 | #weight=1, # default 36 | ) 37 | 38 | # add a model that takes 50% responsibility for score, cause why not 39 | perceptor.load_model( 40 | architecture='cloob', 41 | publisher='crowsonkb', 42 | weight=len(perceptor.models), 43 | ) 44 | 45 | logger.debug(perceptor.models.keys()) 46 | 47 | assert perceptor.supports_text 48 | assert perceptor.supports_image 49 | #assert perceptor.has_shared_latent 50 | 51 | [m.name for m in perceptor.modalities] 52 | 53 | 54 | text=["foo bar baz"] 55 | image=IMAGE.read_from_disk('foobar.jpg') 56 | 57 | multi_similarity_score = perceptor.compare( 58 | text=text_container, 59 | image=image_container, 60 | return_projections = False, 61 | ) 62 | 63 | -------------------------------------------------------------------------------- /pyproject.toml.INSTALL-ALL: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "mmc" #"Multi-Modal Comparators" 3 | version = "0.1.0" 4 | description = "Unified API to facilitate usage of pre-trained \"perceptor\" models, a la CLIP" 5 | authors = ["David Marx "] 6 | license = "MIT" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.7" 10 | clip = {git = "https://github.com/openai/CLIP", branch = "main"} 11 | loguru = "^0.6.0" 12 | #Pillow = "^7.1.2" #"^9.1.0" 13 | #torch = "^1.11.0" 14 | #torchvision = "^0.12.0" 15 | #torchaudio = "^0.11.0" 16 | kornia = "^0.6.4" 17 | open-clip-torch = {git = "https://github.com/mlfoundations/open_clip", branch = "main"} 18 | declip = {git = "https://github.com/pytti-tools/DeCLIP", branch = "installable"} 19 | kelip = {git = "https://github.com/navervision/KELIP.git", branch = "master"} 20 | sentence-transformers = "^2.2.0" 21 | napm="^0.2.0" 22 | timm="^0.5.4" 23 | 24 | [tool.poetry.dev-dependencies] 25 | black = "^22.3.0" 26 | pytest = "^7.1.1" 27 | poethepoet = "^0.13.1" 28 | 29 | [build-system] 30 | requires = ["poetry-core>=1.0.0"] 31 | build-backend = "poetry.core.masonry.api" 32 | 33 | 34 | # https://github.com/nat-n/poethepoet 35 | [tool.poe.tasks] 36 | napm_installs = { "script" = "mmc.napm_installs:all" } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | setup( 3 | name='mmc', 4 | version='0.2.0', 5 | install_requires=[ 6 | "wheel", 7 | "loguru", 8 | "clip-anytorch", 9 | "napm", 10 | ], 11 | packages=find_packages( 12 | #where='src/mmc', 13 | where='src', 14 | include=['mmc*'], # ["*"] by default 15 | #exclude=['mypackage.tests'], # empty by default 16 | ), 17 | package_dir={'mmc': 'src/mmc'}, 18 | ) 19 | -------------------------------------------------------------------------------- /src/mmc/__init__.py: -------------------------------------------------------------------------------- 1 | #from .loaders import * 2 | #from .registry import * -------------------------------------------------------------------------------- /src/mmc/ez/CLIP.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | #from CLIP import clip as openai_clip 4 | import clip as openai_clip 5 | 6 | # initialize 7 | #from ..loaders 8 | import mmc 9 | import mmc.loaders 10 | 11 | #from ..mock.openai import MockOpenaiClip 12 | #from ..registry import REGISTRY 13 | 14 | from mmc.mock.openai import MockOpenaiClip 15 | from mmc.registry import REGISTRY 16 | 17 | 18 | 19 | 20 | class EzClip: 21 | def __init__(self): 22 | self._last_fetched_loader = None 23 | self._last_fetched_model = None 24 | self.d_openai={m.id:str(m) for m in REGISTRY.find(publisher='openai')} 25 | def _id_to_alias(self, id: str) -> str: 26 | """ 27 | Converts a model id to an MMC alias. 28 | """ 29 | if id in self.d_openai: 30 | return self.d_openai[id] 31 | else: 32 | return id 33 | 34 | def _alias_to_query(self, alias: str) -> dict: 35 | """ 36 | Converts an MMC alias to a query. 37 | """ 38 | architecture, publisher, id = alias[1:-1].split(' - ') 39 | return {'id':id, 'publisher':publisher, 'architecture':architecture} 40 | 41 | def available_models(self): 42 | """ 43 | Returns a list of available models. 44 | """ 45 | return list(self.d_openai.keys()) + [str(m) for m in REGISTRY.find()] 46 | 47 | def load(self, id, device=None): 48 | """ 49 | Loads a model from the registry. 50 | """ 51 | alias = self._id_to_alias(id) 52 | query = self._alias_to_query(alias) 53 | hits = REGISTRY.find(**query) 54 | if len(hits) < 1: 55 | raise ValueError(f"No model found for id: {id}") 56 | elif len(hits) > 1: 57 | raise ValueError(f"Multiple models found for id: {id}") 58 | 59 | loader = hits[0] 60 | model = loader.load(device) 61 | self._last_fetched_loader = loader 62 | self._last_fetched_model = model 63 | 64 | mocked = MockOpenaiClip(model, device) 65 | image_preprocessor = model 66 | return mocked, image_preprocessor 67 | 68 | @property 69 | def tokenize(self): 70 | """ 71 | Returns the tokenizer for the last loaded model. 72 | """ 73 | if self._last_fetched_model is None: 74 | #raise ValueError("No model loaded.") 75 | warnings.warn( 76 | "No model loaded. Returning OpenAI's default tokenizer. " 77 | "If this is not what you want, call `clip.load` before requesting the tokenizer." 78 | ) 79 | return openai_clip.tokenize 80 | return self._last_fetched_model.modes['text']['preprocessor'] 81 | 82 | 83 | clip = EzClip() 84 | -------------------------------------------------------------------------------- /src/mmc/ez/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmarx/Multi-Modal-Comparators/e8ae395c0aca6b8b872ae38feb160bd53078e721/src/mmc/ez/__init__.py -------------------------------------------------------------------------------- /src/mmc/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .basemmcloader import * 2 | from .openaicliploader import * 3 | from .mlfcliploader import * 4 | from .sbertclibloader import * 5 | from .clipfaloader import * 6 | from .cloobloader import * 7 | from .keliploader import * 8 | from .fairsliploader import * -------------------------------------------------------------------------------- /src/mmc/loaders/basemmcloader.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import TYPE_CHECKING 3 | 4 | if TYPE_CHECKING: 5 | from ..multimodalcomparator import MultiModalComparator 6 | 7 | class BaseMmcLoader(abc.ABC): 8 | """ 9 | Base class that manages the procedure for loading MMC objects 10 | """ 11 | def __init__( 12 | self, 13 | architecture=None, 14 | publisher=None, 15 | id=None, 16 | ): 17 | self.architecture = architecture 18 | self.publisher = publisher 19 | self.id = id 20 | self.modalities = () 21 | @abc.abstractmethod 22 | def load(self) -> "MultiModalComparator": 23 | """ 24 | Load the MMC object associated with this loader. 25 | """ 26 | return 27 | 28 | def supports_modality(self, modality) -> bool: 29 | """ 30 | Generic test for clarifying whether a specific modality is supported by the MMC this loader returns. 31 | """ 32 | return any(modality.name == m.name for m in self.modalities) 33 | 34 | def __str__(self) -> str: 35 | return f"[{self.architecture} - {self.publisher} - {self.id}]" 36 | 37 | def __repr__(self) -> str: 38 | return self.__str__() 39 | -------------------------------------------------------------------------------- /src/mmc/loaders/clipfaloader.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Loaders for pretrained CLIP models published by OpenAI 4 | """ 5 | 6 | #import clip # this should probably be isolated somehow 7 | from loguru import logger 8 | import torch 9 | 10 | from .basemmcloader import BaseMmcLoader 11 | from ..modalities import TEXT, IMAGE 12 | from ..multimodalcomparator import MultiModalComparator 13 | from ..registry import REGISTRY, register_model 14 | 15 | DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 16 | 17 | 18 | class ClipFaLoader(BaseMmcLoader): 19 | """ 20 | CLIP model trained for the Farsi language (Persian) 21 | https://github.com/sajjjadayobi/CLIPfa 22 | """ 23 | def __init__( 24 | self, 25 | #id, 26 | device=DEVICE, 27 | ): 28 | self.device=device 29 | self.architecture = 'clip' # should this be a type too? 30 | self.publisher = 'sajjjadayobi' 31 | self.id = 'clipfa' 32 | self.modalities = (TEXT, IMAGE) 33 | def load(self, device=None): 34 | """ 35 | Returns the MMC associated with this loader. 36 | """ 37 | if device is None: 38 | device = self.device 39 | #import clip 40 | #model, preprocess_image = clip.load(self.id, jit=False, device=device) 41 | #model.eval() 42 | #model.requires_grad_(False) 43 | #model.to(device, memory_format=torch.channels_last) 44 | #tokenizer = clip.tokenize # clip.simple_tokenizer.SimpleTokenizer() 45 | #def preprocess_image_extended(*args, **kwargs): 46 | # x = preprocess_image(*args, **kwargs) 47 | # if x.ndim == 3: 48 | # logger.debug("adding batch dimension") 49 | # x = x.unsqueeze(0) 50 | # return x 51 | from transformers import CLIPVisionModel, RobertaModel, AutoTokenizer, CLIPFeatureExtractor 52 | # download pre-trained models 53 | vision_encoder = CLIPVisionModel.from_pretrained('SajjadAyoubi/clip-fa-vision') 54 | preprocessor = CLIPFeatureExtractor.from_pretrained('SajjadAyoubi/clip-fa-vision') 55 | text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text') 56 | tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text') 57 | vision_encoder.to(device) 58 | text_encoder.to(device) 59 | #text_embedding = text_encoder(**tokenizer(text, return_tensors='pt')).pooler_output 60 | #image_embedding = vision_encoder(**preprocessor(image, return_tensors='pt')).pooler_output 61 | mmc = MultiModalComparator(name=str(self), device=device) 62 | mmc.register_modality(modality=TEXT, projector=text_encoder, preprocessor=tokenizer) 63 | mmc.register_modality(modality=IMAGE, projector=vision_encoder, preprocessor=preprocessor) 64 | mmc._model = vision_encoder 65 | return mmc 66 | 67 | 68 | register_model( 69 | ClipFaLoader() 70 | ) -------------------------------------------------------------------------------- /src/mmc/loaders/cloobloader.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Loaders for pretrained CLOOB model by crowsonkb 4 | https://github.com/crowsonkb/cloob-training 5 | """ 6 | 7 | # importing this first is necessary for cloob to be available 8 | import napm 9 | 10 | from loguru import logger 11 | import torch 12 | 13 | from .basemmcloader import BaseMmcLoader 14 | from ..modalities import TEXT, IMAGE 15 | from ..multimodalcomparator import MultiModalComparator 16 | from ..registry import REGISTRY, register_model 17 | 18 | from torchvision.transforms import ToTensor 19 | 20 | DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 21 | 22 | from typing import TYPE_CHECKING 23 | 24 | if TYPE_CHECKING: 25 | import PIL 26 | 27 | class KatCloobLoader(BaseMmcLoader): 28 | """ 29 | CLOOB models by crowsonkb, initially trained on LAION datasets 30 | https://github.com/crowsonkb/cloob-training 31 | """ 32 | def __init__( 33 | self, 34 | id='cloob_laion_400m_vit_b_16_32_epochs', 35 | ): 36 | self.architecture = 'cloob' # should this be a type too? 37 | self.publisher = 'crowsonkb' 38 | self.id = id 39 | self.modalities = (TEXT, IMAGE) 40 | def load(self, device=DEVICE): 41 | """ 42 | Returns the MMC associated with this loader. 43 | """ 44 | logger.debug('using napm to "install" katCLOOB') 45 | url = "https://github.com/crowsonkb/cloob-training" 46 | napm.pseudoinstall_git_repo(url, env_name='mmc', package_name='cloob') 47 | napm.populate_pythonpaths('mmc') 48 | from cloob.cloob_training import model_pt, pretrained 49 | 50 | config = pretrained.get_config(self.id) 51 | model = model_pt.get_pt_model(config) 52 | checkpoint = pretrained.download_checkpoint(config) 53 | model.load_state_dict(model_pt.get_pt_params(config, checkpoint)) 54 | model.eval().requires_grad_(False).to(device) 55 | d_im = config['image_encoder']['image_size'] 56 | 57 | def _preprocess_closure(img: "PIL.Image.Image") -> torch.Tensor: 58 | img = img.resize((d_im, d_im)).convert('RGB') 59 | t_img = ToTensor()(img) 60 | if t_img.ndim == 3: 61 | t_img = t_img.unsqueeze(0) 62 | t_img = t_img.to(device) 63 | return model.normalize(t_img) 64 | 65 | mmc = MultiModalComparator(name=str(self), device=device) 66 | mmc.register_modality(modality=TEXT, projector=model.text_encoder, preprocessor=model.tokenize) 67 | mmc.register_modality(modality=IMAGE, projector=model.image_encoder, preprocessor=_preprocess_closure) 68 | mmc._model = model 69 | return mmc 70 | 71 | try: 72 | from cloob.cloob_training import model_pt, pretrained 73 | for model_name in pretrained.list_configs(): 74 | register_model( 75 | KatCloobLoader(id=model_name) 76 | ) 77 | except: 78 | logger.warning( 79 | "unable to import cloob: bypassing loader registration. You can still isntall and load cloob via:" 80 | "`model = KatCloobLoader(id=...).load()`" 81 | ) -------------------------------------------------------------------------------- /src/mmc/loaders/fairsliploader.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Loaders for pretrained CLOOB model by crowsonkb 4 | https://github.com/crowsonkb/cloob-training 5 | """ 6 | from collections import OrderedDict 7 | from pathlib import Path 8 | from platform import architecture 9 | from typing import TYPE_CHECKING 10 | 11 | from loguru import logger 12 | import napm 13 | import torch 14 | from torch import hub 15 | from torchvision import transforms 16 | 17 | from .basemmcloader import BaseMmcLoader 18 | from ..modalities import TEXT, IMAGE 19 | from ..multimodalcomparator import MultiModalComparator 20 | from ..registry import register_model 21 | 22 | 23 | DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 24 | 25 | if TYPE_CHECKING: 26 | import PIL 27 | 28 | 29 | val_transform = transforms.Compose([ 30 | transforms.Resize(224), 31 | transforms.CenterCrop(224), 32 | transforms.ToTensor(), 33 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 34 | std=[0.229, 0.224, 0.225]) 35 | ]) 36 | 37 | 38 | #url_template = "https://dl.fbaipublicfiles.com/slip/{arch}_{size}_{eps}ep.pt" 39 | # let's use "{arch}_{size}_{eps}ep" as the id 40 | 41 | def parse_id(slip_id: str): 42 | arch, size, *_ = slip_id.split('_') 43 | return arch, size 44 | 45 | def loader_name_from_id(slip_id: str): 46 | arch, size = parse_id(slip_id) 47 | name_str = f"{arch.upper()}_VIT{size[0].upper()}16" 48 | return name_str 49 | 50 | def model_factory_from_id(slip_id: str): 51 | from SLIP.models import ( 52 | SLIP_VITS16, 53 | SLIP_VITB16, 54 | SLIP_VITL16, 55 | ) 56 | name = loader_name_from_id(slip_id) 57 | logger.debug(name) 58 | #model_factory = globals().get(name) 59 | model_factory = locals().get(name) 60 | return model_factory 61 | 62 | def url_from_id(slip_id: str): 63 | return f"https://dl.fbaipublicfiles.com/slip/{slip_id}.pt" 64 | 65 | def id_from_url(url): 66 | fname = url.split('/')[-1] 67 | return fname.split('.')[0] 68 | 69 | 70 | def fetch_weights(url, namespace, device=DEVICE): 71 | """ 72 | Downloads the weights from the given url and saves them to the given path. 73 | If weights have already been downloaded, they are loaded from the path. 74 | 75 | :param url: The URL of the checkpoint file 76 | :param namespace: The name of the model 77 | :param device: The device to load the weights on 78 | :return: A dictionary of the weights and biases of the model. 79 | """ 80 | fname = url.split('/')[-1] 81 | fpath = Path(hub.get_dir()) / namespace / fname 82 | try: 83 | ckpt = torch.load(fpath, map_location=device) 84 | except FileNotFoundError: 85 | download(url, fpath) 86 | ckpt = torch.load(fpath, map_location=device) 87 | return ckpt 88 | 89 | 90 | def download(url, fpath): 91 | """ 92 | If the file doesn't exist, download it 93 | 94 | :param url: The URL of the file to download 95 | :param fpath: The path to the file to download 96 | """ 97 | if not Path(fpath).exists(): 98 | Path(fpath).parent.mkdir(parents=True, exist_ok=True) 99 | hub.download_url_to_file(url, fpath) 100 | if not Path(fpath).exists(): 101 | raise FileNotFoundError(f"Download failed: {url}") 102 | 103 | def fix_param_names_old(ckpt): 104 | """ 105 | Takes a checkpoint dictionary and removes the "module" prefix from the keys in the state_dict 106 | 107 | :param ckpt: the checkpoint file 108 | """ 109 | logger.debug(ckpt.keys()) 110 | logger.debug(ckpt['args']) 111 | sd = ckpt['state_dict'] 112 | real_sd = {} 113 | for k, v in sd.items(): 114 | new_key = '.'.join(k.split('.')[1:]) # strips "module" prefix. sure, why not. 115 | real_sd[new_key] = v 116 | del ckpt['state_dict'] 117 | ckpt['state_dict'] = real_sd 118 | 119 | def fix_param_names(ckpt): 120 | # via https://github.com/pixray/pixray/blob/master/slip.py#L127-L128 121 | state_dict = OrderedDict() 122 | for k, v in ckpt['state_dict'].items(): 123 | state_dict[k.replace('module.', '')] = v 124 | ckpt['state_dict'] = state_dict 125 | 126 | 127 | ####################################################################################################################### 128 | 129 | 130 | 131 | 132 | class FairSlipLoaderBase(BaseMmcLoader): 133 | """ 134 | SLIP models via https://github.com/facebookresearch/SLIP 135 | """ 136 | def __init__( 137 | self, 138 | id, 139 | architecture, 140 | ): 141 | self.architecture = architecture 142 | self.publisher = 'facebookresearch' 143 | self.id = id 144 | self.modalities = (TEXT, IMAGE) 145 | def _napm_install(self): 146 | logger.debug('using napm to "install" facebookresearch/SLIP') 147 | url = "https://github.com/facebookresearch/SLIP" 148 | napm.pseudoinstall_git_repo(url, env_name='mmc', add_install_dir_to_path=True) 149 | napm.populate_pythonpaths('mmc') 150 | from SLIP.models import ( 151 | SLIP_VITS16, 152 | SLIP_VITB16, 153 | SLIP_VITL16 154 | ) 155 | 156 | def load(self, device=DEVICE): 157 | """ 158 | Returns the MMC associated with this loader. 159 | """ 160 | self._napm_install() 161 | 162 | model_factory = model_factory_from_id(self.id) 163 | logger.debug(f"model_factory: {model_factory}") 164 | ckpt_url = url_from_id(self.id) 165 | ckpt = fetch_weights( 166 | url=ckpt_url, 167 | namespace='fair_slip', 168 | device=device, 169 | ) 170 | d_args = vars(ckpt['args']) 171 | kwargs = {k:d_args[k] for k in ('ssl_emb_dim', 'ssl_mlp_dim') if k in d_args} 172 | logger.debug(kwargs) 173 | fix_param_names(ckpt) 174 | model = model_factory(**kwargs) 175 | model.load_state_dict(ckpt['state_dict'], strict=True) 176 | model = model.eval().to(device) 177 | 178 | from SLIP.tokenizer import SimpleTokenizer 179 | tokenizer = SimpleTokenizer() 180 | 181 | def preprocess_image_extended(*args, **kwargs): 182 | x = val_transform(*args, **kwargs) 183 | if x.ndim == 3: 184 | logger.debug("adding batch dimension") 185 | x = x.unsqueeze(0) 186 | return x.to(device) 187 | #logger.debug(model) 188 | mmc = MultiModalComparator(name=str(self), device=device) 189 | mmc.register_modality(modality=TEXT, projector=model.encode_text, preprocessor=tokenizer) 190 | mmc.register_modality(modality=IMAGE, projector=model.encode_image, preprocessor= preprocess_image_extended) 191 | mmc._model = model 192 | return mmc 193 | 194 | 195 | class FairSlipLoader_YFCC15M(FairSlipLoaderBase): 196 | """ 197 | SLIP models via https://github.com/facebookresearch/SLIP 198 | """ 199 | def __init__( 200 | self, 201 | id, 202 | architecture, 203 | ): 204 | super().__init__(id, architecture) 205 | self.dataset = 'YFCC15M' 206 | 207 | 208 | class FairSlipLoader_CC3M(FairSlipLoaderBase): 209 | """ 210 | SLIP models via https://github.com/facebookresearch/SLIP 211 | """ 212 | def __init__( 213 | self, 214 | id, 215 | architecture, 216 | ): 217 | super().__init__(id, architecture) 218 | self.dataset = 'CC3M' 219 | 220 | 221 | class FairSlipLoader_CC12M(FairSlipLoaderBase): 222 | """ 223 | SLIP models via https://github.com/facebookresearch/SLIP 224 | """ 225 | def __init__( 226 | self, 227 | id, 228 | architecture, 229 | ): 230 | super().__init__(id, architecture) 231 | self.dataset = 'CC12M' 232 | 233 | 234 | 235 | # To do: register models 236 | 237 | # ViT-Small (MoCo v3 version w/ 12 vs. 6 heads) 238 | model_ids = [ 239 | 'clip_small_25ep', 240 | 'simclr_small_25ep', 241 | 'slip_small_25ep', 242 | 'slip_small_50ep', 243 | 'slip_small_100ep', 244 | 'clip_base_25ep', 245 | 'simclr_base_25ep', 246 | 'slip_base_25ep', 247 | 'slip_base_50ep', 248 | 'slip_base_100ep', 249 | 'clip_large_25ep', 250 | 'simclr_large_25ep', 251 | 'slip_large_25ep', 252 | 'slip_large_50ep', 253 | 'slip_large_100ep', 254 | ] 255 | 256 | for mid in model_ids: 257 | arch, _, _ = mid.split('_') 258 | register_model( 259 | FairSlipLoader_YFCC15M( 260 | id=mid, 261 | architecture=arch, 262 | ) 263 | ) 264 | 265 | 266 | model_ids_cc3m = [ 267 | 'clip_base_cc3m_40ep', 268 | 'slip_base_cc3m_40ep', 269 | ] 270 | 271 | for mid in model_ids_cc3m: 272 | arch, _, _, _ = mid.split('_') 273 | register_model( 274 | FairSlipLoader_CC3M( 275 | id=mid, 276 | architecture=arch, 277 | ) 278 | ) 279 | 280 | 281 | model_ids_cc12m = [ 282 | 'clip_base_cc12m_35ep', 283 | 'slip_base_cc12m_35ep', 284 | ] 285 | 286 | for mid in model_ids_cc12m: 287 | arch, _, _, _ = mid.split('_') 288 | register_model( 289 | FairSlipLoader_CC12M( 290 | id=mid, 291 | architecture=arch, 292 | ) 293 | ) 294 | 295 | -------------------------------------------------------------------------------- /src/mmc/loaders/keliploader.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Loaders for pretrained Korean CLIP (KELIP) published by navervision 4 | https://github.com/navervision/KELIP 5 | """ 6 | 7 | #import clip # this should probably be isolated somehow 8 | from loguru import logger 9 | import torch 10 | 11 | from .basemmcloader import BaseMmcLoader 12 | from ..modalities import TEXT, IMAGE 13 | from ..multimodalcomparator import MultiModalComparator 14 | from ..registry import REGISTRY, register_model 15 | 16 | DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 17 | 18 | 19 | #class ClipFaLoader(BaseMmcLoader): 20 | class ClipKelipLoader(BaseMmcLoader): 21 | """ 22 | CLIP model trained for Korean and English languages 23 | https://github.com/navervision/KELIP 24 | """ 25 | def __init__( 26 | self, 27 | id='kelip_ViT-B/32', 28 | ): 29 | self.architecture = 'clip' # should this be a type too? 30 | self.publisher = 'navervision' 31 | self.id = id 32 | self.modalities = (TEXT, IMAGE) 33 | def load(self, device=DEVICE): 34 | """ 35 | Returns the MMC associated with this loader. 36 | """ 37 | import kelip 38 | _id = self.id.replace('kelip_','') 39 | model, preprocess_img, tokenizer = kelip.build_model(_id) 40 | 41 | mmc = MultiModalComparator(name=str(self), device=device) 42 | mmc.register_modality(modality=TEXT, projector=model.encode_text, preprocessor=tokenizer) 43 | mmc.register_modality(modality=IMAGE, projector=model.encode_image, preprocessor=preprocess_img) 44 | mmc._model = model 45 | return mmc 46 | 47 | 48 | register_model( 49 | #They don't have a systematic way for listing their weights it for now and only support ViT-B/32 50 | ClipKelipLoader(id='kelip_ViT-B/32') 51 | ) -------------------------------------------------------------------------------- /src/mmc/loaders/mlfcliploader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loaders for pretrained CLIP models published by MLFoundations 3 | https://github.com/mlfoundations/open_clip 4 | """ 5 | 6 | 7 | #import clip # this should probably be isolated somehow 8 | #import open_clip 9 | from loguru import logger 10 | import torch 11 | 12 | from .basemmcloader import BaseMmcLoader 13 | from ..modalities import TEXT, IMAGE 14 | from ..multimodalcomparator import MultiModalComparator 15 | from ..registry import register_model 16 | 17 | DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 18 | 19 | 20 | class MlfClipLoader(BaseMmcLoader): 21 | """ 22 | Generic class for loading CLIP models published by MLFoundations. 23 | https://github.com/mlfoundations/open_clip 24 | 25 | There should be a one-to-one mapping between loader objects 26 | and specific sets of pretrained weights (distinguished by the "id" field). 27 | """ 28 | def __init__( 29 | self, 30 | id, 31 | metadata=None, 32 | device=DEVICE, 33 | ): 34 | self.architecture = 'clip' # should this be a type too? 35 | self.publisher = 'mlfoundations' 36 | self.id = id 37 | self.modalities = (TEXT, IMAGE) 38 | self.metadata = {} if metadata is None else metadata 39 | self.device = device 40 | 41 | def load(self, device=None): 42 | """ 43 | Returns the MMC associated with this loader. 44 | """ 45 | if device is None: 46 | device = self.device 47 | 48 | import open_clip 49 | #model, preprocess_image = clip.load(self.id, jit=False, device=device) 50 | model_name, dataset = self.id.split('--') 51 | #model, _, preprocess_image = open_clip.create_model_and_transforms( 52 | model, preprocess_image, _ = open_clip.create_model_and_transforms( 53 | model_name=model_name, 54 | pretrained=dataset) 55 | 56 | model.requires_grad_(False) 57 | model.eval() 58 | #model.set_grad_checkpointing() 59 | model.to(device, memory_format=torch.channels_last) 60 | #tokenizer = clip.tokenize # clip.simple_tokenizer.SimpleTokenizer() 61 | tokenizer = open_clip.tokenize # clip.simple_tokenizer.SimpleTokenizer() 62 | def preprocess_image_extended(*args, **kwargs): 63 | x = preprocess_image(*args, **kwargs) 64 | if x.ndim == 3: 65 | logger.debug("adding batch dimension") 66 | x = x.unsqueeze(0) 67 | return x 68 | mmc = MultiModalComparator(name=str(self), device=device) 69 | mmc.register_modality(modality=TEXT, projector=model.encode_text, preprocessor=tokenizer) 70 | mmc.register_modality(modality=IMAGE, projector=model.encode_image, preprocessor=preprocess_image_extended) 71 | mmc._model = model 72 | return mmc 73 | 74 | try: 75 | import open_clip 76 | for model_name, dataset in open_clip.list_pretrained(): 77 | metadata = {} 78 | if model_name == "ViT-B-16-plus-240": 79 | metadata['input_resolution'] = 240 80 | logger.debug((model_name, metadata)) 81 | register_model( 82 | MlfClipLoader( 83 | id=f"{model_name}--{dataset}", 84 | metadata=metadata), 85 | ) 86 | except ImportError: 87 | pass 88 | -------------------------------------------------------------------------------- /src/mmc/loaders/openaicliploader.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Loaders for pretrained CLIP models published by OpenAI 4 | """ 5 | 6 | import clip # this should probably be isolated somehow 7 | from loguru import logger 8 | import torch 9 | 10 | from .basemmcloader import BaseMmcLoader 11 | from ..modalities import TEXT, IMAGE 12 | from ..multimodalcomparator import MultiModalComparator 13 | from ..registry import REGISTRY, register_model 14 | 15 | DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 16 | 17 | 18 | class OpenAiClipLoader(BaseMmcLoader): 19 | """ 20 | Generic class for loading CLIP models published by OpenAI. 21 | There should be a one-to-one mapping between loader objects 22 | and specific sets of pretrained weights (distinguished by the "id" field) 23 | """ 24 | def __init__( 25 | self, 26 | id, 27 | device=DEVICE, 28 | ): 29 | self.architecture = 'clip' # should this be a type too? 30 | self.publisher = 'openai' 31 | self.id = id 32 | self.modalities = (TEXT, IMAGE) 33 | self.device = device 34 | 35 | def load(self, device=None): 36 | """ 37 | Returns the MMC associated with this loader. 38 | """ 39 | if device is None: 40 | device = self.device 41 | import clip 42 | model, preprocess_image = clip.load(self.id, jit=False, device=device) 43 | model.eval() 44 | model.requires_grad_(False) 45 | #model.to(device, memory_format=torch.channels_last) 46 | tokenizer = clip.tokenize # clip.simple_tokenizer.SimpleTokenizer() 47 | def preprocess_image_extended(*args, **kwargs): 48 | x = preprocess_image(*args, **kwargs) 49 | if x.ndim == 3: 50 | logger.debug("adding batch dimension") 51 | x = x.unsqueeze(0) 52 | return x 53 | mmc = MultiModalComparator(name=str(self), device=device) 54 | mmc.register_modality(modality=TEXT, projector=model.encode_text, preprocessor=tokenizer) 55 | mmc.register_modality(modality=IMAGE, projector=model.encode_image, preprocessor=preprocess_image_extended) 56 | mmc._model = model 57 | return mmc 58 | 59 | 60 | for model_name in clip.available_models(): 61 | #REGISTRY.loaders.append( 62 | # OpenAiClipLoader(id=model_name) 63 | #) 64 | register_model( 65 | OpenAiClipLoader(id=model_name) 66 | ) -------------------------------------------------------------------------------- /src/mmc/loaders/sbertclibloader.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Loaders for pretrained CLIP models published by OpenAI 4 | """ 5 | 6 | #import clip # this should probably be isolated somehow 7 | from loguru import logger 8 | import torch 9 | 10 | from .basemmcloader import BaseMmcLoader 11 | from ..modalities import TEXT, IMAGE 12 | from ..multimodalcomparator import MultiModalComparator 13 | from ..registry import REGISTRY, register_model 14 | 15 | DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 16 | 17 | 18 | class SBertClipLoader(BaseMmcLoader): 19 | """ 20 | Multilingual text encoder aligned to the latent space of OpenAI's CLIP-ViT-B-32 21 | * https://huggingface.co/sentence-transformers/clip-ViT-B-32-multilingual-v1 22 | * https://www.sbert.net/docs/pretrained_models.html#image-text-models 23 | 24 | Primary language support: ar, bg, ca, cs, da, de, el, es, et, fa, fi, fr, fr-ca, 25 | gl, gu, he, hi, hr, hu, hy, id, it, ja, ka, ko, ku, lt, lv, mk, mn, mr, ms, my, nb, 26 | nl, pl, pt, pt, pt-br, ro, ru, sk, sl, sq, sr, sv, th, tr, uk, ur, vi, zh-cn, zh-tw 27 | 28 | Likely weak support for all languages compatible with multi-lingual DistillBERT: 29 | https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages 30 | """ 31 | def __init__( 32 | self, 33 | #id, 34 | ): 35 | self.architecture = 'clip' # should this be a type too? 36 | self.publisher = 'sbert' 37 | self.id = 'ViT-B-32-multilingual-v1' #id 38 | self.modalities = (TEXT, IMAGE) 39 | def load(self, device=DEVICE): 40 | """ 41 | Returns the MMC associated with this loader. 42 | """ 43 | # TO DO: only load text encoder if the OpenAI CLIP image encoder is already 44 | # attached to the invoking multimmc 45 | #import clip 46 | #model, preprocess_image = clip.load(self.id, jit=False, device=device) 47 | #model.eval() 48 | #model.requires_grad_(False) 49 | #model.to(device, memory_format=torch.channels_last) 50 | #tokenizer = clip.tokenize # clip.simple_tokenizer.SimpleTokenizer() 51 | #def preprocess_image_extended(*args, **kwargs): 52 | # x = preprocess_image(*args, **kwargs) 53 | # if x.ndim == 3: 54 | # logger.debug("adding batch dimension") 55 | # x = x.unsqueeze(0) 56 | # return x 57 | from sentence_transformers import SentenceTransformer 58 | img_model = SentenceTransformer('clip-ViT-B-32') # this should be identical to the model published by openai 59 | text_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1') 60 | 61 | # default behavior returns numpy arrays. converting to tensors for API consistency 62 | def image_project_to_tensor(img): 63 | return torch.tensor(img_model.encode(img)).to(device) 64 | 65 | def text_project_to_tensor(txt): 66 | return torch.tensor(text_model.encode(txt)).to(device) 67 | 68 | # To do: we have a 'preprocess' pattern, should add a 'postprocess' pattern too. 69 | # then instead of defining closures here, could just pass in TF.to_tensor() 70 | 71 | mmc = MultiModalComparator(name=str(self), device=device) 72 | mmc.register_modality(modality=TEXT, projector=text_project_to_tensor )#, preprocessor=tokenizer) 73 | mmc.register_modality(modality=IMAGE, projector=image_project_to_tensor )#, preprocessor=preprocess_image_extended) 74 | mmc._model = img_model 75 | return mmc 76 | 77 | register_model(SBertClipLoader()) -------------------------------------------------------------------------------- /src/mmc/mock/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmarx/Multi-Modal-Comparators/e8ae395c0aca6b8b872ae38feb160bd53078e721/src/mmc/mock/__init__.py -------------------------------------------------------------------------------- /src/mmc/mock/openai.py: -------------------------------------------------------------------------------- 1 | """ 2 | API Mocks for models published by OpenAI 3 | """ 4 | 5 | import torch 6 | from dataclasses import dataclass 7 | 8 | DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 9 | 10 | from loguru import logger 11 | 12 | 13 | class MockOpenaiClipModule: 14 | """ 15 | Mocks the OpenAI CLIP.clip module API 16 | """ 17 | def __init__(self, loader, device=DEVICE): 18 | self._loader = loader 19 | self.device = device 20 | 21 | def available_models(self): 22 | # ...should this return the mmc registry? 23 | return str(self._loader) 24 | 25 | @property 26 | def _model(self): 27 | if not hasattr(self, '_model_'): 28 | self._model_ = self._loader.load(self.device) 29 | return self._model_ 30 | 31 | @property 32 | def tokenize(self): 33 | return self._model.modes['text']['preprocessor'] 34 | 35 | @property 36 | def preprocess_image(self): 37 | return self._model.modes['image']['preprocessor'] 38 | 39 | @property 40 | def load(self, id, device=DEVICE): 41 | clip = MockOpenaiClip(self._model, self.device) 42 | return clip, self.preprocess_image 43 | 44 | 45 | @dataclass 46 | class MockVisionModel: 47 | input_resolution: int = 224 48 | output_dim: int = 1024 49 | 50 | 51 | class MockOpenaiClip: 52 | """ 53 | Wrapper class to facilitate drop-in replacement with MMC models where 54 | model interface conforms to OpenAI's CLIP implementationare. 55 | """ 56 | def __init__( 57 | self, 58 | mmc_object, 59 | device=DEVICE, 60 | ): 61 | assert mmc_object.supports_text 62 | assert mmc_object.supports_image 63 | 64 | #if (mmc_object.publisher == 'openai') and (mmc_object.architecture == 'clip'): 65 | # return mmc_object._model 66 | 67 | self.device = device 68 | self.mmc_object = mmc_object 69 | 70 | vision_args = {} 71 | if hasattr(mmc_object, 'input_resolution'): 72 | vision_args['input_resolution'] = mmc_object.input_resolution 73 | elif hasattr(mmc_object, 'metadata'): 74 | vision_args['input_resolution'] = mmc_object.metadata.get( 75 | 'input_resolution', 76 | vision_args['input_resolution'] 77 | ) 78 | 79 | if hasattr(mmc_object, '_model'): 80 | if hasattr(mmc_object._model, 'visual'): 81 | if hasattr(mmc_object._model.visual, 'input_resolution'): 82 | self.visual = mmc_object._model.visual 83 | elif hasattr(mmc_object._model.visual, 'image_size'): 84 | self.visual = mmc_object._model.visual 85 | self.visual.input_resolution = self.visual.image_size 86 | 87 | if not hasattr(self, 'visual'): 88 | logger.debug("'visual' attribute not found in model. Mocking vision model API.") 89 | logger.debug(vision_args) 90 | self.visual = MockVisionModel(**vision_args) 91 | 92 | def encode_image( 93 | self, 94 | image: torch.Tensor, 95 | ) -> torch.Tensor: 96 | #return self.mmc_object.project_image(image) 97 | # bypass pre-processor 98 | #project = self.mmc_object.modes['image']['projector'] 99 | #return project(image) 100 | return self.mmc_object.project_image(image, preprocess=False) 101 | 102 | def encode_text( 103 | self, 104 | text: torch.Tensor, 105 | ) -> torch.Tensor: 106 | #return self.mmc_object.project_text(text) 107 | # bypass pre-processor 108 | #project = self.mmc_object.modes['text']['projector'] 109 | #return project(text) 110 | return self.mmc_object.project_text(text, preprocess=False) 111 | -------------------------------------------------------------------------------- /src/mmc/modalities.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains descriptors for generic data modalities that a particular MMC may support. 3 | 4 | The purpose of these classes is to facilitate determining modality compatibility between MMCs, 5 | and to abstract away specifics of data management that should be common to all media of a given modality. 6 | """ 7 | 8 | # for minimization of dependencies, maybe we could move this import inside the modalities that need it? 9 | # e.g. PIL wouldn't get imported until an Image modality object is invoked somewhere in the library. 10 | import PIL 11 | 12 | class Modality: 13 | """ 14 | A "modality" in our context is a distribution over data that is limited to a single medium. 15 | This class characterizes modalities via metadata attributes and serialization methods. 16 | """ 17 | def __init__(self, 18 | name, 19 | #read_func, 20 | #write_func, 21 | #default_loss, 22 | #default_projector, # e.g. =CLIP, 23 | ): 24 | self.name=name 25 | 26 | def read_from_disk(self, fpath): 27 | """ 28 | Preferred method for loading data of this modality from disk 29 | """ 30 | with open(fpath, 'r') as f: 31 | return f.read() 32 | 33 | def write_to_disk(self, fpath, obj): 34 | """ 35 | Preferred method for writing data of this modality from disk 36 | """ 37 | with open(fpath, 'w') as f: 38 | return f.write(obj) 39 | 40 | 41 | # to do: better mechanism for registering modalities 42 | 43 | TEXT = Modality(name='text') 44 | 45 | IMAGE = Modality(name='image') 46 | IMAGE.read_from_disk = PIL.Image.open 47 | IMAGE.write_to_disk = lambda fpath, obj: obj.save(fpath) 48 | 49 | AUDIO = Modality('audio') 50 | -------------------------------------------------------------------------------- /src/mmc/multimmc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for facilitating computing similarity scores via ensembles of MMCs, 3 | which must be at least compatible with respect to the two modes being compared. 4 | """ 5 | 6 | from loguru import logger 7 | import torch 8 | 9 | from .registry import REGISTRY 10 | from .multimodalcomparator import MultiModalComparator 11 | 12 | DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 13 | 14 | 15 | class MultiMMC(MultiModalComparator): 16 | _registry = REGISTRY #? 17 | def __init__( 18 | self, 19 | *modalities, 20 | device=DEVICE, 21 | ): 22 | self.device=device 23 | self.modalities = modalities 24 | self.models = {} 25 | # probably shouldn't need to redefine/override this method 26 | def supports_modality(self, modality): 27 | return any(modality.name == m.name for m in self.modalities) 28 | # probably shouldn't need to redefine/override this method 29 | def _supports_mode(self, modality_name): 30 | return any(modality_name == m.name for m in self.modalities) #self.modes 31 | def load_model( 32 | self, 33 | architecture='clip', 34 | publisher='openai', 35 | id=None, 36 | weight=1, 37 | device=None, 38 | ): 39 | if device is None: 40 | device = self.device 41 | model_loaders = self._registry.find(architecture=architecture, publisher=publisher, id=id) 42 | for model_loader in model_loaders: 43 | assert all(model_loader.supports_modality(m) for m in self.modalities) 44 | model_key = f"[{architecture} - {publisher} - {id}]" 45 | if model_key not in self.models: 46 | model = model_loader.load() 47 | self.models[model_key] = {'model':model, 'weight':weight} 48 | #self.models[model_key] = {'model':model.to(device), 'weight':weight} 49 | else: 50 | logger.warning(f"Model already loaded: {model_key}") 51 | 52 | def _project_item(self, item, mode) -> dict: 53 | assert self._supports_mode(mode) 54 | projections = {} 55 | for model_name, d_model in self.models.items(): 56 | model, weight = d_model['model'], d_model['weight'] 57 | logger.debug(model_name) 58 | logger.debug(model) 59 | if model._supports_mode(mode): 60 | #item.to(model.device) 61 | #logger.debug(model_name) 62 | #logger.debug(model.name) 63 | #projections[model.name] = {'projection':model._project_item(item, mode), 'weight':weight} 64 | projections[model_name] = {'projection':model._project_item(item, mode), 'weight':weight} 65 | return {'modality':mode, 'projections':projections} 66 | 67 | def compare( 68 | self, 69 | return_projections = False, 70 | **kwargs, 71 | ): 72 | projections = {} # d['modality']['model_key'] = vector 73 | for modality_name, item in kwargs.items(): 74 | projections[modality_name] = self._project_item(item, modality_name) 75 | outv = self._reduce_projections(projections) 76 | if return_projections: 77 | outv = (outv, projections) 78 | return outv 79 | def _reduce_projections(self, projections, return_raw_scores=False): 80 | #logger.debug(projections) 81 | accumulator = 0 82 | raw_scores = {} 83 | # this is hideous and should be trivially vectorizable. 84 | for model_key, d_perceptor in self.models.items(): 85 | perceptor = d_perceptor['model'] 86 | weight = d_perceptor['weight'] 87 | 88 | # compute per-model scores and compute weighted sum 89 | kargs = {} 90 | for modality_name, d_vectors in projections.items(): 91 | #kargs[modality_name] = d_vectors[model_key] 92 | kargs[modality_name] = d_vectors['projections'][model_key]['projection'] 93 | score = perceptor._reduce_projections(**kargs) 94 | accumulator += score * weight 95 | 96 | if return_raw_scores: 97 | raw_scores[model_key] = score 98 | 99 | outv = accumulator 100 | if return_raw_scores: 101 | outv = (accumulator, raw_scores) 102 | return outv 103 | -------------------------------------------------------------------------------- /src/mmc/multimodalcomparator.py: -------------------------------------------------------------------------------- 1 | # this should probably be a nn.Module 2 | #class MultiModalComparator(nn.Module): 3 | 4 | from loguru import logger 5 | import torch 6 | 7 | DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 8 | 9 | 10 | class MultiModalComparator: 11 | """ 12 | Generic class for encapsulating models that can compare data across multiple modalities 13 | """ 14 | def __init__(self, 15 | name=None, 16 | #modalities=[TEXT] 17 | device=DEVICE, 18 | ): 19 | self.modes = {} 20 | self.device=device 21 | #for m in modalities: 22 | # self.register_modality(m) 23 | #assert len(self.modes) > 0 24 | 25 | def register_modality( 26 | self, 27 | modality, 28 | projector, 29 | preprocessor=None, 30 | postprocessor=None, 31 | ): 32 | """ 33 | Register a modality with this MMC. 34 | 35 | The MMC class (this) will manage how data from this modality is processed 36 | for performing comparisons with other modalities supported by this MMC. 37 | This registration function specifies that procedure for a single modality. 38 | 39 | An MMC is required to have a single processing procedure for each modality it 40 | supports. If data for a modality can be read in from different formats, that should 41 | be addressed by the Modality class. If your given modality requires different processing 42 | procedures, you may need to revisit how you are defining "modalities" here. For example, 43 | the CARP model compares passages of narrative text to criticisms of the passages. Although 44 | both the passages and criticisms are "text", from the perspective of CARP they are 45 | separate modalities. To implement a CARP-like model, you would first create PASSAGE and 46 | CRITIQUE modalities (subclassing from the TEXT modality), and then you could register 47 | the appropriate processing procedures for those modalities separately here. 48 | """ 49 | assert modality.name not in self.modes 50 | #if preprocessor is not None: 51 | # preprocessor.to(self.device) 52 | #if postprocessor is not None: 53 | # postprocessor.to(self.device) 54 | if preprocessor is None: 55 | preprocessor = lambda x: x # could lambdas cause pickling issues? 56 | if postprocessor is None: 57 | postprocessor = lambda x: x 58 | self.modes[modality.name] = { 59 | 'modality_obj':modality, 60 | 'projector':projector, #.to(self.device), 61 | 'preprocessor':preprocessor, 62 | 'postprocessor':postprocessor, 63 | } 64 | def supports_modality(self, modality): 65 | return modality.name in self.modes 66 | 67 | def _supports_mode(self, modality_name): 68 | return modality_name in self.modes 69 | 70 | def _preprocess_item(self, item, mode): 71 | preprocess = self.modes[mode]['preprocessor'] 72 | item = preprocess(item) 73 | # If preprocessor is identity, item will not be a tensor 74 | try: 75 | item = item.to(self.device) 76 | except: 77 | pass 78 | if hasattr(item, 'ndim') and (item.ndim == 1): 79 | item = item.unsqueeze(0) 80 | return item 81 | 82 | def _project_item(self, item, mode, preprocess=True): 83 | assert self._supports_mode(mode) 84 | project = self.modes[mode]['projector'] 85 | if preprocess: 86 | item = self._preprocess_item(item, mode) 87 | return project(item) 88 | 89 | @property 90 | def supports_text(self): 91 | return self._supports_mode('text') 92 | @property 93 | def supports_image(self): 94 | return self._supports_mode('image') 95 | @property 96 | def supports_audio(self): 97 | return self._supports_mode('audio') 98 | 99 | def project_text(self, text, preprocess=True): 100 | return self._project_item(item=text, mode='text', preprocess=preprocess) 101 | def project_image(self, image, preprocess=True): 102 | return self._project_item(item=image, mode='image', preprocess=preprocess) 103 | def project_audio(self, audio, preprocess=True): 104 | return self._project_item(item=audio, mode='audio', preprocess=preprocess) 105 | @property 106 | def name(self): 107 | return str(self) 108 | def _reduce_projections(self, **kargs): 109 | #logger.debug(kargs) 110 | projections = [v.squeeze() for v in kargs.values()] #list(kargs.values()) 111 | return torch.dot(*projections) 112 | -------------------------------------------------------------------------------- /src/mmc/napm_installs/__init__.py: -------------------------------------------------------------------------------- 1 | import napm 2 | from loguru import logger 3 | 4 | 5 | def napm_pi_katcloob(): 6 | """ 7 | Usage: 8 | 9 | import cloob 10 | from cloob.cloob_training import model_pt, pretrained 11 | 12 | config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs') 13 | model = model_pt.get_pt_model(config) 14 | checkpoint = pretrained.download_checkpoint(config) 15 | model.load_state_dict(model_pt.get_pt_params(config, checkpoint), ) 16 | model.eval().requires_grad_(False).to('cuda') 17 | """ 18 | logger.debug('using napm to "install" katCLOOB') 19 | url = "https://github.com/crowsonkb/cloob-training" 20 | napm.pseudoinstall_git_repo(url, package_name='cloob') 21 | 22 | 23 | def all(): 24 | napm_pi_katcloob() 25 | 26 | 27 | if __name__ == '__main__': 28 | all() -------------------------------------------------------------------------------- /src/mmc/registry.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | 3 | from typing import List, TYPE_CHECKING 4 | 5 | if TYPE_CHECKING: 6 | from .loaders.basemmcloader import BaseMmcLoader 7 | 8 | class MmcRegistry: 9 | def __init__(self): 10 | self.loaders: List["BaseMmcLoader"] = [] 11 | def find( 12 | self, 13 | **query 14 | ) -> List["BaseMmcLoader"]: 15 | """ 16 | Searches the registry for MMCs loaders 17 | """ 18 | hits = [] 19 | for item in self.loaders: 20 | is_hit = True 21 | for k, v_query in query.items(): 22 | v_item = getattr(item, k) 23 | if (v_item is not None) and (v_item != v_query): 24 | is_hit = False 25 | break 26 | if is_hit: 27 | hits.append(item) 28 | if len(hits) <1: 29 | logger.warning(f"No hits found for query: {query}") 30 | return hits 31 | 32 | REGISTRY = MmcRegistry() 33 | 34 | def register_model(mmc_loader: "BaseMmcLoader"): 35 | """ 36 | Decorator that attaches mmc loaders to the REGISTRY 37 | """ 38 | logger.debug(f"registering model: {mmc_loader}") 39 | REGISTRY.loaders.append(mmc_loader) 40 | return mmc_loader -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmarx/Multi-Modal-Comparators/e8ae395c0aca6b8b872ae38feb160bd53078e721/tests/__init__.py -------------------------------------------------------------------------------- /tests/assets/dummy.txt: -------------------------------------------------------------------------------- 1 | foo bar baz -------------------------------------------------------------------------------- /tests/assets/marley_birthday.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmarx/Multi-Modal-Comparators/e8ae395c0aca6b8b872ae38feb160bd53078e721/tests/assets/marley_birthday.jpg -------------------------------------------------------------------------------- /tests/test_api_mock.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from loguru import logger 3 | import mmc 4 | import PIL 5 | import torch 6 | 7 | 8 | def test_oai_mocking_itself(): 9 | from mmc.mock.openai import MockOpenaiClip 10 | from mmc.loaders import OpenAiClipLoader 11 | 12 | ldr = OpenAiClipLoader(id='RN50') 13 | oai_clip = ldr.load() 14 | model = MockOpenaiClip(oai_clip) 15 | assert model.visual.input_resolution == 224 16 | 17 | 18 | def test_mlf_mocking_oai(): 19 | from mmc.mock.openai import MockOpenaiClip 20 | from mmc.loaders import MlfClipLoader 21 | 22 | ldr = MlfClipLoader(id='RN50--yfcc15m') 23 | mlf_clip = ldr.load() 24 | model = MockOpenaiClip(mlf_clip) 25 | assert model.visual.input_resolution == 224 26 | 27 | 28 | class TestMlfVitb16plus: 29 | 30 | loader_args = {'id':'ViT-B-16-plus-240--laion400m_e32'} 31 | 32 | def test_mock_oai(self): 33 | from mmc.mock.openai import MockOpenaiClip 34 | from mmc.loaders import MlfClipLoader 35 | 36 | ldr = MlfClipLoader(**self.loader_args) 37 | mlf_clip = ldr.load() 38 | model = MockOpenaiClip(mlf_clip) 39 | assert model.visual.input_resolution == (240, 240) 40 | 41 | 42 | def test_project_text(self): 43 | from mmc.mock.openai import MockOpenaiClip 44 | from mmc.loaders import MlfClipLoader 45 | #from clip.simple_tokenizer import SimpleTokenizer 46 | import clip 47 | 48 | ldr = MlfClipLoader(**self.loader_args) 49 | mlf_clip = ldr.load() 50 | model = MockOpenaiClip(mlf_clip) 51 | tokens = clip.tokenize("foo bar baz").to(model.device) 52 | projection = model.encode_text(tokens) 53 | assert isinstance(projection, torch.Tensor) 54 | logger.debug(projection.shape) 55 | 56 | 57 | def test_project_img(self): 58 | from mmc.mock.openai import MockOpenaiClip 59 | from mmc.loaders import MlfClipLoader 60 | 61 | ldr = MlfClipLoader(**self.loader_args) 62 | mlf_clip = ldr.load() 63 | model = MockOpenaiClip(mlf_clip) 64 | im_size = model.visual.input_resolution[0] 65 | logger.debug(im_size) 66 | img = torch.rand(1,3,im_size, im_size) # batch x channels x height x width 67 | #img = torch.rand(3,im_size, im_size) # batch x channels x height x width 68 | logger.debug(img.shape) 69 | img = img.to(model.device) 70 | projection = model.encode_image(img) 71 | assert isinstance(projection, torch.Tensor) 72 | logger.debug(projection.shape) 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | # what even is my expected behavior here? It passes the test but, I'm not sure how I'd even use this. 84 | # maybe this should throw an error? 85 | """ 86 | def test_multi_oai_mocking_oai_init(): 87 | from mmc.mock.openai import MockOpenaiClip 88 | from mmc.multimmc import MultiMMC 89 | from mmc.modalities import TEXT, IMAGE 90 | 91 | perceptor = MultiMMC(TEXT, IMAGE) 92 | models = [ 93 | dict( 94 | architecture='clip', 95 | publisher='openai', 96 | id='RN50', 97 | ), 98 | dict( 99 | architecture='clip', 100 | publisher='openai', 101 | id='ViT-B/32', 102 | )] 103 | for m in models: 104 | perceptor.load_model(**m) 105 | #dir(perceptor) 106 | """ 107 | 108 | -------------------------------------------------------------------------------- /tests/test_ezmode_clip.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | def test_import(): 4 | from mmc.ez.CLIP import clip 5 | 6 | def test_available_models(): 7 | from mmc.ez.CLIP import clip 8 | clip.available_models() 9 | 10 | def test_load_openai_ez(): 11 | from mmc.ez.CLIP import clip 12 | model, preprocessor = clip.load('RN50') 13 | assert model 14 | assert preprocessor 15 | 16 | def test_load_openai_alias(): 17 | from mmc.ez.CLIP import clip 18 | model, preprocessor = clip.load('[clip - openai - RN50]') 19 | assert model 20 | assert preprocessor 21 | -------------------------------------------------------------------------------- /tests/test_mmc.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mmc.loaders import OpenAiClipLoader 3 | import PIL 4 | from loguru import logger 5 | import torch 6 | 7 | def test_oai_clip_project_text(): 8 | ldr = OpenAiClipLoader(id='RN50') 9 | oai_clip = ldr.load() 10 | projection = oai_clip.project_text("foo bar baz") 11 | #logger.debug(type(projection)) 12 | #assert projection['modality'] == 'text' 13 | #logger.debug(projection.shape) # [1 1024] 14 | assert isinstance(projection, torch.Tensor) 15 | 16 | def test_oai_clip_project_img(): 17 | ldr = OpenAiClipLoader(id='RN50') 18 | oai_clip = ldr.load() 19 | img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((250,200)) 20 | projection = oai_clip.project_image(img) 21 | #logger.debug(type(projection)) 22 | #assert projection['modality'] == 'image' 23 | #logger.debug(projection.shape) # [1 1024] 24 | assert isinstance(projection, torch.Tensor) 25 | 26 | def test_oai_clip_supported_modalities(): 27 | ldr = OpenAiClipLoader(id='RN50') 28 | oai_clip = ldr.load() 29 | assert oai_clip.supports_text 30 | assert oai_clip.supports_image 31 | assert not oai_clip.supports_audio 32 | -------------------------------------------------------------------------------- /tests/test_mmc_clipfa.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mmc.loaders import ClipFaLoader as loader 3 | import PIL 4 | from loguru import logger 5 | import torch 6 | 7 | # for some reason SBERT is returning np.ndarrays instead of tensors. 8 | # might be sentence-transformer wonkiness that could be resolved by 9 | # using hugginface/transformers directly. 10 | 11 | #loader_args = {'id':'ViT-B-32-multilingual-v1'} 12 | loader_args = {} 13 | 14 | def test_project_text(): 15 | ldr = loader(**loader_args) 16 | perceptor = ldr.load() 17 | projection = perceptor.project_text("foo bar baz") 18 | print(type(projection)) 19 | assert isinstance(projection, torch.Tensor) 20 | 21 | def test_project_img(): 22 | ldr = loader(**loader_args) 23 | perceptor = ldr.load() 24 | img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((250,200)) 25 | projection = perceptor.project_image(img) 26 | print(type(projection)) 27 | assert isinstance(projection, torch.Tensor) 28 | 29 | def test_supported_modalities(): 30 | ldr = loader(**loader_args) 31 | perceptor = ldr.load() 32 | assert perceptor.supports_text 33 | assert perceptor.supports_image 34 | assert not perceptor.supports_audio 35 | -------------------------------------------------------------------------------- /tests/test_mmc_fairslip.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mmc.loaders import FairSlipLoader_YFCC15M as loader 3 | import PIL 4 | from loguru import logger 5 | import torch 6 | 7 | 8 | loader_args = { 9 | 'id': 'slip_small_100ep', 10 | 'architecture': 'slip', 11 | } 12 | 13 | 14 | def test_project_text(): 15 | ldr = loader(**loader_args) 16 | perceptor = ldr.load() 17 | projection = perceptor.project_text("foo bar baz") 18 | assert isinstance(projection, torch.Tensor) 19 | 20 | def test_project_img(): 21 | ldr = loader(**loader_args) 22 | perceptor = ldr.load() 23 | img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((250,200)) 24 | projection = perceptor.project_image(img) 25 | assert isinstance(projection, torch.Tensor) 26 | 27 | def test_supported_modalities(): 28 | ldr = loader(**loader_args) 29 | perceptor = ldr.load() 30 | assert perceptor.supports_text 31 | assert perceptor.supports_image 32 | assert not perceptor.supports_audio 33 | -------------------------------------------------------------------------------- /tests/test_mmc_fairslip_cc12m.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mmc.loaders import FairSlipLoader_CC12M as loader 3 | import PIL 4 | from loguru import logger 5 | import torch 6 | 7 | 8 | loader_args = { 9 | 'id': 'slip_base_cc12m_35ep', 10 | 'architecture': 'slip', 11 | } 12 | 13 | 14 | def test_project_text(): 15 | ldr = loader(**loader_args) 16 | perceptor = ldr.load() 17 | projection = perceptor.project_text("foo bar baz") 18 | assert isinstance(projection, torch.Tensor) 19 | 20 | def test_project_img(): 21 | ldr = loader(**loader_args) 22 | perceptor = ldr.load() 23 | img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((250,200)) 24 | projection = perceptor.project_image(img) 25 | assert isinstance(projection, torch.Tensor) 26 | 27 | def test_supported_modalities(): 28 | ldr = loader(**loader_args) 29 | perceptor = ldr.load() 30 | assert perceptor.supports_text 31 | assert perceptor.supports_image 32 | assert not perceptor.supports_audio 33 | -------------------------------------------------------------------------------- /tests/test_mmc_fairslip_cc3m.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mmc.loaders import FairSlipLoader_CC3M as loader 3 | import PIL 4 | from loguru import logger 5 | import torch 6 | 7 | 8 | loader_args = { 9 | 'id': 'slip_base_cc3m_40ep', 10 | 'architecture': 'slip', 11 | } 12 | 13 | 14 | def test_project_text(): 15 | ldr = loader(**loader_args) 16 | perceptor = ldr.load() 17 | projection = perceptor.project_text("foo bar baz") 18 | assert isinstance(projection, torch.Tensor) 19 | 20 | def test_project_img(): 21 | ldr = loader(**loader_args) 22 | perceptor = ldr.load() 23 | img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((250,200)) 24 | projection = perceptor.project_image(img) 25 | assert isinstance(projection, torch.Tensor) 26 | 27 | def test_supported_modalities(): 28 | ldr = loader(**loader_args) 29 | perceptor = ldr.load() 30 | assert perceptor.supports_text 31 | assert perceptor.supports_image 32 | assert not perceptor.supports_audio 33 | -------------------------------------------------------------------------------- /tests/test_mmc_katcloob.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mmc.loaders import KatCloobLoader as loader 3 | import PIL 4 | from loguru import logger 5 | import torch 6 | 7 | 8 | #loader_args = {'id':'RN50--cc12m'} 9 | loader_args = {} 10 | 11 | def test_project_text(): 12 | ldr = loader(**loader_args) 13 | perceptor = ldr.load() 14 | projection = perceptor.project_text("foo bar baz") 15 | assert isinstance(projection, torch.Tensor) 16 | 17 | def test_project_img(): 18 | ldr = loader(**loader_args) 19 | perceptor = ldr.load() 20 | img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((250,200)) 21 | projection = perceptor.project_image(img) 22 | assert isinstance(projection, torch.Tensor) 23 | 24 | def test_supported_modalities(): 25 | ldr = loader(**loader_args) 26 | perceptor = ldr.load() 27 | assert perceptor.supports_text 28 | assert perceptor.supports_image 29 | assert not perceptor.supports_audio 30 | -------------------------------------------------------------------------------- /tests/test_mmc_loaders.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mmc.modalities import TEXT, IMAGE 4 | 5 | 6 | def test_loader_import(): 7 | from mmc.loaders import OpenAiClipLoader 8 | 9 | 10 | def test_loader_attrs(): 11 | from mmc.loaders import OpenAiClipLoader 12 | ldr = OpenAiClipLoader(id='RN50') 13 | assert ldr.architecture == 'clip' 14 | assert ldr.publisher == 'openai' 15 | assert ldr.id =='RN50' 16 | assert len(ldr.modalities) == 2 17 | assert ldr.supports_modality(TEXT) 18 | assert ldr.supports_modality(IMAGE) 19 | 20 | 21 | def test_load_oai_clip(): 22 | from mmc.loaders import OpenAiClipLoader 23 | ldr = OpenAiClipLoader(id='RN50') 24 | oai_clip = ldr.load() 25 | 26 | def test_load_mlf_clip(): 27 | from mmc.loaders import MlfClipLoader 28 | ldr = MlfClipLoader(id='RN50--cc12m') 29 | mlf_clip = ldr.load() 30 | 31 | ## Models below pass load but fail inference tests. 32 | # Commonality here I think is models loaded from huggingface 33 | 34 | def test_load_sbert_mclip(): 35 | from mmc.loaders import SBertClipLoader 36 | ldr = SBertClipLoader() 37 | sbert_mclip = ldr.load() 38 | 39 | def test_load_clipfa(): 40 | from mmc.loaders import ClipFaLoader 41 | ldr = ClipFaLoader() 42 | farsi_clip = ldr.load() 43 | 44 | def test_load_katcloob(): 45 | from mmc.loaders import KatCloobLoader 46 | ldr = KatCloobLoader() 47 | cloob = ldr.load() 48 | 49 | def test_load_kelip(): 50 | from mmc.loaders import ClipKelipLoader 51 | ldr = ClipKelipLoader() 52 | kelip = ldr.load() 53 | 54 | def test_load_fairslip_yfcc15m(): 55 | from mmc.loaders.fairsliploader import FairSlipLoader_YFCC15M 56 | ldr = FairSlipLoader_YFCC15M( 57 | architecture='slip', 58 | id='slip_small_100ep', 59 | ) 60 | slip = ldr.load() -------------------------------------------------------------------------------- /tests/test_mmc_mlf.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mmc.loaders import MlfClipLoader as loader 3 | import PIL 4 | from loguru import logger 5 | import torch 6 | 7 | loader_args = {'id':'RN50--cc12m'} 8 | 9 | def test_project_text(): 10 | ldr = loader(**loader_args) 11 | perceptor = ldr.load() 12 | projection = perceptor.project_text("foo bar baz") 13 | assert isinstance(projection, torch.Tensor) 14 | 15 | def test_project_img(): 16 | ldr = loader(**loader_args) 17 | perceptor = ldr.load() 18 | img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((250,200)) 19 | projection = perceptor.project_image(img) 20 | assert isinstance(projection, torch.Tensor) 21 | 22 | def test_supported_modalities(): 23 | ldr = loader(**loader_args) 24 | perceptor = ldr.load() 25 | assert perceptor.supports_text 26 | assert perceptor.supports_image 27 | assert not perceptor.supports_audio 28 | 29 | class TestMlfVitb16plus: 30 | loader_args = {'id':'ViT-B-16-plus-240--laion400m_e32'} 31 | 32 | def test_project_text(self): 33 | ldr = loader(**self.loader_args) 34 | perceptor = ldr.load() 35 | projection = perceptor.project_text("foo bar baz") 36 | assert isinstance(projection, torch.Tensor) 37 | logger.debug(projection.shape) 38 | 39 | def test_project_img(self): 40 | ldr = loader(**self.loader_args) 41 | perceptor = ldr.load() 42 | img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((300,300)) 43 | projection = perceptor.project_image(img) 44 | assert isinstance(projection, torch.Tensor) 45 | logger.debug(projection.shape) 46 | 47 | def test_supported_modalities(self): 48 | ldr = loader(**self.loader_args) 49 | perceptor = ldr.load() 50 | assert perceptor.supports_text 51 | assert perceptor.supports_image 52 | assert not perceptor.supports_audio 53 | -------------------------------------------------------------------------------- /tests/test_mmc_sbert.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mmc.loaders import SBertClipLoader as loader 3 | import PIL 4 | from loguru import logger 5 | import torch 6 | 7 | # for some reason SBERT is returning np.ndarrays instead of tensors. 8 | # might be sentence-transformer wonkiness that could be resolved by 9 | # using hugginface/transformers directly. 10 | 11 | #loader_args = {'id':'ViT-B-32-multilingual-v1'} 12 | loader_args = {} 13 | 14 | def test_project_text(): 15 | ldr = loader(**loader_args) 16 | perceptor = ldr.load() 17 | projection = perceptor.project_text("foo bar baz") 18 | assert isinstance(projection, torch.Tensor) 19 | logger.debug(projection.shape) 20 | 21 | def test_project_img(): 22 | ldr = loader(**loader_args) 23 | perceptor = ldr.load() 24 | img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((250,200)) 25 | projection = perceptor.project_image(img) 26 | assert isinstance(projection, torch.Tensor) 27 | logger.debug(projection.shape) 28 | 29 | def test_supported_modalities(): 30 | ldr = loader(**loader_args) 31 | perceptor = ldr.load() 32 | assert perceptor.supports_text 33 | assert perceptor.supports_image 34 | assert not perceptor.supports_audio 35 | -------------------------------------------------------------------------------- /tests/test_modalities.py: -------------------------------------------------------------------------------- 1 | 2 | def test_TEXT(): 3 | from mmc.modalities import TEXT 4 | TEXT.name == 'text' 5 | 6 | def test_IMAGE(): 7 | from mmc.modalities import IMAGE 8 | IMAGE.name == 'image' 9 | 10 | 11 | def test_AUDIO(): 12 | from mmc.modalities import AUDIO 13 | AUDIO.name == 'audio' 14 | 15 | 16 | def test_Modality(): 17 | from mmc.modalities import Modality 18 | FOO = Modality(name='foo') 19 | FOO.name == 'foo' 20 | 21 | #TEXT = Modality(name='text') 22 | 23 | 24 | #def test_(): 25 | # from mmc.modalities import 26 | 27 | 28 | #def test_(): 29 | # from mmc.modalities import 30 | 31 | 32 | #def test_(): 33 | # from mmc.modalities import 34 | 35 | 36 | #read_from_disk 37 | #write_to_disk 38 | #TEXT = Modality(name='text') 39 | 40 | #IMAGE = Modality(name='image') 41 | #IMAGE.read_from_disk = PIL.Image.open 42 | #IMAGE.write_to_disk = lambda fpath, obj: obj.save(fpath) 43 | 44 | #AUDIO = Modality('audio') 45 | -------------------------------------------------------------------------------- /tests/test_module_mock.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from loguru import logger 3 | import mmc 4 | import PIL 5 | import torch 6 | 7 | 8 | def test_init(): 9 | from mmc.mock.openai import MockOpenaiClipModule 10 | from mmc.loaders import OpenAiClipLoader 11 | 12 | ldr = OpenAiClipLoader(id='RN50') 13 | clip = MockOpenaiClipModule(ldr) 14 | 15 | def test_available_models(): 16 | from mmc.mock.openai import MockOpenaiClipModule 17 | from mmc.loaders import OpenAiClipLoader 18 | 19 | ldr = OpenAiClipLoader(id='RN50') 20 | clip = MockOpenaiClipModule(ldr) 21 | assert clip.available_models() == str(ldr) 22 | 23 | def test_private_mmc(): 24 | from mmc.mock.openai import MockOpenaiClipModule, MockOpenaiClip 25 | from mmc.loaders import OpenAiClipLoader 26 | from mmc.multimodalcomparator import MultiModalComparator 27 | 28 | ldr = OpenAiClipLoader(id='RN50') 29 | clip = MockOpenaiClipModule(ldr) 30 | #assert isinstance(clip._clip, MockOpenaiClip) 31 | assert isinstance(clip._model, MultiModalComparator) 32 | 33 | 34 | def test_tokenize(): 35 | from mmc.mock.openai import MockOpenaiClipModule 36 | from mmc.loaders import OpenAiClipLoader 37 | from clip import tokenize 38 | 39 | ldr = OpenAiClipLoader(id='RN50') 40 | clip = MockOpenaiClipModule(ldr) 41 | 42 | test_text = "foo bar baz" 43 | tokens = clip.tokenize(test_text) 44 | assert isinstance(tokens, torch.Tensor) 45 | assert tokens.shape[1] == tokenize(test_text).shape[1] 46 | 47 | def test_preprocess_image(): 48 | pass 49 | 50 | def test_load(): 51 | pass 52 | 53 | 54 | -------------------------------------------------------------------------------- /tests/test_multimmc.py: -------------------------------------------------------------------------------- 1 | from mmc.multimmc import MultiMMC 2 | from mmc.modalities import TEXT, IMAGE, AUDIO 3 | import PIL 4 | from loguru import logger 5 | 6 | def test_init_perceptor(): 7 | perceptor = MultiMMC(TEXT, IMAGE) 8 | 9 | def test_load_clip(): 10 | perceptor = MultiMMC(TEXT, IMAGE) 11 | perceptor.load_model( 12 | architecture='clip', 13 | publisher='openai', 14 | id='RN50', 15 | ) 16 | 17 | def test_supports_modality_property(): 18 | perceptor = MultiMMC(TEXT, IMAGE) 19 | perceptor.load_model( 20 | architecture='clip', 21 | publisher='openai', 22 | id='RN50', 23 | ) 24 | assert perceptor.supports_image 25 | assert perceptor.supports_text 26 | assert not perceptor.supports_audio 27 | 28 | def test_supports_modality_function(): 29 | perceptor = MultiMMC(TEXT, IMAGE) 30 | perceptor.load_model( 31 | architecture='clip', 32 | publisher='openai', 33 | id='RN50', 34 | ) 35 | assert perceptor.supports_modality(IMAGE) 36 | assert perceptor.supports_modality(TEXT) 37 | assert not perceptor.supports_modality(AUDIO) 38 | 39 | 40 | def test_supports_modality_name(): 41 | perceptor = MultiMMC(TEXT, IMAGE) 42 | perceptor.load_model( 43 | architecture='clip', 44 | publisher='openai', 45 | id='RN50', 46 | ) 47 | assert perceptor._supports_mode('image') 48 | assert perceptor._supports_mode('text') 49 | assert not perceptor._supports_mode('audio') 50 | 51 | def test_compare_text2img(): 52 | perceptor = MultiMMC(TEXT, IMAGE) 53 | perceptor.load_model( 54 | architecture='clip', 55 | publisher='openai', 56 | id='RN50', 57 | ) 58 | text_pos = "a photo of a dog" 59 | text_neg = "a painting of a cat" 60 | img = PIL.Image.open('./tests/assets/marley_birthday.jpg').resize((250,200)) 61 | v_pos = perceptor.compare(image=img, text=text_pos) 62 | v_neg = perceptor.compare(image=img, text=text_neg) 63 | logger.debug((v_pos, v_neg)) 64 | assert v_pos > v_neg 65 | 66 | 67 | def test_multi_same_publisher_and_arch(): 68 | from mmc.multimmc import MultiMMC 69 | from mmc.modalities import TEXT, IMAGE 70 | 71 | perceptor = MultiMMC(TEXT, IMAGE) 72 | models = [ 73 | dict( 74 | architecture='clip', 75 | publisher='openai', 76 | id='RN50', 77 | ), 78 | dict( 79 | architecture='clip', 80 | publisher='openai', 81 | id='ViT-B/32', 82 | )] 83 | for m in models: 84 | perceptor.load_model(**m) 85 | 86 | text_pos = "a photo of a dog" 87 | text_neg = "a painting of a cat" 88 | img = PIL.Image.open('./tests/assets/marley_birthday.jpg').resize((250,200)) 89 | v_pos = perceptor.compare(image=img, text=text_pos) 90 | v_neg = perceptor.compare(image=img, text=text_neg) 91 | logger.debug((v_pos, v_neg)) 92 | assert v_pos > v_neg 93 | -------------------------------------------------------------------------------- /tests/test_registry.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | def test_import_module(): 4 | from mmc import registry 5 | 6 | def test_import_REGISTRY(): 7 | from mmc.registry import REGISTRY 8 | 9 | def test_empty_search_REGISTRY(): 10 | from mmc.registry import REGISTRY 11 | REGISTRY.find() 12 | 13 | def test_clip_search_REGISTRY(): 14 | from mmc.registry import REGISTRY 15 | hits = REGISTRY.find(architecture='clip') 16 | assert len(hits) > 1 17 | 18 | def test_register_model(): 19 | from mmc.registry import register_model 20 | --------------------------------------------------------------------------------