├── .gitattributes ├── .github └── workflows │ ├── ci.yml │ ├── clear-cache.yml │ └── python-publish.yml ├── .gitignore ├── CITATION.cff ├── HISTORY.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs ├── CLIP.png ├── Interacting_with_open_clip.ipynb ├── Interacting_with_open_coca.ipynb ├── LOW_ACC.md ├── PRETRAINED.md ├── clip_conceptual_captions.md ├── clip_loss.png ├── clip_recall.png ├── clip_val_loss.png ├── clip_zeroshot.png ├── clipa.md ├── clipa_acc_compute.png ├── clipa_reduce_image_token.png ├── clipa_reduce_text_token.png ├── datacomp_models.md ├── effective_robustness.png ├── inverse_scaling_law.png ├── laion2b_clip_zeroshot_b32.png ├── laion_clip_zeroshot.png ├── laion_clip_zeroshot_b16.png ├── laion_clip_zeroshot_b16_plus_240.png ├── laion_clip_zeroshot_l14.png ├── laion_openai_compare_b32.jpg ├── model_profile.csv ├── openclip_classification_results.csv ├── openclip_multilingual_retrieval_results.csv ├── openclip_results.csv ├── openclip_retrieval_results.csv ├── scaling.png └── script_examples │ ├── clipa │ ├── vit_b16 │ │ ├── i50_t16_finetune.sh │ │ └── i50_t16_pretrain.sh │ └── vit_l16 │ │ ├── i17_t16_finetune.sh │ │ ├── i17_t16_pretrain.sh │ │ ├── i37_t8_finetune.sh │ │ └── i37_t8_pretrain.sh │ ├── clipav2 │ └── vit_h14 │ │ ├── i257_t32_finetunex4.sh │ │ ├── i50_t8_pretrain.sh │ │ └── i577_t32_finetunex1.sh │ └── stability_example.sh ├── pyproject.toml ├── pytest.ini ├── requirements-test.txt ├── requirements-training.txt ├── requirements.txt ├── scripts ├── clipav1_vit_l16_i37_t8.sh ├── clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh ├── h14_224_32_finetune.sh └── h14_84_8_pretrain.sh ├── src ├── open_clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── coca_model.py │ ├── constants.py │ ├── convert.py │ ├── factory.py │ ├── hf_configs.py │ ├── hf_model.py │ ├── loss.py │ ├── model.py │ ├── model_configs │ │ ├── EVA01-g-14-plus.json │ │ ├── EVA01-g-14.json │ │ ├── EVA02-B-16.json │ │ ├── EVA02-E-14-plus.json │ │ ├── EVA02-E-14.json │ │ ├── EVA02-L-14-336.json │ │ ├── EVA02-L-14.json │ │ ├── MobileCLIP-B.json │ │ ├── MobileCLIP-S1.json │ │ ├── MobileCLIP-S2.json │ │ ├── RN101-quickgelu.json │ │ ├── RN101.json │ │ ├── RN50-quickgelu.json │ │ ├── RN50.json │ │ ├── RN50x16-quickgelu.json │ │ ├── RN50x16.json │ │ ├── RN50x4-quickgelu.json │ │ ├── RN50x4.json │ │ ├── RN50x64-quickgelu.json │ │ ├── RN50x64.json │ │ ├── ViT-B-16-SigLIP-256.json │ │ ├── ViT-B-16-SigLIP-384.json │ │ ├── ViT-B-16-SigLIP-512.json │ │ ├── ViT-B-16-SigLIP-i18n-256.json │ │ ├── ViT-B-16-SigLIP.json │ │ ├── ViT-B-16-SigLIP2-256.json │ │ ├── ViT-B-16-SigLIP2-384.json │ │ ├── ViT-B-16-SigLIP2-512.json │ │ ├── ViT-B-16-SigLIP2.json │ │ ├── ViT-B-16-plus-240.json │ │ ├── ViT-B-16-plus.json │ │ ├── ViT-B-16-quickgelu.json │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32-256.json │ │ ├── ViT-B-32-SigLIP2-256.json │ │ ├── ViT-B-32-plus-256.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── ViT-B-32.json │ │ ├── ViT-H-14-378-quickgelu.json │ │ ├── ViT-H-14-378.json │ │ ├── ViT-H-14-CLIPA-336.json │ │ ├── ViT-H-14-CLIPA.json │ │ ├── ViT-H-14-quickgelu.json │ │ ├── ViT-H-14.json │ │ ├── ViT-H-16.json │ │ ├── ViT-L-14-280.json │ │ ├── ViT-L-14-336-quickgelu.json │ │ ├── ViT-L-14-336.json │ │ ├── ViT-L-14-CLIPA-336.json │ │ ├── ViT-L-14-CLIPA.json │ │ ├── ViT-L-14-quickgelu.json │ │ ├── ViT-L-14.json │ │ ├── ViT-L-16-320.json │ │ ├── ViT-L-16-SigLIP-256.json │ │ ├── ViT-L-16-SigLIP-384.json │ │ ├── ViT-L-16-SigLIP2-256.json │ │ ├── ViT-L-16-SigLIP2-384.json │ │ ├── ViT-L-16-SigLIP2-512.json │ │ ├── ViT-L-16.json │ │ ├── ViT-M-16-alt.json │ │ ├── ViT-M-16.json │ │ ├── ViT-M-32-alt.json │ │ ├── ViT-M-32.json │ │ ├── ViT-S-16-alt.json │ │ ├── ViT-S-16.json │ │ ├── ViT-S-32-alt.json │ │ ├── ViT-S-32.json │ │ ├── ViT-SO400M-14-SigLIP-378.json │ │ ├── ViT-SO400M-14-SigLIP-384.json │ │ ├── ViT-SO400M-14-SigLIP.json │ │ ├── ViT-SO400M-14-SigLIP2-378.json │ │ ├── ViT-SO400M-14-SigLIP2.json │ │ ├── ViT-SO400M-16-SigLIP-i18n-256.json │ │ ├── ViT-SO400M-16-SigLIP2-256.json │ │ ├── ViT-SO400M-16-SigLIP2-384.json │ │ ├── ViT-SO400M-16-SigLIP2-512.json │ │ ├── ViT-bigG-14-CLIPA-336.json │ │ ├── ViT-bigG-14-CLIPA.json │ │ ├── ViT-bigG-14-quickgelu.json │ │ ├── ViT-bigG-14.json │ │ ├── ViT-e-14.json │ │ ├── ViT-g-14.json │ │ ├── ViT-gopt-16-SigLIP2-256.json │ │ ├── ViT-gopt-16-SigLIP2-384.json │ │ ├── ViTamin-B-LTT.json │ │ ├── ViTamin-B.json │ │ ├── ViTamin-L-256.json │ │ ├── ViTamin-L-336.json │ │ ├── ViTamin-L-384.json │ │ ├── ViTamin-L.json │ │ ├── ViTamin-L2-256.json │ │ ├── ViTamin-L2-336.json │ │ ├── ViTamin-L2-384.json │ │ ├── ViTamin-L2.json │ │ ├── ViTamin-S-LTT.json │ │ ├── ViTamin-S.json │ │ ├── ViTamin-XL-256.json │ │ ├── ViTamin-XL-336.json │ │ ├── ViTamin-XL-384.json │ │ ├── coca_ViT-B-32.json │ │ ├── coca_ViT-L-14.json │ │ ├── coca_base.json │ │ ├── coca_roberta-ViT-B-32.json │ │ ├── convnext_base.json │ │ ├── convnext_base_w.json │ │ ├── convnext_base_w_320.json │ │ ├── convnext_large.json │ │ ├── convnext_large_d.json │ │ ├── convnext_large_d_320.json │ │ ├── convnext_small.json │ │ ├── convnext_tiny.json │ │ ├── convnext_xlarge.json │ │ ├── convnext_xxlarge.json │ │ ├── convnext_xxlarge_320.json │ │ ├── mt5-base-ViT-B-32.json │ │ ├── mt5-xl-ViT-H-14.json │ │ ├── nllb-clip-base-siglip.json │ │ ├── nllb-clip-base.json │ │ ├── nllb-clip-large-siglip.json │ │ ├── nllb-clip-large.json │ │ ├── roberta-ViT-B-32.json │ │ ├── swin_base_patch4_window7_224.json │ │ ├── vit_medium_patch16_gap_256.json │ │ ├── vit_relpos_medium_patch16_cls_224.json │ │ ├── xlm-roberta-base-ViT-B-32.json │ │ └── xlm-roberta-large-ViT-H-14.json │ ├── modified_resnet.py │ ├── openai.py │ ├── pos_embed.py │ ├── pretrained.py │ ├── push_to_hf_hub.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── transformer.py │ ├── utils.py │ ├── version.py │ ├── zero_shot_classifier.py │ └── zero_shot_metadata.py └── open_clip_train │ ├── __init__.py │ ├── data.py │ ├── distributed.py │ ├── file_utils.py │ ├── logger.py │ ├── main.py │ ├── params.py │ ├── precision.py │ ├── profiler.py │ ├── scheduler.py │ ├── train.py │ └── zero_shot.py ├── tests ├── test_download_pretrained.py ├── test_hf_model.py ├── test_inference.py ├── test_inference_simple.py ├── test_num_shards.py ├── test_training_simple.py ├── test_wds.py └── util_test.py └── tutorials └── int8_tutorial.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | *.py linguist-language=python 2 | *.ipynb linguist-documentation 3 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Continuous integration 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths-ignore: 8 | - '**.md' 9 | - 'CITATION.cff' 10 | - 'LICENSE' 11 | - '.gitignore' 12 | - 'docs/**' 13 | pull_request: 14 | branches: 15 | - main 16 | paths-ignore: 17 | - '**.md' 18 | - 'CITATION.cff' 19 | - 'LICENSE' 20 | - '.gitignore' 21 | - 'docs/**' 22 | workflow_dispatch: 23 | inputs: 24 | manual_revision_reference: 25 | required: false 26 | type: string 27 | manual_revision_test: 28 | required: false 29 | type: string 30 | 31 | env: 32 | REVISION_REFERENCE: v2.8.2 33 | #9d31b2ec4df6d8228f370ff20c8267ec6ba39383 earliest compatible v2.7.0 + pretrained_hf param 34 | 35 | jobs: 36 | Tests: 37 | strategy: 38 | matrix: 39 | os: [ ubuntu-latest ] #, macos-latest ] 40 | python: [ 3.8 ] 41 | job_num: [ 4 ] 42 | job: [ 1, 2, 3, 4 ] 43 | runs-on: ${{ matrix.os }} 44 | steps: 45 | - uses: actions/checkout@v3 46 | with: 47 | fetch-depth: 0 48 | ref: ${{ inputs.manual_revision_test }} 49 | - name: Set up Python ${{ matrix.python }} 50 | id: pythonsetup 51 | uses: actions/setup-python@v4 52 | with: 53 | python-version: ${{ matrix.python }} 54 | - name: Venv cache 55 | id: venv-cache 56 | uses: actions/cache@v3 57 | with: 58 | path: .env 59 | key: venv-${{ matrix.os }}-${{ steps.pythonsetup.outputs.python-version }}-${{ hashFiles('requirements*') }} 60 | - name: Pytest durations cache 61 | uses: actions/cache@v3 62 | with: 63 | path: .test_durations 64 | key: test_durations-${{ matrix.os }}-${{ steps.pythonsetup.outputs.python-version }}-${{ matrix.job }}-${{ github.run_id }} 65 | restore-keys: test_durations-0- 66 | - name: Setup 67 | if: steps.venv-cache.outputs.cache-hit != 'true' 68 | run: | 69 | python3 -m venv .env 70 | source .env/bin/activate 71 | pip install -e .[test] 72 | - name: Prepare test data 73 | run: | 74 | source .env/bin/activate 75 | python -m pytest \ 76 | --quiet --co \ 77 | --splitting-algorithm least_duration \ 78 | --splits ${{ matrix.job_num }} \ 79 | --group ${{ matrix.job }} \ 80 | -m regression_test \ 81 | tests \ 82 | | head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=-False]|-True])' \ 83 | > models_gh_runner.txt 84 | if [ -n "${{ inputs.manual_revision_reference }}" ]; then 85 | REVISION_REFERENCE=${{ inputs.manual_revision_reference }} 86 | fi 87 | python tests/util_test.py \ 88 | --save_model_list models_gh_runner.txt \ 89 | --model_list models_gh_runner.txt \ 90 | --git_revision $REVISION_REFERENCE 91 | - name: Unit tests 92 | run: | 93 | source .env/bin/activate 94 | if [[ -f .test_durations ]] 95 | then 96 | cp .test_durations durations_1 97 | mv .test_durations durations_2 98 | fi 99 | python -m pytest \ 100 | -x -s -v \ 101 | --splitting-algorithm least_duration \ 102 | --splits ${{ matrix.job_num }} \ 103 | --group ${{ matrix.job }} \ 104 | --store-durations \ 105 | --durations-path durations_1 \ 106 | --clean-durations \ 107 | -m "not regression_test" \ 108 | tests 109 | OPEN_CLIP_TEST_REG_MODELS=models_gh_runner.txt python -m pytest \ 110 | -x -s -v \ 111 | --store-durations \ 112 | --durations-path durations_2 \ 113 | --clean-durations \ 114 | -m "regression_test" \ 115 | tests 116 | jq -s -S 'add' durations_* > .test_durations 117 | - name: Collect pytest durations 118 | uses: actions/upload-artifact@v4 119 | with: 120 | name: pytest_durations_${{ matrix.os }}-${{ matrix.python }}-${{ matrix.job }} 121 | path: .test_durations 122 | -------------------------------------------------------------------------------- /.github/workflows/clear-cache.yml: -------------------------------------------------------------------------------- 1 | name: Clear cache 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | permissions: 7 | actions: write 8 | 9 | jobs: 10 | clear-cache: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Clear cache 14 | uses: actions/github-script@v6 15 | with: 16 | script: | 17 | const caches = await github.rest.actions.getActionsCacheList({ 18 | owner: context.repo.owner, 19 | repo: context.repo.repo, 20 | }) 21 | for (const cache of caches.data.actions_caches) { 22 | console.log(cache) 23 | await github.rest.actions.deleteActionsCacheById({ 24 | owner: context.repo.owner, 25 | repo: context.repo.repo, 26 | cache_id: cache.id, 27 | }) 28 | } 29 | 30 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions-ecosystem/action-regex-match@v2 13 | id: regex-match 14 | with: 15 | text: ${{ github.event.head_commit.message }} 16 | regex: '^Release ([^ ]+)' 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine build 25 | - name: Release 26 | if: ${{ steps.regex-match.outputs.match != '' }} 27 | uses: softprops/action-gh-release@v1 28 | with: 29 | tag_name: v${{ steps.regex-match.outputs.group1 }} 30 | - name: Build and publish 31 | if: ${{ steps.regex-match.outputs.match != '' }} 32 | env: 33 | TWINE_USERNAME: __token__ 34 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 35 | run: | 36 | python -m build 37 | twine upload dist/* 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/logs/ 2 | **/wandb/ 3 | models/ 4 | features/ 5 | results/ 6 | 7 | tests/data/ 8 | *.pt 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | sync.sh 140 | gpu1sync.sh 141 | .idea 142 | *.pdf 143 | **/._* 144 | **/*DS_* 145 | **.jsonl 146 | src/sbatch 147 | src/misc 148 | .vscode 149 | src/debug 150 | core.* 151 | 152 | # Allow 153 | !src/evaluation/misc/results_dbs/* 154 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: If you use this software, please cite it as below. 3 | authors: 4 | - family-names: Ilharco 5 | given-names: Gabriel 6 | - family-names: Wortsman 7 | given-names: Mitchell 8 | - family-names: Wightman 9 | given-names: Ross 10 | - family-names: Gordon 11 | given-names: Cade 12 | - family-names: Carlini 13 | given-names: Nicholas 14 | - family-names: Taori 15 | given-names: Rohan 16 | - family-names: Dave 17 | given-names: Achal 18 | - family-names: Shankar 19 | given-names: Vaishaal 20 | - family-names: Namkoong 21 | given-names: Hongseok 22 | - family-names: Miller 23 | given-names: John 24 | - family-names: Hajishirzi 25 | given-names: Hannaneh 26 | - family-names: Farhadi 27 | given-names: Ali 28 | - family-names: Schmidt 29 | given-names: Ludwig 30 | title: OpenCLIP 31 | version: v0.1 32 | doi: 10.5281/zenodo.5143773 33 | date-released: 2021-07-28 34 | -------------------------------------------------------------------------------- /HISTORY.md: -------------------------------------------------------------------------------- 1 | ## 2.24.0 2 | 3 | * Fix missing space in error message 4 | * use model flag for normalizing embeddings 5 | * init logit_bias for non siglip pretrained models 6 | * Fix logit_bias load_checkpoint addition 7 | * Make CoCa model match CLIP models for logit scale/bias init 8 | * Fix missing return of "logit_bias" in CoCa.forward 9 | * Add NLLB-CLIP with SigLIP models 10 | * Add get_logits method and NLLB tokenizer 11 | * Remove the empty file src/open_clip/generation_utils.py 12 | * Update params.py: "BatchNorm" -> "LayerNorm" in the description string for "--lock-text-freeze-layer-norm" 13 | 14 | ## 2.23.0 15 | 16 | * Add CLIPA-v2 models 17 | * Add SigLIP models 18 | * Add MetaCLIP models 19 | * Add NLLB-CLIP models 20 | * CLIPA train code 21 | * Minor changes/fixes 22 | * Remove protobuf version limit 23 | * Stop checking model name when loading CoCa models 24 | * Log native wandb step 25 | * Use bool instead of long masks 26 | 27 | ## 2.21.0 28 | 29 | * Add SigLIP loss + training support 30 | * Add more DataComp models (B/16, B/32 and B/32@256) 31 | * Update default num workers 32 | * Update CoCa generation for `transformers>=4.31` 33 | * PyTorch 2.0 `state_dict()` compatibility fix for compiled models 34 | * Fix padding in `ResizeMaxSize` 35 | * Convert JIT model on state dict load for `pretrained='filename…'` 36 | * Other minor changes and fixes (typos, README, dependencies, CI) 37 | 38 | ## 2.20.0 39 | 40 | * Add EVA models 41 | * Support serial worker training 42 | * Fix Python 3.7 compatibility 43 | 44 | ## 2.19.0 45 | 46 | * Add DataComp models 47 | 48 | ## 2.18.0 49 | 50 | * Enable int8 inference without `.weight` attribute 51 | 52 | ## 2.17.2 53 | 54 | * Update push_to_hf_hub 55 | 56 | ## 2.17.0 57 | 58 | * Add int8 support 59 | * Update notebook demo 60 | * Refactor zero-shot classification code 61 | 62 | ## 2.16.2 63 | 64 | * Fixes for context_length and vocab_size attributes 65 | 66 | ## 2.16.1 67 | 68 | * Fixes for context_length and vocab_size attributes 69 | * Fix --train-num-samples logic 70 | * Add HF BERT configs for PubMed CLIP model 71 | 72 | ## 2.16.0 73 | 74 | * Add improved g-14 weights 75 | * Update protobuf version 76 | 77 | ## 2.15.0 78 | 79 | * Add convnext_xxlarge weights 80 | * Fixed import in readme 81 | * Add samples per second per gpu logging 82 | * Fix slurm example 83 | 84 | ## 2.14.0 85 | 86 | * Move dataset mixtures logic to shard level 87 | * Fix CoCa accum-grad training 88 | * Safer transformers import guard 89 | * get_labels refactoring 90 | 91 | ## 2.13.0 92 | 93 | * Add support for dataset mixtures with different sampling weights 94 | * Make transformers optional again 95 | 96 | ## 2.12.0 97 | 98 | * Updated convnext configs for consistency 99 | * Added input_patchnorm option 100 | * Clean and improve CoCa generation 101 | * Support model distillation 102 | * Add ConvNeXt-Large 320x320 fine-tune weights 103 | 104 | ## 2.11.1 105 | 106 | * Make transformers optional 107 | * Add MSCOCO CoCa finetunes to pretrained models 108 | 109 | ## 2.11.0 110 | 111 | * coca support and weights 112 | * ConvNeXt-Large weights 113 | 114 | ## 2.10.1 115 | 116 | * `hf-hub:org/model_id` support for loading models w/ config and weights in Hugging Face Hub 117 | 118 | ## 2.10.0 119 | 120 | * Added a ViT-bigG-14 model. 121 | * Added an up-to-date example slurm script for large training jobs. 122 | * Added a option to sync logs and checkpoints to S3 during training. 123 | * New options for LR schedulers, constant and constant with cooldown 124 | * Fix wandb autoresuming when resume is not set 125 | * ConvNeXt `base` & `base_w` pretrained models added 126 | * `timm-` model prefix removed from configs 127 | * `timm` augmentation + regularization (dropout / drop-path) supported 128 | 129 | ## 2.9.3 130 | 131 | * Fix wandb collapsing multiple parallel runs into a single one 132 | 133 | ## 2.9.2 134 | 135 | * Fix braceexpand memory explosion for complex webdataset urls 136 | 137 | ## 2.9.1 138 | 139 | * Fix release 140 | 141 | ## 2.9.0 142 | 143 | * Add training feature to auto-resume from the latest checkpoint on restart via `--resume latest` 144 | * Allow webp in webdataset 145 | * Fix logging for number of samples when using gradient accumulation 146 | * Add model configs for convnext xxlarge 147 | 148 | ## 2.8.2 149 | 150 | * wrapped patchdropout in a torch.nn.Module 151 | 152 | ## 2.8.1 153 | 154 | * relax protobuf dependency 155 | * override the default patch dropout value in 'vision_cfg' 156 | 157 | ## 2.8.0 158 | 159 | * better support for HF models 160 | * add support for gradient accumulation 161 | * CI fixes 162 | * add support for patch dropout 163 | * add convnext configs 164 | 165 | 166 | ## 2.7.0 167 | 168 | * add multilingual H/14 xlm roberta large 169 | 170 | ## 2.6.1 171 | 172 | * fix setup.py _read_reqs 173 | 174 | ## 2.6.0 175 | 176 | * Make openclip training usable from pypi. 177 | * Add xlm roberta large vit h 14 config. 178 | 179 | ## 2.5.0 180 | 181 | * pretrained B/32 xlm roberta base: first multilingual clip trained on laion5B 182 | * pretrained B/32 roberta base: first clip trained using an HF text encoder 183 | 184 | ## 2.4.1 185 | 186 | * Add missing hf_tokenizer_name in CLIPTextCfg. 187 | 188 | ## 2.4.0 189 | 190 | * Fix #211, missing RN50x64 config. Fix type of dropout param for ResNet models 191 | * Bring back LayerNorm impl that casts to input for non bf16/fp16 192 | * zero_shot.py: set correct tokenizer based on args 193 | * training/params.py: remove hf params and get them from model config 194 | 195 | ## 2.3.1 196 | 197 | * Implement grad checkpointing for hf model. 198 | * custom_text: True if hf_model_name is set 199 | * Disable hf tokenizer parallelism 200 | 201 | ## 2.3.0 202 | 203 | * Generalizable Text Transformer with HuggingFace Models (@iejMac) 204 | 205 | ## 2.2.0 206 | 207 | * Support for custom text tower 208 | * Add checksum verification for pretrained model weights 209 | 210 | ## 2.1.0 211 | 212 | * lot including sota models, bfloat16 option, better loading, better metrics 213 | 214 | ## 1.2.0 215 | 216 | * ViT-B/32 trained on Laion2B-en 217 | * add missing openai RN50x64 model 218 | 219 | ## 1.1.1 220 | 221 | * ViT-B/16+ 222 | * Add grad checkpointing support 223 | * more robust data loader 224 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, 2 | Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, 3 | John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, 4 | Ludwig Schmidt 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining 7 | a copy of this software and associated documentation files (the 8 | "Software"), to deal in the Software without restriction, including 9 | without limitation the rights to use, copy, modify, merge, publish, 10 | distribute, sublicense, and/or sell copies of the Software, and to 11 | permit persons to whom the Software is furnished to do so, subject to 12 | the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be 15 | included in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 21 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 22 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 23 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/open_clip/bpe_simple_vocab_16e6.txt.gz 2 | include src/open_clip/model_configs/*.json 3 | 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-training: 6 | python -m pip install -r requirements-training.txt 7 | 8 | install-test: ## [Local development] Install test requirements 9 | python -m pip install -r requirements-test.txt 10 | 11 | test: ## [Local development] Run unit tests 12 | python -m pytest -x -s -v tests 13 | -------------------------------------------------------------------------------- /docs/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/CLIP.png -------------------------------------------------------------------------------- /docs/LOW_ACC.md: -------------------------------------------------------------------------------- 1 | As we describe in more detail below, CLIP models in a medium accuracy regime already allow us to draw conclusions about the robustness of larger CLIP models since the models follow reliable scaling laws. 2 | 3 | [Cherti et al., 2022](https://arxiv.org/abs/2212.07143) and [Gadre et al., 2023](https://arxiv.org/abs/2304.14108) show additional discussions about the scaling behavior of CLIP models. 4 | 5 | ## Scaling trends 6 | 7 | The plot below shows how zero-shot performance of CLIP models varies as we scale the number of samples used for training. Zero-shot performance increases steadily for both ImageNet and [ImageNetV2](https://arxiv.org/abs/1902.10811), and is far from saturated at ~15M samples. 8 | 9 | 10 | 11 | ## Why are low-accuracy CLIP models interesting? 12 | 13 | **TL;DR:** CLIP models have high effective robustness, even at small scales. 14 | 15 | CLIP models are particularly intriguing because they are more robust to natural distribution shifts (see Section 3.3 in the [CLIP paper](https://arxiv.org/abs/2103.00020)). 16 | This phenomena is illustrated by the figure below, with ImageNet accuracy on the x-axis 17 | and [ImageNetV2](https://arxiv.org/abs/1902.10811) (a reproduction of the ImageNet validation set with distribution shift) accuracy on the y-axis. 18 | Standard training denotes training on the ImageNet train set and the CLIP zero-shot models 19 | are shown as stars. 20 | 21 | ![CLIP scatter plot](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/effective_robustness.png) 22 | 23 | As observed by [Taori et al., 2020](https://arxiv.org/abs/2007.00644) and [Miller et al., 2021](https://arxiv.org/abs/2107.04649), the in-distribution 24 | and out-of-distribution accuracies of models trained on ImageNet follow a predictable linear trend (the red line in the above plot). *Effective robustness* 25 | quantifies robustness as accuracy beyond this baseline, i.e., how far a model lies above the red line. Ideally a model would not suffer from distribution shift and fall on the y = x line ([trained human labelers are within a percentage point of the y = x line](http://proceedings.mlr.press/v119/shankar20c.html)). 26 | 27 | Even though the CLIP models trained with 28 | this codebase achieve much lower accuracy than those trained by OpenAI, our models still lie on the same 29 | trend of improved effective robustness (the purple line). Therefore, we can study what makes 30 | CLIP robust without requiring industrial-scale compute. 31 | 32 | For more information on effective robustness, please see: 33 | 34 | - [Recht et al., 2019](https://arxiv.org/abs/1902.10811). 35 | - [Taori et al., 2020](https://arxiv.org/abs/2007.00644). 36 | - [Miller et al., 2021](https://arxiv.org/abs/2107.04649). 37 | 38 | To know more about the factors that contribute to CLIP's robustness refer to [Fang et al., 2022](https://arxiv.org/abs/2205.01397). -------------------------------------------------------------------------------- /docs/clip_conceptual_captions.md: -------------------------------------------------------------------------------- 1 | ## Additional training curves for CLIP on Conceptual Captions 2 | 3 | # Zero shot accuracy 4 | ![](/docs/clip_zeroshot.png) 5 | 6 | # Training loss curve 7 | ![](/docs/clip_loss.png) 8 | 9 | # Validation loss curve 10 | ![](/docs/clip_val_loss.png) 11 | 12 | # Validation recall 13 | ![](/docs/clip_recall.png) -------------------------------------------------------------------------------- /docs/clip_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/clip_loss.png -------------------------------------------------------------------------------- /docs/clip_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/clip_recall.png -------------------------------------------------------------------------------- /docs/clip_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/clip_val_loss.png -------------------------------------------------------------------------------- /docs/clip_zeroshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/clip_zeroshot.png -------------------------------------------------------------------------------- /docs/clipa.md: -------------------------------------------------------------------------------- 1 | ## CLIPA 2 | 3 | In this work, we present a surprising finding that there exists an _inverse_ scaling law for CLIP training, 4 | whereby the larger the image/text encoders used, the shorter the sequence length of image/text tokens that can be applied in training. 5 | Moreover, we showcase that the strategy for reducing image/text token length plays a crucial role in determining the quality of this scaling law. 6 | 7 | ![](/docs/inverse_scaling_law.png) 8 | 9 | As a result of this finding, we are able to successfully train CLIP even by using academic resources. 10 | For example, on an A100 eight-GPU server, our CLIP models achieve zero-shot top-1 ImageNet accuracies of **63.2%** in about **2 days**, 11 | **67.8%** in about **3 days**, and **69.3%** in about **4 days**. 12 | 13 | Moreover, We find that CLIPA at scale leads to state-of-the-art performance. For example, our CLIPA-v2 H/14 achieves a zero-shot top-1 ImageNet accuracy of **81.8%**, 14 | with a budget less than **$15000**. 15 | 16 | ![](/docs/clipa_acc_compute.png) 17 | 18 | For more details, please see our paper [An Inverse Scaling Law for CLIP Training](https://arxiv.org/abs/2305.07017) and 19 | [CLIPA-v2: Scaling CLIP Training with 81.1% Zero-shot ImageNet Accuracy within a $10,000 Budget; An Extra $4,000 Unlocks 81.8% Accuracy](https://arxiv.org/abs/2306.15658). 20 | 21 | 22 | Eight token length reduction strategies are investigated in this work, detailed as follows. 23 | 24 | 25 | ## Image token length reduction 26 | 27 | ![](/docs/clipa_reduce_image_token.png) 28 | 29 | * `resize`: use `--force-image-size` to specify the image size you want to adopt. We find this strategy generally works the best as it retains full image information. 30 | 31 | * `random mask`: Randomly mask out image patches. use `--force-patch-dropout` to specify the mask ratio you want to adopt. 32 | 33 | * `grid mask`: Preserve one patch in each 2 × 2 grid window. We do not provide implementation for grid masking, as it is only experimental and we generally find resizing works better. 34 | 35 | * `block mask`: Keep a single block and remove other patches. We do not provide implementation for block masking, as it is only experimental and we generally find resizing works better. 36 | 37 | 38 | ## Text token length reduction 39 | 40 | * `syntax mask`: Assign different masking priorities to parts of speech. Specify `"text_mask": syntax` in `"tokenizer_kwargs"` in `"text_cfg"` of model config `json` file to use. 41 | Specifically, we prioritize retaining nouns, followed by adjectives, and then other words. 42 | We find this strategy generally works the best as it retains critical information for contrastive learning. 43 | 44 | * `truncate`: Truncation selects the first N text tokens and discards the rest. This is the default setting of `open_clip`. 45 | 46 | * `random mask`: Randomly drops a portion of the text tokens. Specify `"text_mask": random` in `"tokenizer_kwargs"` in `"text_cfg"` of model config `json` file to use. 47 | 48 | * `block mask`: Randomly preserves consecutive text sequences. Specify `"text_mask": block` in `"tokenizer_kwargs"` in `"text_cfg"` of model config `json` file to use. 49 | 50 | 51 | ## Installation 52 | 53 | The installation is really the same as `open_clip`, except for the usage of Natural Language Toolkit (NLTK) in `syntax mask` of text token length reduction. 54 | Please follow the [official doc](https://www.nltk.org/) to install NLTK. 55 | 56 | Note that the the usage of NLTK brings two constraints: 57 | * Because certain functions like `nltk.pos_tag` from NLTK only support English and Russian for now, the `syntax mask` only works for English. 58 | we have not tested it on Russian or any other language. Theoretically, it should work the same, given a proper language processing toolkit for other languages. 59 | If you still want to apply `syntax mask` on other languages, try finding the right toolkit. Otherwise, use other text token length reduction strategies 60 | * some modules of NLTK like `punkt` or `averaged_perceptron_tagger` need to be downloaded first before using NLTK. 61 | We have included the downloading code in `tokenizer.py`, but this might cause trouble in certain cases. 62 | You may want to manually download those modules first, by `nltk.download('punkt')` and `nltk.download('averaged_perceptron_tagger')`, 63 | and then setup the environmental variable before running the script `export NLTK_DATA=cache`. 64 | Note that this is a one-time effort. Remember to comment out those `nltk.download` lines in `tokenizer.py` afterwards. 65 | 66 | ## Training 67 | We provide example scripts to reproduce our CLIPA results on an A100 eight-GPU machine under path `docs/script_examples/clipa`. 68 | 69 | For instance, to reproduce the CLIPA-L16(I37,T8) results, first run the pre-training script 70 | ``` 71 | bash docs/script_examples/clipa/vit_l16/i37_t8_pretrain.sh 72 | ``` 73 | and fine-tune the pre-trained checkpoint with 74 | ``` 75 | bash docs/script_examples/clipa/vit_l16/i37_t8_finetune.sh 76 | ``` 77 | - Remember to change the path to dataset to your own path. 78 | - This is a two-stage training pipeline. Remember to change the path to pre-trained checkpoint to your own when fine-tuning. 79 | - The training time is ~3 days for pre-training and ~1 day for fine-tuning on an A100 eight-GPU machine. 80 | 81 | ## Model Weights 82 | Below are CLIPA trained weights on LAION-400M with an A100 eight-GPU machine. 83 | All models are pre-trained for 6 epochs with reduced input token lengths and subsequently fine-tuned for 0.36 epoch with full input token lengths. 84 | 85 | 86 | | | Pre-trained Weights | zero-shot IN-1K | 87 | |---------------------|:----------------------------------------------------------------------------------------------:|:---------------:| 88 | | CLIPA-B/16(I50,T16) | [download](https://drive.google.com/file/d/1MDpz8gV2Vjaazk16rBhLxU8811U7_cGL/view?usp=sharing) | 59.7 | 89 | | CLIPA-L/16(I17,T16) | [download](https://drive.google.com/file/d/1Tr2GYiKAaMH6EGIn5l7eX_1K20eaA3WA/view?usp=sharing) | 60.3 | 90 | | CLIPA_L/16(I37,T8) | [download](https://drive.google.com/file/d/1EM1ChRNARpLckkJjf6m7njCY3xyvpGBu/view?usp=sharing) | 57.9 | 91 | 92 | | | Fine-tuned Weights | zero-shot IN-1K | 93 | |---------------------|:----------------------------------------------------------------------------------------------:|:-----:| 94 | | CLIPA-B/16(I50,T16) | [download](https://drive.google.com/file/d/1fURK0K_a3-83jVEI4PVEbnEJb_V6UbGv/view?usp=sharing) | 63.2 | 95 | | CLIPA-L/16(I17,T16) | [download](https://drive.google.com/file/d/18qqZGOTGOgb3I3JWONuat6qObsgLq7sR/view?usp=sharing) | 67.8 | 96 | | CLIPA_L/16(I37,T8) | [download](https://drive.google.com/file/d/1lV7pLORUK04T9QKKx9TpYtMws-AZrib0/view?usp=sharing) | 69.3 | 97 | 98 | 99 | ## CLIPA-v2 100 | We also provide example scripts to reproduce our CLIPA-v2 H/14 results under path `docs/script_examples/clipav2`. 101 | Note that the original results are obtained with [our JAX implementation](https://github.com/UCSC-VLAA/CLIPA/tree/master/clipa_jax). 102 | These scripts are written after manually scanning the JAX config files. 103 | As it is infeasible for us to retrain those models again with pytorch, its correctness cannot be verified with 100% confidence. Use them at your own discretion. 104 | -------------------------------------------------------------------------------- /docs/clipa_acc_compute.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/clipa_acc_compute.png -------------------------------------------------------------------------------- /docs/clipa_reduce_image_token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/clipa_reduce_image_token.png -------------------------------------------------------------------------------- /docs/clipa_reduce_text_token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/clipa_reduce_text_token.png -------------------------------------------------------------------------------- /docs/datacomp_models.md: -------------------------------------------------------------------------------- 1 | ## CommonPool and DataComp models 2 | 3 | As part of [DataComp](https://github.com/mlfoundations/datacomp), we trained models on CommonPool using various data filtering strategies. 4 | We release models for all four scales of the competition, small, medium, large and xlarge, corresponding to a pool size and number of samples seen of 12.8M, 128M, 1.28B and 12.8B, respectively. 5 | 6 | The models are specified below, see our paper [DataComp: In seearch of the next generation of multimodal datasets](https://arxiv.org/abs/2304.14108) for more details. 7 | 8 | 9 | ## xlarge scale models 10 | 11 | * `datacomp_xl_s13b_b90k`: A ViT-L/14 trained on DataComp-1B for 12.8B steps and batch size 90k. Achieves 79.2% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K. 12 | 13 | * `commonpool_xl_clip_s13b_b90k`: A ViT-L/14 trained on CommonPool-XL filtered using CLIP scores, for 12.8B steps and batch size 90k. Achieves 76.4% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K. 14 | 15 | * `commonpool_xl_laion_s13b_b90k`: A ViT-L/14 trained on CommonPool-XL filtered using the LAION-2B filtering scheme, for 12.8B steps and batch size 90k. Achieves 75.5% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K. 16 | 17 | * `commonpool_xl_s13b_b90k`: A ViT-L/14 trained on CommonPool-XL without any filtering, for 12.8B steps and batch size 90k. Achieves 72.3% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K. 18 | 19 | 20 | ## large scale models 21 | 22 | * `datacomp_l_s1b_b8k`: A ViT-B/16 trained on a 140M subset of DataComp-1B, for 1.28B steps and batch size 8k. Achieves 63.1% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K. 23 | 24 | * `commonpool_l_clip_s1b_b8k`: A ViT-B/16 trained on CommonPool-L filtered using CLIP scores, for 1.28B steps and batch size 8k. Achieves 57.8% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K. 25 | 26 | * `commonpool_l_laion_s1b_b8k`: A ViT-B/16 trained on CommonPool-L filtered using the LAION-2B filtering scheme, for 1.28B steps and batch size 8k. Achieves 55.3% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K. 27 | 28 | * `commonpool_l_image_s1b_b8k`: A ViT-B/16 trained on CommonPool-L filtered using image-based filtering, for 1.28B steps and batch size 8k. Achieves 57.2% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K. 29 | 30 | * `commonpool_l_text_s1b_b8k`: A ViT-B/16 trained on CommonPool-L filtered using text-based filtering, for 1.28B steps and batch size 8k. Achieves 56.1% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K. 31 | 32 | * `commonpool_l_basic_s1b_b8k`: A ViT-B/16 trained on CommonPool-L filtered using basic filtering (English filtering + caption length and image size filtering), for 1.28B steps and batch size 8k. Achieves 51.6% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K. 33 | 34 | * `commonpool_l_s1b_b8k`: A ViT-B/16 trained on CommonPool-L without any filtering, for 1.28B steps and batch size 8k. Achieves 45.9% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K. 35 | 36 | 37 | ## medium scale models 38 | 39 | * `datacomp_m_s128m_b4k`: A ViT-B/32 trained on a 14M subset of DataComp-1B, for 128M steps and batch size 4k. Achieves 29.7% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K. 40 | 41 | * `commonpool_m_clip_s128m_b4k`: A ViT-B/32 trained on CommonPool-M filtered using CLIP scores, for 128M steps and batch size 4k. Achieves 27.3% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K. 42 | 43 | * `commonpool_m_laion_s128m_b4k`: A ViT-B/32 trained on CommonPool-M filtered using the LAION-2B filtering scheme, for 128M steps and batch size 4k. Achieves 23.0% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K. 44 | 45 | * `commonpool_m_image_s128m_b4k`: A ViT-B/32 trained on CommonPool-M filtered using image-based filtering, for 128M steps and batch size 4k. Achieves 26.8% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K. 46 | 47 | * `commonpool_m_text_s128m_b4k`: A ViT-B/32 trained on CommonPool-M filtered using text-based filtering, for 128M steps and batch size 4k. Achieves 25.5% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K. 48 | 49 | * `commonpool_m_basic_s128m_b4k`: A ViT-B/32 trained on CommonPool-M filtered using basic filtering (English filtering + caption length and image size filtering), for 128M steps and batch size 4k. Achieves 22.6% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K. 50 | 51 | * `commonpool_m_s128m_b4k`: A ViT-B/32 trained on CommonPool-M without any filtering, for 128M steps and batch size 4k. Achieves 17.6% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K. 52 | 53 | 54 | ## small scale models 55 | 56 | * `datacomp_s_s13m_b4k`: A ViT-B/32 trained on a 1.4M subset of DataComp-1B, for 12.8M steps and batch size 4k. Achieves 3.9% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K. 57 | 58 | * `commonpool_s_clip_s13m_b4k`: A ViT-B/32 trained on CommonPool-S filtered using CLIP scores, for 12.8M steps and batch size 4k. Achieves 5.1% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K. 59 | 60 | * `commonpool_s_laion_s13m_b4k`: A ViT-B/32 trained on CommonPool-S filtered using the LAION-2B filtering scheme scores, for 12.8M steps and batch size 4k. Achieves 3.1% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K. 61 | 62 | * `commonpool_s_image_s13m_b4k`: A ViT-B/32 trained on CommonPool-S filtered using image-based filtering, for 12.8M steps and batch size 4k. Achieves 4.3% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K. 63 | 64 | * `commonpool_s_text_s13m_b4k`: A ViT-B/32 trained on CommonPool-S filtered using text-based filtering, for 12.8M steps and batch size 4k. Achieves 4.6% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K. 65 | 66 | * `commonpool_s_basic_s13m_b4k`: A ViT-B/32 trained on CommonPool-S filtered using basic filtering (English filtering + caption length and image size filtering), for 12.8M steps and batch size 4k. Achieves 3.0% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K. 67 | 68 | * `commonpool_s_s13m_b4k`: A ViT-B/32 trained on CommonPool-S without any filtering, for 12.8M steps and batch size 4k. Achieves 2.5% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K. 69 | 70 | -------------------------------------------------------------------------------- /docs/effective_robustness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/effective_robustness.png -------------------------------------------------------------------------------- /docs/inverse_scaling_law.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/inverse_scaling_law.png -------------------------------------------------------------------------------- /docs/laion2b_clip_zeroshot_b32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/laion2b_clip_zeroshot_b32.png -------------------------------------------------------------------------------- /docs/laion_clip_zeroshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/laion_clip_zeroshot.png -------------------------------------------------------------------------------- /docs/laion_clip_zeroshot_b16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/laion_clip_zeroshot_b16.png -------------------------------------------------------------------------------- /docs/laion_clip_zeroshot_b16_plus_240.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/laion_clip_zeroshot_b16_plus_240.png -------------------------------------------------------------------------------- /docs/laion_clip_zeroshot_l14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/laion_clip_zeroshot_l14.png -------------------------------------------------------------------------------- /docs/laion_openai_compare_b32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/laion_openai_compare_b32.jpg -------------------------------------------------------------------------------- /docs/model_profile.csv: -------------------------------------------------------------------------------- 1 | model,image_size,image_width,text_width,embed_dim,mparams,image_mparams,text_mparams,gflops,image_gflops,text_gflops 2 | ViT-S-32-alt,224,384,256,256,43.22,22.59,20.63,3.56,2.29,1.27 3 | ViT-S-32,224,384,384,384,63.09,22.64,40.44,5.66,2.29,3.38 4 | ViT-M-32-alt,224,512,384,384,80.07,39.63,40.44,7.37,3.99,3.38 5 | ViT-M-32,224,512,512,512,103.12,39.69,63.43,9.95,3.99,5.96 6 | ViT-S-16-alt,224,384,256,256,42.4,21.76,20.63,10.47,9.2,1.27 7 | ViT-S-16,224,384,384,384,62.26,21.81,40.44,12.58,9.2,3.38 8 | ViT-B-32,224,768,512,512,151.28,87.85,63.43,14.78,8.82,5.96 9 | ViT-B-32-quickgelu,224,768,512,512,151.28,87.85,63.43,14.78,8.82,5.96 10 | convnext_tiny,224,768,512,1024,92.3,28.61,63.69,14.87,8.91,5.96 11 | ViT-B-32-256,256,768,512,512,151.29,87.86,63.43,17.46,11.5,5.96 12 | RN50,224,64,512,1024,102.01,38.32,63.69,18.18,12.22,5.96 13 | RN50-quickgelu,224,64,512,1024,102.01,38.32,63.69,18.18,12.22,5.96 14 | ViT-M-16-alt,224,512,384,384,78.98,38.53,40.44,19.36,15.98,3.38 15 | ViT-M-16,224,512,512,512,102.02,38.59,63.43,21.94,15.98,5.96 16 | vit_relpos_medium_patch16_cls_224,224,768,512,512,101.94,38.51,63.43,21.99,16.03,5.96 17 | mt5-base-ViT-B-32,224,768,512,512,365.71,87.85,277.86,22.12,8.82,13.3 18 | convnext_small,224,768,512,512,113.28,49.85,63.43,23.33,17.37,5.96 19 | ViT-B-32-plus-256,256,896,640,640,210.3,119.13,91.16,24.83,15.56,9.27 20 | RN101,224,64,512,512,119.69,56.26,63.43,25.5,19.54,5.96 21 | RN101-quickgelu,224,64,512,512,119.69,56.26,63.43,25.5,19.54,5.96 22 | vit_medium_patch16_gap_256,256,768,512,512,102.04,38.61,63.43,27.1,21.14,5.96 23 | coca_ViT-B-32,224,768,512,512,253.56,89.16,63.43,33.34,9.19,5.96 24 | convnext_base,224,768,512,512,151.52,88.09,63.43,36.67,30.71,5.96 25 | swin_base_patch4_window7_224,224,768,640,640,178.56,87.4,91.16,40.13,30.86,9.27 26 | ViT-B-16,224,768,512,512,149.62,86.19,63.43,41.09,35.13,5.96 27 | ViT-B-16-quickgelu,224,768,512,512,149.62,86.19,63.43,41.09,35.13,5.96 28 | EVA02-B-16,224,768,512,512,149.69,86.26,63.43,41.09,35.13,5.96 29 | ViT-B-16-SigLIP,224,768,768,768,203.16,92.88,110.27,46.44,35.42,11.02 30 | convnext_base_w,256,768,640,640,179.39,88.22,91.16,49.38,40.11,9.27 31 | RN50x4,288,80,640,640,178.3,87.14,91.16,51.82,42.56,9.27 32 | coca_roberta-ViT-B-32,224,768,768,512,420.37,87.85,124.45,53.12,8.82,13.12 33 | ViT-B-16-plus,224,896,640,640,208.35,117.19,91.16,56.75,47.49,9.27 34 | ViT-B-16-SigLIP-256,256,768,768,768,203.2,92.93,110.27,57.84,46.82,11.02 35 | ViT-B-16-SigLIP-i18n-256,256,768,768,768,370.63,92.93,277.7,57.84,46.82,11.02 36 | ViT-B-16-plus-240,240,896,640,640,208.38,117.21,91.16,64.03,54.76,9.27 37 | convnext_base_w_320,320,768,640,640,179.39,88.22,91.16,71.94,62.67,9.27 38 | convnext_large,224,768,768,768,321.06,197.41,123.65,82.02,68.72,13.3 39 | coca_base,288,768,768,512,440.34,86.4,134.66,99.09,46.47,13.3 40 | roberta-ViT-B-32,224,768,512,512,212.72,87.85,124.87,105.87,8.82,97.05 41 | xlm-roberta-base-ViT-B-32,224,768,512,512,366.12,87.85,278.27,105.87,8.82,97.05 42 | convnext_large_d,256,768,768,768,351.77,199.77,152.0,107.5,89.76,17.73 43 | ViT-B-16-SigLIP-384,384,768,768,768,203.45,93.18,110.27,123.15,112.13,11.02 44 | ViT-L-16,224,1024,768,768,427.74,304.09,123.65,136.41,123.11,13.3 45 | convnext_large_d_320,320,768,768,768,351.77,199.77,152.0,157.98,140.25,17.73 46 | RN50x16,384,96,768,768,290.98,167.33,123.65,162.69,149.39,13.3 47 | ViT-L-14-CLIPA,224,1024,768,768,414.21,303.96,110.25,167.5,162.03,5.47 48 | EVA02-L-14,224,768,768,768,427.76,304.11,123.65,175.3,162.0,13.3 49 | ViT-L-14,224,1024,768,768,427.62,303.97,123.65,175.33,162.03,13.3 50 | ViT-L-14-quickgelu,224,1024,768,768,427.62,303.97,123.65,175.33,162.03,13.3 51 | convnext_xlarge,256,768,1024,1024,653.89,350.25,303.65,198.38,159.14,39.24 52 | ViT-L-16-SigLIP-256,256,768,1024,1024,652.15,315.96,336.19,201.62,162.56,39.06 53 | coca_ViT-L-14,224,1024,768,768,638.45,306.72,123.65,214.52,163.64,13.3 54 | ViT-B-16-SigLIP-512,512,768,768,768,203.79,93.52,110.27,227.26,216.24,11.02 55 | ViT-SO400M-14-SigLIP,224,768,1152,1152,877.36,427.68,449.68,233.54,220.35,13.19 56 | ViT-L-14-280,280,1024,768,768,427.76,304.11,123.65,271.79,258.49,13.3 57 | ViT-L-16-320,320,1024,768,768,427.95,304.3,123.65,271.93,258.63,13.3 58 | ViT-H-16,224,1280,1024,1024,986.26,632.23,354.03,301.72,254.63,47.09 59 | ViT-H-14-CLIPA,224,1280,1024,1024,968.24,632.07,336.16,354.02,334.59,19.43 60 | nllb-clip-base,224,768,512,512,501.89,87.85,414.04,369.6,8.82,360.78 61 | ViT-H-14,224,1280,1024,1024,986.11,632.08,354.03,381.68,334.59,47.09 62 | ViT-H-14-quickgelu,224,1280,1024,1024,986.11,632.08,354.03,381.68,334.59,47.09 63 | ViT-L-14-CLIPA-336,336,1024,768,768,414.54,304.29,110.25,387.39,381.92,5.47 64 | EVA02-L-14-336,336,768,768,768,428.08,304.43,123.65,395.16,381.86,13.3 65 | ViT-L-14-336,336,1024,768,768,427.94,304.29,123.65,395.22,381.92,13.3 66 | ViT-L-16-SigLIP-384,384,768,1024,1024,652.48,316.28,336.19,422.91,383.85,39.06 67 | convnext_xxlarge,256,768,1024,1024,1200.58,846.54,354.03,443.03,395.94,47.09 68 | nllb-clip-base-siglip,384,768,512,768,507.47,93.18,414.3,472.91,112.13,360.78 69 | mt5-xl-ViT-H-14,224,1280,512,1024,2306.75,632.08,1674.68,514.04,334.59,179.45 70 | EVA01-g-14,224,768,768,1024,1136.44,1012.59,123.85,547.36,534.06,13.3 71 | RN50x64,448,128,1024,1024,623.26,420.38,202.88,552.65,529.11,23.55 72 | EVA01-g-14-plus,224,768,1024,1024,1366.62,1012.59,354.03,581.15,534.06,47.09 73 | ViT-g-14,224,1408,1024,1024,1366.68,1012.65,354.03,581.15,534.06,47.09 74 | convnext_xxlarge_320,320,768,1024,1024,1200.58,846.54,354.03,665.74,618.65,47.09 75 | xlm-roberta-large-ViT-H-14,224,1280,512,1024,1193.01,632.08,560.94,671.01,334.59,336.42 76 | ViT-SO400M-14-SigLIP-384,384,768,1152,1152,877.96,428.23,449.73,723.48,670.35,53.13 77 | ViT-H-14-CLIPA-336,336,1280,1024,1024,968.64,632.48,336.16,800.88,781.45,19.43 78 | ViT-bigG-14-CLIPA,224,1664,1280,1280,2517.22,1844.9,672.32,1007.93,967.5,40.44 79 | ViT-H-14-378-quickgelu,378,1280,1024,1024,986.71,632.68,354.03,1054.05,1006.96,47.09 80 | ViT-bigG-14,224,1664,1280,1280,2539.57,1844.91,694.66,1065.36,967.5,97.86 81 | nllb-clip-large,224,1280,512,1024,1399.22,632.08,767.14,1468.46,334.59,1133.87 82 | nllb-clip-large-siglip,384,768,512,1152,1195.5,428.23,767.27,1804.22,670.35,1133.87 83 | ViT-e-14,224,1792,1280,1280,4581.09,3807.72,773.37,2091.45,1981.35,110.1 84 | ViT-bigG-14-CLIPA-336,336,1664,1280,1280,2517.76,1845.44,672.32,2271.58,2231.15,40.44 85 | EVA02-E-14,224,768,1024,1024,4704.59,4350.56,354.03,2311.42,2264.33,47.09 86 | EVA02-E-14-plus,224,768,1280,1024,5044.89,4350.56,694.33,2362.19,2264.33,97.86 87 | -------------------------------------------------------------------------------- /docs/scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/docs/scaling.png -------------------------------------------------------------------------------- /docs/script_examples/clipa/vit_b16/i50_t16_finetune.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m open_clip_train.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "2.56e-5" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 3072 \ 11 | --wd 0.2 \ 12 | --batch-size 1024 \ 13 | --aug-cfg scale='(0.4, 1.0)' \ 14 | --epochs 1 \ 15 | --train-num-samples 131072000 \ 16 | --workers 6 \ 17 | --model ViT-B-16-CL16 \ 18 | --pretrained '/path/to/ckpt' \ 19 | --precision 'amp_bf16' \ 20 | --ddp-static-graph \ 21 | --local-loss \ 22 | --gather-with-grad \ 23 | --grad-checkpointing \ 24 | --log-every-n-steps 256 \ 25 | --seed 0 \ 26 | --logs ./logs/ \ 27 | --imagenet-val '/path/to/imagenet/val' 28 | -------------------------------------------------------------------------------- /docs/script_examples/clipa/vit_b16/i50_t16_pretrain.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m open_clip_train.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "2.048e-3" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 782 \ 11 | --wd 0.2 \ 12 | --batch-size 8192 \ 13 | --aug-cfg scale='(0.4, 1.0)' \ 14 | --epochs 6 \ 15 | --workers 6 \ 16 | --model ViT-B-16-CL16 \ 17 | --precision 'amp_bf16' \ 18 | --ddp-static-graph \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 112 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 32 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /docs/script_examples/clipa/vit_l16/i17_t16_finetune.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m open_clip_train.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "2.24e-5" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 3571 \ 11 | --wd 0.2 \ 12 | --batch-size 896 \ 13 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 14 | --epochs 1 \ 15 | --train-num-samples 131072000 \ 16 | --workers 6 \ 17 | --model ViT-L-16-CL16-GAP \ 18 | --pretrained '/path/to/ckpt' \ 19 | --precision 'amp_bf16' \ 20 | --ddp-static-graph \ 21 | --local-loss \ 22 | --gather-with-grad \ 23 | --grad-checkpointing \ 24 | --log-every-n-steps 293 \ 25 | --seed 0 \ 26 | --logs ./logs/ \ 27 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /docs/script_examples/clipa/vit_l16/i17_t16_pretrain.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m open_clip_train.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "1.024e-3" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 1563 \ 11 | --wd 0.2 \ 12 | --batch-size 4096 \ 13 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 14 | --epochs 6 \ 15 | --workers 6 \ 16 | --model ViT-L-16-CL16-GAP \ 17 | --precision 'amp_bf16' \ 18 | --ddp-static-graph \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 64 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 64 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /docs/script_examples/clipa/vit_l16/i37_t8_finetune.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m open_clip_train.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "2.24e-5" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 3571 \ 11 | --wd 0.2 \ 12 | --batch-size 896 \ 13 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 14 | --epochs 1 \ 15 | --train-num-samples 131072000 \ 16 | --workers 6 \ 17 | --model ViT-L-16-CL32-GAP \ 18 | --pretrained '/path/to/ckpt' \ 19 | --precision 'amp_bf16' \ 20 | --ddp-static-graph \ 21 | --local-loss \ 22 | --gather-with-grad \ 23 | --grad-checkpointing \ 24 | --log-every-n-steps 293 \ 25 | --seed 0 \ 26 | --logs ./logs/ \ 27 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /docs/script_examples/clipa/vit_l16/i37_t8_pretrain.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m open_clip_train.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "1.024e-3" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 1563 \ 11 | --wd 0.2 \ 12 | --batch-size 4096 \ 13 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 14 | --epochs 6 \ 15 | --workers 6 \ 16 | --model ViT-L-16-CL8-Syntax-GAP \ 17 | --precision 'amp_bf16' \ 18 | --ddp-static-graph \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 96 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 64 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /docs/script_examples/clipav2/vit_h14/i257_t32_finetunex4.sh: -------------------------------------------------------------------------------- 1 | # have not been tested. use it at your own discretion 2 | # the original experiment was run on tpu v3-256. 3 | # this example script assumes 8 gpus, each with huge memory. Tune batchsize, warmup, and lr accordingly if you have different machine setups. 4 | torchrun --nproc_per_node 8 -m open_clip_train.main \ 5 | --save-frequency 1 \ 6 | --save-most-recent \ 7 | --zeroshot-frequency 1 \ 8 | --train-data '/path/to/laion2b_or_datacomp1b' \ 9 | --train-num-samples 131072000 \ 10 | --dataset-type webdataset \ 11 | --lr "5.12e-5" \ 12 | --beta1 0.9 \ 13 | --beta2 0.95 \ 14 | --warmup 800 \ 15 | --wd 0.2 \ 16 | --batch-size 4096 \ 17 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 18 | --epochs 4 \ 19 | --workers 6 \ 20 | --model ViT-H-14-CL32-GAP \ 21 | --pretrained '/path/to/pretrain84_ckpt' \ 22 | --precision 'amp_bf16' \ 23 | --ddp-static-graph \ 24 | --local-loss \ 25 | --gather-with-grad \ 26 | --force-image-size 224 \ 27 | --force-patch-dropout 0.3 \ 28 | --grad-checkpointing \ 29 | --log-every-n-steps 64 \ 30 | --seed 0 \ 31 | --logs ./logs/ \ 32 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /docs/script_examples/clipav2/vit_h14/i50_t8_pretrain.sh: -------------------------------------------------------------------------------- 1 | # have not been tested. use it at your own discretion 2 | # the original experiment was run on tpu v3-256. 3 | # this example script assumes 8 gpus, each with huge memory. Tune batchsize, warmup, and lr accordingly if you have different machine setups. 4 | torchrun --nproc_per_node 8 -m open_clip_train.main \ 5 | --save-frequency 1 \ 6 | --save-most-recent \ 7 | --zeroshot-frequency 1 \ 8 | --train-data '/path/to/laion2b_or_datacomp1b' \ 9 | --train-num-samples 4e8 \ 10 | --dataset-type webdataset \ 11 | --lr "2.048e-3" \ 12 | --beta1 0.9 \ 13 | --beta2 0.95 \ 14 | --warmup 3200 \ 15 | --wd 0.2 \ 16 | --batch-size 8192 \ 17 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 18 | --epochs 32 \ 19 | --workers 6 \ 20 | --model ViT-H-14-CL8-Syntax-GAP \ 21 | --precision 'amp_bf16' \ 22 | --ddp-static-graph \ 23 | --local-loss \ 24 | --gather-with-grad \ 25 | --force-image-size 84 \ 26 | --grad-checkpointing \ 27 | --log-every-n-steps 32 \ 28 | --seed 0 \ 29 | --logs ./logs/ \ 30 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /docs/script_examples/clipav2/vit_h14/i577_t32_finetunex1.sh: -------------------------------------------------------------------------------- 1 | # have not been tested. use it at your own discretion 2 | # the original experiment was run on tpu v3-256. 3 | # this example script assumes 8 gpus, each with huge memory. Tune batchsize, warmup, and lr accordingly if you have different machine setups. 4 | torchrun --nproc_per_node 8 -m open_clip_train.main \ 5 | --save-frequency 1 \ 6 | --save-most-recent \ 7 | --zeroshot-frequency 1 \ 8 | --train-data '/path/to/laion2b_or_datacomp1b' \ 9 | --train-num-samples 131072000 \ 10 | --dataset-type webdataset \ 11 | --lr "6.4e-6" \ 12 | --beta1 0.9 \ 13 | --beta2 0.95 \ 14 | --warmup 1600 \ 15 | --wd 0.2 \ 16 | --batch-size 2048 \ 17 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 18 | --epochs 1 \ 19 | --workers 6 \ 20 | --model ViT-H-14-CL32-GAP \ 21 | --pretrained '/path/to/finetune224_ckpt' \ 22 | --precision 'amp_bf16' \ 23 | --ddp-static-graph \ 24 | --local-loss \ 25 | --gather-with-grad \ 26 | --force-image-size 336 \ 27 | --force-patch-dropout 0.4 \ 28 | --grad-checkpointing \ 29 | --log-every-n-steps 64 \ 30 | --seed 0 \ 31 | --logs ./logs/ \ 32 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /docs/script_examples/stability_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=g40423 3 | #SBATCH --job-name=testopenclip 4 | #SBATCH --nodes 30 5 | #SBATCH --ntasks-per-node=8 6 | #SBATCH --cpus-per-task=12 7 | #SBATCH --output=%x_%j.out 8 | #SBATCH --comment=laion 9 | #SBATCH --open-mode=append 10 | #SBATCH --exclusive 11 | 12 | module load openmpi 13 | module load cuda/11.7 14 | 15 | export MASTER_ADDR=`hostname` 16 | export MASTER_PORT=12802 17 | export NCCL_PROTO=simple 18 | export FI_EFA_FORK_SAFE=1 19 | export FI_LOG_LEVEL=1 20 | export FI_EFA_USE_DEVICE_RDMA=1 21 | export NCCL_DEBUG=info 22 | 23 | export PYTHONFAULTHANDLER=1 24 | 25 | export CUDA_LAUNCH_BLOCKING=0 26 | export OMPI_MCA_mtl_base_verbose=1 27 | export FI_EFA_ENABLE_SHM_TRANSFER=0 28 | export FI_PROVIDER=efa 29 | export FI_EFA_TX_MIN_CREDITS=64 30 | export NCCL_TREE_THRESHOLD=0 31 | 32 | cd /admin/home-mitchellw/open_clip/src 33 | export PYTHONPATH="$PYTHONPATH:/admin/home-mitchellw/open_clip/src" 34 | 35 | EXP_NAME="test-B-32-laion5b-lr1e-3-bs90k" 36 | 37 | srun --comment laion --cpu_bind=v --accel-bind=gn python -m open_clip_train.main \ 38 | --save-frequency 1 \ 39 | --train-data="pipe:aws s3 cp s3://s-datasets/laion5b/{laion2B-data/{000000..231349}.tar,laion2B-multi-data/{000000..226687}.tar,laion1B-nolang-data/{000000..127231}.tar} -" \ 40 | --train-num-samples 135646078 \ 41 | --dataset-type webdataset \ 42 | --dataset-resampled \ 43 | --warmup 2000 \ 44 | --batch-size=375 \ 45 | --epochs=97 \ 46 | --lr 1e-3 \ 47 | --workers=8 \ 48 | --report-to wandb \ 49 | --name ${EXP_NAME} \ 50 | --logs /scratch/logs/ \ 51 | --model ViT-B-32 \ 52 | --seed 0 \ 53 | --ddp-static-graph \ 54 | --local-loss \ 55 | --gather-with-grad \ 56 | --grad-checkpointing \ 57 | --precision amp_bfloat16 \ 58 | --wandb-project-name open_clip6 \ 59 | --resume "latest" \ 60 | --remote-sync s3://s-laion/mitchellw/logs 61 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["pdm-backend"] 3 | build-backend = "pdm.backend" 4 | 5 | [project] 6 | name = "open_clip_torch" 7 | # NOTE for full list of authors see https://github.com/mlfoundations/open_clip?tab=readme-ov-file#citing 8 | # below covers most active / recent maintainers 9 | authors = [ 10 | {name = "Ross Wightman", email = "ross@huggingface.co"}, 11 | {name = "Gabriel Ilharco"}, 12 | {name = "Mitchell Wortsman"}, 13 | {name = "Romain Beaumont"}, 14 | ] 15 | description = "Open reproduction of consastive language-image pretraining (CLIP) and related." 16 | readme = "README.md" 17 | requires-python = ">=3.8" 18 | keywords = ["pytorch", "clip", "image-text", "language-image", "multimodal"] 19 | license = {text = "MIT"} 20 | classifiers = [ 21 | 'Development Status :: 4 - Beta', 22 | 'Intended Audience :: Education', 23 | 'Intended Audience :: Science/Research', 24 | 'License :: OSI Approved :: MIT License', 25 | 'Programming Language :: Python :: 3.8', 26 | 'Programming Language :: Python :: 3.9', 27 | 'Programming Language :: Python :: 3.10', 28 | 'Programming Language :: Python :: 3.11', 29 | 'Programming Language :: Python :: 3.12', 30 | 'Topic :: Scientific/Engineering', 31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 32 | 'Topic :: Software Development', 33 | 'Topic :: Software Development :: Libraries', 34 | 'Topic :: Software Development :: Libraries :: Python Modules', 35 | ] 36 | dependencies = [ 37 | 'torch>=1.9.0', 38 | 'torchvision', 39 | 'regex', 40 | 'ftfy', 41 | 'tqdm', 42 | 'huggingface-hub', 43 | 'safetensors', 44 | 'timm', 45 | ] 46 | dynamic = ["version"] 47 | 48 | [project.optional-dependencies] 49 | training = [ 50 | 'torch>=2.0', 51 | 'webdataset>=0.2.5,<=0.2.86', 52 | 'pandas', 53 | 'transformers[sentencepiece]', 54 | 'timm>=1.0.10', 55 | 'fsspec', 56 | ] 57 | test = [ 58 | 'pytest-split', 59 | 'pytest', 60 | 'open_clip_torch[training]' 61 | ] 62 | 63 | [project.urls] 64 | homepage = "https://github.com/mlfoundations/open_clip" 65 | repository = "https://github.com/mlfoundations/open_clip" 66 | 67 | [tool.pdm.version] 68 | source = "file" 69 | path = "src/open_clip/version.py" 70 | 71 | [tool.pdm.build] 72 | excludes = ["./**/.git", "./**/logs/*"] 73 | package-dir = "src" 74 | includes = ["src/open_clip", "src/open_clip_train"] 75 | 76 | [tool.pytest.ini_options] 77 | testpaths = ['tests'] 78 | markers = [ 79 | 'regression_test' 80 | ] -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | regression_test 4 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest-split==0.8.0 2 | pytest==7.2.0 3 | transformers[sentencepiece] 4 | timm>=1.0.10 5 | -------------------------------------------------------------------------------- /requirements-training.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | webdataset>=0.2.5,<=0.2.86 4 | regex 5 | ftfy 6 | tqdm 7 | pandas 8 | braceexpand 9 | huggingface_hub 10 | safetensors 11 | transformers[sentencepiece] 12 | timm>=1.0.15 13 | fsspec 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | regex 4 | ftfy 5 | tqdm 6 | huggingface_hub 7 | safetensors 8 | timm 9 | -------------------------------------------------------------------------------- /scripts/clipav1_vit_l16_i37_t8.sh: -------------------------------------------------------------------------------- 1 | # eval on a single gpu 2 | CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m open_clip_train.main \ 3 | --model ViT-L-16-CL32-GAP \ 4 | --pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \ 5 | --seed 0 \ 6 | --imagenet-val '/path/to/ImageNet/val' -------------------------------------------------------------------------------- /scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python3 -m open_clip_train.main \ 2 | --model ViT-H-14-CL32-GAP-BigVision \ 3 | --pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \ 4 | --force-image-size 336 \ 5 | --square-resize-only \ 6 | --interpolation 'bilinear' \ 7 | --image-mean 0.485 0.456 0.406 \ 8 | --image-std 0.229 0.224 0.225 \ 9 | --seed 0 \ 10 | --imagenet-val '/path/to/ImageNet/val' 11 | -------------------------------------------------------------------------------- /scripts/h14_224_32_finetune.sh: -------------------------------------------------------------------------------- 1 | # 64k batchsize for 2.048e-3 lr 2 | TORCH_CUDNN_V8_API_ENABLED=1 torchrun --nproc_per_node 8 -m open_clip_train.main \ 3 | --save-frequency 1 \ 4 | --save-most-recent \ 5 | --zeroshot-frequency 1 \ 6 | --train-data '/path/to/laion' \ 7 | --dataset-type webdataset \ 8 | --lr "2.048e-3" \ 9 | --beta1 0.9 \ 10 | --beta2 0.95 \ 11 | --warmup 782 \ 12 | --wd 0.2 \ 13 | --batch-size 4096 \ 14 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 15 | --epochs=7 \ 16 | --workers=6 \ 17 | --model ViT-H-14-CL32-GAP \ 18 | --precision 'amp_bf16' \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 224 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 32 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/ImageNet/val' \ 27 | --name 'name' \ 28 | --report-to "wandb" \ 29 | --wandb-project-name "project_name" 30 | 31 | 32 | -------------------------------------------------------------------------------- /scripts/h14_84_8_pretrain.sh: -------------------------------------------------------------------------------- 1 | # 64k batchsize for 2.048e-3 lr 2 | TORCH_CUDNN_V8_API_ENABLED=1 torchrun --nproc_per_node 8 -m open_clip_train.main \ 3 | --save-frequency 1 \ 4 | --save-most-recent \ 5 | --zeroshot-frequency 1 \ 6 | --train-data '/path/to/laion' \ 7 | --dataset-type webdataset \ 8 | --lr "2.048e-3" \ 9 | --beta1 0.9 \ 10 | --beta2 0.95 \ 11 | --warmup 782 \ 12 | --wd 0.2 \ 13 | --batch-size 4096 \ 14 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 15 | --epochs=7 \ 16 | --workers=6 \ 17 | --model ViT-H-14-CL8-SyntaxMask-GAP \ 18 | --precision 'amp_bf16' \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 84 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 32 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/ImageNet/val' \ 27 | --name 'name' \ 28 | --report-to "wandb" \ 29 | --wandb-project-name "project_name" 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | 3 | from .coca_model import CoCa 4 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 5 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 6 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 7 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 8 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 9 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ 10 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg 11 | from .openai import load_openai_model, list_openai_models 12 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 13 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 14 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 15 | from .tokenizer import SimpleTokenizer, tokenize, decode 16 | from .transform import image_transform, AugmentationCfg 17 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy 18 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES 19 | -------------------------------------------------------------------------------- /src/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/src/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_STD = (0.229, 0.224, 0.225) 5 | INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | INCEPTION_STD = (0.5, 0.5, 0.5) 7 | 8 | # Default name for a weights file hosted on the Huggingface Hub. 9 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl 10 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version 11 | HF_CONFIG_NAME = 'open_clip_config.json' 12 | -------------------------------------------------------------------------------- /src/open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | # https://huggingface.co/docs/transformers/model_doc/bert 46 | "bert": { 47 | "config_names": { 48 | "context_length": "max_position_embeddings", 49 | "vocab_size": "vocab_size", 50 | "width": "hidden_size", 51 | "heads": "num_attention_heads", 52 | "layers": "num_hidden_layers", 53 | }, 54 | "pooler": "cls_pooler", 55 | }, 56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100 57 | "m2m_100": { 58 | "config_names": { 59 | "context_length": "max_position_embeddings", 60 | "vocab_size": "vocab_size", 61 | "width": "d_model", 62 | "heads": "encoder_attention_heads", 63 | "layers": "encoder_layers", 64 | }, 65 | "pooler": "cls_pooler", 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /src/open_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | import re 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import TensorType 10 | 11 | try: 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 15 | BaseModelOutputWithPoolingAndCrossAttentions 16 | except ImportError as e: 17 | transformers = None 18 | 19 | 20 | class BaseModelOutput: 21 | pass 22 | 23 | 24 | class PretrainedConfig: 25 | pass 26 | 27 | from .hf_configs import arch_dict 28 | 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(? List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model 91 | -------------------------------------------------------------------------------- /src/open_clip/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /src/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | from itertools import repeat 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | from torch import nn as nn 7 | from torch import _assert 8 | from torchvision.ops.misc import FrozenBatchNorm2d 9 | 10 | 11 | def freeze_batch_norm_2d(module, module_match={}, name=''): 12 | """ 13 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 14 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 15 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 16 | 17 | Args: 18 | module (torch.nn.Module): Any PyTorch module. 19 | module_match (dict): Dictionary of full module names to freeze (all if empty) 20 | name (str): Full module name (prefix) 21 | 22 | Returns: 23 | torch.nn.Module: Resulting module 24 | 25 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 26 | """ 27 | res = module 28 | is_match = True 29 | if module_match: 30 | is_match = name in module_match 31 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 32 | res = FrozenBatchNorm2d(module.num_features) 33 | res.num_features = module.num_features 34 | res.affine = module.affine 35 | if module.affine: 36 | res.weight.data = module.weight.data.clone().detach() 37 | res.bias.data = module.bias.data.clone().detach() 38 | res.running_mean.data = module.running_mean.data 39 | res.running_var.data = module.running_var.data 40 | res.eps = module.eps 41 | else: 42 | for child_name, child in module.named_children(): 43 | full_child_name = '.'.join([name, child_name]) if name else child_name 44 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 45 | if new_child is not child: 46 | res.add_module(child_name, new_child) 47 | return res 48 | 49 | 50 | # From PyTorch internals 51 | def _ntuple(n): 52 | def parse(x): 53 | if isinstance(x, collections.abc.Iterable): 54 | return x 55 | return tuple(repeat(x, n)) 56 | return parse 57 | 58 | 59 | to_1tuple = _ntuple(1) 60 | to_2tuple = _ntuple(2) 61 | to_3tuple = _ntuple(3) 62 | to_4tuple = _ntuple(4) 63 | to_ntuple = lambda n, x: _ntuple(n)(x) 64 | 65 | # Replaces all linear layers with linear_replacement 66 | # TODO: add int8 support for other linear layers including attn and convnets 67 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 68 | for name, module in model.named_children(): 69 | if len(list(module.children())) > 0: 70 | replace_linear(module, linear_replacement, include_modules, copy_weights) 71 | 72 | if isinstance(module, torch.nn.Linear) and name in include_modules: 73 | old_module = model._modules[name] 74 | model._modules[name] = linear_replacement( 75 | module.in_features, 76 | module.out_features, 77 | module.bias is not None, 78 | ) 79 | if copy_weights: 80 | model._modules[name].weight.data.copy_(old_module.weight.data) 81 | if model._modules[name].bias is not None: 82 | model._modules[name].bias.data.copy_(old_module.bias) 83 | 84 | return model 85 | 86 | def convert_int8_model_to_inference_mode(model): 87 | for m in model.modules(): 88 | if hasattr(m, 'prepare_for_eval'): 89 | int8_original_dtype = m.weight.dtype 90 | m.prepare_for_eval() 91 | m.int8_original_dtype = int8_original_dtype 92 | 93 | 94 | def feature_take_indices( 95 | num_features: int, 96 | indices: Optional[Union[int, List[int]]] = None, 97 | as_set: bool = False, 98 | ) -> Tuple[List[int], int]: 99 | """ Determine the absolute feature indices to 'take' from. 100 | 101 | Note: This function can be called in forward() so must be torchscript compatible, 102 | which requires some incomplete typing and workaround hacks. 103 | 104 | Args: 105 | num_features: total number of features to select from 106 | indices: indices to select, 107 | None -> select all 108 | int -> select last n 109 | list/tuple of int -> return specified (-ve indices specify from end) 110 | as_set: return as a set 111 | 112 | Returns: 113 | List (or set) of absolute (from beginning) indices, Maximum index 114 | """ 115 | if indices is None: 116 | indices = num_features # all features if None 117 | 118 | if isinstance(indices, int): 119 | # convert int -> last n indices 120 | _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') 121 | take_indices = [num_features - indices + i for i in range(indices)] 122 | else: 123 | take_indices: List[int] = [] 124 | for i in indices: 125 | idx = num_features + i if i < 0 else i 126 | _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') 127 | take_indices.append(idx) 128 | 129 | if not torch.jit.is_scripting() and as_set: 130 | return set(take_indices), max(take_indices) 131 | 132 | return take_indices, max(take_indices) 133 | 134 | 135 | def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: 136 | if isinstance(x, int): 137 | # if indices is an int, take last N features 138 | return tuple(range(-x, 0)) 139 | return tuple(x) 140 | -------------------------------------------------------------------------------- /src/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.32.0' 2 | -------------------------------------------------------------------------------- /src/open_clip/zero_shot_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | from typing import Callable, List, Optional, Sequence, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def batched(iterable, n): 10 | """Batch data into lists of length *n*. The last batch may be shorter. 11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl 12 | """ 13 | it = iter(iterable) 14 | while True: 15 | batch = list(islice(it, n)) 16 | if not batch: 17 | break 18 | yield batch 19 | 20 | 21 | def build_zero_shot_classifier( 22 | model, 23 | tokenizer, 24 | classnames: Sequence[str], 25 | templates: Sequence[Union[Callable, str]], 26 | num_classes_per_batch: Optional[int] = 10, 27 | device: Union[str, torch.device] = 'cpu', 28 | use_tqdm: bool = False, 29 | ): 30 | """ Build zero-shot classifier weights by iterating over class names in batches 31 | Args: 32 | model: CLIP model instance 33 | tokenizer: CLIP tokenizer instance 34 | classnames: A sequence of class (label) names 35 | templates: A sequence of callables or format() friendly strings to produce templates per class name 36 | num_classes_per_batch: The number of classes to batch together in each forward, all if None 37 | device: Device to use. 38 | use_tqdm: Enable TQDM progress bar. 39 | """ 40 | assert isinstance(templates, Sequence) and len(templates) > 0 41 | assert isinstance(classnames, Sequence) and len(classnames) > 0 42 | use_format = isinstance(templates[0], str) 43 | num_templates = len(templates) 44 | num_classes = len(classnames) 45 | if use_tqdm: 46 | import tqdm 47 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) 48 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) 49 | else: 50 | iter_wrap = iter 51 | 52 | def _process_batch(batch_classnames): 53 | num_batch_classes = len(batch_classnames) 54 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] 55 | texts = tokenizer(texts).to(device) 56 | class_embeddings = model.encode_text(texts, normalize=True) 57 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) 58 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) 59 | class_embeddings = class_embeddings.T 60 | return class_embeddings 61 | 62 | with torch.no_grad(): 63 | if num_classes_per_batch: 64 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] 65 | zeroshot_weights = torch.cat(batched_embeds, dim=1) 66 | else: 67 | zeroshot_weights = _process_batch(classnames) 68 | return zeroshot_weights 69 | 70 | 71 | def build_zero_shot_classifier_legacy( 72 | model, 73 | tokenizer, 74 | classnames: Sequence[str], 75 | templates: Sequence[Union[Callable, str]], 76 | device: Union[str, torch.device] = 'cpu', 77 | use_tqdm: bool = False, 78 | ): 79 | """ Build zero-shot classifier weights by iterating over class names 1 by 1 80 | Args: 81 | model: CLIP model instance 82 | tokenizer: CLIP tokenizer instance 83 | classnames: A sequence of class (label) names 84 | templates: A sequence of callables or format() friendly strings to produce templates per class name 85 | device: Device to use. 86 | use_tqdm: Enable TQDM progress bar. 87 | """ 88 | assert isinstance(templates, Sequence) and len(templates) > 0 89 | assert isinstance(classnames, Sequence) and len(classnames) > 0 90 | if use_tqdm: 91 | import tqdm 92 | iter_wrap = tqdm.tqdm 93 | else: 94 | iter_wrap = iter 95 | 96 | use_format = isinstance(templates[0], str) 97 | 98 | with torch.no_grad(): 99 | zeroshot_weights = [] 100 | for classname in iter_wrap(classnames): 101 | texts = [template.format(classname) if use_format else template(classname) for template in templates] 102 | texts = tokenizer(texts).to(device) # tokenize 103 | class_embeddings = model.encode_text(texts) 104 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 105 | class_embedding /= class_embedding.norm() 106 | zeroshot_weights.append(class_embedding) 107 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 108 | 109 | return zeroshot_weights 110 | 111 | -------------------------------------------------------------------------------- /src/open_clip_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlfoundations/open_clip/131f46f6fc1c5caebdc6e7d6c487d33de5b85d2a/src/open_clip_train/__init__.py -------------------------------------------------------------------------------- /src/open_clip_train/file_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import multiprocessing 4 | import subprocess 5 | import time 6 | import fsspec 7 | import torch 8 | from tqdm import tqdm 9 | 10 | def remote_sync_s3(local_dir, remote_dir): 11 | # skip epoch_latest which can change during sync. 12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | if result.returncode != 0: 14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") 15 | return False 16 | 17 | logging.info(f"Successfully synced with S3 bucket") 18 | return True 19 | 20 | def remote_sync_fsspec(local_dir, remote_dir): 21 | # FIXME currently this is slow and not recommended. Look into speeding up. 22 | a = fsspec.get_mapper(local_dir) 23 | b = fsspec.get_mapper(remote_dir) 24 | 25 | for k in a: 26 | # skip epoch_latest which can change during sync. 27 | if 'epoch_latest.pt' in k: 28 | continue 29 | 30 | logging.info(f'Attempting to sync {k}') 31 | if k in b and len(a[k]) == len(b[k]): 32 | logging.debug(f'Skipping remote sync for {k}.') 33 | continue 34 | 35 | try: 36 | logging.info(f'Successful sync for {k}.') 37 | b[k] = a[k] 38 | except Exception as e: 39 | logging.info(f'Error during remote sync for {k}: {e}') 40 | return False 41 | 42 | return True 43 | 44 | def remote_sync(local_dir, remote_dir, protocol): 45 | logging.info('Starting remote sync.') 46 | if protocol == 's3': 47 | return remote_sync_s3(local_dir, remote_dir) 48 | elif protocol == 'fsspec': 49 | return remote_sync_fsspec(local_dir, remote_dir) 50 | else: 51 | logging.error('Remote protocol not known') 52 | return False 53 | 54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): 55 | while True: 56 | time.sleep(sync_every) 57 | remote_sync(local_dir, remote_dir, protocol) 58 | 59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol): 60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) 61 | return p 62 | 63 | # Note: we are not currently using this save function. 64 | def pt_save(pt_obj, file_path): 65 | of = fsspec.open(file_path, "wb") 66 | with of as f: 67 | torch.save(pt_obj, file_path) 68 | 69 | def pt_load(file_path, map_location=None): 70 | if file_path.startswith('s3'): 71 | logging.info('Loading remote checkpoint, which may take a bit.') 72 | of = fsspec.open(file_path, "rb") 73 | with of as f: 74 | out = torch.load(f, map_location=map_location) 75 | return out 76 | 77 | def check_exists(file_path): 78 | try: 79 | with fsspec.open(file_path): 80 | pass 81 | except FileNotFoundError: 82 | return False 83 | return True 84 | -------------------------------------------------------------------------------- /src/open_clip_train/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /src/open_clip_train/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | from functools import partial 4 | 5 | 6 | def get_autocast(precision, device_type='cuda'): 7 | if precision =='amp': 8 | amp_dtype = torch.float16 9 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16': 10 | amp_dtype = torch.bfloat16 11 | else: 12 | return suppress 13 | 14 | return partial(torch.amp.autocast, device_type=device_type, dtype=amp_dtype) -------------------------------------------------------------------------------- /src/open_clip_train/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | 22 | return _lr_adjuster 23 | 24 | 25 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): 26 | def _lr_adjuster(step): 27 | start_cooldown_step = steps - cooldown_steps 28 | if step < warmup_length: 29 | lr = _warmup_lr(base_lr, warmup_length, step) 30 | else: 31 | if step < start_cooldown_step: 32 | lr = base_lr 33 | else: 34 | e = step - start_cooldown_step 35 | es = steps - start_cooldown_step 36 | # linear decay if power == 1; polynomial decay otherwise; 37 | decay = (1 - (e / es)) ** cooldown_power 38 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 39 | assign_learning_rate(optimizer, lr) 40 | return lr 41 | 42 | return _lr_adjuster 43 | 44 | 45 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 46 | def _lr_adjuster(step): 47 | if step < warmup_length: 48 | lr = _warmup_lr(base_lr, warmup_length, step) 49 | else: 50 | e = step - warmup_length 51 | es = steps - warmup_length 52 | lr = 0.5 * (1 + math.cos(math.pi * e / es)) * base_lr 53 | assign_learning_rate(optimizer, lr) 54 | return lr 55 | 56 | return _lr_adjuster 57 | 58 | -------------------------------------------------------------------------------- /src/open_clip_train/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ 7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES 8 | from open_clip_train.precision import get_autocast 9 | 10 | 11 | def accuracy(output, target, topk=(1,)): 12 | pred = output.topk(max(topk), 1, True, True)[1].t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 15 | 16 | 17 | def run(model, classifier, dataloader, args): 18 | device = torch.device(args.device) 19 | autocast = get_autocast(args.precision, device_type=device.type) 20 | input_dtype = get_input_dtype(args.precision) 21 | 22 | with torch.inference_mode(): 23 | top1, top5, n = 0., 0., 0. 24 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 25 | images = images.to(device=device, dtype=input_dtype) 26 | target = target.to(device) 27 | 28 | with autocast(): 29 | # predict 30 | output = model(image=images) 31 | image_features = output['image_features'] if isinstance(output, dict) else output[0] 32 | logits = 100. * image_features @ classifier 33 | 34 | # measure accuracy 35 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 36 | top1 += acc1 37 | top5 += acc5 38 | n += images.size(0) 39 | 40 | top1 = (top1 / n) 41 | top5 = (top5 / n) 42 | return top1, top5 43 | 44 | 45 | def zero_shot_eval(model, data, epoch, args, tokenizer=None): 46 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 47 | return {} 48 | if args.zeroshot_frequency == 0: 49 | return {} 50 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 51 | return {} 52 | if args.distributed and not args.horovod: 53 | model = model.module 54 | 55 | logging.info('Starting zero-shot imagenet.') 56 | if tokenizer is None: 57 | tokenizer = get_tokenizer(args.model) 58 | 59 | logging.info('Building zero-shot classifier') 60 | device = torch.device(args.device) 61 | autocast = get_autocast(args.precision, device_type=device.type) 62 | with autocast(): 63 | classifier = build_zero_shot_classifier( 64 | model, 65 | tokenizer=tokenizer, 66 | classnames=IMAGENET_CLASSNAMES, 67 | templates=OPENAI_IMAGENET_TEMPLATES, 68 | num_classes_per_batch=10, 69 | device=device, 70 | use_tqdm=True, 71 | ) 72 | 73 | logging.info('Using classifier') 74 | results = {} 75 | if 'imagenet-val' in data: 76 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 77 | results['imagenet-zeroshot-val-top1'] = top1 78 | results['imagenet-zeroshot-val-top5'] = top5 79 | if 'imagenet-v2' in data: 80 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 81 | results['imagenetv2-zeroshot-val-top1'] = top1 82 | results['imagenetv2-zeroshot-val-top5'] = top5 83 | 84 | logging.info('Finished zero-shot imagenet.') 85 | 86 | return results 87 | -------------------------------------------------------------------------------- /tests/test_download_pretrained.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | import hashlib 5 | import tempfile 6 | import unittest 7 | from io import BytesIO 8 | from pathlib import Path 9 | from unittest.mock import patch 10 | 11 | from urllib3 import HTTPResponse 12 | from urllib3._collections import HTTPHeaderDict 13 | 14 | import open_clip 15 | from open_clip.pretrained import download_pretrained_from_url 16 | 17 | 18 | class DownloadPretrainedTests(unittest.TestCase): 19 | 20 | def create_response(self, data, status_code=200, content_type='application/octet-stream'): 21 | fp = BytesIO(data) 22 | headers = HTTPHeaderDict({ 23 | 'Content-Type': content_type, 24 | 'Content-Length': str(len(data)) 25 | }) 26 | raw = HTTPResponse(fp, preload_content=False, headers=headers, status=status_code) 27 | return raw 28 | 29 | @patch('open_clip.pretrained.urllib') 30 | def test_download_pretrained_from_url_from_openaipublic(self, urllib): 31 | file_contents = b'pretrained model weights' 32 | expected_hash = hashlib.sha256(file_contents).hexdigest() 33 | urllib.request.urlopen.return_value = self.create_response(file_contents) 34 | with tempfile.TemporaryDirectory() as root: 35 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 36 | download_pretrained_from_url(url, root) 37 | urllib.request.urlopen.assert_called_once() 38 | 39 | @patch('open_clip.pretrained.urllib') 40 | def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib): 41 | file_contents = b'pretrained model weights' 42 | expected_hash = hashlib.sha256(file_contents).hexdigest() 43 | urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') 44 | with tempfile.TemporaryDirectory() as root: 45 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 46 | with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): 47 | download_pretrained_from_url(url, root) 48 | urllib.request.urlopen.assert_called_once() 49 | 50 | @patch('open_clip.pretrained.urllib') 51 | def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib): 52 | file_contents = b'pretrained model weights' 53 | expected_hash = hashlib.sha256(file_contents).hexdigest() 54 | urllib.request.urlopen.return_value = self.create_response(file_contents) 55 | with tempfile.TemporaryDirectory() as root: 56 | local_file = Path(root) / 'RN50.pt' 57 | local_file.write_bytes(file_contents) 58 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 59 | download_pretrained_from_url(url, root) 60 | urllib.request.urlopen.assert_not_called() 61 | 62 | @patch('open_clip.pretrained.urllib') 63 | def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib): 64 | file_contents = b'pretrained model weights' 65 | expected_hash = hashlib.sha256(file_contents).hexdigest() 66 | urllib.request.urlopen.return_value = self.create_response(file_contents) 67 | with tempfile.TemporaryDirectory() as root: 68 | local_file = Path(root) / 'RN50.pt' 69 | local_file.write_bytes(b'corrupted pretrained model') 70 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 71 | download_pretrained_from_url(url, root) 72 | urllib.request.urlopen.assert_called_once() 73 | 74 | @patch('open_clip.pretrained.urllib') 75 | def test_download_pretrained_from_url_from_mlfoundations(self, urllib): 76 | file_contents = b'pretrained model weights' 77 | expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] 78 | urllib.request.urlopen.return_value = self.create_response(file_contents) 79 | with tempfile.TemporaryDirectory() as root: 80 | url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' 81 | download_pretrained_from_url(url, root) 82 | urllib.request.urlopen.assert_called_once() 83 | 84 | @patch('open_clip.pretrained.urllib') 85 | def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib): 86 | file_contents = b'pretrained model weights' 87 | expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] 88 | urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') 89 | with tempfile.TemporaryDirectory() as root: 90 | url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' 91 | with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): 92 | download_pretrained_from_url(url, root) 93 | urllib.request.urlopen.assert_called_once() 94 | 95 | @patch('open_clip.pretrained.urllib') 96 | def test_download_pretrained_from_hfh(self, urllib): 97 | model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:hf-internal-testing/tiny-open-clip-model') 98 | tokenizer = open_clip.get_tokenizer('hf-hub:hf-internal-testing/tiny-open-clip-model') 99 | img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png" 100 | image = preprocess(Image.open(requests.get(img_url, stream=True).raw)).unsqueeze(0) 101 | text = tokenizer(["a diagram", "a dog", "a cat"]) 102 | 103 | with torch.no_grad(): 104 | image_features = model.encode_image(image) 105 | text_features = model.encode_text(text) 106 | image_features /= image_features.norm(dim=-1, keepdim=True) 107 | text_features /= text_features.norm(dim=-1, keepdim=True) 108 | 109 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 110 | 111 | self.assertTrue(torch.allclose(text_probs, torch.tensor([[0.0597, 0.6349, 0.3053]]), 1e-3)) 112 | -------------------------------------------------------------------------------- /tests/test_hf_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from open_clip.hf_model import _POOLERS, HFTextEncoder 5 | from transformers import AutoConfig 6 | from transformers.modeling_outputs import BaseModelOutput 7 | 8 | # test poolers 9 | def test_poolers(): 10 | bs, sl, d = 2, 10, 5 11 | h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d) 12 | mask = torch.ones(bs, sl, dtype=torch.bool) 13 | mask[:2, 6:] = False 14 | x = BaseModelOutput(h) 15 | for name, cls in _POOLERS.items(): 16 | pooler = cls() 17 | res = pooler(x, mask) 18 | assert res.shape == (bs, d), f"{name} returned wrong shape" 19 | 20 | # test HFTextEncoder 21 | @pytest.mark.parametrize("model_id", ["arampacha/roberta-tiny", "roberta-base", "xlm-roberta-base", "google/mt5-base"]) 22 | def test_pretrained_text_encoder(model_id): 23 | bs, sl, d = 2, 10, 64 24 | cfg = AutoConfig.from_pretrained(model_id) 25 | model = HFTextEncoder(model_id, d, proj_type='linear') 26 | x = torch.randint(0, cfg.vocab_size, (bs, sl)) 27 | with torch.no_grad(): 28 | emb = model(x) 29 | 30 | assert emb.shape == (bs, d) 31 | -------------------------------------------------------------------------------- /tests/test_inference.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pytest 4 | import torch 5 | import open_clip 6 | import util_test 7 | 8 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 9 | 10 | if hasattr(torch._C, '_jit_set_profiling_executor'): 11 | # legacy executor is too slow to compile large models for unit tests 12 | # no need for the fusion performance here 13 | torch._C._jit_set_profiling_executor(True) 14 | torch._C._jit_set_profiling_mode(False) 15 | 16 | models_to_test = set(open_clip.list_models()) 17 | 18 | # testing excemptions 19 | models_to_test = models_to_test.difference({ 20 | # not available with timm yet 21 | # see https://github.com/mlfoundations/open_clip/issues/219 22 | 'convnext_xlarge', 23 | 'convnext_xxlarge', 24 | 'convnext_xxlarge_320', 25 | 'vit_medium_patch16_gap_256', 26 | # exceeds GH runner memory limit 27 | 'ViT-bigG-14', 28 | 'ViT-e-14', 29 | 'mt5-xl-ViT-H-14', 30 | 'coca_base', 31 | 'coca_ViT-B-32', 32 | 'coca_roberta-ViT-B-32' 33 | }) 34 | 35 | if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ: 36 | external_model_list = os.environ['OPEN_CLIP_TEST_REG_MODELS'] 37 | with open(external_model_list, 'r') as f: 38 | models_to_test = set(f.read().splitlines()).intersection(models_to_test) 39 | print(f"Selected models from {external_model_list}: {models_to_test}") 40 | 41 | # TODO: add "coca_ViT-B-32" onece https://github.com/pytorch/pytorch/issues/92073 gets fixed 42 | models_to_test = list(models_to_test) 43 | models_to_test.sort() 44 | models_to_test = [(model_name, False) for model_name in models_to_test] 45 | 46 | models_to_jit_test = {"ViT-B-32"} 47 | models_to_jit_test = list(models_to_jit_test) 48 | models_to_jit_test = [(model_name, True) for model_name in models_to_jit_test] 49 | models_to_test_fully = models_to_test + models_to_jit_test 50 | 51 | 52 | @pytest.mark.regression_test 53 | @pytest.mark.parametrize("model_name,jit", models_to_test_fully) 54 | def test_inference_with_data( 55 | model_name, 56 | jit, 57 | pretrained = None, 58 | pretrained_hf = False, 59 | precision = 'fp32', 60 | force_quick_gelu = False, 61 | ): 62 | util_test.seed_all() 63 | model, _, preprocess_val = open_clip.create_model_and_transforms( 64 | model_name, 65 | pretrained = pretrained, 66 | precision = precision, 67 | jit = jit, 68 | force_quick_gelu = force_quick_gelu, 69 | pretrained_hf = pretrained_hf 70 | ) 71 | model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}' 72 | input_dir, output_dir = util_test.get_data_dirs() 73 | # text 74 | input_text_path = os.path.join(input_dir, 'random_text.pt') 75 | gt_text_path = os.path.join(output_dir, f'{model_id}_random_text.pt') 76 | if not os.path.isfile(input_text_path): 77 | pytest.skip(reason = f"missing test data, expected at {input_text_path}") 78 | if not os.path.isfile(gt_text_path): 79 | pytest.skip(reason = f"missing test data, expected at {gt_text_path}") 80 | input_text = torch.load(input_text_path) 81 | gt_text = torch.load(gt_text_path) 82 | y_text = util_test.inference_text(model, model_name, input_text) 83 | assert (y_text == gt_text).all(), f"text output differs @ {input_text_path}" 84 | # image 85 | image_size = model.visual.image_size 86 | if not isinstance(image_size, tuple): 87 | image_size = (image_size, image_size) 88 | input_image_path = os.path.join(input_dir, f'random_image_{image_size[0]}_{image_size[1]}.pt') 89 | gt_image_path = os.path.join(output_dir, f'{model_id}_random_image.pt') 90 | if not os.path.isfile(input_image_path): 91 | pytest.skip(reason = f"missing test data, expected at {input_image_path}") 92 | if not os.path.isfile(gt_image_path): 93 | pytest.skip(reason = f"missing test data, expected at {gt_image_path}") 94 | input_image = torch.load(input_image_path) 95 | gt_image = torch.load(gt_image_path) 96 | y_image = util_test.inference_image(model, preprocess_val, input_image) 97 | assert (y_image == gt_image).all(), f"image output differs @ {input_image_path}" 98 | 99 | if not jit: 100 | model.eval() 101 | model_out = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text) 102 | if type(model) not in [open_clip.CLIP, open_clip.CustomTextCLIP]: 103 | assert type(model_out) == dict 104 | else: 105 | model.output_dict = True 106 | model_out_dict = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text) 107 | assert (model_out_dict["image_features"] == model_out[0]).all() 108 | assert (model_out_dict["text_features"] == model_out[1]).all() 109 | assert (model_out_dict["logit_scale"] == model_out[2]).all() 110 | model.output_dict = None 111 | else: 112 | model, _, preprocess_val = open_clip.create_model_and_transforms( 113 | model_name, 114 | pretrained = pretrained, 115 | precision = precision, 116 | jit = False, 117 | force_quick_gelu = force_quick_gelu, 118 | pretrained_hf = pretrained_hf 119 | ) 120 | 121 | test_model = util_test.TestWrapper(model, model_name, output_dict=False) 122 | test_model = torch.jit.script(test_model) 123 | model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) 124 | assert model_out["test_output"].shape[-1] == 2 125 | 126 | test_model = util_test.TestWrapper(model, model_name, output_dict=True) 127 | test_model = torch.jit.script(test_model) 128 | model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) 129 | assert model_out["test_output"].shape[-1] == 2 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /tests/test_inference_simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from open_clip.factory import get_tokenizer 4 | import pytest 5 | import open_clip 6 | import os 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 8 | 9 | if hasattr(torch._C, '_jit_set_profiling_executor'): 10 | # legacy executor is too slow to compile large models for unit tests 11 | # no need for the fusion performance here 12 | torch._C._jit_set_profiling_executor(True) 13 | torch._C._jit_set_profiling_mode(False) 14 | 15 | 16 | test_simple_models = [ 17 | # model, pretrained, jit, force_custom_text 18 | ("ViT-B-32", "laion2b_s34b_b79k", False, False), 19 | ("ViT-B-32", "laion2b_s34b_b79k", True, False), 20 | ("ViT-B-32", "laion2b_s34b_b79k", True, True), 21 | ("roberta-ViT-B-32", "laion2b_s12b_b32k", False, False), 22 | ] 23 | 24 | 25 | @pytest.mark.parametrize("model_type,pretrained,jit,force_custom_text", test_simple_models) 26 | def test_inference_simple( 27 | model_type, 28 | pretrained, 29 | jit, 30 | force_custom_text, 31 | ): 32 | model, _, preprocess = open_clip.create_model_and_transforms( 33 | model_type, 34 | pretrained=pretrained, 35 | jit=jit, 36 | force_custom_text=force_custom_text, 37 | ) 38 | tokenizer = get_tokenizer(model_type) 39 | 40 | current_dir = os.path.dirname(os.path.realpath(__file__)) 41 | 42 | image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0) 43 | text = tokenizer(["a diagram", "a dog", "a cat"]) 44 | 45 | with torch.no_grad(): 46 | image_features = model.encode_image(image) 47 | text_features = model.encode_text(text) 48 | 49 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 50 | 51 | assert torch.allclose(text_probs.cpu()[0], torch.tensor([1.0, 0.0, 0.0])) 52 | -------------------------------------------------------------------------------- /tests/test_num_shards.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from open_clip_train.data import get_dataset_size 4 | 5 | @pytest.mark.parametrize( 6 | "shards,expected_size", 7 | [ 8 | ('/path/to/shard.tar', 1), 9 | ('/path/to/shard_{000..000}.tar', 1), 10 | ('/path/to/shard_{000..009}.tar', 10), 11 | ('/path/to/shard_{000..009}_{000..009}.tar', 100), 12 | ('/path/to/shard.tar::/path/to/other_shard_{000..009}.tar', 11), 13 | ('/path/to/shard_{000..009}.tar::/path/to/other_shard_{000..009}.tar', 20), 14 | (['/path/to/shard.tar'], 1), 15 | (['/path/to/shard.tar', '/path/to/other_shard.tar'], 2), 16 | ] 17 | ) 18 | def test_num_shards(shards, expected_size): 19 | _, size = get_dataset_size(shards) 20 | assert size == expected_size, f'Expected {expected_size} for {shards} but found {size} instead.' 21 | -------------------------------------------------------------------------------- /tests/test_training_simple.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import pytest 5 | import torch 6 | from open_clip_train.main import main 7 | 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 9 | 10 | if hasattr(torch._C, '_jit_set_profiling_executor'): 11 | # legacy executor is too slow to compile large models for unit tests 12 | # no need for the fusion performance here 13 | torch._C._jit_set_profiling_executor(True) 14 | torch._C._jit_set_profiling_mode(False) 15 | 16 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 17 | def test_training(): 18 | main([ 19 | '--save-frequency', '1', 20 | '--zeroshot-frequency', '1', 21 | '--dataset-type', "synthetic", 22 | '--train-num-samples', '16', 23 | '--warmup', '1', 24 | '--batch-size', '4', 25 | '--lr', '1e-3', 26 | '--wd', '0.1', 27 | '--epochs', '1', 28 | '--workers', '2', 29 | '--model', 'RN50' 30 | ]) 31 | 32 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 33 | def test_training_coca(): 34 | main([ 35 | '--save-frequency', '1', 36 | '--zeroshot-frequency', '1', 37 | '--dataset-type', "synthetic", 38 | '--train-num-samples', '16', 39 | '--warmup', '1', 40 | '--batch-size', '4', 41 | '--lr', '1e-3', 42 | '--wd', '0.1', 43 | '--epochs', '1', 44 | '--workers', '2', 45 | '--model', 'coca_ViT-B-32' 46 | ]) 47 | 48 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 49 | def test_training_mt5(): 50 | main([ 51 | '--save-frequency', '1', 52 | '--zeroshot-frequency', '1', 53 | '--dataset-type', "synthetic", 54 | '--train-num-samples', '16', 55 | '--warmup', '1', 56 | '--batch-size', '4', 57 | '--lr', '1e-3', 58 | '--wd', '0.1', 59 | '--epochs', '1', 60 | '--workers', '2', 61 | '--model', 'mt5-base-ViT-B-32', 62 | '--lock-text', 63 | '--lock-text-unlocked-layers', '2' 64 | ]) 65 | 66 | 67 | 68 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 69 | def test_training_unfreezing_vit(): 70 | main([ 71 | '--save-frequency', '1', 72 | '--zeroshot-frequency', '1', 73 | '--dataset-type', "synthetic", 74 | '--train-num-samples', '16', 75 | '--warmup', '1', 76 | '--batch-size', '4', 77 | '--lr', '1e-3', 78 | '--wd', '0.1', 79 | '--epochs', '1', 80 | '--workers', '2', 81 | '--model', 'ViT-B-32', 82 | '--lock-image', 83 | '--lock-image-unlocked-groups', '5', 84 | '--accum-freq', '2' 85 | ]) 86 | 87 | 88 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 89 | def test_training_clip_with_jit(): 90 | main([ 91 | '--save-frequency', '1', 92 | '--zeroshot-frequency', '1', 93 | '--dataset-type', "synthetic", 94 | '--train-num-samples', '16', 95 | '--warmup', '1', 96 | '--batch-size', '4', 97 | '--lr', '1e-3', 98 | '--wd', '0.1', 99 | '--epochs', '1', 100 | '--workers', '2', 101 | '--model', 'ViT-B-32', 102 | '--torchscript' 103 | ]) 104 | -------------------------------------------------------------------------------- /tests/test_wds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import util_test 4 | import collections 5 | import tarfile 6 | import io 7 | from PIL import Image 8 | 9 | from open_clip_train.data import get_wds_dataset 10 | from open_clip_train.params import parse_args 11 | from open_clip_train.main import random_seed 12 | 13 | TRAIN_NUM_SAMPLES = 10_000 14 | RTOL = 0.2 15 | 16 | # NOTE: we use two test tar files, which are created on the fly and saved to data/input. 17 | # 000.tar has 10 samples, and the captions are 000_0, 000_1, ..., 000_9 18 | # 001.tar has 5 samples, and the captions are 001_0, 001_1, ..., 001_4 19 | def build_inputs(test_name): 20 | base_input_dir, _ = util_test.get_data_dirs() 21 | input_dir = os.path.join(base_input_dir, test_name) 22 | os.makedirs(input_dir, exist_ok=True) 23 | 24 | def save_tar(idx, num_samples): 25 | filename = os.path.join(input_dir, f'test_data_{idx:03d}.tar') 26 | tar = tarfile.open(filename, 'w') 27 | 28 | for sample_idx in range(num_samples): 29 | # Image 30 | image = Image.new('RGB', (32, 32)) 31 | info = tarfile.TarInfo(f'{sample_idx}.png') 32 | bio = io.BytesIO() 33 | image.save(bio, format='png') 34 | size = bio.tell() 35 | bio.seek(0) 36 | info.size = size 37 | tar.addfile(info, bio) 38 | 39 | # Caption 40 | info = tarfile.TarInfo(f'{sample_idx}.txt') 41 | bio = io.BytesIO() 42 | bio.write(f'{idx:03d}_{sample_idx}'.encode('utf-8')) 43 | size = bio.tell() 44 | bio.seek(0) 45 | info.size = size 46 | tar.addfile(info, bio) 47 | 48 | tar.close() 49 | 50 | save_tar(0, 10) 51 | save_tar(1, 5) 52 | 53 | return input_dir 54 | 55 | 56 | def build_params(input_shards, seed=0): 57 | args = parse_args([]) 58 | args.train_data = input_shards 59 | args.train_num_samples = TRAIN_NUM_SAMPLES 60 | args.dataset_resampled = True 61 | args.seed = seed 62 | args.workers = 1 63 | args.world_size = 1 64 | args.batch_size = 1 65 | random_seed(seed) 66 | 67 | preprocess_img = lambda x: x 68 | tokenizer = lambda x: [x.strip()] 69 | 70 | return args, preprocess_img, tokenizer 71 | 72 | 73 | def get_dataloader(input_shards): 74 | args, preprocess_img, tokenizer = build_params(input_shards) 75 | dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer) 76 | dataloader = dataset.dataloader 77 | return dataloader 78 | 79 | 80 | def test_single_source(): 81 | """Test webdataset with a single tar file.""" 82 | input_dir = build_inputs('single_source') 83 | input_shards = os.path.join(input_dir, 'test_data_000.tar') 84 | dataloader = get_dataloader(input_shards) 85 | 86 | counts = collections.defaultdict(int) 87 | for sample in dataloader: 88 | txts = sample[1] 89 | for txt in txts: 90 | counts[txt] += 1 91 | 92 | for key, count in counts.items(): 93 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL) 94 | 95 | 96 | def test_two_sources(): 97 | """Test webdataset with a single two tar files.""" 98 | input_dir = build_inputs('two_sources') 99 | input_shards = os.path.join(input_dir, 'test_data_{000..001}.tar') 100 | dataloader = get_dataloader(input_shards) 101 | 102 | counts = collections.defaultdict(int) 103 | for sample in dataloader: 104 | txts = sample[1] 105 | for txt in txts: 106 | counts[txt] += 1 107 | 108 | for key, count in counts.items(): 109 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}' 110 | 111 | 112 | def test_two_sources_same_weights(): 113 | """Test webdataset with a two tar files, using --train-data-weights=1::1.""" 114 | input_dir = build_inputs('two_sources_same_weights') 115 | input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}" 116 | args, preprocess_img, tokenizer = build_params(input_shards) 117 | args.train_data_upsampling_factors = '1::1' 118 | dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer) 119 | dataloader = dataset.dataloader 120 | 121 | counts = collections.defaultdict(int) 122 | for sample in dataloader: 123 | txts = sample[1] 124 | for txt in txts: 125 | counts[txt] += 1 126 | 127 | for key, count in counts.items(): 128 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}' 129 | 130 | def test_two_sources_with_upsampling(): 131 | """Test webdataset with a two tar files with upsampling.""" 132 | input_dir = build_inputs('two_sources_with_upsampling') 133 | input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}" 134 | args, preprocess_img, tokenizer = build_params(input_shards) 135 | args.train_data_upsampling_factors = '1::2' 136 | dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer) 137 | dataloader = dataset.dataloader 138 | 139 | counts = collections.defaultdict(int) 140 | for sample in dataloader: 141 | txts = sample[1] 142 | for txt in txts: 143 | counts[txt] += 1 144 | 145 | for key, count in counts.items(): 146 | if key.startswith('000'): 147 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 20, RTOL), f'{key}, {count}' 148 | else: 149 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL), f'{key}, {count}' 150 | --------------------------------------------------------------------------------