├── README.md
├── figs
└── fineclip_arch.png
├── metadata
├── coco_panoptic_clip_hand_craft_EVACLIP_ViTB16.npy
└── coco_panoptic_clip_hand_craft_EVACLIP_ViTL14x336.npy
├── open_clip
├── .DS_Store
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── coca_model.cpython-38.pyc
│ ├── coca_model.cpython-39.pyc
│ ├── constants.cpython-38.pyc
│ ├── constants.cpython-39.pyc
│ ├── factory.cpython-38.pyc
│ ├── factory.cpython-39.pyc
│ ├── hf_configs.cpython-38.pyc
│ ├── hf_configs.cpython-39.pyc
│ ├── hf_model.cpython-38.pyc
│ ├── hf_model.cpython-39.pyc
│ ├── loss.cpython-38.pyc
│ ├── loss.cpython-39.pyc
│ ├── model.cpython-38.pyc
│ ├── model.cpython-39.pyc
│ ├── modified_resnet.cpython-38.pyc
│ ├── modified_resnet.cpython-39.pyc
│ ├── openai.cpython-38.pyc
│ ├── openai.cpython-39.pyc
│ ├── pretrained.cpython-38.pyc
│ ├── pretrained.cpython-39.pyc
│ ├── push_to_hf_hub.cpython-38.pyc
│ ├── push_to_hf_hub.cpython-39.pyc
│ ├── timm_model.cpython-38.pyc
│ ├── timm_model.cpython-39.pyc
│ ├── tokenizer.cpython-38.pyc
│ ├── tokenizer.cpython-39.pyc
│ ├── transform.cpython-38.pyc
│ ├── transform.cpython-39.pyc
│ ├── transformer.cpython-38.pyc
│ ├── transformer.cpython-39.pyc
│ ├── utils.cpython-38.pyc
│ ├── utils.cpython-39.pyc
│ ├── version.cpython-38.pyc
│ └── version.cpython-39.pyc
├── bpe_simple_vocab_16e6.txt.gz
├── coca_model.py
├── constants.py
├── customs.py
├── eva_clip
│ ├── .DS_Store
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── constants.cpython-38.pyc
│ │ ├── constants.cpython-39.pyc
│ │ ├── eva_vit_model.cpython-38.pyc
│ │ ├── eva_vit_model.cpython-39.pyc
│ │ ├── factory.cpython-38.pyc
│ │ ├── factory.cpython-39.pyc
│ │ ├── hf_configs.cpython-38.pyc
│ │ ├── hf_configs.cpython-39.pyc
│ │ ├── hf_model.cpython-38.pyc
│ │ ├── hf_model.cpython-39.pyc
│ │ ├── loss.cpython-38.pyc
│ │ ├── loss.cpython-39.pyc
│ │ ├── model.cpython-38.pyc
│ │ ├── model.cpython-39.pyc
│ │ ├── modified_resnet.cpython-38.pyc
│ │ ├── modified_resnet.cpython-39.pyc
│ │ ├── openai.cpython-38.pyc
│ │ ├── openai.cpython-39.pyc
│ │ ├── pretrained.cpython-38.pyc
│ │ ├── pretrained.cpython-39.pyc
│ │ ├── rope.cpython-38.pyc
│ │ ├── rope.cpython-39.pyc
│ │ ├── timm_model.cpython-38.pyc
│ │ ├── timm_model.cpython-39.pyc
│ │ ├── tokenizer.cpython-38.pyc
│ │ ├── tokenizer.cpython-39.pyc
│ │ ├── transform.cpython-38.pyc
│ │ ├── transform.cpython-39.pyc
│ │ ├── transformer.cpython-38.pyc
│ │ ├── transformer.cpython-39.pyc
│ │ ├── utils.cpython-38.pyc
│ │ └── utils.cpython-39.pyc
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── constants.py
│ ├── eva_vit_model.py
│ ├── factory.py
│ ├── hf_configs.py
│ ├── hf_model.py
│ ├── loss.py
│ ├── model.py
│ ├── model_configs
│ │ ├── EVA01-CLIP-B-16.json
│ │ ├── EVA01-CLIP-g-14-plus.json
│ │ ├── EVA01-CLIP-g-14.json
│ │ ├── EVA02-CLIP-B-16.json
│ │ ├── EVA02-CLIP-L-14-336.json
│ │ ├── EVA02-CLIP-L-14.json
│ │ ├── EVA02-CLIP-bigE-14-plus.json
│ │ └── EVA02-CLIP-bigE-14.json
│ ├── modified_resnet.py
│ ├── openai.py
│ ├── pretrained.py
│ ├── rope.py
│ ├── timm_model.py
│ ├── tokenizer.py
│ ├── transform.py
│ ├── transformer.py
│ └── utils.py
├── factory.py
├── generation_utils.py
├── get_text_embedding_voc.py
├── hf_configs.py
├── hf_model.py
├── loss.py
├── model.py
├── model_configs
│ ├── RN101-quickgelu.json
│ ├── RN101.json
│ ├── RN50-quickgelu.json
│ ├── RN50.json
│ ├── RN50x16.json
│ ├── RN50x4.json
│ ├── RN50x64.json
│ ├── ViT-B-16-plus-240.json
│ ├── ViT-B-16-plus.json
│ ├── ViT-B-16.json
│ ├── ViT-B-32-plus-256.json
│ ├── ViT-B-32-quickgelu.json
│ ├── ViT-B-32.json
│ ├── ViT-H-14.json
│ ├── ViT-H-16.json
│ ├── ViT-L-14-280.json
│ ├── ViT-L-14-336.json
│ ├── ViT-L-14.json
│ ├── ViT-L-16-320.json
│ ├── ViT-L-16.json
│ ├── ViT-M-16-alt.json
│ ├── ViT-M-16.json
│ ├── ViT-M-32-alt.json
│ ├── ViT-M-32.json
│ ├── ViT-S-16-alt.json
│ ├── ViT-S-16.json
│ ├── ViT-S-32-alt.json
│ ├── ViT-S-32.json
│ ├── ViT-bigG-14.json
│ ├── ViT-e-14.json
│ ├── ViT-g-14.json
│ ├── coca_ViT-B-32.json
│ ├── coca_ViT-L-14.json
│ ├── coca_base.json
│ ├── coca_roberta-ViT-B-32.json
│ ├── convnext_base.json
│ ├── convnext_base_w.json
│ ├── convnext_base_w_320.json
│ ├── convnext_large.json
│ ├── convnext_large_d.json
│ ├── convnext_large_d_320.json
│ ├── convnext_small.json
│ ├── convnext_tiny.json
│ ├── convnext_xlarge.json
│ ├── convnext_xxlarge.json
│ ├── convnext_xxlarge_320.json
│ ├── mt5-base-ViT-B-32.json
│ ├── mt5-xl-ViT-H-14.json
│ ├── roberta-ViT-B-32.json
│ ├── swin_base_patch4_window7_224.json
│ ├── vit_medium_patch16_gap_256.json
│ ├── vit_relpos_medium_patch16_cls_224.json
│ ├── xlm-roberta-base-ViT-B-32.json
│ └── xlm-roberta-large-ViT-H-14.json
├── modified_resnet.py
├── openai.py
├── pretrained.py
├── push_to_hf_hub.py
├── timm_model.py
├── tokenizer.py
├── transform.py
├── transformer.py
├── utils.py
└── version.py
├── requirements.txt
├── scripts
├── test_vitb16_box.sh
├── test_vitb16_flickr.sh
├── test_vitb16_mscoco.sh
├── test_vitl14_box.sh
├── test_vitl14_flickr.sh
├── test_vitl14_mscoco.sh
├── train_vitb16.sh
└── train_vitl14.sh
├── tools
└── generate_text_embeddings.py
└── training
├── .gitignore
├── __init__.py
├── __pycache__
├── __init__.cpython-38.pyc
├── __init__.cpython-39.pyc
├── clipself.cpython-38.pyc
├── clipself.cpython-39.pyc
├── coco_api.cpython-38.pyc
├── coco_api.cpython-39.pyc
├── custom_transforms.cpython-38.pyc
├── custom_transforms.cpython-39.pyc
├── data.cpython-38.pyc
├── data.cpython-39.pyc
├── dist_utils.cpython-38.pyc
├── dist_utils.cpython-39.pyc
├── distributed.cpython-38.pyc
├── distributed.cpython-39.pyc
├── file_utils.cpython-38.pyc
├── file_utils.cpython-39.pyc
├── logger.cpython-38.pyc
├── logger.cpython-39.pyc
├── main.cpython-38.pyc
├── main.cpython-39.pyc
├── params.cpython-38.pyc
├── params.cpython-39.pyc
├── precision.cpython-38.pyc
├── precision.cpython-39.pyc
├── region_clip.cpython-38.pyc
├── region_clip.cpython-39.pyc
├── scheduler.cpython-38.pyc
├── scheduler.cpython-39.pyc
├── test_flickr.cpython-38.pyc
├── test_mscoco.cpython-38.pyc
├── train.cpython-38.pyc
├── train.cpython-39.pyc
├── utils.cpython-38.pyc
├── utils.cpython-39.pyc
├── zero_shot.cpython-38.pyc
└── zero_shot.cpython-39.pyc
├── clipself.py
├── coco_api.py
├── custom_transforms.py
├── data.py
├── dist_utils.py
├── distributed.py
├── file_utils.py
├── logger.py
├── main.py
├── params.py
├── precision.py
├── profile.py
├── region_clip.py
├── scheduler.py
├── test_flickr.py
├── test_mscoco.py
├── train.py
├── utils.py
└── zero_shot.py
/README.md:
--------------------------------------------------------------------------------
1 | # FineCLIP
2 |
3 | ## Introduction
4 |
5 | Official Release of FineCLIP: Self-distilled Region-based CLIP for Better Fine-grained Understanding **(NIPS2024)**
6 |
7 | > [**FineCLIP: Self-distilled Region-based CLIP for Better Fine-grained Understanding**](https://openreview.net/pdf?id=nExI4FuKWD),
8 | > Dong Jing*, Xiaolong He*, Yutian Luo, Nanyi Fei, Guoxing Yang, Wei Wei, Huiwen Zhao, Zhiwu Lu†
9 |
10 | 
11 |
12 | ## TODO
13 | - [x] Code for training and evaluating FineCLIP on COCO dataset
14 | - [ ] Data generation pipeline of CC2.5M
15 |
16 | ## Environment
17 |
18 | ```bash
19 | conda create -n fineclip python=3.8
20 | conda activate fineclip
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ## Data Preparation
25 |
26 | The main ablation experiments are conducted using images from [COCO](https://cocodataset.org/#home)
27 | and [Flickr](https://shannon.cs.illinois.edu/DenotationGraph/) datasets. Please donwload and organize datasets like the following:
28 |
29 | ```text
30 | FineCLIP/
31 | ├── data
32 | ├── coco
33 | ├── annotations
34 | ├── captions_train2017.json [Image Captions](https://drive.google.com/file/d/1gV7MZxQCkRbmC__FcVr0rGuNIdpacnwJ/view?usp=drive_link)
35 | ├── instances_train2017.json
36 | ├── panoptic_val2017.json
37 | ├── panoptic_val2017
38 | ├── train2017
39 | ├── val2017
40 | ├── coco_captions.json [Region Captions](https://drive.google.com/file/d/1hmXjZ1i8LMTkKg6JPnU5qfT1pIMulFqA/view?usp=drive_link)
41 | ├── coco_proposals.json
42 | ├── coco_test.json [Test Data](https://drive.google.com/file/d/1Qdqe8fINH79A53nb7D81KYski7iLsWyV/view?usp=drive_link)
43 | ├── flickr30k
44 | ├── flickr30k_images
45 | ├── flickr30k_test.json [Test Data](https://drive.google.com/file/d/1zluOdlZ9JJL0XufMIw6i6ls0iW2EXYPf/view?usp=drive_link)
46 | ```
47 |
48 | ## Original Models
49 | To run FineCLIP, first obtain the original models from
50 | [EVA-02-CLIP](https://github.com/baaivision/EVA/tree/master/EVA-CLIP), and put them under
51 | `checkpoints/` like the following:
52 |
53 | ```text
54 | FineCLIP/
55 | ├── checkpoints
56 | ├── EVA02_CLIP_B_psz16_s8B.pt
57 | ├── EVA02_CLIP_L_336_psz14_s6B.pt
58 | ```
59 |
60 |
61 | ## Start Training
62 |
63 | After preparing data and model, we can start train FineCLIP by running scripts under [scripts/](scripts).
64 |
65 | If you want to train FineCLIP of ViT-B/16 on COCO, please run:
66 | ```bash
67 | bash scripts/train_vitb16.sh
68 | ```
69 |
70 | If you want to train FineCLIP of ViT-L/14 on COCO, please run:
71 | ```bash
72 | bash scripts/train_vitl14.sh
73 | ```
74 |
75 |
76 | ## Evaluation
77 |
78 | We provide the scripts to evaluate FineCLIP under [scripts/](scripts), they are summarized as follows:
79 |
80 | | # | Model Type | Task | Benchmark | script |
81 | |:---:|:----------:|:-----------------------:|:-------------:|:----------------------------------------:|
82 | | 1 | ViT-B/16 | Box Classification | COCO Panoptic | [script](scripts/test_vitb16_box.sh) |
83 | | 2 | ViT-B/16 | Image-text Retrieval | MSCOCO | [script](scripts/test_vitb16_mscoco.sh) |
84 | | 3 | ViT-B/16 | Image-text Retrieval | Flicker30K | [script](scripts/test_vitb16_flickr.sh) |
85 | | 4 | ViT-L/14 | Box Classification | COCO Panoptic | [script](scripts/test_vitl14_box.sh) |
86 | | 5 | ViT-L/14 | Image-text Retrieval | MSCOCO | [script](scripts/test_vitl14_mscoco.sh) |
87 | | 6 | ViT-L/14 | Image-text Retrieval | Flicker30K | [script](scripts/test_vitl14_flickr.sh) |
88 |
89 |
90 | ## FineCLIP Checkpoints
91 |
92 | We provide FineCLIP checkpoints trained on COCO :)
93 |
94 | [FineCLIP_coco_vitb16.pt](https://drive.google.com/file/d/119eeWzjsE2rpUFBs2rDMZ7QlK1RxZYL2/view?usp=drive_link)
95 |
96 | [FineCLIP_coco_vitl14.pt](https://drive.google.com/file/d/1lSIj5tWVVNVFNzuPHisOAdDQyCXWIzUH/view?usp=drive_link)
97 |
98 | ## Downstream Evaluations
99 |
100 | If you want to evaluate FineCLIP on downstream tasks like open-vocabulary object detection and segmention, please refer to [CLIPSelf](https://github.com/wusize/CLIPSelf/tree/main) and [CAT-Seg](https://github.com/cvlab-kaist/CAT-Seg).
101 | We thank their valuable code bases.
102 |
103 | ## Citation
104 |
105 | ```bibtex
106 | @article{fineclip,
107 | title={FineCLIP: Self-distilled Region-based CLIP for Better Fine-grained Understanding},
108 | author={Dong Jing, Xiaolong He, Yutian Luo, Nanyi Fei, Guoxing Yang, Wei Wei, Huiwen Zhao, Zhiwu Lu},
109 | journal={Advances in Neural Information Processing Systems},
110 | year={2024}
111 | }
112 | ```
113 |
114 | ## Acknowledgement
115 |
116 | We sincerely thank the following excellent works:
117 | [CLIPSelf](https://github.com/wusize/CLIPSelf/tree/main),
118 | [RegionCLIP](https://github.com/microsoft/RegionCLIP),
119 | [OpenCLIP](https://github.com/mlfoundations/open_clip/tree/v2.16.0),
120 | [EVA-CLIP](https://github.com/baaivision/EVA/tree/master/EVA-CLIP).
121 |
122 | ## License
123 |
124 |
125 |
--------------------------------------------------------------------------------
/figs/fineclip_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/figs/fineclip_arch.png
--------------------------------------------------------------------------------
/metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTB16.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTB16.npy
--------------------------------------------------------------------------------
/metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTL14x336.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTL14x336.npy
--------------------------------------------------------------------------------
/open_clip/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/.DS_Store
--------------------------------------------------------------------------------
/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .coca_model import CoCa
2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
8 | from .openai import load_openai_model, list_openai_models
9 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
10 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
11 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
12 | from .tokenizer import SimpleTokenizer, tokenize, decode
13 | from .transform import image_transform, AugmentationCfg
14 |
--------------------------------------------------------------------------------
/open_clip/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/coca_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/coca_model.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/coca_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/coca_model.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/constants.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/constants.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/constants.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/constants.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/factory.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/factory.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/factory.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/hf_configs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/hf_configs.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/hf_configs.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/hf_configs.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/hf_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/hf_model.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/hf_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/hf_model.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/loss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/loss.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/modified_resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/modified_resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/modified_resnet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/modified_resnet.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/openai.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/openai.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/openai.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/openai.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/pretrained.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/pretrained.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/pretrained.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/pretrained.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/push_to_hf_hub.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/push_to_hf_hub.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/timm_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/timm_model.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/timm_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/timm_model.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/tokenizer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/tokenizer.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/transform.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/transform.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/transform.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/transform.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/transformer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/transformer.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/version.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/version.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/__pycache__/version.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/__pycache__/version.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/open_clip/constants.py:
--------------------------------------------------------------------------------
1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3 |
--------------------------------------------------------------------------------
/open_clip/customs.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 | from torch.nn import MultiheadAttention
3 | from torch.nn import functional as F
4 | from typing import Optional, Tuple
5 |
6 |
7 | class MultiheadSelfAttention(MultiheadAttention):
8 | def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
9 | need_weights: bool = True, attn_mask: Optional[Tensor] = None, return_tokens: bool = False) \
10 | -> Tuple[Tensor, Tensor]:
11 | assert query is value and value is key # self-attention
12 | if return_tokens:
13 | # in_projection
14 | tokens = F.linear(value, self.in_proj_weight, bias=self.in_proj_bias)[..., -self.embed_dim:]
15 | # out_projection
16 | tokens = F.linear(tokens, self.out_proj.weight, bias=self.out_proj.bias)
17 | else:
18 | tokens = None
19 |
20 | attn_output, attn_output_weights = F.multi_head_attention_forward(
21 | query=query, key=key, value=value,
22 | embed_dim_to_check=self.embed_dim,
23 | num_heads=self.num_heads,
24 | in_proj_weight=self.in_proj_weight,
25 | in_proj_bias=self.in_proj_bias,
26 | bias_k=None, bias_v=None,
27 | add_zero_attn=False,
28 | dropout_p=0.,
29 | out_proj_weight=self.out_proj.weight,
30 | out_proj_bias=self.out_proj.bias,
31 | training=self.training,
32 | key_padding_mask=key_padding_mask, need_weights=need_weights,
33 | attn_mask=attn_mask)
34 |
35 | return attn_output, tokens # , attn_output_weights
36 |
--------------------------------------------------------------------------------
/open_clip/eva_clip/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/.DS_Store
--------------------------------------------------------------------------------
/open_clip/eva_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4 | from .loss import ClipLoss
5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
6 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7 | from .openai import load_openai_model, list_openai_models
8 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10 | from .tokenizer import SimpleTokenizer, tokenize
11 | from .transform import image_transform
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/constants.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/constants.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/constants.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/constants.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/eva_vit_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/eva_vit_model.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/eva_vit_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/eva_vit_model.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/factory.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/factory.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/factory.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/hf_configs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/hf_configs.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/hf_configs.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/hf_configs.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/hf_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/hf_model.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/hf_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/hf_model.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/loss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/loss.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/modified_resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/modified_resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/modified_resnet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/modified_resnet.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/openai.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/openai.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/openai.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/openai.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/pretrained.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/pretrained.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/pretrained.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/pretrained.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/rope.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/rope.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/rope.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/rope.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/timm_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/timm_model.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/timm_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/timm_model.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/tokenizer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/tokenizer.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/transform.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/transform.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/transform.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/transform.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/transformer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/transformer.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/open_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/open_clip/eva_clip/constants.py:
--------------------------------------------------------------------------------
1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3 |
--------------------------------------------------------------------------------
/open_clip/eva_clip/hf_configs.py:
--------------------------------------------------------------------------------
1 | # HF architecture dict:
2 | arch_dict = {
3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4 | "roberta": {
5 | "config_names": {
6 | "context_length": "max_position_embeddings",
7 | "vocab_size": "vocab_size",
8 | "width": "hidden_size",
9 | "heads": "num_attention_heads",
10 | "layers": "num_hidden_layers",
11 | "layer_attr": "layer",
12 | "token_embeddings_attr": "embeddings"
13 | },
14 | "pooler": "mean_pooler",
15 | },
16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17 | "xlm-roberta": {
18 | "config_names": {
19 | "context_length": "max_position_embeddings",
20 | "vocab_size": "vocab_size",
21 | "width": "hidden_size",
22 | "heads": "num_attention_heads",
23 | "layers": "num_hidden_layers",
24 | "layer_attr": "layer",
25 | "token_embeddings_attr": "embeddings"
26 | },
27 | "pooler": "mean_pooler",
28 | },
29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30 | "mt5": {
31 | "config_names": {
32 | # unlimited seqlen
33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35 | "context_length": "",
36 | "vocab_size": "vocab_size",
37 | "width": "d_model",
38 | "heads": "num_heads",
39 | "layers": "num_layers",
40 | "layer_attr": "block",
41 | "token_embeddings_attr": "embed_tokens"
42 | },
43 | "pooler": "mean_pooler",
44 | },
45 | "bert": {
46 | "config_names": {
47 | "context_length": "max_position_embeddings",
48 | "vocab_size": "vocab_size",
49 | "width": "hidden_size",
50 | "heads": "num_attention_heads",
51 | "layers": "num_hidden_layers",
52 | "layer_attr": "layer",
53 | "token_embeddings_attr": "embeddings"
54 | },
55 | "pooler": "mean_pooler",
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/open_clip/eva_clip/loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 |
6 | try:
7 | import torch.distributed.nn
8 | from torch import distributed as dist
9 | has_distributed = True
10 | except ImportError:
11 | has_distributed = False
12 |
13 | try:
14 | import horovod.torch as hvd
15 | except ImportError:
16 | hvd = None
17 |
18 | from timm.loss import LabelSmoothingCrossEntropy
19 |
20 |
21 | def gather_features(
22 | image_features,
23 | text_features,
24 | local_loss=False,
25 | gather_with_grad=False,
26 | rank=0,
27 | world_size=1,
28 | use_horovod=False,
29 | skip=True
30 | ):
31 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
32 | if use_horovod:
33 | assert hvd is not None, 'Please install horovod'
34 | if gather_with_grad:
35 | all_image_features = hvd.allgather(image_features)
36 | all_text_features = hvd.allgather(text_features)
37 | else:
38 | with torch.no_grad():
39 | all_image_features = hvd.allgather(image_features)
40 | all_text_features = hvd.allgather(text_features)
41 | if not local_loss:
42 | # ensure grads for local rank when all_* features don't have a gradient
43 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
44 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
45 | gathered_image_features[rank] = image_features
46 | gathered_text_features[rank] = text_features
47 | all_image_features = torch.cat(gathered_image_features, dim=0)
48 | all_text_features = torch.cat(gathered_text_features, dim=0)
49 | else:
50 | # We gather tensors from all gpus
51 | if gather_with_grad:
52 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
53 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
54 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
55 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
56 | else:
57 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
58 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
59 | dist.all_gather(gathered_image_features, image_features)
60 | dist.all_gather(gathered_text_features, text_features)
61 | if not local_loss:
62 | # ensure grads for local rank when all_* features don't have a gradient
63 | gathered_image_features[rank] = image_features
64 | gathered_text_features[rank] = text_features
65 | all_image_features = torch.cat(gathered_image_features, dim=0)
66 | all_text_features = torch.cat(gathered_text_features, dim=0)
67 | if skip == False:
68 | return all_image_features, all_text_features
69 | all_image_feature_list = []
70 | all_text_feature_list = []
71 | for image_feature in all_image_features:
72 | if (image_feature==torch.zeros(image_feature.shape[0], device=image_feature.device)).all():
73 | pass
74 | else:
75 | all_image_feature_list.append(image_feature.unsqueeze(0))
76 |
77 | for text_feature in all_text_features:
78 | if (text_feature==torch.zeros(text_feature.shape[0], device=text_feature.device)).all():
79 | pass
80 | else:
81 | all_text_feature_list.append(text_feature.unsqueeze(0))
82 | all_image_features_new = torch.cat(all_image_feature_list, dim=0)
83 | all_text_features_new = torch.cat(all_text_feature_list, dim=0)
84 | return all_image_features_new, all_text_features_new
85 | # return all_image_features, all_text_features
86 |
87 |
88 | class ClipLoss(nn.Module):
89 |
90 | def __init__(
91 | self,
92 | local_loss=False,
93 | gather_with_grad=False,
94 | cache_labels=False,
95 | rank=0,
96 | world_size=1,
97 | use_horovod=False,
98 | smoothing=0.,
99 | ):
100 | super().__init__()
101 | self.local_loss = local_loss
102 | self.gather_with_grad = gather_with_grad
103 | self.cache_labels = cache_labels
104 | self.rank = rank
105 | self.world_size = world_size
106 | self.use_horovod = use_horovod
107 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
108 |
109 | # cache state
110 | self.prev_num_logits = 0
111 | self.labels = {}
112 |
113 | def forward(self, image_features, text_features, logit_scale=1., skip=True):
114 | device = image_features.device
115 | if self.world_size > 1:
116 | all_image_features, all_text_features = gather_features(
117 | image_features, text_features,
118 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod, skip)
119 |
120 | if self.local_loss:
121 | logits_per_image = logit_scale * image_features @ all_text_features.T
122 | logits_per_text = logit_scale * text_features @ all_image_features.T
123 | else:
124 | logits_per_image = logit_scale * all_image_features @ all_text_features.T
125 | logits_per_text = logits_per_image.T
126 | else:
127 | logits_per_image = logit_scale * image_features @ text_features.T
128 | logits_per_text = logit_scale * text_features @ image_features.T
129 | # calculated ground-truth and cache if enabled
130 | num_logits = logits_per_image.shape[0]
131 | if self.prev_num_logits != num_logits or device not in self.labels:
132 | labels = torch.arange(num_logits, device=device, dtype=torch.long)
133 | if self.world_size > 1 and self.local_loss:
134 | labels = labels + num_logits * self.rank
135 | if self.cache_labels:
136 | self.labels[device] = labels
137 | self.prev_num_logits = num_logits
138 | else:
139 | labels = self.labels[device]
140 |
141 | if self.label_smoothing_cross_entropy:
142 | total_loss = (
143 | self.label_smoothing_cross_entropy(logits_per_image, labels) +
144 | self.label_smoothing_cross_entropy(logits_per_text, labels)
145 | ) / 2
146 | else:
147 | total_loss = (
148 | F.cross_entropy(logits_per_image, labels) +
149 | F.cross_entropy(logits_per_text, labels)
150 | ) / 2
151 |
152 | acc = None
153 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
154 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
155 | acc = {"i2t": i2t_acc, "t2i": t2i_acc}
156 | return total_loss, acc
--------------------------------------------------------------------------------
/open_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16,
8 | "eva_model_name": "eva-clip-b-16",
9 | "ls_init_value": 0.1,
10 | "drop_path_rate": 0.0
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/open_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 1024,
19 | "heads": 16,
20 | "layers": 24,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/open_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0.4,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 768,
19 | "heads": 12,
20 | "layers": 12,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/open_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "head_width": 64,
8 | "patch_size": 16,
9 | "mlp_ratio": 2.6667,
10 | "eva_model_name": "eva-clip-b-16-X",
11 | "drop_path_rate": 0.0,
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 512,
24 | "heads": 8,
25 | "layers": 12,
26 | "xattn": true,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/open_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14-336",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/open_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/open_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1280,
20 | "heads": 20,
21 | "layers": 32,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/open_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1024,
20 | "heads": 16,
21 | "layers": 24,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
--------------------------------------------------------------------------------
/open_clip/eva_clip/modified_resnet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from open_clip.eva_clip.utils import freeze_batch_norm_2d
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.act1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.act2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.act3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.act1(self.bn1(self.conv1(x)))
46 | out = self.act2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.act3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x, key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0.,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 |
92 | return x[0]
93 |
94 |
95 | class ModifiedResNet(nn.Module):
96 | """
97 | A ResNet class that is similar to torchvision's but contains the following changes:
98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100 | - The final pooling layer is a QKV attention instead of an average pool
101 | """
102 |
103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104 | super().__init__()
105 | self.output_dim = output_dim
106 | self.image_size = image_size
107 |
108 | # the 3-layer stem
109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110 | self.bn1 = nn.BatchNorm2d(width // 2)
111 | self.act1 = nn.ReLU(inplace=True)
112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113 | self.bn2 = nn.BatchNorm2d(width // 2)
114 | self.act2 = nn.ReLU(inplace=True)
115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116 | self.bn3 = nn.BatchNorm2d(width)
117 | self.act3 = nn.ReLU(inplace=True)
118 | self.avgpool = nn.AvgPool2d(2)
119 |
120 | # residual layers
121 | self._inplanes = width # this is a *mutable* variable used during construction
122 | self.layer1 = self._make_layer(width, layers[0])
123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126 |
127 | embed_dim = width * 32 # the ResNet feature dimension
128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129 |
130 | self.init_parameters()
131 |
132 | def _make_layer(self, planes, blocks, stride=1):
133 | layers = [Bottleneck(self._inplanes, planes, stride)]
134 |
135 | self._inplanes = planes * Bottleneck.expansion
136 | for _ in range(1, blocks):
137 | layers.append(Bottleneck(self._inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def init_parameters(self):
142 | if self.attnpool is not None:
143 | std = self.attnpool.c_proj.in_features ** -0.5
144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148 |
149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150 | for name, param in resnet_block.named_parameters():
151 | if name.endswith("bn3.weight"):
152 | nn.init.zeros_(param)
153 |
154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156 | for param in self.parameters():
157 | param.requires_grad = False
158 | if freeze_bn_stats:
159 | freeze_batch_norm_2d(self)
160 |
161 | @torch.jit.ignore
162 | def set_grad_checkpointing(self, enable=True):
163 | # FIXME support for non-transformer
164 | pass
165 |
166 | def stem(self, x):
167 | x = self.act1(self.bn1(self.conv1(x)))
168 | x = self.act2(self.bn2(self.conv2(x)))
169 | x = self.act3(self.bn3(self.conv3(x)))
170 | x = self.avgpool(x)
171 | return x
172 |
173 | def forward(self, x):
174 | x = self.stem(x)
175 | x = self.layer1(x)
176 | x = self.layer2(x)
177 | x = self.layer3(x)
178 | x = self.layer4(x)
179 | x = self.attnpool(x)
180 |
181 | return x
182 |
--------------------------------------------------------------------------------
/open_clip/eva_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import List, Optional, Union
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14 |
15 | __all__ = ["list_openai_models", "load_openai_model"]
16 |
17 |
18 | def list_openai_models() -> List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_models_by_tag('openai')
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | precision: Optional[str] = None,
26 | device: Optional[Union[str, torch.device]] = None,
27 | jit: bool = True,
28 | cache_dir: Optional[str] = None,
29 | ):
30 | """Load a CLIP model
31 |
32 | Parameters
33 | ----------
34 | name : str
35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 | precision: str
37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 | device : Union[str, torch.device]
39 | The device to put the loaded model
40 | jit : bool
41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42 | cache_dir : Optional[str]
43 | The directory to cache the downloaded model weights
44 |
45 | Returns
46 | -------
47 | model : torch.nn.Module
48 | The CLIP model
49 | preprocess : Callable[[PIL.Image], torch.Tensor]
50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51 | """
52 | if device is None:
53 | device = "cuda" if torch.cuda.is_available() else "cpu"
54 | if precision is None:
55 | precision = 'fp32' if device == 'cpu' else 'fp16'
56 |
57 | if get_pretrained_url(name, 'openai'):
58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59 | elif os.path.isfile(name):
60 | model_path = name
61 | else:
62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63 |
64 | try:
65 | # loading JIT archive
66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67 | state_dict = None
68 | except RuntimeError:
69 | # loading saved state dict
70 | if jit:
71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72 | jit = False
73 | state_dict = torch.load(model_path, map_location="cpu")
74 |
75 | if not jit:
76 | # Build a non-jit model from the OpenAI jitted model state dict
77 | cast_dtype = get_cast_dtype(precision)
78 | try:
79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80 | except KeyError:
81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83 |
84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85 | model = model.to(device)
86 | if precision.startswith('amp') or precision == 'fp32':
87 | model.float()
88 | elif precision == 'bf16':
89 | convert_weights_to_lp(model, dtype=torch.bfloat16)
90 |
91 | return model
92 |
93 | # patch the device names
94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96 |
97 | def patch_device(module):
98 | try:
99 | graphs = [module.graph] if hasattr(module, "graph") else []
100 | except RuntimeError:
101 | graphs = []
102 |
103 | if hasattr(module, "forward1"):
104 | graphs.append(module.forward1.graph)
105 |
106 | for graph in graphs:
107 | for node in graph.findAllNodes("prim::Constant"):
108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109 | node.copyAttributes(device_node)
110 |
111 | model.apply(patch_device)
112 | patch_device(model.encode_image)
113 | patch_device(model.encode_text)
114 |
115 | # patch dtype to float32 (typically for CPU)
116 | if precision == 'fp32':
117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119 | float_node = float_input.node()
120 |
121 | def patch_float(module):
122 | try:
123 | graphs = [module.graph] if hasattr(module, "graph") else []
124 | except RuntimeError:
125 | graphs = []
126 |
127 | if hasattr(module, "forward1"):
128 | graphs.append(module.forward1.graph)
129 |
130 | for graph in graphs:
131 | for node in graph.findAllNodes("aten::to"):
132 | inputs = list(node.inputs())
133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134 | if inputs[i].node()["value"] == 5:
135 | inputs[i].node().copyAttributes(float_node)
136 |
137 | model.apply(patch_float)
138 | patch_float(model.encode_image)
139 | patch_float(model.encode_text)
140 | model.float()
141 |
142 | # ensure image_size attr available at consistent location for both jit and non-jit
143 | model.visual.image_size = model.input_resolution.item()
144 | return model
145 |
--------------------------------------------------------------------------------
/open_clip/eva_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | import logging
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | try:
12 | import timm
13 | from timm.models.layers import Mlp, to_2tuple
14 | try:
15 | # old timm imports < 0.8.1
16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18 | except ImportError:
19 | # new timm imports >= 0.8.1
20 | from timm.layers import RotAttentionPool2d
21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d
22 | except ImportError:
23 | timm = None
24 |
25 | from .utils import freeze_batch_norm_2d
26 |
27 |
28 | class TimmModel(nn.Module):
29 | """ timm model adapter
30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
31 | """
32 |
33 | def __init__(
34 | self,
35 | model_name,
36 | embed_dim,
37 | image_size=224,
38 | pool='avg',
39 | proj='linear',
40 | proj_bias=False,
41 | drop=0.,
42 | pretrained=False):
43 | super().__init__()
44 | if timm is None:
45 | raise RuntimeError("Please `pip install timm` to use timm models.")
46 |
47 | self.image_size = to_2tuple(image_size)
48 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
49 | feat_size = self.trunk.default_cfg.get('pool_size', None)
50 | feature_ndim = 1 if not feat_size else 2
51 | if pool in ('abs_attn', 'rot_attn'):
52 | assert feature_ndim == 2
53 | # if attn pooling used, remove both classifier and default pool
54 | self.trunk.reset_classifier(0, global_pool='')
55 | else:
56 | # reset global pool if pool config set, otherwise leave as network default
57 | reset_kwargs = dict(global_pool=pool) if pool else {}
58 | self.trunk.reset_classifier(0, **reset_kwargs)
59 | prev_chs = self.trunk.num_features
60 |
61 | head_layers = OrderedDict()
62 | if pool == 'abs_attn':
63 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
64 | prev_chs = embed_dim
65 | elif pool == 'rot_attn':
66 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
67 | prev_chs = embed_dim
68 | else:
69 | assert proj, 'projection layer needed if non-attention pooling is used.'
70 |
71 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
72 | if proj == 'linear':
73 | head_layers['drop'] = nn.Dropout(drop)
74 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
75 | elif proj == 'mlp':
76 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
77 |
78 | self.head = nn.Sequential(head_layers)
79 |
80 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
81 | """ lock modules
82 | Args:
83 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
84 | """
85 | if not unlocked_groups:
86 | # lock full model
87 | for param in self.trunk.parameters():
88 | param.requires_grad = False
89 | if freeze_bn_stats:
90 | freeze_batch_norm_2d(self.trunk)
91 | else:
92 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
93 | try:
94 | # FIXME import here until API stable and in an official release
95 | from timm.models.helpers import group_parameters, group_modules
96 | except ImportError:
97 | raise RuntimeError(
98 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
99 | matcher = self.trunk.group_matcher()
100 | gparams = group_parameters(self.trunk, matcher)
101 | max_layer_id = max(gparams.keys())
102 | max_layer_id = max_layer_id - unlocked_groups
103 | for group_idx in range(max_layer_id + 1):
104 | group = gparams[group_idx]
105 | for param in group:
106 | self.trunk.get_parameter(param).requires_grad = False
107 | if freeze_bn_stats:
108 | gmodules = group_modules(self.trunk, matcher, reverse=True)
109 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
110 | freeze_batch_norm_2d(self.trunk, gmodules)
111 |
112 | @torch.jit.ignore
113 | def set_grad_checkpointing(self, enable=True):
114 | try:
115 | self.trunk.set_grad_checkpointing(enable)
116 | except Exception as e:
117 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
118 |
119 | def forward(self, x):
120 | x = self.trunk(x)
121 | x = self.head(x)
122 | return x
123 |
--------------------------------------------------------------------------------
/open_clip/eva_clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | import ftfy
12 | import regex as re
13 | import torch
14 |
15 | # https://stackoverflow.com/q/62691279
16 | import os
17 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
18 |
19 |
20 | @lru_cache()
21 | def default_bpe():
22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23 |
24 |
25 | @lru_cache()
26 | def bytes_to_unicode():
27 | """
28 | Returns list of utf-8 byte and a corresponding list of unicode strings.
29 | The reversible bpe codes work on unicode strings.
30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32 | This is a signficant percentage of your normal, say, 32K bpe vocab.
33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34 | And avoids mapping to whitespace/control characters the bpe code barfs on.
35 | """
36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37 | cs = bs[:]
38 | n = 0
39 | for b in range(2**8):
40 | if b not in bs:
41 | bs.append(b)
42 | cs.append(2**8+n)
43 | n += 1
44 | cs = [chr(n) for n in cs]
45 | return dict(zip(bs, cs))
46 |
47 |
48 | def get_pairs(word):
49 | """Return set of symbol pairs in a word.
50 | Word is represented as tuple of symbols (symbols being variable-length strings).
51 | """
52 | pairs = set()
53 | prev_char = word[0]
54 | for char in word[1:]:
55 | pairs.add((prev_char, char))
56 | prev_char = char
57 | return pairs
58 |
59 |
60 | def basic_clean(text):
61 | text = ftfy.fix_text(text)
62 | text = html.unescape(html.unescape(text))
63 | return text.strip()
64 |
65 |
66 | def whitespace_clean(text):
67 | text = re.sub(r'\s+', ' ', text)
68 | text = text.strip()
69 | return text
70 |
71 |
72 | class SimpleTokenizer(object):
73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74 | self.byte_encoder = bytes_to_unicode()
75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77 | merges = merges[1:49152-256-2+1]
78 | merges = [tuple(merge.split()) for merge in merges]
79 | vocab = list(bytes_to_unicode().values())
80 | vocab = vocab + [v+'' for v in vocab]
81 | for merge in merges:
82 | vocab.append(''.join(merge))
83 | if not special_tokens:
84 | special_tokens = ['', '']
85 | else:
86 | special_tokens = ['', ''] + special_tokens
87 | vocab.extend(special_tokens)
88 | self.encoder = dict(zip(vocab, range(len(vocab))))
89 | self.decoder = {v: k for k, v in self.encoder.items()}
90 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
91 | self.cache = {t:t for t in special_tokens}
92 | special = "|".join(special_tokens)
93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94 |
95 | self.vocab_size = len(self.encoder)
96 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
97 |
98 | def bpe(self, token):
99 | if token in self.cache:
100 | return self.cache[token]
101 | word = tuple(token[:-1]) + ( token[-1] + '',)
102 | pairs = get_pairs(word)
103 |
104 | if not pairs:
105 | return token+''
106 |
107 | while True:
108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109 | if bigram not in self.bpe_ranks:
110 | break
111 | first, second = bigram
112 | new_word = []
113 | i = 0
114 | while i < len(word):
115 | try:
116 | j = word.index(first, i)
117 | new_word.extend(word[i:j])
118 | i = j
119 | except:
120 | new_word.extend(word[i:])
121 | break
122 |
123 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
124 | new_word.append(first+second)
125 | i += 2
126 | else:
127 | new_word.append(word[i])
128 | i += 1
129 | new_word = tuple(new_word)
130 | word = new_word
131 | if len(word) == 1:
132 | break
133 | else:
134 | pairs = get_pairs(word)
135 | word = ' '.join(word)
136 | self.cache[token] = word
137 | return word
138 |
139 | def encode(self, text):
140 | bpe_tokens = []
141 | text = whitespace_clean(basic_clean(text)).lower()
142 | for token in re.findall(self.pat, text):
143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145 | return bpe_tokens
146 |
147 | def decode(self, tokens):
148 | text = ''.join([self.decoder[token] for token in tokens])
149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
150 | return text
151 |
152 |
153 | _tokenizer = SimpleTokenizer()
154 |
155 |
156 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
157 | """
158 | Returns the tokenized representation of given input string(s)
159 |
160 | Parameters
161 | ----------
162 | texts : Union[str, List[str]]
163 | An input string or a list of input strings to tokenize
164 | context_length : int
165 | The context length to use; all CLIP models use 77 as the context length
166 |
167 | Returns
168 | -------
169 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
170 | """
171 | if isinstance(texts, str):
172 | texts = [texts]
173 |
174 | sot_token = _tokenizer.encoder[""]
175 | eot_token = _tokenizer.encoder[""]
176 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
177 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
178 |
179 | for i, tokens in enumerate(all_tokens):
180 | if len(tokens) > context_length:
181 | tokens = tokens[:context_length] # Truncate
182 | tokens[-1] = eot_token
183 | result[i, :len(tokens)] = torch.tensor(tokens)
184 |
185 | return result
186 |
187 |
188 | class HFTokenizer:
189 | "HuggingFace tokenizer wrapper"
190 | def __init__(self, tokenizer_name:str):
191 | from transformers import AutoTokenizer
192 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
193 |
194 | def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
195 | # same cleaning as for default tokenizer, except lowercasing
196 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
197 | if isinstance(texts, str):
198 | texts = [texts]
199 | texts = [whitespace_clean(basic_clean(text)) for text in texts]
200 | input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
201 | return input_ids
202 |
--------------------------------------------------------------------------------
/open_clip/eva_clip/transform.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.transforms.functional as F
6 |
7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8 | CenterCrop
9 |
10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11 |
12 |
13 | class ResizeMaxSize(nn.Module):
14 |
15 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16 | super().__init__()
17 | if not isinstance(max_size, int):
18 | raise TypeError(f"Size should be int. Got {type(max_size)}")
19 | self.max_size = max_size
20 | self.interpolation = interpolation
21 | self.fn = min if fn == 'min' else min
22 | self.fill = fill
23 |
24 | def forward(self, img):
25 | if isinstance(img, torch.Tensor):
26 | height, width = img.shape[:2]
27 | else:
28 | width, height = img.size
29 | scale = self.max_size / float(max(height, width))
30 | if scale != 1.0:
31 | new_size = tuple(round(dim * scale) for dim in (height, width))
32 | img = F.resize(img, new_size, self.interpolation)
33 | pad_h = self.max_size - new_size[0]
34 | pad_w = self.max_size - new_size[1]
35 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36 | return img
37 |
38 |
39 | def _convert_to_rgb(image):
40 | return image.convert('RGB')
41 |
42 |
43 | # class CatGen(nn.Module):
44 | # def __init__(self, num=4):
45 | # self.num = num
46 | # def mixgen_batch(image, text):
47 | # batch_size = image.shape[0]
48 | # index = np.random.permutation(batch_size)
49 |
50 | # cat_images = []
51 | # for i in range(batch_size):
52 | # # image mixup
53 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54 | # # text concat
55 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56 | # text = torch.stack(text)
57 | # return image, text
58 |
59 |
60 | def image_transform(
61 | image_size: int,
62 | is_train: bool,
63 | mean: Optional[Tuple[float, ...]] = None,
64 | std: Optional[Tuple[float, ...]] = None,
65 | resize_longest_max: bool = False,
66 | fill_color: int = 0,
67 | ):
68 | mean = mean or OPENAI_DATASET_MEAN
69 | if not isinstance(mean, (list, tuple)):
70 | mean = (mean,) * 3
71 |
72 | std = std or OPENAI_DATASET_STD
73 | if not isinstance(std, (list, tuple)):
74 | std = (std,) * 3
75 |
76 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78 | image_size = image_size[0]
79 |
80 | normalize = Normalize(mean=mean, std=std)
81 | if is_train:
82 | return Compose([
83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84 | _convert_to_rgb,
85 | ToTensor(),
86 | normalize,
87 | ])
88 | else:
89 | if resize_longest_max:
90 | transforms = [
91 | ResizeMaxSize(image_size, fill=fill_color)
92 | ]
93 | else:
94 | transforms = [
95 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96 | CenterCrop(image_size),
97 | ]
98 | transforms.extend([
99 | _convert_to_rgb,
100 | ToTensor(),
101 | normalize,
102 | ])
103 | return Compose(transforms)
104 |
--------------------------------------------------------------------------------
/open_clip/generation_utils.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/open_clip/generation_utils.py
--------------------------------------------------------------------------------
/open_clip/get_text_embedding_voc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | sys.path.append("..")
4 | from open_clip import create_model_and_transforms, get_tokenizer
5 | from PIL import Image
6 | import json
7 | import tqdm
8 | import os
9 | import numpy as np
10 |
11 | model_name = "EVA02-CLIP-B-16"
12 | pretrained = "/home/xiaolong_he/works/paper/logs/new_ours_only_itc_lr1e_5_bs32_epoch10/checkpoints/epoch_10.pt" # or "/path/to/EVA02_CLIP_B_psz16_s8B.pt"
13 |
14 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
15 | model, _, preprocess = create_model_and_transforms(
16 | "EVA02-CLIP-B-16",
17 | "eva",
18 | "amp",
19 | device="cpu",
20 | jit=False,
21 | force_quick_gelu=False,
22 | force_custom_text=False,
23 | force_patch_dropout=None,
24 | force_image_size=None,
25 | pretrained_image=False,
26 | image_mean=None,
27 | image_std=None,
28 | aug_cfg={},
29 | output_dict=True,
30 | cache_dir=pretrained,
31 | det_image_size=224,
32 | dataset_type="grid_distill",
33 | )
34 |
35 |
36 | tokenizer = get_tokenizer(model_name)
37 | model = model.to(device)
38 | text = torch.tensor(np.load("/home/xiaolong_he/works/paper/data/voc/text_tokens.npy")).to(device)
39 |
40 | text_features = model.encode_text(text)
41 | np.save("/home/xiaolong_he/works/paper/data/voc/text_embedding.npy", text_features.cpu().numpy())
42 | print(text_features.shape)
43 |
44 |
45 |
--------------------------------------------------------------------------------
/open_clip/hf_configs.py:
--------------------------------------------------------------------------------
1 | # HF architecture dict:
2 | arch_dict = {
3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4 | "roberta": {
5 | "config_names": {
6 | "context_length": "max_position_embeddings",
7 | "vocab_size": "vocab_size",
8 | "width": "hidden_size",
9 | "heads": "num_attention_heads",
10 | "layers": "num_hidden_layers",
11 | "layer_attr": "layer",
12 | "token_embeddings_attr": "embeddings"
13 | },
14 | "pooler": "mean_pooler",
15 | },
16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17 | "xlm-roberta": {
18 | "config_names": {
19 | "context_length": "max_position_embeddings",
20 | "vocab_size": "vocab_size",
21 | "width": "hidden_size",
22 | "heads": "num_attention_heads",
23 | "layers": "num_hidden_layers",
24 | "layer_attr": "layer",
25 | "token_embeddings_attr": "embeddings"
26 | },
27 | "pooler": "mean_pooler",
28 | },
29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30 | "mt5": {
31 | "config_names": {
32 | # unlimited seqlen
33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35 | "context_length": "",
36 | "vocab_size": "vocab_size",
37 | "width": "d_model",
38 | "heads": "num_heads",
39 | "layers": "num_layers",
40 | "layer_attr": "block",
41 | "token_embeddings_attr": "embed_tokens"
42 | },
43 | "pooler": "mean_pooler",
44 | },
45 | }
46 |
--------------------------------------------------------------------------------
/open_clip/hf_model.py:
--------------------------------------------------------------------------------
1 | """ huggingface model adapter
2 |
3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4 | """
5 |
6 | import re
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch import TensorType
11 |
12 | try:
13 | import transformers
14 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
15 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16 | BaseModelOutputWithPoolingAndCrossAttentions
17 | except ImportError as e:
18 | transformers = None
19 |
20 |
21 | class BaseModelOutput:
22 | pass
23 |
24 |
25 | class PretrainedConfig:
26 | pass
27 |
28 | from .hf_configs import arch_dict
29 |
30 |
31 | # utils
32 | def _camel2snake(s):
33 | return re.sub(r'(? List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_models_by_tag('openai')
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | precision: Optional[str] = None,
26 | device: Optional[Union[str, torch.device]] = None,
27 | jit: bool = True,
28 | cache_dir: Optional[str] = None,
29 | ):
30 | """Load a CLIP model
31 |
32 | Parameters
33 | ----------
34 | name : str
35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 | precision: str
37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 | device : Union[str, torch.device]
39 | The device to put the loaded model
40 | jit : bool
41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42 | cache_dir : Optional[str]
43 | The directory to cache the downloaded model weights
44 |
45 | Returns
46 | -------
47 | model : torch.nn.Module
48 | The CLIP model
49 | preprocess : Callable[[PIL.Image], torch.Tensor]
50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51 | """
52 | if device is None:
53 | device = "cuda" if torch.cuda.is_available() else "cpu"
54 | if precision is None:
55 | precision = 'fp32' if device == 'cpu' else 'fp16'
56 | print(get_pretrained_url(name, 'openai'))
57 | if get_pretrained_url(name, 'openai'):
58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59 | elif os.path.isfile(name):
60 | model_path = name
61 | else:
62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63 |
64 | try:
65 | # loading JIT archive
66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67 | state_dict = None
68 | except RuntimeError:
69 | # loading saved state dict
70 | if jit:
71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72 | jit = False
73 | state_dict = torch.load(model_path, map_location="cpu")
74 |
75 | if not jit:
76 | # Build a non-jit model from the OpenAI jitted model state dict
77 | cast_dtype = get_cast_dtype(precision)
78 | try:
79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80 | except KeyError:
81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83 |
84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85 | model = model.to(device)
86 | if precision.startswith('amp') or precision == 'fp32':
87 | model.float()
88 | elif precision == 'bf16':
89 | convert_weights_to_lp(model, dtype=torch.bfloat16)
90 |
91 | return model
92 |
93 | # patch the device names
94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96 |
97 | def patch_device(module):
98 | try:
99 | graphs = [module.graph] if hasattr(module, "graph") else []
100 | except RuntimeError:
101 | graphs = []
102 |
103 | if hasattr(module, "forward1"):
104 | graphs.append(module.forward1.graph)
105 |
106 | for graph in graphs:
107 | for node in graph.findAllNodes("prim::Constant"):
108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109 | node.copyAttributes(device_node)
110 |
111 | model.apply(patch_device)
112 | patch_device(model.encode_image)
113 | patch_device(model.encode_text)
114 |
115 | # patch dtype to float32 (typically for CPU)
116 | if precision == 'fp32':
117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119 | float_node = float_input.node()
120 |
121 | def patch_float(module):
122 | try:
123 | graphs = [module.graph] if hasattr(module, "graph") else []
124 | except RuntimeError:
125 | graphs = []
126 |
127 | if hasattr(module, "forward1"):
128 | graphs.append(module.forward1.graph)
129 |
130 | for graph in graphs:
131 | for node in graph.findAllNodes("aten::to"):
132 | inputs = list(node.inputs())
133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134 | if inputs[i].node()["value"] == 5:
135 | inputs[i].node().copyAttributes(float_node)
136 |
137 | model.apply(patch_float)
138 | patch_float(model.encode_image)
139 | patch_float(model.encode_text)
140 | model.float()
141 |
142 | # ensure image_size attr available at consistent location for both jit and non-jit
143 | model.visual.image_size = model.input_resolution.item()
144 | return model
145 |
--------------------------------------------------------------------------------
/open_clip/transform.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from dataclasses import dataclass, asdict
3 | from typing import Any, Dict, Optional, Sequence, Tuple, Union
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.transforms.functional as F
8 |
9 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
10 | CenterCrop
11 |
12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13 |
14 |
15 | @dataclass
16 | class AugmentationCfg:
17 | scale: Tuple[float, float] = (0.9, 1.0)
18 | ratio: Optional[Tuple[float, float]] = None
19 | color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
20 | interpolation: Optional[str] = None
21 | re_prob: Optional[float] = None
22 | re_count: Optional[int] = None
23 | use_timm: bool = False
24 |
25 |
26 | class ResizeMaxSize(nn.Module):
27 |
28 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
29 | super().__init__()
30 | if not isinstance(max_size, int):
31 | raise TypeError(f"Size should be int. Got {type(max_size)}")
32 | self.max_size = max_size
33 | self.interpolation = interpolation
34 | self.fn = min if fn == 'min' else min
35 | self.fill = fill
36 |
37 | def forward(self, img):
38 | if isinstance(img, torch.Tensor):
39 | height, width = img.shape[:2]
40 | else:
41 | width, height = img.size
42 | scale = self.max_size / float(max(height, width))
43 | new_size = tuple(round(dim * scale) for dim in (height, width))
44 | img = F.resize(img, new_size, self.interpolation)
45 | pad_h = self.max_size - new_size[0]
46 | pad_w = self.max_size - new_size[1]
47 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill)
48 |
49 | return img
50 |
51 |
52 | def _convert_to_rgb(image):
53 | return image.convert('RGB')
54 |
55 |
56 | def image_transform(
57 | image_size: int,
58 | is_train: bool,
59 | mean: Optional[Tuple[float, ...]] = None,
60 | std: Optional[Tuple[float, ...]] = None,
61 | resize_longest_max: bool = False,
62 | fill_color: int = 0,
63 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
64 | ):
65 | mean = mean or OPENAI_DATASET_MEAN
66 | if not isinstance(mean, (list, tuple)):
67 | mean = (mean,) * 3
68 |
69 | std = std or OPENAI_DATASET_STD
70 | if not isinstance(std, (list, tuple)):
71 | std = (std,) * 3
72 |
73 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
74 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
75 | image_size = image_size[0]
76 |
77 | if isinstance(aug_cfg, dict):
78 | aug_cfg = AugmentationCfg(**aug_cfg)
79 | else:
80 | aug_cfg = aug_cfg or AugmentationCfg()
81 | normalize = Normalize(mean=mean, std=std)
82 | if is_train:
83 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
84 | use_timm = aug_cfg_dict.pop('use_timm', False)
85 | if use_timm:
86 | from timm.data import create_transform # timm can still be optional
87 | if isinstance(image_size, (tuple, list)):
88 | assert len(image_size) >= 2
89 | input_size = (3,) + image_size[-2:]
90 | else:
91 | input_size = (3, image_size, image_size)
92 | # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
93 | aug_cfg_dict.setdefault('interpolation', 'random')
94 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default
95 | train_transform = create_transform(
96 | input_size=input_size,
97 | is_training=True,
98 | hflip=0.,
99 | mean=mean,
100 | std=std,
101 | re_mode='pixel',
102 | **aug_cfg_dict,
103 | )
104 | else:
105 | train_transform = Compose([
106 | RandomResizedCrop(
107 | image_size,
108 | scale=aug_cfg_dict.pop('scale'),
109 | interpolation=InterpolationMode.BICUBIC,
110 | ),
111 | _convert_to_rgb,
112 | ToTensor(),
113 | normalize,
114 | ])
115 | if aug_cfg_dict:
116 | warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
117 | return train_transform
118 | else:
119 | if resize_longest_max:
120 | transforms = [
121 | ResizeMaxSize(image_size, fill=fill_color)
122 | ]
123 | else:
124 | transforms = [
125 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
126 | CenterCrop(image_size),
127 | ]
128 | transforms.extend([
129 | _convert_to_rgb,
130 | ToTensor(),
131 | normalize,
132 | ])
133 | return Compose(transforms)
134 |
135 |
136 | def det_image_transform(
137 | image_size: int,
138 | is_train: bool,
139 | mean: Optional[Tuple[float, ...]] = None,
140 | std: Optional[Tuple[float, ...]] = None,
141 | fill_color: int = 0,
142 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
143 | ):
144 | mean = mean or OPENAI_DATASET_MEAN
145 | if not isinstance(mean, (list, tuple)):
146 | mean = (mean,) * 3
147 |
148 | std = std or OPENAI_DATASET_STD
149 | if not isinstance(std, (list, tuple)):
150 | std = (std,) * 3
151 |
152 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
153 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
154 | image_size = image_size[0]
155 |
156 | normalize = Normalize(mean=mean, std=std)
157 | if is_train:
158 | raise NotImplementedError
159 | else:
160 | transforms = [
161 | Resize((image_size, image_size)),
162 | _convert_to_rgb,
163 | ToTensor(),
164 | normalize,
165 | ]
166 | return Compose(transforms)
167 |
168 |
169 | class ResizeLongest(nn.Module):
170 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fill=0):
171 | super().__init__()
172 | if not isinstance(max_size, int):
173 | raise TypeError(f"Size should be int. Got {type(max_size)}")
174 | self.max_size = max_size
175 | self.interpolation = interpolation
176 | self.fill = fill
177 |
178 | def forward(self, img):
179 | if isinstance(img, torch.Tensor):
180 | height, width = img.shape[1:]
181 | else:
182 | width, height = img.size
183 | scale = self.max_size / float(max(height, width))
184 | new_height, new_width = round(height * scale), round(width * scale)
185 |
186 | img = F.resize(img, [new_height, new_width], self.interpolation)
187 | pad_h = self.max_size - new_height
188 | pad_w = self.max_size - new_width
189 | img = F.pad(img, padding=[0, 0, pad_w, pad_h], fill=self.fill)
190 |
191 | return img
192 |
193 |
194 | def get_scale(img, new_image):
195 | if isinstance(img, torch.Tensor):
196 | height, width = new_image.shape[-2:]
197 | else:
198 | width, height = img.size
199 |
200 | if isinstance(new_image, torch.Tensor):
201 | new_height, new_width = new_image.shape[-2:]
202 | else:
203 | new_width, new_height = new_image.size
204 |
205 | scale = (new_height/height, new_width/width)
206 |
207 | return scale
--------------------------------------------------------------------------------
/open_clip/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 | import collections.abc
3 |
4 | from torch import nn as nn
5 | from torchvision.ops.misc import FrozenBatchNorm2d
6 |
7 |
8 | def freeze_batch_norm_2d(module, module_match={}, name=''):
9 | """
10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
12 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
13 |
14 | Args:
15 | module (torch.nn.Module): Any PyTorch module.
16 | module_match (dict): Dictionary of full module names to freeze (all if empty)
17 | name (str): Full module name (prefix)
18 |
19 | Returns:
20 | torch.nn.Module: Resulting module
21 |
22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
23 | """
24 | res = module
25 | is_match = True
26 | if module_match:
27 | is_match = name in module_match
28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
29 | res = FrozenBatchNorm2d(module.num_features)
30 | res.num_features = module.num_features
31 | res.affine = module.affine
32 | if module.affine:
33 | res.weight.data = module.weight.data.clone().detach()
34 | res.bias.data = module.bias.data.clone().detach()
35 | res.running_mean.data = module.running_mean.data
36 | res.running_var.data = module.running_var.data
37 | res.eps = module.eps
38 | else:
39 | for child_name, child in module.named_children():
40 | full_child_name = '.'.join([name, child_name]) if name else child_name
41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
42 | if new_child is not child:
43 | res.add_module(child_name, new_child)
44 | return res
45 |
46 |
47 | # From PyTorch internals
48 | def _ntuple(n):
49 | def parse(x):
50 | if isinstance(x, collections.abc.Iterable):
51 | return x
52 | return tuple(repeat(x, n))
53 | return parse
54 |
55 |
56 | to_1tuple = _ntuple(1)
57 | to_2tuple = _ntuple(2)
58 | to_3tuple = _ntuple(3)
59 | to_4tuple = _ntuple(4)
60 | to_ntuple = lambda n, x: _ntuple(n)(x)
61 |
--------------------------------------------------------------------------------
/open_clip/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '2.16.0'
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.1.2
2 | torchvision==0.16.2
3 | regex
4 | ftfy
5 | tqdm
6 | timm
7 | einops
8 | pycocotools
9 | xformers==0.0.23.post1
10 | panopticapi@git+https://githubfast.com/cocodataset/panopticapi.git
--------------------------------------------------------------------------------
/scripts/test_vitb16_box.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 8 -m training.main --batch-size=32 \
2 | --model EVA02-CLIP-B-16 --pretrained eva --test-type coco_panoptic --train-data="" \
3 | --val-data data/coco/annotations/panoptic_val2017.json \
4 | --embed-path metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTB16.npy \
5 | --val-image-root data/coco/val2017 --cache-dir checkpoints/coco_vitb16.pt --extract-type="v2" \
6 | --name test_vitb16 --downsample-factor 16 --det-image-size 224
--------------------------------------------------------------------------------
/scripts/test_vitb16_flickr.sh:
--------------------------------------------------------------------------------
1 | python -m training.test_flickr \
2 | --model-name "EVA02-CLIP-B-16" \
3 | --pretrained "./checkpoints/coco_vitb16.pt" \
4 | --data "./data/flickr30k/flickr30k_test.json" \
5 | --image-path "./data/flickr30k/flickr30k_images" \
6 | --image-size 224 \
7 | --device "cuda:0"
--------------------------------------------------------------------------------
/scripts/test_vitb16_mscoco.sh:
--------------------------------------------------------------------------------
1 | python -m training.test_mscoco \
2 | --model-name "EVA02-CLIP-B-16" \
3 | --pretrained "./checkpoints/coco_vitb16.pt" \
4 | --data "./data/coco/coco_test.json" \
5 | --image-path "./data/coco/val2017" \
6 | --image-size 224 \
7 | --device "cuda:0"
--------------------------------------------------------------------------------
/scripts/test_vitl14_box.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 8 -m training.main --batch-size=32 \
2 | --model EVA02-CLIP-L-14-336 --pretrained eva --test-type coco_panoptic --train-data="" \
3 | --val-data data/coco/annotations/panoptic_val2017.json \
4 | --embed-path metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTL14x336.npy \
5 | --val-image-root data/coco/val2017 --cache-dir checkpoints/coco_vitl14.pt --extract-type="v2" \
6 | --name test_vitl14 --downsample-factor 14 --det-image-size 336
--------------------------------------------------------------------------------
/scripts/test_vitl14_flickr.sh:
--------------------------------------------------------------------------------
1 | python -m training.test_flickr \
2 | --model-name "EVA02-CLIP-L-14-336" \
3 | --pretrained "./checkpoints/coco_vitl14.pt" \
4 | --data "./data/flickr30k/flickr30k_test.json" \
5 | --image-path "./data/flickr30k/flickr30k_images" \
6 | --image-size 336 \
7 | --device "cuda:0"
--------------------------------------------------------------------------------
/scripts/test_vitl14_mscoco.sh:
--------------------------------------------------------------------------------
1 | python -m training.test_mscoco \
2 | --model-name "EVA02-CLIP-L-14-336" \
3 | --pretrained "./checkpoints/coco_vitl14.pt" \
4 | --data "./data/coco/coco_test.json" \
5 | --image-path "./data/coco/val2017" \
6 | --image-size 336 \
7 | --device "cuda:0"
--------------------------------------------------------------------------------
/scripts/train_vitb16.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 8 -m training.main --batch-size=32 --lr=1e-5 --wd=0.1 --epochs=10 --workers=4 \
2 | --model EVA02-CLIP-B-16 --pretrained eva --warmup 1000 --zeroshot-frequency 1 --dataset-type proposals_distill \
3 | --test-type coco_panoptic --train-data data/coco/coco_proposals.json --max-boxes 20 \
4 | --val-data data/coco/annotations/panoptic_val2017.json --image-caption-path data/coco/annotations/captions_train2017.json \
5 | --embed-path metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTB16.npy --train-image-root data/coco/train2017 \
6 | --val-image-root data/coco/val2017 --cache-dir checkpoints/EVA02_CLIP_B_psz16_s8B.pt --log-every-n-steps 50 \
7 | --save-frequency 1 --extract-type="v2" --image-region-caption-path "data/coco/coco_captions.json" \
8 | --name train_vitb16 --downsample-factor 16 --det-image-size 224 --alpha 1
9 |
--------------------------------------------------------------------------------
/scripts/train_vitl14.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 8 -m training.main --batch-size=10 --lr=1e-5 --wd=0.1 --epochs=10 --workers=4 \
2 | --model EVA02-CLIP-L-14-336 --pretrained eva --warmup 1000 --zeroshot-frequency 1 --dataset-type proposals_distill \
3 | --test-type coco_panoptic --train-data data/coco/coco_proposals.json --max-boxes 5 \
4 | --val-data data/coco/annotations/panoptic_val2017.json --image-caption-path data/coco/annotations/captions_train2017.json \
5 | --embed-path metadata/coco_panoptic_clip_hand_craft_EVACLIP_ViTL14x336.npy --train-image-root data/coco/train2017 \
6 | --val-image-root data/coco/val2017 --cache-dir checkpoints/EVA02_CLIP_L_336_psz14_s6B.pt --log-every-n-steps 50 \
7 | --save-frequency 1 --extract-type="v2" \
8 | --name train_vitl14 --downsample-factor 14 --det-image-size 336 --alpha 1
9 |
--------------------------------------------------------------------------------
/tools/generate_text_embeddings.py:
--------------------------------------------------------------------------------
1 | # Modified from [ViLD](https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild)
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | from tqdm import tqdm
6 | import open_clip
7 |
8 |
9 | def article(name):
10 | return 'an' if name[0] in 'aeiou' else 'a'
11 |
12 | def processed_name(name, rm_dot=False):
13 | # _ for lvis
14 | # / for obj365
15 | res = name.replace('_', ' ').replace('/', ' or ').lower()
16 | if rm_dot:
17 | res = res.rstrip('.')
18 | return res
19 |
20 |
21 | single_template = [
22 | 'a photo of {article} {}.'
23 | ]
24 |
25 | multiple_templates = [
26 | 'There is {article} {} in the scene.',
27 | 'There is the {} in the scene.',
28 | 'a photo of {article} {} in the scene.',
29 | 'a photo of the {} in the scene.',
30 | 'a photo of one {} in the scene.',
31 |
32 |
33 | 'itap of {article} {}.',
34 | 'itap of my {}.', # itap: I took a picture of
35 | 'itap of the {}.',
36 | 'a photo of {article} {}.',
37 | 'a photo of my {}.',
38 | 'a photo of the {}.',
39 | 'a photo of one {}.',
40 | 'a photo of many {}.',
41 |
42 | 'a good photo of {article} {}.',
43 | 'a good photo of the {}.',
44 | 'a bad photo of {article} {}.',
45 | 'a bad photo of the {}.',
46 | 'a photo of a nice {}.',
47 | 'a photo of the nice {}.',
48 | 'a photo of a cool {}.',
49 | 'a photo of the cool {}.',
50 | 'a photo of a weird {}.',
51 | 'a photo of the weird {}.',
52 |
53 | 'a photo of a small {}.',
54 | 'a photo of the small {}.',
55 | 'a photo of a large {}.',
56 | 'a photo of the large {}.',
57 |
58 | 'a photo of a clean {}.',
59 | 'a photo of the clean {}.',
60 | 'a photo of a dirty {}.',
61 | 'a photo of the dirty {}.',
62 |
63 | 'a bright photo of {article} {}.',
64 | 'a bright photo of the {}.',
65 | 'a dark photo of {article} {}.',
66 | 'a dark photo of the {}.',
67 |
68 | 'a photo of a hard to see {}.',
69 | 'a photo of the hard to see {}.',
70 | 'a low resolution photo of {article} {}.',
71 | 'a low resolution photo of the {}.',
72 | 'a cropped photo of {article} {}.',
73 | 'a cropped photo of the {}.',
74 | 'a close-up photo of {article} {}.',
75 | 'a close-up photo of the {}.',
76 | 'a jpeg corrupted photo of {article} {}.',
77 | 'a jpeg corrupted photo of the {}.',
78 | 'a blurry photo of {article} {}.',
79 | 'a blurry photo of the {}.',
80 | 'a pixelated photo of {article} {}.',
81 | 'a pixelated photo of the {}.',
82 |
83 | 'a black and white photo of the {}.',
84 | 'a black and white photo of {article} {}.',
85 |
86 | 'a plastic {}.',
87 | 'the plastic {}.',
88 |
89 | 'a toy {}.',
90 | 'the toy {}.',
91 | 'a plushie {}.',
92 | 'the plushie {}.',
93 | 'a cartoon {}.',
94 | 'the cartoon {}.',
95 |
96 | 'an embroidered {}.',
97 | 'the embroidered {}.',
98 |
99 | 'a painting of the {}.',
100 | 'a painting of a {}.',
101 | ]
102 |
103 |
104 | def build_text_embedding_coco(categories, model):
105 | templates = multiple_templates
106 | with torch.no_grad():
107 | zeroshot_weights = []
108 | attn12_weights = []
109 | for category in categories:
110 | texts = [
111 | template.format(processed_name(category, rm_dot=True), article=article(category))
112 | for template in templates
113 | ]
114 | texts = [
115 | "This is " + text if text.startswith("a") or text.startswith("the") else text
116 | for text in texts
117 | ]
118 | texts = open_clip.tokenize(texts).cuda() # tokenize
119 | text_embeddings = model.encode_text(texts)
120 | text_attnfeatures, _, _ = model.encode_text_endk(texts, stepk=12, normalize=True)
121 |
122 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
123 | text_embedding = text_embeddings.mean(dim=0)
124 | text_embedding /= text_embedding.norm()
125 |
126 | text_attnfeatures = text_attnfeatures.mean(0)
127 | text_attnfeatures = F.normalize(text_attnfeatures, dim=0)
128 | attn12_weights.append(text_attnfeatures)
129 | zeroshot_weights.append(text_embedding)
130 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0)
131 | attn12_weights = torch.stack(attn12_weights, dim=0)
132 |
133 | return zeroshot_weights, attn12_weights
134 |
135 |
136 | def build_text_embedding_lvis(categories, model, tokenizer):
137 | templates = multiple_templates
138 |
139 | with torch.no_grad():
140 | all_text_embeddings = []
141 | for category in tqdm(categories):
142 | texts = [
143 | template.format(
144 | processed_name(category, rm_dot=True), article=article(category)
145 | )
146 | for template in templates
147 | ]
148 | texts = [
149 | "This is " + text if text.startswith("a") or text.startswith("the") else text
150 | for text in texts
151 | ]
152 | texts = tokenizer(texts).cuda() # tokenize
153 |
154 | text_embeddings = model.encode_text(texts)
155 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
156 | text_embedding = text_embeddings.mean(dim=0)
157 | text_embedding /= text_embedding.norm()
158 |
159 | all_text_embeddings.append(text_embedding)
160 | all_text_embeddings = torch.stack(all_text_embeddings, dim=0)
161 |
162 | return all_text_embeddings
163 |
164 |
165 | # voc_cats = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
166 | # 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
167 | # 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
168 | # 'tvmonitor')
169 | # text_embeddings, _ = build_text_embedding_coco(voc_cats)
170 | # np.save('datasets/metadata/voc_clip_hand_craft.npy', text_embeddings.cpu().numpy())
171 |
172 | import argparse
173 | import json
174 |
175 | if __name__ == '__main__':
176 | parser = argparse.ArgumentParser()
177 | parser.add_argument('--model_version', default='ViT-L-14-336')
178 | parser.add_argument('--ann', default='data/coco/annotations/panoptic_val2017.json')
179 | parser.add_argument('--out_path', default='metadata/coco_panoptic_clip_hand_craft_ViTL14x336.npy')
180 | parser.add_argument('--pretrained', default='openai')
181 | parser.add_argument('--cache_dir', default='checkpoints')
182 |
183 | args = parser.parse_args()
184 |
185 | model = open_clip.create_model(
186 | args.model_version, pretrained=args.pretrained, cache_dir=args.cache_dir
187 | )
188 | tokenizer = open_clip.get_tokenizer(args.model_version)
189 | model.cuda()
190 |
191 | print('Loading', args.ann)
192 | data = json.load(open(args.ann, 'r'))
193 | cat_names = [x['name'] for x in \
194 | sorted(data['categories'], key=lambda x: x['id'])]
195 | out_path = args.out_path
196 | text_embeddings = build_text_embedding_lvis(cat_names, model, tokenizer)
197 | np.save(out_path, text_embeddings.cpu().numpy())
198 |
--------------------------------------------------------------------------------
/training/.gitignore:
--------------------------------------------------------------------------------
1 | logs/
2 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__init__.py
--------------------------------------------------------------------------------
/training/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/clipself.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/clipself.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/clipself.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/clipself.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/coco_api.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/coco_api.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/coco_api.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/coco_api.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/custom_transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/custom_transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/custom_transforms.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/custom_transforms.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/data.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/data.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/data.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/dist_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/dist_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/dist_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/dist_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/distributed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/distributed.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/distributed.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/distributed.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/file_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/file_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/file_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/file_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/logger.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/logger.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/main.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/main.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/main.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/main.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/params.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/params.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/params.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/params.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/precision.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/precision.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/precision.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/precision.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/region_clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/region_clip.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/region_clip.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/region_clip.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/scheduler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/scheduler.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/scheduler.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/scheduler.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/test_flickr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/test_flickr.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/test_mscoco.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/test_mscoco.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/train.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/train.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/train.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/train.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/training/__pycache__/zero_shot.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/zero_shot.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/zero_shot.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Timsty1/FineCLIP/d44078b66ff16be450735d70c832a125a3bc564b/training/__pycache__/zero_shot.cpython-39.pyc
--------------------------------------------------------------------------------
/training/clipself.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | class CLIPSelf:
7 | def __call__(self, batch, model, dist_model, itc_loss, device, cast_dtype, distributed, args):
8 | if distributed:
9 | model = model.module
10 | dist_model = dist_model.module
11 | images, normed_boxes, image_crops, texts, region_texts = batch
12 |
13 | images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
14 | normed_boxes = normed_boxes.to(device=device, dtype=cast_dtype, non_blocking=True)
15 | image_crops = image_crops.to(device=device, dtype=cast_dtype, non_blocking=True)
16 | texts = texts.to(device=device, dtype=cast_dtype, non_blocking=True)
17 | texts = texts.squeeze(1) # alignment dimension
18 | region_texts = region_texts.to(device=device, dtype=cast_dtype, non_blocking=True)
19 |
20 | if args.multiscale:
21 | cur_h, cur_w = images.shape[2:]
22 | assert cur_h == cur_w
23 | if cur_h == 1024:
24 | tar_sizes = [320, 640, 896, 1024]
25 | elif cur_h == 896:
26 | tar_sizes = [336, 448, 672, 896]
27 | else:
28 | raise NotImplementedError
29 | tar_size = random.choice(tar_sizes)
30 | images = F.interpolate(images, size=(tar_size, tar_size), mode='bilinear')
31 |
32 | rois_list = []
33 | crops_list = []
34 | region_text_list = []
35 | for bboxes_per_image, crops_per_image, region_per_text in zip(normed_boxes, image_crops, region_texts):
36 | valid = bboxes_per_image[:, -1] > 0.5
37 | rois_list.append(bboxes_per_image[valid, :4])
38 | crops_list.append(crops_per_image[valid])
39 | region_text_list.append(region_per_text[valid])
40 |
41 | image_crops = torch.cat(crops_list)
42 | region_texts = torch.cat(region_text_list)
43 | teacher_crop_features = model.encode_image(image_crops, normalize=False)
44 | student_roi_features = model.encode_pseudo_boxes(images, rois_list, normalize=False, extract_type=args.extract_type)
45 | region_text_features = model.encode_text(region_texts, normalize=False)
46 |
47 | normed_student_features = F.normalize(student_roi_features, dim=-1)
48 | normed_teacher_features = F.normalize(teacher_crop_features, dim=-1)
49 | normed_region_text_features = F.normalize(region_text_features, dim=-1)
50 | losses = {}
51 | if "distill" in args.loss_type:
52 | loss_cosine = 1.0 - (normed_student_features * normed_teacher_features).sum(-1).mean()
53 | losses["distill"] = loss_cosine
54 |
55 | if "global_itc" in args.loss_type:
56 | image_features, text_features, logit_scale = model(images, texts)
57 | global_itc_loss, _ = itc_loss(image_features, text_features, logit_scale)
58 | losses["global_itc"] = global_itc_loss
59 |
60 | if "region_itc" in args.loss_type:
61 | assert normed_student_features.shape[0] == normed_region_text_features.shape[0]
62 | all_student_features = torch.zeros((args.batch_size*args.max_boxes, normed_student_features.shape[1])).to(device=device, dtype=cast_dtype, non_blocking=True)
63 | all_student_features[0:normed_student_features.shape[0]] = normed_student_features
64 |
65 | all_region_text_features = torch.zeros((args.batch_size*args.max_boxes, normed_student_features.shape[1])).to(device=device, dtype=cast_dtype, non_blocking=True)
66 | all_region_text_features[0:normed_region_text_features.shape[0]] = normed_region_text_features
67 |
68 | region_itc_loss, _ = itc_loss(all_student_features, all_region_text_features, model.logit_scale.exp())
69 | losses["region_itc"] = region_itc_loss
70 |
71 | return losses, len(images), model.logit_scale.exp()
72 |
--------------------------------------------------------------------------------
/training/coco_api.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | # This file add snake case alias for coco api
3 |
4 | import warnings
5 | from collections import defaultdict
6 | from typing import List, Optional, Union
7 |
8 | import pycocotools
9 | from pycocotools.coco import COCO as _COCO
10 | from pycocotools.cocoeval import COCOeval as _COCOeval
11 |
12 |
13 | class COCO(_COCO):
14 | """This class is almost the same as official pycocotools package.
15 |
16 | It implements some snake case function aliases. So that the COCO class has
17 | the same interface as LVIS class.
18 | """
19 |
20 | def __init__(self, annotation_file=None):
21 | if getattr(pycocotools, '__version__', '0') >= '12.0.2':
22 | warnings.warn(
23 | 'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501
24 | UserWarning)
25 | super().__init__(annotation_file=annotation_file)
26 | self.img_ann_map = self.imgToAnns
27 | self.cat_img_map = self.catToImgs
28 |
29 | def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None):
30 | return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)
31 |
32 | def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]):
33 | return self.getCatIds(cat_names, sup_names, cat_ids)
34 |
35 | def get_img_ids(self, img_ids=[], cat_ids=[]):
36 | return self.getImgIds(img_ids, cat_ids)
37 |
38 | def load_anns(self, ids):
39 | return self.loadAnns(ids)
40 |
41 | def load_cats(self, ids):
42 | return self.loadCats(ids)
43 |
44 | def load_imgs(self, ids):
45 | return self.loadImgs(ids)
46 |
47 |
48 | # just for the ease of import
49 | COCOeval = _COCOeval
50 |
51 |
52 | class COCOPanoptic(COCO):
53 | """This wrapper is for loading the panoptic style annotation file.
54 |
55 | The format is shown in the CocoPanopticDataset class.
56 |
57 | Args:
58 | annotation_file (str, optional): Path of annotation file.
59 | Defaults to None.
60 | """
61 |
62 | def __init__(self, annotation_file: Optional[str] = None) -> None:
63 | super(COCOPanoptic, self).__init__(annotation_file)
64 |
65 | def createIndex(self) -> None:
66 | """Create index."""
67 | # create index
68 | print('creating index...')
69 | # anns stores 'segment_id -> annotation'
70 | anns, cats, imgs = {}, {}, {}
71 | img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list)
72 | if 'annotations' in self.dataset:
73 | for ann in self.dataset['annotations']:
74 | for seg_ann in ann['segments_info']:
75 | # to match with instance.json
76 | seg_ann['image_id'] = ann['image_id']
77 | img_to_anns[ann['image_id']].append(seg_ann)
78 | # segment_id is not unique in coco dataset orz...
79 | # annotations from different images but
80 | # may have same segment_id
81 | if seg_ann['id'] in anns.keys():
82 | anns[seg_ann['id']].append(seg_ann)
83 | else:
84 | anns[seg_ann['id']] = [seg_ann]
85 |
86 | # filter out annotations from other images
87 | img_to_anns_ = defaultdict(list)
88 | for k, v in img_to_anns.items():
89 | img_to_anns_[k] = [x for x in v if x['image_id'] == k]
90 | img_to_anns = img_to_anns_
91 |
92 | if 'images' in self.dataset:
93 | for img_info in self.dataset['images']:
94 | img_info['segm_file'] = img_info['file_name'].replace(
95 | 'jpg', 'png')
96 | imgs[img_info['id']] = img_info
97 |
98 | if 'categories' in self.dataset:
99 | for cat in self.dataset['categories']:
100 | cats[cat['id']] = cat
101 |
102 | if 'annotations' in self.dataset and 'categories' in self.dataset:
103 | for ann in self.dataset['annotations']:
104 | for seg_ann in ann['segments_info']:
105 | cat_to_imgs[seg_ann['category_id']].append(ann['image_id'])
106 |
107 | print('index created!')
108 |
109 | self.anns = anns
110 | self.imgToAnns = img_to_anns
111 | self.catToImgs = cat_to_imgs
112 | self.imgs = imgs
113 | self.cats = cats
114 |
115 | def load_anns(self,
116 | ids: Union[List[int], int] = []) -> Optional[List[dict]]:
117 | """Load anns with the specified ids.
118 |
119 | ``self.anns`` is a list of annotation lists instead of a
120 | list of annotations.
121 |
122 | Args:
123 | ids (Union[List[int], int]): Integer ids specifying anns.
124 |
125 | Returns:
126 | anns (List[dict], optional): Loaded ann objects.
127 | """
128 | anns = []
129 |
130 | if hasattr(ids, '__iter__') and hasattr(ids, '__len__'):
131 | # self.anns is a list of annotation lists instead of
132 | # a list of annotations
133 | for id in ids:
134 | anns += self.anns[id]
135 | return anns
136 | elif type(ids) == int:
137 | return self.anns[ids]
138 |
--------------------------------------------------------------------------------
/training/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.transforms.functional as F
5 | from torchvision.transforms import RandomCrop, InterpolationMode
6 |
7 |
8 | class CustomRandomResize(nn.Module):
9 |
10 | def __init__(self, scale=(0.5, 2.0), interpolation=InterpolationMode.BILINEAR):
11 | super().__init__()
12 | self.min_scale, self.max_scale = min(scale), max(scale)
13 | self.interpolation = interpolation
14 |
15 | def forward(self, img):
16 | if isinstance(img, torch.Tensor):
17 | height, width = img.shape[:2]
18 | else:
19 | width, height = img.size
20 | scale = random.uniform(self.min_scale, self.max_scale)
21 | new_size = [int(height * scale), int(width * scale)]
22 | img = F.resize(img, new_size, self.interpolation)
23 |
24 | return img
25 |
26 |
27 | class CustomRandomCrop(RandomCrop):
28 | def forward(self, img):
29 | """
30 | Args:
31 | img (PIL Image or Tensor): Image to be cropped.
32 |
33 | Returns:
34 | PIL Image or Tensor: Cropped image.
35 | """
36 |
37 | width, height = F.get_image_size(img)
38 | tar_h, tar_w = self.size
39 |
40 | tar_h = min(tar_h, height)
41 | tar_w = min(tar_w, width)
42 | i, j, h, w = self.get_params(img, (tar_h, tar_w))
43 |
44 | return F.crop(img, i, j, h, w)
45 |
--------------------------------------------------------------------------------
/training/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 | try:
7 | import horovod.torch as hvd
8 | except ImportError:
9 | hvd = None
10 |
11 |
12 | def is_global_master(args):
13 | return args.rank == 0
14 |
15 |
16 | def is_local_master(args):
17 | return args.local_rank == 0
18 |
19 |
20 | def is_master(args, local=False):
21 | return is_local_master(args) if local else is_global_master(args)
22 |
23 |
24 | def is_using_horovod():
25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"]
29 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
30 | return True
31 | else:
32 | return False
33 |
34 |
35 | def is_using_distributed():
36 | if 'WORLD_SIZE' in os.environ:
37 | return int(os.environ['WORLD_SIZE']) > 1
38 | if 'SLURM_NTASKS' in os.environ:
39 | return int(os.environ['SLURM_NTASKS']) > 1
40 | return False
41 |
42 |
43 | def world_info_from_env():
44 | local_rank = 0
45 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
46 | if v in os.environ:
47 | local_rank = int(os.environ[v])
48 | break
49 | global_rank = 0
50 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
51 | if v in os.environ:
52 | global_rank = int(os.environ[v])
53 | break
54 | world_size = 1
55 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
56 | if v in os.environ:
57 | world_size = int(os.environ[v])
58 | break
59 |
60 | return local_rank, global_rank, world_size
61 |
62 |
63 | def init_distributed_device(args):
64 | # Distributed training = training on more than one GPU.
65 | # Works in both single and multi-node scenarios.
66 | args.distributed = False
67 | args.world_size = 1
68 | args.rank = 0 # global rank
69 | args.local_rank = 0
70 | if args.horovod:
71 | assert hvd is not None, "Horovod is not installed"
72 | hvd.init()
73 | args.local_rank = int(hvd.local_rank())
74 | args.rank = hvd.rank()
75 | args.world_size = hvd.size()
76 | args.distributed = True
77 | os.environ['LOCAL_RANK'] = str(args.local_rank)
78 | os.environ['RANK'] = str(args.rank)
79 | os.environ['WORLD_SIZE'] = str(args.world_size)
80 | elif is_using_distributed():
81 | if 'SLURM_PROCID' in os.environ:
82 | # DDP via SLURM
83 | args.local_rank, args.rank, args.world_size = world_info_from_env()
84 | # SLURM var -> torch.distributed vars in case needed
85 | os.environ['LOCAL_RANK'] = str(args.local_rank)
86 | os.environ['RANK'] = str(args.rank)
87 | os.environ['WORLD_SIZE'] = str(args.world_size)
88 | torch.distributed.init_process_group(
89 | backend=args.dist_backend,
90 | init_method=args.dist_url,
91 | world_size=args.world_size,
92 | rank=args.rank,
93 | )
94 | else:
95 | # DDP via torchrun, torch.distributed.launch
96 | args.local_rank, _, _ = world_info_from_env()
97 | torch.distributed.init_process_group(
98 | backend=args.dist_backend,
99 | init_method=args.dist_url)
100 | args.world_size = torch.distributed.get_world_size()
101 | args.rank = torch.distributed.get_rank()
102 | args.distributed = True
103 |
104 | if torch.cuda.is_available():
105 | if args.distributed and not args.no_set_device_rank:
106 | device = 'cuda:%d' % args.local_rank
107 | else:
108 | device = 'cuda:0'
109 | torch.cuda.set_device(device)
110 | else:
111 | device = 'cpu'
112 | args.device = device
113 | device = torch.device(device)
114 | return device
115 |
116 |
117 | def broadcast_object(args, obj, src=0):
118 | # broadcast a pickle-able python object from rank-0 to all ranks
119 | if args.horovod:
120 | return hvd.broadcast_object(obj, root_rank=src)
121 | else:
122 | if args.rank == src:
123 | objects = [obj]
124 | else:
125 | objects = [None]
126 | dist.broadcast_object_list(objects, src=src)
127 | return objects[0]
128 |
129 |
130 | def all_gather_object(args, obj, dst=0):
131 | # gather a pickle-able python object across all ranks
132 | if args.horovod:
133 | return hvd.allgather_object(obj)
134 | else:
135 | objects = [None for _ in range(args.world_size)]
136 | dist.all_gather_object(objects, obj)
137 | return objects
138 |
--------------------------------------------------------------------------------
/training/file_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import multiprocessing
4 | import subprocess
5 | import time
6 | import fsspec
7 | import torch
8 | from tqdm import tqdm
9 |
10 | def remote_sync_s3(local_dir, remote_dir):
11 | # skip epoch_latest which can change during sync.
12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
13 | if result.returncode != 0:
14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
15 | return False
16 |
17 | logging.info(f"Successfully synced with S3 bucket")
18 | return True
19 |
20 | def remote_sync_fsspec(local_dir, remote_dir):
21 | # FIXME currently this is slow and not recommended. Look into speeding up.
22 | a = fsspec.get_mapper(local_dir)
23 | b = fsspec.get_mapper(remote_dir)
24 |
25 | for k in a:
26 | # skip epoch_latest which can change during sync.
27 | if 'epoch_latest.pt' in k:
28 | continue
29 |
30 | logging.info(f'Attempting to sync {k}')
31 | if k in b and len(a[k]) == len(b[k]):
32 | logging.debug(f'Skipping remote sync for {k}.')
33 | continue
34 |
35 | try:
36 | logging.info(f'Successful sync for {k}.')
37 | b[k] = a[k]
38 | except Exception as e:
39 | logging.info(f'Error during remote sync for {k}: {e}')
40 | return False
41 |
42 | return True
43 |
44 | def remote_sync(local_dir, remote_dir, protocol):
45 | logging.info('Starting remote sync.')
46 | if protocol == 's3':
47 | return remote_sync_s3(local_dir, remote_dir)
48 | elif protocol == 'fsspec':
49 | return remote_sync_fsspec(local_dir, remote_dir)
50 | else:
51 | logging.error('Remote protocol not known')
52 | return False
53 |
54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
55 | while True:
56 | time.sleep(sync_every)
57 | remote_sync(local_dir, remote_dir, protocol)
58 |
59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol):
60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))
61 | return p
62 |
63 | # Note: we are not currently using this save function.
64 | def pt_save(pt_obj, file_path):
65 | of = fsspec.open(file_path, "wb")
66 | with of as f:
67 | torch.save(pt_obj, file_path)
68 |
69 | def pt_load(file_path, map_location=None):
70 | if file_path.startswith('s3'):
71 | logging.info('Loading remote checkpoint, which may take a bit.')
72 | of = fsspec.open(file_path, "rb")
73 | with of as f:
74 | out = torch.load(f, map_location=map_location)
75 | return out
76 |
77 | def check_exists(file_path):
78 | try:
79 | with fsspec.open(file_path):
80 | pass
81 | except FileNotFoundError:
82 | return False
83 | return True
84 |
--------------------------------------------------------------------------------
/training/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def setup_logging(log_file, level, include_host=False):
5 | if include_host:
6 | import socket
7 | hostname = socket.gethostname()
8 | formatter = logging.Formatter(
9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
10 | else:
11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
12 |
13 | logging.root.setLevel(level)
14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
15 | for logger in loggers:
16 | logger.setLevel(level)
17 |
18 | stream_handler = logging.StreamHandler()
19 | stream_handler.setFormatter(formatter)
20 | logging.root.addHandler(stream_handler)
21 |
22 | if log_file:
23 | file_handler = logging.FileHandler(filename=log_file)
24 | file_handler.setFormatter(formatter)
25 | logging.root.addHandler(file_handler)
26 |
27 |
--------------------------------------------------------------------------------
/training/precision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from contextlib import suppress
3 |
4 |
5 | def get_autocast(precision):
6 | if precision == 'amp':
7 | return torch.cuda.amp.autocast
8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
9 | # amp_bfloat16 is more stable than amp float16 for clip training
10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
11 | else:
12 | return suppress
13 |
--------------------------------------------------------------------------------
/training/profile.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import open_clip
5 | import pandas as pd
6 | from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis
7 |
8 |
9 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler')
10 |
11 | # benchmark specific args
12 | parser.add_argument('--model', metavar='NAME', default='',
13 | help='model(s) to profile')
14 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
15 | help='Output csv file for results')
16 |
17 |
18 | def profile_fvcore(
19 | model,
20 | image_input_size=(3, 224, 224),
21 | text_input_size=(77,),
22 | batch_size=1,
23 | detailed=False,
24 | force_cpu=False
25 | ):
26 | if force_cpu:
27 | model = model.to('cpu')
28 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
29 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
30 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
31 | fca = FlopCountAnalysis(model, (example_image_input, example_text_input))
32 | aca = ActivationCountAnalysis(model, (example_image_input, example_text_input))
33 | if detailed:
34 | fcs = flop_count_str(fca)
35 | print(fcs)
36 | return fca.total(), aca.total()
37 |
38 |
39 | def profile_fvcore_text(
40 | model,
41 | text_input_size=(77,),
42 | batch_size=1,
43 | detailed=False,
44 | force_cpu=False
45 | ):
46 | if force_cpu:
47 | model = model.to('cpu')
48 | device = next(model.parameters()).device
49 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
50 | fca = FlopCountAnalysis(model, example_input)
51 | aca = ActivationCountAnalysis(model, example_input)
52 | if detailed:
53 | fcs = flop_count_str(fca)
54 | print(fcs)
55 | return fca.total(), aca.total()
56 |
57 |
58 | def profile_fvcore_image(
59 | model,
60 | image_input_size=(3, 224, 224),
61 | batch_size=1,
62 | detailed=False,
63 | force_cpu=False
64 | ):
65 | if force_cpu:
66 | model = model.to('cpu')
67 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
68 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
69 | fca = FlopCountAnalysis(model, example_input)
70 | aca = ActivationCountAnalysis(model, example_input)
71 | if detailed:
72 | fcs = flop_count_str(fca)
73 | print(fcs)
74 | return fca.total(), aca.total()
75 |
76 |
77 | def count_params(model):
78 | return sum([m.numel() for m in model.parameters()])
79 |
80 |
81 | def profile_model(model_name):
82 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False)
83 | model.eval()
84 | if torch.cuda.is_available():
85 | model = model.cuda()
86 |
87 | if isinstance(model.visual.image_size, (tuple, list)):
88 | image_input_size = (3,) + tuple(model.visual.image_size[-2:])
89 | else:
90 | image_input_size = (3, model.visual.image_size, model.visual.image_size)
91 | text_input_size = (77,)
92 |
93 | results = {}
94 | results['model'] = model_name
95 | results['image_size'] = image_input_size[1]
96 |
97 | model_cfg = open_clip.get_model_config(model_name)
98 | if model_cfg:
99 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg'])
100 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg'])
101 | results['image_width'] = int(vision_cfg.width)
102 | results['text_width'] = int(text_cfg.width)
103 | results['embed_dim'] = int(model_cfg['embed_dim'])
104 | else:
105 | results['image_width'] = 0
106 | results['text_width'] = 0
107 | results['embed_dim'] = 0
108 |
109 | retries = 2
110 | while retries:
111 | retries -= 1
112 | try:
113 | macs, acts = profile_fvcore(
114 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries)
115 |
116 | image_macs, image_acts = profile_fvcore_image(
117 | model.visual, image_input_size=image_input_size, force_cpu=not retries)
118 |
119 | text_macs, text_acts = profile_fvcore_text(
120 | model.text, text_input_size=text_input_size, force_cpu=not retries)
121 |
122 | results['gmacs'] = round(macs / 1e9, 2)
123 | results['macts'] = round(acts / 1e6, 2)
124 | results['mparams'] = round(count_params(model) / 1e6, 2)
125 | results['image_gmacs'] = round(image_macs / 1e9, 2)
126 | results['image_macts'] = round(image_acts / 1e6, 2)
127 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2)
128 | results['text_gmacs'] = round(text_macs / 1e9, 2)
129 | results['text_macts'] = round(text_acts / 1e6, 2)
130 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2)
131 | except RuntimeError as e:
132 | pass
133 | return results
134 |
135 |
136 | def main():
137 | args = parser.parse_args()
138 |
139 | # FIXME accept a text file name to allow lists of models in txt/csv
140 | if args.model == 'all':
141 | parsed_model = open_clip.list_models()
142 | else:
143 | parsed_model = args.model.split(',')
144 |
145 | results = []
146 | for m in parsed_model:
147 | row = profile_model(m)
148 | results.append(row)
149 |
150 | df = pd.DataFrame(results, columns=results[0].keys())
151 | df = df.sort_values('gmacs')
152 | print(df)
153 | if args.results_file:
154 | df.to_csv(args.results_file, index=False)
155 |
156 |
157 | if __name__ == '__main__':
158 | main()
159 |
--------------------------------------------------------------------------------
/training/region_clip.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 |
6 |
7 | def get_fed_loss_inds(gt_classes, num_sample_cats, C):
8 | appeared = torch.unique(gt_classes) # C'
9 | prob = appeared.new_ones(C).float()
10 | if len(appeared) < num_sample_cats:
11 | prob[appeared] = 0
12 | more_appeared = torch.multinomial(
13 | prob, num_sample_cats - len(appeared),
14 | replacement=False)
15 | appeared = torch.cat([appeared, more_appeared])
16 | return appeared
17 |
18 |
19 | class RegionCLIP(nn.Module):
20 | def __init__(self, args):
21 | super().__init__()
22 | embed_path = args.train_embed_path
23 | noun_embeddings = torch.from_numpy(np.load(embed_path))
24 | noun_embeddings = F.normalize(noun_embeddings, dim=-1)
25 | self.register_buffer("noun_embeddings", noun_embeddings)
26 | self.place_holder = nn.Parameter(torch.ones(1))
27 |
28 | def __call__(self, batch, model, dist_model, loss, device, cast_dtype,
29 | distributed, args):
30 | if distributed:
31 | model = model.module
32 | images, boxes = batch
33 | images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
34 | boxes = boxes.to(device=device, non_blocking=True)
35 |
36 | boxes_list = []
37 | boxes_label_list = []
38 |
39 | for boxes_per_image in boxes:
40 | boxes_per_image = boxes_per_image[boxes_per_image[:, -1] > 0.5]
41 | boxes_label_list.append(boxes_per_image[:, 4].long())
42 | boxes_list.append(boxes_per_image[:, :4])
43 | boxes_labels = torch.cat(boxes_label_list)
44 | box_features = model.encode_pseudo_boxes(images, boxes_list, normalize=True,
45 | extract_type=args.extract_type)
46 | temp = model.logit_scale.exp().detach()
47 | boxes2nouns = box_features @ self.noun_embeddings.T * temp
48 | target = torch.zeros_like(boxes2nouns)
49 | target[range(len(boxes_labels)), boxes_labels] = 1.0
50 |
51 | appeared = get_fed_loss_inds(boxes_labels, 100, self.noun_embeddings.shape[0])
52 | target = target[:, appeared]
53 | boxes2nouns = boxes2nouns[:, appeared]
54 |
55 | loss_cls = F.binary_cross_entropy_with_logits(boxes2nouns, target, reduction='none') # B x C
56 | loss_cls = loss_cls.sum(-1).mean()
57 |
58 | image_size = model.visual.image_size
59 | if isinstance(image_size, int):
60 | tar_h = tar_w = image_size
61 | else:
62 | tar_h, tar_w = image_size
63 | images = F.interpolate(images, size=(tar_h, tar_w), mode='bilinear')
64 |
65 | losses = dict(loss_contrast=loss_cls * args.contrast_weight)
66 |
67 | return losses, len(images), temp
68 |
--------------------------------------------------------------------------------
/training/scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def assign_learning_rate(optimizer, new_lr):
5 | for param_group in optimizer.param_groups:
6 | param_group["lr"] = new_lr
7 |
8 |
9 | def _warmup_lr(base_lr, warmup_length, step):
10 | return base_lr * (step + 1) / warmup_length
11 |
12 |
13 | def const_lr(optimizer, base_lr, warmup_length, steps):
14 | def _lr_adjuster(step):
15 | if step < warmup_length:
16 | lr = _warmup_lr(base_lr, warmup_length, step)
17 | else:
18 | lr = base_lr
19 | assign_learning_rate(optimizer, lr)
20 | return lr
21 | return _lr_adjuster
22 |
23 |
24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.):
25 | def _lr_adjuster(step):
26 | start_cooldown_step = steps - cooldown_steps
27 | if step < warmup_length:
28 | lr = _warmup_lr(base_lr, warmup_length, step)
29 | else:
30 | if step < start_cooldown_step:
31 | lr = base_lr
32 | else:
33 | e = step - start_cooldown_step
34 | es = steps - start_cooldown_step
35 | # linear decay if power == 1; polynomial decay otherwise;
36 | decay = (1 - (e/es)) ** cooldown_power
37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr
38 | assign_learning_rate(optimizer, lr)
39 | return lr
40 | return _lr_adjuster
41 |
42 |
43 | def cosine_lr(optimizer, base_lr, warmup_length, steps):
44 | def _lr_adjuster(step):
45 | if step < warmup_length:
46 | lr = _warmup_lr(base_lr, warmup_length, step)
47 | else:
48 | e = step - warmup_length
49 | es = steps - warmup_length
50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
51 | assign_learning_rate(optimizer, lr)
52 | return lr
53 | return _lr_adjuster
54 |
--------------------------------------------------------------------------------
/training/test_flickr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from open_clip import create_model_and_transforms, get_tokenizer
3 | from PIL import Image
4 | import json
5 | import tqdm
6 | import os
7 | import argparse
8 |
9 |
10 | def main(args):
11 | model_name = args.model_name
12 | pretrained = args.pretrained
13 | data = args.data
14 | image_path = args.image_path
15 | image_size = args.image_size
16 | device = args.device
17 | model, _, preprocess = create_model_and_transforms(
18 | model_name,
19 | "eva",
20 | "amp",
21 | device="cpu",
22 | jit=False,
23 | force_quick_gelu=False,
24 | force_custom_text=False,
25 | force_patch_dropout=None,
26 | force_image_size=None,
27 | pretrained_image=False,
28 | image_mean=None,
29 | image_std=None,
30 | aug_cfg={},
31 | output_dict=True,
32 | cache_dir=pretrained,
33 | det_image_size=image_size,
34 | dataset_type="grid_distill",
35 | )
36 | tokenizer = get_tokenizer(model_name)
37 | model = model.to(device)
38 | image_list = []
39 | caption_list = []
40 | image_fea = []
41 | caption_fea = []
42 | txt2img = {}
43 | img2txt = {}
44 | txt_id = 0
45 | with open(data) as f:
46 | annotation = json.load(f)
47 | for img_id, ann in enumerate(annotation):
48 | img2txt[img_id] = []
49 | image_list.append(ann['image'])
50 | for caption in ann['caption']:
51 | caption_list.append(caption)
52 | img2txt[img_id].append(txt_id)
53 | txt2img[txt_id] = img_id
54 | txt_id += 1
55 |
56 | with torch.no_grad(), torch.cuda.amp.autocast():
57 | tmp_image_list = []
58 | for i in tqdm.tqdm(range(len(image_list))):
59 | tmp_image_list.append(preprocess[0](Image.open(os.path.join(image_path, image_list[i]))).unsqueeze(0).to(device))
60 | if len(tmp_image_list) % 64 == 0 or i == len(image_list)-1:
61 | tmp_image = torch.cat(tmp_image_list, dim=0)
62 | image_features = model.encode_image(tmp_image)
63 | image_features /= image_features.norm(dim=-1, keepdim=True)
64 | image_fea.append(image_features)
65 | tmp_image_list = []
66 |
67 | tmp_text_list = []
68 | for i in tqdm.tqdm(range(len(caption_list))):
69 | tmp_text_list.append(caption_list[i])
70 | if len(tmp_text_list) % 64 == 0 or i == len(caption_list)-1:
71 | text = tokenizer(tmp_text_list).to(device)
72 | text_features = model.encode_text(text)
73 | text_features /= text_features.norm(dim=-1, keepdim=True)
74 | caption_fea.append(text_features)
75 | tmp_text_list = []
76 |
77 | image_fea_total = torch.cat(image_fea, dim=0)
78 | caption_fea_total = torch.cat(caption_fea, dim=0)
79 | sims = image_fea_total@caption_fea_total.t()
80 | _, topk_idx = sims.topk(k=1, dim=0)
81 | count = 0
82 |
83 | for i in range(topk_idx.shape[1]):
84 | if topk_idx[0,i] == txt2img[i]:
85 | count += 1
86 | print("文搜图的准确率为:{:.2f}%".format(100*count/topk_idx.shape[1]))
87 |
88 | sims = sims.t()
89 | _, topk_idx = sims.topk(k=1, dim=0)
90 |
91 | count = 0
92 | new_list = []
93 | for i in range(topk_idx.shape[1]):
94 | if topk_idx[0,i] in img2txt[i]:
95 | count += 1
96 | else:
97 | new_list.append({"image": image_list[i], "caption":caption_list[img2txt[i][0]]})
98 | print("图搜文的准确率为:{:.2f}%".format(100*count/topk_idx.shape[1]))
99 |
100 |
101 | if __name__ == "__main__":
102 | parser = argparse.ArgumentParser()
103 | parser.add_argument(
104 | "--model-name",
105 | type=str,
106 | default="EVA02-CLIP-B-16",
107 | )
108 | parser.add_argument(
109 | "--pretrained",
110 | type=str,
111 | default="./checkpoints/coco_vitb16.pt",
112 | )
113 | parser.add_argument(
114 | "--data",
115 | type=str,
116 | default="./data/flickr30k/flickr30k_test.json"
117 | )
118 | parser.add_argument(
119 | "--image-path",
120 | type=str,
121 | default="./data/flickr30k/flickr30k_images"
122 | )
123 | parser.add_argument(
124 | "--image-size",
125 | type=int,
126 | default=224
127 | )
128 | parser.add_argument(
129 | "--device",
130 | type=str,
131 | default="cuda:0"
132 | )
133 | args = parser.parse_args()
134 | main(args)
--------------------------------------------------------------------------------
/training/test_mscoco.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from open_clip import create_model_and_transforms, get_tokenizer
3 | from PIL import Image
4 | import json
5 | import tqdm
6 | import os
7 | import argparse
8 |
9 |
10 | def main(args):
11 | model_name = args.model_name
12 | pretrained = args.pretrained
13 | data = args.data
14 | image_path = args.image_path
15 | image_size = args.image_size
16 | device = args.device
17 | model, _, preprocess = create_model_and_transforms(
18 | model_name,
19 | "eva",
20 | "amp",
21 | device="cpu",
22 | jit=False,
23 | force_quick_gelu=False,
24 | force_custom_text=False,
25 | force_patch_dropout=None,
26 | force_image_size=None,
27 | pretrained_image=False,
28 | image_mean=None,
29 | image_std=None,
30 | aug_cfg={},
31 | output_dict=True,
32 | cache_dir=pretrained,
33 | det_image_size=image_size,
34 | dataset_type="grid_distill",
35 | )
36 | tokenizer = get_tokenizer(model_name)
37 | model = model.to(device)
38 | image_list = []
39 | caption_list = []
40 | image_fea = []
41 | caption_fea = []
42 | txt2img = {}
43 | img2txt = {}
44 | txt_id = 0
45 | with open(data) as f:
46 | annotation = json.load(f)
47 | for img_id, ann in enumerate(annotation):
48 | img2txt[img_id] = []
49 | image_list.append(ann['image'])
50 | for caption in ann['caption']:
51 | caption_list.append(caption)
52 | img2txt[img_id].append(txt_id)
53 | txt2img[txt_id] = img_id
54 | txt_id += 1
55 |
56 | with torch.no_grad(), torch.cuda.amp.autocast():
57 | tmp_image_list = []
58 | for i in tqdm.tqdm(range(len(image_list))):
59 | tmp_image_list.append(preprocess[0](Image.open(os.path.join(image_path, image_list[i]))).unsqueeze(0).to(device))
60 | if len(tmp_image_list) % 64 == 0 or i == len(image_list)-1:
61 | tmp_image = torch.cat(tmp_image_list, dim=0)
62 | image_features = model.encode_image(tmp_image)
63 | image_features /= image_features.norm(dim=-1, keepdim=True)
64 | image_fea.append(image_features)
65 | tmp_image_list = []
66 |
67 | tmp_text_list = []
68 | for i in tqdm.tqdm(range(len(caption_list))):
69 | tmp_text_list.append(caption_list[i])
70 | if len(tmp_text_list) % 64 == 0 or i == len(caption_list)-1:
71 | text = tokenizer(tmp_text_list).to(device)
72 | text_features = model.encode_text(text)
73 | text_features /= text_features.norm(dim=-1, keepdim=True)
74 | caption_fea.append(text_features)
75 | tmp_text_list = []
76 |
77 | image_fea_total = torch.cat(image_fea, dim=0)
78 | caption_fea_total = torch.cat(caption_fea, dim=0)
79 | sims = image_fea_total@caption_fea_total.t()
80 | _, topk_idx = sims.topk(k=1, dim=0)
81 | count = 0
82 |
83 | for i in range(topk_idx.shape[1]):
84 | if topk_idx[0,i] == txt2img[i]:
85 | count += 1
86 | print("文搜图的准确率为:{:.2f}%".format(100*count/topk_idx.shape[1]))
87 |
88 | sims = sims.t()
89 | _, topk_idx = sims.topk(k=1, dim=0)
90 |
91 | count = 0
92 | new_list = []
93 | for i in range(topk_idx.shape[1]):
94 | if topk_idx[0,i] in img2txt[i]:
95 | count += 1
96 | else:
97 | new_list.append({"image": image_list[i], "caption":caption_list[img2txt[i][0]]})
98 | print("图搜文的准确率为:{:.2f}%".format(100*count/topk_idx.shape[1]))
99 |
100 |
101 | if __name__ == "__main__":
102 | parser = argparse.ArgumentParser()
103 | parser.add_argument(
104 | "--model-name",
105 | type=str,
106 | default="EVA02-CLIP-B-16",
107 | )
108 | parser.add_argument(
109 | "--pretrained",
110 | type=str,
111 | default="./checkpoints/coco_vitb16.pt",
112 | )
113 | parser.add_argument(
114 | "--data",
115 | type=str,
116 | default="./data/coco/coco_test.json"
117 | )
118 | parser.add_argument(
119 | "--image-path",
120 | type=str,
121 | default="./data/coco/val2017"
122 | )
123 | parser.add_argument(
124 | "--image-size",
125 | type=int,
126 | default=224
127 | )
128 | parser.add_argument(
129 | "--device",
130 | type=str,
131 | default="cuda:0"
132 | )
133 | args = parser.parse_args()
134 | main(args)
--------------------------------------------------------------------------------
/training/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import math
4 | import time
5 | import torch
6 |
7 | from open_clip.eva_clip import ClipLoss, get_cast_dtype
8 | from .distributed import is_master
9 | from .zero_shot import zero_shot_eval
10 | from .precision import get_autocast
11 | import os
12 |
13 |
14 | class AverageMeter(object):
15 | """Computes and stores the average and current value"""
16 |
17 | def __init__(self):
18 | self.reset()
19 |
20 | def reset(self):
21 | self.val = 0
22 | self.avg = 0
23 | self.sum = 0
24 | self.count = 0
25 |
26 | def update(self, val, n=1):
27 | self.val = val
28 | self.sum += val * n
29 | self.count += n
30 | self.avg = self.sum / self.count
31 |
32 | def postprocess_clip_output(model_out):
33 | return {
34 | "image_features": model_out[0],
35 | "text_features": model_out[1],
36 | "logit_scale": model_out[2]
37 | }
38 |
39 | def unwrap_model(model):
40 | if hasattr(model, 'module'):
41 | return model.module
42 | else:
43 | return model
44 |
45 |
46 | def backward(total_loss, scaler):
47 | if scaler is not None:
48 | scaler.scale(total_loss).backward()
49 | else:
50 | total_loss.backward()
51 |
52 |
53 | @torch.no_grad()
54 | def student_teacher_ensemble(student, teacher, alpha=0.5):
55 | target_state_dict = {}
56 | for k, v in student.items():
57 | target_state_dict[k] = v * alpha + teacher[k] * (1.0 - alpha)
58 |
59 | return target_state_dict
60 |
61 |
62 | def train_one_epoch(model, method, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args):
63 | device = torch.device(args.device)
64 | autocast = get_autocast(args.precision)
65 | cast_dtype = get_cast_dtype(args.precision)
66 |
67 | model.train()
68 | itc_loss = ClipLoss(
69 | local_loss=args.local_loss,
70 | gather_with_grad=args.gather_with_grad,
71 | cache_labels=True,
72 | rank=args.rank,
73 | world_size=args.world_size,
74 | )
75 | if dist_model is not None:
76 | dist_model.eval()
77 |
78 | data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch
79 | dataloader = data['train'].dataloader
80 | num_batches_per_epoch = dataloader.num_batches // args.accum_freq
81 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
82 |
83 | losses_m = {}
84 | batch_time_m = AverageMeter()
85 | data_time_m = AverageMeter()
86 | end = time.time()
87 | for i, batch in enumerate(dataloader):
88 | i_accum = i // args.accum_freq
89 | step = num_batches_per_epoch * epoch + i_accum
90 |
91 | if not args.skip_scheduler:
92 | scheduler(step)
93 |
94 | data_time_m.update(time.time() - end)
95 | optimizer.zero_grad()
96 | assert args.accum_freq == 1, "accum freq disabled"
97 | with autocast():
98 | losses, batch_size, logit_scale = method(batch, model, dist_model, itc_loss, device, cast_dtype,
99 | args.distributed, args)
100 | total_loss = sum(losses.values())
101 | losses["loss"] = total_loss
102 |
103 | backward(total_loss, scaler)
104 |
105 | if scaler is not None:
106 | if args.horovod:
107 | optimizer.synchronize()
108 | scaler.unscale_(optimizer)
109 | if args.grad_clip_norm is not None:
110 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
111 | with optimizer.skip_synchronize():
112 | scaler.step(optimizer)
113 | else:
114 | if args.grad_clip_norm is not None:
115 | scaler.unscale_(optimizer)
116 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
117 | scaler.step(optimizer)
118 | scaler.update()
119 | else:
120 | if args.grad_clip_norm is not None:
121 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
122 | optimizer.step()
123 |
124 | # Note: we clamp to 4.6052 = ln(100), as in the original paper.
125 | with torch.no_grad():
126 | unwrap_model(model).logit_scale.clamp_(0, math.log(100))
127 |
128 | batch_time_m.update(time.time() - end)
129 | end = time.time()
130 | batch_count = i_accum + 1
131 | if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):
132 | # batch_size = len(images)
133 | num_samples = batch_count * batch_size * args.accum_freq * args.world_size
134 | samples_per_epoch = dataloader.num_samples
135 | percent_complete = 100.0 * batch_count / num_batches_per_epoch
136 |
137 | # NOTE loss is coarsely sampled, just master node and per log update
138 | for key, val in losses.items():
139 | if key not in losses_m:
140 | losses_m[key] = AverageMeter()
141 | losses_m[key].update(val.item(), batch_size)
142 |
143 | logit_scale_scalar = logit_scale.item()
144 | loss_log = " ".join(
145 | [
146 | f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})"
147 | for loss_name, loss_m in losses_m.items()
148 | ]
149 | )
150 | samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val
151 | samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val
152 | logging.info(
153 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
154 | f"Data (t): {data_time_m.avg:.3f} "
155 | f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu "
156 | f"LR: {optimizer.param_groups[0]['lr']:5f} "
157 | f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log
158 | )
159 |
160 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
161 | log_data = {
162 | "data_time": data_time_m.val,
163 | "batch_time": batch_time_m.val,
164 | "samples_per_second": samples_per_second,
165 | "samples_per_second_per_gpu": samples_per_second_per_gpu,
166 | "scale": logit_scale_scalar,
167 | "lr": optimizer.param_groups[0]["lr"]
168 | }
169 | log_data.update({name:val.val for name,val in losses_m.items()})
170 | # resetting batch / data time meters per log window
171 | batch_time_m.reset()
172 | data_time_m.reset()
173 |
174 |
175 | def evaluate(model, data, epoch, args):
176 | metrics = {}
177 | model.eval()
178 |
179 | zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
180 | if not is_master(args):
181 | return {}
182 | metrics.update(zero_shot_metrics)
183 | if not metrics:
184 | return metrics
185 |
186 | keys = ''.join([f"{k}, " for k in metrics.keys() if 'all' in k])[:-2]
187 | values = ''.join([f'{round(v, 4):.4f}, ' for k, v in metrics.items() if 'all' in k])[:-2]
188 |
189 | logging.info(
190 | f"Eval Epoch: {epoch}. "
191 | + f"{keys}: {values}."
192 | )
193 | # TODO save the results as plots
194 | logging.info(metrics)
195 |
196 | if args.save_logs:
197 | with open(os.path.join(args.checkpoint_path, "results.json"), "a+") as f:
198 | f.write(json.dumps(metrics))
199 | f.write("\n")
200 |
201 | return metrics
202 |
--------------------------------------------------------------------------------
/training/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from functools import partial
3 | from six.moves import map, zip
4 |
5 |
6 | def multi_apply(func, *args, **kwargs):
7 | """Apply function to a list of arguments.
8 | Note:
9 | This function applies the ``func`` to multiple inputs and
10 | map the multiple outputs of the ``func`` into different
11 | list. Each list contains the same type of outputs corresponding
12 | to different inputs.
13 | Args:
14 | func (Function): A function that will be applied to a list of
15 | arguments
16 | Returns:
17 | tuple(list): A tuple containing multiple list, each list contains \
18 | a kind of returned results by the function
19 | """
20 | pfunc = partial(func, **kwargs) if kwargs else func
21 | map_results = map(pfunc, *args)
22 | return tuple(map(list, zip(*map_results)))
23 |
24 |
25 | def mask2box(mask):
26 | ys, xs = np.where(mask)
27 | y0, y1 = ys.min(), ys.max()
28 | x0, x1 = xs.min(), xs.max()
29 |
30 | return x0, y0, x1, y1
31 |
--------------------------------------------------------------------------------
/training/zero_shot.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import torch.nn.functional as F
4 | from training.dist_utils import all_gather
5 | from tqdm import tqdm
6 | from .distributed import is_master
7 | from open_clip import get_cast_dtype
8 | from .precision import get_autocast
9 |
10 |
11 | def run(model, dataloader, args):
12 | cls_embeddings = dataloader.dataset.embeddings
13 | cls_embeddings = F.normalize(torch.from_numpy(cls_embeddings).float(), dim=-1)
14 | cls_embeddings = cls_embeddings.to(args.device)
15 | autocast = get_autocast(args.precision)
16 | cast_dtype = get_cast_dtype(args.precision)
17 | if cast_dtype is not None:
18 | cls_embeddings = cls_embeddings.to(dtype=cast_dtype)
19 | with torch.no_grad():
20 | correct_rois = []
21 | all_cls_labels = []
22 | for images, bboxes, _, gt_masks, _ \
23 | in tqdm(dataloader, disable=not is_master(args)):
24 | images = images.to(args.device)
25 | bboxes = bboxes.to(args.device)
26 | gt_masks = gt_masks.to(args.device)
27 | if cast_dtype is not None:
28 | images = images.to(dtype=cast_dtype)
29 | bboxes = bboxes.to(dtype=cast_dtype)
30 | gt_masks = gt_masks.to(dtype=cast_dtype)
31 | cls_labels = []
32 | rois = []
33 | for bboxes_per_image in bboxes:
34 | valid = bboxes_per_image[:, 5] > 0.5
35 | rois.append(bboxes_per_image[valid, :4])
36 | cls_labels.append(bboxes_per_image[valid, 4])
37 | cls_labels = torch.cat(cls_labels, dim=0).to(torch.long)
38 | if cls_labels.shape[0] == 0:
39 | continue
40 | with autocast():
41 | # predict
42 | if args.distributed and not args.horovod:
43 | module = model.module
44 | else:
45 | module = model
46 | roi_extractor = module.encode_pseudo_boxes
47 | roi_features = roi_extractor(images, rois, normalize=True,
48 | extract_type=args.extract_type)
49 |
50 | if cast_dtype is not None:
51 | roi_features = roi_features.to(dtype=cast_dtype)
52 |
53 | roi_logits = roi_features @ cls_embeddings.T
54 |
55 | _, roi_top5_inds = roi_logits.topk(5)
56 | correct_rois.append(roi_top5_inds == cls_labels.view(-1, 1))
57 | all_cls_labels.append(cls_labels)
58 |
59 | correct_rois = torch.cat(correct_rois).float()
60 | all_cls_labels = torch.cat(all_cls_labels)
61 | if args.distributed and not args.horovod:
62 | correct_rois = multi_gpu_sync(correct_rois)
63 | all_cls_labels = multi_gpu_sync(all_cls_labels)
64 |
65 | return correct_rois, all_cls_labels
66 |
67 |
68 | def multi_gpu_sync(x):
69 | device = x.device
70 | x_list = all_gather(x.cpu())
71 | x = torch.cat([res.to(device) for res in x_list])
72 | return x
73 |
74 |
75 | def macc_with_box(correct_matrix, all_cls_labels, prefix):
76 | def _macc(corrects, cls_labels):
77 | min_id = cls_labels.min().item()
78 | max_id = cls_labels.max().item()
79 | cand_labels = list(range(min_id, max_id+1))
80 |
81 | acc_per_cls = []
82 |
83 | for lb in cand_labels:
84 | corrects_per_cls = corrects[cls_labels == lb]
85 | if corrects_per_cls.shape[0] == 0:
86 | continue
87 | acc_per_cls.append(corrects_per_cls.mean().half().item())
88 |
89 | return sum(acc_per_cls) / len(acc_per_cls)
90 |
91 | results = {}
92 |
93 | box_top1_acc = _macc(correct_matrix[:, 0], all_cls_labels)
94 | box_top5_acc = _macc(correct_matrix.sum(-1), all_cls_labels)
95 |
96 | results[f'{prefix}.box.macc1'] = box_top1_acc
97 | results[f'{prefix}.box.macc5'] = box_top5_acc
98 |
99 | return results
100 |
101 |
102 | def zero_shot_eval(model, data, epoch, args):
103 | if 'val' not in data:
104 | return {}
105 | if args.zeroshot_frequency == 0:
106 | return {}
107 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
108 | return {}
109 | logging.info('Region classifier')
110 | results = {}
111 | correct_rois, all_cls_labels = run(model, data['val'].dataloader, args)
112 | results.update(macc_with_box(correct_rois, all_cls_labels, 'rois'))
113 | return results
114 |
--------------------------------------------------------------------------------