├── README.md ├── figs └── fineclip_arch.png ├── metadata ├── coco_panoptic_clip_hand_craft_EVACLIP_ViTB16.npy └── coco_panoptic_clip_hand_craft_EVACLIP_ViTL14x336.npy ├── open_clip ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── coca_model.cpython-38.pyc │ ├── coca_model.cpython-39.pyc │ ├── constants.cpython-38.pyc │ ├── constants.cpython-39.pyc │ ├── factory.cpython-38.pyc │ ├── factory.cpython-39.pyc │ ├── hf_configs.cpython-38.pyc │ ├── hf_configs.cpython-39.pyc │ ├── hf_model.cpython-38.pyc │ ├── hf_model.cpython-39.pyc │ ├── loss.cpython-38.pyc │ ├── loss.cpython-39.pyc │ ├── model.cpython-38.pyc │ ├── model.cpython-39.pyc │ ├── modified_resnet.cpython-38.pyc │ ├── modified_resnet.cpython-39.pyc │ ├── openai.cpython-38.pyc │ ├── openai.cpython-39.pyc │ ├── pretrained.cpython-38.pyc │ ├── pretrained.cpython-39.pyc │ ├── push_to_hf_hub.cpython-38.pyc │ ├── push_to_hf_hub.cpython-39.pyc │ ├── timm_model.cpython-38.pyc │ ├── timm_model.cpython-39.pyc │ ├── tokenizer.cpython-38.pyc │ ├── tokenizer.cpython-39.pyc │ ├── transform.cpython-38.pyc │ ├── transform.cpython-39.pyc │ ├── transformer.cpython-38.pyc │ ├── transformer.cpython-39.pyc │ ├── utils.cpython-38.pyc │ ├── utils.cpython-39.pyc │ ├── version.cpython-38.pyc │ └── version.cpython-39.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── coca_model.py ├── constants.py ├── customs.py ├── eva_clip │ ├── .DS_Store │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── constants.cpython-38.pyc │ │ ├── constants.cpython-39.pyc │ │ ├── eva_vit_model.cpython-38.pyc │ │ ├── eva_vit_model.cpython-39.pyc │ │ ├── factory.cpython-38.pyc │ │ ├── factory.cpython-39.pyc │ │ ├── hf_configs.cpython-38.pyc │ │ ├── hf_configs.cpython-39.pyc │ │ ├── hf_model.cpython-38.pyc │ │ ├── hf_model.cpython-39.pyc │ │ ├── loss.cpython-38.pyc │ │ ├── loss.cpython-39.pyc │ │ ├── model.cpython-38.pyc │ │ ├── model.cpython-39.pyc │ │ ├── modified_resnet.cpython-38.pyc │ │ ├── modified_resnet.cpython-39.pyc │ │ ├── openai.cpython-38.pyc │ │ ├── openai.cpython-39.pyc │ │ ├── pretrained.cpython-38.pyc │ │ ├── pretrained.cpython-39.pyc │ │ ├── rope.cpython-38.pyc │ │ ├── rope.cpython-39.pyc │ │ ├── timm_model.cpython-38.pyc │ │ ├── timm_model.cpython-39.pyc │ │ ├── tokenizer.cpython-38.pyc │ │ ├── tokenizer.cpython-39.pyc │ │ ├── transform.cpython-38.pyc │ │ ├── transform.cpython-39.pyc │ │ ├── transformer.cpython-38.pyc │ │ ├── transformer.cpython-39.pyc │ │ ├── utils.cpython-38.pyc │ │ └── utils.cpython-39.pyc │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── constants.py │ ├── eva_vit_model.py │ ├── factory.py │ ├── hf_configs.py │ ├── hf_model.py │ ├── loss.py │ ├── model.py │ ├── model_configs │ │ ├── EVA01-CLIP-B-16.json │ │ ├── EVA01-CLIP-g-14-plus.json │ │ ├── EVA01-CLIP-g-14.json │ │ ├── EVA02-CLIP-B-16.json │ │ ├── EVA02-CLIP-L-14-336.json │ │ ├── EVA02-CLIP-L-14.json │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ └── EVA02-CLIP-bigE-14.json │ ├── modified_resnet.py │ ├── openai.py │ ├── pretrained.py │ ├── rope.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── transformer.py │ └── utils.py ├── factory.py ├── generation_utils.py ├── get_text_embedding_voc.py ├── hf_configs.py ├── hf_model.py ├── loss.py ├── model.py ├── model_configs │ ├── RN101-quickgelu.json │ ├── RN101.json │ ├── RN50-quickgelu.json │ ├── RN50.json │ ├── RN50x16.json │ ├── RN50x4.json │ ├── RN50x64.json │ ├── ViT-B-16-plus-240.json │ ├── ViT-B-16-plus.json │ ├── ViT-B-16.json │ ├── ViT-B-32-plus-256.json │ ├── ViT-B-32-quickgelu.json │ ├── ViT-B-32.json │ ├── ViT-H-14.json │ ├── ViT-H-16.json │ ├── ViT-L-14-280.json │ ├── ViT-L-14-336.json │ ├── ViT-L-14.json │ ├── ViT-L-16-320.json │ ├── ViT-L-16.json │ ├── ViT-M-16-alt.json │ ├── ViT-M-16.json │ ├── ViT-M-32-alt.json │ ├── ViT-M-32.json │ ├── ViT-S-16-alt.json │ ├── ViT-S-16.json │ ├── ViT-S-32-alt.json │ ├── ViT-S-32.json │ ├── ViT-bigG-14.json │ ├── ViT-e-14.json │ ├── ViT-g-14.json │ ├── coca_ViT-B-32.json │ ├── coca_ViT-L-14.json │ ├── coca_base.json │ ├── coca_roberta-ViT-B-32.json │ ├── convnext_base.json │ ├── convnext_base_w.json │ ├── convnext_base_w_320.json │ ├── convnext_large.json │ ├── convnext_large_d.json │ ├── convnext_large_d_320.json │ ├── convnext_small.json │ ├── convnext_tiny.json │ ├── convnext_xlarge.json │ ├── convnext_xxlarge.json │ ├── convnext_xxlarge_320.json │ ├── mt5-base-ViT-B-32.json │ ├── mt5-xl-ViT-H-14.json │ ├── roberta-ViT-B-32.json │ ├── swin_base_patch4_window7_224.json │ ├── vit_medium_patch16_gap_256.json │ ├── vit_relpos_medium_patch16_cls_224.json │ ├── xlm-roberta-base-ViT-B-32.json │ └── xlm-roberta-large-ViT-H-14.json ├── modified_resnet.py ├── openai.py ├── pretrained.py ├── push_to_hf_hub.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── transformer.py ├── utils.py └── version.py ├── requirements.txt ├── scripts ├── test_vitb16_box.sh ├── test_vitb16_flickr.sh ├── test_vitb16_mscoco.sh ├── test_vitl14_box.sh ├── test_vitl14_flickr.sh ├── test_vitl14_mscoco.sh ├── train_vitb16.sh └── train_vitl14.sh ├── tools └── generate_text_embeddings.py └── training ├── .gitignore ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── __init__.cpython-39.pyc ├── clipself.cpython-38.pyc ├── clipself.cpython-39.pyc ├── coco_api.cpython-38.pyc ├── coco_api.cpython-39.pyc ├── custom_transforms.cpython-38.pyc ├── custom_transforms.cpython-39.pyc ├── data.cpython-38.pyc ├── data.cpython-39.pyc ├── dist_utils.cpython-38.pyc ├── dist_utils.cpython-39.pyc ├── distributed.cpython-38.pyc ├── distributed.cpython-39.pyc ├── file_utils.cpython-38.pyc ├── file_utils.cpython-39.pyc ├── logger.cpython-38.pyc ├── logger.cpython-39.pyc ├── main.cpython-38.pyc ├── main.cpython-39.pyc ├── params.cpython-38.pyc ├── params.cpython-39.pyc ├── precision.cpython-38.pyc ├── precision.cpython-39.pyc ├── region_clip.cpython-38.pyc ├── region_clip.cpython-39.pyc ├── scheduler.cpython-38.pyc ├── scheduler.cpython-39.pyc ├── test_flickr.cpython-38.pyc ├── test_mscoco.cpython-38.pyc ├── train.cpython-38.pyc ├── train.cpython-39.pyc ├── utils.cpython-38.pyc ├── utils.cpython-39.pyc ├── zero_shot.cpython-38.pyc └── zero_shot.cpython-39.pyc ├── clipself.py ├── coco_api.py ├── custom_transforms.py ├── data.py ├── dist_utils.py ├── distributed.py ├── file_utils.py ├── logger.py ├── main.py ├── params.py ├── precision.py ├── profile.py ├── region_clip.py ├── scheduler.py ├── test_flickr.py ├── test_mscoco.py ├── train.py ├── utils.py └── zero_shot.py /README.md: -------------------------------------------------------------------------------- 1 | # FineCLIP 2 | 3 | ## Introduction 4 | 5 | Official Release of FineCLIP: Self-distilled Region-based CLIP for Better Fine-grained Understanding **(NIPS2024)** 6 | 7 | > [**FineCLIP: Self-distilled Region-based CLIP for Better Fine-grained Understanding**](https://openreview.net/pdf?id=nExI4FuKWD), 8 | > Dong Jing*, Xiaolong He*, Yutian Luo, Nanyi Fei, Guoxing Yang, Wei Wei, Huiwen Zhao, Zhiwu Lu† 9 | 10 | ![](figs/fineclip_arch.png) 11 | 12 | ## TODO 13 | - [x] Code for training and evaluating FineCLIP on COCO dataset 14 | - [ ] Data generation pipeline of CC2.5M 15 | 16 | ## Environment 17 | 18 | ```bash 19 | conda create -n fineclip python=3.8 20 | conda activate fineclip 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## Data Preparation 25 | 26 | The main ablation experiments are conducted using images from [COCO](https://cocodataset.org/#home) 27 | and [Flickr](https://shannon.cs.illinois.edu/DenotationGraph/) datasets. Please donwload and organize datasets like the following: 28 | 29 | ```text 30 | FineCLIP/ 31 | ├── data 32 | ├── coco 33 | ├── annotations 34 | ├── captions_train2017.json [Image Captions](https://drive.google.com/file/d/1gV7MZxQCkRbmC__FcVr0rGuNIdpacnwJ/view?usp=drive_link) 35 | ├── instances_train2017.json 36 | ├── panoptic_val2017.json 37 | ├── panoptic_val2017 38 | ├── train2017 39 | ├── val2017 40 | ├── coco_captions.json [Region Captions](https://drive.google.com/file/d/1hmXjZ1i8LMTkKg6JPnU5qfT1pIMulFqA/view?usp=drive_link) 41 | ├── coco_proposals.json 42 | ├── coco_test.json [Test Data](https://drive.google.com/file/d/1Qdqe8fINH79A53nb7D81KYski7iLsWyV/view?usp=drive_link) 43 | ├── flickr30k 44 | ├── flickr30k_images 45 | ├── flickr30k_test.json [Test Data](https://drive.google.com/file/d/1zluOdlZ9JJL0XufMIw6i6ls0iW2EXYPf/view?usp=drive_link) 46 | ``` 47 | 48 | ## Original Models 49 | To run FineCLIP, first obtain the original models from 50 | [EVA-02-CLIP](https://github.com/baaivision/EVA/tree/master/EVA-CLIP), and put them under 51 | `checkpoints/` like the following: 52 | 53 | ```text 54 | FineCLIP/ 55 | ├── checkpoints 56 | ├── EVA02_CLIP_B_psz16_s8B.pt 57 | ├── EVA02_CLIP_L_336_psz14_s6B.pt 58 | ``` 59 | 60 | 61 | ## Start Training 62 | 63 | After preparing data and model, we can start train FineCLIP by running scripts under [scripts/](scripts). 64 | 65 | If you want to train FineCLIP of ViT-B/16 on COCO, please run: 66 | ```bash 67 | bash scripts/train_vitb16.sh 68 | ``` 69 | 70 | If you want to train FineCLIP of ViT-L/14 on COCO, please run: 71 | ```bash 72 | bash scripts/train_vitl14.sh 73 | ``` 74 | 75 | 76 | ## Evaluation 77 | 78 | We provide the scripts to evaluate FineCLIP under [scripts/](scripts), they are summarized as follows: 79 | 80 | | # | Model Type | Task | Benchmark | script | 81 | |:---:|:----------:|:-----------------------:|:-------------:|:----------------------------------------:| 82 | | 1 | ViT-B/16 | Box Classification | COCO Panoptic | [script](scripts/test_vitb16_box.sh) | 83 | | 2 | ViT-B/16 | Image-text Retrieval | MSCOCO | [script](scripts/test_vitb16_mscoco.sh) | 84 | | 3 | ViT-B/16 | Image-text Retrieval | Flicker30K | [script](scripts/test_vitb16_flickr.sh) | 85 | | 4 | ViT-L/14 | Box Classification | COCO Panoptic | [script](scripts/test_vitl14_box.sh) | 86 | | 5 | ViT-L/14 | Image-text Retrieval | MSCOCO | [script](scripts/test_vitl14_mscoco.sh) | 87 | | 6 | ViT-L/14 | Image-text Retrieval | Flicker30K | [script](scripts/test_vitl14_flickr.sh) | 88 | 89 | 90 | ## FineCLIP Checkpoints 91 | 92 | We provide FineCLIP checkpoints trained on COCO :) 93 | 94 | [FineCLIP_coco_vitb16.pt](https://drive.google.com/file/d/119eeWzjsE2rpUFBs2rDMZ7QlK1RxZYL2/view?usp=drive_link) 95 | 96 | [FineCLIP_coco_vitl14.pt](https://drive.google.com/file/d/1lSIj5tWVVNVFNzuPHisOAdDQyCXWIzUH/view?usp=drive_link) 97 | 98 | ## Downstream Evaluations 99 | 100 | If you want to evaluate FineCLIP on downstream tasks like open-vocabulary object detection and segmention, please refer to [CLIPSelf](https://github.com/wusize/CLIPSelf/tree/main) and [CAT-Seg](https://github.com/cvlab-kaist/CAT-Seg). 101 | We thank their valuable code bases. 102 | 103 | ## Citation 104 | 105 | ```bibtex 106 | @article{fineclip, 107 | title={FineCLIP: Self-distilled Region-based CLIP for Better Fine-grained Understanding}, 108 | author={Dong Jing, Xiaolong He, Yutian Luo, Nanyi Fei, Guoxing Yang, Wei Wei, Huiwen Zhao, Zhiwu Lu}, 109 | journal={Advances in Neural Information Processing Systems}, 110 | year={2024} 111 | } 112 | ``` 113 | 114 | ## Acknowledgement 115 | 116 | We sincerely thank the following excellent works: 117 | [CLIPSelf](https://github.com/wusize/CLIPSelf/tree/main), 118 | [RegionCLIP](https://github.com/microsoft/RegionCLIP), 119 | [OpenCLIP](https://github.com/mlfoundations/open_clip/tree/v2.16.0), 120 | [EVA-CLIP](https://github.com/baaivision/EVA/tree/master/EVA-CLIP). 121 | 122 | ## License 123 | 124 | Creative Commons License 125 | -------------------------------------------------------------------------------- /figs/fineclip_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/figs/fineclip_arch.png -------------------------------------------------------------------------------- /metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTB16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTB16.npy -------------------------------------------------------------------------------- /metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTL14x336.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTL14x336.npy -------------------------------------------------------------------------------- /open_clip/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/.DS_Store -------------------------------------------------------------------------------- /open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 8 | from .openai import load_openai_model, list_openai_models 9 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 10 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 11 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 12 | from .tokenizer import SimpleTokenizer, tokenize, decode 13 | from .transform import image_transform, AugmentationCfg 14 | -------------------------------------------------------------------------------- /open_clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/coca_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/coca_model.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/coca_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/coca_model.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/constants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/constants.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/constants.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/constants.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/factory.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/factory.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/factory.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/hf_configs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/hf_configs.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/hf_configs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/hf_configs.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/hf_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/hf_model.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/hf_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/hf_model.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/modified_resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/modified_resnet.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/modified_resnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/modified_resnet.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/openai.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/openai.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/openai.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/openai.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/pretrained.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/pretrained.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/pretrained.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/pretrained.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/push_to_hf_hub.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/push_to_hf_hub.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/timm_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/timm_model.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/timm_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/timm_model.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/transform.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/transform.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/transform.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/transformer.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/version.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/version.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/__pycache__/version.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/version.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /open_clip/customs.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from torch.nn import MultiheadAttention 3 | from torch.nn import functional as F 4 | from typing import Optional, Tuple 5 | 6 | 7 | class MultiheadSelfAttention(MultiheadAttention): 8 | def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, 9 | need_weights: bool = True, attn_mask: Optional[Tensor] = None, return_tokens: bool = False) \ 10 | -> Tuple[Tensor, Tensor]: 11 | assert query is value and value is key # self-attention 12 | if return_tokens: 13 | # in_projection 14 | tokens = F.linear(value, self.in_proj_weight, bias=self.in_proj_bias)[..., -self.embed_dim:] 15 | # out_projection 16 | tokens = F.linear(tokens, self.out_proj.weight, bias=self.out_proj.bias) 17 | else: 18 | tokens = None 19 | 20 | attn_output, attn_output_weights = F.multi_head_attention_forward( 21 | query=query, key=key, value=value, 22 | embed_dim_to_check=self.embed_dim, 23 | num_heads=self.num_heads, 24 | in_proj_weight=self.in_proj_weight, 25 | in_proj_bias=self.in_proj_bias, 26 | bias_k=None, bias_v=None, 27 | add_zero_attn=False, 28 | dropout_p=0., 29 | out_proj_weight=self.out_proj.weight, 30 | out_proj_bias=self.out_proj.bias, 31 | training=self.training, 32 | key_padding_mask=key_padding_mask, need_weights=need_weights, 33 | attn_mask=attn_mask) 34 | 35 | return attn_output, tokens # , attn_output_weights 36 | -------------------------------------------------------------------------------- /open_clip/eva_clip/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/.DS_Store -------------------------------------------------------------------------------- /open_clip/eva_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\ 6 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 7 | from .openai import load_openai_model, list_openai_models 8 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ 9 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 10 | from .tokenizer import SimpleTokenizer, tokenize 11 | from .transform import image_transform -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/constants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/constants.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/constants.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/constants.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/eva_vit_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/eva_vit_model.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/eva_vit_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/eva_vit_model.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/factory.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/factory.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/factory.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/hf_configs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/hf_configs.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/hf_configs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/hf_configs.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/hf_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/hf_model.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/hf_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/hf_model.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/modified_resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/modified_resnet.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/modified_resnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/modified_resnet.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/openai.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/openai.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/openai.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/openai.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/pretrained.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/pretrained.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/pretrained.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/pretrained.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/rope.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/rope.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/rope.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/rope.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/timm_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/timm_model.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/timm_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/timm_model.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/transform.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/transform.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/transform.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/transformer.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /open_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip/eva_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /open_clip/eva_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings" 54 | }, 55 | "pooler": "mean_pooler", 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /open_clip/eva_clip/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | try: 7 | import torch.distributed.nn 8 | from torch import distributed as dist 9 | has_distributed = True 10 | except ImportError: 11 | has_distributed = False 12 | 13 | try: 14 | import horovod.torch as hvd 15 | except ImportError: 16 | hvd = None 17 | 18 | from timm.loss import LabelSmoothingCrossEntropy 19 | 20 | 21 | def gather_features( 22 | image_features, 23 | text_features, 24 | local_loss=False, 25 | gather_with_grad=False, 26 | rank=0, 27 | world_size=1, 28 | use_horovod=False, 29 | skip=True 30 | ): 31 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 32 | if use_horovod: 33 | assert hvd is not None, 'Please install horovod' 34 | if gather_with_grad: 35 | all_image_features = hvd.allgather(image_features) 36 | all_text_features = hvd.allgather(text_features) 37 | else: 38 | with torch.no_grad(): 39 | all_image_features = hvd.allgather(image_features) 40 | all_text_features = hvd.allgather(text_features) 41 | if not local_loss: 42 | # ensure grads for local rank when all_* features don't have a gradient 43 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 44 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 45 | gathered_image_features[rank] = image_features 46 | gathered_text_features[rank] = text_features 47 | all_image_features = torch.cat(gathered_image_features, dim=0) 48 | all_text_features = torch.cat(gathered_text_features, dim=0) 49 | else: 50 | # We gather tensors from all gpus 51 | if gather_with_grad: 52 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 53 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 54 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 55 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 56 | else: 57 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 58 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 59 | dist.all_gather(gathered_image_features, image_features) 60 | dist.all_gather(gathered_text_features, text_features) 61 | if not local_loss: 62 | # ensure grads for local rank when all_* features don't have a gradient 63 | gathered_image_features[rank] = image_features 64 | gathered_text_features[rank] = text_features 65 | all_image_features = torch.cat(gathered_image_features, dim=0) 66 | all_text_features = torch.cat(gathered_text_features, dim=0) 67 | if skip == False: 68 | return all_image_features, all_text_features 69 | all_image_feature_list = [] 70 | all_text_feature_list = [] 71 | for image_feature in all_image_features: 72 | if (image_feature==torch.zeros(image_feature.shape[0], device=image_feature.device)).all(): 73 | pass 74 | else: 75 | all_image_feature_list.append(image_feature.unsqueeze(0)) 76 | 77 | for text_feature in all_text_features: 78 | if (text_feature==torch.zeros(text_feature.shape[0], device=text_feature.device)).all(): 79 | pass 80 | else: 81 | all_text_feature_list.append(text_feature.unsqueeze(0)) 82 | all_image_features_new = torch.cat(all_image_feature_list, dim=0) 83 | all_text_features_new = torch.cat(all_text_feature_list, dim=0) 84 | return all_image_features_new, all_text_features_new 85 | # return all_image_features, all_text_features 86 | 87 | 88 | class ClipLoss(nn.Module): 89 | 90 | def __init__( 91 | self, 92 | local_loss=False, 93 | gather_with_grad=False, 94 | cache_labels=False, 95 | rank=0, 96 | world_size=1, 97 | use_horovod=False, 98 | smoothing=0., 99 | ): 100 | super().__init__() 101 | self.local_loss = local_loss 102 | self.gather_with_grad = gather_with_grad 103 | self.cache_labels = cache_labels 104 | self.rank = rank 105 | self.world_size = world_size 106 | self.use_horovod = use_horovod 107 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 108 | 109 | # cache state 110 | self.prev_num_logits = 0 111 | self.labels = {} 112 | 113 | def forward(self, image_features, text_features, logit_scale=1., skip=True): 114 | device = image_features.device 115 | if self.world_size > 1: 116 | all_image_features, all_text_features = gather_features( 117 | image_features, text_features, 118 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod, skip) 119 | 120 | if self.local_loss: 121 | logits_per_image = logit_scale * image_features @ all_text_features.T 122 | logits_per_text = logit_scale * text_features @ all_image_features.T 123 | else: 124 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 125 | logits_per_text = logits_per_image.T 126 | else: 127 | logits_per_image = logit_scale * image_features @ text_features.T 128 | logits_per_text = logit_scale * text_features @ image_features.T 129 | # calculated ground-truth and cache if enabled 130 | num_logits = logits_per_image.shape[0] 131 | if self.prev_num_logits != num_logits or device not in self.labels: 132 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 133 | if self.world_size > 1 and self.local_loss: 134 | labels = labels + num_logits * self.rank 135 | if self.cache_labels: 136 | self.labels[device] = labels 137 | self.prev_num_logits = num_logits 138 | else: 139 | labels = self.labels[device] 140 | 141 | if self.label_smoothing_cross_entropy: 142 | total_loss = ( 143 | self.label_smoothing_cross_entropy(logits_per_image, labels) + 144 | self.label_smoothing_cross_entropy(logits_per_text, labels) 145 | ) / 2 146 | else: 147 | total_loss = ( 148 | F.cross_entropy(logits_per_image, labels) + 149 | F.cross_entropy(logits_per_text, labels) 150 | ) / 2 151 | 152 | acc = None 153 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 154 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 155 | acc = {"i2t": i2t_acc, "t2i": t2i_acc} 156 | return total_loss, acc -------------------------------------------------------------------------------- /open_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /open_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /open_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /open_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /open_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /open_clip/eva_clip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from open_clip.eva_clip.utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /open_clip/eva_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = 'fp32' if device == 'cpu' else 'fp16' 56 | 57 | if get_pretrained_url(name, 'openai'): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith('amp') or precision == 'fp32': 87 | model.float() 88 | elif precision == 'bf16': 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == 'fp32': 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /open_clip/eva_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 31 | """ 32 | 33 | def __init__( 34 | self, 35 | model_name, 36 | embed_dim, 37 | image_size=224, 38 | pool='avg', 39 | proj='linear', 40 | proj_bias=False, 41 | drop=0., 42 | pretrained=False): 43 | super().__init__() 44 | if timm is None: 45 | raise RuntimeError("Please `pip install timm` to use timm models.") 46 | 47 | self.image_size = to_2tuple(image_size) 48 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 49 | feat_size = self.trunk.default_cfg.get('pool_size', None) 50 | feature_ndim = 1 if not feat_size else 2 51 | if pool in ('abs_attn', 'rot_attn'): 52 | assert feature_ndim == 2 53 | # if attn pooling used, remove both classifier and default pool 54 | self.trunk.reset_classifier(0, global_pool='') 55 | else: 56 | # reset global pool if pool config set, otherwise leave as network default 57 | reset_kwargs = dict(global_pool=pool) if pool else {} 58 | self.trunk.reset_classifier(0, **reset_kwargs) 59 | prev_chs = self.trunk.num_features 60 | 61 | head_layers = OrderedDict() 62 | if pool == 'abs_attn': 63 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 64 | prev_chs = embed_dim 65 | elif pool == 'rot_attn': 66 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 67 | prev_chs = embed_dim 68 | else: 69 | assert proj, 'projection layer needed if non-attention pooling is used.' 70 | 71 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 72 | if proj == 'linear': 73 | head_layers['drop'] = nn.Dropout(drop) 74 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 75 | elif proj == 'mlp': 76 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 77 | 78 | self.head = nn.Sequential(head_layers) 79 | 80 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 81 | """ lock modules 82 | Args: 83 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 84 | """ 85 | if not unlocked_groups: 86 | # lock full model 87 | for param in self.trunk.parameters(): 88 | param.requires_grad = False 89 | if freeze_bn_stats: 90 | freeze_batch_norm_2d(self.trunk) 91 | else: 92 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 93 | try: 94 | # FIXME import here until API stable and in an official release 95 | from timm.models.helpers import group_parameters, group_modules 96 | except ImportError: 97 | raise RuntimeError( 98 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 99 | matcher = self.trunk.group_matcher() 100 | gparams = group_parameters(self.trunk, matcher) 101 | max_layer_id = max(gparams.keys()) 102 | max_layer_id = max_layer_id - unlocked_groups 103 | for group_idx in range(max_layer_id + 1): 104 | group = gparams[group_idx] 105 | for param in group: 106 | self.trunk.get_parameter(param).requires_grad = False 107 | if freeze_bn_stats: 108 | gmodules = group_modules(self.trunk, matcher, reverse=True) 109 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 110 | freeze_batch_norm_2d(self.trunk, gmodules) 111 | 112 | @torch.jit.ignore 113 | def set_grad_checkpointing(self, enable=True): 114 | try: 115 | self.trunk.set_grad_checkpointing(enable) 116 | except Exception as e: 117 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 118 | 119 | def forward(self, x): 120 | x = self.trunk(x) 121 | x = self.head(x) 122 | return x 123 | -------------------------------------------------------------------------------- /open_clip/eva_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a signficant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | 156 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 157 | """ 158 | Returns the tokenized representation of given input string(s) 159 | 160 | Parameters 161 | ---------- 162 | texts : Union[str, List[str]] 163 | An input string or a list of input strings to tokenize 164 | context_length : int 165 | The context length to use; all CLIP models use 77 as the context length 166 | 167 | Returns 168 | ------- 169 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 170 | """ 171 | if isinstance(texts, str): 172 | texts = [texts] 173 | 174 | sot_token = _tokenizer.encoder[""] 175 | eot_token = _tokenizer.encoder[""] 176 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 177 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 178 | 179 | for i, tokens in enumerate(all_tokens): 180 | if len(tokens) > context_length: 181 | tokens = tokens[:context_length] # Truncate 182 | tokens[-1] = eot_token 183 | result[i, :len(tokens)] = torch.tensor(tokens) 184 | 185 | return result 186 | 187 | 188 | class HFTokenizer: 189 | "HuggingFace tokenizer wrapper" 190 | def __init__(self, tokenizer_name:str): 191 | from transformers import AutoTokenizer 192 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 193 | 194 | def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor: 195 | # same cleaning as for default tokenizer, except lowercasing 196 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 197 | if isinstance(texts, str): 198 | texts = [texts] 199 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 200 | input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids 201 | return input_ids 202 | -------------------------------------------------------------------------------- /open_clip/eva_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 8 | CenterCrop 9 | 10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 11 | 12 | 13 | class ResizeMaxSize(nn.Module): 14 | 15 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 16 | super().__init__() 17 | if not isinstance(max_size, int): 18 | raise TypeError(f"Size should be int. Got {type(max_size)}") 19 | self.max_size = max_size 20 | self.interpolation = interpolation 21 | self.fn = min if fn == 'min' else min 22 | self.fill = fill 23 | 24 | def forward(self, img): 25 | if isinstance(img, torch.Tensor): 26 | height, width = img.shape[:2] 27 | else: 28 | width, height = img.size 29 | scale = self.max_size / float(max(height, width)) 30 | if scale != 1.0: 31 | new_size = tuple(round(dim * scale) for dim in (height, width)) 32 | img = F.resize(img, new_size, self.interpolation) 33 | pad_h = self.max_size - new_size[0] 34 | pad_w = self.max_size - new_size[1] 35 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 36 | return img 37 | 38 | 39 | def _convert_to_rgb(image): 40 | return image.convert('RGB') 41 | 42 | 43 | # class CatGen(nn.Module): 44 | # def __init__(self, num=4): 45 | # self.num = num 46 | # def mixgen_batch(image, text): 47 | # batch_size = image.shape[0] 48 | # index = np.random.permutation(batch_size) 49 | 50 | # cat_images = [] 51 | # for i in range(batch_size): 52 | # # image mixup 53 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 54 | # # text concat 55 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 56 | # text = torch.stack(text) 57 | # return image, text 58 | 59 | 60 | def image_transform( 61 | image_size: int, 62 | is_train: bool, 63 | mean: Optional[Tuple[float, ...]] = None, 64 | std: Optional[Tuple[float, ...]] = None, 65 | resize_longest_max: bool = False, 66 | fill_color: int = 0, 67 | ): 68 | mean = mean or OPENAI_DATASET_MEAN 69 | if not isinstance(mean, (list, tuple)): 70 | mean = (mean,) * 3 71 | 72 | std = std or OPENAI_DATASET_STD 73 | if not isinstance(std, (list, tuple)): 74 | std = (std,) * 3 75 | 76 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 77 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 78 | image_size = image_size[0] 79 | 80 | normalize = Normalize(mean=mean, std=std) 81 | if is_train: 82 | return Compose([ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ]) 88 | else: 89 | if resize_longest_max: 90 | transforms = [ 91 | ResizeMaxSize(image_size, fill=fill_color) 92 | ] 93 | else: 94 | transforms = [ 95 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 96 | CenterCrop(image_size), 97 | ] 98 | transforms.extend([ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ]) 103 | return Compose(transforms) 104 | -------------------------------------------------------------------------------- /open_clip/generation_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/generation_utils.py -------------------------------------------------------------------------------- /open_clip/get_text_embedding_voc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | sys.path.append("..") 4 | from open_clip import create_model_and_transforms, get_tokenizer 5 | from PIL import Image 6 | import json 7 | import tqdm 8 | import os 9 | import numpy as np 10 | 11 | model_name = "EVA02-CLIP-B-16" 12 | pretrained = "/home/xiaolong_he/works/paper/logs/new_ours_only_itc_lr1e_5_bs32_epoch10/checkpoints/epoch_10.pt" # or "/path/to/EVA02_CLIP_B_psz16_s8B.pt" 13 | 14 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 15 | model, _, preprocess = create_model_and_transforms( 16 | "EVA02-CLIP-B-16", 17 | "eva", 18 | "amp", 19 | device="cpu", 20 | jit=False, 21 | force_quick_gelu=False, 22 | force_custom_text=False, 23 | force_patch_dropout=None, 24 | force_image_size=None, 25 | pretrained_image=False, 26 | image_mean=None, 27 | image_std=None, 28 | aug_cfg={}, 29 | output_dict=True, 30 | cache_dir=pretrained, 31 | det_image_size=224, 32 | dataset_type="grid_distill", 33 | ) 34 | 35 | 36 | tokenizer = get_tokenizer(model_name) 37 | model = model.to(device) 38 | text = torch.tensor(np.load("/home/xiaolong_he/works/paper/data/voc/text_tokens.npy")).to(device) 39 | 40 | text_features = model.encode_text(text) 41 | np.save("/home/xiaolong_he/works/paper/data/voc/text_embedding.npy", text_features.cpu().numpy()) 42 | print(text_features.shape) 43 | 44 | 45 | -------------------------------------------------------------------------------- /open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | } 46 | -------------------------------------------------------------------------------- /open_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | 6 | import re 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch import TensorType 11 | 12 | try: 13 | import transformers 14 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 15 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 16 | BaseModelOutputWithPoolingAndCrossAttentions 17 | except ImportError as e: 18 | transformers = None 19 | 20 | 21 | class BaseModelOutput: 22 | pass 23 | 24 | 25 | class PretrainedConfig: 26 | pass 27 | 28 | from .hf_configs import arch_dict 29 | 30 | 31 | # utils 32 | def _camel2snake(s): 33 | return re.sub(r'(? List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = 'fp32' if device == 'cpu' else 'fp16' 56 | print(get_pretrained_url(name, 'openai')) 57 | if get_pretrained_url(name, 'openai'): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith('amp') or precision == 'fp32': 87 | model.float() 88 | elif precision == 'bf16': 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == 'fp32': 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /open_clip/transform.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass, asdict 3 | from typing import Any, Dict, Optional, Sequence, Tuple, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms.functional as F 8 | 9 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 10 | CenterCrop 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | 14 | 15 | @dataclass 16 | class AugmentationCfg: 17 | scale: Tuple[float, float] = (0.9, 1.0) 18 | ratio: Optional[Tuple[float, float]] = None 19 | color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None 20 | interpolation: Optional[str] = None 21 | re_prob: Optional[float] = None 22 | re_count: Optional[int] = None 23 | use_timm: bool = False 24 | 25 | 26 | class ResizeMaxSize(nn.Module): 27 | 28 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 29 | super().__init__() 30 | if not isinstance(max_size, int): 31 | raise TypeError(f"Size should be int. Got {type(max_size)}") 32 | self.max_size = max_size 33 | self.interpolation = interpolation 34 | self.fn = min if fn == 'min' else min 35 | self.fill = fill 36 | 37 | def forward(self, img): 38 | if isinstance(img, torch.Tensor): 39 | height, width = img.shape[:2] 40 | else: 41 | width, height = img.size 42 | scale = self.max_size / float(max(height, width)) 43 | new_size = tuple(round(dim * scale) for dim in (height, width)) 44 | img = F.resize(img, new_size, self.interpolation) 45 | pad_h = self.max_size - new_size[0] 46 | pad_w = self.max_size - new_size[1] 47 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) 48 | 49 | return img 50 | 51 | 52 | def _convert_to_rgb(image): 53 | return image.convert('RGB') 54 | 55 | 56 | def image_transform( 57 | image_size: int, 58 | is_train: bool, 59 | mean: Optional[Tuple[float, ...]] = None, 60 | std: Optional[Tuple[float, ...]] = None, 61 | resize_longest_max: bool = False, 62 | fill_color: int = 0, 63 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 64 | ): 65 | mean = mean or OPENAI_DATASET_MEAN 66 | if not isinstance(mean, (list, tuple)): 67 | mean = (mean,) * 3 68 | 69 | std = std or OPENAI_DATASET_STD 70 | if not isinstance(std, (list, tuple)): 71 | std = (std,) * 3 72 | 73 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 74 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 75 | image_size = image_size[0] 76 | 77 | if isinstance(aug_cfg, dict): 78 | aug_cfg = AugmentationCfg(**aug_cfg) 79 | else: 80 | aug_cfg = aug_cfg or AugmentationCfg() 81 | normalize = Normalize(mean=mean, std=std) 82 | if is_train: 83 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} 84 | use_timm = aug_cfg_dict.pop('use_timm', False) 85 | if use_timm: 86 | from timm.data import create_transform # timm can still be optional 87 | if isinstance(image_size, (tuple, list)): 88 | assert len(image_size) >= 2 89 | input_size = (3,) + image_size[-2:] 90 | else: 91 | input_size = (3, image_size, image_size) 92 | # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time 93 | aug_cfg_dict.setdefault('interpolation', 'random') 94 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default 95 | train_transform = create_transform( 96 | input_size=input_size, 97 | is_training=True, 98 | hflip=0., 99 | mean=mean, 100 | std=std, 101 | re_mode='pixel', 102 | **aug_cfg_dict, 103 | ) 104 | else: 105 | train_transform = Compose([ 106 | RandomResizedCrop( 107 | image_size, 108 | scale=aug_cfg_dict.pop('scale'), 109 | interpolation=InterpolationMode.BICUBIC, 110 | ), 111 | _convert_to_rgb, 112 | ToTensor(), 113 | normalize, 114 | ]) 115 | if aug_cfg_dict: 116 | warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') 117 | return train_transform 118 | else: 119 | if resize_longest_max: 120 | transforms = [ 121 | ResizeMaxSize(image_size, fill=fill_color) 122 | ] 123 | else: 124 | transforms = [ 125 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 126 | CenterCrop(image_size), 127 | ] 128 | transforms.extend([ 129 | _convert_to_rgb, 130 | ToTensor(), 131 | normalize, 132 | ]) 133 | return Compose(transforms) 134 | 135 | 136 | def det_image_transform( 137 | image_size: int, 138 | is_train: bool, 139 | mean: Optional[Tuple[float, ...]] = None, 140 | std: Optional[Tuple[float, ...]] = None, 141 | fill_color: int = 0, 142 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 143 | ): 144 | mean = mean or OPENAI_DATASET_MEAN 145 | if not isinstance(mean, (list, tuple)): 146 | mean = (mean,) * 3 147 | 148 | std = std or OPENAI_DATASET_STD 149 | if not isinstance(std, (list, tuple)): 150 | std = (std,) * 3 151 | 152 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 153 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 154 | image_size = image_size[0] 155 | 156 | normalize = Normalize(mean=mean, std=std) 157 | if is_train: 158 | raise NotImplementedError 159 | else: 160 | transforms = [ 161 | Resize((image_size, image_size)), 162 | _convert_to_rgb, 163 | ToTensor(), 164 | normalize, 165 | ] 166 | return Compose(transforms) 167 | 168 | 169 | class ResizeLongest(nn.Module): 170 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fill=0): 171 | super().__init__() 172 | if not isinstance(max_size, int): 173 | raise TypeError(f"Size should be int. Got {type(max_size)}") 174 | self.max_size = max_size 175 | self.interpolation = interpolation 176 | self.fill = fill 177 | 178 | def forward(self, img): 179 | if isinstance(img, torch.Tensor): 180 | height, width = img.shape[1:] 181 | else: 182 | width, height = img.size 183 | scale = self.max_size / float(max(height, width)) 184 | new_height, new_width = round(height * scale), round(width * scale) 185 | 186 | img = F.resize(img, [new_height, new_width], self.interpolation) 187 | pad_h = self.max_size - new_height 188 | pad_w = self.max_size - new_width 189 | img = F.pad(img, padding=[0, 0, pad_w, pad_h], fill=self.fill) 190 | 191 | return img 192 | 193 | 194 | def get_scale(img, new_image): 195 | if isinstance(img, torch.Tensor): 196 | height, width = new_image.shape[-2:] 197 | else: 198 | width, height = img.size 199 | 200 | if isinstance(new_image, torch.Tensor): 201 | new_height, new_width = new_image.shape[-2:] 202 | else: 203 | new_width, new_height = new_image.size 204 | 205 | scale = (new_height/height, new_width/width) 206 | 207 | return scale -------------------------------------------------------------------------------- /open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.16.0' 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.2 2 | torchvision==0.16.2 3 | regex 4 | ftfy 5 | tqdm 6 | timm 7 | einops 8 | pycocotools 9 | xformers==0.0.23.post1 10 | panopticapi@git+https://githubfast.com/cocodataset/panopticapi.git -------------------------------------------------------------------------------- /scripts/test_vitb16_box.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m training.main --batch-size=32 \ 2 | --model EVA02-CLIP-B-16 --pretrained eva --test-type coco_panoptic --train-data="" \ 3 | --val-data data/coco/annotations/panoptic_val2017.json \ 4 | --embed-path metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTB16.npy \ 5 | --val-image-root data/coco/val2017 --cache-dir checkpoints/coco_vitb16.pt --extract-type="v2" \ 6 | --name test_vitb16 --downsample-factor 16 --det-image-size 224 -------------------------------------------------------------------------------- /scripts/test_vitb16_flickr.sh: -------------------------------------------------------------------------------- 1 | python -m training.test_flickr \ 2 | --model-name "EVA02-CLIP-B-16" \ 3 | --pretrained "./checkpoints/coco_vitb16.pt" \ 4 | --data "./data/flickr30k/flickr30k_test.json" \ 5 | --image-path "./data/flickr30k/flickr30k_images" \ 6 | --image-size 224 \ 7 | --device "cuda:0" -------------------------------------------------------------------------------- /scripts/test_vitb16_mscoco.sh: -------------------------------------------------------------------------------- 1 | python -m training.test_mscoco \ 2 | --model-name "EVA02-CLIP-B-16" \ 3 | --pretrained "./checkpoints/coco_vitb16.pt" \ 4 | --data "./data/coco/coco_test.json" \ 5 | --image-path "./data/coco/val2017" \ 6 | --image-size 224 \ 7 | --device "cuda:0" -------------------------------------------------------------------------------- /scripts/test_vitl14_box.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m training.main --batch-size=32 \ 2 | --model EVA02-CLIP-L-14-336 --pretrained eva --test-type coco_panoptic --train-data="" \ 3 | --val-data data/coco/annotations/panoptic_val2017.json \ 4 | --embed-path metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTL14x336.npy \ 5 | --val-image-root data/coco/val2017 --cache-dir checkpoints/coco_vitl14.pt --extract-type="v2" \ 6 | --name test_vitl14 --downsample-factor 14 --det-image-size 336 -------------------------------------------------------------------------------- /scripts/test_vitl14_flickr.sh: -------------------------------------------------------------------------------- 1 | python -m training.test_flickr \ 2 | --model-name "EVA02-CLIP-L-14-336" \ 3 | --pretrained "./checkpoints/coco_vitl14.pt" \ 4 | --data "./data/flickr30k/flickr30k_test.json" \ 5 | --image-path "./data/flickr30k/flickr30k_images" \ 6 | --image-size 336 \ 7 | --device "cuda:0" -------------------------------------------------------------------------------- /scripts/test_vitl14_mscoco.sh: -------------------------------------------------------------------------------- 1 | python -m training.test_mscoco \ 2 | --model-name "EVA02-CLIP-L-14-336" \ 3 | --pretrained "./checkpoints/coco_vitl14.pt" \ 4 | --data "./data/coco/coco_test.json" \ 5 | --image-path "./data/coco/val2017" \ 6 | --image-size 336 \ 7 | --device "cuda:0" -------------------------------------------------------------------------------- /scripts/train_vitb16.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m training.main --batch-size=32 --lr=1e-5 --wd=0.1 --epochs=10 --workers=4 \ 2 | --model EVA02-CLIP-B-16 --pretrained eva --warmup 1000 --zeroshot-frequency 1 --dataset-type proposals_distill \ 3 | --test-type coco_panoptic --train-data data/coco/coco_proposals.json --max-boxes 20 \ 4 | --val-data data/coco/annotations/panoptic_val2017.json --image-caption-path data/coco/annotations/captions_train2017.json \ 5 | --embed-path metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTB16.npy --train-image-root data/coco/train2017 \ 6 | --val-image-root data/coco/val2017 --cache-dir checkpoints/EVA02_CLIP_B_psz16_s8B.pt --log-every-n-steps 50 \ 7 | --save-frequency 1 --extract-type="v2" --image-region-caption-path "data/coco/coco_captions.json" \ 8 | --name train_vitb16 --downsample-factor 16 --det-image-size 224 --alpha 1 9 | -------------------------------------------------------------------------------- /scripts/train_vitl14.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 -m training.main --batch-size=10 --lr=1e-5 --wd=0.1 --epochs=10 --workers=4 \ 2 | --model EVA02-CLIP-L-14-336 --pretrained eva --warmup 1000 --zeroshot-frequency 1 --dataset-type proposals_distill \ 3 | --test-type coco_panoptic --train-data data/coco/coco_proposals.json --max-boxes 5 \ 4 | --val-data data/coco/annotations/panoptic_val2017.json --image-caption-path data/coco/annotations/captions_train2017.json \ 5 | --embed-path metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTL14x336.npy --train-image-root data/coco/train2017 \ 6 | --val-image-root data/coco/val2017 --cache-dir checkpoints/EVA02_CLIP_L_336_psz14_s6B.pt --log-every-n-steps 50 \ 7 | --save-frequency 1 --extract-type="v2" \ 8 | --name train_vitl14 --downsample-factor 14 --det-image-size 336 --alpha 1 9 | -------------------------------------------------------------------------------- /tools/generate_text_embeddings.py: -------------------------------------------------------------------------------- 1 | # Modified from [ViLD](https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild) 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | import open_clip 7 | 8 | 9 | def article(name): 10 | return 'an' if name[0] in 'aeiou' else 'a' 11 | 12 | def processed_name(name, rm_dot=False): 13 | # _ for lvis 14 | # / for obj365 15 | res = name.replace('_', ' ').replace('/', ' or ').lower() 16 | if rm_dot: 17 | res = res.rstrip('.') 18 | return res 19 | 20 | 21 | single_template = [ 22 | 'a photo of {article} {}.' 23 | ] 24 | 25 | multiple_templates = [ 26 | 'There is {article} {} in the scene.', 27 | 'There is the {} in the scene.', 28 | 'a photo of {article} {} in the scene.', 29 | 'a photo of the {} in the scene.', 30 | 'a photo of one {} in the scene.', 31 | 32 | 33 | 'itap of {article} {}.', 34 | 'itap of my {}.', # itap: I took a picture of 35 | 'itap of the {}.', 36 | 'a photo of {article} {}.', 37 | 'a photo of my {}.', 38 | 'a photo of the {}.', 39 | 'a photo of one {}.', 40 | 'a photo of many {}.', 41 | 42 | 'a good photo of {article} {}.', 43 | 'a good photo of the {}.', 44 | 'a bad photo of {article} {}.', 45 | 'a bad photo of the {}.', 46 | 'a photo of a nice {}.', 47 | 'a photo of the nice {}.', 48 | 'a photo of a cool {}.', 49 | 'a photo of the cool {}.', 50 | 'a photo of a weird {}.', 51 | 'a photo of the weird {}.', 52 | 53 | 'a photo of a small {}.', 54 | 'a photo of the small {}.', 55 | 'a photo of a large {}.', 56 | 'a photo of the large {}.', 57 | 58 | 'a photo of a clean {}.', 59 | 'a photo of the clean {}.', 60 | 'a photo of a dirty {}.', 61 | 'a photo of the dirty {}.', 62 | 63 | 'a bright photo of {article} {}.', 64 | 'a bright photo of the {}.', 65 | 'a dark photo of {article} {}.', 66 | 'a dark photo of the {}.', 67 | 68 | 'a photo of a hard to see {}.', 69 | 'a photo of the hard to see {}.', 70 | 'a low resolution photo of {article} {}.', 71 | 'a low resolution photo of the {}.', 72 | 'a cropped photo of {article} {}.', 73 | 'a cropped photo of the {}.', 74 | 'a close-up photo of {article} {}.', 75 | 'a close-up photo of the {}.', 76 | 'a jpeg corrupted photo of {article} {}.', 77 | 'a jpeg corrupted photo of the {}.', 78 | 'a blurry photo of {article} {}.', 79 | 'a blurry photo of the {}.', 80 | 'a pixelated photo of {article} {}.', 81 | 'a pixelated photo of the {}.', 82 | 83 | 'a black and white photo of the {}.', 84 | 'a black and white photo of {article} {}.', 85 | 86 | 'a plastic {}.', 87 | 'the plastic {}.', 88 | 89 | 'a toy {}.', 90 | 'the toy {}.', 91 | 'a plushie {}.', 92 | 'the plushie {}.', 93 | 'a cartoon {}.', 94 | 'the cartoon {}.', 95 | 96 | 'an embroidered {}.', 97 | 'the embroidered {}.', 98 | 99 | 'a painting of the {}.', 100 | 'a painting of a {}.', 101 | ] 102 | 103 | 104 | def build_text_embedding_coco(categories, model): 105 | templates = multiple_templates 106 | with torch.no_grad(): 107 | zeroshot_weights = [] 108 | attn12_weights = [] 109 | for category in categories: 110 | texts = [ 111 | template.format(processed_name(category, rm_dot=True), article=article(category)) 112 | for template in templates 113 | ] 114 | texts = [ 115 | "This is " + text if text.startswith("a") or text.startswith("the") else text 116 | for text in texts 117 | ] 118 | texts = open_clip.tokenize(texts).cuda() # tokenize 119 | text_embeddings = model.encode_text(texts) 120 | text_attnfeatures, _, _ = model.encode_text_endk(texts, stepk=12, normalize=True) 121 | 122 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 123 | text_embedding = text_embeddings.mean(dim=0) 124 | text_embedding /= text_embedding.norm() 125 | 126 | text_attnfeatures = text_attnfeatures.mean(0) 127 | text_attnfeatures = F.normalize(text_attnfeatures, dim=0) 128 | attn12_weights.append(text_attnfeatures) 129 | zeroshot_weights.append(text_embedding) 130 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0) 131 | attn12_weights = torch.stack(attn12_weights, dim=0) 132 | 133 | return zeroshot_weights, attn12_weights 134 | 135 | 136 | def build_text_embedding_lvis(categories, model, tokenizer): 137 | templates = multiple_templates 138 | 139 | with torch.no_grad(): 140 | all_text_embeddings = [] 141 | for category in tqdm(categories): 142 | texts = [ 143 | template.format( 144 | processed_name(category, rm_dot=True), article=article(category) 145 | ) 146 | for template in templates 147 | ] 148 | texts = [ 149 | "This is " + text if text.startswith("a") or text.startswith("the") else text 150 | for text in texts 151 | ] 152 | texts = tokenizer(texts).cuda() # tokenize 153 | 154 | text_embeddings = model.encode_text(texts) 155 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 156 | text_embedding = text_embeddings.mean(dim=0) 157 | text_embedding /= text_embedding.norm() 158 | 159 | all_text_embeddings.append(text_embedding) 160 | all_text_embeddings = torch.stack(all_text_embeddings, dim=0) 161 | 162 | return all_text_embeddings 163 | 164 | 165 | # voc_cats = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 166 | # 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 167 | # 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 168 | # 'tvmonitor') 169 | # text_embeddings, _ = build_text_embedding_coco(voc_cats) 170 | # np.save('datasets/metadata/voc_clip_hand_craft.npy', text_embeddings.cpu().numpy()) 171 | 172 | import argparse 173 | import json 174 | 175 | if __name__ == '__main__': 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument('--model_version', default='ViT-L-14-336') 178 | parser.add_argument('--ann', default='data/coco/annotations/panoptic_val2017.json') 179 | parser.add_argument('--out_path', default='metadata/coco_panoptic_clip_hand_craft_ViTL14x336.npy') 180 | parser.add_argument('--pretrained', default='openai') 181 | parser.add_argument('--cache_dir', default='checkpoints') 182 | 183 | args = parser.parse_args() 184 | 185 | model = open_clip.create_model( 186 | args.model_version, pretrained=args.pretrained, cache_dir=args.cache_dir 187 | ) 188 | tokenizer = open_clip.get_tokenizer(args.model_version) 189 | model.cuda() 190 | 191 | print('Loading', args.ann) 192 | data = json.load(open(args.ann, 'r')) 193 | cat_names = [x['name'] for x in \ 194 | sorted(data['categories'], key=lambda x: x['id'])] 195 | out_path = args.out_path 196 | text_embeddings = build_text_embedding_lvis(cat_names, model, tokenizer) 197 | np.save(out_path, text_embeddings.cpu().numpy()) 198 | -------------------------------------------------------------------------------- /training/.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__init__.py -------------------------------------------------------------------------------- /training/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/clipself.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/clipself.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/clipself.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/clipself.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/coco_api.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/coco_api.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/coco_api.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/coco_api.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/custom_transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/custom_transforms.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/custom_transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/custom_transforms.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/data.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/data.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/dist_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/dist_utils.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/dist_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/dist_utils.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/distributed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/distributed.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/distributed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/distributed.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/file_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/file_utils.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/file_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/file_utils.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/main.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/main.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/main.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/main.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/params.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/params.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/params.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/params.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/precision.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/precision.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/precision.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/precision.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/region_clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/region_clip.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/region_clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/region_clip.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/scheduler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/scheduler.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/test_flickr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/test_flickr.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/test_mscoco.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/test_mscoco.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/train.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/train.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/train.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/train.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/zero_shot.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/zero_shot.cpython-38.pyc -------------------------------------------------------------------------------- /training/__pycache__/zero_shot.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/zero_shot.cpython-39.pyc -------------------------------------------------------------------------------- /training/clipself.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class CLIPSelf: 7 | def __call__(self, batch, model, dist_model, itc_loss, device, cast_dtype, distributed, args): 8 | if distributed: 9 | model = model.module 10 | dist_model = dist_model.module 11 | images, normed_boxes, image_crops, texts, region_texts = batch 12 | 13 | images = images.to(device=device, dtype=cast_dtype, non_blocking=True) 14 | normed_boxes = normed_boxes.to(device=device, dtype=cast_dtype, non_blocking=True) 15 | image_crops = image_crops.to(device=device, dtype=cast_dtype, non_blocking=True) 16 | texts = texts.to(device=device, dtype=cast_dtype, non_blocking=True) 17 | texts = texts.squeeze(1) # alignment dimension 18 | region_texts = region_texts.to(device=device, dtype=cast_dtype, non_blocking=True) 19 | 20 | if args.multiscale: 21 | cur_h, cur_w = images.shape[2:] 22 | assert cur_h == cur_w 23 | if cur_h == 1024: 24 | tar_sizes = [320, 640, 896, 1024] 25 | elif cur_h == 896: 26 | tar_sizes = [336, 448, 672, 896] 27 | else: 28 | raise NotImplementedError 29 | tar_size = random.choice(tar_sizes) 30 | images = F.interpolate(images, size=(tar_size, tar_size), mode='bilinear') 31 | 32 | rois_list = [] 33 | crops_list = [] 34 | region_text_list = [] 35 | for bboxes_per_image, crops_per_image, region_per_text in zip(normed_boxes, image_crops, region_texts): 36 | valid = bboxes_per_image[:, -1] > 0.5 37 | rois_list.append(bboxes_per_image[valid, :4]) 38 | crops_list.append(crops_per_image[valid]) 39 | region_text_list.append(region_per_text[valid]) 40 | 41 | image_crops = torch.cat(crops_list) 42 | region_texts = torch.cat(region_text_list) 43 | teacher_crop_features = model.encode_image(image_crops, normalize=False) 44 | student_roi_features = model.encode_pseudo_boxes(images, rois_list, normalize=False, extract_type=args.extract_type) 45 | region_text_features = model.encode_text(region_texts, normalize=False) 46 | 47 | normed_student_features = F.normalize(student_roi_features, dim=-1) 48 | normed_teacher_features = F.normalize(teacher_crop_features, dim=-1) 49 | normed_region_text_features = F.normalize(region_text_features, dim=-1) 50 | losses = {} 51 | if "distill" in args.loss_type: 52 | loss_cosine = 1.0 - (normed_student_features * normed_teacher_features).sum(-1).mean() 53 | losses["distill"] = loss_cosine 54 | 55 | if "global_itc" in args.loss_type: 56 | image_features, text_features, logit_scale = model(images, texts) 57 | global_itc_loss, _ = itc_loss(image_features, text_features, logit_scale) 58 | losses["global_itc"] = global_itc_loss 59 | 60 | if "region_itc" in args.loss_type: 61 | assert normed_student_features.shape[0] == normed_region_text_features.shape[0] 62 | all_student_features = torch.zeros((args.batch_size*args.max_boxes, normed_student_features.shape[1])).to(device=device, dtype=cast_dtype, non_blocking=True) 63 | all_student_features[0:normed_student_features.shape[0]] = normed_student_features 64 | 65 | all_region_text_features = torch.zeros((args.batch_size*args.max_boxes, normed_student_features.shape[1])).to(device=device, dtype=cast_dtype, non_blocking=True) 66 | all_region_text_features[0:normed_region_text_features.shape[0]] = normed_region_text_features 67 | 68 | region_itc_loss, _ = itc_loss(all_student_features, all_region_text_features, model.logit_scale.exp()) 69 | losses["region_itc"] = region_itc_loss 70 | 71 | return losses, len(images), model.logit_scale.exp() 72 | -------------------------------------------------------------------------------- /training/coco_api.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # This file add snake case alias for coco api 3 | 4 | import warnings 5 | from collections import defaultdict 6 | from typing import List, Optional, Union 7 | 8 | import pycocotools 9 | from pycocotools.coco import COCO as _COCO 10 | from pycocotools.cocoeval import COCOeval as _COCOeval 11 | 12 | 13 | class COCO(_COCO): 14 | """This class is almost the same as official pycocotools package. 15 | 16 | It implements some snake case function aliases. So that the COCO class has 17 | the same interface as LVIS class. 18 | """ 19 | 20 | def __init__(self, annotation_file=None): 21 | if getattr(pycocotools, '__version__', '0') >= '12.0.2': 22 | warnings.warn( 23 | 'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501 24 | UserWarning) 25 | super().__init__(annotation_file=annotation_file) 26 | self.img_ann_map = self.imgToAnns 27 | self.cat_img_map = self.catToImgs 28 | 29 | def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None): 30 | return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd) 31 | 32 | def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]): 33 | return self.getCatIds(cat_names, sup_names, cat_ids) 34 | 35 | def get_img_ids(self, img_ids=[], cat_ids=[]): 36 | return self.getImgIds(img_ids, cat_ids) 37 | 38 | def load_anns(self, ids): 39 | return self.loadAnns(ids) 40 | 41 | def load_cats(self, ids): 42 | return self.loadCats(ids) 43 | 44 | def load_imgs(self, ids): 45 | return self.loadImgs(ids) 46 | 47 | 48 | # just for the ease of import 49 | COCOeval = _COCOeval 50 | 51 | 52 | class COCOPanoptic(COCO): 53 | """This wrapper is for loading the panoptic style annotation file. 54 | 55 | The format is shown in the CocoPanopticDataset class. 56 | 57 | Args: 58 | annotation_file (str, optional): Path of annotation file. 59 | Defaults to None. 60 | """ 61 | 62 | def __init__(self, annotation_file: Optional[str] = None) -> None: 63 | super(COCOPanoptic, self).__init__(annotation_file) 64 | 65 | def createIndex(self) -> None: 66 | """Create index.""" 67 | # create index 68 | print('creating index...') 69 | # anns stores 'segment_id -> annotation' 70 | anns, cats, imgs = {}, {}, {} 71 | img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list) 72 | if 'annotations' in self.dataset: 73 | for ann in self.dataset['annotations']: 74 | for seg_ann in ann['segments_info']: 75 | # to match with instance.json 76 | seg_ann['image_id'] = ann['image_id'] 77 | img_to_anns[ann['image_id']].append(seg_ann) 78 | # segment_id is not unique in coco dataset orz... 79 | # annotations from different images but 80 | # may have same segment_id 81 | if seg_ann['id'] in anns.keys(): 82 | anns[seg_ann['id']].append(seg_ann) 83 | else: 84 | anns[seg_ann['id']] = [seg_ann] 85 | 86 | # filter out annotations from other images 87 | img_to_anns_ = defaultdict(list) 88 | for k, v in img_to_anns.items(): 89 | img_to_anns_[k] = [x for x in v if x['image_id'] == k] 90 | img_to_anns = img_to_anns_ 91 | 92 | if 'images' in self.dataset: 93 | for img_info in self.dataset['images']: 94 | img_info['segm_file'] = img_info['file_name'].replace( 95 | 'jpg', 'png') 96 | imgs[img_info['id']] = img_info 97 | 98 | if 'categories' in self.dataset: 99 | for cat in self.dataset['categories']: 100 | cats[cat['id']] = cat 101 | 102 | if 'annotations' in self.dataset and 'categories' in self.dataset: 103 | for ann in self.dataset['annotations']: 104 | for seg_ann in ann['segments_info']: 105 | cat_to_imgs[seg_ann['category_id']].append(ann['image_id']) 106 | 107 | print('index created!') 108 | 109 | self.anns = anns 110 | self.imgToAnns = img_to_anns 111 | self.catToImgs = cat_to_imgs 112 | self.imgs = imgs 113 | self.cats = cats 114 | 115 | def load_anns(self, 116 | ids: Union[List[int], int] = []) -> Optional[List[dict]]: 117 | """Load anns with the specified ids. 118 | 119 | ``self.anns`` is a list of annotation lists instead of a 120 | list of annotations. 121 | 122 | Args: 123 | ids (Union[List[int], int]): Integer ids specifying anns. 124 | 125 | Returns: 126 | anns (List[dict], optional): Loaded ann objects. 127 | """ 128 | anns = [] 129 | 130 | if hasattr(ids, '__iter__') and hasattr(ids, '__len__'): 131 | # self.anns is a list of annotation lists instead of 132 | # a list of annotations 133 | for id in ids: 134 | anns += self.anns[id] 135 | return anns 136 | elif type(ids) == int: 137 | return self.anns[ids] 138 | -------------------------------------------------------------------------------- /training/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms.functional as F 5 | from torchvision.transforms import RandomCrop, InterpolationMode 6 | 7 | 8 | class CustomRandomResize(nn.Module): 9 | 10 | def __init__(self, scale=(0.5, 2.0), interpolation=InterpolationMode.BILINEAR): 11 | super().__init__() 12 | self.min_scale, self.max_scale = min(scale), max(scale) 13 | self.interpolation = interpolation 14 | 15 | def forward(self, img): 16 | if isinstance(img, torch.Tensor): 17 | height, width = img.shape[:2] 18 | else: 19 | width, height = img.size 20 | scale = random.uniform(self.min_scale, self.max_scale) 21 | new_size = [int(height * scale), int(width * scale)] 22 | img = F.resize(img, new_size, self.interpolation) 23 | 24 | return img 25 | 26 | 27 | class CustomRandomCrop(RandomCrop): 28 | def forward(self, img): 29 | """ 30 | Args: 31 | img (PIL Image or Tensor): Image to be cropped. 32 | 33 | Returns: 34 | PIL Image or Tensor: Cropped image. 35 | """ 36 | 37 | width, height = F.get_image_size(img) 38 | tar_h, tar_w = self.size 39 | 40 | tar_h = min(tar_h, height) 41 | tar_w = min(tar_w, width) 42 | i, j, h, w = self.get_params(img, (tar_h, tar_w)) 43 | 44 | return F.crop(img, i, j, h, w) 45 | -------------------------------------------------------------------------------- /training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def is_global_master(args): 13 | return args.rank == 0 14 | 15 | 16 | def is_local_master(args): 17 | return args.local_rank == 0 18 | 19 | 20 | def is_master(args, local=False): 21 | return is_local_master(args) if local else is_global_master(args) 22 | 23 | 24 | def is_using_horovod(): 25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 29 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 30 | return True 31 | else: 32 | return False 33 | 34 | 35 | def is_using_distributed(): 36 | if 'WORLD_SIZE' in os.environ: 37 | return int(os.environ['WORLD_SIZE']) > 1 38 | if 'SLURM_NTASKS' in os.environ: 39 | return int(os.environ['SLURM_NTASKS']) > 1 40 | return False 41 | 42 | 43 | def world_info_from_env(): 44 | local_rank = 0 45 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 46 | if v in os.environ: 47 | local_rank = int(os.environ[v]) 48 | break 49 | global_rank = 0 50 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 51 | if v in os.environ: 52 | global_rank = int(os.environ[v]) 53 | break 54 | world_size = 1 55 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 56 | if v in os.environ: 57 | world_size = int(os.environ[v]) 58 | break 59 | 60 | return local_rank, global_rank, world_size 61 | 62 | 63 | def init_distributed_device(args): 64 | # Distributed training = training on more than one GPU. 65 | # Works in both single and multi-node scenarios. 66 | args.distributed = False 67 | args.world_size = 1 68 | args.rank = 0 # global rank 69 | args.local_rank = 0 70 | if args.horovod: 71 | assert hvd is not None, "Horovod is not installed" 72 | hvd.init() 73 | args.local_rank = int(hvd.local_rank()) 74 | args.rank = hvd.rank() 75 | args.world_size = hvd.size() 76 | args.distributed = True 77 | os.environ['LOCAL_RANK'] = str(args.local_rank) 78 | os.environ['RANK'] = str(args.rank) 79 | os.environ['WORLD_SIZE'] = str(args.world_size) 80 | elif is_using_distributed(): 81 | if 'SLURM_PROCID' in os.environ: 82 | # DDP via SLURM 83 | args.local_rank, args.rank, args.world_size = world_info_from_env() 84 | # SLURM var -> torch.distributed vars in case needed 85 | os.environ['LOCAL_RANK'] = str(args.local_rank) 86 | os.environ['RANK'] = str(args.rank) 87 | os.environ['WORLD_SIZE'] = str(args.world_size) 88 | torch.distributed.init_process_group( 89 | backend=args.dist_backend, 90 | init_method=args.dist_url, 91 | world_size=args.world_size, 92 | rank=args.rank, 93 | ) 94 | else: 95 | # DDP via torchrun, torch.distributed.launch 96 | args.local_rank, _, _ = world_info_from_env() 97 | torch.distributed.init_process_group( 98 | backend=args.dist_backend, 99 | init_method=args.dist_url) 100 | args.world_size = torch.distributed.get_world_size() 101 | args.rank = torch.distributed.get_rank() 102 | args.distributed = True 103 | 104 | if torch.cuda.is_available(): 105 | if args.distributed and not args.no_set_device_rank: 106 | device = 'cuda:%d' % args.local_rank 107 | else: 108 | device = 'cuda:0' 109 | torch.cuda.set_device(device) 110 | else: 111 | device = 'cpu' 112 | args.device = device 113 | device = torch.device(device) 114 | return device 115 | 116 | 117 | def broadcast_object(args, obj, src=0): 118 | # broadcast a pickle-able python object from rank-0 to all ranks 119 | if args.horovod: 120 | return hvd.broadcast_object(obj, root_rank=src) 121 | else: 122 | if args.rank == src: 123 | objects = [obj] 124 | else: 125 | objects = [None] 126 | dist.broadcast_object_list(objects, src=src) 127 | return objects[0] 128 | 129 | 130 | def all_gather_object(args, obj, dst=0): 131 | # gather a pickle-able python object across all ranks 132 | if args.horovod: 133 | return hvd.allgather_object(obj) 134 | else: 135 | objects = [None for _ in range(args.world_size)] 136 | dist.all_gather_object(objects, obj) 137 | return objects 138 | -------------------------------------------------------------------------------- /training/file_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import multiprocessing 4 | import subprocess 5 | import time 6 | import fsspec 7 | import torch 8 | from tqdm import tqdm 9 | 10 | def remote_sync_s3(local_dir, remote_dir): 11 | # skip epoch_latest which can change during sync. 12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | if result.returncode != 0: 14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") 15 | return False 16 | 17 | logging.info(f"Successfully synced with S3 bucket") 18 | return True 19 | 20 | def remote_sync_fsspec(local_dir, remote_dir): 21 | # FIXME currently this is slow and not recommended. Look into speeding up. 22 | a = fsspec.get_mapper(local_dir) 23 | b = fsspec.get_mapper(remote_dir) 24 | 25 | for k in a: 26 | # skip epoch_latest which can change during sync. 27 | if 'epoch_latest.pt' in k: 28 | continue 29 | 30 | logging.info(f'Attempting to sync {k}') 31 | if k in b and len(a[k]) == len(b[k]): 32 | logging.debug(f'Skipping remote sync for {k}.') 33 | continue 34 | 35 | try: 36 | logging.info(f'Successful sync for {k}.') 37 | b[k] = a[k] 38 | except Exception as e: 39 | logging.info(f'Error during remote sync for {k}: {e}') 40 | return False 41 | 42 | return True 43 | 44 | def remote_sync(local_dir, remote_dir, protocol): 45 | logging.info('Starting remote sync.') 46 | if protocol == 's3': 47 | return remote_sync_s3(local_dir, remote_dir) 48 | elif protocol == 'fsspec': 49 | return remote_sync_fsspec(local_dir, remote_dir) 50 | else: 51 | logging.error('Remote protocol not known') 52 | return False 53 | 54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): 55 | while True: 56 | time.sleep(sync_every) 57 | remote_sync(local_dir, remote_dir, protocol) 58 | 59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol): 60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) 61 | return p 62 | 63 | # Note: we are not currently using this save function. 64 | def pt_save(pt_obj, file_path): 65 | of = fsspec.open(file_path, "wb") 66 | with of as f: 67 | torch.save(pt_obj, file_path) 68 | 69 | def pt_load(file_path, map_location=None): 70 | if file_path.startswith('s3'): 71 | logging.info('Loading remote checkpoint, which may take a bit.') 72 | of = fsspec.open(file_path, "rb") 73 | with of as f: 74 | out = torch.load(f, map_location=map_location) 75 | return out 76 | 77 | def check_exists(file_path): 78 | try: 79 | with fsspec.open(file_path): 80 | pass 81 | except FileNotFoundError: 82 | return False 83 | return True 84 | -------------------------------------------------------------------------------- /training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /training/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | 4 | 5 | def get_autocast(precision): 6 | if precision == 'amp': 7 | return torch.cuda.amp.autocast 8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16': 9 | # amp_bfloat16 is more stable than amp float16 for clip training 10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 11 | else: 12 | return suppress 13 | -------------------------------------------------------------------------------- /training/profile.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import open_clip 5 | import pandas as pd 6 | from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis 7 | 8 | 9 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler') 10 | 11 | # benchmark specific args 12 | parser.add_argument('--model', metavar='NAME', default='', 13 | help='model(s) to profile') 14 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 15 | help='Output csv file for results') 16 | 17 | 18 | def profile_fvcore( 19 | model, 20 | image_input_size=(3, 224, 224), 21 | text_input_size=(77,), 22 | batch_size=1, 23 | detailed=False, 24 | force_cpu=False 25 | ): 26 | if force_cpu: 27 | model = model.to('cpu') 28 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 29 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 30 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 31 | fca = FlopCountAnalysis(model, (example_image_input, example_text_input)) 32 | aca = ActivationCountAnalysis(model, (example_image_input, example_text_input)) 33 | if detailed: 34 | fcs = flop_count_str(fca) 35 | print(fcs) 36 | return fca.total(), aca.total() 37 | 38 | 39 | def profile_fvcore_text( 40 | model, 41 | text_input_size=(77,), 42 | batch_size=1, 43 | detailed=False, 44 | force_cpu=False 45 | ): 46 | if force_cpu: 47 | model = model.to('cpu') 48 | device = next(model.parameters()).device 49 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 50 | fca = FlopCountAnalysis(model, example_input) 51 | aca = ActivationCountAnalysis(model, example_input) 52 | if detailed: 53 | fcs = flop_count_str(fca) 54 | print(fcs) 55 | return fca.total(), aca.total() 56 | 57 | 58 | def profile_fvcore_image( 59 | model, 60 | image_input_size=(3, 224, 224), 61 | batch_size=1, 62 | detailed=False, 63 | force_cpu=False 64 | ): 65 | if force_cpu: 66 | model = model.to('cpu') 67 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 68 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 69 | fca = FlopCountAnalysis(model, example_input) 70 | aca = ActivationCountAnalysis(model, example_input) 71 | if detailed: 72 | fcs = flop_count_str(fca) 73 | print(fcs) 74 | return fca.total(), aca.total() 75 | 76 | 77 | def count_params(model): 78 | return sum([m.numel() for m in model.parameters()]) 79 | 80 | 81 | def profile_model(model_name): 82 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) 83 | model.eval() 84 | if torch.cuda.is_available(): 85 | model = model.cuda() 86 | 87 | if isinstance(model.visual.image_size, (tuple, list)): 88 | image_input_size = (3,) + tuple(model.visual.image_size[-2:]) 89 | else: 90 | image_input_size = (3, model.visual.image_size, model.visual.image_size) 91 | text_input_size = (77,) 92 | 93 | results = {} 94 | results['model'] = model_name 95 | results['image_size'] = image_input_size[1] 96 | 97 | model_cfg = open_clip.get_model_config(model_name) 98 | if model_cfg: 99 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) 100 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) 101 | results['image_width'] = int(vision_cfg.width) 102 | results['text_width'] = int(text_cfg.width) 103 | results['embed_dim'] = int(model_cfg['embed_dim']) 104 | else: 105 | results['image_width'] = 0 106 | results['text_width'] = 0 107 | results['embed_dim'] = 0 108 | 109 | retries = 2 110 | while retries: 111 | retries -= 1 112 | try: 113 | macs, acts = profile_fvcore( 114 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries) 115 | 116 | image_macs, image_acts = profile_fvcore_image( 117 | model.visual, image_input_size=image_input_size, force_cpu=not retries) 118 | 119 | text_macs, text_acts = profile_fvcore_text( 120 | model.text, text_input_size=text_input_size, force_cpu=not retries) 121 | 122 | results['gmacs'] = round(macs / 1e9, 2) 123 | results['macts'] = round(acts / 1e6, 2) 124 | results['mparams'] = round(count_params(model) / 1e6, 2) 125 | results['image_gmacs'] = round(image_macs / 1e9, 2) 126 | results['image_macts'] = round(image_acts / 1e6, 2) 127 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) 128 | results['text_gmacs'] = round(text_macs / 1e9, 2) 129 | results['text_macts'] = round(text_acts / 1e6, 2) 130 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2) 131 | except RuntimeError as e: 132 | pass 133 | return results 134 | 135 | 136 | def main(): 137 | args = parser.parse_args() 138 | 139 | # FIXME accept a text file name to allow lists of models in txt/csv 140 | if args.model == 'all': 141 | parsed_model = open_clip.list_models() 142 | else: 143 | parsed_model = args.model.split(',') 144 | 145 | results = [] 146 | for m in parsed_model: 147 | row = profile_model(m) 148 | results.append(row) 149 | 150 | df = pd.DataFrame(results, columns=results[0].keys()) 151 | df = df.sort_values('gmacs') 152 | print(df) 153 | if args.results_file: 154 | df.to_csv(args.results_file, index=False) 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /training/region_clip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | 7 | def get_fed_loss_inds(gt_classes, num_sample_cats, C): 8 | appeared = torch.unique(gt_classes) # C' 9 | prob = appeared.new_ones(C).float() 10 | if len(appeared) < num_sample_cats: 11 | prob[appeared] = 0 12 | more_appeared = torch.multinomial( 13 | prob, num_sample_cats - len(appeared), 14 | replacement=False) 15 | appeared = torch.cat([appeared, more_appeared]) 16 | return appeared 17 | 18 | 19 | class RegionCLIP(nn.Module): 20 | def __init__(self, args): 21 | super().__init__() 22 | embed_path = args.train_embed_path 23 | noun_embeddings = torch.from_numpy(np.load(embed_path)) 24 | noun_embeddings = F.normalize(noun_embeddings, dim=-1) 25 | self.register_buffer("noun_embeddings", noun_embeddings) 26 | self.place_holder = nn.Parameter(torch.ones(1)) 27 | 28 | def __call__(self, batch, model, dist_model, loss, device, cast_dtype, 29 | distributed, args): 30 | if distributed: 31 | model = model.module 32 | images, boxes = batch 33 | images = images.to(device=device, dtype=cast_dtype, non_blocking=True) 34 | boxes = boxes.to(device=device, non_blocking=True) 35 | 36 | boxes_list = [] 37 | boxes_label_list = [] 38 | 39 | for boxes_per_image in boxes: 40 | boxes_per_image = boxes_per_image[boxes_per_image[:, -1] > 0.5] 41 | boxes_label_list.append(boxes_per_image[:, 4].long()) 42 | boxes_list.append(boxes_per_image[:, :4]) 43 | boxes_labels = torch.cat(boxes_label_list) 44 | box_features = model.encode_pseudo_boxes(images, boxes_list, normalize=True, 45 | extract_type=args.extract_type) 46 | temp = model.logit_scale.exp().detach() 47 | boxes2nouns = box_features @ self.noun_embeddings.T * temp 48 | target = torch.zeros_like(boxes2nouns) 49 | target[range(len(boxes_labels)), boxes_labels] = 1.0 50 | 51 | appeared = get_fed_loss_inds(boxes_labels, 100, self.noun_embeddings.shape[0]) 52 | target = target[:, appeared] 53 | boxes2nouns = boxes2nouns[:, appeared] 54 | 55 | loss_cls = F.binary_cross_entropy_with_logits(boxes2nouns, target, reduction='none') # B x C 56 | loss_cls = loss_cls.sum(-1).mean() 57 | 58 | image_size = model.visual.image_size 59 | if isinstance(image_size, int): 60 | tar_h = tar_w = image_size 61 | else: 62 | tar_h, tar_w = image_size 63 | images = F.interpolate(images, size=(tar_h, tar_w), mode='bilinear') 64 | 65 | losses = dict(loss_contrast=loss_cls * args.contrast_weight) 66 | 67 | return losses, len(images), temp 68 | -------------------------------------------------------------------------------- /training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | return _lr_adjuster 22 | 23 | 24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): 25 | def _lr_adjuster(step): 26 | start_cooldown_step = steps - cooldown_steps 27 | if step < warmup_length: 28 | lr = _warmup_lr(base_lr, warmup_length, step) 29 | else: 30 | if step < start_cooldown_step: 31 | lr = base_lr 32 | else: 33 | e = step - start_cooldown_step 34 | es = steps - start_cooldown_step 35 | # linear decay if power == 1; polynomial decay otherwise; 36 | decay = (1 - (e/es)) ** cooldown_power 37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 38 | assign_learning_rate(optimizer, lr) 39 | return lr 40 | return _lr_adjuster 41 | 42 | 43 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 44 | def _lr_adjuster(step): 45 | if step < warmup_length: 46 | lr = _warmup_lr(base_lr, warmup_length, step) 47 | else: 48 | e = step - warmup_length 49 | es = steps - warmup_length 50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 51 | assign_learning_rate(optimizer, lr) 52 | return lr 53 | return _lr_adjuster 54 | -------------------------------------------------------------------------------- /training/test_flickr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from open_clip import create_model_and_transforms, get_tokenizer 3 | from PIL import Image 4 | import json 5 | import tqdm 6 | import os 7 | import argparse 8 | 9 | 10 | def main(args): 11 | model_name = args.model_name 12 | pretrained = args.pretrained 13 | data = args.data 14 | image_path = args.image_path 15 | image_size = args.image_size 16 | device = args.device 17 | model, _, preprocess = create_model_and_transforms( 18 | model_name, 19 | "eva", 20 | "amp", 21 | device="cpu", 22 | jit=False, 23 | force_quick_gelu=False, 24 | force_custom_text=False, 25 | force_patch_dropout=None, 26 | force_image_size=None, 27 | pretrained_image=False, 28 | image_mean=None, 29 | image_std=None, 30 | aug_cfg={}, 31 | output_dict=True, 32 | cache_dir=pretrained, 33 | det_image_size=image_size, 34 | dataset_type="grid_distill", 35 | ) 36 | tokenizer = get_tokenizer(model_name) 37 | model = model.to(device) 38 | image_list = [] 39 | caption_list = [] 40 | image_fea = [] 41 | caption_fea = [] 42 | txt2img = {} 43 | img2txt = {} 44 | txt_id = 0 45 | with open(data) as f: 46 | annotation = json.load(f) 47 | for img_id, ann in enumerate(annotation): 48 | img2txt[img_id] = [] 49 | image_list.append(ann['image']) 50 | for caption in ann['caption']: 51 | caption_list.append(caption) 52 | img2txt[img_id].append(txt_id) 53 | txt2img[txt_id] = img_id 54 | txt_id += 1 55 | 56 | with torch.no_grad(), torch.cuda.amp.autocast(): 57 | tmp_image_list = [] 58 | for i in tqdm.tqdm(range(len(image_list))): 59 | tmp_image_list.append(preprocess[0](Image.open(os.path.join(image_path, image_list[i]))).unsqueeze(0).to(device)) 60 | if len(tmp_image_list) % 64 == 0 or i == len(image_list)-1: 61 | tmp_image = torch.cat(tmp_image_list, dim=0) 62 | image_features = model.encode_image(tmp_image) 63 | image_features /= image_features.norm(dim=-1, keepdim=True) 64 | image_fea.append(image_features) 65 | tmp_image_list = [] 66 | 67 | tmp_text_list = [] 68 | for i in tqdm.tqdm(range(len(caption_list))): 69 | tmp_text_list.append(caption_list[i]) 70 | if len(tmp_text_list) % 64 == 0 or i == len(caption_list)-1: 71 | text = tokenizer(tmp_text_list).to(device) 72 | text_features = model.encode_text(text) 73 | text_features /= text_features.norm(dim=-1, keepdim=True) 74 | caption_fea.append(text_features) 75 | tmp_text_list = [] 76 | 77 | image_fea_total = torch.cat(image_fea, dim=0) 78 | caption_fea_total = torch.cat(caption_fea, dim=0) 79 | sims = image_fea_total@caption_fea_total.t() 80 | _, topk_idx = sims.topk(k=1, dim=0) 81 | count = 0 82 | 83 | for i in range(topk_idx.shape[1]): 84 | if topk_idx[0,i] == txt2img[i]: 85 | count += 1 86 | print("文搜图的准确率为:{:.2f}%".format(100*count/topk_idx.shape[1])) 87 | 88 | sims = sims.t() 89 | _, topk_idx = sims.topk(k=1, dim=0) 90 | 91 | count = 0 92 | new_list = [] 93 | for i in range(topk_idx.shape[1]): 94 | if topk_idx[0,i] in img2txt[i]: 95 | count += 1 96 | else: 97 | new_list.append({"image": image_list[i], "caption":caption_list[img2txt[i][0]]}) 98 | print("图搜文的准确率为:{:.2f}%".format(100*count/topk_idx.shape[1])) 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument( 104 | "--model-name", 105 | type=str, 106 | default="EVA02-CLIP-B-16", 107 | ) 108 | parser.add_argument( 109 | "--pretrained", 110 | type=str, 111 | default="./checkpoints/coco_vitb16.pt", 112 | ) 113 | parser.add_argument( 114 | "--data", 115 | type=str, 116 | default="./data/flickr30k/flickr30k_test.json" 117 | ) 118 | parser.add_argument( 119 | "--image-path", 120 | type=str, 121 | default="./data/flickr30k/flickr30k_images" 122 | ) 123 | parser.add_argument( 124 | "--image-size", 125 | type=int, 126 | default=224 127 | ) 128 | parser.add_argument( 129 | "--device", 130 | type=str, 131 | default="cuda:0" 132 | ) 133 | args = parser.parse_args() 134 | main(args) -------------------------------------------------------------------------------- /training/test_mscoco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from open_clip import create_model_and_transforms, get_tokenizer 3 | from PIL import Image 4 | import json 5 | import tqdm 6 | import os 7 | import argparse 8 | 9 | 10 | def main(args): 11 | model_name = args.model_name 12 | pretrained = args.pretrained 13 | data = args.data 14 | image_path = args.image_path 15 | image_size = args.image_size 16 | device = args.device 17 | model, _, preprocess = create_model_and_transforms( 18 | model_name, 19 | "eva", 20 | "amp", 21 | device="cpu", 22 | jit=False, 23 | force_quick_gelu=False, 24 | force_custom_text=False, 25 | force_patch_dropout=None, 26 | force_image_size=None, 27 | pretrained_image=False, 28 | image_mean=None, 29 | image_std=None, 30 | aug_cfg={}, 31 | output_dict=True, 32 | cache_dir=pretrained, 33 | det_image_size=image_size, 34 | dataset_type="grid_distill", 35 | ) 36 | tokenizer = get_tokenizer(model_name) 37 | model = model.to(device) 38 | image_list = [] 39 | caption_list = [] 40 | image_fea = [] 41 | caption_fea = [] 42 | txt2img = {} 43 | img2txt = {} 44 | txt_id = 0 45 | with open(data) as f: 46 | annotation = json.load(f) 47 | for img_id, ann in enumerate(annotation): 48 | img2txt[img_id] = [] 49 | image_list.append(ann['image']) 50 | for caption in ann['caption']: 51 | caption_list.append(caption) 52 | img2txt[img_id].append(txt_id) 53 | txt2img[txt_id] = img_id 54 | txt_id += 1 55 | 56 | with torch.no_grad(), torch.cuda.amp.autocast(): 57 | tmp_image_list = [] 58 | for i in tqdm.tqdm(range(len(image_list))): 59 | tmp_image_list.append(preprocess[0](Image.open(os.path.join(image_path, image_list[i]))).unsqueeze(0).to(device)) 60 | if len(tmp_image_list) % 64 == 0 or i == len(image_list)-1: 61 | tmp_image = torch.cat(tmp_image_list, dim=0) 62 | image_features = model.encode_image(tmp_image) 63 | image_features /= image_features.norm(dim=-1, keepdim=True) 64 | image_fea.append(image_features) 65 | tmp_image_list = [] 66 | 67 | tmp_text_list = [] 68 | for i in tqdm.tqdm(range(len(caption_list))): 69 | tmp_text_list.append(caption_list[i]) 70 | if len(tmp_text_list) % 64 == 0 or i == len(caption_list)-1: 71 | text = tokenizer(tmp_text_list).to(device) 72 | text_features = model.encode_text(text) 73 | text_features /= text_features.norm(dim=-1, keepdim=True) 74 | caption_fea.append(text_features) 75 | tmp_text_list = [] 76 | 77 | image_fea_total = torch.cat(image_fea, dim=0) 78 | caption_fea_total = torch.cat(caption_fea, dim=0) 79 | sims = image_fea_total@caption_fea_total.t() 80 | _, topk_idx = sims.topk(k=1, dim=0) 81 | count = 0 82 | 83 | for i in range(topk_idx.shape[1]): 84 | if topk_idx[0,i] == txt2img[i]: 85 | count += 1 86 | print("文搜图的准确率为:{:.2f}%".format(100*count/topk_idx.shape[1])) 87 | 88 | sims = sims.t() 89 | _, topk_idx = sims.topk(k=1, dim=0) 90 | 91 | count = 0 92 | new_list = [] 93 | for i in range(topk_idx.shape[1]): 94 | if topk_idx[0,i] in img2txt[i]: 95 | count += 1 96 | else: 97 | new_list.append({"image": image_list[i], "caption":caption_list[img2txt[i][0]]}) 98 | print("图搜文的准确率为:{:.2f}%".format(100*count/topk_idx.shape[1])) 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument( 104 | "--model-name", 105 | type=str, 106 | default="EVA02-CLIP-B-16", 107 | ) 108 | parser.add_argument( 109 | "--pretrained", 110 | type=str, 111 | default="./checkpoints/coco_vitb16.pt", 112 | ) 113 | parser.add_argument( 114 | "--data", 115 | type=str, 116 | default="./data/coco/coco_test.json" 117 | ) 118 | parser.add_argument( 119 | "--image-path", 120 | type=str, 121 | default="./data/coco/val2017" 122 | ) 123 | parser.add_argument( 124 | "--image-size", 125 | type=int, 126 | default=224 127 | ) 128 | parser.add_argument( 129 | "--device", 130 | type=str, 131 | default="cuda:0" 132 | ) 133 | args = parser.parse_args() 134 | main(args) -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import time 5 | import torch 6 | 7 | from open_clip.eva_clip import ClipLoss, get_cast_dtype 8 | from .distributed import is_master 9 | from .zero_shot import zero_shot_eval 10 | from .precision import get_autocast 11 | import os 12 | 13 | 14 | class AverageMeter(object): 15 | """Computes and stores the average and current value""" 16 | 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | def postprocess_clip_output(model_out): 33 | return { 34 | "image_features": model_out[0], 35 | "text_features": model_out[1], 36 | "logit_scale": model_out[2] 37 | } 38 | 39 | def unwrap_model(model): 40 | if hasattr(model, 'module'): 41 | return model.module 42 | else: 43 | return model 44 | 45 | 46 | def backward(total_loss, scaler): 47 | if scaler is not None: 48 | scaler.scale(total_loss).backward() 49 | else: 50 | total_loss.backward() 51 | 52 | 53 | @torch.no_grad() 54 | def student_teacher_ensemble(student, teacher, alpha=0.5): 55 | target_state_dict = {} 56 | for k, v in student.items(): 57 | target_state_dict[k] = v * alpha + teacher[k] * (1.0 - alpha) 58 | 59 | return target_state_dict 60 | 61 | 62 | def train_one_epoch(model, method, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args): 63 | device = torch.device(args.device) 64 | autocast = get_autocast(args.precision) 65 | cast_dtype = get_cast_dtype(args.precision) 66 | 67 | model.train() 68 | itc_loss = ClipLoss( 69 | local_loss=args.local_loss, 70 | gather_with_grad=args.gather_with_grad, 71 | cache_labels=True, 72 | rank=args.rank, 73 | world_size=args.world_size, 74 | ) 75 | if dist_model is not None: 76 | dist_model.eval() 77 | 78 | data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch 79 | dataloader = data['train'].dataloader 80 | num_batches_per_epoch = dataloader.num_batches // args.accum_freq 81 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) 82 | 83 | losses_m = {} 84 | batch_time_m = AverageMeter() 85 | data_time_m = AverageMeter() 86 | end = time.time() 87 | for i, batch in enumerate(dataloader): 88 | i_accum = i // args.accum_freq 89 | step = num_batches_per_epoch * epoch + i_accum 90 | 91 | if not args.skip_scheduler: 92 | scheduler(step) 93 | 94 | data_time_m.update(time.time() - end) 95 | optimizer.zero_grad() 96 | assert args.accum_freq == 1, "accum freq disabled" 97 | with autocast(): 98 | losses, batch_size, logit_scale = method(batch, model, dist_model, itc_loss, device, cast_dtype, 99 | args.distributed, args) 100 | total_loss = sum(losses.values()) 101 | losses["loss"] = total_loss 102 | 103 | backward(total_loss, scaler) 104 | 105 | if scaler is not None: 106 | if args.horovod: 107 | optimizer.synchronize() 108 | scaler.unscale_(optimizer) 109 | if args.grad_clip_norm is not None: 110 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) 111 | with optimizer.skip_synchronize(): 112 | scaler.step(optimizer) 113 | else: 114 | if args.grad_clip_norm is not None: 115 | scaler.unscale_(optimizer) 116 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) 117 | scaler.step(optimizer) 118 | scaler.update() 119 | else: 120 | if args.grad_clip_norm is not None: 121 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) 122 | optimizer.step() 123 | 124 | # Note: we clamp to 4.6052 = ln(100), as in the original paper. 125 | with torch.no_grad(): 126 | unwrap_model(model).logit_scale.clamp_(0, math.log(100)) 127 | 128 | batch_time_m.update(time.time() - end) 129 | end = time.time() 130 | batch_count = i_accum + 1 131 | if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): 132 | # batch_size = len(images) 133 | num_samples = batch_count * batch_size * args.accum_freq * args.world_size 134 | samples_per_epoch = dataloader.num_samples 135 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 136 | 137 | # NOTE loss is coarsely sampled, just master node and per log update 138 | for key, val in losses.items(): 139 | if key not in losses_m: 140 | losses_m[key] = AverageMeter() 141 | losses_m[key].update(val.item(), batch_size) 142 | 143 | logit_scale_scalar = logit_scale.item() 144 | loss_log = " ".join( 145 | [ 146 | f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" 147 | for loss_name, loss_m in losses_m.items() 148 | ] 149 | ) 150 | samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val 151 | samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val 152 | logging.info( 153 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 154 | f"Data (t): {data_time_m.avg:.3f} " 155 | f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " 156 | f"LR: {optimizer.param_groups[0]['lr']:5f} " 157 | f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log 158 | ) 159 | 160 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing 161 | log_data = { 162 | "data_time": data_time_m.val, 163 | "batch_time": batch_time_m.val, 164 | "samples_per_second": samples_per_second, 165 | "samples_per_second_per_gpu": samples_per_second_per_gpu, 166 | "scale": logit_scale_scalar, 167 | "lr": optimizer.param_groups[0]["lr"] 168 | } 169 | log_data.update({name:val.val for name,val in losses_m.items()}) 170 | # resetting batch / data time meters per log window 171 | batch_time_m.reset() 172 | data_time_m.reset() 173 | 174 | 175 | def evaluate(model, data, epoch, args): 176 | metrics = {} 177 | model.eval() 178 | 179 | zero_shot_metrics = zero_shot_eval(model, data, epoch, args) 180 | if not is_master(args): 181 | return {} 182 | metrics.update(zero_shot_metrics) 183 | if not metrics: 184 | return metrics 185 | 186 | keys = ''.join([f"{k}, " for k in metrics.keys() if 'all' in k])[:-2] 187 | values = ''.join([f'{round(v, 4):.4f}, ' for k, v in metrics.items() if 'all' in k])[:-2] 188 | 189 | logging.info( 190 | f"Eval Epoch: {epoch}. " 191 | + f"{keys}: {values}." 192 | ) 193 | # TODO save the results as plots 194 | logging.info(metrics) 195 | 196 | if args.save_logs: 197 | with open(os.path.join(args.checkpoint_path, "results.json"), "a+") as f: 198 | f.write(json.dumps(metrics)) 199 | f.write("\n") 200 | 201 | return metrics 202 | -------------------------------------------------------------------------------- /training/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from functools import partial 3 | from six.moves import map, zip 4 | 5 | 6 | def multi_apply(func, *args, **kwargs): 7 | """Apply function to a list of arguments. 8 | Note: 9 | This function applies the ``func`` to multiple inputs and 10 | map the multiple outputs of the ``func`` into different 11 | list. Each list contains the same type of outputs corresponding 12 | to different inputs. 13 | Args: 14 | func (Function): A function that will be applied to a list of 15 | arguments 16 | Returns: 17 | tuple(list): A tuple containing multiple list, each list contains \ 18 | a kind of returned results by the function 19 | """ 20 | pfunc = partial(func, **kwargs) if kwargs else func 21 | map_results = map(pfunc, *args) 22 | return tuple(map(list, zip(*map_results))) 23 | 24 | 25 | def mask2box(mask): 26 | ys, xs = np.where(mask) 27 | y0, y1 = ys.min(), ys.max() 28 | x0, x1 = xs.min(), xs.max() 29 | 30 | return x0, y0, x1, y1 31 | -------------------------------------------------------------------------------- /training/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn.functional as F 4 | from training.dist_utils import all_gather 5 | from tqdm import tqdm 6 | from .distributed import is_master 7 | from open_clip import get_cast_dtype 8 | from .precision import get_autocast 9 | 10 | 11 | def run(model, dataloader, args): 12 | cls_embeddings = dataloader.dataset.embeddings 13 | cls_embeddings = F.normalize(torch.from_numpy(cls_embeddings).float(), dim=-1) 14 | cls_embeddings = cls_embeddings.to(args.device) 15 | autocast = get_autocast(args.precision) 16 | cast_dtype = get_cast_dtype(args.precision) 17 | if cast_dtype is not None: 18 | cls_embeddings = cls_embeddings.to(dtype=cast_dtype) 19 | with torch.no_grad(): 20 | correct_rois = [] 21 | all_cls_labels = [] 22 | for images, bboxes, _, gt_masks, _ \ 23 | in tqdm(dataloader, disable=not is_master(args)): 24 | images = images.to(args.device) 25 | bboxes = bboxes.to(args.device) 26 | gt_masks = gt_masks.to(args.device) 27 | if cast_dtype is not None: 28 | images = images.to(dtype=cast_dtype) 29 | bboxes = bboxes.to(dtype=cast_dtype) 30 | gt_masks = gt_masks.to(dtype=cast_dtype) 31 | cls_labels = [] 32 | rois = [] 33 | for bboxes_per_image in bboxes: 34 | valid = bboxes_per_image[:, 5] > 0.5 35 | rois.append(bboxes_per_image[valid, :4]) 36 | cls_labels.append(bboxes_per_image[valid, 4]) 37 | cls_labels = torch.cat(cls_labels, dim=0).to(torch.long) 38 | if cls_labels.shape[0] == 0: 39 | continue 40 | with autocast(): 41 | # predict 42 | if args.distributed and not args.horovod: 43 | module = model.module 44 | else: 45 | module = model 46 | roi_extractor = module.encode_pseudo_boxes 47 | roi_features = roi_extractor(images, rois, normalize=True, 48 | extract_type=args.extract_type) 49 | 50 | if cast_dtype is not None: 51 | roi_features = roi_features.to(dtype=cast_dtype) 52 | 53 | roi_logits = roi_features @ cls_embeddings.T 54 | 55 | _, roi_top5_inds = roi_logits.topk(5) 56 | correct_rois.append(roi_top5_inds == cls_labels.view(-1, 1)) 57 | all_cls_labels.append(cls_labels) 58 | 59 | correct_rois = torch.cat(correct_rois).float() 60 | all_cls_labels = torch.cat(all_cls_labels) 61 | if args.distributed and not args.horovod: 62 | correct_rois = multi_gpu_sync(correct_rois) 63 | all_cls_labels = multi_gpu_sync(all_cls_labels) 64 | 65 | return correct_rois, all_cls_labels 66 | 67 | 68 | def multi_gpu_sync(x): 69 | device = x.device 70 | x_list = all_gather(x.cpu()) 71 | x = torch.cat([res.to(device) for res in x_list]) 72 | return x 73 | 74 | 75 | def macc_with_box(correct_matrix, all_cls_labels, prefix): 76 | def _macc(corrects, cls_labels): 77 | min_id = cls_labels.min().item() 78 | max_id = cls_labels.max().item() 79 | cand_labels = list(range(min_id, max_id+1)) 80 | 81 | acc_per_cls = [] 82 | 83 | for lb in cand_labels: 84 | corrects_per_cls = corrects[cls_labels == lb] 85 | if corrects_per_cls.shape[0] == 0: 86 | continue 87 | acc_per_cls.append(corrects_per_cls.mean().half().item()) 88 | 89 | return sum(acc_per_cls) / len(acc_per_cls) 90 | 91 | results = {} 92 | 93 | box_top1_acc = _macc(correct_matrix[:, 0], all_cls_labels) 94 | box_top5_acc = _macc(correct_matrix.sum(-1), all_cls_labels) 95 | 96 | results[f'{prefix}.box.macc1'] = box_top1_acc 97 | results[f'{prefix}.box.macc5'] = box_top5_acc 98 | 99 | return results 100 | 101 | 102 | def zero_shot_eval(model, data, epoch, args): 103 | if 'val' not in data: 104 | return {} 105 | if args.zeroshot_frequency == 0: 106 | return {} 107 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 108 | return {} 109 | logging.info('Region classifier') 110 | results = {} 111 | correct_rois, all_cls_labels = run(model, data['val'].dataloader, args) 112 | results.update(macc_with_box(correct_rois, all_cls_labels, 'rois')) 113 | return results 114 | --------------------------------------------------------------------------------