├── open_clip_customized ├── src │ ├── training │ │ ├── __init__.py │ │ ├── .gitignore │ │ ├── precision.py │ │ ├── logger.py │ │ ├── scheduler.py │ │ ├── file_utils.py │ │ ├── zero_shot.py │ │ └── distributed.py │ └── open_clip │ │ ├── version.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── constants.py │ │ ├── model_configs │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32.json │ │ ├── ViT-M-16.json │ │ ├── ViT-M-32.json │ │ ├── ViT-S-16.json │ │ ├── ViT-S-32.json │ │ ├── ViT-B-16-plus.json │ │ ├── ViT-L-14-280.json │ │ ├── ViT-L-14-336.json │ │ ├── ViT-L-14.json │ │ ├── ViT-L-16-320.json │ │ ├── ViT-L-16.json │ │ ├── ViT-M-32-alt.json │ │ ├── ViT-S-16-alt.json │ │ ├── ViT-S-32-alt.json │ │ ├── ViT-B-16-plus-240.json │ │ ├── ViT-B-32-256.json │ │ ├── ViT-B-32-plus-256.json │ │ ├── mt5-base-ViT-B-32.json │ │ ├── xlm-roberta-base-ViT-B-32.json │ │ ├── latent_ViT.json │ │ ├── ViT-H-14.json │ │ ├── ViT-H-16.json │ │ ├── ViT-B-16-quickgelu.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── ViT-L-14-quickgelu.json │ │ ├── ViT-M-16-alt.json │ │ ├── roberta-ViT-B-32.json │ │ ├── mt5-xl-ViT-H-14.json │ │ ├── xlm-roberta-large-ViT-H-14.json │ │ ├── ViT-e-14.json │ │ ├── ViT-g-14.json │ │ ├── ViT-H-14-quickgelu.json │ │ ├── ViT-bigG-14.json │ │ ├── ViT-H-14-378-quickgelu.json │ │ ├── nllb-clip-base.json │ │ ├── RN50.json │ │ ├── RN101.json │ │ ├── RN50x16.json │ │ ├── RN50x4.json │ │ ├── RN50x64.json │ │ ├── vit_medium_patch16_gap_256.json │ │ ├── swin_base_patch4_window7_224.json │ │ ├── nllb-clip-large.json │ │ ├── EVA01-g-14.json │ │ ├── vit_relpos_medium_patch16_cls_224.json │ │ ├── EVA01-g-14-plus.json │ │ ├── EVA02-B-16.json │ │ ├── EVA02-L-14.json │ │ ├── EVA02-E-14.json │ │ ├── EVA02-L-14-336.json │ │ ├── RN101-quickgelu.json │ │ ├── latent_RN50.json │ │ ├── EVA02-E-14-plus.json │ │ ├── RN50-quickgelu.json │ │ ├── convnext_base.json │ │ ├── convnext_base_w.json │ │ ├── convnext_large.json │ │ ├── convnext_large_d.json │ │ ├── convnext_small.json │ │ ├── convnext_tiny.json │ │ ├── convnext_base_w_320.json │ │ ├── convnext_large_d_320.json │ │ ├── convnext_xlarge.json │ │ ├── convnext_xxlarge.json │ │ ├── convnext_xxlarge_320.json │ │ ├── nllb-clip-base-siglip.json │ │ ├── nllb-clip-large-siglip.json │ │ ├── coca_roberta-ViT-B-32.json │ │ ├── ViT-L-14-CLIPA.json │ │ ├── ViT-L-14-CLIPA-336.json │ │ ├── ViT-H-14-CLIPA.json │ │ ├── ViT-H-14-CLIPA-336.json │ │ ├── ViT-bigG-14-CLIPA.json │ │ ├── ViT-bigG-14-CLIPA-336.json │ │ ├── coca_ViT-B-32.json │ │ ├── coca_ViT-L-14.json │ │ ├── coca_base.json │ │ ├── ViT-B-16-SigLIP.json │ │ ├── ViT-B-16-SigLIP-256.json │ │ ├── ViT-B-16-SigLIP-384.json │ │ ├── ViT-B-16-SigLIP-512.json │ │ ├── ViT-L-16-SigLIP-256.json │ │ ├── ViT-L-16-SigLIP-384.json │ │ ├── ViT-B-16-SigLIP-i18n-256.json │ │ ├── ViT-SO400M-14-SigLIP.json │ │ └── ViT-SO400M-14-SigLIP-384.json │ │ ├── __init__.py │ │ ├── hf_configs.py │ │ ├── openai.py │ │ ├── utils.py │ │ ├── pos_embed.py │ │ ├── zero_shot_classifier.py │ │ ├── timm_model.py │ │ ├── hf_model.py │ │ ├── modified_resnet.py │ │ └── big_vision.py ├── pytest.ini ├── requirements-test.txt ├── docs │ ├── CLIP.png │ ├── scaling.png │ ├── clip_loss.png │ ├── clip_recall.png │ ├── clip_val_loss.png │ ├── clip_zeroshot.png │ ├── clipa_acc_compute.png │ ├── effective_robustness.png │ ├── inverse_scaling_law.png │ ├── laion_clip_zeroshot.png │ ├── clipa_reduce_text_token.png │ ├── laion_clip_zeroshot_b16.png │ ├── laion_clip_zeroshot_l14.png │ ├── clipa_reduce_image_token.png │ ├── laion2b_clip_zeroshot_b32.png │ ├── laion_openai_compare_b32.jpg │ ├── laion_clip_zeroshot_b16_plus_240.png │ ├── clip_conceptual_captions.md │ ├── script_examples │ │ ├── clipa │ │ │ ├── vit_b16 │ │ │ │ ├── i50_t16_pretrain.sh │ │ │ │ └── i50_t16_finetune.sh │ │ │ └── vit_l16 │ │ │ │ ├── i17_t16_pretrain.sh │ │ │ │ ├── i37_t8_pretrain.sh │ │ │ │ ├── i37_t8_finetune.sh │ │ │ │ └── i17_t16_finetune.sh │ │ ├── clipav2 │ │ │ └── vit_h14 │ │ │ │ ├── i50_t8_pretrain.sh │ │ │ │ ├── i257_t32_finetunex4.sh │ │ │ │ └── i577_t32_finetunex1.sh │ │ └── stability_example.sh │ ├── LOW_ACC.md │ ├── model_profile.csv │ ├── datacomp_models.md │ └── clipa.md ├── MANIFEST.in ├── requirements.txt ├── requirements-training.txt ├── scripts │ ├── clipav1_vit_l16_i37_t8.sh │ ├── clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh │ ├── h14_224_32_finetune.sh │ └── h14_84_8_pretrain.sh ├── Makefile ├── tests │ ├── test_num_shards.py │ ├── test_hf_model.py │ ├── test_inference_simple.py │ ├── test_training_simple.py │ ├── test_wds.py │ ├── test_inference.py │ └── test_download_pretrained.py ├── CITATION.cff ├── LICENSE ├── setup.py ├── .gitignore └── HISTORY.md ├── local_gradio └── style.css ├── ddim_solver.py ├── README.md └── dataset.py /open_clip_customized/src/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /open_clip_customized/src/training/.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.24.0' 2 | -------------------------------------------------------------------------------- /open_clip_customized/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | regression_test 4 | -------------------------------------------------------------------------------- /open_clip_customized/requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest-split==0.8.0 2 | pytest==7.2.0 3 | transformers 4 | timm>=0.9.8 5 | -------------------------------------------------------------------------------- /open_clip_customized/docs/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/CLIP.png -------------------------------------------------------------------------------- /open_clip_customized/docs/scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/scaling.png -------------------------------------------------------------------------------- /open_clip_customized/docs/clip_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/clip_loss.png -------------------------------------------------------------------------------- /open_clip_customized/docs/clip_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/clip_recall.png -------------------------------------------------------------------------------- /open_clip_customized/docs/clip_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/clip_val_loss.png -------------------------------------------------------------------------------- /open_clip_customized/docs/clip_zeroshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/clip_zeroshot.png -------------------------------------------------------------------------------- /open_clip_customized/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/open_clip/bpe_simple_vocab_16e6.txt.gz 2 | include src/open_clip/model_configs/*.json 3 | 4 | -------------------------------------------------------------------------------- /open_clip_customized/docs/clipa_acc_compute.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/clipa_acc_compute.png -------------------------------------------------------------------------------- /open_clip_customized/docs/effective_robustness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/effective_robustness.png -------------------------------------------------------------------------------- /open_clip_customized/docs/inverse_scaling_law.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/inverse_scaling_law.png -------------------------------------------------------------------------------- /open_clip_customized/docs/laion_clip_zeroshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/laion_clip_zeroshot.png -------------------------------------------------------------------------------- /open_clip_customized/docs/clipa_reduce_text_token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/clipa_reduce_text_token.png -------------------------------------------------------------------------------- /open_clip_customized/docs/laion_clip_zeroshot_b16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/laion_clip_zeroshot_b16.png -------------------------------------------------------------------------------- /open_clip_customized/docs/laion_clip_zeroshot_l14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/laion_clip_zeroshot_l14.png -------------------------------------------------------------------------------- /open_clip_customized/docs/clipa_reduce_image_token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/clipa_reduce_image_token.png -------------------------------------------------------------------------------- /open_clip_customized/docs/laion2b_clip_zeroshot_b32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/laion2b_clip_zeroshot_b32.png -------------------------------------------------------------------------------- /open_clip_customized/docs/laion_openai_compare_b32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/laion_openai_compare_b32.jpg -------------------------------------------------------------------------------- /open_clip_customized/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | regex 4 | ftfy 5 | tqdm 6 | huggingface_hub 7 | sentencepiece 8 | protobuf 9 | timm 10 | -------------------------------------------------------------------------------- /open_clip_customized/docs/laion_clip_zeroshot_b16_plus_240.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/docs/laion_clip_zeroshot_b16_plus_240.png -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ji4chenLi/rg-lcd/HEAD/open_clip_customized/src/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip_customized/requirements-training.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | webdataset>=0.2.5 4 | regex 5 | ftfy 6 | tqdm 7 | pandas 8 | braceexpand 9 | huggingface_hub 10 | transformers 11 | timm>=0.9.8 12 | fsspec 13 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_STD = (0.229, 0.224, 0.225) 5 | INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | INCEPTION_STD = (0.5, 0.5, 0.5) 7 | -------------------------------------------------------------------------------- /local_gradio/style.css: -------------------------------------------------------------------------------- 1 | h1 { 2 | text-align: center; 3 | } 4 | 5 | #duplicate-button { 6 | margin: auto; 7 | color: #fff; 8 | background: #1565c0; 9 | border-radius: 100vh; 10 | } 11 | 12 | #component-0 { 13 | max-width: 830px; 14 | margin: auto; 15 | padding-top: 1.5rem; 16 | } -------------------------------------------------------------------------------- /open_clip_customized/scripts/clipav1_vit_l16_i37_t8.sh: -------------------------------------------------------------------------------- 1 | # eval on a single gpu 2 | CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m training.main \ 3 | --model ViT-L-16-CL32-GAP \ 4 | --pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \ 5 | --seed 0 \ 6 | --imagenet-val '/path/to/ImageNet/val' -------------------------------------------------------------------------------- /open_clip_customized/docs/clip_conceptual_captions.md: -------------------------------------------------------------------------------- 1 | ## Additional training curves for CLIP on Conceptual Captions 2 | 3 | # Zero shot accuracy 4 | ![](/docs/clip_zeroshot.png) 5 | 6 | # Training loss curve 7 | ![](/docs/clip_loss.png) 8 | 9 | # Validation loss curve 10 | ![](/docs/clip_val_loss.png) 11 | 12 | # Validation recall 13 | ![](/docs/clip_recall.png) -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-M-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-M-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-S-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-S-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-M-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-S-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-S-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-32-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/mt5-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "google/mt5-base", 11 | "hf_tokenizer_name": "google/mt5-base", 12 | "hf_pooler_type": "mean_pooler" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "xlm-roberta-base", 11 | "hf_tokenizer_name": "xlm-roberta-base", 12 | "hf_pooler_type": "mean_pooler" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/latent_ViT.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "in_chans": 4, 5 | "image_size": 64, 6 | "layers": 12, 7 | "width": 512, 8 | "patch_size": 4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-16-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-L-14-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 24, 7 | "width": 1024, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-M-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16, 8 | "ls_init_value": 1e-4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 384, 14 | "heads": 6, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/mt5-xl-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "google/mt5-xl", 12 | "hf_tokenizer_name": "google/mt5-xl", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /open_clip_customized/Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-training: 6 | python -m pip install -r requirements-training.txt 7 | 8 | install-test: ## [Local development] Install test requirements 9 | python -m pip install -r requirements-test.txt 10 | 11 | test: ## [Local development] Run unit tests 12 | python -m pytest -x -s -v tests 13 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "xlm-roberta-large", 12 | "hf_tokenizer_name": "xlm-roberta-large", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /open_clip_customized/src/training/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | 4 | 5 | def get_autocast(precision): 6 | if precision == 'amp': 7 | return torch.cuda.amp.autocast 8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16': 9 | # amp_bfloat16 is more stable than amp float16 for clip training 10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 11 | else: 12 | return suppress 13 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-e-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 56, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.5715, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 36 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-H-14-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 32, 7 | "width": 1280, 8 | "head_width": 80, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-H-14-378-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 378, 6 | "layers": 32, 7 | "width": 1280, 8 | "head_width": 80, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/nllb-clip-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "facebook/nllb-200-distilled-600M", 11 | "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", 12 | "hf_proj_type": "linear", 13 | "hf_pooler_type": "cls_pooler" 14 | } 15 | } -------------------------------------------------------------------------------- /open_clip_customized/scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python3 -m training.main \ 2 | --model ViT-H-14-CL32-GAP-BigVision \ 3 | --pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \ 4 | --force-image-size 336 \ 5 | --square-resize-only \ 6 | --interpolation 'bilinear' \ 7 | --image-mean 0.485 0.456 0.406 \ 8 | --image-std 0.229 0.224 0.225 \ 9 | --seed 0 \ 10 | --imagenet-val '/path/to/ImageNet/val' 11 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/RN50x64.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": [ 6 | 3, 7 | 15, 8 | 36, 9 | 10 10 | ], 11 | "width": 128, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 1024, 18 | "heads": 16, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/vit_medium_patch16_gap_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_medium_patch16_gap_256", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 256 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 640, 14 | "heads": 10, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/nllb-clip-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "facebook/nllb-200-distilled-1.3B", 12 | "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", 13 | "hf_proj_type": "linear", 14 | "hf_pooler_type": "cls_pooler" 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/EVA01-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/EVA01-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/EVA02-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_base_patch16_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/EVA02-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_large_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/EVA02-E-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/EVA02-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "timm_model_name": "eva02_large_patch14_clip_336", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/latent_RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "in_chans": 4, 5 | "image_size": 64, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24 21 | } 22 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/EVA02-E-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1280, 14 | "heads": 20, 15 | "layers": 32 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_base_w.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_large_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_small", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_tiny", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_base_w_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_large_d_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 20 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_xxlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/convnext_xxlarge_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/nllb-clip-base-siglip.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "custom_text": true, 4 | "init_logit_bias": -10, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_base_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "hf_model_name": "facebook/nllb-200-distilled-600M", 14 | "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", 15 | "hf_proj_type": "linear", 16 | "hf_pooler_type": "cls_pooler" 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/nllb-clip-large-siglip.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "custom_text": true, 4 | "init_logit_bias": -10, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_so400m_patch14_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "hf_model_name": "facebook/nllb-200-distilled-1.3B", 14 | "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", 15 | "hf_proj_type": "linear", 16 | "hf_pooler_type": "cls_pooler" 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/coca_roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "output_tokens": true 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "hf_proj_type": "linear", 14 | "width": 768, 15 | "output_tokens": true 16 | }, 17 | "multimodal_cfg": { 18 | "context_length": 76, 19 | "width": 768, 20 | "heads": 8, 21 | "layers": 12 22 | }, 23 | "custom_text": true 24 | } 25 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-L-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "no_ln_pre": true, 9 | "pool_type": "avg", 10 | "final_ln_after_pool": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 32, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "bert-base-uncased", 16 | "tokenizer_kwargs": { 17 | "strip_sep_token": true 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "pool_type": "last", 23 | "no_causal_mask": true 24 | } 25 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-L-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "no_ln_pre": true, 9 | "pool_type": "avg", 10 | "final_ln_after_pool": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 32, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "bert-base-uncased", 16 | "tokenizer_kwargs": { 17 | "strip_sep_token": true 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "pool_type": "last", 23 | "no_causal_mask": true 24 | } 25 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-H-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14, 9 | "no_ln_pre": true, 10 | "pool_type": "avg", 11 | "final_ln_after_pool": true 12 | }, 13 | "text_cfg": { 14 | "context_length": 32, 15 | "vocab_size": 32000, 16 | "hf_tokenizer_name": "bert-base-uncased", 17 | "tokenizer_kwargs": { 18 | "strip_sep_token": true 19 | }, 20 | "width": 1024, 21 | "heads": 16, 22 | "layers": 24, 23 | "pool_type": "last", 24 | "no_causal_mask": true 25 | } 26 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-H-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14, 9 | "no_ln_pre": true, 10 | "pool_type": "avg", 11 | "final_ln_after_pool": true 12 | }, 13 | "text_cfg": { 14 | "context_length": 32, 15 | "vocab_size": 32000, 16 | "hf_tokenizer_name": "bert-base-uncased", 17 | "tokenizer_kwargs": { 18 | "strip_sep_token": true 19 | }, 20 | "width": 1024, 21 | "heads": 16, 22 | "layers": 24, 23 | "pool_type": "last", 24 | "no_causal_mask": true 25 | } 26 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-bigG-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14, 10 | "no_ln_pre": true, 11 | "pool_type": "avg", 12 | "final_ln_after_pool": true 13 | }, 14 | "text_cfg": { 15 | "context_length": 32, 16 | "vocab_size": 32000, 17 | "hf_tokenizer_name": "bert-base-uncased", 18 | "tokenizer_kwargs": { 19 | "strip_sep_token": true 20 | }, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "pool_type": "last", 25 | "no_causal_mask": true 26 | } 27 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-bigG-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14, 10 | "no_ln_pre": true, 11 | "pool_type": "avg", 12 | "final_ln_after_pool": true 13 | }, 14 | "text_cfg": { 15 | "context_length": 32, 16 | "vocab_size": 32000, 17 | "hf_tokenizer_name": "bert-base-uncased", 18 | "tokenizer_kwargs": { 19 | "strip_sep_token": true 20 | }, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "pool_type": "last", 25 | "no_causal_mask": true 26 | } 27 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/coca_ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 512, 25 | "heads": 8, 26 | "layers": 12, 27 | "attn_pooler_heads": 8 28 | }, 29 | "custom_text": true 30 | } -------------------------------------------------------------------------------- /open_clip_customized/docs/script_examples/clipa/vit_b16/i50_t16_pretrain.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m training.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "2.048e-3" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 782 \ 11 | --wd 0.2 \ 12 | --batch-size 8192 \ 13 | --aug-cfg scale='(0.4, 1.0)' \ 14 | --epochs 6 \ 15 | --workers 6 \ 16 | --model ViT-B-16-CL16 \ 17 | --precision 'amp_bf16' \ 18 | --ddp-static-graph \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 112 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 32 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/coca_ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 768, 25 | "heads": 12, 26 | "layers": 12, 27 | "attn_pooler_heads": 12 28 | }, 29 | "custom_text": true 30 | } 31 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/coca_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "multimodal_cfg": { 4 | "width": 768, 5 | "context_length": 76, 6 | "vocab_size": 64000, 7 | "mlp_ratio": 4, 8 | "layers": 12, 9 | "dim_head": 64, 10 | "heads": 12, 11 | "n_queries": 256, 12 | "attn_pooler_heads": 8 13 | }, 14 | "vision_cfg": { 15 | "image_size": 288, 16 | "layers": 12, 17 | "width": 768, 18 | "patch_size": 18, 19 | "output_tokens": true 20 | }, 21 | "text_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 64000, 24 | "layers": 12, 25 | "heads": 12, 26 | "width": 768, 27 | "embed_cls": true, 28 | "output_tokens": true 29 | }, 30 | "custom_text": true 31 | } -------------------------------------------------------------------------------- /open_clip_customized/tests/test_num_shards.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from training.data import get_dataset_size 4 | 5 | @pytest.mark.parametrize( 6 | "shards,expected_size", 7 | [ 8 | ('/path/to/shard.tar', 1), 9 | ('/path/to/shard_{000..000}.tar', 1), 10 | ('/path/to/shard_{000..009}.tar', 10), 11 | ('/path/to/shard_{000..009}_{000..009}.tar', 100), 12 | ('/path/to/shard.tar::/path/to/other_shard_{000..009}.tar', 11), 13 | ('/path/to/shard_{000..009}.tar::/path/to/other_shard_{000..009}.tar', 20), 14 | (['/path/to/shard.tar'], 1), 15 | (['/path/to/shard.tar', '/path/to/other_shard.tar'], 2), 16 | ] 17 | ) 18 | def test_num_shards(shards, expected_size): 19 | _, size = get_dataset_size(shards) 20 | assert size == expected_size, f'Expected {expected_size} for {shards} but found {size} instead.' 21 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-16-SigLIP.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 224, 7 | "timm_model_name": "vit_base_patch16_siglip_224", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-16-SigLIP-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_base_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-16-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_base_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-16-SigLIP-512.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 512, 7 | "timm_model_name": "vit_base_patch16_siglip_512", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-L-16-SigLIP-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_large_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-L-16-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_large_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_customized/docs/script_examples/clipa/vit_b16/i50_t16_finetune.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m training.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "2.56e-5" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 3072 \ 11 | --wd 0.2 \ 12 | --batch-size 1024 \ 13 | --aug-cfg scale='(0.4, 1.0)' \ 14 | --epochs 1 \ 15 | --train-num-samples 131072000 \ 16 | --workers 6 \ 17 | --model ViT-B-16-CL16 \ 18 | --pretrained '/path/to/ckpt' \ 19 | --precision 'amp_bf16' \ 20 | --ddp-static-graph \ 21 | --local-loss \ 22 | --gather-with-grad \ 23 | --grad-checkpointing \ 24 | --log-every-n-steps 256 \ 25 | --seed 0 \ 26 | --logs ./logs/ \ 27 | --imagenet-val '/path/to/imagenet/val' 28 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_base_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 250000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_customized/docs/script_examples/clipa/vit_l16/i17_t16_pretrain.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m training.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "1.024e-3" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 1563 \ 11 | --wd 0.2 \ 12 | --batch-size 4096 \ 13 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 14 | --epochs 6 \ 15 | --workers 6 \ 16 | --model ViT-L-16-CL16-GAP \ 17 | --precision 'amp_bf16' \ 18 | --ddp-static-graph \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 64 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 64 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-SO400M-14-SigLIP.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 224, 7 | "timm_model_name": "vit_so400m_patch14_siglip_224", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 16, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1152, 20 | "heads": 16, 21 | "layers": 27, 22 | "mlp_ratio": 3.7362, 23 | "no_causal_mask": true, 24 | "proj_bias": true, 25 | "pool_type": "last", 26 | "norm_kwargs":{ 27 | "eps": 1e-6 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /open_clip_customized/docs/script_examples/clipa/vit_l16/i37_t8_pretrain.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m training.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "1.024e-3" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 1563 \ 11 | --wd 0.2 \ 12 | --batch-size 4096 \ 13 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 14 | --epochs 6 \ 15 | --workers 6 \ 16 | --model ViT-L-16-CL8-Syntax-GAP \ 17 | --precision 'amp_bf16' \ 18 | --ddp-static-graph \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 96 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 64 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_so400m_patch14_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1152, 20 | "heads": 16, 21 | "layers": 27, 22 | "mlp_ratio": 3.7362, 23 | "no_causal_mask": true, 24 | "proj_bias": true, 25 | "pool_type": "last", 26 | "norm_kwargs":{ 27 | "eps": 1e-6 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /open_clip_customized/docs/script_examples/clipa/vit_l16/i37_t8_finetune.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m training.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "2.24e-5" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 3571 \ 11 | --wd 0.2 \ 12 | --batch-size 896 \ 13 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 14 | --epochs 1 \ 15 | --train-num-samples 131072000 \ 16 | --workers 6 \ 17 | --model ViT-L-16-CL32-GAP \ 18 | --pretrained '/path/to/ckpt' \ 19 | --precision 'amp_bf16' \ 20 | --ddp-static-graph \ 21 | --local-loss \ 22 | --gather-with-grad \ 23 | --grad-checkpointing \ 24 | --log-every-n-steps 293 \ 25 | --seed 0 \ 26 | --logs ./logs/ \ 27 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /open_clip_customized/docs/script_examples/clipa/vit_l16/i17_t16_finetune.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m training.main \ 2 | --save-frequency 1 \ 3 | --save-most-recent \ 4 | --zeroshot-frequency 1 \ 5 | --train-data '/path/to/laion-400m' \ 6 | --dataset-type webdataset \ 7 | --lr "2.24e-5" \ 8 | --beta1 0.9 \ 9 | --beta2 0.95 \ 10 | --warmup 3571 \ 11 | --wd 0.2 \ 12 | --batch-size 896 \ 13 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 14 | --epochs 1 \ 15 | --train-num-samples 131072000 \ 16 | --workers 6 \ 17 | --model ViT-L-16-CL16-GAP \ 18 | --pretrained '/path/to/ckpt' \ 19 | --precision 'amp_bf16' \ 20 | --ddp-static-graph \ 21 | --local-loss \ 22 | --gather-with-grad \ 23 | --grad-checkpointing \ 24 | --log-every-n-steps 293 \ 25 | --seed 0 \ 26 | --logs ./logs/ \ 27 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /open_clip_customized/CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: If you use this software, please cite it as below. 3 | authors: 4 | - family-names: Ilharco 5 | given-names: Gabriel 6 | - family-names: Wortsman 7 | given-names: Mitchell 8 | - family-names: Wightman 9 | given-names: Ross 10 | - family-names: Gordon 11 | given-names: Cade 12 | - family-names: Carlini 13 | given-names: Nicholas 14 | - family-names: Taori 15 | given-names: Rohan 16 | - family-names: Dave 17 | given-names: Achal 18 | - family-names: Shankar 19 | given-names: Vaishaal 20 | - family-names: Namkoong 21 | given-names: Hongseok 22 | - family-names: Miller 23 | given-names: John 24 | - family-names: Hajishirzi 25 | given-names: Hannaneh 26 | - family-names: Farhadi 27 | given-names: Ali 28 | - family-names: Schmidt 29 | given-names: Ludwig 30 | title: OpenCLIP 31 | version: v0.1 32 | doi: 10.5281/zenodo.5143773 33 | date-released: 2021-07-28 34 | -------------------------------------------------------------------------------- /open_clip_customized/src/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /open_clip_customized/scripts/h14_224_32_finetune.sh: -------------------------------------------------------------------------------- 1 | # 64k batchsize for 2.048e-3 lr 2 | TORCH_CUDNN_V8_API_ENABLED=1 torchrun --nproc_per_node 8 -m training.main \ 3 | --save-frequency 1 \ 4 | --save-most-recent \ 5 | --zeroshot-frequency 1 \ 6 | --train-data '/path/to/laion' \ 7 | --dataset-type webdataset \ 8 | --lr "2.048e-3" \ 9 | --beta1 0.9 \ 10 | --beta2 0.95 \ 11 | --warmup 782 \ 12 | --wd 0.2 \ 13 | --batch-size 4096 \ 14 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 15 | --epochs=7 \ 16 | --workers=6 \ 17 | --model ViT-H-14-CL32-GAP \ 18 | --precision 'amp_bf16' \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 224 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 32 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/ImageNet/val' \ 27 | --name 'name' \ 28 | --report-to "wandb" \ 29 | --wandb-project-name "project_name" 30 | 31 | 32 | -------------------------------------------------------------------------------- /open_clip_customized/scripts/h14_84_8_pretrain.sh: -------------------------------------------------------------------------------- 1 | # 64k batchsize for 2.048e-3 lr 2 | TORCH_CUDNN_V8_API_ENABLED=1 torchrun --nproc_per_node 8 -m training.main \ 3 | --save-frequency 1 \ 4 | --save-most-recent \ 5 | --zeroshot-frequency 1 \ 6 | --train-data '/path/to/laion' \ 7 | --dataset-type webdataset \ 8 | --lr "2.048e-3" \ 9 | --beta1 0.9 \ 10 | --beta2 0.95 \ 11 | --warmup 782 \ 12 | --wd 0.2 \ 13 | --batch-size 4096 \ 14 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 15 | --epochs=7 \ 16 | --workers=6 \ 17 | --model ViT-H-14-CL8-SyntaxMask-GAP \ 18 | --precision 'amp_bf16' \ 19 | --local-loss \ 20 | --gather-with-grad \ 21 | --force-image-size 84 \ 22 | --grad-checkpointing \ 23 | --log-every-n-steps 32 \ 24 | --seed 0 \ 25 | --logs ./logs/ \ 26 | --imagenet-val '/path/to/ImageNet/val' \ 27 | --name 'name' \ 28 | --report-to "wandb" \ 29 | --wandb-project-name "project_name" 30 | 31 | 32 | -------------------------------------------------------------------------------- /open_clip_customized/tests/test_hf_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from open_clip.hf_model import _POOLERS, HFTextEncoder 5 | from transformers import AutoConfig 6 | from transformers.modeling_outputs import BaseModelOutput 7 | # test poolers 8 | def test_poolers(): 9 | bs, sl, d = 2, 10, 5 10 | h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d) 11 | mask = torch.ones(bs, sl, dtype=torch.bool) 12 | mask[:2, 6:] = False 13 | x = BaseModelOutput(h) 14 | for name, cls in _POOLERS.items(): 15 | pooler = cls() 16 | res = pooler(x, mask) 17 | assert res.shape == (bs, d), f"{name} returned wrong shape" 18 | 19 | # test HFTextEncoder 20 | @pytest.mark.parametrize("model_id", ["arampacha/roberta-tiny", "roberta-base", "xlm-roberta-base", "google/mt5-base"]) 21 | def test_pretrained_text_encoder(model_id): 22 | bs, sl, d = 2, 10, 64 23 | cfg = AutoConfig.from_pretrained(model_id) 24 | model = HFTextEncoder(model_id, d, proj_type='linear') 25 | x = torch.randint(0, cfg.vocab_size, (bs, sl)) 26 | with torch.no_grad(): 27 | emb = model(x) 28 | 29 | assert emb.shape == (bs, d) 30 | -------------------------------------------------------------------------------- /open_clip_customized/docs/script_examples/clipav2/vit_h14/i50_t8_pretrain.sh: -------------------------------------------------------------------------------- 1 | # have not been tested. use it at your own discretion 2 | # the original experiment was run on tpu v3-256. 3 | # this example script assumes 8 gpus, each with huge memory. Tune batchsize, warmup, and lr accordingly if you have different machine setups. 4 | torchrun --nproc_per_node 8 -m training.main \ 5 | --save-frequency 1 \ 6 | --save-most-recent \ 7 | --zeroshot-frequency 1 \ 8 | --train-data '/path/to/laion2b_or_datacomp1b' \ 9 | --train-num-samples 4e8 \ 10 | --dataset-type webdataset \ 11 | --lr "2.048e-3" \ 12 | --beta1 0.9 \ 13 | --beta2 0.95 \ 14 | --warmup 3200 \ 15 | --wd 0.2 \ 16 | --batch-size 8192 \ 17 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 18 | --epochs 32 \ 19 | --workers 6 \ 20 | --model ViT-H-14-CL8-Syntax-GAP \ 21 | --precision 'amp_bf16' \ 22 | --ddp-static-graph \ 23 | --local-loss \ 24 | --gather-with-grad \ 25 | --force-image-size 84 \ 26 | --grad-checkpointing \ 27 | --log-every-n-steps 32 \ 28 | --seed 0 \ 29 | --logs ./logs/ \ 30 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /open_clip_customized/docs/script_examples/clipav2/vit_h14/i257_t32_finetunex4.sh: -------------------------------------------------------------------------------- 1 | # have not been tested. use it at your own discretion 2 | # the original experiment was run on tpu v3-256. 3 | # this example script assumes 8 gpus, each with huge memory. Tune batchsize, warmup, and lr accordingly if you have different machine setups. 4 | torchrun --nproc_per_node 8 -m training.main \ 5 | --save-frequency 1 \ 6 | --save-most-recent \ 7 | --zeroshot-frequency 1 \ 8 | --train-data '/path/to/laion2b_or_datacomp1b' \ 9 | --train-num-samples 131072000 \ 10 | --dataset-type webdataset \ 11 | --lr "5.12e-5" \ 12 | --beta1 0.9 \ 13 | --beta2 0.95 \ 14 | --warmup 800 \ 15 | --wd 0.2 \ 16 | --batch-size 4096 \ 17 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 18 | --epochs 4 \ 19 | --workers 6 \ 20 | --model ViT-H-14-CL32-GAP \ 21 | --pretrained '/path/to/pretrain84_ckpt' \ 22 | --precision 'amp_bf16' \ 23 | --ddp-static-graph \ 24 | --local-loss \ 25 | --gather-with-grad \ 26 | --force-image-size 224 \ 27 | --force-patch-dropout 0.3 \ 28 | --grad-checkpointing \ 29 | --log-every-n-steps 64 \ 30 | --seed 0 \ 31 | --logs ./logs/ \ 32 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /open_clip_customized/docs/script_examples/clipav2/vit_h14/i577_t32_finetunex1.sh: -------------------------------------------------------------------------------- 1 | # have not been tested. use it at your own discretion 2 | # the original experiment was run on tpu v3-256. 3 | # this example script assumes 8 gpus, each with huge memory. Tune batchsize, warmup, and lr accordingly if you have different machine setups. 4 | torchrun --nproc_per_node 8 -m training.main \ 5 | --save-frequency 1 \ 6 | --save-most-recent \ 7 | --zeroshot-frequency 1 \ 8 | --train-data '/path/to/laion2b_or_datacomp1b' \ 9 | --train-num-samples 131072000 \ 10 | --dataset-type webdataset \ 11 | --lr "6.4e-6" \ 12 | --beta1 0.9 \ 13 | --beta2 0.95 \ 14 | --warmup 1600 \ 15 | --wd 0.2 \ 16 | --batch-size 2048 \ 17 | --aug-cfg scale='(0.4, 1.0)' color_jitter='(0.32, 0.32, 0.32, 0.08)' color_jitter_prob=0.8 gray_scale_prob=0.2 \ 18 | --epochs 1 \ 19 | --workers 6 \ 20 | --model ViT-H-14-CL32-GAP \ 21 | --pretrained '/path/to/finetune224_ckpt' \ 22 | --precision 'amp_bf16' \ 23 | --ddp-static-graph \ 24 | --local-loss \ 25 | --gather-with-grad \ 26 | --force-image-size 336 \ 27 | --force-patch-dropout 0.4 \ 28 | --grad-checkpointing \ 29 | --log-every-n-steps 64 \ 30 | --seed 0 \ 31 | --logs ./logs/ \ 32 | --imagenet-val '/path/to/imagenet/val' -------------------------------------------------------------------------------- /open_clip_customized/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, 2 | Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, 3 | John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, 4 | Ludwig Schmidt 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining 7 | a copy of this software and associated documentation files (the 8 | "Software"), to deal in the Software without restriction, including 9 | without limitation the rights to use, copy, modify, merge, publish, 10 | distribute, sublicense, and/or sell copies of the Software, and to 11 | permit persons to whom the Software is furnished to do so, subject to 12 | the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be 15 | included in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 21 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 22 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 23 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ 8 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg 9 | from .openai import load_openai_model, list_openai_models 10 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 11 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 12 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 13 | from .tokenizer import SimpleTokenizer, tokenize, decode 14 | from .transform import image_transform, AugmentationCfg 15 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy 16 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES 17 | -------------------------------------------------------------------------------- /ddim_solver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from common_utils import extract_into_tensor 4 | 5 | 6 | class DDIMSolver: 7 | def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): 8 | # DDIM sampling parameters 9 | step_ratio = timesteps // ddim_timesteps 10 | self.ddim_timesteps = ( 11 | np.arange(1, ddim_timesteps + 1) * step_ratio 12 | ).round().astype(np.int64) - 1 13 | self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] 14 | self.ddim_alpha_cumprods_prev = np.asarray( 15 | [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() 16 | ) 17 | # convert to torch tensors 18 | self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() 19 | self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) 20 | self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) 21 | 22 | def to(self, device): 23 | self.ddim_timesteps = self.ddim_timesteps.to(device) 24 | self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) 25 | self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) 26 | return self 27 | 28 | def ddim_step(self, pred_x0, pred_noise, timestep_index): 29 | alpha_cumprod_prev = extract_into_tensor( 30 | self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape 31 | ) 32 | dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise 33 | x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt 34 | return x_prev 35 | 36 | -------------------------------------------------------------------------------- /open_clip_customized/tests/test_inference_simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from open_clip.factory import get_tokenizer 4 | import pytest 5 | import open_clip 6 | import os 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 8 | 9 | if hasattr(torch._C, '_jit_set_profiling_executor'): 10 | # legacy executor is too slow to compile large models for unit tests 11 | # no need for the fusion performance here 12 | torch._C._jit_set_profiling_executor(True) 13 | torch._C._jit_set_profiling_mode(False) 14 | 15 | 16 | test_simple_models = [ 17 | # model, pretrained, jit, force_custom_text 18 | ("ViT-B-32", "laion2b_s34b_b79k", False, False), 19 | ("ViT-B-32", "laion2b_s34b_b79k", True, False), 20 | ("ViT-B-32", "laion2b_s34b_b79k", True, True), 21 | ("roberta-ViT-B-32", "laion2b_s12b_b32k", False, False), 22 | ] 23 | 24 | 25 | @pytest.mark.parametrize("model_type,pretrained,jit,force_custom_text", test_simple_models) 26 | def test_inference_simple( 27 | model_type, 28 | pretrained, 29 | jit, 30 | force_custom_text, 31 | ): 32 | model, _, preprocess = open_clip.create_model_and_transforms( 33 | model_type, 34 | pretrained=pretrained, 35 | jit=jit, 36 | force_custom_text=force_custom_text, 37 | ) 38 | tokenizer = get_tokenizer(model_type) 39 | 40 | current_dir = os.path.dirname(os.path.realpath(__file__)) 41 | 42 | image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0) 43 | text = tokenizer(["a diagram", "a dog", "a cat"]) 44 | 45 | with torch.no_grad(): 46 | image_features = model.encode_image(image) 47 | text_features = model.encode_text(text) 48 | 49 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 50 | 51 | assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] 52 | -------------------------------------------------------------------------------- /open_clip_customized/docs/script_examples/stability_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=g40423 3 | #SBATCH --job-name=testopenclip 4 | #SBATCH --nodes 30 5 | #SBATCH --ntasks-per-node=8 6 | #SBATCH --cpus-per-task=12 7 | #SBATCH --output=%x_%j.out 8 | #SBATCH --comment=laion 9 | #SBATCH --open-mode=append 10 | #SBATCH --exclusive 11 | 12 | module load openmpi 13 | module load cuda/11.7 14 | 15 | export MASTER_ADDR=`hostname` 16 | export MASTER_PORT=12802 17 | export NCCL_PROTO=simple 18 | export FI_EFA_FORK_SAFE=1 19 | export FI_LOG_LEVEL=1 20 | export FI_EFA_USE_DEVICE_RDMA=1 21 | export NCCL_DEBUG=info 22 | 23 | export PYTHONFAULTHANDLER=1 24 | 25 | export CUDA_LAUNCH_BLOCKING=0 26 | export OMPI_MCA_mtl_base_verbose=1 27 | export FI_EFA_ENABLE_SHM_TRANSFER=0 28 | export FI_PROVIDER=efa 29 | export FI_EFA_TX_MIN_CREDITS=64 30 | export NCCL_TREE_THRESHOLD=0 31 | 32 | cd /admin/home-mitchellw/open_clip/src 33 | export PYTHONPATH="$PYTHONPATH:/admin/home-mitchellw/open_clip/src" 34 | 35 | EXP_NAME="test-B-32-laion5b-lr1e-3-bs90k" 36 | 37 | srun --comment laion --cpu_bind=v --accel-bind=gn python -m training.main \ 38 | --save-frequency 1 \ 39 | --train-data="pipe:aws s3 cp s3://s-datasets/laion5b/{laion2B-data/{000000..231349}.tar,laion2B-multi-data/{000000..226687}.tar,laion1B-nolang-data/{000000..127231}.tar} -" \ 40 | --train-num-samples 135646078 \ 41 | --dataset-type webdataset \ 42 | --dataset-resampled \ 43 | --warmup 2000 \ 44 | --batch-size=375 \ 45 | --epochs=97 \ 46 | --lr 1e-3 \ 47 | --workers=8 \ 48 | --report-to wandb \ 49 | --name ${EXP_NAME} \ 50 | --logs /scratch/logs/ \ 51 | --model ViT-B-32 \ 52 | --seed 0 \ 53 | --ddp-static-graph \ 54 | --local-loss \ 55 | --gather-with-grad \ 56 | --grad-checkpointing \ 57 | --precision amp_bfloat16 \ 58 | --wandb-project-name open_clip6 \ 59 | --resume "latest" \ 60 | --remote-sync s3://s-laion/mitchellw/logs 61 | -------------------------------------------------------------------------------- /open_clip_customized/src/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | return _lr_adjuster 22 | 23 | 24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): 25 | def _lr_adjuster(step): 26 | start_cooldown_step = steps - cooldown_steps 27 | if step < warmup_length: 28 | lr = _warmup_lr(base_lr, warmup_length, step) 29 | else: 30 | if step < start_cooldown_step: 31 | lr = base_lr 32 | else: 33 | e = step - start_cooldown_step 34 | es = steps - start_cooldown_step 35 | # linear decay if power == 1; polynomial decay otherwise; 36 | decay = (1 - (e/es)) ** cooldown_power 37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 38 | assign_learning_rate(optimizer, lr) 39 | return lr 40 | return _lr_adjuster 41 | 42 | 43 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 44 | def _lr_adjuster(step): 45 | if step < warmup_length: 46 | lr = _warmup_lr(base_lr, warmup_length, step) 47 | else: 48 | e = step - warmup_length 49 | es = steps - warmup_length 50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 51 | assign_learning_rate(optimizer, lr) 52 | return lr 53 | return _lr_adjuster 54 | -------------------------------------------------------------------------------- /open_clip_customized/setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | def _read_reqs(relpath): 14 | fullpath = path.join(path.dirname(__file__), relpath) 15 | with open(fullpath) as f: 16 | return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))] 17 | 18 | REQUIREMENTS = _read_reqs("requirements.txt") 19 | TRAINING_REQUIREMENTS = _read_reqs("requirements-training.txt") 20 | 21 | exec(open('src/open_clip/version.py').read()) 22 | setup( 23 | name='open_clip_torch', 24 | version=__version__, 25 | description='OpenCLIP', 26 | long_description=long_description, 27 | long_description_content_type='text/markdown', 28 | url='https://github.com/mlfoundations/open_clip', 29 | author='', 30 | author_email='', 31 | classifiers=[ 32 | # How mature is this project? Common values are 33 | # 3 - Alpha 34 | # 4 - Beta 35 | # 5 - Production/Stable 36 | 'Development Status :: 3 - Alpha', 37 | 'Intended Audience :: Education', 38 | 'Intended Audience :: Science/Research', 39 | 'License :: OSI Approved :: Apache Software License', 40 | 'Programming Language :: Python :: 3.7', 41 | 'Programming Language :: Python :: 3.8', 42 | 'Programming Language :: Python :: 3.9', 43 | 'Programming Language :: Python :: 3.10', 44 | 'Topic :: Scientific/Engineering', 45 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 46 | 'Topic :: Software Development', 47 | 'Topic :: Software Development :: Libraries', 48 | 'Topic :: Software Development :: Libraries :: Python Modules', 49 | ], 50 | 51 | # Note that this is a string of words separated by whitespace, not a list. 52 | keywords='CLIP pretrained', 53 | package_dir={'': 'src'}, 54 | packages=find_packages(where='src'), 55 | include_package_data=True, 56 | install_requires=REQUIREMENTS, 57 | extras_require={ 58 | "training": TRAINING_REQUIREMENTS, 59 | }, 60 | python_requires='>=3.7', 61 | ) 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reward Guided Latent Consistency Distillation 2 | 3 | ## 🔥News 4 | 5 | - (🔥New) 10/09/2024 We release the training codes! 6 | - (🔥New) 05/28/2024 We release the model weights and the local gradio demo! The model weights can be download from [here](https://huggingface.co/jiachenli-ucsb/RG-LCM-SD-2.1-768-HPSv2.1). We will release other model weights soon! 7 | - 03/18/2024 Our repo for RG-LCD is created. We will release our codes and models very soon!! Please stay tuned! 8 | 9 | ## 🏭 Installation 10 | 11 | ``` 12 | pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 diffusers transformers accelerate gradio webdataset accelerate open_clip_torch gradio==3.48.0 13 | ``` 14 | 15 | ## ✅ Local gradio Demos (Text-to-Image): 16 | 17 | Launch the gradio: (For MacOS users, need to set the device="mps" in app.py; For Intel GPU users, set device="xpu" in app.py) 18 | 19 | ``` 20 | python local_gradio/app.py --model_name MODEL_NAME 21 | ``` 22 | 23 | You can find the currently available models at [here](https://huggingface.co/jiachenli-ucsb) with the prefix `RG-LCM`. By default, `MODEL_NAME` is set to `jiachenli-ucsb/RG-LCM-SD-2.1-768-HPSv2.1`, which is ditilled from [Stable Diffusion 2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) with the reward feedback from [HPSv2.1](https://github.com/tgxs002/HPSv2/tree/master). 24 | 25 | ## 🏋️ Training commands 26 | 27 | To perform RG-LCD with the HPSv2.1, we can run 28 | 29 | ```python 30 | accelerate launch main.py \ 31 | --output_dir=PATH_TO_LOG \ 32 | --gradient_checkpointing \ 33 | --use_8bit_adam \ 34 | --enable_xformers_memory_efficient_attention \ 35 | --resolution 768 \ 36 | --allow_tf32 \ 37 | --mixed_precision bf 16 \ 38 | --train_shards_path_or_url "pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01090}.tar?download=true" \ 39 | --optimize_reward_fn \ 40 | --direct_optim_expert_reward \ 41 | --reward_fn_name hpsv2 \ 42 | --reward_scale 1 43 | ``` 44 | 45 | ## 📃 Citation 46 | 47 | ``` 48 | @article{ 49 | li2024reward, 50 | title={Reward Guided Latent Consistency Distillation}, 51 | author={Jiachen Li and Weixi Feng and Wenhu Chen and William Yang Wang}, 52 | journal={Transactions on Machine Learning Research}, 53 | issn={2835-8856}, 54 | year={2024}, 55 | url={https://openreview.net/forum?id=z116TO4LDT}, 56 | note={Featured Certification} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | # https://huggingface.co/docs/transformers/model_doc/bert 46 | "bert": { 47 | "config_names": { 48 | "context_length": "max_position_embeddings", 49 | "vocab_size": "vocab_size", 50 | "width": "hidden_size", 51 | "heads": "num_attention_heads", 52 | "layers": "num_hidden_layers", 53 | }, 54 | "pooler": "cls_pooler", 55 | }, 56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100 57 | "m2m_100": { 58 | "config_names": { 59 | "context_length": "max_position_embeddings", 60 | "vocab_size": "vocab_size", 61 | "width": "d_model", 62 | "heads": "encoder_attention_heads", 63 | "layers": "encoder_layers", 64 | }, 65 | "pooler": "cls_pooler", 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /open_clip_customized/.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | wandb/ 3 | models/ 4 | features/ 5 | results/ 6 | 7 | tests/data/ 8 | *.pt 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | sync.sh 140 | gpu1sync.sh 141 | .idea 142 | *.pdf 143 | **/._* 144 | **/*DS_* 145 | **.jsonl 146 | src/sbatch 147 | src/misc 148 | .vscode 149 | src/debug 150 | core.* 151 | 152 | # Allow 153 | !src/evaluation/misc/results_dbs/* -------------------------------------------------------------------------------- /open_clip_customized/docs/LOW_ACC.md: -------------------------------------------------------------------------------- 1 | As we describe in more detail below, CLIP models in a medium accuracy regime already allow us to draw conclusions about the robustness of larger CLIP models since the models follow reliable scaling laws. 2 | 3 | [Cherti et al., 2022](https://arxiv.org/abs/2212.07143) and [Gadre et al., 2023](https://arxiv.org/abs/2304.14108) show additional discussions about the scaling behavior of CLIP models. 4 | 5 | ## Scaling trends 6 | 7 | The plot below shows how zero-shot performance of CLIP models varies as we scale the number of samples used for training. Zero-shot performance increases steadily for both ImageNet and [ImageNetV2](https://arxiv.org/abs/1902.10811), and is far from saturated at ~15M samples. 8 | 9 | 10 | 11 | ## Why are low-accuracy CLIP models interesting? 12 | 13 | **TL;DR:** CLIP models have high effective robustness, even at small scales. 14 | 15 | CLIP models are particularly intriguing because they are more robust to natural distribution shifts (see Section 3.3 in the [CLIP paper](https://arxiv.org/abs/2103.00020)). 16 | This phenomena is illustrated by the figure below, with ImageNet accuracy on the x-axis 17 | and [ImageNetV2](https://arxiv.org/abs/1902.10811) (a reproduction of the ImageNet validation set with distribution shift) accuracy on the y-axis. 18 | Standard training denotes training on the ImageNet train set and the CLIP zero-shot models 19 | are shown as stars. 20 | 21 | ![CLIP scatter plot](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/effective_robustness.png) 22 | 23 | As observed by [Taori et al., 2020](https://arxiv.org/abs/2007.00644) and [Miller et al., 2021](https://arxiv.org/abs/2107.04649), the in-distribution 24 | and out-of-distribution accuracies of models trained on ImageNet follow a predictable linear trend (the red line in the above plot). *Effective robustness* 25 | quantifies robustness as accuracy beyond this baseline, i.e., how far a model lies above the red line. Ideally a model would not suffer from distribution shift and fall on the y = x line ([trained human labelers are within a percentage point of the y = x line](http://proceedings.mlr.press/v119/shankar20c.html)). 26 | 27 | Even though the CLIP models trained with 28 | this codebase achieve much lower accuracy than those trained by OpenAI, our models still lie on the same 29 | trend of improved effective robustness (the purple line). Therefore, we can study what makes 30 | CLIP robust without requiring industrial-scale compute. 31 | 32 | For more information on effective robustness, please see: 33 | 34 | - [Recht et al., 2019](https://arxiv.org/abs/1902.10811). 35 | - [Taori et al., 2020](https://arxiv.org/abs/2007.00644). 36 | - [Miller et al., 2021](https://arxiv.org/abs/2107.04649). 37 | 38 | To know more about the factors that contribute to CLIP's robustness refer to [Fang et al., 2022](https://arxiv.org/abs/2205.01397). -------------------------------------------------------------------------------- /open_clip_customized/src/training/file_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import multiprocessing 4 | import subprocess 5 | import time 6 | import fsspec 7 | import torch 8 | from tqdm import tqdm 9 | 10 | def remote_sync_s3(local_dir, remote_dir): 11 | # skip epoch_latest which can change during sync. 12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | if result.returncode != 0: 14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") 15 | return False 16 | 17 | logging.info(f"Successfully synced with S3 bucket") 18 | return True 19 | 20 | def remote_sync_fsspec(local_dir, remote_dir): 21 | # FIXME currently this is slow and not recommended. Look into speeding up. 22 | a = fsspec.get_mapper(local_dir) 23 | b = fsspec.get_mapper(remote_dir) 24 | 25 | for k in a: 26 | # skip epoch_latest which can change during sync. 27 | if 'epoch_latest.pt' in k: 28 | continue 29 | 30 | logging.info(f'Attempting to sync {k}') 31 | if k in b and len(a[k]) == len(b[k]): 32 | logging.debug(f'Skipping remote sync for {k}.') 33 | continue 34 | 35 | try: 36 | logging.info(f'Successful sync for {k}.') 37 | b[k] = a[k] 38 | except Exception as e: 39 | logging.info(f'Error during remote sync for {k}: {e}') 40 | return False 41 | 42 | return True 43 | 44 | def remote_sync(local_dir, remote_dir, protocol): 45 | logging.info('Starting remote sync.') 46 | if protocol == 's3': 47 | return remote_sync_s3(local_dir, remote_dir) 48 | elif protocol == 'fsspec': 49 | return remote_sync_fsspec(local_dir, remote_dir) 50 | else: 51 | logging.error('Remote protocol not known') 52 | return False 53 | 54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): 55 | while True: 56 | time.sleep(sync_every) 57 | remote_sync(local_dir, remote_dir, protocol) 58 | 59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol): 60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) 61 | return p 62 | 63 | # Note: we are not currently using this save function. 64 | def pt_save(pt_obj, file_path): 65 | of = fsspec.open(file_path, "wb") 66 | with of as f: 67 | torch.save(pt_obj, file_path) 68 | 69 | def pt_load(file_path, map_location=None): 70 | if file_path.startswith('s3'): 71 | logging.info('Loading remote checkpoint, which may take a bit.') 72 | of = fsspec.open(file_path, "rb") 73 | with of as f: 74 | out = torch.load(f, map_location=map_location) 75 | return out 76 | 77 | def check_exists(file_path): 78 | try: 79 | with fsspec.open(file_path): 80 | pass 81 | except FileNotFoundError: 82 | return False 83 | return True 84 | -------------------------------------------------------------------------------- /open_clip_customized/src/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ 7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES 8 | from .precision import get_autocast 9 | 10 | 11 | def accuracy(output, target, topk=(1,)): 12 | pred = output.topk(max(topk), 1, True, True)[1].t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 15 | 16 | 17 | def run(model, classifier, dataloader, args): 18 | autocast = get_autocast(args.precision) 19 | input_dtype = get_input_dtype(args.precision) 20 | 21 | with torch.no_grad(): 22 | top1, top5, n = 0., 0., 0. 23 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 24 | images = images.to(device=args.device, dtype=input_dtype) 25 | target = target.to(args.device) 26 | 27 | with autocast(): 28 | # predict 29 | output = model(image=images) 30 | image_features = output['image_features'] if isinstance(output, dict) else output[0] 31 | logits = 100. * image_features @ classifier 32 | 33 | # measure accuracy 34 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 35 | top1 += acc1 36 | top5 += acc5 37 | n += images.size(0) 38 | 39 | top1 = (top1 / n) 40 | top5 = (top5 / n) 41 | return top1, top5 42 | 43 | 44 | def zero_shot_eval(model, data, epoch, args, tokenizer=None): 45 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 46 | return {} 47 | if args.zeroshot_frequency == 0: 48 | return {} 49 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 50 | return {} 51 | if args.distributed and not args.horovod: 52 | model = model.module 53 | 54 | logging.info('Starting zero-shot imagenet.') 55 | if tokenizer is None: 56 | tokenizer = get_tokenizer(args.model) 57 | 58 | logging.info('Building zero-shot classifier') 59 | autocast = get_autocast(args.precision) 60 | with autocast(): 61 | classifier = build_zero_shot_classifier( 62 | model, 63 | tokenizer=tokenizer, 64 | classnames=IMAGENET_CLASSNAMES, 65 | templates=OPENAI_IMAGENET_TEMPLATES, 66 | num_classes_per_batch=10, 67 | device=args.device, 68 | use_tqdm=True, 69 | ) 70 | 71 | logging.info('Using classifier') 72 | results = {} 73 | if 'imagenet-val' in data: 74 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 75 | results['imagenet-zeroshot-val-top1'] = top1 76 | results['imagenet-zeroshot-val-top5'] = top5 77 | if 'imagenet-v2' in data: 78 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 79 | results['imagenetv2-zeroshot-val-top1'] = top1 80 | results['imagenetv2-zeroshot-val-top5'] = top5 81 | 82 | logging.info('Finished zero-shot imagenet.') 83 | 84 | return results 85 | -------------------------------------------------------------------------------- /open_clip_customized/tests/test_training_simple.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import pytest 5 | from PIL import Image 6 | import torch 7 | from training.main import main 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 10 | 11 | if hasattr(torch._C, '_jit_set_profiling_executor'): 12 | # legacy executor is too slow to compile large models for unit tests 13 | # no need for the fusion performance here 14 | torch._C._jit_set_profiling_executor(True) 15 | torch._C._jit_set_profiling_mode(False) 16 | 17 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 18 | def test_training(): 19 | main([ 20 | '--save-frequency', '1', 21 | '--zeroshot-frequency', '1', 22 | '--dataset-type', "synthetic", 23 | '--train-num-samples', '16', 24 | '--warmup', '1', 25 | '--batch-size', '4', 26 | '--lr', '1e-3', 27 | '--wd', '0.1', 28 | '--epochs', '1', 29 | '--workers', '2', 30 | '--model', 'RN50' 31 | ]) 32 | 33 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 34 | def test_training_coca(): 35 | main([ 36 | '--save-frequency', '1', 37 | '--zeroshot-frequency', '1', 38 | '--dataset-type', "synthetic", 39 | '--train-num-samples', '16', 40 | '--warmup', '1', 41 | '--batch-size', '4', 42 | '--lr', '1e-3', 43 | '--wd', '0.1', 44 | '--epochs', '1', 45 | '--workers', '2', 46 | '--model', 'coca_ViT-B-32' 47 | ]) 48 | 49 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 50 | def test_training_mt5(): 51 | main([ 52 | '--save-frequency', '1', 53 | '--zeroshot-frequency', '1', 54 | '--dataset-type', "synthetic", 55 | '--train-num-samples', '16', 56 | '--warmup', '1', 57 | '--batch-size', '4', 58 | '--lr', '1e-3', 59 | '--wd', '0.1', 60 | '--epochs', '1', 61 | '--workers', '2', 62 | '--model', 'mt5-base-ViT-B-32', 63 | '--lock-text', 64 | '--lock-text-unlocked-layers', '2' 65 | ]) 66 | 67 | 68 | 69 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 70 | def test_training_unfreezing_vit(): 71 | main([ 72 | '--save-frequency', '1', 73 | '--zeroshot-frequency', '1', 74 | '--dataset-type', "synthetic", 75 | '--train-num-samples', '16', 76 | '--warmup', '1', 77 | '--batch-size', '4', 78 | '--lr', '1e-3', 79 | '--wd', '0.1', 80 | '--epochs', '1', 81 | '--workers', '2', 82 | '--model', 'ViT-B-32', 83 | '--lock-image', 84 | '--lock-image-unlocked-groups', '5', 85 | '--accum-freq', '2' 86 | ]) 87 | 88 | 89 | @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") 90 | def test_training_clip_with_jit(): 91 | main([ 92 | '--save-frequency', '1', 93 | '--zeroshot-frequency', '1', 94 | '--dataset-type', "synthetic", 95 | '--train-num-samples', '16', 96 | '--warmup', '1', 97 | '--batch-size', '4', 98 | '--lr', '1e-3', 99 | '--wd', '0.1', 100 | '--epochs', '1', 101 | '--workers', '2', 102 | '--model', 'ViT-B-32', 103 | '--torchscript' 104 | ]) 105 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 15 | 16 | __all__ = ["list_openai_models", "load_openai_model"] 17 | 18 | 19 | def list_openai_models() -> List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model 91 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torchvision.ops.misc import FrozenBatchNorm2d 7 | 8 | 9 | def freeze_batch_norm_2d(module, module_match={}, name=''): 10 | """ 11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 13 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 14 | 15 | Args: 16 | module (torch.nn.Module): Any PyTorch module. 17 | module_match (dict): Dictionary of full module names to freeze (all if empty) 18 | name (str): Full module name (prefix) 19 | 20 | Returns: 21 | torch.nn.Module: Resulting module 22 | 23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 24 | """ 25 | res = module 26 | is_match = True 27 | if module_match: 28 | is_match = name in module_match 29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 30 | res = FrozenBatchNorm2d(module.num_features) 31 | res.num_features = module.num_features 32 | res.affine = module.affine 33 | if module.affine: 34 | res.weight.data = module.weight.data.clone().detach() 35 | res.bias.data = module.bias.data.clone().detach() 36 | res.running_mean.data = module.running_mean.data 37 | res.running_var.data = module.running_var.data 38 | res.eps = module.eps 39 | else: 40 | for child_name, child in module.named_children(): 41 | full_child_name = '.'.join([name, child_name]) if name else child_name 42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 43 | if new_child is not child: 44 | res.add_module(child_name, new_child) 45 | return res 46 | 47 | 48 | # From PyTorch internals 49 | def _ntuple(n): 50 | def parse(x): 51 | if isinstance(x, collections.abc.Iterable): 52 | return x 53 | return tuple(repeat(x, n)) 54 | return parse 55 | 56 | 57 | to_1tuple = _ntuple(1) 58 | to_2tuple = _ntuple(2) 59 | to_3tuple = _ntuple(3) 60 | to_4tuple = _ntuple(4) 61 | to_ntuple = lambda n, x: _ntuple(n)(x) 62 | 63 | # Replaces all linear layers with linear_replacement 64 | # TODO: add int8 support for other linear layers including attn and convnets 65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 66 | for name, module in model.named_children(): 67 | if len(list(module.children())) > 0: 68 | replace_linear(module, linear_replacement, include_modules, copy_weights) 69 | 70 | if isinstance(module, torch.nn.Linear) and name in include_modules: 71 | old_module = model._modules[name] 72 | model._modules[name] = linear_replacement( 73 | module.in_features, 74 | module.out_features, 75 | module.bias is not None, 76 | ) 77 | if copy_weights: 78 | model._modules[name].weight.data.copy_(old_module.weight.data) 79 | if model._modules[name].bias is not None: 80 | model._modules[name].bias.data.copy_(old_module.bias) 81 | 82 | return model 83 | 84 | def convert_int8_model_to_inference_mode(model): 85 | for m in model.modules(): 86 | if hasattr(m, 'prepare_for_eval'): 87 | int8_original_dtype = m.weight.dtype 88 | m.prepare_for_eval() 89 | m.int8_original_dtype = int8_original_dtype -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/zero_shot_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | from typing import Callable, List, Optional, Sequence, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def batched(iterable, n): 10 | """Batch data into lists of length *n*. The last batch may be shorter. 11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl 12 | """ 13 | it = iter(iterable) 14 | while True: 15 | batch = list(islice(it, n)) 16 | if not batch: 17 | break 18 | yield batch 19 | 20 | 21 | def build_zero_shot_classifier( 22 | model, 23 | tokenizer, 24 | classnames: Sequence[str], 25 | templates: Sequence[Union[Callable, str]], 26 | num_classes_per_batch: Optional[int] = 10, 27 | device: Union[str, torch.device] = 'cpu', 28 | use_tqdm: bool = False, 29 | ): 30 | """ Build zero-shot classifier weights by iterating over class names in batches 31 | Args: 32 | model: CLIP model instance 33 | tokenizer: CLIP tokenizer instance 34 | classnames: A sequence of class (label) names 35 | templates: A sequence of callables or format() friendly strings to produce templates per class name 36 | num_classes_per_batch: The number of classes to batch together in each forward, all if None 37 | device: Device to use. 38 | use_tqdm: Enable TQDM progress bar. 39 | """ 40 | assert isinstance(templates, Sequence) and len(templates) > 0 41 | assert isinstance(classnames, Sequence) and len(classnames) > 0 42 | use_format = isinstance(templates[0], str) 43 | num_templates = len(templates) 44 | num_classes = len(classnames) 45 | if use_tqdm: 46 | import tqdm 47 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) 48 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) 49 | else: 50 | iter_wrap = iter 51 | 52 | def _process_batch(batch_classnames): 53 | num_batch_classes = len(batch_classnames) 54 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] 55 | texts = tokenizer(texts).to(device) 56 | class_embeddings = model.encode_text(texts, normalize=True) 57 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) 58 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) 59 | class_embeddings = class_embeddings.T 60 | return class_embeddings 61 | 62 | with torch.no_grad(): 63 | if num_classes_per_batch: 64 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] 65 | zeroshot_weights = torch.cat(batched_embeds, dim=1) 66 | else: 67 | zeroshot_weights = _process_batch(classnames) 68 | return zeroshot_weights 69 | 70 | 71 | def build_zero_shot_classifier_legacy( 72 | model, 73 | tokenizer, 74 | classnames: Sequence[str], 75 | templates: Sequence[Union[Callable, str]], 76 | device: Union[str, torch.device] = 'cpu', 77 | use_tqdm: bool = False, 78 | ): 79 | """ Build zero-shot classifier weights by iterating over class names 1 by 1 80 | Args: 81 | model: CLIP model instance 82 | tokenizer: CLIP tokenizer instance 83 | classnames: A sequence of class (label) names 84 | templates: A sequence of callables or format() friendly strings to produce templates per class name 85 | device: Device to use. 86 | use_tqdm: Enable TQDM progress bar. 87 | """ 88 | assert isinstance(templates, Sequence) and len(templates) > 0 89 | assert isinstance(classnames, Sequence) and len(classnames) > 0 90 | if use_tqdm: 91 | import tqdm 92 | iter_wrap = tqdm.tqdm 93 | else: 94 | iter_wrap = iter 95 | 96 | use_format = isinstance(templates[0], str) 97 | 98 | with torch.no_grad(): 99 | zeroshot_weights = [] 100 | for classname in iter_wrap(classnames): 101 | texts = [template.format(classname) if use_format else template(classname) for template in templates] 102 | texts = tokenizer(texts).to(device) # tokenize 103 | class_embeddings = model.encode_text(texts) 104 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 105 | class_embedding /= class_embedding.norm() 106 | zeroshot_weights.append(class_embedding) 107 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 108 | 109 | return zeroshot_weights 110 | 111 | -------------------------------------------------------------------------------- /open_clip_customized/src/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def is_global_master(args): 13 | return args.rank == 0 14 | 15 | 16 | def is_local_master(args): 17 | return args.local_rank == 0 18 | 19 | 20 | def is_master(args, local=False): 21 | return is_local_master(args) if local else is_global_master(args) 22 | 23 | 24 | def is_using_horovod(): 25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 29 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 30 | return True 31 | else: 32 | return False 33 | 34 | 35 | def is_using_distributed(): 36 | if 'WORLD_SIZE' in os.environ: 37 | return int(os.environ['WORLD_SIZE']) > 1 38 | if 'SLURM_NTASKS' in os.environ: 39 | return int(os.environ['SLURM_NTASKS']) > 1 40 | return False 41 | 42 | 43 | def world_info_from_env(): 44 | local_rank = 0 45 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 46 | if v in os.environ: 47 | local_rank = int(os.environ[v]) 48 | break 49 | global_rank = 0 50 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 51 | if v in os.environ: 52 | global_rank = int(os.environ[v]) 53 | break 54 | world_size = 1 55 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 56 | if v in os.environ: 57 | world_size = int(os.environ[v]) 58 | break 59 | 60 | return local_rank, global_rank, world_size 61 | 62 | 63 | def init_distributed_device(args): 64 | # Distributed training = training on more than one GPU. 65 | # Works in both single and multi-node scenarios. 66 | args.distributed = False 67 | args.world_size = 1 68 | args.rank = 0 # global rank 69 | args.local_rank = 0 70 | if args.horovod: 71 | assert hvd is not None, "Horovod is not installed" 72 | hvd.init() 73 | args.local_rank = int(hvd.local_rank()) 74 | args.rank = hvd.rank() 75 | args.world_size = hvd.size() 76 | args.distributed = True 77 | os.environ['LOCAL_RANK'] = str(args.local_rank) 78 | os.environ['RANK'] = str(args.rank) 79 | os.environ['WORLD_SIZE'] = str(args.world_size) 80 | elif is_using_distributed(): 81 | if 'SLURM_PROCID' in os.environ: 82 | # DDP via SLURM 83 | args.local_rank, args.rank, args.world_size = world_info_from_env() 84 | # SLURM var -> torch.distributed vars in case needed 85 | os.environ['LOCAL_RANK'] = str(args.local_rank) 86 | os.environ['RANK'] = str(args.rank) 87 | os.environ['WORLD_SIZE'] = str(args.world_size) 88 | torch.distributed.init_process_group( 89 | backend=args.dist_backend, 90 | init_method=args.dist_url, 91 | world_size=args.world_size, 92 | rank=args.rank, 93 | ) 94 | else: 95 | # DDP via torchrun, torch.distributed.launch 96 | args.local_rank, _, _ = world_info_from_env() 97 | torch.distributed.init_process_group( 98 | backend=args.dist_backend, 99 | init_method=args.dist_url) 100 | args.world_size = torch.distributed.get_world_size() 101 | args.rank = torch.distributed.get_rank() 102 | args.distributed = True 103 | 104 | if torch.cuda.is_available(): 105 | if args.distributed and not args.no_set_device_rank: 106 | device = 'cuda:%d' % args.local_rank 107 | else: 108 | device = 'cuda:0' 109 | torch.cuda.set_device(device) 110 | else: 111 | device = 'cpu' 112 | args.device = device 113 | device = torch.device(device) 114 | return device 115 | 116 | 117 | def broadcast_object(args, obj, src=0): 118 | # broadcast a pickle-able python object from rank-0 to all ranks 119 | if args.horovod: 120 | return hvd.broadcast_object(obj, root_rank=src) 121 | else: 122 | if args.rank == src: 123 | objects = [obj] 124 | else: 125 | objects = [None] 126 | dist.broadcast_object_list(objects, src=src) 127 | return objects[0] 128 | 129 | 130 | def all_gather_object(args, obj, dst=0): 131 | # gather a pickle-able python object across all ranks 132 | if args.horovod: 133 | return hvd.allgather_object(obj) 134 | else: 135 | objects = [None for _ in range(args.world_size)] 136 | dist.all_gather_object(objects, obj) 137 | return objects 138 | -------------------------------------------------------------------------------- /open_clip_customized/tests/test_wds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import util_test 4 | import collections 5 | import tarfile 6 | import io 7 | from PIL import Image 8 | 9 | from training.data import get_wds_dataset 10 | from training.params import parse_args 11 | from training.main import random_seed 12 | 13 | TRAIN_NUM_SAMPLES = 10_000 14 | RTOL = 0.2 15 | 16 | # NOTE: we use two test tar files, which are created on the fly and saved to data/input. 17 | # 000.tar has 10 samples, and the captions are 000_0, 000_1, ..., 000_9 18 | # 001.tar has 5 samples, and the captions are 001_0, 001_1, ..., 001_4 19 | def build_inputs(test_name): 20 | base_input_dir, _ = util_test.get_data_dirs() 21 | input_dir = os.path.join(base_input_dir, test_name) 22 | os.makedirs(input_dir, exist_ok=True) 23 | 24 | def save_tar(idx, num_samples): 25 | filename = os.path.join(input_dir, f'test_data_{idx:03d}.tar') 26 | tar = tarfile.open(filename, 'w') 27 | 28 | for sample_idx in range(num_samples): 29 | # Image 30 | image = Image.new('RGB', (32, 32)) 31 | info = tarfile.TarInfo(f'{sample_idx}.png') 32 | bio = io.BytesIO() 33 | image.save(bio, format='png') 34 | size = bio.tell() 35 | bio.seek(0) 36 | info.size = size 37 | tar.addfile(info, bio) 38 | 39 | # Caption 40 | info = tarfile.TarInfo(f'{sample_idx}.txt') 41 | bio = io.BytesIO() 42 | bio.write(f'{idx:03d}_{sample_idx}'.encode('utf-8')) 43 | size = bio.tell() 44 | bio.seek(0) 45 | info.size = size 46 | tar.addfile(info, bio) 47 | 48 | tar.close() 49 | 50 | save_tar(0, 10) 51 | save_tar(1, 5) 52 | 53 | return input_dir 54 | 55 | 56 | def build_params(input_shards, seed=0): 57 | args = parse_args([]) 58 | args.train_data = input_shards 59 | args.train_num_samples = TRAIN_NUM_SAMPLES 60 | args.dataset_resampled = True 61 | args.seed = seed 62 | args.workers = 1 63 | args.world_size = 1 64 | args.batch_size = 1 65 | random_seed(seed) 66 | 67 | preprocess_img = lambda x: x 68 | tokenizer = lambda x: [x.strip()] 69 | 70 | return args, preprocess_img, tokenizer 71 | 72 | 73 | def get_dataloader(input_shards): 74 | args, preprocess_img, tokenizer = build_params(input_shards) 75 | dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer) 76 | dataloader = dataset.dataloader 77 | return dataloader 78 | 79 | 80 | def test_single_source(): 81 | """Test webdataset with a single tar file.""" 82 | input_dir = build_inputs('single_source') 83 | input_shards = os.path.join(input_dir, 'test_data_000.tar') 84 | dataloader = get_dataloader(input_shards) 85 | 86 | counts = collections.defaultdict(int) 87 | for sample in dataloader: 88 | txts = sample[1] 89 | for txt in txts: 90 | counts[txt] += 1 91 | 92 | for key, count in counts.items(): 93 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL) 94 | 95 | 96 | def test_two_sources(): 97 | """Test webdataset with a single two tar files.""" 98 | input_dir = build_inputs('two_sources') 99 | input_shards = os.path.join(input_dir, 'test_data_{000..001}.tar') 100 | dataloader = get_dataloader(input_shards) 101 | 102 | counts = collections.defaultdict(int) 103 | for sample in dataloader: 104 | txts = sample[1] 105 | for txt in txts: 106 | counts[txt] += 1 107 | 108 | for key, count in counts.items(): 109 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}' 110 | 111 | 112 | def test_two_sources_same_weights(): 113 | """Test webdataset with a two tar files, using --train-data-weights=1::1.""" 114 | input_dir = build_inputs('two_sources_same_weights') 115 | input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}" 116 | args, preprocess_img, tokenizer = build_params(input_shards) 117 | args.train_data_upsampling_factors = '1::1' 118 | dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer) 119 | dataloader = dataset.dataloader 120 | 121 | counts = collections.defaultdict(int) 122 | for sample in dataloader: 123 | txts = sample[1] 124 | for txt in txts: 125 | counts[txt] += 1 126 | 127 | for key, count in counts.items(): 128 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}' 129 | 130 | def test_two_sources_with_upsampling(): 131 | """Test webdataset with a two tar files with upsampling.""" 132 | input_dir = build_inputs('two_sources_with_upsampling') 133 | input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}" 134 | args, preprocess_img, tokenizer = build_params(input_shards) 135 | args.train_data_upsampling_factors = '1::2' 136 | dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer) 137 | dataloader = dataset.dataloader 138 | 139 | counts = collections.defaultdict(int) 140 | for sample in dataloader: 141 | txts = sample[1] 142 | for txt in txts: 143 | counts[txt] += 1 144 | 145 | for key, count in counts.items(): 146 | if key.startswith('000'): 147 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 20, RTOL), f'{key}, {count}' 148 | else: 149 | assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL), f'{key}, {count}' 150 | -------------------------------------------------------------------------------- /open_clip_customized/tests/test_inference.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pytest 4 | import torch 5 | import open_clip 6 | import util_test 7 | 8 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 9 | 10 | if hasattr(torch._C, '_jit_set_profiling_executor'): 11 | # legacy executor is too slow to compile large models for unit tests 12 | # no need for the fusion performance here 13 | torch._C._jit_set_profiling_executor(True) 14 | torch._C._jit_set_profiling_mode(False) 15 | 16 | models_to_test = set(open_clip.list_models()) 17 | 18 | # testing excemptions 19 | models_to_test = models_to_test.difference({ 20 | # not available with timm yet 21 | # see https://github.com/mlfoundations/open_clip/issues/219 22 | 'convnext_xlarge', 23 | 'convnext_xxlarge', 24 | 'convnext_xxlarge_320', 25 | 'vit_medium_patch16_gap_256', 26 | # exceeds GH runner memory limit 27 | 'ViT-bigG-14', 28 | 'ViT-e-14', 29 | 'mt5-xl-ViT-H-14', 30 | 'coca_base', 31 | 'coca_ViT-B-32', 32 | 'coca_roberta-ViT-B-32' 33 | }) 34 | 35 | if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ: 36 | external_model_list = os.environ['OPEN_CLIP_TEST_REG_MODELS'] 37 | with open(external_model_list, 'r') as f: 38 | models_to_test = set(f.read().splitlines()).intersection(models_to_test) 39 | print(f"Selected models from {external_model_list}: {models_to_test}") 40 | 41 | # TODO: add "coca_ViT-B-32" onece https://github.com/pytorch/pytorch/issues/92073 gets fixed 42 | models_to_test = list(models_to_test) 43 | models_to_test.sort() 44 | models_to_test = [(model_name, False) for model_name in models_to_test] 45 | 46 | models_to_jit_test = {"ViT-B-32"} 47 | models_to_jit_test = list(models_to_jit_test) 48 | models_to_jit_test = [(model_name, True) for model_name in models_to_jit_test] 49 | models_to_test_fully = models_to_test + models_to_jit_test 50 | 51 | 52 | @pytest.mark.regression_test 53 | @pytest.mark.parametrize("model_name,jit", models_to_test_fully) 54 | def test_inference_with_data( 55 | model_name, 56 | jit, 57 | pretrained = None, 58 | pretrained_hf = False, 59 | precision = 'fp32', 60 | force_quick_gelu = False, 61 | ): 62 | util_test.seed_all() 63 | model, _, preprocess_val = open_clip.create_model_and_transforms( 64 | model_name, 65 | pretrained = pretrained, 66 | precision = precision, 67 | jit = jit, 68 | force_quick_gelu = force_quick_gelu, 69 | pretrained_hf = pretrained_hf 70 | ) 71 | model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}' 72 | input_dir, output_dir = util_test.get_data_dirs() 73 | # text 74 | input_text_path = os.path.join(input_dir, 'random_text.pt') 75 | gt_text_path = os.path.join(output_dir, f'{model_id}_random_text.pt') 76 | if not os.path.isfile(input_text_path): 77 | pytest.skip(reason = f"missing test data, expected at {input_text_path}") 78 | if not os.path.isfile(gt_text_path): 79 | pytest.skip(reason = f"missing test data, expected at {gt_text_path}") 80 | input_text = torch.load(input_text_path) 81 | gt_text = torch.load(gt_text_path) 82 | y_text = util_test.inference_text(model, model_name, input_text) 83 | assert (y_text == gt_text).all(), f"text output differs @ {input_text_path}" 84 | # image 85 | image_size = model.visual.image_size 86 | if not isinstance(image_size, tuple): 87 | image_size = (image_size, image_size) 88 | input_image_path = os.path.join(input_dir, f'random_image_{image_size[0]}_{image_size[1]}.pt') 89 | gt_image_path = os.path.join(output_dir, f'{model_id}_random_image.pt') 90 | if not os.path.isfile(input_image_path): 91 | pytest.skip(reason = f"missing test data, expected at {input_image_path}") 92 | if not os.path.isfile(gt_image_path): 93 | pytest.skip(reason = f"missing test data, expected at {gt_image_path}") 94 | input_image = torch.load(input_image_path) 95 | gt_image = torch.load(gt_image_path) 96 | y_image = util_test.inference_image(model, preprocess_val, input_image) 97 | assert (y_image == gt_image).all(), f"image output differs @ {input_image_path}" 98 | 99 | if not jit: 100 | model.eval() 101 | model_out = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text) 102 | if type(model) not in [open_clip.CLIP, open_clip.CustomTextCLIP]: 103 | assert type(model_out) == dict 104 | else: 105 | model.output_dict = True 106 | model_out_dict = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text) 107 | assert (model_out_dict["image_features"] == model_out[0]).all() 108 | assert (model_out_dict["text_features"] == model_out[1]).all() 109 | assert (model_out_dict["logit_scale"] == model_out[2]).all() 110 | model.output_dict = None 111 | else: 112 | model, _, preprocess_val = open_clip.create_model_and_transforms( 113 | model_name, 114 | pretrained = pretrained, 115 | precision = precision, 116 | jit = False, 117 | force_quick_gelu = force_quick_gelu, 118 | pretrained_hf = pretrained_hf 119 | ) 120 | 121 | test_model = util_test.TestWrapper(model, model_name, output_dict=False) 122 | test_model = torch.jit.script(test_model) 123 | model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) 124 | assert model_out["test_output"].shape[-1] == 2 125 | 126 | test_model = util_test.TestWrapper(model, model_name, output_dict=True) 127 | test_model = torch.jit.script(test_model) 128 | model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) 129 | assert model_out["test_output"].shape[-1] == 2 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /open_clip_customized/tests/test_download_pretrained.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | import hashlib 5 | import tempfile 6 | import unittest 7 | from io import BytesIO 8 | from pathlib import Path 9 | from unittest.mock import patch 10 | 11 | from urllib3 import HTTPResponse 12 | from urllib3._collections import HTTPHeaderDict 13 | 14 | import open_clip 15 | from open_clip.pretrained import download_pretrained_from_url 16 | 17 | 18 | class DownloadPretrainedTests(unittest.TestCase): 19 | 20 | def create_response(self, data, status_code=200, content_type='application/octet-stream'): 21 | fp = BytesIO(data) 22 | headers = HTTPHeaderDict({ 23 | 'Content-Type': content_type, 24 | 'Content-Length': str(len(data)) 25 | }) 26 | raw = HTTPResponse(fp, preload_content=False, headers=headers, status=status_code) 27 | return raw 28 | 29 | @patch('open_clip.pretrained.urllib') 30 | def test_download_pretrained_from_url_from_openaipublic(self, urllib): 31 | file_contents = b'pretrained model weights' 32 | expected_hash = hashlib.sha256(file_contents).hexdigest() 33 | urllib.request.urlopen.return_value = self.create_response(file_contents) 34 | with tempfile.TemporaryDirectory() as root: 35 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 36 | download_pretrained_from_url(url, root) 37 | urllib.request.urlopen.assert_called_once() 38 | 39 | @patch('open_clip.pretrained.urllib') 40 | def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib): 41 | file_contents = b'pretrained model weights' 42 | expected_hash = hashlib.sha256(file_contents).hexdigest() 43 | urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') 44 | with tempfile.TemporaryDirectory() as root: 45 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 46 | with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): 47 | download_pretrained_from_url(url, root) 48 | urllib.request.urlopen.assert_called_once() 49 | 50 | @patch('open_clip.pretrained.urllib') 51 | def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib): 52 | file_contents = b'pretrained model weights' 53 | expected_hash = hashlib.sha256(file_contents).hexdigest() 54 | urllib.request.urlopen.return_value = self.create_response(file_contents) 55 | with tempfile.TemporaryDirectory() as root: 56 | local_file = Path(root) / 'RN50.pt' 57 | local_file.write_bytes(file_contents) 58 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 59 | download_pretrained_from_url(url, root) 60 | urllib.request.urlopen.assert_not_called() 61 | 62 | @patch('open_clip.pretrained.urllib') 63 | def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib): 64 | file_contents = b'pretrained model weights' 65 | expected_hash = hashlib.sha256(file_contents).hexdigest() 66 | urllib.request.urlopen.return_value = self.create_response(file_contents) 67 | with tempfile.TemporaryDirectory() as root: 68 | local_file = Path(root) / 'RN50.pt' 69 | local_file.write_bytes(b'corrupted pretrained model') 70 | url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' 71 | download_pretrained_from_url(url, root) 72 | urllib.request.urlopen.assert_called_once() 73 | 74 | @patch('open_clip.pretrained.urllib') 75 | def test_download_pretrained_from_url_from_mlfoundations(self, urllib): 76 | file_contents = b'pretrained model weights' 77 | expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] 78 | urllib.request.urlopen.return_value = self.create_response(file_contents) 79 | with tempfile.TemporaryDirectory() as root: 80 | url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' 81 | download_pretrained_from_url(url, root) 82 | urllib.request.urlopen.assert_called_once() 83 | 84 | @patch('open_clip.pretrained.urllib') 85 | def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib): 86 | file_contents = b'pretrained model weights' 87 | expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] 88 | urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') 89 | with tempfile.TemporaryDirectory() as root: 90 | url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' 91 | with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): 92 | download_pretrained_from_url(url, root) 93 | urllib.request.urlopen.assert_called_once() 94 | 95 | @patch('open_clip.pretrained.urllib') 96 | def test_download_pretrained_from_hfh(self, urllib): 97 | model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:hf-internal-testing/tiny-open-clip-model') 98 | tokenizer = open_clip.get_tokenizer('hf-hub:hf-internal-testing/tiny-open-clip-model') 99 | img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png" 100 | image = preprocess(Image.open(requests.get(img_url, stream=True).raw)).unsqueeze(0) 101 | text = tokenizer(["a diagram", "a dog", "a cat"]) 102 | 103 | with torch.no_grad(): 104 | image_features = model.encode_image(image) 105 | text_features = model.encode_text(text) 106 | image_features /= image_features.norm(dim=-1, keepdim=True) 107 | text_features /= text_features.norm(dim=-1, keepdim=True) 108 | 109 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 110 | 111 | self.assertTrue(torch.allclose(text_probs, torch.tensor([[0.0597, 0.6349, 0.3053]]), 1e-3)) 112 | -------------------------------------------------------------------------------- /open_clip_customized/HISTORY.md: -------------------------------------------------------------------------------- 1 | ## 2.24.0 2 | 3 | * Fix missing space in error message 4 | * use model flag for normalizing embeddings 5 | * init logit_bias for non siglip pretrained models 6 | * Fix logit_bias load_checkpoint addition 7 | * Make CoCa model match CLIP models for logit scale/bias init 8 | * Fix missing return of "logit_bias" in CoCa.forward 9 | * Add NLLB-CLIP with SigLIP models 10 | * Add get_logits method and NLLB tokenizer 11 | * Remove the empty file src/open_clip/generation_utils.py 12 | * Update params.py: "BatchNorm" -> "LayerNorm" in the description string for "--lock-text-freeze-layer-norm" 13 | 14 | ## 2.23.0 15 | 16 | * Add CLIPA-v2 models 17 | * Add SigLIP models 18 | * Add MetaCLIP models 19 | * Add NLLB-CLIP models 20 | * CLIPA train code 21 | * Minor changes/fixes 22 | * Remove protobuf version limit 23 | * Stop checking model name when loading CoCa models 24 | * Log native wandb step 25 | * Use bool instead of long masks 26 | 27 | ## 2.21.0 28 | 29 | * Add SigLIP loss + training support 30 | * Add more DataComp models (B/16, B/32 and B/32@256) 31 | * Update default num workers 32 | * Update CoCa generation for `transformers>=4.31` 33 | * PyTorch 2.0 `state_dict()` compatibility fix for compiled models 34 | * Fix padding in `ResizeMaxSize` 35 | * Convert JIT model on state dict load for `pretrained='filename…'` 36 | * Other minor changes and fixes (typos, README, dependencies, CI) 37 | 38 | ## 2.20.0 39 | 40 | * Add EVA models 41 | * Support serial worker training 42 | * Fix Python 3.7 compatibility 43 | 44 | ## 2.19.0 45 | 46 | * Add DataComp models 47 | 48 | ## 2.18.0 49 | 50 | * Enable int8 inference without `.weight` attribute 51 | 52 | ## 2.17.2 53 | 54 | * Update push_to_hf_hub 55 | 56 | ## 2.17.0 57 | 58 | * Add int8 support 59 | * Update notebook demo 60 | * Refactor zero-shot classification code 61 | 62 | ## 2.16.2 63 | 64 | * Fixes for context_length and vocab_size attributes 65 | 66 | ## 2.16.1 67 | 68 | * Fixes for context_length and vocab_size attributes 69 | * Fix --train-num-samples logic 70 | * Add HF BERT configs for PubMed CLIP model 71 | 72 | ## 2.16.0 73 | 74 | * Add improved g-14 weights 75 | * Update protobuf version 76 | 77 | ## 2.15.0 78 | 79 | * Add convnext_xxlarge weights 80 | * Fixed import in readme 81 | * Add samples per second per gpu logging 82 | * Fix slurm example 83 | 84 | ## 2.14.0 85 | 86 | * Move dataset mixtures logic to shard level 87 | * Fix CoCa accum-grad training 88 | * Safer transformers import guard 89 | * get_labels refactoring 90 | 91 | ## 2.13.0 92 | 93 | * Add support for dataset mixtures with different sampling weights 94 | * Make transformers optional again 95 | 96 | ## 2.12.0 97 | 98 | * Updated convnext configs for consistency 99 | * Added input_patchnorm option 100 | * Clean and improve CoCa generation 101 | * Support model distillation 102 | * Add ConvNeXt-Large 320x320 fine-tune weights 103 | 104 | ## 2.11.1 105 | 106 | * Make transformers optional 107 | * Add MSCOCO CoCa finetunes to pretrained models 108 | 109 | ## 2.11.0 110 | 111 | * coca support and weights 112 | * ConvNeXt-Large weights 113 | 114 | ## 2.10.1 115 | 116 | * `hf-hub:org/model_id` support for loading models w/ config and weights in Hugging Face Hub 117 | 118 | ## 2.10.0 119 | 120 | * Added a ViT-bigG-14 model. 121 | * Added an up-to-date example slurm script for large training jobs. 122 | * Added a option to sync logs and checkpoints to S3 during training. 123 | * New options for LR schedulers, constant and constant with cooldown 124 | * Fix wandb autoresuming when resume is not set 125 | * ConvNeXt `base` & `base_w` pretrained models added 126 | * `timm-` model prefix removed from configs 127 | * `timm` augmentation + regularization (dropout / drop-path) supported 128 | 129 | ## 2.9.3 130 | 131 | * Fix wandb collapsing multiple parallel runs into a single one 132 | 133 | ## 2.9.2 134 | 135 | * Fix braceexpand memory explosion for complex webdataset urls 136 | 137 | ## 2.9.1 138 | 139 | * Fix release 140 | 141 | ## 2.9.0 142 | 143 | * Add training feature to auto-resume from the latest checkpoint on restart via `--resume latest` 144 | * Allow webp in webdataset 145 | * Fix logging for number of samples when using gradient accumulation 146 | * Add model configs for convnext xxlarge 147 | 148 | ## 2.8.2 149 | 150 | * wrapped patchdropout in a torch.nn.Module 151 | 152 | ## 2.8.1 153 | 154 | * relax protobuf dependency 155 | * override the default patch dropout value in 'vision_cfg' 156 | 157 | ## 2.8.0 158 | 159 | * better support for HF models 160 | * add support for gradient accumulation 161 | * CI fixes 162 | * add support for patch dropout 163 | * add convnext configs 164 | 165 | 166 | ## 2.7.0 167 | 168 | * add multilingual H/14 xlm roberta large 169 | 170 | ## 2.6.1 171 | 172 | * fix setup.py _read_reqs 173 | 174 | ## 2.6.0 175 | 176 | * Make openclip training usable from pypi. 177 | * Add xlm roberta large vit h 14 config. 178 | 179 | ## 2.5.0 180 | 181 | * pretrained B/32 xlm roberta base: first multilingual clip trained on laion5B 182 | * pretrained B/32 roberta base: first clip trained using an HF text encoder 183 | 184 | ## 2.4.1 185 | 186 | * Add missing hf_tokenizer_name in CLIPTextCfg. 187 | 188 | ## 2.4.0 189 | 190 | * Fix #211, missing RN50x64 config. Fix type of dropout param for ResNet models 191 | * Bring back LayerNorm impl that casts to input for non bf16/fp16 192 | * zero_shot.py: set correct tokenizer based on args 193 | * training/params.py: remove hf params and get them from model config 194 | 195 | ## 2.3.1 196 | 197 | * Implement grad checkpointing for hf model. 198 | * custom_text: True if hf_model_name is set 199 | * Disable hf tokenizer parallelism 200 | 201 | ## 2.3.0 202 | 203 | * Generalizable Text Transformer with HuggingFace Models (@iejMac) 204 | 205 | ## 2.2.0 206 | 207 | * Support for custom text tower 208 | * Add checksum verification for pretrained model weights 209 | 210 | ## 2.1.0 211 | 212 | * lot including sota models, bfloat16 option, better loading, better metrics 213 | 214 | ## 1.2.0 215 | 216 | * ViT-B/32 trained on Laion2B-en 217 | * add missing openai RN50x64 model 218 | 219 | ## 1.1.1 220 | 221 | * ViT-B/16+ 222 | * Add grad checkpointing support 223 | * more robust data loader 224 | -------------------------------------------------------------------------------- /open_clip_customized/docs/model_profile.csv: -------------------------------------------------------------------------------- 1 | model,image_size,image_width,text_width,embed_dim,mparams,image_mparams,text_mparams,gflops,image_gflops,text_gflops 2 | ViT-S-32-alt,224,384,256,256,43.22,22.59,20.63,3.56,2.29,1.27 3 | ViT-S-32,224,384,384,384,63.09,22.64,40.44,5.66,2.29,3.38 4 | ViT-M-32-alt,224,512,384,384,80.07,39.63,40.44,7.37,3.99,3.38 5 | ViT-M-32,224,512,512,512,103.12,39.69,63.43,9.95,3.99,5.96 6 | ViT-S-16-alt,224,384,256,256,42.4,21.76,20.63,10.47,9.2,1.27 7 | ViT-S-16,224,384,384,384,62.26,21.81,40.44,12.58,9.2,3.38 8 | ViT-B-32,224,768,512,512,151.28,87.85,63.43,14.78,8.82,5.96 9 | ViT-B-32-quickgelu,224,768,512,512,151.28,87.85,63.43,14.78,8.82,5.96 10 | convnext_tiny,224,768,512,1024,92.3,28.61,63.69,14.87,8.91,5.96 11 | ViT-B-32-256,256,768,512,512,151.29,87.86,63.43,17.46,11.5,5.96 12 | RN50,224,64,512,1024,102.01,38.32,63.69,18.18,12.22,5.96 13 | RN50-quickgelu,224,64,512,1024,102.01,38.32,63.69,18.18,12.22,5.96 14 | ViT-M-16-alt,224,512,384,384,78.98,38.53,40.44,19.36,15.98,3.38 15 | ViT-M-16,224,512,512,512,102.02,38.59,63.43,21.94,15.98,5.96 16 | vit_relpos_medium_patch16_cls_224,224,768,512,512,101.94,38.51,63.43,21.99,16.03,5.96 17 | mt5-base-ViT-B-32,224,768,512,512,365.71,87.85,277.86,22.12,8.82,13.3 18 | convnext_small,224,768,512,512,113.28,49.85,63.43,23.33,17.37,5.96 19 | ViT-B-32-plus-256,256,896,640,640,210.3,119.13,91.16,24.83,15.56,9.27 20 | RN101,224,64,512,512,119.69,56.26,63.43,25.5,19.54,5.96 21 | RN101-quickgelu,224,64,512,512,119.69,56.26,63.43,25.5,19.54,5.96 22 | vit_medium_patch16_gap_256,256,768,512,512,102.04,38.61,63.43,27.1,21.14,5.96 23 | coca_ViT-B-32,224,768,512,512,253.56,89.16,63.43,33.34,9.19,5.96 24 | convnext_base,224,768,512,512,151.52,88.09,63.43,36.67,30.71,5.96 25 | swin_base_patch4_window7_224,224,768,640,640,178.56,87.4,91.16,40.13,30.86,9.27 26 | ViT-B-16,224,768,512,512,149.62,86.19,63.43,41.09,35.13,5.96 27 | ViT-B-16-quickgelu,224,768,512,512,149.62,86.19,63.43,41.09,35.13,5.96 28 | EVA02-B-16,224,768,512,512,149.69,86.26,63.43,41.09,35.13,5.96 29 | ViT-B-16-SigLIP,224,768,768,768,203.16,92.88,110.27,46.44,35.42,11.02 30 | convnext_base_w,256,768,640,640,179.39,88.22,91.16,49.38,40.11,9.27 31 | RN50x4,288,80,640,640,178.3,87.14,91.16,51.82,42.56,9.27 32 | coca_roberta-ViT-B-32,224,768,768,512,420.37,87.85,124.45,53.12,8.82,13.12 33 | ViT-B-16-plus,224,896,640,640,208.35,117.19,91.16,56.75,47.49,9.27 34 | ViT-B-16-SigLIP-256,256,768,768,768,203.2,92.93,110.27,57.84,46.82,11.02 35 | ViT-B-16-SigLIP-i18n-256,256,768,768,768,370.63,92.93,277.7,57.84,46.82,11.02 36 | ViT-B-16-plus-240,240,896,640,640,208.38,117.21,91.16,64.03,54.76,9.27 37 | convnext_base_w_320,320,768,640,640,179.39,88.22,91.16,71.94,62.67,9.27 38 | convnext_large,224,768,768,768,321.06,197.41,123.65,82.02,68.72,13.3 39 | coca_base,288,768,768,512,440.34,86.4,134.66,99.09,46.47,13.3 40 | roberta-ViT-B-32,224,768,512,512,212.72,87.85,124.87,105.87,8.82,97.05 41 | xlm-roberta-base-ViT-B-32,224,768,512,512,366.12,87.85,278.27,105.87,8.82,97.05 42 | convnext_large_d,256,768,768,768,351.77,199.77,152.0,107.5,89.76,17.73 43 | ViT-B-16-SigLIP-384,384,768,768,768,203.45,93.18,110.27,123.15,112.13,11.02 44 | ViT-L-16,224,1024,768,768,427.74,304.09,123.65,136.41,123.11,13.3 45 | convnext_large_d_320,320,768,768,768,351.77,199.77,152.0,157.98,140.25,17.73 46 | RN50x16,384,96,768,768,290.98,167.33,123.65,162.69,149.39,13.3 47 | ViT-L-14-CLIPA,224,1024,768,768,414.21,303.96,110.25,167.5,162.03,5.47 48 | EVA02-L-14,224,768,768,768,427.76,304.11,123.65,175.3,162.0,13.3 49 | ViT-L-14,224,1024,768,768,427.62,303.97,123.65,175.33,162.03,13.3 50 | ViT-L-14-quickgelu,224,1024,768,768,427.62,303.97,123.65,175.33,162.03,13.3 51 | convnext_xlarge,256,768,1024,1024,653.89,350.25,303.65,198.38,159.14,39.24 52 | ViT-L-16-SigLIP-256,256,768,1024,1024,652.15,315.96,336.19,201.62,162.56,39.06 53 | coca_ViT-L-14,224,1024,768,768,638.45,306.72,123.65,214.52,163.64,13.3 54 | ViT-B-16-SigLIP-512,512,768,768,768,203.79,93.52,110.27,227.26,216.24,11.02 55 | ViT-SO400M-14-SigLIP,224,768,1152,1152,877.36,427.68,449.68,233.54,220.35,13.19 56 | ViT-L-14-280,280,1024,768,768,427.76,304.11,123.65,271.79,258.49,13.3 57 | ViT-L-16-320,320,1024,768,768,427.95,304.3,123.65,271.93,258.63,13.3 58 | ViT-H-16,224,1280,1024,1024,986.26,632.23,354.03,301.72,254.63,47.09 59 | ViT-H-14-CLIPA,224,1280,1024,1024,968.24,632.07,336.16,354.02,334.59,19.43 60 | nllb-clip-base,224,768,512,512,501.89,87.85,414.04,369.6,8.82,360.78 61 | ViT-H-14,224,1280,1024,1024,986.11,632.08,354.03,381.68,334.59,47.09 62 | ViT-H-14-quickgelu,224,1280,1024,1024,986.11,632.08,354.03,381.68,334.59,47.09 63 | ViT-L-14-CLIPA-336,336,1024,768,768,414.54,304.29,110.25,387.39,381.92,5.47 64 | EVA02-L-14-336,336,768,768,768,428.08,304.43,123.65,395.16,381.86,13.3 65 | ViT-L-14-336,336,1024,768,768,427.94,304.29,123.65,395.22,381.92,13.3 66 | ViT-L-16-SigLIP-384,384,768,1024,1024,652.48,316.28,336.19,422.91,383.85,39.06 67 | convnext_xxlarge,256,768,1024,1024,1200.58,846.54,354.03,443.03,395.94,47.09 68 | nllb-clip-base-siglip,384,768,512,768,507.47,93.18,414.3,472.91,112.13,360.78 69 | mt5-xl-ViT-H-14,224,1280,512,1024,2306.75,632.08,1674.68,514.04,334.59,179.45 70 | EVA01-g-14,224,768,768,1024,1136.44,1012.59,123.85,547.36,534.06,13.3 71 | RN50x64,448,128,1024,1024,623.26,420.38,202.88,552.65,529.11,23.55 72 | EVA01-g-14-plus,224,768,1024,1024,1366.62,1012.59,354.03,581.15,534.06,47.09 73 | ViT-g-14,224,1408,1024,1024,1366.68,1012.65,354.03,581.15,534.06,47.09 74 | convnext_xxlarge_320,320,768,1024,1024,1200.58,846.54,354.03,665.74,618.65,47.09 75 | xlm-roberta-large-ViT-H-14,224,1280,512,1024,1193.01,632.08,560.94,671.01,334.59,336.42 76 | ViT-SO400M-14-SigLIP-384,384,768,1152,1152,877.96,428.23,449.73,723.48,670.35,53.13 77 | ViT-H-14-CLIPA-336,336,1280,1024,1024,968.64,632.48,336.16,800.88,781.45,19.43 78 | ViT-bigG-14-CLIPA,224,1664,1280,1280,2517.22,1844.9,672.32,1007.93,967.5,40.44 79 | ViT-H-14-378-quickgelu,378,1280,1024,1024,986.71,632.68,354.03,1054.05,1006.96,47.09 80 | ViT-bigG-14,224,1664,1280,1280,2539.57,1844.91,694.66,1065.36,967.5,97.86 81 | nllb-clip-large,224,1280,512,1024,1399.22,632.08,767.14,1468.46,334.59,1133.87 82 | nllb-clip-large-siglip,384,768,512,1152,1195.5,428.23,767.27,1804.22,670.35,1133.87 83 | ViT-e-14,224,1792,1280,1280,4581.09,3807.72,773.37,2091.45,1981.35,110.1 84 | ViT-bigG-14-CLIPA-336,336,1664,1280,1280,2517.76,1845.44,672.32,2271.58,2231.15,40.44 85 | EVA02-E-14,224,768,1024,1024,4704.59,4350.56,354.03,2311.42,2264.33,47.09 86 | EVA02-E-14-plus,224,768,1280,1024,5044.89,4350.56,694.33,2362.19,2264.33,97.86 87 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | """ 31 | 32 | def __init__( 33 | self, 34 | model_name, 35 | embed_dim, 36 | image_size=224, 37 | pool='avg', 38 | proj='linear', 39 | proj_bias=False, 40 | drop=0., 41 | drop_path=None, 42 | patch_drop=None, 43 | pretrained=False, 44 | ): 45 | super().__init__() 46 | if timm is None: 47 | raise RuntimeError("Please `pip install timm` to use timm models.") 48 | self.image_size = to_2tuple(image_size) 49 | 50 | # setup kwargs that may not be common across all models 51 | timm_kwargs = {} 52 | if drop_path is not None: 53 | timm_kwargs['drop_path_rate'] = drop_path 54 | if patch_drop is not None: 55 | timm_kwargs['patch_drop_rate'] = patch_drop 56 | 57 | custom_pool = pool in ('abs_attn', 'rot_attn') 58 | if proj: 59 | assert proj in ("linear", "mlp", "none") 60 | extra_proj = proj in ("linear", "mlp") 61 | if not extra_proj and not custom_pool: 62 | # use network classifier head as projection if no proj specified and no custom pooling used 63 | # if projection is explicitly set to "none" will be pass through from network trunk 64 | proj_dim = 0 if proj == 'none' else embed_dim 65 | self.trunk = timm.create_model( 66 | model_name, 67 | num_classes=proj_dim, 68 | global_pool=pool, 69 | pretrained=pretrained, 70 | **timm_kwargs, 71 | ) 72 | prev_chs = embed_dim 73 | else: 74 | self.trunk = timm.create_model( 75 | model_name, 76 | pretrained=pretrained, 77 | **timm_kwargs, 78 | ) 79 | feat_size = self.trunk.default_cfg.get('pool_size', None) 80 | feature_ndim = 1 if not feat_size else 2 81 | if custom_pool: 82 | assert feature_ndim == 2 83 | # if attn pooling used, remove both classifier and default pool 84 | self.trunk.reset_classifier(0, global_pool='') 85 | else: 86 | # reset global pool if pool config set, otherwise leave as network default 87 | reset_kwargs = dict(global_pool=pool) if pool else {} 88 | self.trunk.reset_classifier(0, **reset_kwargs) 89 | prev_chs = self.trunk.num_features 90 | 91 | head_layers = OrderedDict() 92 | 93 | # Add custom pooling to head 94 | if pool == 'abs_attn': 95 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 96 | prev_chs = embed_dim 97 | elif pool == 'rot_attn': 98 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 99 | prev_chs = embed_dim 100 | 101 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 102 | if proj == 'linear': 103 | head_layers['drop'] = nn.Dropout(drop) 104 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 105 | elif proj == 'mlp': 106 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 107 | 108 | self.head = nn.Sequential(head_layers) 109 | 110 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 111 | """ lock modules 112 | Args: 113 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 114 | """ 115 | if not unlocked_groups: 116 | # lock full model 117 | for param in self.trunk.parameters(): 118 | param.requires_grad = False 119 | if freeze_bn_stats: 120 | freeze_batch_norm_2d(self.trunk) 121 | else: 122 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 123 | try: 124 | # FIXME import here until API stable and in an official release 125 | from timm.models.helpers import group_parameters, group_modules 126 | except ImportError: 127 | raise RuntimeError( 128 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 129 | matcher = self.trunk.group_matcher() 130 | gparams = group_parameters(self.trunk, matcher) 131 | max_layer_id = max(gparams.keys()) 132 | max_layer_id = max_layer_id - unlocked_groups 133 | for group_idx in range(max_layer_id + 1): 134 | group = gparams[group_idx] 135 | for param in group: 136 | self.trunk.get_parameter(param).requires_grad = False 137 | if freeze_bn_stats: 138 | gmodules = group_modules(self.trunk, matcher, reverse=True) 139 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 140 | freeze_batch_norm_2d(self.trunk, gmodules) 141 | 142 | @torch.jit.ignore 143 | def set_grad_checkpointing(self, enable=True): 144 | try: 145 | self.trunk.set_grad_checkpointing(enable) 146 | except Exception as e: 147 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 148 | 149 | def forward(self, x): 150 | x = self.trunk(x) 151 | x = self.head(x) 152 | return x 153 | -------------------------------------------------------------------------------- /open_clip_customized/docs/datacomp_models.md: -------------------------------------------------------------------------------- 1 | ## CommonPool and DataComp models 2 | 3 | As part of [DataComp](https://github.com/mlfoundations/datacomp), we trained models on CommonPool using various data filtering strategies. 4 | We release models for all four scales of the competition, small, medium, large and xlarge, corresponding to a pool size and number of samples seen of 12.8M, 128M, 1.28B and 12.8B, respectively. 5 | 6 | The models are specified below, see our paper [DataComp: In seearch of the next generation of multimodal datasets](https://arxiv.org/abs/2304.14108) for more details. 7 | 8 | 9 | ## xlarge scale models 10 | 11 | * `datacomp_xl_s13b_b90k`: A ViT-L/14 trained on DataComp-1B for 12.8B steps and batch size 90k. Achieves 79.2% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K. 12 | 13 | * `commonpool_xl_clip_s13b_b90k`: A ViT-L/14 trained on CommonPool-XL filtered using CLIP scores, for 12.8B steps and batch size 90k. Achieves 76.4% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K. 14 | 15 | * `commonpool_xl_laion_s13b_b90k`: A ViT-L/14 trained on CommonPool-XL filtered using the LAION-2B filtering scheme, for 12.8B steps and batch size 90k. Achieves 75.5% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K. 16 | 17 | * `commonpool_xl_s13b_b90k`: A ViT-L/14 trained on CommonPool-XL without any filtering, for 12.8B steps and batch size 90k. Achieves 72.3% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K. 18 | 19 | 20 | ## large scale models 21 | 22 | * `datacomp_l_s1b_b8k`: A ViT-B/16 trained on a 140M subset of DataComp-1B, for 1.28B steps and batch size 8k. Achieves 63.1% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K. 23 | 24 | * `commonpool_l_clip_s1b_b8k`: A ViT-B/16 trained on CommonPool-L filtered using CLIP scores, for 1.28B steps and batch size 8k. Achieves 57.8% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K. 25 | 26 | * `commonpool_l_laion_s1b_b8k`: A ViT-B/16 trained on CommonPool-L filtered using the LAION-2B filtering scheme, for 1.28B steps and batch size 8k. Achieves 55.3% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K. 27 | 28 | * `commonpool_l_image_s1b_b8k`: A ViT-B/16 trained on CommonPool-L filtered using image-based filtering, for 1.28B steps and batch size 8k. Achieves 57.2% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K. 29 | 30 | * `commonpool_l_text_s1b_b8k`: A ViT-B/16 trained on CommonPool-L filtered using text-based filtering, for 1.28B steps and batch size 8k. Achieves 56.1% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K. 31 | 32 | * `commonpool_l_basic_s1b_b8k`: A ViT-B/16 trained on CommonPool-L filtered using basic filtering (English filtering + caption length and image size filtering), for 1.28B steps and batch size 8k. Achieves 51.6% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K. 33 | 34 | * `commonpool_l_s1b_b8k`: A ViT-B/16 trained on CommonPool-L without any filtering, for 1.28B steps and batch size 8k. Achieves 45.9% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K. 35 | 36 | 37 | ## medium scale models 38 | 39 | * `datacomp_m_s128m_b4k`: A ViT-B/32 trained on a 14M subset of DataComp-1B, for 128M steps and batch size 4k. Achieves 29.7% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K. 40 | 41 | * `commonpool_m_clip_s128m_b4k`: A ViT-B/32 trained on CommonPool-M filtered using CLIP scores, for 128M steps and batch size 4k. Achieves 27.3% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K. 42 | 43 | * `commonpool_m_laion_s128m_b4k`: A ViT-B/32 trained on CommonPool-M filtered using the LAION-2B filtering scheme, for 128M steps and batch size 4k. Achieves 23.0% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K. 44 | 45 | * `commonpool_m_image_s128m_b4k`: A ViT-B/32 trained on CommonPool-M filtered using image-based filtering, for 128M steps and batch size 4k. Achieves 26.8% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K. 46 | 47 | * `commonpool_m_text_s128m_b4k`: A ViT-B/32 trained on CommonPool-M filtered using text-based filtering, for 128M steps and batch size 4k. Achieves 25.5% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K. 48 | 49 | * `commonpool_m_basic_s128m_b4k`: A ViT-B/32 trained on CommonPool-M filtered using basic filtering (English filtering + caption length and image size filtering), for 128M steps and batch size 4k. Achieves 22.6% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K. 50 | 51 | * `commonpool_m_s128m_b4k`: A ViT-B/32 trained on CommonPool-M without any filtering, for 128M steps and batch size 4k. Achieves 17.6% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K. 52 | 53 | 54 | ## small scale models 55 | 56 | * `datacomp_s_s13m_b4k`: A ViT-B/32 trained on a 1.4M subset of DataComp-1B, for 12.8M steps and batch size 4k. Achieves 3.9% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K. 57 | 58 | * `commonpool_s_clip_s13m_b4k`: A ViT-B/32 trained on CommonPool-S filtered using CLIP scores, for 12.8M steps and batch size 4k. Achieves 5.1% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K. 59 | 60 | * `commonpool_s_laion_s13m_b4k`: A ViT-B/32 trained on CommonPool-S filtered using the LAION-2B filtering scheme scores, for 12.8M steps and batch size 4k. Achieves 3.1% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K. 61 | 62 | * `commonpool_s_image_s13m_b4k`: A ViT-B/32 trained on CommonPool-S filtered using image-based filtering, for 12.8M steps and batch size 4k. Achieves 4.3% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K. 63 | 64 | * `commonpool_s_text_s13m_b4k`: A ViT-B/32 trained on CommonPool-S filtered using text-based filtering, for 12.8M steps and batch size 4k. Achieves 4.6% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K. 65 | 66 | * `commonpool_s_basic_s13m_b4k`: A ViT-B/32 trained on CommonPool-S filtered using basic filtering (English filtering + caption length and image size filtering), for 12.8M steps and batch size 4k. Achieves 3.0% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K. 67 | 68 | * `commonpool_s_s13m_b4k`: A ViT-B/32 trained on CommonPool-S without any filtering, for 12.8M steps and batch size 4k. Achieves 2.5% zero-shot accuracy on ImageNet. Available at https://huggingface.co/laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K. 69 | 70 | -------------------------------------------------------------------------------- /open_clip_customized/docs/clipa.md: -------------------------------------------------------------------------------- 1 | ## CLIPA 2 | 3 | In this work, we present a surprising finding that there exists an _inverse_ scaling law for CLIP training, 4 | whereby the larger the image/text encoders used, the shorter the sequence length of image/text tokens that can be applied in training. 5 | Moreover, we showcase that the strategy for reducing image/text token length plays a crucial role in determining the quality of this scaling law. 6 | 7 | ![](/docs/inverse_scaling_law.png) 8 | 9 | As a result of this finding, we are able to successfully train CLIP even by using academic resources. 10 | For example, on an A100 eight-GPU server, our CLIP models achieve zero-shot top-1 ImageNet accuracies of **63.2%** in about **2 days**, 11 | **67.8%** in about **3 days**, and **69.3%** in about **4 days**. 12 | 13 | Moreover, We find that CLIPA at scale leads to state-of-the-art performance. For example, our CLIPA-v2 H/14 achieves a zero-shot top-1 ImageNet accuracy of **81.8%**, 14 | with a budget less than **$15000**. 15 | 16 | ![](/docs/clipa_acc_compute.png) 17 | 18 | For more details, please see our paper [An Inverse Scaling Law for CLIP Training](https://arxiv.org/abs/2305.07017) and 19 | [CLIPA-v2: Scaling CLIP Training with 81.1% Zero-shot ImageNet Accuracy within a $10,000 Budget; An Extra $4,000 Unlocks 81.8% Accuracy](https://arxiv.org/abs/2306.15658). 20 | 21 | 22 | Eight token length reduction strategies are investigated in this work, detailed as follows. 23 | 24 | 25 | ## Image token length reduction 26 | 27 | ![](/docs/clipa_reduce_image_token.png) 28 | 29 | * `resize`: use `--force-image-size` to specify the image size you want to adopt. We find this strategy generally works the best as it retains full image information. 30 | 31 | * `random mask`: Randomly mask out image patches. use `--force-patch-dropout` to specify the mask ratio you want to adopt. 32 | 33 | * `grid mask`: Preserve one patch in each 2 × 2 grid window. We do not provide implementation for grid masking, as it is only experimental and we generally find resizing works better. 34 | 35 | * `block mask`: Keep a single block and remove other patches. We do not provide implementation for block masking, as it is only experimental and we generally find resizing works better. 36 | 37 | 38 | ## Text token length reduction 39 | 40 | * `syntax mask`: Assign different masking priorities to parts of speech. Specify `"text_mask": syntax` in `"tokenizer_kwargs"` in `"text_cfg"` of model config `json` file to use. 41 | Specifically, we prioritize retaining nouns, followed by adjectives, and then other words. 42 | We find this strategy generally works the best as it retains critical information for contrastive learning. 43 | 44 | * `truncate`: Truncation selects the first N text tokens and discards the rest. This is the default setting of `open_clip`. 45 | 46 | * `random mask`: Randomly drops a portion of the text tokens. Specify `"text_mask": random` in `"tokenizer_kwargs"` in `"text_cfg"` of model config `json` file to use. 47 | 48 | * `block mask`: Randomly preserves consecutive text sequences. Specify `"text_mask": block` in `"tokenizer_kwargs"` in `"text_cfg"` of model config `json` file to use. 49 | 50 | 51 | ## Installation 52 | 53 | The installation is really the same as `open_clip`, except for the usage of Natural Language Toolkit (NLTK) in `syntax mask` of text token length reduction. 54 | Please follow the [official doc](https://www.nltk.org/) to install NLTK. 55 | 56 | Note that the the usage of NLTK brings two constraints: 57 | * Because certain functions like `nltk.pos_tag` from NLTK only support English and Russian for now, the `syntax mask` only works for English. 58 | we have not tested it on Russian or any other language. Theoretically, it should work the same, given a proper language processing toolkit for other languages. 59 | If you still want to apply `syntax mask` on other languages, try finding the right toolkit. Otherwise, use other text token length reduction strategies 60 | * some modules of NLTK like `punkt` or `averaged_perceptron_tagger` need to be downloaded first before using NLTK. 61 | We have included the downloading code in `tokenizer.py`, but this might cause trouble in certain cases. 62 | You may want to manually download those modules first, by `nltk.download('punkt')` and `nltk.download('averaged_perceptron_tagger')`, 63 | and then setup the environmental variable before running the script `export NLTK_DATA=cache`. 64 | Note that this is a one-time effort. Remember to comment out those `nltk.download` lines in `tokenizer.py` afterwards. 65 | 66 | ## Training 67 | We provide example scripts to reproduce our CLIPA results on an A100 eight-GPU machine under path `docs/script_examples/clipa`. 68 | 69 | For instance, to reproduce the CLIPA-L16(I37,T8) results, first run the pre-training script 70 | ``` 71 | bash docs/script_examples/clipa/vit_l16/i37_t8_pretrain.sh 72 | ``` 73 | and fine-tune the pre-trained checkpoint with 74 | ``` 75 | bash docs/script_examples/clipa/vit_l16/i37_t8_finetune.sh 76 | ``` 77 | - Remember to change the path to dataset to your own path. 78 | - This is a two-stage training pipeline. Remember to change the path to pre-trained checkpoint to your own when fine-tuning. 79 | - The training time is ~3 days for pre-training and ~1 day for fine-tuning on an A100 eight-GPU machine. 80 | 81 | ## Model Weights 82 | Below are CLIPA trained weights on LAION-400M with an A100 eight-GPU machine. 83 | All models are pre-trained for 6 epochs with reduced input token lengths and subsequently fine-tuned for 0.36 epoch with full input token lengths. 84 | 85 | 86 | | | Pre-trained Weights | zero-shot IN-1K | 87 | |---------------------|:----------------------------------------------------------------------------------------------:|:---------------:| 88 | | CLIPA-B/16(I50,T16) | [download](https://drive.google.com/file/d/1MDpz8gV2Vjaazk16rBhLxU8811U7_cGL/view?usp=sharing) | 59.7 | 89 | | CLIPA-L/16(I17,T16) | [download](https://drive.google.com/file/d/1Tr2GYiKAaMH6EGIn5l7eX_1K20eaA3WA/view?usp=sharing) | 60.3 | 90 | | CLIPA_L/16(I37,T8) | [download](https://drive.google.com/file/d/1EM1ChRNARpLckkJjf6m7njCY3xyvpGBu/view?usp=sharing) | 57.9 | 91 | 92 | | | Fine-tuned Weights | zero-shot IN-1K | 93 | |---------------------|:----------------------------------------------------------------------------------------------:|:-----:| 94 | | CLIPA-B/16(I50,T16) | [download](https://drive.google.com/file/d/1fURK0K_a3-83jVEI4PVEbnEJb_V6UbGv/view?usp=sharing) | 63.2 | 95 | | CLIPA-L/16(I17,T16) | [download](https://drive.google.com/file/d/18qqZGOTGOgb3I3JWONuat6qObsgLq7sR/view?usp=sharing) | 67.8 | 96 | | CLIPA_L/16(I37,T8) | [download](https://drive.google.com/file/d/1lV7pLORUK04T9QKKx9TpYtMws-AZrib0/view?usp=sharing) | 69.3 | 97 | 98 | 99 | ## CLIPA-v2 100 | We also provide example scripts to reproduce our CLIPA-v2 H/14 results under path `docs/script_examples/clipav2`. 101 | Note that the original results are obtained with [our JAX implementation](https://github.com/UCSC-VLAA/CLIPA/tree/master/clipa_jax). 102 | These scripts are written after manually scanning the JAX config files. 103 | As it is infeasible for us to retrain those models again with pytorch, its correctness cannot be verified with 100% confidence. Use them at your own discretion. 104 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | import re 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import TensorType 10 | 11 | try: 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 15 | BaseModelOutputWithPoolingAndCrossAttentions 16 | except ImportError as e: 17 | transformers = None 18 | 19 | 20 | class BaseModelOutput: 21 | pass 22 | 23 | 24 | class PretrainedConfig: 25 | pass 26 | 27 | from .hf_configs import arch_dict 28 | 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(? 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64, in_chans=3): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(in_chans, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /open_clip_customized/src/open_clip/big_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .model import CustomTextCLIP 5 | from .transformer import TextTransformer, Transformer 6 | 7 | 8 | @torch.no_grad() 9 | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): 10 | """ Load weights from .npz checkpoints for official Google big_vision image-text models 11 | 12 | Currently the SigLIP source models are supported and a CustomTextCLIP destination model 13 | w/ timm image encoder. 14 | """ 15 | from timm.layers import resample_patch_embed, resample_abs_pos_embed 16 | 17 | def _n2p(w, t=True): 18 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 19 | w = w.flatten() 20 | if t: 21 | if w.ndim == 4: 22 | w = w.transpose([3, 2, 0, 1]) 23 | elif w.ndim == 3: 24 | w = w.transpose([2, 0, 1]) 25 | elif w.ndim == 2: 26 | w = w.transpose([1, 0]) 27 | return torch.from_numpy(w) 28 | 29 | w = np.load(checkpoint_path) 30 | interpolation = 'bilinear' 31 | antialias = False 32 | 33 | def _convert_timm_img(module, prefix): 34 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 35 | if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: 36 | embed_conv_w = resample_patch_embed( 37 | embed_conv_w, 38 | module.patch_embed.proj.weight.shape[-2:], 39 | interpolation=interpolation, 40 | antialias=antialias, 41 | verbose=True, 42 | ) 43 | module.patch_embed.proj.weight.copy_(embed_conv_w) 44 | module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 45 | 46 | if module.cls_token is not None: 47 | module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 48 | 49 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) 50 | if pos_embed_w.shape != module.pos_embed.shape: 51 | assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' 52 | num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) 53 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights 54 | pos_embed_w, 55 | new_size=module.patch_embed.grid_size, 56 | num_prefix_tokens=num_prefix_tokens, 57 | interpolation=interpolation, 58 | antialias=antialias, 59 | verbose=True, 60 | ) 61 | module.pos_embed.copy_(pos_embed_w) 62 | 63 | mha_sub, b_sub, ln1_sub = (0, 0, 1) 64 | for i, block in enumerate(module.blocks.children()): 65 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 66 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' 67 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 68 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 69 | block.attn.qkv.weight.copy_(torch.cat([ 70 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 71 | block.attn.qkv.bias.copy_(torch.cat([ 72 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 73 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 74 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 75 | for r in range(2): 76 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) 77 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) 78 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) 79 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) 80 | 81 | module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 82 | module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 83 | 84 | if module.attn_pool is not None: 85 | block_prefix = f'{prefix}MAPHead_0/' 86 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 87 | module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) 88 | module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) 89 | module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) 90 | module.attn_pool.kv.weight.copy_(torch.cat([ 91 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) 92 | module.attn_pool.kv.bias.copy_(torch.cat([ 93 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) 94 | module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 95 | module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 96 | module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 97 | module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 98 | for r in range(2): 99 | getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) 100 | getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) 101 | 102 | def _convert_openclip_transformer(module: Transformer, prefix): 103 | for i, block in enumerate(module.resblocks.children()): 104 | block_prefix = f'{prefix}encoderblock_{i}/' 105 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 106 | block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 107 | block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 108 | block.attn.in_proj_weight.copy_(torch.cat([ 109 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 110 | block.attn.in_proj_bias.copy_(torch.cat([ 111 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 112 | block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 113 | block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 114 | block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) 115 | block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) 116 | block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) 117 | block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) 118 | block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) 119 | block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) 120 | 121 | def _convert_openclip_txt(module: TextTransformer, prefix): 122 | module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) 123 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) 124 | module.positional_embedding.copy_(pos_embed_w) 125 | _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') 126 | module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) 127 | module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) 128 | module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 129 | module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 130 | 131 | _convert_timm_img(model.visual.trunk, 'params/img/') 132 | _convert_openclip_txt(model.text, 'params/txt/') 133 | model.logit_bias.copy_(_n2p(w['params/b'])[0]) 134 | model.logit_scale.copy_(_n2p(w['params/t'])[0]) 135 | 136 | 137 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | import math 3 | import itertools 4 | import webdataset as wds 5 | from braceexpand import braceexpand 6 | 7 | from torchvision import transforms 8 | import torchvision.transforms.functional as TF 9 | from diffusers.training_utils import resolve_interpolation_mode 10 | 11 | from common_utils import tarfile_to_samples_nothrow, filter_keys, default_collate 12 | 13 | 14 | class SDText2ImageDataset: 15 | def __init__( 16 | self, 17 | train_shards_path_or_url: Union[str, List[str]], 18 | num_train_examples: int, 19 | per_gpu_batch_size: int, 20 | global_batch_size: int, 21 | num_workers: int, 22 | resolution: int = 512, 23 | interpolation_type: str = "bilinear", 24 | shuffle_buffer_size: int = 1000, 25 | pin_memory: bool = False, 26 | persistent_workers: bool = False, 27 | work_on_latent: bool = False, 28 | ): 29 | if not isinstance(train_shards_path_or_url, str): 30 | train_shards_path_or_url = [ 31 | list(braceexpand(urls)) for urls in train_shards_path_or_url 32 | ] 33 | # flatten list using itertools 34 | train_shards_path_or_url = list( 35 | itertools.chain.from_iterable(train_shards_path_or_url) 36 | ) 37 | 38 | if work_on_latent: 39 | processing_pipeline = [ 40 | wds.decode("l", handler=wds.ignore_and_continue), 41 | wds.rename( 42 | param="npy", 43 | text="txt", 44 | handler=wds.warn_and_continue, 45 | ), 46 | wds.map(filter_keys({"param", "text"})), 47 | wds.to_tuple("param", "text"), 48 | ] 49 | else: 50 | interpolation_mode = resolve_interpolation_mode(interpolation_type) 51 | 52 | def transform(example): 53 | # resize image 54 | image = example["image"] 55 | image = TF.resize(image, resolution, interpolation=interpolation_mode) 56 | 57 | # get crop coordinates and crop image 58 | c_top, c_left, _, _ = transforms.RandomCrop.get_params( 59 | image, output_size=(resolution, resolution) 60 | ) 61 | image = TF.crop(image, c_top, c_left, resolution, resolution) 62 | image = TF.to_tensor(image) 63 | image = TF.normalize(image, [0.5], [0.5]) 64 | 65 | example["image"] = image 66 | return example 67 | 68 | processing_pipeline = [ 69 | wds.decode("pil", handler=wds.ignore_and_continue), 70 | wds.rename( 71 | image="jpg;png;jpeg;webp", 72 | text="text;txt;caption", 73 | handler=wds.warn_and_continue, 74 | ), 75 | wds.map(filter_keys({"image", "text"})), 76 | wds.map(transform), 77 | wds.to_tuple("image", "text"), 78 | ] 79 | 80 | # Create train dataset and loader 81 | pipeline = [ 82 | wds.ResampledShards(train_shards_path_or_url), 83 | tarfile_to_samples_nothrow, 84 | wds.shuffle(shuffle_buffer_size), 85 | *processing_pipeline, 86 | wds.batched( 87 | per_gpu_batch_size, partial=False, collation_fn=default_collate 88 | ), 89 | ] 90 | 91 | num_worker_batches = math.ceil( 92 | num_train_examples / (global_batch_size * num_workers) 93 | ) # per dataloader worker 94 | num_batches = num_worker_batches * num_workers 95 | num_samples = num_batches * global_batch_size 96 | 97 | # each worker is iterating over this 98 | self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) 99 | self._train_dataloader = wds.WebLoader( 100 | self._train_dataset, 101 | batch_size=None, 102 | shuffle=False, 103 | num_workers=num_workers, 104 | pin_memory=pin_memory, 105 | persistent_workers=persistent_workers, 106 | ) 107 | # add meta-data to dataloader instance for convenience 108 | self._train_dataloader.num_batches = num_batches 109 | self._train_dataloader.num_samples = num_samples 110 | 111 | @property 112 | def train_dataset(self): 113 | return self._train_dataset 114 | 115 | @property 116 | def train_dataloader(self): 117 | return self._train_dataloader 118 | 119 | 120 | class Text2ImageDataset: 121 | def __init__( 122 | self, 123 | train_shards_path_or_url: Union[str, List[str]], 124 | num_train_examples: int, 125 | per_gpu_batch_size: int, 126 | global_batch_size: int, 127 | num_workers: int, 128 | resolution: int = 1024, 129 | shuffle_buffer_size: int = 1000, 130 | pin_memory: bool = False, 131 | persistent_workers: bool = False, 132 | use_fix_crop_and_size: bool = False, 133 | ): 134 | if not isinstance(train_shards_path_or_url, str): 135 | train_shards_path_or_url = [ 136 | list(braceexpand(urls)) for urls in train_shards_path_or_url 137 | ] 138 | # flatten list using itertools 139 | train_shards_path_or_url = list( 140 | itertools.chain.from_iterable(train_shards_path_or_url) 141 | ) 142 | 143 | def get_orig_size(json): 144 | if use_fix_crop_and_size: 145 | return (resolution, resolution) 146 | else: 147 | return ( 148 | int(json.get("original_width", 0.0)), 149 | int(json.get("original_height", 0.0)), 150 | ) 151 | 152 | def transform(example): 153 | # resize image 154 | image = example["image"] 155 | image = TF.resize( 156 | image, resolution, interpolation=transforms.InterpolationMode.BILINEAR 157 | ) 158 | 159 | # get crop coordinates and crop image 160 | c_top, c_left, _, _ = transforms.RandomCrop.get_params( 161 | image, output_size=(resolution, resolution) 162 | ) 163 | image = TF.crop(image, c_top, c_left, resolution, resolution) 164 | image = TF.to_tensor(image) 165 | image = TF.normalize(image, [0.5], [0.5]) 166 | 167 | example["image"] = image 168 | example["crop_coords"] = ( 169 | (c_top, c_left) if not use_fix_crop_and_size else (0, 0) 170 | ) 171 | return example 172 | 173 | processing_pipeline = [ 174 | wds.decode("pil", handler=wds.ignore_and_continue), 175 | wds.rename( 176 | image="jpg;png;jpeg;webp", 177 | text="text;txt;caption", 178 | orig_size="json", 179 | handler=wds.warn_and_continue, 180 | ), 181 | wds.map(filter_keys({"image", "text", "orig_size"})), 182 | wds.map_dict(orig_size=get_orig_size), 183 | wds.map(transform), 184 | wds.to_tuple("image", "text", "orig_size", "crop_coords"), 185 | ] 186 | 187 | # Create train dataset and loader 188 | pipeline = [ 189 | wds.ResampledShards(train_shards_path_or_url), 190 | tarfile_to_samples_nothrow, 191 | wds.shuffle(shuffle_buffer_size), 192 | *processing_pipeline, 193 | wds.batched( 194 | per_gpu_batch_size, partial=False, collation_fn=default_collate 195 | ), 196 | ] 197 | 198 | num_worker_batches = math.ceil( 199 | num_train_examples / (global_batch_size * num_workers) 200 | ) # per dataloader worker 201 | num_batches = num_worker_batches * num_workers 202 | num_samples = num_batches * global_batch_size 203 | 204 | # each worker is iterating over this 205 | self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) 206 | self._train_dataloader = wds.WebLoader( 207 | self._train_dataset, 208 | batch_size=None, 209 | shuffle=False, 210 | num_workers=num_workers, 211 | pin_memory=pin_memory, 212 | persistent_workers=persistent_workers, 213 | ) 214 | # add meta-data to dataloader instance for convenience 215 | self._train_dataloader.num_batches = num_batches 216 | self._train_dataloader.num_samples = num_samples 217 | 218 | @property 219 | def train_dataset(self): 220 | return self._train_dataset 221 | 222 | @property 223 | def train_dataloader(self): 224 | return self._train_dataloader 225 | --------------------------------------------------------------------------------