├── model ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── simple_tokenizer.py └── xttn.py ├── open_clip_long ├── version.py ├── bpe_simple_vocab_16e6.txt.gz ├── constants.py ├── model_configs │ ├── ViT-B-16.json │ ├── ViT-B-32.json │ ├── ViT-M-16.json │ ├── ViT-M-32.json │ ├── ViT-S-16.json │ ├── ViT-S-32.json │ ├── ViT-B-16-plus.json │ ├── ViT-L-14-280.json │ ├── ViT-L-14-336.json │ ├── ViT-L-14.json │ ├── ViT-L-16-320.json │ ├── ViT-L-16.json │ ├── ViT-M-32-alt.json │ ├── ViT-S-16-alt.json │ ├── ViT-S-32-alt.json │ ├── ViT-B-16-plus-240.json │ ├── ViT-B-32-256.json │ ├── ViT-B-32-plus-256.json │ ├── mt5-base-ViT-B-32.json │ ├── xlm-roberta-base-ViT-B-32.json │ ├── ViT-H-14.json │ ├── ViT-H-16.json │ ├── ViT-B-16-quickgelu.json │ ├── ViT-B-32-quickgelu.json │ ├── ViT-L-14-quickgelu.json │ ├── ViT-M-16-alt.json │ ├── roberta-ViT-B-32.json │ ├── mt5-xl-ViT-H-14.json │ ├── xlm-roberta-large-ViT-H-14.json │ ├── ViT-e-14.json │ ├── ViT-g-14.json │ ├── ViT-H-14-quickgelu.json │ ├── ViT-bigG-14.json │ ├── ViT-H-14-378-quickgelu.json │ ├── nllb-clip-base.json │ ├── RN50.json │ ├── RN101.json │ ├── RN50x16.json │ ├── RN50x4.json │ ├── RN50x64.json │ ├── vit_medium_patch16_gap_256.json │ ├── swin_base_patch4_window7_224.json │ ├── nllb-clip-large.json │ ├── EVA01-g-14.json │ ├── vit_relpos_medium_patch16_cls_224.json │ ├── EVA01-g-14-plus.json │ ├── EVA02-B-16.json │ ├── EVA02-L-14.json │ ├── EVA02-E-14.json │ ├── EVA02-L-14-336.json │ ├── RN101-quickgelu.json │ ├── EVA02-E-14-plus.json │ ├── RN50-quickgelu.json │ ├── convnext_base.json │ ├── convnext_base_w.json │ ├── convnext_large.json │ ├── convnext_large_d.json │ ├── convnext_small.json │ ├── convnext_tiny.json │ ├── convnext_base_w_320.json │ ├── convnext_large_d_320.json │ ├── convnext_xlarge.json │ ├── convnext_xxlarge.json │ ├── convnext_xxlarge_320.json │ ├── nllb-clip-base-siglip.json │ ├── nllb-clip-large-siglip.json │ ├── coca_roberta-ViT-B-32.json │ ├── ViT-L-14-CLIPA.json │ ├── ViT-L-14-CLIPA-336.json │ ├── ViT-H-14-CLIPA.json │ ├── ViT-H-14-CLIPA-336.json │ ├── ViT-bigG-14-CLIPA.json │ ├── ViT-bigG-14-CLIPA-336.json │ ├── coca_ViT-B-32.json │ ├── coca_ViT-L-14.json │ ├── coca_base.json │ ├── ViT-B-16-SigLIP.json │ ├── ViT-B-16-SigLIP-256.json │ ├── ViT-B-16-SigLIP-384.json │ ├── ViT-B-16-SigLIP-512.json │ ├── ViT-L-16-SigLIP-256.json │ ├── ViT-L-16-SigLIP-384.json │ ├── ViT-B-16-SigLIP-i18n-256.json │ ├── ViT-SO400M-14-SigLIP.json │ └── ViT-SO400M-14-SigLIP-384.json ├── __init__.py ├── hf_configs.py ├── openai.py ├── utils.py ├── pos_embed.py ├── zero_shot_classifier.py ├── timm_model.py ├── hf_model.py ├── modified_resnet.py ├── big_vision.py └── push_to_hf_hub.py ├── SDXL ├── demo_SDXL.png ├── SDXL.md ├── sdxl.py ├── fid_score.py └── inception.py ├── test.sh ├── train.sh ├── train ├── scheduler.py ├── train.md ├── utils.py ├── arguments.py ├── sharegpt4v.py ├── eval_data.py ├── loss.py ├── test.py └── train.py ├── README.md ├── requirements.txt └── .gitignore /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .finelip import * 2 | -------------------------------------------------------------------------------- /open_clip_long/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.24.0' 2 | -------------------------------------------------------------------------------- /SDXL/demo_SDXL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiiuae/FineLIP/HEAD/SDXL/demo_SDXL.png -------------------------------------------------------------------------------- /model/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiiuae/FineLIP/HEAD/model/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip_long/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiiuae/FineLIP/HEAD/open_clip_long/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip_long/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_STD = (0.229, 0.224, 0.225) 5 | INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | INCEPTION_STD = (0.5, 0.5, 0.5) 7 | -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-M-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-M-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-S-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-S-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-M-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-S-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-S-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-32-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/mt5-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "google/mt5-base", 11 | "hf_tokenizer_name": "google/mt5-base", 12 | "hf_pooler_type": "mean_pooler" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /open_clip_long/model_configs/xlm-roberta-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "xlm-roberta-base", 11 | "hf_tokenizer_name": "xlm-roberta-base", 12 | "hf_pooler_type": "mean_pooler" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set variables 4 | # CKPT_PATH="experiments/finelip_baseline/ckpt/baseline_128_epoch_6_finelip-B.pt" 5 | CKPT_PATH="experiments/finelip/ckpt/finelip_128_epoch_4_finelip-B.pt" 6 | TEST_DATA="coco" #urban,coco,flickr, docci 7 | 8 | # Use variables 9 | echo "Running evaluation on $TEST_DATA" 10 | python train/test.py --ckpt_path $CKPT_PATH --test_data $TEST_DATA --run_finelip 11 | #--finegrain -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-16-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-L-14-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 24, 7 | "width": 1024, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-M-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16, 8 | "ls_init_value": 1e-4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 384, 14 | "heads": 6, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /open_clip_long/model_configs/mt5-xl-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "google/mt5-xl", 12 | "hf_tokenizer_name": "google/mt5-xl", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /open_clip_long/model_configs/xlm-roberta-large-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "xlm-roberta-large", 12 | "hf_tokenizer_name": "xlm-roberta-large", 13 | "hf_pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-e-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 56, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.5715, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 36 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-H-14-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 32, 7 | "width": 1280, 8 | "head_width": 80, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-H-14-378-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 378, 6 | "layers": 32, 7 | "width": 1280, 8 | "head_width": 80, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/nllb-clip-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "facebook/nllb-200-distilled-600M", 11 | "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", 12 | "hf_proj_type": "linear", 13 | "hf_pooler_type": "cls_pooler" 14 | } 15 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/RN50x64.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": [ 6 | 3, 7 | 15, 8 | 36, 9 | 10 10 | ], 11 | "width": 128, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 1024, 18 | "heads": 16, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/vit_medium_patch16_gap_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_medium_patch16_gap_256", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 256 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 640, 14 | "heads": 10, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/nllb-clip-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "facebook/nllb-200-distilled-1.3B", 12 | "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", 13 | "hf_proj_type": "linear", 14 | "hf_pooler_type": "cls_pooler" 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/EVA01-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/vit_relpos_medium_patch16_cls_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/EVA01-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/EVA02-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_base_patch16_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/EVA02-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_large_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/EVA02-E-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/EVA02-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "timm_model_name": "eva02_large_patch14_clip_336", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/EVA02-E-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1280, 14 | "heads": 20, 15 | "layers": 32 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_base_w.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_large_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_small", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_tiny", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_base_w_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_large_d_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 20 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_xxlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/convnext_xxlarge_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/nllb-clip-base-siglip.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "custom_text": true, 4 | "init_logit_bias": -10, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_base_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "hf_model_name": "facebook/nllb-200-distilled-600M", 14 | "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", 15 | "hf_proj_type": "linear", 16 | "hf_pooler_type": "cls_pooler" 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/nllb-clip-large-siglip.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "custom_text": true, 4 | "init_logit_bias": -10, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_so400m_patch14_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "hf_model_name": "facebook/nllb-200-distilled-1.3B", 14 | "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", 15 | "hf_proj_type": "linear", 16 | "hf_pooler_type": "cls_pooler" 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/coca_roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "output_tokens": true 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "hf_proj_type": "linear", 14 | "width": 768, 15 | "output_tokens": true 16 | }, 17 | "multimodal_cfg": { 18 | "context_length": 76, 19 | "width": 768, 20 | "heads": 8, 21 | "layers": 12 22 | }, 23 | "custom_text": true 24 | } 25 | -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-L-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "no_ln_pre": true, 9 | "pool_type": "avg", 10 | "final_ln_after_pool": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 32, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "bert-base-uncased", 16 | "tokenizer_kwargs": { 17 | "strip_sep_token": true 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "pool_type": "last", 23 | "no_causal_mask": true 24 | } 25 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-L-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "no_ln_pre": true, 9 | "pool_type": "avg", 10 | "final_ln_after_pool": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 32, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "bert-base-uncased", 16 | "tokenizer_kwargs": { 17 | "strip_sep_token": true 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "pool_type": "last", 23 | "no_causal_mask": true 24 | } 25 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-H-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14, 9 | "no_ln_pre": true, 10 | "pool_type": "avg", 11 | "final_ln_after_pool": true 12 | }, 13 | "text_cfg": { 14 | "context_length": 32, 15 | "vocab_size": 32000, 16 | "hf_tokenizer_name": "bert-base-uncased", 17 | "tokenizer_kwargs": { 18 | "strip_sep_token": true 19 | }, 20 | "width": 1024, 21 | "heads": 16, 22 | "layers": 24, 23 | "pool_type": "last", 24 | "no_causal_mask": true 25 | } 26 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-H-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14, 9 | "no_ln_pre": true, 10 | "pool_type": "avg", 11 | "final_ln_after_pool": true 12 | }, 13 | "text_cfg": { 14 | "context_length": 32, 15 | "vocab_size": 32000, 16 | "hf_tokenizer_name": "bert-base-uncased", 17 | "tokenizer_kwargs": { 18 | "strip_sep_token": true 19 | }, 20 | "width": 1024, 21 | "heads": 16, 22 | "layers": 24, 23 | "pool_type": "last", 24 | "no_causal_mask": true 25 | } 26 | } -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export MKL_NUM_THREADS=1 3 | export NUMEXPR_NUM_THREADS=1 4 | export OMP_NUM_THREADS=4 5 | export LD_LIBRARY_PATH="" 6 | export CUDA_VISIBLE_DEVICES=0,1,2,3 7 | 8 | GPUs=4 9 | GPUs_per_node=4 10 | 11 | WORLD_SIZE=$(($GPUs)) 12 | 13 | # for baseline, run: 14 | # torchrun --nproc_per_node=$GPUs_per_node \ 15 | # --nnodes=1 \ 16 | # --node_rank=0 \ 17 | # --master_port=25678 \ 18 | # train/train.py \ 19 | # --exp_name="baseline" \ 20 | # --run_baseline \ 21 | # --enable_wandb 22 | 23 | 24 | torchrun --nproc_per_node=$GPUs_per_node \ 25 | --nnodes=1 \ 26 | --node_rank=0 \ 27 | --master_port=25678 \ 28 | train/train.py \ 29 | --exp_name="finelip" \ 30 | --enable_wandb -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-bigG-14-CLIPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14, 10 | "no_ln_pre": true, 11 | "pool_type": "avg", 12 | "final_ln_after_pool": true 13 | }, 14 | "text_cfg": { 15 | "context_length": 32, 16 | "vocab_size": 32000, 17 | "hf_tokenizer_name": "bert-base-uncased", 18 | "tokenizer_kwargs": { 19 | "strip_sep_token": true 20 | }, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "pool_type": "last", 25 | "no_causal_mask": true 26 | } 27 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-bigG-14-CLIPA-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14, 10 | "no_ln_pre": true, 11 | "pool_type": "avg", 12 | "final_ln_after_pool": true 13 | }, 14 | "text_cfg": { 15 | "context_length": 32, 16 | "vocab_size": 32000, 17 | "hf_tokenizer_name": "bert-base-uncased", 18 | "tokenizer_kwargs": { 19 | "strip_sep_token": true 20 | }, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "pool_type": "last", 25 | "no_causal_mask": true 26 | } 27 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/coca_ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 512, 25 | "heads": 8, 26 | "layers": 12, 27 | "attn_pooler_heads": 8 28 | }, 29 | "custom_text": true 30 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/coca_ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 768, 25 | "heads": 12, 26 | "layers": 12, 27 | "attn_pooler_heads": 12 28 | }, 29 | "custom_text": true 30 | } 31 | -------------------------------------------------------------------------------- /open_clip_long/model_configs/coca_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "multimodal_cfg": { 4 | "width": 768, 5 | "context_length": 76, 6 | "vocab_size": 64000, 7 | "mlp_ratio": 4, 8 | "layers": 12, 9 | "dim_head": 64, 10 | "heads": 12, 11 | "n_queries": 256, 12 | "attn_pooler_heads": 8 13 | }, 14 | "vision_cfg": { 15 | "image_size": 288, 16 | "layers": 12, 17 | "width": 768, 18 | "patch_size": 18, 19 | "output_tokens": true 20 | }, 21 | "text_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 64000, 24 | "layers": 12, 25 | "heads": 12, 26 | "width": 768, 27 | "embed_cls": true, 28 | "output_tokens": true 29 | }, 30 | "custom_text": true 31 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-16-SigLIP.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 224, 7 | "timm_model_name": "vit_base_patch16_siglip_224", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-16-SigLIP-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_base_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-16-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_base_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-16-SigLIP-512.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 512, 7 | "timm_model_name": "vit_base_patch16_siglip_512", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-L-16-SigLIP-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_large_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-L-16-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_large_patch16_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-B-16-SigLIP-i18n-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 256, 7 | "timm_model_name": "vit_base_patch16_siglip_256", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 250000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 768, 20 | "heads": 12, 21 | "layers": 12, 22 | "no_causal_mask": true, 23 | "proj_bias": true, 24 | "pool_type": "last", 25 | "norm_kwargs":{ 26 | "eps": 1e-6 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-SO400M-14-SigLIP.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 224, 7 | "timm_model_name": "vit_so400m_patch14_siglip_224", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 16, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1152, 20 | "heads": 16, 21 | "layers": 27, 22 | "mlp_ratio": 3.7362, 23 | "no_causal_mask": true, 24 | "proj_bias": true, 25 | "pool_type": "last", 26 | "norm_kwargs":{ 27 | "eps": 1e-6 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /open_clip_long/model_configs/ViT-SO400M-14-SigLIP-384.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1152, 3 | "init_logit_bias": -10, 4 | "custom_text": true, 5 | "vision_cfg": { 6 | "image_size": 384, 7 | "timm_model_name": "vit_so400m_patch14_siglip_384", 8 | "timm_model_pretrained": false, 9 | "timm_pool": "map", 10 | "timm_proj": "none" 11 | }, 12 | "text_cfg": { 13 | "context_length": 64, 14 | "vocab_size": 32000, 15 | "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", 16 | "tokenizer_kwargs": { 17 | "clean": "canonicalize" 18 | }, 19 | "width": 1152, 20 | "heads": 16, 21 | "layers": 27, 22 | "mlp_ratio": 3.7362, 23 | "no_causal_mask": true, 24 | "proj_bias": true, 25 | "pool_type": "last", 26 | "norm_kwargs":{ 27 | "eps": 1e-6 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /train/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lrs): 5 | for param_group, new_lr in zip(optimizer.param_groups, new_lrs): 6 | param_group["lr"] = new_lr 7 | 8 | 9 | 10 | def _warmup_lr(base_lr, warmup_length, step): 11 | return base_lr * (step + 1) / warmup_length 12 | 13 | 14 | def cosine_lr(optimizer, base_lrs, warmup_length, steps): 15 | def _lr_adjuster(step): 16 | lrs = [] 17 | for base_lr in base_lrs: 18 | if step < warmup_length: 19 | lr = _warmup_lr(base_lr, warmup_length, step) 20 | else: 21 | e = step - warmup_length 22 | es = steps - warmup_length 23 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 24 | lrs.append(lr) 25 | assign_learning_rate(optimizer, lrs) 26 | return lrs 27 | return _lr_adjuster 28 | -------------------------------------------------------------------------------- /SDXL/SDXL.md: -------------------------------------------------------------------------------- 1 | # Finelip-SDXL 2 | To run Finelip-SDXL, please follow the following step. 3 | 4 | ### 1. Prepare SDXL Model 5 | Download the pre-trained weights of SDXL-base and SDXL-refiner in the following pages: 6 | [https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 7 | [https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0) 8 | 9 | ### 2. Prepare the text encoders 10 | Download the pre-trained Finelip-L and Finelip-bigG respectively. 11 | 12 | [https://huggingface.co/BeichenZhang/LongCLIP-L](https://huggingface.co/BeichenZhang/LongCLIP-L) 13 | [https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) 14 | 15 | ### 3. Start generating images. 16 | Finally, you can run the `sdxl.py` for generating images. 17 | 18 | python sdxl.py \ 19 | --ckpt_path /path/to/checkpoint.pth \ 20 | --enable_finelip \ 21 | --img_dir /path/to/images \ 22 | --dataset dataset_name -------------------------------------------------------------------------------- /open_clip_long/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ 8 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg 9 | from .openai import load_openai_model, list_openai_models 10 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 11 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 12 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 13 | from .tokenizer import SimpleTokenizer, tokenize, decode 14 | from .transform import image_transform, AugmentationCfg 15 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy 16 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FineLIP 2 | This repository is the official implementation of FineLIP 3 | 4 | **FineLIP: Extending CLIP’s Reach via Fine-Grained Alignment with Longer 5 | Text Inputs** 6 | 7 | ## Getting Started 8 | 9 | ### Environment Setup 10 | 11 | Refer to the `requirements.txt` file to prepare the environment. 12 | 13 | ### Model Training 14 | 15 | For model training, run: 16 | ```shell 17 | bash train.sh 18 | ``` 19 | Refer to `train/arguments.py` for all the defined variables. 20 | 21 | ### Model Evaluation 22 | 23 | #### Short & Long Caption Retrieval 24 | To run retrieval on short caption datasets [COCO2017, Flickr30k] and long caption datasets [DOCCI, Urban1K], run the following command after preparing the data: 25 | 26 | ```shell 27 | bash test.sh 28 | ``` 29 | Set the appropriate variables in the bash file. 30 | 31 | ## 🙏 Acknowledgments 32 | This project builds upon the following open-source resources: 33 | 34 | - LongCLIP [https://github.com/beichenzbc/Long-CLIP] 35 | - LAPS [https://github.com/CrossmodalGroup/LAPS] 36 | 37 | ## Citation 38 | If this project benefits your research, please consider citing our work: 39 | ``` 40 | @InProceedings{Asokan_2025_CVPR, 41 | author = {Asokan, Mothilal and Wu, Kebin and Albreiki, Fatima}, 42 | title = {FineLIP: Extending CLIP's Reach via Fine-Grained Alignment with Longer Text Inputs}, 43 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, 44 | month = {June}, 45 | year = {2025}, 46 | pages = {14495-14504} 47 | } 48 | -------------------------------------------------------------------------------- /train/train.md: -------------------------------------------------------------------------------- 1 | # FineLIP Training 2 | To run the training code for FineLIP, please follow the following step. 3 | 4 | ### 1. Prepare ShareGPT4V dataset [Taken from LongCLIP Repo] 5 | 6 | First, download all images we used. 7 | - LAION-CC-SBU-558K: [images.zip](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/blob/main/images.zip) 8 | - COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip) 9 | - WebData: [images](https://drive.google.com/drive/folders/1tCUQ-sq6vdshZVkF0ZeF3K4eztkXJgax?usp=sharing). Only for academic usage. 10 | - SAM: [images](https://ai.meta.com/datasets/segment-anything-downloads/). We only use 000000~000050.tar for now. If you just want to use ShareGPT4V for SFT, you can quickly download 9K images from [here](https://drive.google.com/file/d/1dKumdOKSXtV7lIXdrG7jsIK_z2vZv2gs/view?usp=drive_link). 11 | - GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip) 12 | - OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing). We save all files as `.jpg` 13 | - TextVQA: [trainvalimages](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) 14 | - VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip) 15 | 16 | Then, download the long caption of these image [share-captioner_coco_lcs_sam_1246k_1107.json](https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/share-captioner_coco_lcs_sam_1246k_1107.json) 17 | 18 | ``` 19 | ShareGPT4V 20 | ├── ... 21 | ├── data 22 | | ├── share-captioner_coco_lcs_sam_1246k_1107.json 23 | │ ├── llava 24 | │ │ ├── llava_pretrain 25 | │ │ │ ├── images 26 | │ ├── coco 27 | │ │ ├── train2017 28 | │ ├── sam 29 | │ │ ├── images 30 | │ ├── gqa 31 | │ │ ├── images 32 | │ ├── ocr_vqa 33 | │ │ ├── images 34 | │ ├── textvqa 35 | │ │ ├── train_images 36 | │ ├── vg 37 | │ │ ├── VG_100K 38 | │ │ ├── VG_100K_2 39 | │ ├── share_textvqa 40 | │ │ ├── images 41 | │ ├── web-celebrity 42 | │ │ ├── images 43 | │ ├── web-landmark 44 | │ │ ├── images 45 | │ ├── wikiart 46 | │ │ ├── images 47 | ├── ... 48 | ``` 49 | Then, change the data root in `sharegpt4v.py` 50 | 51 | ### 2. Train FineLIP 52 | 53 | Run train.sh to train FineLIP. 54 | -------------------------------------------------------------------------------- /train/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | def is_dist_avail_and_initialized(): 5 | if not dist.is_available(): 6 | return False 7 | if not dist.is_initialized(): 8 | return False 9 | return True 10 | 11 | @torch.no_grad() 12 | def concat_all_gather(tensor): 13 | """ 14 | Performs all_gather operation on the provided tensors. 15 | *** Warning ***: torch.distributed.all_gather has no gradient. 16 | """ 17 | # if use distributed training 18 | if not is_dist_avail_and_initialized(): 19 | return tensor 20 | 21 | tensors_gather = [ 22 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 23 | ] 24 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 25 | 26 | output = torch.cat(tensors_gather, dim=0) 27 | return output 28 | 29 | def accuracy(output, target, topk=(1,)): 30 | pred = output.topk(max(topk), 1, True, True)[1].t() 31 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 32 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 33 | 34 | # come from BLIP, https://github.com/salesforce/BLIP/blob/b7bb1eeb6e901044a9eb1016f408ee908b216bc7/models/blip_retrieval.py#L306 35 | # Gather tensors from all workers with support for backward propagation: 36 | # This implementation does not cut the gradients as torch.distributed.all_gather does. 37 | class GatherLayer(torch.autograd.Function): 38 | 39 | @staticmethod 40 | def forward(ctx, x): 41 | output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())] 42 | torch.distributed.all_gather(output, x) 43 | return tuple(output) 44 | 45 | @staticmethod 46 | def backward(ctx, *grads): 47 | all_gradients = torch.stack(grads) 48 | 49 | # op=torch.distributed.ReduceOp.SUM 50 | torch.distributed.all_reduce(all_gradients) 51 | 52 | return all_gradients[torch.distributed.get_rank()] 53 | 54 | 55 | # Performs all_gather operation on the provided tensors. 56 | # Graph remains connected for backward grad computation. 57 | def all_gather_with_grad(tensors): 58 | 59 | # if use distributed training 60 | if not is_dist_avail_and_initialized(): 61 | return tensors 62 | 63 | tensor_all = GatherLayer.apply(tensors) 64 | 65 | return torch.cat(tensor_all, dim=0) -------------------------------------------------------------------------------- /open_clip_long/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | # https://huggingface.co/docs/transformers/model_doc/bert 46 | "bert": { 47 | "config_names": { 48 | "context_length": "max_position_embeddings", 49 | "vocab_size": "vocab_size", 50 | "width": "hidden_size", 51 | "heads": "num_attention_heads", 52 | "layers": "num_hidden_layers", 53 | }, 54 | "pooler": "cls_pooler", 55 | }, 56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100 57 | "m2m_100": { 58 | "config_names": { 59 | "context_length": "max_position_embeddings", 60 | "vocab_size": "vocab_size", 61 | "width": "d_model", 62 | "heads": "encoder_attention_heads", 63 | "layers": "encoder_layers", 64 | }, 65 | "pooler": "cls_pooler", 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /train/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser(description="params") 5 | parser.add_argument("--exp_name", default="experiment", type=str, help="specify experiment name.") 6 | parser.add_argument("--run_baseline", action="store_true", help="run baseline") 7 | parser.add_argument("--seed", default=71, type=int, help="seed") 8 | parser.add_argument("--accumulation_steps", default=4, type=int, help="accumulation steps") 9 | parser.add_argument("--epochs", default=6, type=int, help="num_epochs") 10 | parser.add_argument("--lr", default=1e-6, type=float, help="lr.") 11 | parser.add_argument("--cross_net_lr", default=2e-4, type=float, help="cross_net_lr.") 12 | parser.add_argument("--weight_decay", default=1e-2, type=float, help="wd.") 13 | parser.add_argument("--log_scale", default=4.6052, type=float, help="clip temperature.") 14 | parser.add_argument("--warmup_length", default=200, type=int, help="warmup_length.") 15 | parser.add_argument("--base_model", default="ViT-B/16", help="CLIP Base Model") 16 | parser.add_argument("--enable_wandb", action="store_true", help="enable wandb logging") 17 | parser.add_argument("--s3_bucket", default=None, type=str, help="s3 bucket path") 18 | parser.add_argument("--debug", action="store_true", help="debug mode") 19 | parser.add_argument("--global_batch_size", default=128, type=int, help="global batch size") 20 | parser.add_argument("--resume_path", default=None, type=str, help="resume path") 21 | 22 | parser.add_argument('--embed_size', default=512, type=int, help='Dimensionality of the joint embedding.') 23 | parser.add_argument('--num_patches', default=196, type=int, help='Number of patches.') 24 | parser.add_argument('--loss_finegrain', default='vse', type=str, help='the objective function for optimization') 25 | parser.add_argument('--margin', default=0.2, type=float, help='Rank loss margin.') 26 | parser.add_argument('--max_violation', action='store_true', help='Use max instead of sum in the rank loss.') 27 | parser.add_argument('--vse_mean_warmup_epochs', type=int, default=1, help='The number of warmup epochs using mean vse loss') 28 | parser.add_argument('--embedding_warmup_epochs', type=int, default=0, help='The number of epochs for warming up the embedding layer') 29 | # cross-modal alignment 30 | parser.add_argument('--aggr_ratio', default=0.4, type=float, help='the aggr rate for visual token') 31 | parser.add_argument('--sparse_ratio', default=0.5, type=float, help='the sprase rate for visual token') 32 | parser.add_argument('--attention_weight', default=0.8, type=int, help='the weight of attention_map for mask prediction') 33 | parser.add_argument('--ratio_weight', default=2.0, type=float, help='if use detach for kt loss') 34 | return parser -------------------------------------------------------------------------------- /train/sharegpt4v.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image, UnidentifiedImageError 3 | import clip 4 | import torch.utils.data as data 5 | import io 6 | import s3fs 7 | 8 | data4v_root = 'data/' 9 | json_name = 'share-captioner_coco_lcs_sam_1246k_1107.json' 10 | image_root = "s3://dataset-bucket/ShareGPT4V/data/" 11 | 12 | class share4v_val_dataset(data.Dataset): 13 | def __init__(self, preprocess=None): 14 | self.data4v_root = data4v_root 15 | self.json_name = json_name 16 | self.image_root = image_root 17 | self.total_len = 1000 18 | with open(data4v_root + json_name, 'r',encoding='utf8')as fp: 19 | self.json_data = json.load(fp)[:self.total_len] 20 | if preprocess is None: 21 | _ , self.preprocess = clip.load("ViT-L/14") 22 | else: 23 | self.preprocess = preprocess 24 | def __len__(self): 25 | return self.total_len 26 | 27 | def __getitem__(self, index): 28 | caption = self.json_data[index]['conversations'][1]['value'] 29 | caption = caption.replace("\n", " ") 30 | image_name = self.image_root + self.json_data[index]['image'] 31 | try: 32 | image = Image.open(io.BytesIO(s3fs.S3FileSystem().open(image_name).read())) 33 | image = image.convert('RGB') 34 | except (OSError, UnidentifiedImageError) as e: 35 | print(f"Error loading image ({image_name})") 36 | image = Image.new('RGB', (224, 224), color='black') 37 | 38 | image_tensor = self.preprocess(image) 39 | return image_tensor, caption 40 | 41 | 42 | class share4v_train_dataset(data.Dataset): 43 | def __init__(self): 44 | self.data4v_root = data4v_root 45 | self.json_name = json_name 46 | self.image_root = image_root 47 | self.total_len = 1000 48 | with open(data4v_root + json_name, 'r',encoding='utf8')as fp: 49 | self.json_data = json.load(fp)[self.total_len:] 50 | _ , self.preprocess = clip.load("ViT-L/14") 51 | 52 | def __len__(self): 53 | return len(self.json_data) 54 | 55 | def __getitem__(self, index): 56 | caption = self.json_data[index]['conversations'][1]['value'] 57 | caption = caption.replace("\n", " ") 58 | 59 | 60 | caption_short = caption.split(". ")[0] 61 | 62 | image_name = self.image_root + self.json_data[index]['image'] 63 | 64 | try: 65 | image = Image.open(io.BytesIO(s3fs.S3FileSystem().open(image_name).read())) 66 | image = image.convert('RGB') 67 | except (OSError, UnidentifiedImageError) as e: 68 | print(f"Error loading image ({image_name})") 69 | image = Image.new('RGB', (224, 224), color = 'black') 70 | 71 | image_tensor = self.preprocess(image) 72 | return image_tensor, caption, caption_short, index 73 | 74 | 75 | if __name__ == "__main__": 76 | import pdb 77 | train_dataset = share4v_train_dataset() 78 | val_dataset = share4v_train_dataset() 79 | pdb.set_trace() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.29.3 3 | aiobotocore==2.12.3 4 | aiohttp==3.9.5 5 | aioitertools==0.11.0 6 | aiosignal==1.3.1 7 | alabaster==0.7.16 8 | appdirs==1.4.4 9 | astunparse==1.6.3 10 | async-timeout==4.0.3 11 | attrs==23.2.0 12 | Babel==2.14.0 13 | beautifulsoup4==4.12.3 14 | botocore==1.34.69 15 | Brotli==1.1.0 16 | certifi==2024.2.2 17 | charset-normalizer==3.3.2 18 | click==8.1.7 19 | clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33 20 | contourpy==1.2.1 21 | cycler==0.12.1 22 | datasets==2.19.0 23 | dill==0.3.8 24 | diffusers==0.30.2 25 | docker-pycreds==0.4.0 26 | docstring_parser==0.16 27 | docutils==0.16 28 | eval_type_backport==0.2.0 29 | filelock==3.14.0 30 | fire==0.6.0 31 | flatbuffers==24.3.25 32 | fonttools==4.51.0 33 | frozenlist==1.4.1 34 | fsspec==2024.3.1 35 | ftfy==6.2.0 36 | futures==3.0.5 37 | gast==0.5.4 38 | gdown==5.2.0 39 | gitdb==4.0.11 40 | GitPython==3.1.43 41 | gmpy2==2.1.5 42 | google-pasta==0.2.0 43 | grpcio==1.63.0 44 | h5py==3.11.0 45 | huggingface-hub==0.24.7 46 | idna==3.7 47 | imagesize==1.4.1 48 | importlib_metadata==7.1.0 49 | importlib_resources==6.4.0 50 | invisible_watermark==0.2.0 51 | jinja2==3.1.3 52 | jmespath==1.0.1 53 | keras==3.3.3 54 | kiwisolver==1.4.5 55 | libclang==18.1.1 56 | Markdown==3.6 57 | markdown-it-py==3.0.0 58 | markupsafe==2.1.5 59 | matplotlib==3.8.4 60 | mdurl==0.1.2 61 | ml-dtypes==0.3.2 62 | mpmath==1.3.0 63 | multidict==6.0.5 64 | multiprocess==0.70.16 65 | namex==0.0.8 66 | namex==0.0.8 67 | numpy==1.26.4 68 | open_clip_torch==2.26.1 69 | opencv-python==4.9.0.80 70 | opt-einsum==3.3.0 71 | optree==0.11.0 72 | packaging==24.0 73 | pandas==2.2.2 74 | pillow==10.3.0 75 | protobuf==4.25.3 76 | psutil==5.9.8 77 | pyarrow==16.0.0 78 | pyarrow-hotfix==0.6 79 | pycocoevalcap @ git+https://github.com/jmhessel/pycocoevalcap 80 | Pygments==2.17.2 81 | pyparsing==3.1.2 82 | pysocks==1.7.1 83 | python-dateutil==2.9.0.post0 84 | pytz==2024.1 85 | PyWavelets==1.6.0 86 | pyyaml==6.0.1 87 | regex==2024.4.28 88 | requests==2.31.0 89 | rich==13.7.1 90 | s3==3.0.0 91 | s3fs==2024.3.1 92 | safetensors==0.4.3 93 | scipy==1.13.0 94 | seaborn==0.13.2 95 | sentry-sdk==2.0.1 96 | setproctitle==1.3.3 97 | setuptools==69.5.1 98 | shtab==1.7.1 99 | six==1.16.0 100 | smmap==5.0.1 101 | snowballstemmer==2.2.0 102 | soupsieve==2.6 103 | Sphinx==5.3.0 104 | sphinx-rtd-theme==0.5.2 105 | sphinxcontrib-applehelp==1.0.8 106 | sphinxcontrib-devhelp==1.0.6 107 | sphinxcontrib-htmlhelp==2.0.5 108 | sphinxcontrib-jsmath==1.0.1 109 | sphinxcontrib-qthelp==1.0.7 110 | sphinxcontrib-serializinghtml==1.1.10 111 | sympy==1.12 112 | tensorboard==2.16.2 113 | tensorboard-data-server==0.7.2 114 | tensorflow==2.16.1 115 | tensorflow-io-gcs-filesystem==0.37.0 116 | termcolor==2.4.0 117 | tokenizers==0.19.1 118 | torch==2.2.2 119 | torchaudio==2.2.2 120 | torchvision==0.17.2 121 | tqdm==4.66.2 122 | transformers==4.40.1 123 | triton==2.2.0 124 | trl==0.8.6 125 | typing_extensions==4.11.0 126 | tyro==0.8.3 127 | tzdata==2024.1 128 | urllib3==1.26.18 129 | wandb==0.16.6 130 | wcwidth==0.2.13 131 | Werkzeug==3.0.2 132 | wrapt==1.16.0 133 | xmltodict==0.13.0 134 | xxhash==3.4.1 135 | yarl==1.9.4 136 | zipp==3.18.1 137 | ffmpeg==1.4 -------------------------------------------------------------------------------- /open_clip_long/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 15 | 16 | __all__ = ["list_openai_models", "load_openai_model"] 17 | 18 | 19 | def list_openai_models() -> List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model 91 | -------------------------------------------------------------------------------- /open_clip_long/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torchvision.ops.misc import FrozenBatchNorm2d 7 | 8 | 9 | def freeze_batch_norm_2d(module, module_match={}, name=''): 10 | """ 11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 13 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 14 | 15 | Args: 16 | module (torch.nn.Module): Any PyTorch module. 17 | module_match (dict): Dictionary of full module names to freeze (all if empty) 18 | name (str): Full module name (prefix) 19 | 20 | Returns: 21 | torch.nn.Module: Resulting module 22 | 23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 24 | """ 25 | res = module 26 | is_match = True 27 | if module_match: 28 | is_match = name in module_match 29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 30 | res = FrozenBatchNorm2d(module.num_features) 31 | res.num_features = module.num_features 32 | res.affine = module.affine 33 | if module.affine: 34 | res.weight.data = module.weight.data.clone().detach() 35 | res.bias.data = module.bias.data.clone().detach() 36 | res.running_mean.data = module.running_mean.data 37 | res.running_var.data = module.running_var.data 38 | res.eps = module.eps 39 | else: 40 | for child_name, child in module.named_children(): 41 | full_child_name = '.'.join([name, child_name]) if name else child_name 42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 43 | if new_child is not child: 44 | res.add_module(child_name, new_child) 45 | return res 46 | 47 | 48 | # From PyTorch internals 49 | def _ntuple(n): 50 | def parse(x): 51 | if isinstance(x, collections.abc.Iterable): 52 | return x 53 | return tuple(repeat(x, n)) 54 | return parse 55 | 56 | 57 | to_1tuple = _ntuple(1) 58 | to_2tuple = _ntuple(2) 59 | to_3tuple = _ntuple(3) 60 | to_4tuple = _ntuple(4) 61 | to_ntuple = lambda n, x: _ntuple(n)(x) 62 | 63 | # Replaces all linear layers with linear_replacement 64 | # TODO: add int8 support for other linear layers including attn and convnets 65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 66 | for name, module in model.named_children(): 67 | if len(list(module.children())) > 0: 68 | replace_linear(module, linear_replacement, include_modules, copy_weights) 69 | 70 | if isinstance(module, torch.nn.Linear) and name in include_modules: 71 | old_module = model._modules[name] 72 | model._modules[name] = linear_replacement( 73 | module.in_features, 74 | module.out_features, 75 | module.bias is not None, 76 | ) 77 | if copy_weights: 78 | model._modules[name].weight.data.copy_(old_module.weight.data) 79 | if model._modules[name].bias is not None: 80 | model._modules[name].bias.data.copy_(old_module.bias) 81 | 82 | return model 83 | 84 | def convert_int8_model_to_inference_mode(model): 85 | for m in model.modules(): 86 | if hasattr(m, 'prepare_for_eval'): 87 | int8_original_dtype = m.weight.dtype 88 | m.prepare_for_eval() 89 | m.int8_original_dtype = int8_original_dtype -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .vscode 163 | *.ipynb 164 | *.sh 165 | data 166 | experiments 167 | wandb 168 | SDXL/images 169 | model/tome 170 | SDXL/streamlit_app* -------------------------------------------------------------------------------- /open_clip_long/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /SDXL/sdxl.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | from glob import glob 3 | import torch.utils.data as data 4 | 5 | import sys 6 | sys.path.append('..') 7 | from diffusers import DiffusionPipeline 8 | import torch 9 | from open_clip_long import factory as open_clip 10 | import torch.nn as nn 11 | import inspect 12 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 13 | from transformers import ( 14 | CLIPImageProcessor, 15 | CLIPTextModel, 16 | CLIPTextModelWithProjection, 17 | CLIPTokenizer, 18 | CLIPVisionModelWithProjection, 19 | ) 20 | 21 | from SDXL_pipeline import get_image 22 | from SDXL_img2img import image2image 23 | 24 | import argparse 25 | from encode_prompt import initialize 26 | import warnings 27 | 28 | warnings.filterwarnings("ignore") 29 | 30 | class docci_dataset(data.Dataset): 31 | def __init__(self, filter='test'): 32 | self.caption_root = '../data/docci/docci_descriptions.jsonlines' 33 | self.total_caption = [] 34 | with open(self.caption_root, 'r') as f: 35 | for line in f: 36 | line = json.loads(line) 37 | if line['split'] == filter: 38 | self.total_caption.append(line) 39 | 40 | def __len__(self): 41 | return len(self.total_caption) 42 | 43 | def __getitem__(self, index): 44 | caption_json = self.total_caption[index] 45 | image_name = caption_json['image_file'] 46 | caption = caption_json['description'] 47 | 48 | return image_name, caption 49 | 50 | class urban_dataset(data.Dataset): 51 | def __init__(self): 52 | self.caption_root = '../data/Urban1k/caption/' 53 | self.total_caption = sorted(glob(f'{self.caption_root}*.txt')) 54 | 55 | def __len__(self): 56 | return len(self.total_caption) 57 | 58 | def __getitem__(self, index): 59 | caption_name = self.total_caption[index] 60 | image_name = caption_name.split('/')[-1][:-4] + '.jpg' 61 | caption = open(caption_name).read() 62 | 63 | return image_name, caption 64 | 65 | parser = argparse.ArgumentParser(description='params') 66 | parser.add_argument('--ckpt_path', default='work/experiments/finelip-L.pt', help="ckpt_path") 67 | parser.add_argument('--enable_finelip', action='store_true', help='enable finelip') 68 | parser.add_argument('--img_dir', default='images/finelip', help='output image directory') 69 | parser.add_argument('--dataset', default='docci', help='dataset to use') 70 | args = parser.parse_args() 71 | 72 | initialize(args) 73 | # fix the seed 74 | generator = torch.Generator(device='cuda') 75 | generator.manual_seed(1971) 76 | 77 | base = DiffusionPipeline.from_pretrained( 78 | "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True 79 | ) 80 | base.to("cuda") 81 | 82 | refiner = DiffusionPipeline.from_pretrained( 83 | "stabilityai/stable-diffusion-xl-refiner-1.0", 84 | text_encoder_2=base.text_encoder_2, 85 | vae=base.vae, 86 | torch_dtype=torch.float16, 87 | use_safetensors=True, 88 | variant="fp16", 89 | ) 90 | refiner.to("cuda") 91 | 92 | # Define how many steps and what % of steps to be run on each experts (80/20) here 93 | n_steps = 40 94 | high_noise_frac = 0.8 95 | 96 | if args.dataset == 'docci': 97 | testset = docci_dataset() 98 | elif args.dataset == 'urban1k': 99 | testset = urban_dataset() 100 | else: 101 | raise ValueError("Invalid dataset") 102 | test_loader = data.DataLoader(testset, batch_size=1, shuffle=False) 103 | 104 | if not os.path.exists(args.img_dir): 105 | os.makedirs(args.img_dir) 106 | for i, (image_name, caption) in enumerate(test_loader): 107 | image_name = image_name[0] 108 | caption = caption[0] 109 | image = get_image( 110 | pipe=base, 111 | prompt=caption, 112 | num_inference_steps=n_steps, 113 | denoising_end=high_noise_frac, 114 | output_type="latent", 115 | generator=generator, 116 | ).images 117 | 118 | image = image2image( 119 | pipe=refiner, 120 | prompt=caption, 121 | num_inference_steps=n_steps, 122 | denoising_start=high_noise_frac, 123 | image=image, 124 | generator=generator, 125 | ).images[0] 126 | 127 | image_name = f"{args.img_dir}/{image_name}" 128 | image.save(image_name) 129 | 130 | print("Done!") 131 | -------------------------------------------------------------------------------- /train/eval_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | from PIL import Image 4 | import sys 5 | sys.path.append('../..') 6 | import torch 7 | import torch.utils.data as data 8 | import os 9 | import numpy as np 10 | import clip 11 | from torchvision.datasets import CocoCaptions 12 | import pdb 13 | 14 | class docci_dataset(data.Dataset): 15 | def __init__(self, preprocess=None, filter='test'): 16 | self.image_root = 'evaluation_data/docci/test_images/' 17 | self.caption_root = 'evaluation_data/docci/docci_descriptions.jsonlines' 18 | self.total_image = os.listdir(self.image_root) 19 | self.total_caption = [] 20 | with open(self.caption_root, 'r') as f: 21 | for line in f: 22 | line = json.loads(line) 23 | if line['split'] == filter: 24 | self.total_caption.append(line) 25 | 26 | if preprocess is None: 27 | _, self.preprocess = clip.load("ViT-B/16") 28 | else: 29 | self.preprocess = preprocess 30 | 31 | def __len__(self): 32 | return len(self.total_caption) 33 | 34 | def __getitem__(self, index): 35 | caption_json = self.total_caption[index] 36 | image_name = caption_json['image_file'] 37 | image = Image.open(self.image_root + image_name) 38 | image_tensor = self.preprocess(image) 39 | caption = caption_json['description'] 40 | 41 | return image_tensor, caption 42 | 43 | class urban_dataset(data.Dataset): 44 | def __init__(self, preprocess=None): 45 | self.image_root = 'evaluation_data/Urban1k/image/' 46 | self.caption_root = 'evaluation_data/Urban1k/caption/' 47 | self.total_image = os.listdir(self.image_root) 48 | self.total_caption = os.listdir(self.caption_root) 49 | if preprocess is None: 50 | _, self.preprocess = clip.load("ViT-B/16") 51 | else: 52 | self.preprocess = preprocess 53 | 54 | def __len__(self): 55 | return len(self.total_caption) 56 | 57 | def __getitem__(self, index): 58 | caption_name = self.total_caption[index] 59 | image_name = self.total_caption[index][:-4] + '.jpg' 60 | image = Image.open(self.image_root + image_name) 61 | image_tensor = self.preprocess(image) 62 | f=open(self.caption_root + caption_name) 63 | caption = f.readlines()[0] 64 | 65 | return image_tensor, caption 66 | 67 | class coco_dataset(data.Dataset): 68 | def __init__(self, preprocess=None): 69 | self.image_root = 'evaluation_data/coco/val2017' 70 | self.caption_root = 'evaluation_data/coco/annotations/captions_val2017.json' 71 | self.coco_zipped = CocoCaptions(root=self.image_root, annFile=self.caption_root, transform=None) 72 | if preprocess is None: 73 | _, self.preprocess = clip.load("ViT-B/16") 74 | else: 75 | self.preprocess = preprocess 76 | 77 | def __len__(self): 78 | return len(self.coco_zipped) 79 | 80 | def __getitem__(self, index): 81 | org_img, org_caption = self.coco_zipped[index] 82 | image_tensor = self.preprocess(org_img) 83 | caption = org_caption[0:5] 84 | return image_tensor, caption 85 | 86 | class flickr_dataset(data.Dataset): 87 | def __init__(self, preprocess=None): 88 | self.image_root = 'evaluation_data/flickr30k/Images' 89 | self.caption_root = 'evaluation_data/flickr30k/results_20130124.token' 90 | self.zipped_dataset = self._get_list() 91 | if preprocess is None: 92 | _, self.preprocess = clip.load("ViT-B/16") 93 | else: 94 | self.preprocess = preprocess 95 | 96 | def __len__(self): 97 | return len(self.zipped_dataset) // 5 98 | 99 | def _get_list(self): 100 | with open(self.caption_root, 'r') as f: 101 | dataset_zipped = f.readlines() 102 | return dataset_zipped 103 | 104 | def __getitem__(self, index): 105 | data = self.zipped_dataset[index*5:(index+1)*5] 106 | image_name = data[0].split('\t')[0][:-2] 107 | caption = [data[i].split('\t')[1] for i in range(5)] 108 | image_full_path = Image.open(os.path.join(self.image_root,image_name)) 109 | image_tensor = self.preprocess(image_full_path) 110 | return image_tensor, caption 111 | 112 | if __name__ == "__main__": 113 | flickrDataset = flickr_dataset() 114 | pdb.set_trace() -------------------------------------------------------------------------------- /open_clip_long/zero_shot_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | from typing import Callable, List, Optional, Sequence, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def batched(iterable, n): 10 | """Batch data into lists of length *n*. The last batch may be shorter. 11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl 12 | """ 13 | it = iter(iterable) 14 | while True: 15 | batch = list(islice(it, n)) 16 | if not batch: 17 | break 18 | yield batch 19 | 20 | 21 | def build_zero_shot_classifier( 22 | model, 23 | tokenizer, 24 | classnames: Sequence[str], 25 | templates: Sequence[Union[Callable, str]], 26 | num_classes_per_batch: Optional[int] = 10, 27 | device: Union[str, torch.device] = 'cpu', 28 | use_tqdm: bool = False, 29 | ): 30 | """ Build zero-shot classifier weights by iterating over class names in batches 31 | Args: 32 | model: CLIP model instance 33 | tokenizer: CLIP tokenizer instance 34 | classnames: A sequence of class (label) names 35 | templates: A sequence of callables or format() friendly strings to produce templates per class name 36 | num_classes_per_batch: The number of classes to batch together in each forward, all if None 37 | device: Device to use. 38 | use_tqdm: Enable TQDM progress bar. 39 | """ 40 | assert isinstance(templates, Sequence) and len(templates) > 0 41 | assert isinstance(classnames, Sequence) and len(classnames) > 0 42 | use_format = isinstance(templates[0], str) 43 | num_templates = len(templates) 44 | num_classes = len(classnames) 45 | if use_tqdm: 46 | import tqdm 47 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) 48 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) 49 | else: 50 | iter_wrap = iter 51 | 52 | def _process_batch(batch_classnames): 53 | num_batch_classes = len(batch_classnames) 54 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] 55 | texts = tokenizer(texts).to(device) 56 | class_embeddings = model.encode_text(texts, normalize=True) 57 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) 58 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) 59 | class_embeddings = class_embeddings.T 60 | return class_embeddings 61 | 62 | with torch.no_grad(): 63 | if num_classes_per_batch: 64 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] 65 | zeroshot_weights = torch.cat(batched_embeds, dim=1) 66 | else: 67 | zeroshot_weights = _process_batch(classnames) 68 | return zeroshot_weights 69 | 70 | 71 | def build_zero_shot_classifier_legacy( 72 | model, 73 | tokenizer, 74 | classnames: Sequence[str], 75 | templates: Sequence[Union[Callable, str]], 76 | device: Union[str, torch.device] = 'cpu', 77 | use_tqdm: bool = False, 78 | ): 79 | """ Build zero-shot classifier weights by iterating over class names 1 by 1 80 | Args: 81 | model: CLIP model instance 82 | tokenizer: CLIP tokenizer instance 83 | classnames: A sequence of class (label) names 84 | templates: A sequence of callables or format() friendly strings to produce templates per class name 85 | device: Device to use. 86 | use_tqdm: Enable TQDM progress bar. 87 | """ 88 | assert isinstance(templates, Sequence) and len(templates) > 0 89 | assert isinstance(classnames, Sequence) and len(classnames) > 0 90 | if use_tqdm: 91 | import tqdm 92 | iter_wrap = tqdm.tqdm 93 | else: 94 | iter_wrap = iter 95 | 96 | use_format = isinstance(templates[0], str) 97 | 98 | with torch.no_grad(): 99 | zeroshot_weights = [] 100 | for classname in iter_wrap(classnames): 101 | texts = [template.format(classname) if use_format else template(classname) for template in templates] 102 | texts = tokenizer(texts).to(device) # tokenize 103 | class_embeddings = model.encode_text(texts) 104 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 105 | class_embedding /= class_embedding.norm() 106 | zeroshot_weights.append(class_embedding) 107 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 108 | 109 | return zeroshot_weights 110 | 111 | -------------------------------------------------------------------------------- /model/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /train/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def pos_neg_mask(labels): 7 | pos_mask = (labels.unsqueeze(0) == labels.unsqueeze(1)) 8 | neg_mask = labels.unsqueeze(0) != labels.unsqueeze(1) 9 | return pos_mask, neg_mask 10 | 11 | 12 | def get_sim(images, captions): 13 | similarities = images.mm(captions.t()) 14 | return similarities 15 | 16 | 17 | def loss_select(opt, loss_type='vse'): 18 | if loss_type == 'vse': 19 | # default loss 20 | criterion = ContrastiveLoss(opt=opt, margin=opt.margin, max_violation=opt.max_violation) 21 | elif loss_type == 'trip': 22 | opt.dataset = '' 23 | criterion = TripletLoss(opt=opt) 24 | else: 25 | raise ValueError('Invalid loss {}'.format(loss_type)) 26 | return criterion 27 | 28 | 29 | class ContrastiveLoss(nn.Module): 30 | 31 | def __init__(self, opt, margin=0.2, max_violation=False): 32 | super(ContrastiveLoss, self).__init__() 33 | self.opt = opt 34 | self.margin = margin 35 | self.max_violation = max_violation 36 | self.mask_repeat = True 37 | 38 | self.false_hard = [] 39 | 40 | def max_violation_on(self): 41 | self.max_violation = True 42 | 43 | def max_violation_off(self): 44 | self.max_violation = False 45 | 46 | def forward(self, im, s, img_ids=None, scores=None): 47 | 48 | # compute image-sentence score matrix 49 | if scores is None: 50 | scores = get_sim(im, s) 51 | 52 | diagonal = scores.diag().view(im.size(0), 1) 53 | d1 = diagonal.expand_as(scores) 54 | d2 = diagonal.t().expand_as(scores) 55 | 56 | # compare every diagonal score to scores in its column 57 | # caption retrieval, i->t 58 | cost_s = (self.margin + scores - d1).clamp(min=0) 59 | 60 | # compare every diagonal score to scores in its row 61 | # image retrieval t->i 62 | cost_im = (self.margin + scores - d2).clamp(min=0) 63 | 64 | # clear diagonals 65 | if not self.mask_repeat: 66 | mask = torch.eye(scores.size(0), dtype=torch.bool, device=scores.device) 67 | else: 68 | mask = (img_ids.unsqueeze(1) == img_ids.unsqueeze(0)) 69 | # repeat = len(img_ids) - len(torch.unique(img_ids)) 70 | 71 | cost_s = cost_s.masked_fill_(mask, 0) 72 | cost_im = cost_im.masked_fill_(mask, 0) 73 | 74 | # keep the maximum violating negative for each query 75 | if self.max_violation: 76 | cost_s, idx_s = cost_s.max(1) 77 | cost_im, idx_im = cost_im.max(0) 78 | 79 | loss = cost_s.sum() + cost_im.sum() 80 | 81 | return loss 82 | 83 | 84 | # Triplet loss + DistanceWeight Miner 85 | class TripletLoss(nn.Module): 86 | 87 | def __init__(self, opt=None, margin=0.2, ): 88 | super().__init__() 89 | 90 | self.opt = opt 91 | self.margin = margin 92 | 93 | self.cut_off = 0.5 94 | self.d = 512 95 | 96 | self.nonzero_loss_cutoff = 1.9 if opt.dataset == 'coco' else 1.7 97 | 98 | def forward(self, im, s, img_ids, sim_mat=None): 99 | 100 | if sim_mat is None: 101 | sim_mat = get_sim(im, s) 102 | 103 | pos_mask, neg_mask = pos_neg_mask(img_ids) 104 | 105 | loss_im = self.loss_forward(sim_mat, pos_mask, neg_mask) 106 | loss_s = self.loss_forward(sim_mat.t(), pos_mask.t(), neg_mask.t()) 107 | 108 | loss = loss_im + loss_s 109 | 110 | return loss 111 | 112 | def loss_forward(self, sim_mat, pos_mask, neg_mask): 113 | 114 | pos_pair_idx = pos_mask.nonzero(as_tuple=False) 115 | anchor_idx = pos_pair_idx[:, 0] 116 | pos_idx = pos_pair_idx[:, 1] 117 | 118 | # distance-based weight 119 | # This miner works well only with low dimensionality embeddings (e.g 64-dim) and L2-normalized distances. 120 | # Check out UniformHistogramMiner for a miner that is roughly equivalent, 121 | # but works with embeddings of any dimensionality and any distance metric. 122 | # from https://kevinmusgrave.github.io/pytorch-metric-learning/miners/ 123 | 124 | # our dimension is 1024, belong to high dimensionality 125 | dist = (2 - 2 * sim_mat).sqrt() 126 | dist = dist.clamp(min=self.cut_off) 127 | 128 | log_weight = (2.0 - self.d) * dist.log() - ((self.d - 3.0) / 2.0) * (1.0 - 0.25 * (dist * dist)).log() 129 | inf_or_nan = torch.isinf(log_weight) | torch.isnan(log_weight) 130 | 131 | log_weight = log_weight * neg_mask 132 | log_weight[inf_or_nan] = 0. 133 | 134 | weight = (log_weight - log_weight.max(dim=1, keepdim=True)[0]).exp() 135 | weight = weight * (neg_mask * (dist < self.nonzero_loss_cutoff)).float() 136 | 137 | weight = weight / (weight.sum(dim=1, keepdim=True) + 1e-20) 138 | weight = weight[anchor_idx] 139 | 140 | try: 141 | neg_idx = torch.multinomial(weight, 1).squeeze(1) 142 | except Exception: 143 | return torch.zeros([], requires_grad=True, device=sim_mat.device) 144 | 145 | 146 | s_ap = sim_mat[anchor_idx, pos_idx] 147 | s_an = sim_mat[anchor_idx, neg_idx] 148 | 149 | loss = F.relu(self.margin + s_an - s_ap) 150 | loss = loss.sum() 151 | 152 | return loss 153 | 154 | 155 | if __name__ == '__main__': 156 | 157 | pass -------------------------------------------------------------------------------- /open_clip_long/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | """ 31 | 32 | def __init__( 33 | self, 34 | model_name, 35 | embed_dim, 36 | image_size=224, 37 | pool='avg', 38 | proj='linear', 39 | proj_bias=False, 40 | drop=0., 41 | drop_path=None, 42 | patch_drop=None, 43 | pretrained=False, 44 | ): 45 | super().__init__() 46 | if timm is None: 47 | raise RuntimeError("Please `pip install timm` to use timm models.") 48 | self.image_size = to_2tuple(image_size) 49 | 50 | # setup kwargs that may not be common across all models 51 | timm_kwargs = {} 52 | if drop_path is not None: 53 | timm_kwargs['drop_path_rate'] = drop_path 54 | if patch_drop is not None: 55 | timm_kwargs['patch_drop_rate'] = patch_drop 56 | 57 | custom_pool = pool in ('abs_attn', 'rot_attn') 58 | if proj: 59 | assert proj in ("linear", "mlp", "none") 60 | extra_proj = proj in ("linear", "mlp") 61 | if not extra_proj and not custom_pool: 62 | # use network classifier head as projection if no proj specified and no custom pooling used 63 | # if projection is explicitly set to "none" will be pass through from network trunk 64 | proj_dim = 0 if proj == 'none' else embed_dim 65 | self.trunk = timm.create_model( 66 | model_name, 67 | num_classes=proj_dim, 68 | global_pool=pool, 69 | pretrained=pretrained, 70 | **timm_kwargs, 71 | ) 72 | prev_chs = embed_dim 73 | else: 74 | self.trunk = timm.create_model( 75 | model_name, 76 | pretrained=pretrained, 77 | **timm_kwargs, 78 | ) 79 | feat_size = self.trunk.default_cfg.get('pool_size', None) 80 | feature_ndim = 1 if not feat_size else 2 81 | if custom_pool: 82 | assert feature_ndim == 2 83 | # if attn pooling used, remove both classifier and default pool 84 | self.trunk.reset_classifier(0, global_pool='') 85 | else: 86 | # reset global pool if pool config set, otherwise leave as network default 87 | reset_kwargs = dict(global_pool=pool) if pool else {} 88 | self.trunk.reset_classifier(0, **reset_kwargs) 89 | prev_chs = self.trunk.num_features 90 | 91 | head_layers = OrderedDict() 92 | 93 | # Add custom pooling to head 94 | if pool == 'abs_attn': 95 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 96 | prev_chs = embed_dim 97 | elif pool == 'rot_attn': 98 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 99 | prev_chs = embed_dim 100 | 101 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 102 | if proj == 'linear': 103 | head_layers['drop'] = nn.Dropout(drop) 104 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 105 | elif proj == 'mlp': 106 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 107 | 108 | self.head = nn.Sequential(head_layers) 109 | 110 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 111 | """ lock modules 112 | Args: 113 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 114 | """ 115 | if not unlocked_groups: 116 | # lock full model 117 | for param in self.trunk.parameters(): 118 | param.requires_grad = False 119 | if freeze_bn_stats: 120 | freeze_batch_norm_2d(self.trunk) 121 | else: 122 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 123 | try: 124 | # FIXME import here until API stable and in an official release 125 | from timm.models.helpers import group_parameters, group_modules 126 | except ImportError: 127 | raise RuntimeError( 128 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 129 | matcher = self.trunk.group_matcher() 130 | gparams = group_parameters(self.trunk, matcher) 131 | max_layer_id = max(gparams.keys()) 132 | max_layer_id = max_layer_id - unlocked_groups 133 | for group_idx in range(max_layer_id + 1): 134 | group = gparams[group_idx] 135 | for param in group: 136 | self.trunk.get_parameter(param).requires_grad = False 137 | if freeze_bn_stats: 138 | gmodules = group_modules(self.trunk, matcher, reverse=True) 139 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 140 | freeze_batch_norm_2d(self.trunk, gmodules) 141 | 142 | @torch.jit.ignore 143 | def set_grad_checkpointing(self, enable=True): 144 | try: 145 | self.trunk.set_grad_checkpointing(enable) 146 | except Exception as e: 147 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 148 | 149 | def forward(self, x): 150 | x = self.trunk(x) 151 | x = self.head(x) 152 | return x 153 | -------------------------------------------------------------------------------- /open_clip_long/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | import re 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import TensorType 10 | 11 | try: 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 15 | BaseModelOutputWithPoolingAndCrossAttentions 16 | except ImportError as e: 17 | transformers = None 18 | 19 | 20 | class BaseModelOutput: 21 | pass 22 | 23 | 24 | class PretrainedConfig: 25 | pass 26 | 27 | from .hf_configs import arch_dict 28 | 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(? 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /open_clip_long/big_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .model import CustomTextCLIP 5 | from .transformer import TextTransformer, Transformer 6 | 7 | 8 | @torch.no_grad() 9 | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): 10 | """ Load weights from .npz checkpoints for official Google big_vision image-text models 11 | 12 | Currently the SigLIP source models are supported and a CustomTextCLIP destination model 13 | w/ timm image encoder. 14 | """ 15 | from timm.layers import resample_patch_embed, resample_abs_pos_embed 16 | 17 | def _n2p(w, t=True): 18 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 19 | w = w.flatten() 20 | if t: 21 | if w.ndim == 4: 22 | w = w.transpose([3, 2, 0, 1]) 23 | elif w.ndim == 3: 24 | w = w.transpose([2, 0, 1]) 25 | elif w.ndim == 2: 26 | w = w.transpose([1, 0]) 27 | return torch.from_numpy(w) 28 | 29 | w = np.load(checkpoint_path) 30 | interpolation = 'bilinear' 31 | antialias = False 32 | 33 | def _convert_timm_img(module, prefix): 34 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 35 | if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: 36 | embed_conv_w = resample_patch_embed( 37 | embed_conv_w, 38 | module.patch_embed.proj.weight.shape[-2:], 39 | interpolation=interpolation, 40 | antialias=antialias, 41 | verbose=True, 42 | ) 43 | module.patch_embed.proj.weight.copy_(embed_conv_w) 44 | module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 45 | 46 | if module.cls_token is not None: 47 | module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 48 | 49 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) 50 | if pos_embed_w.shape != module.pos_embed.shape: 51 | assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' 52 | num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) 53 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights 54 | pos_embed_w, 55 | new_size=module.patch_embed.grid_size, 56 | num_prefix_tokens=num_prefix_tokens, 57 | interpolation=interpolation, 58 | antialias=antialias, 59 | verbose=True, 60 | ) 61 | module.pos_embed.copy_(pos_embed_w) 62 | 63 | mha_sub, b_sub, ln1_sub = (0, 0, 1) 64 | for i, block in enumerate(module.blocks.children()): 65 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 66 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' 67 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 68 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 69 | block.attn.qkv.weight.copy_(torch.cat([ 70 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 71 | block.attn.qkv.bias.copy_(torch.cat([ 72 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 73 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 74 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 75 | for r in range(2): 76 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) 77 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) 78 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) 79 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) 80 | 81 | module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 82 | module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 83 | 84 | if module.attn_pool is not None: 85 | block_prefix = f'{prefix}MAPHead_0/' 86 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 87 | module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) 88 | module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) 89 | module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) 90 | module.attn_pool.kv.weight.copy_(torch.cat([ 91 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) 92 | module.attn_pool.kv.bias.copy_(torch.cat([ 93 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) 94 | module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 95 | module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 96 | module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 97 | module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 98 | for r in range(2): 99 | getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) 100 | getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) 101 | 102 | def _convert_openclip_transformer(module: Transformer, prefix): 103 | for i, block in enumerate(module.resblocks.children()): 104 | block_prefix = f'{prefix}encoderblock_{i}/' 105 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 106 | block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 107 | block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 108 | block.attn.in_proj_weight.copy_(torch.cat([ 109 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 110 | block.attn.in_proj_bias.copy_(torch.cat([ 111 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 112 | block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 113 | block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 114 | block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) 115 | block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) 116 | block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) 117 | block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) 118 | block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) 119 | block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) 120 | 121 | def _convert_openclip_txt(module: TextTransformer, prefix): 122 | module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) 123 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) 124 | module.positional_embedding.copy_(pos_embed_w) 125 | _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') 126 | module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) 127 | module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) 128 | module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 129 | module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 130 | 131 | _convert_timm_img(model.visual.trunk, 'params/img/') 132 | _convert_openclip_txt(model.text, 'params/txt/') 133 | model.logit_bias.copy_(_n2p(w['params/b'])[0]) 134 | model.logit_scale.copy_(_n2p(w['params/t'])[0]) 135 | 136 | 137 | -------------------------------------------------------------------------------- /model/xttn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.init 4 | import torch.nn.functional as F 5 | 6 | 7 | def get_padding_mask(features, lengths): 8 | 9 | with torch.no_grad(): 10 | max_len = features.shape[1] 11 | 12 | mask = torch.arange(max_len).expand(features.shape[0], max_len).to(features.device) 13 | 14 | # (B, L) 15 | mask = (mask < lengths.long().unsqueeze(1)) 16 | 17 | return mask 18 | 19 | 20 | def l2norm(X, dim, eps=1e-8): 21 | 22 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 23 | X = torch.div(X, norm) 24 | return X 25 | 26 | 27 | def func_attention(query, context, smooth=4, eps=1e-8, detach=True): 28 | 29 | queryT = torch.transpose(query, 1, 2) 30 | 31 | # (batch, sourceL, d) (batch, d, queryL) --> (batch, sourceL, queryL) 32 | attn = torch.bmm(context, queryT) 33 | 34 | attn = F.leaky_relu(attn, negative_slope=0.1) 35 | 36 | attn = l2norm(attn, 2) 37 | 38 | # --> (batch, queryL, sourceL) 39 | attn = torch.transpose(attn, 1, 2).contiguous() 40 | 41 | # --> (batch, queryL, sourceL) 42 | attn = F.softmax(attn*smooth, dim=2) 43 | 44 | # --> (batch, sourceL, queryL) 45 | attnT = torch.transpose(attn, 1, 2).contiguous() 46 | 47 | # --> (batch, d, sourceL) 48 | contextT = torch.transpose(context, 1, 2) 49 | 50 | # (batch x d x sourceL)(batch x sourceL x queryL) 51 | # --> (batch, d, queryL) 52 | weightedContext = torch.bmm(contextT, attnT) 53 | 54 | # --> (batch, queryL, d) 55 | weightedContext = torch.transpose(weightedContext, 1, 2) 56 | 57 | return weightedContext, attnT 58 | 59 | 60 | # Returns cosine similarity between x1 and x2, computed along dim 61 | def cosine_similarity(x1, x2, dim=1, eps=1e-8): 62 | 63 | x1 = F.normalize(x1, p=2, dim=dim) 64 | x2 = F.normalize(x2, p=2, dim=dim) 65 | 66 | w12 = torch.sum(x1 * x2, dim) 67 | 68 | return w12 69 | 70 | 71 | # SCAN-t2i 72 | def xattn_score_t2i(images, captions, cap_lens, smooth=9.0): 73 | 74 | similarities = [] 75 | n_image = images.size(0) 76 | n_caption = captions.size(0) 77 | 78 | images = F.normalize(images, dim=-1) 79 | captions = F.normalize(captions, dim=-1) 80 | 81 | for i in range(n_caption): 82 | 83 | # # --> (n_image, n_word, d) 84 | n_word = cap_lens[i] 85 | cap_i_expand = captions[i, :n_word, :].unsqueeze(0).repeat(n_image, 1, 1) 86 | 87 | weiContext, _ = func_attention(cap_i_expand, images, smooth=smooth, ) 88 | 89 | # (n_image, n_words) 90 | row_sim = cosine_similarity(cap_i_expand, weiContext, dim=2) 91 | row_sim = row_sim.mean(dim=1, keepdim=True) 92 | 93 | similarities.append(row_sim) 94 | 95 | # (n_image, n_caption) 96 | similarities = torch.cat(similarities, 1) 97 | 98 | return similarities 99 | 100 | 101 | # SCAN-i2t 102 | def xattn_score_i2t(images, captions, cap_lens, smooth=4): 103 | 104 | similarities = [] 105 | n_image = images.size(0) 106 | n_caption = captions.size(0) 107 | 108 | images = F.normalize(images, dim=-1) 109 | captions = F.normalize(captions, dim=-1) 110 | 111 | for i in range(n_caption): 112 | 113 | # # --> (n_image, n_word, d) 114 | n_word = cap_lens[i] 115 | cap_i_expand = captions[i, :n_word, :].unsqueeze(0).repeat(n_image, 1, 1) 116 | 117 | weiContext, _ = func_attention(images, cap_i_expand, smooth=smooth, ) 118 | 119 | # (n_image, n_region) 120 | row_sim = cosine_similarity(images, weiContext, dim=2) 121 | row_sim = row_sim.mean(dim=1, keepdim=True) 122 | 123 | similarities.append(row_sim) 124 | 125 | # (n_image, n_caption) 126 | similarities = torch.cat(similarities, 1) 127 | 128 | return similarities 129 | 130 | 131 | # SCAN bi-directional 132 | def xattn_score_two(images, captions, cap_lens, smooth_t2i=9, smooth_i2t=4): 133 | 134 | similarities = [] 135 | n_image = images.size(0) 136 | n_caption = captions.size(0) 137 | 138 | images = F.normalize(images, dim=-1) 139 | captions = F.normalize(captions, dim=-1) 140 | 141 | for i in range(n_caption): 142 | 143 | # # --> (n_image, n_word, d) 144 | n_word = cap_lens[i] 145 | cap_i_expand = captions[i, :n_word, :].unsqueeze(0).repeat(n_image, 1, 1) 146 | 147 | # t2i 148 | weiContext_t2i, _ = func_attention(cap_i_expand, images, smooth=smooth_t2i, ) 149 | row_sim_t2i = cosine_similarity(cap_i_expand, weiContext_t2i, dim=2).mean(dim=1, keepdim=True) 150 | 151 | # i2t 152 | weiContext_i2t, _ = func_attention(images, cap_i_expand, smooth=smooth_i2t, ) 153 | row_sim_i2t = cosine_similarity(images, weiContext_i2t, dim=2).mean(dim=1, keepdim=True) 154 | 155 | sims = (row_sim_t2i + row_sim_i2t) * 0.5 156 | similarities.append(sims) 157 | 158 | # (n_image, n_caption) 159 | similarities = torch.cat(similarities, 1) 160 | 161 | return similarities 162 | 163 | 164 | def matching_max_mean(img_regions, cap_words, cap_len, i2t=False, scan=False, bi_norm=False): 165 | 166 | similarities = [] 167 | 168 | img_regions = F.normalize(img_regions, dim=-1) 169 | cap_words = F.normalize(cap_words, dim=-1) 170 | 171 | if len(img_regions.shape) == 4: 172 | n_image = img_regions.size(1) 173 | img_regions_context = img_regions 174 | else: 175 | n_image = img_regions.size(0) 176 | img_regions_context = None 177 | 178 | n_caption = cap_words.size(0) 179 | 180 | # Each text is operated separately 181 | for i in range(n_caption): 182 | 183 | if img_regions_context: 184 | img_regions = img_regions_context[i] 185 | 186 | n_word = cap_len[i] 187 | # (n_images, cap_len, C) 188 | cap_i_expand = cap_words[i, :n_word, :].unsqueeze(0).repeat(n_image, 1, 1) 189 | 190 | # (n_images, cap_len, img_len) 191 | cap2img_sim = torch.bmm(cap_i_expand, img_regions.transpose(1, 2)) 192 | 193 | if scan: 194 | cap2img_sim = F.leaky_relu(cap2img_sim, negative_slope=0.1) 195 | 196 | cap2img_sim_norm = F.normalize(cap2img_sim, dim=1) if bi_norm else cap2img_sim 197 | 198 | # t2i 199 | # (n_images, cap_len) 200 | row_sim = cap2img_sim_norm.max(dim=2)[0] 201 | # (n_images, 1) 202 | row_sim_mean = row_sim.mean(dim=1, keepdim=True) 203 | 204 | if i2t: 205 | cap2img_sim_norm = F.normalize(cap2img_sim, dim=2) if bi_norm else cap2img_sim 206 | 207 | # (n_images, img_len) 208 | column_sim = cap2img_sim_norm.max(dim=1)[0] 209 | # (n_images, 1) 210 | column_sim_mean = column_sim.mean(dim=1, keepdim=True) 211 | 212 | similarities.append((row_sim_mean + column_sim_mean) * 0.5) 213 | else: 214 | similarities.append(row_sim_mean) 215 | 216 | # (n_image, n_caption) 217 | similarities = torch.cat(similarities, 1) 218 | 219 | return similarities 220 | 221 | 222 | # Only for one text 223 | # The required feature has been L2 regularized 224 | # img_mask (B_v, L_v) 225 | def mask_xattn_one_text(img_embs, cap_i_expand, img_mask=None, i2t=True, scan=True,): 226 | 227 | # (B_v, L_t, L_v) 228 | cap2img_sim = torch.bmm(cap_i_expand, img_embs.transpose(1, 2)) 229 | 230 | if scan: 231 | cap2img_sim = F.leaky_relu(cap2img_sim, negative_slope=0.1) 232 | 233 | # t2i 234 | # (B_v, L_t) 235 | if img_mask is None: 236 | row_sim = cap2img_sim.max(dim=2)[0] 237 | else: 238 | # Add a low value to the similarity of the masked patch location 239 | # to prevent it from being selected 240 | row_sim = (cap2img_sim - 1000 * (1 - img_mask).unsqueeze(1)).max(dim=2)[0] 241 | 242 | # (B_v, 1) 243 | row_sim_mean = row_sim.mean(dim=1, keepdim=True) 244 | 245 | if i2t: 246 | # i2t 247 | # (B_v, L_v) 248 | column_sim = cap2img_sim.max(dim=1)[0] 249 | 250 | if img_mask is None: 251 | column_sim_mean = column_sim.mean(dim=1, keepdim=True) 252 | else: 253 | # (B_v, 1) 254 | column_sim_mean = (column_sim * img_mask).sum(dim=-1, keepdim=True) / (img_mask.sum(dim=-1, keepdim=True) + 1e-8) 255 | 256 | sim_one_text = row_sim_mean + column_sim_mean 257 | else: 258 | sim_one_text = row_sim_mean 259 | 260 | return sim_one_text 261 | 262 | 263 | # different alignment functions 264 | def xattn_score(img_cross, cap_cross, cap_len, xattn_type='max_mean', i2t=True, scan=True): 265 | 266 | smooth_t2i = 9 267 | smooth_i2t = 4 268 | 269 | if xattn_type == 'scan_t2i': 270 | sim = xattn_score_t2i(img_cross, cap_cross, cap_len, smooth=smooth_t2i) 271 | elif xattn_type == 'scan_i2t': 272 | sim = xattn_score_i2t(img_cross, cap_cross, cap_len, smooth=smooth_i2t) 273 | elif xattn_type == 'scan_all': 274 | sim = xattn_score_two(img_cross, cap_cross, cap_len, smooth_t2i, smooth_i2t) 275 | else: 276 | sim = matching_max_mean(img_cross, cap_cross, cap_len, i2t=i2t, scan=scan) 277 | 278 | return sim 279 | 280 | 281 | if __name__ == '__main__': 282 | 283 | pass 284 | 285 | 286 | 287 | 288 | -------------------------------------------------------------------------------- /SDXL/fid_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as TF 8 | from PIL import Image 9 | from scipy import linalg 10 | from torch.nn.functional import adaptive_avg_pool2d 11 | import clip 12 | try: 13 | from tqdm import tqdm 14 | except ImportError: 15 | # If tqdm is not available, provide a mock version of it 16 | def tqdm(x): 17 | return x 18 | 19 | from inception import InceptionV3 20 | 21 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--batch-size', type=int, default=50, 23 | help='Batch size to use') 24 | parser.add_argument('--num-workers', type=int, 25 | help=('Number of processes to use for data loading. ' 26 | 'Defaults to `min(8, num_cpus)`')) 27 | parser.add_argument('--device', type=str, default="cuda", 28 | help='Device to use. Like cuda, cuda:0 or cpu') 29 | parser.add_argument('--dims', type=int, default=2048, 30 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 31 | help=('Dimensionality of Inception features to use. ' 32 | 'By default, uses pool3 features')) 33 | parser.add_argument('path', type=str, nargs=2, 34 | default=['../data/docci/test_images', 'images/clip'], 35 | help=('Paths to the generated images or ' 36 | 'to .npz statistic files')) 37 | 38 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 39 | 'tif', 'tiff', 'webp'} 40 | 41 | 42 | class ImagePathDataset(torch.utils.data.Dataset): 43 | def __init__(self, files, transforms=None): 44 | self.files = files 45 | self.transforms = transforms 46 | device = "cuda" if torch.cuda.is_available() else "cpu" 47 | _, self.preprocess = clip.load("ViT-B/32", device=device) 48 | 49 | def __len__(self): 50 | return len(self.files) 51 | 52 | def __getitem__(self, i): 53 | path = self.files[i] 54 | image = self.preprocess(Image.open(path)).unsqueeze(0) 55 | return image 56 | 57 | 58 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', 59 | num_workers=1): 60 | """Calculates the activations of the pool_3 layer for all images. 61 | Params: 62 | -- files : List of image files paths 63 | -- model : Instance of inception model 64 | -- batch_size : Batch size of images for the model to process at once. 65 | Make sure that the number of samples is a multiple of 66 | the batch size, otherwise some samples are ignored. This 67 | behavior is retained to match the original FID score 68 | implementation. 69 | -- dims : Dimensionality of features returned by Inception 70 | -- device : Device to run calculations 71 | -- num_workers : Number of parallel dataloader workers 72 | Returns: 73 | -- A numpy array of dimension (num images, dims) that contains the 74 | activations of the given tensor when feeding inception with the 75 | query tensor. 76 | """ 77 | model.eval() 78 | 79 | if batch_size > len(files): 80 | print(('Warning: batch size is bigger than the data size. ' 81 | 'Setting batch size to data size')) 82 | batch_size = len(files) 83 | 84 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 85 | dataloader = torch.utils.data.DataLoader(dataset, 86 | batch_size=batch_size, 87 | shuffle=False, 88 | drop_last=False, 89 | num_workers=num_workers) 90 | 91 | pred_arr = np.empty((len(files), 512)) 92 | 93 | start_idx = 0 94 | 95 | for batch in tqdm(dataloader): 96 | batch = batch.to(device) 97 | 98 | with torch.no_grad(): 99 | pred = model(batch)[0] 100 | 101 | # If model output is not scalar, apply global spatial average pooling. 102 | # This happens if you choose a dimensionality not equal 2048. 103 | if pred.size(2) != 1 or pred.size(3) != 1: 104 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 105 | 106 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 107 | 108 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 109 | 110 | start_idx = start_idx + pred.shape[0] 111 | 112 | return pred_arr 113 | 114 | 115 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 116 | """Numpy implementation of the Frechet Distance. 117 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 118 | and X_2 ~ N(mu_2, C_2) is 119 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 120 | Stable version by Dougal J. Sutherland. 121 | Params: 122 | -- mu1 : Numpy array containing the activations of a layer of the 123 | inception net (like returned by the function 'get_predictions') 124 | for generated samples. 125 | -- mu2 : The sample mean over activations, precalculated on an 126 | representative data set. 127 | -- sigma1: The covariance matrix over activations for generated samples. 128 | -- sigma2: The covariance matrix over activations, precalculated on an 129 | representative data set. 130 | Returns: 131 | -- : The Frechet Distance. 132 | """ 133 | 134 | mu1 = np.atleast_1d(mu1) 135 | mu2 = np.atleast_1d(mu2) 136 | 137 | sigma1 = np.atleast_2d(sigma1) 138 | sigma2 = np.atleast_2d(sigma2) 139 | 140 | assert mu1.shape == mu2.shape, \ 141 | 'Training and test mean vectors have different lengths' 142 | assert sigma1.shape == sigma2.shape, \ 143 | 'Training and test covariances have different dimensions' 144 | 145 | diff = mu1 - mu2 146 | 147 | # Product might be almost singular 148 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 149 | if not np.isfinite(covmean).all(): 150 | msg = ('fid calculation produces singular product; ' 151 | 'adding %s to diagonal of cov estimates') % eps 152 | print(msg) 153 | offset = np.eye(sigma1.shape[0]) * eps 154 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 155 | 156 | # Numerical error might give slight imaginary component 157 | if np.iscomplexobj(covmean): 158 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 159 | m = np.max(np.abs(covmean.imag)) 160 | raise ValueError('Imaginary component {}'.format(m)) 161 | covmean = covmean.real 162 | 163 | tr_covmean = np.trace(covmean) 164 | 165 | return (diff.dot(diff) + np.trace(sigma1) 166 | + np.trace(sigma2) - 2 * tr_covmean) 167 | 168 | 169 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 170 | device='cpu', num_workers=1): 171 | """Calculation of the statistics used by the FID. 172 | Params: 173 | -- files : List of image files paths 174 | -- model : Instance of inception model 175 | -- batch_size : The images numpy array is split into batches with 176 | batch size batch_size. A reasonable batch size 177 | depends on the hardware. 178 | -- dims : Dimensionality of features returned by Inception 179 | -- device : Device to run calculations 180 | -- num_workers : Number of parallel dataloader workers 181 | Returns: 182 | -- mu : The mean over samples of the activations of the pool_3 layer of 183 | the inception model. 184 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 185 | the inception model. 186 | """ 187 | act = get_activations(files, model, batch_size, dims, device, num_workers) 188 | mu = np.mean(act, axis=0) 189 | sigma = np.cov(act, rowvar=False) 190 | return mu, sigma 191 | 192 | 193 | def compute_statistics_of_path(path, model, batch_size, dims, device, 194 | num_workers=1): 195 | if path.endswith('.npz'): 196 | with np.load(path) as f: 197 | m, s = f['mu'][:], f['sigma'][:] 198 | else: 199 | path = pathlib.Path(path) 200 | files = sorted([file for ext in IMAGE_EXTENSIONS 201 | for file in path.glob('*.{}'.format(ext))]) 202 | 203 | m, s = calculate_activation_statistics(files, model, batch_size, 204 | dims, device, num_workers) 205 | 206 | return m, s 207 | 208 | 209 | def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): 210 | """Calculates the FID of two paths""" 211 | for p in paths: 212 | if not os.path.exists(p): 213 | raise RuntimeError('Invalid path: %s' % p) 214 | 215 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 216 | 217 | model = InceptionV3([block_idx]).to(device) 218 | 219 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 220 | dims, device, num_workers) 221 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 222 | dims, device, num_workers) 223 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 224 | 225 | return fid_value 226 | 227 | 228 | def main(): 229 | args = parser.parse_args() 230 | 231 | if args.device is None: 232 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 233 | else: 234 | device = torch.device(args.device) 235 | 236 | if args.num_workers is None: 237 | num_avail_cpus = len(os.sched_getaffinity(0)) 238 | num_workers = min(num_avail_cpus, 8) 239 | else: 240 | num_workers = args.num_workers 241 | 242 | fid_value = calculate_fid_given_paths(args.path, 243 | args.batch_size, 244 | device, 245 | args.dims, 246 | num_workers) 247 | print('FID: ', fid_value) 248 | 249 | 250 | if __name__ == '__main__': 251 | main() -------------------------------------------------------------------------------- /train/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import sys 4 | import os 5 | from torch.utils.data import DataLoader 6 | from torch.nn.functional import normalize 7 | 8 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 9 | ROOT_DIR = os.path.dirname(BASE_DIR) 10 | sys.path.append(ROOT_DIR) 11 | sys.path.append(os.path.join(ROOT_DIR, 'model')) 12 | sys.path.append("..") 13 | 14 | from fine_lip.model import finelip 15 | 16 | # import clip # for CLIP evaluation 17 | # from open_clip_long import factory as open_clip # for bigG model 18 | 19 | from sharegpt4v import share4v_val_dataset 20 | from eval_data import urban_dataset, coco_dataset, flickr_dataset, docci_dataset 21 | 22 | import argparse 23 | import numpy as np 24 | import warnings 25 | warnings.filterwarnings("ignore") 26 | 27 | class CLIP_Clean_Train(): 28 | def __init__(self, local_rank, args): 29 | self.local_rank = local_rank 30 | self.test_data = args.test_data 31 | self.model, self.preprocess = finelip.load(args.ckpt_path, device='cpu',run_finelip= args.run_finelip) 32 | # self.model, self.preprocess = clip.load('ViT-L/14', device='cpu') # for CLIP evaluation 33 | # self.model, _, self.preprocess = open_clip.create_model_and_transforms( 34 | # 'ViT-bigG-14', 35 | # pretrained="experiments/open_clip_pytorch_model.bin", 36 | # text_cfg={'context_length': 77, 'vocab_size': 49408, "width": 1280, "heads": 20, "layers": 32}, 37 | # ) 38 | torch.cuda.set_device(device=f'cuda:{local_rank}') 39 | self.model = self.model.float().cuda() 40 | 41 | self.finegrain = args.finegrain 42 | 43 | 44 | @torch.no_grad() 45 | def test_epoch(self): 46 | all_image_features, all_text_features = [], [] 47 | self.all_finegrain_image_features, self.all_finegrain_text_features = [], [] 48 | self.lengths = [] 49 | 50 | for id, (images, text) in enumerate(tqdm(self.testloader)): 51 | images = images.cuda() 52 | image_features_full = self.model.encode_image_full(images) 53 | image_features = image_features_full[:, 0, :] 54 | image_features /= image_features.norm(dim=-1, keepdim=True) 55 | 56 | batch_text_features, batch_text_features_full, batch_lengths = [], [], [] 57 | for cap_list in text: 58 | text_token = finelip.tokenize(cap_list, truncate=True).cuda() 59 | # text_token = clip.tokenize(cap_list, truncate=True).cuda() # for CLIP evaluation 60 | # openclip_tokenizer = open_clip.get_tokenizer('ViT-bigG-14', context_length=77) # for CLIP-bigG evaluation 61 | # text_token = openclip_tokenizer(cap_list).cuda() # for CLIP-bigG evaluation 62 | lengths = [torch.nonzero(text_token[i]).size(0) for i in range(text_token.shape[0])] 63 | text_features_full = self.model.encode_text_full(text_token) @ self.model.text_projection 64 | text_features = text_features_full[torch.arange(text_features_full.shape[0]), text_token.argmax(dim=-1)] 65 | text_features /= text_features.norm(dim=-1, keepdim=True) 66 | batch_text_features.append(text_features.unsqueeze(0)) 67 | batch_text_features_full.append(text_features_full.unsqueeze(0)) 68 | batch_lengths.append(torch.tensor(lengths).unsqueeze(0)) 69 | 70 | batch_text_features = torch.cat(batch_text_features, dim=0) 71 | batch_text_features_full = torch.cat(batch_text_features_full, dim=0) 72 | batch_lengths = torch.cat(batch_lengths, dim=0) 73 | text_features = batch_text_features.permute(1, 0, 2).reshape(-1, batch_text_features.shape[-1]) 74 | 75 | if self.finegrain: 76 | text_features_full = batch_text_features_full.permute(1, 0, 2, 3).reshape(-1, batch_text_features_full.shape[-2], batch_text_features_full.shape[-1]) 77 | lengths = batch_lengths.permute(1, 0).reshape(-1) 78 | self.all_finegrain_image_features.append(image_features_full) 79 | self.all_finegrain_text_features.append(text_features_full.cpu()) 80 | self.lengths.extend(lengths) 81 | 82 | all_image_features.append(image_features) 83 | all_text_features.append(text_features) 84 | 85 | all_image_features = torch.cat(all_image_features, dim=0) 86 | all_text_features = torch.cat(all_text_features, dim=0) 87 | self.logits_per_image = (all_image_features @ all_text_features.t()).detach().cpu() 88 | self.logits_per_text = self.logits_per_image.t() 89 | 90 | if self.test_data in ['share','urban', 'docci']: 91 | self.ground_truth = torch.arange(len(all_text_features)).view(-1, 1) 92 | results = self.get_metrics() 93 | else: 94 | results = self.get_metrics_1v5() 95 | 96 | return results 97 | 98 | def test(self): 99 | self.model.eval() 100 | if self.test_data == 'share': 101 | testset = share4v_val_dataset(self.preprocess) 102 | elif self.test_data == 'urban': 103 | testset = urban_dataset(self.preprocess) 104 | elif self.test_data == 'coco': 105 | testset = coco_dataset(self.preprocess) 106 | elif self.test_data == 'flickr': 107 | testset = flickr_dataset(self.preprocess) 108 | elif self.test_data == 'docci': 109 | testset = docci_dataset(self.preprocess) 110 | 111 | self.testloader = DataLoader(testset, batch_size=200, num_workers=8, pin_memory=True) #changed batch size from 1000 to 500 due to OOM 112 | with torch.no_grad(): 113 | metrics = self.test_epoch() 114 | print("=====================================") 115 | print(f"test mean of {self.test_data} retrieval") 116 | for k, v in metrics.items(): 117 | if "@" in k: 118 | print(f"{k} {format(v,'.4f')}") 119 | print("=====================================") 120 | 121 | return 122 | 123 | def save_logits(self, logits, name): 124 | if not os.path.exists(os.path.dirname(name)): 125 | os.makedirs(os.path.dirname(name)) 126 | np.save(name, logits.numpy()) 127 | 128 | def get_metrics(self): 129 | metrics = {} 130 | 131 | # self.save_logits(self.logits_per_image, f"{os.path.dirname(args.ckpt_path)}/{args.test_data}/logits_per_image.npy") 132 | if self.finegrain: 133 | self.all_finegrain_image_features = torch.cat(self.all_finegrain_image_features, dim=0) 134 | self.all_finegrain_text_features = torch.cat(self.all_finegrain_text_features, dim=0) 135 | 136 | finegrain_logits_per_image = torch.empty_like(self.logits_per_image) 137 | finegrain_logits_per_text = torch.empty_like(self.logits_per_text) 138 | chunk_size = 100 139 | for i in tqdm(range(0, self.all_finegrain_text_features.shape[0], chunk_size)): 140 | finegrain_logits_per_image[:, i:i+chunk_size] = self.model.cross_net.forward_dual_aggr(self.all_finegrain_image_features, self.all_finegrain_text_features[i:i+chunk_size].cuda(), self.lengths[i:i+chunk_size]).detach().cpu() 141 | 142 | finegrain_logits_per_text = finegrain_logits_per_image.t() 143 | 144 | alpha = 0.8 #[0.8, 0.2] 145 | # self.save_logits(finegrain_logits_per_image, f"{os.path.dirname(args.ckpt_path)}/{args.test_data}/finegrain_logits_per_image.npy") 146 | self.logits_per_image = alpha * self.logits_per_image + (1 - alpha) * finegrain_logits_per_image 147 | self.logits_per_text = alpha * self.logits_per_text + (1 - alpha) * finegrain_logits_per_text 148 | 149 | logits = {"image_to_text": self.logits_per_image, "text_to_image": self.logits_per_text} 150 | 151 | for name, logit in logits.items(): 152 | ranking = torch.argsort(logit, descending=True) 153 | preds = torch.where(ranking == self.ground_truth)[1] 154 | preds = preds.detach().cpu().numpy() 155 | metrics[f"{name}_mean_rank"] = preds.mean() + 1 156 | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 157 | for k in [1, 5, 10]: 158 | metrics[f"{name}_R@{k}"] = np.mean(preds < k) 159 | 160 | return metrics 161 | 162 | def get_metrics_1v5(self): 163 | metrics = {} 164 | 165 | # self.save_logits(self.logits_per_image, f"{os.path.dirname(args.ckpt_path)}/{args.test_data}/logits_per_image.npy") 166 | if self.finegrain: 167 | self.all_finegrain_image_features = torch.cat(self.all_finegrain_image_features, dim=0) 168 | self.all_finegrain_text_features = torch.cat(self.all_finegrain_text_features, dim=0) 169 | 170 | finegrain_logits_per_image = torch.empty_like(self.logits_per_image) 171 | finegrain_logits_per_text = torch.empty_like(self.logits_per_text) 172 | chunk_size = 500 173 | for i in tqdm(range(0, self.all_finegrain_text_features.shape[0], chunk_size)): 174 | finegrain_logits_per_image[:, i:i+chunk_size] = self.model.cross_net.forward_dual_aggr(self.all_finegrain_image_features, self.all_finegrain_text_features[i:i+chunk_size].cuda(), self.lengths[i:i+chunk_size]).detach().cpu() 175 | 176 | finegrain_logits_per_text = finegrain_logits_per_image.t() 177 | 178 | alpha = 0.8 #[0.8, 0.2] 179 | # self.save_logits(finegrain_logits_per_image, f"{os.path.dirname(args.ckpt_path)}/{args.test_data}/finegrain_logits_per_image.npy") 180 | self.logits_per_image = alpha * self.logits_per_image + (1 - alpha) * finegrain_logits_per_image 181 | self.logits_per_text = alpha * self.logits_per_text + (1 - alpha) * finegrain_logits_per_text 182 | 183 | for k in [1, 5, 10]: 184 | pred_true = 0 185 | for i in range(self.logits_per_image.shape[0]): 186 | pred = self.logits_per_image[i] 187 | values, topk = pred.topk(k) 188 | for j in range(5): 189 | true_index = 5*i + j 190 | if true_index in topk: 191 | pred_true = pred_true + 1 192 | break 193 | metrics[f"image_to_text_R@{k}"] = pred_true/self.logits_per_image.shape[0] 194 | 195 | for k in [1, 5, 10]: 196 | pred_true = 0 197 | for i in range(self.logits_per_text.shape[0]): 198 | pred = self.logits_per_text[i] 199 | values, topk = pred.topk(k) 200 | true_index = i//5 201 | if true_index in topk: 202 | pred_true = pred_true + 1 203 | metrics[f"text_to_image_R@{k}"] = pred_true/self.logits_per_text.shape[0] 204 | 205 | return metrics 206 | 207 | if __name__ == "__main__": 208 | parser = argparse.ArgumentParser(description='params') 209 | parser.add_argument('--ckpt_path', default='work/checkpoints/finelip-B.pt', help="ckpt_path") 210 | parser.add_argument('--test_data',default='urban', help='docci,urban,coco,and flicker') 211 | parser.add_argument('--run_finelip',action='store_true', help='run finelip model') 212 | parser.add_argument('--finegrain', action='store_true', help='enable finegrain evaluation') 213 | args = parser.parse_args() 214 | 215 | trainer = CLIP_Clean_Train( 216 | local_rank=0, 217 | args=args 218 | ) 219 | trainer.test() 220 | -------------------------------------------------------------------------------- /open_clip_long/push_to_hf_hub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from pathlib import Path 5 | from tempfile import TemporaryDirectory 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | 10 | try: 11 | from huggingface_hub import ( 12 | create_repo, 13 | get_hf_file_metadata, 14 | hf_hub_download, 15 | hf_hub_url, 16 | repo_type_and_id_from_hf_id, 17 | upload_folder, 18 | list_repo_files, 19 | ) 20 | from huggingface_hub.utils import EntryNotFoundError 21 | _has_hf_hub = True 22 | except ImportError: 23 | _has_hf_hub = False 24 | 25 | try: 26 | import safetensors.torch 27 | _has_safetensors = True 28 | except ImportError: 29 | _has_safetensors = False 30 | 31 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer 32 | from .tokenizer import HFTokenizer 33 | 34 | # Default name for a weights file hosted on the Huggingface Hub. 35 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl 36 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version 37 | HF_CONFIG_NAME = 'open_clip_config.json' 38 | 39 | 40 | def save_config_for_hf( 41 | model, 42 | config_path: str, 43 | model_config: Optional[dict] 44 | ): 45 | preprocess_cfg = { 46 | 'mean': model.visual.image_mean, 47 | 'std': model.visual.image_std, 48 | } 49 | other_pp = getattr(model.visual, 'preprocess_cfg', {}) 50 | if 'interpolation' in other_pp: 51 | preprocess_cfg['interpolation'] = other_pp['interpolation'] 52 | if 'resize_mode' in other_pp: 53 | preprocess_cfg['resize_mode'] = other_pp['resize_mode'] 54 | hf_config = { 55 | 'model_cfg': model_config, 56 | 'preprocess_cfg': preprocess_cfg, 57 | } 58 | 59 | with config_path.open('w') as f: 60 | json.dump(hf_config, f, indent=2) 61 | 62 | 63 | def save_for_hf( 64 | model, 65 | tokenizer: HFTokenizer, 66 | model_config: dict, 67 | save_directory: str, 68 | safe_serialization: Union[bool, str] = 'both', 69 | skip_weights : bool = False, 70 | ): 71 | config_filename = HF_CONFIG_NAME 72 | 73 | save_directory = Path(save_directory) 74 | save_directory.mkdir(exist_ok=True, parents=True) 75 | 76 | if not skip_weights: 77 | tensors = model.state_dict() 78 | if safe_serialization is True or safe_serialization == "both": 79 | assert _has_safetensors, "`pip install safetensors` to use .safetensors" 80 | safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) 81 | if safe_serialization is False or safe_serialization == "both": 82 | torch.save(tensors, save_directory / HF_WEIGHTS_NAME) 83 | 84 | tokenizer.save_pretrained(save_directory) 85 | 86 | config_path = save_directory / config_filename 87 | save_config_for_hf(model, config_path, model_config=model_config) 88 | 89 | 90 | def push_to_hf_hub( 91 | model, 92 | tokenizer, 93 | model_config: Optional[dict], 94 | repo_id: str, 95 | commit_message: str = 'Add model', 96 | token: Optional[str] = None, 97 | revision: Optional[str] = None, 98 | private: bool = False, 99 | create_pr: bool = False, 100 | model_card: Optional[dict] = None, 101 | safe_serialization: Union[bool, str] = False, 102 | ): 103 | if not isinstance(tokenizer, HFTokenizer): 104 | # FIXME this makes it awkward to push models with new tokenizers, come up with better soln. 105 | # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 106 | tokenizer = HFTokenizer('openai/clip-vit-large-patch14') 107 | 108 | # Create repo if it doesn't exist yet 109 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) 110 | 111 | # Infer complete repo_id from repo_url 112 | # Can be different from the input `repo_id` if repo_owner was implicit 113 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) 114 | repo_id = f"{repo_owner}/{repo_name}" 115 | 116 | # Check if repo already exists and determine what needs updating 117 | repo_exists = False 118 | repo_files = {} 119 | try: 120 | repo_files = set(list_repo_files(repo_id)) 121 | repo_exists = True 122 | except Exception as e: 123 | print('Repo does not exist', e) 124 | 125 | try: 126 | get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) 127 | has_readme = True 128 | except EntryNotFoundError: 129 | has_readme = False 130 | 131 | # Dump model and push to Hub 132 | with TemporaryDirectory() as tmpdir: 133 | # Save model weights and config. 134 | save_for_hf( 135 | model, 136 | tokenizer=tokenizer, 137 | model_config=model_config, 138 | save_directory=tmpdir, 139 | safe_serialization=safe_serialization, 140 | ) 141 | 142 | # Add readme if it does not exist 143 | if not has_readme: 144 | model_card = model_card or {} 145 | model_name = repo_id.split('/')[-1] 146 | readme_path = Path(tmpdir) / "README.md" 147 | readme_text = generate_readme(model_card, model_name) 148 | readme_path.write_text(readme_text) 149 | 150 | # Upload model and return 151 | return upload_folder( 152 | repo_id=repo_id, 153 | folder_path=tmpdir, 154 | revision=revision, 155 | create_pr=create_pr, 156 | commit_message=commit_message, 157 | ) 158 | 159 | 160 | def push_pretrained_to_hf_hub( 161 | model_name, 162 | pretrained: str, 163 | repo_id: str, 164 | precision: str = 'fp32', 165 | image_mean: Optional[Tuple[float, ...]] = None, 166 | image_std: Optional[Tuple[float, ...]] = None, 167 | image_interpolation: Optional[str] = None, 168 | image_resize_mode: Optional[str] = None, # only effective for inference 169 | commit_message: str = 'Add model', 170 | token: Optional[str] = None, 171 | revision: Optional[str] = None, 172 | private: bool = False, 173 | create_pr: bool = False, 174 | model_card: Optional[dict] = None, 175 | hf_tokenizer_self: bool = False, 176 | ): 177 | model, preprocess_eval = create_model_from_pretrained( 178 | model_name, 179 | pretrained=pretrained, 180 | precision=precision, 181 | image_mean=image_mean, 182 | image_std=image_std, 183 | image_interpolation=image_interpolation, 184 | image_resize_mode=image_resize_mode, 185 | ) 186 | model_config = get_model_config(model_name) 187 | assert model_config 188 | 189 | tokenizer = get_tokenizer(model_name) 190 | if hf_tokenizer_self: 191 | # make hf tokenizer config in the uploaded model point to self instead of original location 192 | model_config['text']['hf_tokenizer_name'] = repo_id 193 | 194 | push_to_hf_hub( 195 | model=model, 196 | tokenizer=tokenizer, 197 | model_config=model_config, 198 | repo_id=repo_id, 199 | commit_message=commit_message, 200 | token=token, 201 | revision=revision, 202 | private=private, 203 | create_pr=create_pr, 204 | model_card=model_card, 205 | safe_serialization='both', 206 | ) 207 | 208 | 209 | def generate_readme(model_card: dict, model_name: str): 210 | tags = model_card.pop('tags', ('clip',)) 211 | pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification') 212 | readme_text = "---\n" 213 | if tags: 214 | readme_text += "tags:\n" 215 | for t in tags: 216 | readme_text += f"- {t}\n" 217 | readme_text += "library_name: open_clip\n" 218 | readme_text += f"pipeline_tag: {pipeline_tag}\n" 219 | readme_text += f"license: {model_card.get('license', 'mit')}\n" 220 | if 'details' in model_card and 'Dataset' in model_card['details']: 221 | readme_text += 'datasets:\n' 222 | readme_text += f"- {model_card['details']['Dataset'].lower()}\n" 223 | readme_text += "---\n" 224 | readme_text += f"# Model card for {model_name}\n" 225 | if 'description' in model_card: 226 | readme_text += f"\n{model_card['description']}\n" 227 | if 'details' in model_card: 228 | readme_text += f"\n## Model Details\n" 229 | for k, v in model_card['details'].items(): 230 | if isinstance(v, (list, tuple)): 231 | readme_text += f"- **{k}:**\n" 232 | for vi in v: 233 | readme_text += f" - {vi}\n" 234 | elif isinstance(v, dict): 235 | readme_text += f"- **{k}:**\n" 236 | for ki, vi in v.items(): 237 | readme_text += f" - {ki}: {vi}\n" 238 | else: 239 | readme_text += f"- **{k}:** {v}\n" 240 | if 'usage' in model_card: 241 | readme_text += f"\n## Model Usage\n" 242 | readme_text += model_card['usage'] 243 | readme_text += '\n' 244 | 245 | if 'comparison' in model_card: 246 | readme_text += f"\n## Model Comparison\n" 247 | readme_text += model_card['comparison'] 248 | readme_text += '\n' 249 | 250 | if 'citation' in model_card: 251 | readme_text += f"\n## Citation\n" 252 | if not isinstance(model_card['citation'], (list, tuple)): 253 | citations = [model_card['citation']] 254 | else: 255 | citations = model_card['citation'] 256 | for c in citations: 257 | readme_text += f"```bibtex\n{c}\n```\n" 258 | 259 | return readme_text 260 | 261 | 262 | if __name__ == "__main__": 263 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") 264 | parser.add_argument( 265 | "--model", type=str, help="Name of the model to use.", 266 | ) 267 | parser.add_argument( 268 | "--pretrained", type=str, 269 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 270 | ) 271 | parser.add_argument( 272 | "--repo-id", type=str, 273 | help="Destination HF Hub repo-id ie 'organization/model_id'.", 274 | ) 275 | parser.add_argument( 276 | "--precision", type=str, default='fp32', 277 | ) 278 | parser.add_argument( 279 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', 280 | help='Override default image mean value of dataset') 281 | parser.add_argument( 282 | '--image-std', type=float, nargs='+', default=None, metavar='STD', 283 | help='Override default image std deviation of of dataset') 284 | parser.add_argument( 285 | '--image-interpolation', 286 | default=None, type=str, choices=['bicubic', 'bilinear', 'random'], 287 | help="image resize interpolation" 288 | ) 289 | parser.add_argument( 290 | '--image-resize-mode', 291 | default=None, type=str, choices=['shortest', 'longest', 'squash'], 292 | help="image resize mode during inference" 293 | ) 294 | parser.add_argument( 295 | "--hf-tokenizer-self", 296 | default=False, 297 | action="store_true", 298 | help="make hf_tokenizer_name point in uploaded config point to itself" 299 | ) 300 | args = parser.parse_args() 301 | 302 | print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') 303 | 304 | # FIXME add support to pass model_card json / template from file via cmd line 305 | 306 | push_pretrained_to_hf_hub( 307 | args.model, 308 | args.pretrained, 309 | args.repo_id, 310 | precision=args.precision, 311 | image_mean=args.image_mean, # override image mean/std if trained w/ non defaults 312 | image_std=args.image_std, 313 | image_interpolation=args.image_interpolation, 314 | image_resize_mode=args.image_resize_mode, 315 | ) 316 | 317 | print(f'{args.model} saved.') 318 | -------------------------------------------------------------------------------- /SDXL/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import clip 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=(DEFAULT_BLOCK_INDEX,), 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | Parameters 39 | ---------- 40 | output_blocks : list of int 41 | Indices of blocks to return features of. Possible values are: 42 | - 0: corresponds to output of first max pooling 43 | - 1: corresponds to output of second max pooling 44 | - 2: corresponds to output which is fed to aux classifier 45 | - 3: corresponds to output of final average pooling 46 | resize_input : bool 47 | If true, bilinearly resizes input to width and height 299 before 48 | feeding input to model. As the network without fully connected 49 | layers is fully convolutional, it should be able to handle inputs 50 | of arbitrary size, so resizing might not be strictly needed 51 | normalize_input : bool 52 | If true, scales the input from range (0, 1) to the range the 53 | pretrained Inception network expects, namely (-1, 1) 54 | requires_grad : bool 55 | If true, parameters of the model require gradients. Possibly useful 56 | for finetuning the network 57 | use_fid_inception : bool 58 | If true, uses the pretrained Inception model used in Tensorflow's 59 | FID implementation. If false, uses the pretrained Inception model 60 | available in torchvision. The FID Inception model has different 61 | weights and a slightly different structure from torchvision's 62 | Inception model. If you want to compute FID scores, you are 63 | strongly advised to set this parameter to true to get comparable 64 | results. 65 | """ 66 | super(InceptionV3, self).__init__() 67 | self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") 68 | self.resize_input = resize_input 69 | self.normalize_input = normalize_input 70 | self.output_blocks = sorted(output_blocks) 71 | self.last_needed_block = max(output_blocks) 72 | 73 | assert self.last_needed_block <= 3, \ 74 | 'Last possible output block index is 3' 75 | 76 | self.blocks = nn.ModuleList() 77 | 78 | if use_fid_inception: 79 | inception = fid_inception_v3() 80 | else: 81 | inception = _inception_v3(pretrained=True) 82 | 83 | # Block 0: input to maxpool1 84 | block0 = [ 85 | inception.Conv2d_1a_3x3, 86 | inception.Conv2d_2a_3x3, 87 | inception.Conv2d_2b_3x3, 88 | nn.MaxPool2d(kernel_size=3, stride=2) 89 | ] 90 | self.blocks.append(nn.Sequential(*block0)) 91 | 92 | # Block 1: maxpool1 to maxpool2 93 | if self.last_needed_block >= 1: 94 | block1 = [ 95 | inception.Conv2d_3b_1x1, 96 | inception.Conv2d_4a_3x3, 97 | nn.MaxPool2d(kernel_size=3, stride=2) 98 | ] 99 | self.blocks.append(nn.Sequential(*block1)) 100 | 101 | # Block 2: maxpool2 to aux classifier 102 | if self.last_needed_block >= 2: 103 | block2 = [ 104 | inception.Mixed_5b, 105 | inception.Mixed_5c, 106 | inception.Mixed_5d, 107 | inception.Mixed_6a, 108 | inception.Mixed_6b, 109 | inception.Mixed_6c, 110 | inception.Mixed_6d, 111 | inception.Mixed_6e, 112 | ] 113 | self.blocks.append(nn.Sequential(*block2)) 114 | 115 | # Block 3: aux classifier to final avgpool 116 | if self.last_needed_block >= 3: 117 | block3 = [ 118 | inception.Mixed_7a, 119 | inception.Mixed_7b, 120 | inception.Mixed_7c, 121 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 122 | ] 123 | self.blocks.append(nn.Sequential(*block3)) 124 | 125 | for param in self.parameters(): 126 | param.requires_grad = requires_grad 127 | 128 | def forward(self, inp): 129 | """Get Inception feature maps 130 | Parameters 131 | ---------- 132 | inp : torch.autograd.Variable 133 | Input tensor of shape Bx3xHxW. Values are expected to be in 134 | range (0, 1) 135 | Returns 136 | ------- 137 | List of torch.autograd.Variable, corresponding to the selected output 138 | block, sorted ascending by index 139 | """ 140 | outp = [] 141 | x = inp 142 | # print(x.shape) 143 | image_features = self.clip_model.encode_image(x.squeeze(1)) 144 | # print(image_features.shape) 145 | outp=[image_features.unsqueeze(2).unsqueeze(3)] 146 | # print(outp[0].shape) 147 | # if self.resize_input: 148 | # x = F.interpolate(x, 149 | # size=(299, 299), 150 | # mode='bilinear', 151 | # align_corners=False) 152 | 153 | # if self.normalize_input: 154 | # x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 155 | 156 | # for idx, block in enumerate(self.blocks): 157 | # x = block(x) 158 | # if idx in self.output_blocks: 159 | # outp.append(x) 160 | 161 | # if idx == self.last_needed_block: 162 | # break 163 | return outp 164 | 165 | 166 | def _inception_v3(*args, **kwargs): 167 | """Wraps `torchvision.models.inception_v3` 168 | Skips default weight inititialization if supported by torchvision version. 169 | See https://github.com/mseitzer/pytorch-fid/issues/28. 170 | """ 171 | try: 172 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 173 | except ValueError: 174 | # Just a caution against weird version strings 175 | version = (0,) 176 | 177 | if version >= (0, 6): 178 | kwargs['init_weights'] = False 179 | 180 | return torchvision.models.inception_v3(*args, **kwargs) 181 | 182 | 183 | def fid_inception_v3(): 184 | """Build pretrained Inception model for FID computation 185 | The Inception model for FID computation uses a different set of weights 186 | and has a slightly different structure than torchvision's Inception. 187 | This method first constructs torchvision's Inception and then patches the 188 | necessary parts that are different in the FID Inception model. 189 | """ 190 | inception = _inception_v3(num_classes=1008, 191 | aux_logits=False, 192 | pretrained=False) 193 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 194 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 195 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 196 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 197 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 198 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 199 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 200 | inception.Mixed_7b = FIDInceptionE_1(1280) 201 | inception.Mixed_7c = FIDInceptionE_2(2048) 202 | 203 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 204 | inception.load_state_dict(state_dict) 205 | return inception 206 | 207 | 208 | class FIDInceptionA(torchvision.models.inception.InceptionA): 209 | """InceptionA block patched for FID computation""" 210 | def __init__(self, in_channels, pool_features): 211 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 212 | 213 | def forward(self, x): 214 | branch1x1 = self.branch1x1(x) 215 | 216 | branch5x5 = self.branch5x5_1(x) 217 | branch5x5 = self.branch5x5_2(branch5x5) 218 | 219 | branch3x3dbl = self.branch3x3dbl_1(x) 220 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 221 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 222 | 223 | # Patch: Tensorflow's average pool does not use the padded zero's in 224 | # its average calculation 225 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 226 | count_include_pad=False) 227 | branch_pool = self.branch_pool(branch_pool) 228 | 229 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 230 | return torch.cat(outputs, 1) 231 | 232 | 233 | class FIDInceptionC(torchvision.models.inception.InceptionC): 234 | """InceptionC block patched for FID computation""" 235 | def __init__(self, in_channels, channels_7x7): 236 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 237 | 238 | def forward(self, x): 239 | branch1x1 = self.branch1x1(x) 240 | 241 | branch7x7 = self.branch7x7_1(x) 242 | branch7x7 = self.branch7x7_2(branch7x7) 243 | branch7x7 = self.branch7x7_3(branch7x7) 244 | 245 | branch7x7dbl = self.branch7x7dbl_1(x) 246 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 247 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 248 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 249 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 250 | 251 | # Patch: Tensorflow's average pool does not use the padded zero's in 252 | # its average calculation 253 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 254 | count_include_pad=False) 255 | branch_pool = self.branch_pool(branch_pool) 256 | 257 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 258 | return torch.cat(outputs, 1) 259 | 260 | 261 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 262 | """First InceptionE block patched for FID computation""" 263 | def __init__(self, in_channels): 264 | super(FIDInceptionE_1, self).__init__(in_channels) 265 | 266 | def forward(self, x): 267 | branch1x1 = self.branch1x1(x) 268 | 269 | branch3x3 = self.branch3x3_1(x) 270 | branch3x3 = [ 271 | self.branch3x3_2a(branch3x3), 272 | self.branch3x3_2b(branch3x3), 273 | ] 274 | branch3x3 = torch.cat(branch3x3, 1) 275 | 276 | branch3x3dbl = self.branch3x3dbl_1(x) 277 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 278 | branch3x3dbl = [ 279 | self.branch3x3dbl_3a(branch3x3dbl), 280 | self.branch3x3dbl_3b(branch3x3dbl), 281 | ] 282 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 283 | 284 | # Patch: Tensorflow's average pool does not use the padded zero's in 285 | # its average calculation 286 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 287 | count_include_pad=False) 288 | branch_pool = self.branch_pool(branch_pool) 289 | 290 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 291 | return torch.cat(outputs, 1) 292 | 293 | 294 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 295 | """Second InceptionE block patched for FID computation""" 296 | def __init__(self, in_channels): 297 | super(FIDInceptionE_2, self).__init__(in_channels) 298 | 299 | def forward(self, x): 300 | branch1x1 = self.branch1x1(x) 301 | 302 | branch3x3 = self.branch3x3_1(x) 303 | branch3x3 = [ 304 | self.branch3x3_2a(branch3x3), 305 | self.branch3x3_2b(branch3x3), 306 | ] 307 | branch3x3 = torch.cat(branch3x3, 1) 308 | 309 | branch3x3dbl = self.branch3x3dbl_1(x) 310 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 311 | branch3x3dbl = [ 312 | self.branch3x3dbl_3a(branch3x3dbl), 313 | self.branch3x3dbl_3b(branch3x3dbl), 314 | ] 315 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 316 | 317 | # Patch: The FID Inception model uses max pooling instead of average 318 | # pooling. This is likely an error in this specific Inception 319 | # implementation, as other Inception models use average pooling here 320 | # (which matches the description in the paper). 321 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 322 | branch_pool = self.branch_pool(branch_pool) 323 | 324 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 325 | return torch.cat(outputs, 1) -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from tqdm import tqdm 4 | import sys 5 | import os 6 | 7 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | ROOT_DIR = os.path.dirname(BASE_DIR) 9 | sys.path.append(ROOT_DIR) 10 | sys.path.append(os.path.join(ROOT_DIR, 'model')) 11 | 12 | from model import finelip 13 | from loss import loss_select 14 | sys.path.append("..") 15 | from arguments import get_args 16 | from sharegpt4v import share4v_val_dataset, share4v_train_dataset 17 | 18 | from torch.utils.data.distributed import DistributedSampler 19 | from scheduler import cosine_lr 20 | import subprocess 21 | import torch.optim as optim 22 | from torch.utils.tensorboard import SummaryWriter 23 | import numpy as np 24 | import warnings 25 | import wandb 26 | import random 27 | 28 | warnings.filterwarnings("ignore") 29 | 30 | def START_SEED(seed=71): 31 | np.random.seed(seed) 32 | random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) 36 | torch.backends.cudnn.benchmark = False 37 | torch.backends.cudnn.deterministic = True 38 | torch.use_deterministic_algorithms(True) 39 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 40 | 41 | def push_to_s3(local_path, s3_path): 42 | command = f"aws s3 cp {local_path} {s3_path}" 43 | subprocess.run(command, shell=True) 44 | 45 | def setup_distributed(backend="nccl", port=None): 46 | """Initialize distributed training environment. 47 | support both slurm and torch.distributed.launch 48 | see torch.distributed.init_process_group() for more details 49 | """ 50 | num_gpus = torch.cuda.device_count() 51 | 52 | if "SLURM_JOB_ID" in os.environ: 53 | rank = int(os.environ["SLURM_PROCID"]) 54 | world_size = int(os.environ["SLURM_NTASKS"]) 55 | node_list = os.environ["SLURM_NODELIST"] 56 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") 57 | # specify master port 58 | if port is not None: 59 | os.environ["MASTER_PORT"] = str(port) 60 | elif "MASTER_PORT" not in os.environ: 61 | os.environ["MASTER_PORT"] = "29522" 62 | if "MASTER_ADDR" not in os.environ: 63 | os.environ["MASTER_ADDR"] = addr 64 | os.environ["WORLD_SIZE"] = str(world_size) 65 | os.environ["LOCAL_RANK"] = str(rank % num_gpus) 66 | os.environ["RANK"] = str(rank) 67 | else: 68 | rank = int(os.environ["RANK"]) 69 | world_size = int(os.environ["WORLD_SIZE"]) 70 | 71 | dist.init_process_group( 72 | backend=backend, 73 | world_size=world_size, 74 | rank=rank, 75 | ) 76 | torch.cuda.set_device(device=f'cuda:{rank % num_gpus}') 77 | return rank % num_gpus 78 | 79 | def get_embed_size(vit_variant: str) -> int: 80 | vit_variant = vit_variant.lower() 81 | 82 | if "bigg" in vit_variant: 83 | return 1280 84 | elif "l" in vit_variant: 85 | return 768 86 | elif "b" in vit_variant: 87 | return 512 88 | else: 89 | raise ValueError(f"Unknown ViT variant: {vit_variant}") 90 | 91 | class CLIP_Clean_Train(): 92 | def __init__(self, args, local_rank=0): 93 | self.args = args 94 | self.local_rank = local_rank 95 | self.exp_name = args.exp_name 96 | self.base_model = args.base_model 97 | self.model, _ = finelip.load_from_clip(self.base_model, device='cpu', run_finelip= not self.args.run_baseline) 98 | args.embed_size = get_embed_size(vit_variant=self.base_model) 99 | if not self.args.run_baseline: self.model.cross_net.__init__(opt=args) 100 | self.model.criterion = loss_select(opt=args, loss_type=args.loss_finegrain) 101 | self.model.train() 102 | self.model.logit_scale = torch.nn.Parameter(torch.ones([]) * args.log_scale) 103 | self.model = self.model.cuda() 104 | 105 | self.batch_size = args.global_batch_size // torch.cuda.device_count() 106 | self.accumulation_steps = 512 // args.global_batch_size 107 | self.num_epoch = args.epochs 108 | self.lr = args.lr 109 | self.weight_decay = args.weight_decay 110 | self.warmup_length = args.warmup_length 111 | self.logdir = f"experiments/{self.exp_name}" 112 | self.ckptdir = self.logdir + "/ckpt" 113 | os.makedirs(self.ckptdir, exist_ok=True) 114 | 115 | if self.local_rank == 0: 116 | hyperparameter_defaults = { 117 | "weight_decay":args.weight_decay, 118 | "warmup_length":args.warmup_length, 119 | "log_scale":args.log_scale, 120 | "batch_size":self.batch_size, 121 | "lr":self.lr, 122 | "num_epoch":self.num_epoch, 123 | } 124 | # Report to wandb 125 | if args.enable_wandb: 126 | wandb.tensorboard.patch(root_logdir=self.logdir) 127 | wandb.init(config=hyperparameter_defaults, project="FineLIP", sync_tensorboard=True, save_code=True, name=self.exp_name) 128 | self.writer = SummaryWriter(self.logdir) 129 | 130 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 131 | if self.args.run_baseline: 132 | self.optimizer = optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) # use this for baseline 133 | else: 134 | self.optimizer = self.create_optimizer() 135 | self.scaler = torch.cuda.amp.grad_scaler.GradScaler() 136 | 137 | def create_optimizer(self): 138 | finelip_params = [] 139 | for n, p in self.model.named_parameters(): 140 | if not any(nd in n for nd in ["cross_net", "criterion"]): 141 | finelip_params.append(p) 142 | param_groups = [ 143 | {'params': finelip_params, 'lr': self.lr}, 144 | {'params': self.model.module.cross_net.parameters(), 'lr': self.args.cross_net_lr} 145 | ] 146 | self.optimizer = optim.AdamW(param_groups, weight_decay=self.weight_decay) 147 | return self.optimizer 148 | 149 | def resume_checkpoint(self, checkpoint_path): 150 | state_dict = torch.load(checkpoint_path, map_location='cpu') 151 | remove = checkpoint_path[-21:-13] 152 | checkpoint = torch.load(checkpoint_path.replace('.pt', '_other.pt').replace(remove, ''), map_location='cpu') 153 | 154 | self.model.module.load_state_dict(state_dict) 155 | self.optimizer.load_state_dict(checkpoint['optimizer']) 156 | self.scaler.load_state_dict(checkpoint['scaler']) 157 | return checkpoint['epoch'] 158 | 159 | def save_checkpoint(self, epoch): 160 | if self.base_model == "ViT-B/16": 161 | name = 'finelip-B.pt' 162 | elif self.base_model == "ViT-L/14": 163 | name = 'finelip-L.pt' 164 | else: 165 | name = "finelip-others.pt" 166 | 167 | experiment_name = f'{self.ckptdir}/{self.exp_name}_{self.args.global_batch_size}_epoch_{epoch+1}_{name}' 168 | state_dict = self.model.module.state_dict() 169 | other_state_dict = { 170 | 'epoch': epoch + 1, 171 | 'optimizer': self.optimizer.state_dict(), 172 | 'scaler': self.scaler.state_dict() 173 | } 174 | torch.save(state_dict, experiment_name) 175 | torch.save(other_state_dict, experiment_name.replace('.pt', '_other.pt').replace(f'epoch_{epoch+1}', '')) 176 | if self.args.s3_bucket != None: 177 | push_to_s3(experiment_name, self.args.s3_bucket) 178 | print(f"saved model to {experiment_name}") 179 | 180 | def train_epoch(self, dataloader, epoch, start_iter=0): 181 | running_loss = 0.0 182 | loss_1, loss_3 = 0.0, 0.0 183 | num_batches_per_epoch = len(dataloader) 184 | self.optimizer.zero_grad() 185 | for i, (images, texts, short_text, img_ids) in enumerate(tqdm(dataloader, disable=(self.local_rank != 0))): 186 | step = num_batches_per_epoch * epoch + i 187 | if step < start_iter: 188 | continue 189 | img_ids = img_ids.cuda() 190 | images = images.cuda() 191 | texts = finelip.tokenize(texts, truncate=True).cuda() 192 | warmup_alpha = float(i) / num_batches_per_epoch if epoch == self.args.embedding_warmup_epochs else 1.0 193 | 194 | loss_1, loss_3 = self.model(images, texts, img_ids, 195 | warmup_alpha, self.local_rank) 196 | 197 | if torch.isnan(loss_3) or torch.isinf(loss_3): 198 | print("loss_3 is NaN or Inf") 199 | loss_3 = torch.zeros([], requires_grad=True, device=images.device) 200 | 201 | loss = (loss_1 + loss_3) / self.accumulation_steps # Normalize our loss (if averaged) 202 | loss.backward() 203 | 204 | if (i+1) % self.accumulation_steps == 0: # Wait for several backward steps 205 | self.optimizer.step() # Now we can do an optimizer step 206 | self.optimizer.zero_grad() 207 | self.scheduler(step) 208 | 209 | running_loss += loss.item() 210 | 211 | dist.all_reduce(loss) 212 | loss = loss.item() / torch.distributed.get_world_size() 213 | 214 | if step % 1000 == 0: 215 | if self.local_rank == 0: 216 | print("=====================================") 217 | for i, param_group in enumerate(self.optimizer.param_groups): 218 | print(f"train lr_{i} step {step}: {param_group['lr']}") 219 | self.writer.add_scalar(f"hyper/lr_{i}", param_group['lr'], step) 220 | print(f"train logit_scale step {step}: {self.model.module.logit_scale.item()}") 221 | self.writer.add_scalar("logit_scale/train", self.model.module.logit_scale.item(), step) 222 | print(f"train loss step {step}: {loss}") 223 | self.writer.add_scalar("Loss/train", loss, step) 224 | print(f"train loss lvl1 step {step}: {loss_1}") 225 | self.writer.add_scalar("Loss 1/train", loss_1, step) 226 | print(f"train loss lvl3 step {step}: {loss_3}") 227 | self.writer.add_scalar("Loss 3/train", loss_3, step) 228 | print("=====================================") 229 | 230 | # with torch.no_grad(): 231 | # self.model.eval() 232 | # self.test() 233 | # self.model.train() 234 | 235 | return running_loss / num_batches_per_epoch 236 | 237 | @torch.no_grad() 238 | def test_epoch(self, dataloader): 239 | for id, (images, text) in enumerate(tqdm(dataloader, disable=(self.local_rank != 0))): 240 | 241 | images = images.cuda() 242 | image_features = self.model.module.encode_image(images) 243 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 244 | 245 | text = finelip.tokenize(text, truncate=True).cuda() 246 | text_feature = self.model.module.encode_text(text) 247 | text_feature /= text_feature.norm(dim=-1, keepdim=True) 248 | 249 | # i = 0 250 | correct = 0 251 | total = 0 252 | 253 | for i in range(text_feature.shape[0]): 254 | text = text_feature[i] 255 | sim = text @ image_features.T 256 | sim = sim.squeeze() 257 | correct_i = torch.argmax(sim) 258 | 259 | if i==correct_i: 260 | correct = correct + 1 261 | total = total + 1 262 | 263 | return correct/total 264 | 265 | def test(self): 266 | if self.local_rank == 0: 267 | self.model.eval() 268 | testset = share4v_val_dataset() 269 | testloader = torch.utils.data.DataLoader(testset, batch_size=1000, num_workers=32, pin_memory=True) 270 | with torch.no_grad(): 271 | 272 | acc = self.test_epoch(testloader) 273 | print("=====================================") 274 | print(f"test mean of share4v retrieval: {acc}") 275 | print("=====================================") 276 | 277 | return 278 | 279 | def train(self, resume=False, warmup_length=200): 280 | trainset = share4v_train_dataset() 281 | train_sampler = DistributedSampler(dataset=trainset, shuffle=True) 282 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, sampler=train_sampler, num_workers=32, pin_memory=True) 283 | 284 | lrs = [p["lr"] for p in self.optimizer.param_groups] 285 | self.scheduler = cosine_lr(self.optimizer, base_lrs=lrs, warmup_length=warmup_length, steps=(self.num_epoch * len(train_loader))/self.accumulation_steps) 286 | if resume: 287 | start_epoch, resume_iter = self.resume_checkpoint(self.args.resume_path) 288 | else: 289 | start_epoch = 0 290 | resume_iter = 0 291 | 292 | for epoch in range(start_epoch, self.num_epoch): 293 | loss = self.train_epoch(train_loader, epoch, start_iter=resume_iter) 294 | print("=====================================") 295 | print(f"loss: {loss} after training epoch: {epoch+1}") 296 | print("=====================================") 297 | if self.local_rank == 0: 298 | self.save_checkpoint(epoch) 299 | 300 | if __name__ == "__main__": 301 | parser = get_args() 302 | args = parser.parse_args() 303 | START_SEED(args.seed) 304 | 305 | local_rank = setup_distributed() 306 | print("DDP Done") 307 | if local_rank == 0: 308 | print(f"args: {args}") 309 | 310 | trainer = CLIP_Clean_Train( 311 | args=args, 312 | local_rank=local_rank 313 | ) 314 | trainer.train(resume=(args.resume_path != None)) 315 | torch.distributed.destroy_process_group() --------------------------------------------------------------------------------