├── LICENSE ├── README.md ├── __pycache__ ├── dataset.cpython-38.pyc └── few_shot.cpython-38.pyc ├── data ├── mvtec.py └── visa.py ├── dataset.py ├── few_shot.py ├── few_shot.sh ├── reproduce_WinCLIP.py ├── requirements.txt ├── src ├── open_clip │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── coca_model.cpython-38.pyc │ │ ├── constants.cpython-38.pyc │ │ ├── factory.cpython-38.pyc │ │ ├── hf_configs.cpython-38.pyc │ │ ├── hf_model.cpython-38.pyc │ │ ├── loss.cpython-38.pyc │ │ ├── model.cpython-38.pyc │ │ ├── model_revise.cpython-38.pyc │ │ ├── model_revise_learn.cpython-38.pyc │ │ ├── modified_resnet.cpython-38.pyc │ │ ├── openai.cpython-38.pyc │ │ ├── pretrained.cpython-38.pyc │ │ ├── push_to_hf_hub.cpython-38.pyc │ │ ├── timm_model.cpython-38.pyc │ │ ├── tokenizer.cpython-38.pyc │ │ ├── transform.cpython-38.pyc │ │ ├── transformer.cpython-38.pyc │ │ ├── utils.cpython-38.pyc │ │ └── version.cpython-38.pyc │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── coca_model.py │ ├── constants.py │ ├── factory.py │ ├── generation_utils.py │ ├── hf_configs.py │ ├── hf_model.py │ ├── loss.py │ ├── model.py │ ├── model_configs │ │ ├── RN101-quickgelu.json │ │ ├── RN101.json │ │ ├── RN50-quickgelu.json │ │ ├── RN50.json │ │ ├── RN50x16.json │ │ ├── RN50x4.json │ │ ├── RN50x64.json │ │ ├── ViT-B-16-plus-240.json │ │ ├── ViT-B-16-plus.json │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32-plus-256.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── ViT-B-32.json │ │ ├── ViT-H-14.json │ │ ├── ViT-H-16.json │ │ ├── ViT-L-14-280.json │ │ ├── ViT-L-14-336.json │ │ ├── ViT-L-14.json │ │ ├── ViT-L-16-320.json │ │ ├── ViT-L-16.json │ │ ├── ViT-M-16-alt.json │ │ ├── ViT-M-16.json │ │ ├── ViT-M-32-alt.json │ │ ├── ViT-M-32.json │ │ ├── ViT-S-16-alt.json │ │ ├── ViT-S-16.json │ │ ├── ViT-S-32-alt.json │ │ ├── ViT-S-32.json │ │ ├── ViT-bigG-14.json │ │ ├── ViT-e-14.json │ │ ├── ViT-g-14.json │ │ ├── coca_ViT-B-32.json │ │ ├── coca_ViT-L-14.json │ │ ├── coca_base.json │ │ ├── coca_roberta-ViT-B-32.json │ │ ├── convnext_base.json │ │ ├── convnext_base_w.json │ │ ├── convnext_base_w_320.json │ │ ├── convnext_large.json │ │ ├── convnext_large_d.json │ │ ├── convnext_large_d_320.json │ │ ├── convnext_small.json │ │ ├── convnext_tiny.json │ │ ├── convnext_xlarge.json │ │ ├── convnext_xxlarge.json │ │ ├── convnext_xxlarge_320.json │ │ ├── mt5-base-ViT-B-32.json │ │ ├── mt5-xl-ViT-H-14.json │ │ ├── roberta-ViT-B-32.json │ │ ├── swin_base_patch4_window7_224.json │ │ ├── vit_medium_patch16_gap_256.json │ │ ├── vit_relpos_medium_patch16_cls_224.json │ │ ├── xlm-roberta-base-ViT-B-32.json │ │ └── xlm-roberta-large-ViT-H-14.json │ ├── model_revise.py │ ├── modified_resnet.py │ ├── openai.py │ ├── pretrained.py │ ├── push_to_hf_hub.py │ ├── timm_model.py │ ├── tokenizer.py │ ├── transform.py │ ├── transformer.py │ ├── utils.py │ ├── version.py │ └── xx.py ├── test.py └── training │ ├── .gitignore │ ├── __init__.py │ ├── data.py │ ├── distributed.py │ ├── file_utils.py │ ├── imagenet_zeroshot_data.py │ ├── logger.py │ ├── main.py │ ├── params.py │ ├── precision.py │ ├── profile.py │ ├── scheduler.py │ ├── train.py │ └── zero_shot.py └── zero_shot.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Qihang Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WinCLIP 2 | This is an unofficial implementation of [WinCLIP](https://openaccess.thecvf.com/content/CVPR2023/papers/Jeong_WinCLIP_Zero-Few-Shot_Anomaly_Classification_and_Segmentation_CVPR_2023_paper.pdf) in [AnomalyCLIP](https://arxiv.org/abs/2310.18961) 3 | 4 | The implementation of CLIP is based on [open_clip](https://github.com/mlfoundations/open_clip) 5 | ## Updates 6 | 7 | - **03.20.2024**: Update the 2-shot, 4-shot, and 8-shot results of VisA. 8 | - **08.08.2024**: Update the visa.py to generate the JSON for VisA. 9 | 10 | ## Performance evaluation 11 | ### Few-shot 12 | #### MVTec AD (1-shot) 13 | | objects | auroc_px | f1_px | ap_px | aupro | auroc_sp | f1_sp | ap_sp | 14 | |:-----------|-----------:|--------:|--------:|--------:|-----------:|--------:|--------:| 15 | | carpet | 99.1 | 66.1 | 69.4 | 95.9 | 100 | 99.4 | 100 | 16 | | bottle | 94.3 | 60.9 | 64.9 | 85.1 | 99.4 | 98.4 | 99.8 | 17 | | hazelnut | 98.5 | 58.9 | 61.1 | 93.4 | 98 | 95.6 | 99 | 18 | | leather | 99.2 | 45.4 | 39.3 | 97.8 | 100 | 99.5 | 100 | 19 | | cable | 86.9 | 28.7 | 22.8 | 65 | 89.2 | 86.3 | 93.4 | 20 | | capsule | 96.4 | 32.3 | 24.7 | 89.7 | 83.5 | 92.4 | 96.3 | 21 | | grid | 94.1 | 28.7 | 19.1 | 82.1 | 99.6 | 99.1 | 99.9 | 22 | | pill | 92.4 | 36.1 | 28.7 | 89.8 | 89.6 | 93.3 | 98 | 23 | | transistor | 90 | 41.2 | 41.1 | 67.5 | 89.6 | 80.9 | 85.7 | 24 | | metal_nut | 78.5 | 36.5 | 28.7 | 75.3 | 98.2 | 97.4 | 99.6 | 25 | | screw | 95.9 | 23.5 | 14.4 | 84.5 | 81.5 | 86.8 | 93.1 | 26 | | toothbrush | 96 | 33.6 | 26.3 | 82.8 | 91.4 | 90.6 | 96.6 | 27 | | zipper | 97 | 46.5 | 40.8 | 90.5 | 86.4 | 90.3 | 95.8 | 28 | | tile | 91.7 | 53.5 | 46.2 | 77.5 | 100 | 99.4 | 100 | 29 | | wood | 94.5 | 56.4 | 59.4 | 84.5 | 99 | 96.8 | 99.7 | 30 | | mean | 93.6 | 43.2 | 39.1 | 84.1 | 93.7 | 93.7 | 97.1 | 31 | #### MVTec AD (2-shot) 32 | | objects | auroc_px | f1_px | ap_px | aupro | auroc_sp | f1_sp | ap_sp | 33 | |:-----------|-----------:|--------:|--------:|--------:|-----------:|--------:|--------:| 34 | | carpet | 99 | 64.5 | 68.2 | 95.6 | 99.8 | 98.9 | 99.9 | 35 | | bottle | 94.8 | 62.4 | 66.3 | 85.9 | 99.6 | 98.4 | 99.9 | 36 | | hazelnut | 98.7 | 61.6 | 63.9 | 93.6 | 97.9 | 95.7 | 98.9 | 37 | | leather | 99.2 | 45.2 | 39 | 97.9 | 99.9 | 99.5 | 100 | 38 | | cable | 88.8 | 31.5 | 25.3 | 72.5 | 91 | 89.5 | 94 | 39 | | capsule | 95.6 | 23.1 | 11.9 | 86.8 | 66 | 92.6 | 88.1 | 40 | | grid | 94.8 | 30.3 | 20.7 | 83.9 | 99.4 | 99.1 | 99.8 | 41 | | pill | 92.8 | 39.5 | 32.9 | 90.4 | 92.9 | 95.3 | 98.6 | 42 | | transistor | 89.8 | 41 | 40.4 | 66.7 | 89.5 | 79.2 | 85.6 | 43 | | metal_nut | 76.7 | 35.2 | 26.8 | 73.8 | 98.5 | 98.4 | 99.7 | 44 | | screw | 96.7 | 25.6 | 18 | 87.5 | 82.9 | 86.9 | 93.5 | 45 | | toothbrush | 96.4 | 36.6 | 29.9 | 82 | 93.3 | 92.1 | 97.6 | 46 | | zipper | 97.2 | 50 | 43.9 | 91.1 | 95.2 | 94.8 | 98.7 | 47 | | tile | 92 | 53.9 | 46.4 | 78 | 99.9 | 99.4 | 100 | 48 | | wood | 94.5 | 56.2 | 58.5 | 86 | 99.5 | 98.3 | 99.8 | 49 | | mean | 93.8 | 43.8 | 39.5 | 84.8 | 93.7 | 94.5 | 96.9 | 50 | #### MVTec AD (4-shot) 51 | | objects | auroc_px | f1_px | ap_px | aupro | auroc_sp | f1_sp | ap_sp | 52 | |:-----------|-----------:|--------:|--------:|--------:|-----------:|--------:|--------:| 53 | | carpet | 99 | 65.1 | 68.9 | 95.5 | 99.9 | 99.4 | 100 | 54 | | bottle | 94.4 | 62.1 | 65.7 | 85.2 | 99.4 | 97.6 | 99.8 | 55 | | hazelnut | 98.5 | 60.7 | 62.5 | 92.8 | 97.6 | 95 | 98.8 | 56 | | leather | 99.3 | 45.4 | 39.3 | 97.8 | 100 | 100 | 100 | 57 | | cable | 89 | 31.7 | 25.8 | 71.4 | 89.6 | 88.4 | 92.9 | 58 | | capsule | 97.2 | 35.7 | 27.9 | 91.1 | 86.5 | 94 | 96.9 | 59 | | grid | 95.1 | 30 | 22 | 84 | 99.7 | 98.2 | 99.9 | 60 | | pill | 93 | 40.9 | 34.4 | 90.9 | 92.4 | 94.1 | 98.5 | 61 | | transistor | 89.4 | 40.6 | 39.2 | 65.5 | 90.4 | 80.4 | 87.3 | 62 | | metal_nut | 80.2 | 38 | 31.1 | 78 | 99.3 | 98.4 | 99.8 | 63 | | screw | 96 | 22 | 15.1 | 85 | 81.4 | 89.1 | 91.6 | 64 | | toothbrush | 98.2 | 55.1 | 50.8 | 88.6 | 98.1 | 96.7 | 99.3 | 65 | | zipper | 97.4 | 51.3 | 46.2 | 91.2 | 95.5 | 94.8 | 98.8 | 66 | | tile | 91.7 | 53.1 | 45.3 | 77.7 | 100 | 99.4 | 100 | 67 | | wood | 94.5 | 56.6 | 59.3 | 86.6 | 99.3 | 97.5 | 99.8 | 68 | | mean | 94.2 | 45.9 | 42.2 | 85.4 | 95.3 | 94.9 | 97.6 | 69 | #### MVTec AD (8-shot) 70 | | objects | auroc_px | f1_px | ap_px | aupro | auroc_sp | f1_sp | ap_sp | 71 | |:-----------|-----------:|--------:|--------:|--------:|-----------:|--------:|--------:| 72 | | carpet | 98.9 | 64.8 | 68.5 | 94.7 | 99.8 | 99.4 | 99.9 | 73 | | bottle | 94.8 | 61.9 | 65.9 | 85.3 | 99.5 | 98.4 | 99.8 | 74 | | hazelnut | 98.8 | 62.3 | 64.3 | 93.8 | 98.4 | 96.4 | 99.1 | 75 | | leather | 99.2 | 44.8 | 38.2 | 97.4 | 100 | 100 | 100 | 76 | | cable | 89.7 | 32.9 | 26.8 | 73.9 | 91.4 | 89.8 | 94 | 77 | | capsule | 96.9 | 34.6 | 27 | 90.6 | 85.5 | 93.6 | 96.6 | 78 | | grid | 95.6 | 31 | 23.6 | 85.9 | 99.6 | 99.1 | 99.9 | 79 | | pill | 93.4 | 42.4 | 35.9 | 91.4 | 93.3 | 94.4 | 98.7 | 80 | | transistor | 90.8 | 42.6 | 41.6 | 68.6 | 91.1 | 81.4 | 86.8 | 81 | | metal_nut | 79.9 | 37.9 | 30.6 | 77.9 | 99.2 | 98.9 | 99.8 | 82 | | screw | 96.8 | 17.9 | 12.4 | 86 | 80.4 | 88.8 | 91.1 | 83 | | toothbrush | 98.3 | 55.4 | 51.8 | 89.4 | 98.9 | 96.7 | 99.6 | 84 | | zipper | 97.4 | 51.2 | 45.5 | 91.6 | 97.6 | 96 | 99.4 | 85 | | tile | 91.6 | 53 | 45 | 76.6 | 100 | 99.4 | 100 | 86 | | wood | 94.8 | 56.1 | 59.2 | 87.3 | 99.6 | 98.4 | 99.9 | 87 | | mean | 94.5 | 45.9 | 42.4 | 86 | 95.6 | 95.4 | 97.6 | 88 | 89 | #### VisA (1-shot) 90 | | objects | auroc_px | f1_px | ap_px | aupro | auroc_sp | f1_sp | ap_sp | 91 | |:-----------|-----------:|--------:|--------:|--------:|-----------:|--------:|--------:| 92 | | candle | 94.8 | 16.8 | 8 | 90.5 | 96.4 | 91 | 96.9 | 93 | | capsules | 95.9 | 32.3 | 22.4 | 65.4 | 80.2 | 80.2 | 87.8 | 94 | | cashew | 96.6 | 32.3 | 22.2 | 90.3 | 95.4 | 91.8 | 97.9 | 95 | | chewinggum | 99 | 57 | 54.6 | 85.9 | 97.7 | 94.8 | 99 | 96 | | fryum | 94.4 | 32.4 | 26 | 86 | 87.7 | 86.6 | 94.4 | 97 | | macaroni1 | 91.5 | 13 | 4 | 78.9 | 85.6 | 78.8 | 87.7 | 98 | | macaroni2 | 91.6 | 4 | 0.9 | 72.1 | 75.4 | 72.4 | 74.6 | 99 | | pcb1 | 96 | 16.1 | 7.3 | 76.8 | 85.6 | 83.7 | 84.5 | 100 | | pcb2 | 91.5 | 6.6 | 2.9 | 66.4 | 59.6 | 67.9 | 57 | 101 | | pcb3 | 92.9 | 13.4 | 7.6 | 77.6 | 68.9 | 71.2 | 68.9 | 102 | | pcb4 | 95.3 | 22.4 | 15.9 | 82.4 | 85.5 | 79 | 85.6 | 103 | | pipe_fryum | 96.4 | 28 | 18.9 | 94.1 | 88 | 85.6 | 94.2 | 104 | | mean | 94.7 | 22.9 | 15.9 | 80.5 | 83.8 | 81.9 | 85.7 | 105 | #### VisA (2-shot) 106 | | objects | auroc_px | f1_px | ap_px | aupro | auroc_sp | f1_sp | ap_sp | 107 | |:-----------|-----------:|--------:|--------:|--------:|-----------:|--------:|--------:| 108 | | candle | 95.5 | 18.2 | 8.7 | 91.2 | 96 | 90.9 | 96.7 | 109 | | capsules | 95.8 | 31.8 | 22.5 | 65.4 | 82.4 | 83 | 89.8 | 110 | | cashew | 96.9 | 34.2 | 24 | 89.1 | 93.4 | 89.4 | 97 | 111 | | chewinggum | 99 | 57.4 | 56 | 86.5 | 98.2 | 95.9 | 99.2 | 112 | | fryum | 95 | 35.5 | 27.9 | 86.1 | 84.7 | 84.8 | 92.7 | 113 | | macaroni1 | 93.8 | 13.4 | 4.5 | 83.8 | 88 | 81.5 | 89.6 | 114 | | macaroni2 | 91.4 | 3.7 | 0.8 | 70.2 | 73.3 | 72.6 | 73.4 | 115 | | pcb1 | 96.2 | 17 | 8.1 | 77.7 | 85.4 | 83.7 | 83.6 | 116 | | pcb2 | 92.1 | 7.1 | 3.2 | 65.9 | 58 | 69.2 | 57.7 | 117 | | pcb3 | 93.8 | 19.2 | 10.3 | 80.7 | 72 | 70.3 | 70.5 | 118 | | pcb4 | 95.9 | 21.9 | 14.1 | 84.1 | 79.3 | 80.7 | 70.1 | 119 | | pipe_fryum | 96.2 | 28 | 18.7 | 93.9 | 89.7 | 87.9 | 95 | 120 | | mean | 95.1 | 23.9 | 16.6 | 81.2 | 83.4 | 82.5 | 84.6 | 121 | #### VisA (4-shot) 122 | | objects | auroc_px | f1_px | ap_px | aupro | auroc_sp | f1_sp | ap_sp | 123 | |:-----------|-----------:|--------:|--------:|--------:|-----------:|--------:|--------:| 124 | | candle | 95.8 | 19.2 | 9.3 | 91.2 | 96.7 | 90.7 | 97.2 | 125 | | capsules | 96.1 | 32.2 | 23.1 | 66.6 | 82.2 | 81.7 | 89.3 | 126 | | cashew | 96.7 | 32.4 | 22.2 | 89.4 | 93.4 | 89.3 | 97 | 127 | | chewinggum | 99 | 57.2 | 55.5 | 85.9 | 98 | 96.5 | 99.2 | 128 | | fryum | 94.9 | 34.3 | 27.8 | 87.5 | 87.1 | 86.6 | 93.8 | 129 | | macaroni1 | 93.9 | 14.2 | 4.8 | 83.6 | 89.1 | 82.9 | 90.7 | 130 | | macaroni2 | 90 | 4.8 | 1 | 68.9 | 76 | 73.2 | 76.1 | 131 | | pcb1 | 96.2 | 16.9 | 8 | 77.1 | 86.8 | 84.2 | 84.2 | 132 | | pcb2 | 91.7 | 9.8 | 4.5 | 65.4 | 59.6 | 68.8 | 59.4 | 133 | | pcb3 | 94.6 | 23.3 | 13.2 | 81.1 | 69.9 | 70.7 | 68.7 | 134 | | pcb4 | 96.7 | 30.8 | 23.4 | 86.1 | 80.7 | 76.6 | 79.5 | 135 | | pipe_fryum | 96.3 | 28.1 | 19.1 | 94.4 | 89.8 | 87.1 | 95 | 136 | | mean | 95.2 | 25.3 | 17.7 | 81.4 | 84.1 | 82.4 | 85.8 | 137 | #### VisA (8-shot) 138 | | objects | auroc_px | f1_px | ap_px | aupro | auroc_sp | f1_sp | ap_sp | 139 | |:-----------|-----------:|--------:|--------:|--------:|-----------:|--------:|--------:| 140 | | candle | 95.9 | 19.2 | 9.4 | 91.1 | 96.9 | 91.5 | 97.3 | 141 | | capsules | 96.2 | 32.2 | 23.2 | 66.9 | 82.6 | 83 | 89.3 | 142 | | cashew | 96.9 | 34.2 | 24.4 | 89.5 | 95.6 | 92.2 | 98 | 143 | | chewinggum | 99 | 56.7 | 54.5 | 85.8 | 97.9 | 96 | 99.1 | 144 | | fryum | 95 | 35.4 | 28 | 88.2 | 89.7 | 86.7 | 95.3 | 145 | | macaroni1 | 93.9 | 13.7 | 4.7 | 82.9 | 89.6 | 82 | 90.8 | 146 | | macaroni2 | 89 | 3.8 | 0.6 | 68.3 | 76.7 | 73.5 | 76.1 | 147 | | pcb1 | 96.2 | 17 | 8.3 | 76.8 | 87.4 | 84.7 | 85.4 | 148 | | pcb2 | 92.5 | 10.8 | 4.9 | 66 | 63.7 | 70.1 | 61 | 149 | | pcb3 | 95.2 | 23.9 | 14.7 | 81.7 | 76.1 | 74.8 | 74 | 150 | | pcb4 | 97.2 | 33.5 | 25.8 | 88 | 84.6 | 81.8 | 83.1 | 151 | | pipe_fryum | 96.5 | 29.3 | 19.7 | 94.3 | 91.2 | 89.3 | 95.7 | 152 | | mean | 95.3 | 25.8 | 18.2 | 81.6 | 86 | 83.8 | 87.1 | 153 | 154 | ### Zero-shot 155 | #### [MVTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad/) 156 | | objects | auroc_px | f1_px | ap_px | aupro | auroc_sp | f1_sp | ap_sp | 157 | |:-----------|-----------:|--------:|--------:|--------:|-----------:|--------:|--------:| 158 | | carpet | 90.9 | 33.9 | 26 | 66.3 | 99.3 | 97.8 | 99.8 | 159 | | bottle | 85.7 | 49.4 | 49.8 | 69.9 | 98.6 | 97.6 | 99.5 | 160 | | hazelnut | 95.7 | 39.1 | 33.3 | 81.3 | 92.3 | 88.6 | 96 | 161 | | leather | 95.5 | 30.8 | 20.5 | 86 | 100 | 100 | 100 | 162 | | cable | 61.3 | 12.2 | 6.2 | 39.4 | 85 | 84.8 | 89.8 | 163 | | capsule | 87 | 14.3 | 8.6 | 63.8 | 68.7 | 93.5 | 90.5 | 164 | | grid | 79.4 | 13.7 | 5.7 | 49.3 | 99.2 | 98.2 | 99.7 | 165 | | pill | 72.7 | 11.8 | 7 | 66.9 | 81.5 | 91.6 | 96.4 | 166 | | transistor | 83.7 | 27 | 20.2 | 45.5 | 89.1 | 80 | 84.9 | 167 | | metal_nut | 49.3 | 23.8 | 10.8 | 39.7 | 96.2 | 95.3 | 99.1 | 168 | | screw | 91.1 | 11.3 | 5.4 | 70.2 | 71.7 | 85.9 | 87.7 | 169 | | toothbrush | 86.2 | 10.5 | 5.5 | 67.9 | 85.3 | 88.9 | 94.5 | 170 | | zipper | 91.7 | 27.8 | 19.4 | 72 | 91.2 | 93.4 | 97.5 | 171 | | tile | 79.1 | 30.8 | 21.2 | 54.5 | 99.9 | 99.4 | 100 | 172 | | wood | 85.1 | 35.4 | 32.9 | 56.3 | 97.6 | 95.2 | 99.3 | 173 | | mean | 82.3 | 24.8 | 18.2 | 61.9 | 90.4 | 92.7 | 95.6 | 174 | 175 | #### VisA 176 | | objects | auroc_px | f1_px | ap_px | aupro | auroc_sp | f1_sp | ap_sp | 177 | |:-----------|-----------:|--------:|--------:|--------:|-----------:|--------:|--------:| 178 | | candle | 87 | 8.9 | 2.3 | 77.7 | 94.9 | 90.6 | 95.4 | 179 | | capsules | 80 | 4.2 | 1.4 | 39.4 | 79.4 | 80.5 | 87.9 | 180 | | cashew | 84.8 | 9.6 | 4.8 | 78.4 | 91.2 | 88.9 | 96 | 181 | | chewinggum | 95.4 | 31.5 | 24 | 69.6 | 95.5 | 93.8 | 98.2 | 182 | | fryum | 87.7 | 16.2 | 11.1 | 74.4 | 73.6 | 80 | 86.9 | 183 | | macaroni1 | 50.3 | 0.1 | 0 | 24.7 | 79 | 74.2 | 80 | 184 | | macaroni2 | 44.7 | 0.1 | 0 | 8 | 67.1 | 68.8 | 65.1 | 185 | | pcb1 | 38.7 | 0.9 | 0.4 | 20.7 | 72.1 | 70.2 | 73 | 186 | | pcb2 | 58.7 | 1.5 | 0.4 | 20.6 | 47 | 67.1 | 46.1 | 187 | | pcb3 | 76 | 2.1 | 0.7 | 43.7 | 63.9 | 67.6 | 63 | 188 | | pcb4 | 91.4 | 24.6 | 15.5 | 74.5 | 74.2 | 75.7 | 70.1 | 189 | | pipe_fryum | 83.6 | 8.3 | 4.4 | 80.3 | 67.8 | 80.3 | 82.1 | 190 | | mean | 73.2 | 9 | 5.4 | 51 | 75.5 | 78.2 | 78.7 | 191 | 192 | ## Quick start 193 | Zero-shot anomaly detection 194 | ```sh 195 | bash zero_shot.sh 196 | ``` 197 | Few-shot anomaly detection 198 | ```sh 199 | bash few_shot.sh 200 | ``` 201 | 202 | 203 | ## BibTex Citation 204 | 205 | If you find this paper and repository useful, please cite our paper. 206 | 207 | ``` 208 | @article{zhou2024anomalyclip, 209 | title={AnomalyCLIP: Object-agnostic Prompt Learning for Zero-shot Anomaly Detection}, 210 | author={Zhou, Qihang and Pang, Guansong and Tian, Yu and He, Shibo and Chen, Jiming}, 211 | journal={The Twelfth International Conference on Learning Representations}, 212 | year={2024} 213 | } 214 | 215 | @misc{jeong2023winclipzerofewshotanomalyclassification, 216 | title={WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation}, 217 | author={Jongheon Jeong and Yang Zou and Taewan Kim and Dongqing Zhang and Avinash Ravichandran and Onkar Dabeer}, 218 | year={2023}, 219 | eprint={2303.14814}, 220 | archivePrefix={arXiv}, 221 | primaryClass={cs.CV}, 222 | url={https://arxiv.org/abs/2303.14814}, 223 | } 224 | ``` 225 | -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/few_shot.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/__pycache__/few_shot.cpython-38.pyc -------------------------------------------------------------------------------- /data/mvtec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | class MVTecSolver(object): 6 | CLSNAMES = [ 7 | 'bottle', 'cable', 'capsule', 'carpet', 'grid', 8 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 9 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 10 | ] 11 | 12 | def __init__(self, root='data/mvtec'): 13 | self.root = root 14 | self.meta_path = f'{root}/meta.json' 15 | 16 | def run(self): 17 | info = dict(train={}, test={}) 18 | anomaly_samples = 0 19 | normal_samples = 0 20 | for cls_name in self.CLSNAMES: 21 | cls_dir = f'{self.root}/{cls_name}' 22 | for phase in ['train', 'test']: 23 | cls_info = [] 24 | species = os.listdir(f'{cls_dir}/{phase}') 25 | for specie in species: 26 | is_abnormal = True if specie not in ['good'] else False 27 | img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') 28 | mask_names = os.listdir(f'{cls_dir}/groundtruth/{specie}') if is_abnormal else None 29 | img_names.sort() 30 | mask_names.sort() if mask_names is not None else None 31 | for idx, img_name in enumerate(img_names): 32 | info_img = dict( 33 | img_path=f'{cls_name}/{phase}/{specie}/{img_name}', 34 | mask_path=f'{cls_name}/groundtruth/{specie}/{mask_names[idx]}' if is_abnormal else '', 35 | cls_name=cls_name, 36 | specie_name=specie, 37 | anomaly=1 if is_abnormal else 0, 38 | ) 39 | cls_info.append(info_img) 40 | if phase == 'test': 41 | if is_abnormal: 42 | anomaly_samples = anomaly_samples + 1 43 | else: 44 | normal_samples = normal_samples + 1 45 | info[phase][cls_name] = cls_info 46 | with open(self.meta_path, 'w') as f: 47 | f.write(json.dumps(info, indent=4) + "\n") 48 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 49 | if __name__ == '__main__': 50 | runner = MVTecSolver(root='/remote-home/iot_zhouqihang/data/mvdataset') 51 | runner.run() 52 | -------------------------------------------------------------------------------- /data/visa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | 5 | 6 | class VisASolver(object): 7 | CLSNAMES = [ 8 | 'candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 9 | 'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3', 10 | 'pcb4', 'pipe_fryum', 11 | ] 12 | 13 | def __init__(self, root='data/visa'): 14 | self.root = root 15 | self.meta_path = f'{root}/meta.json' 16 | self.phases = ['train', 'test'] 17 | self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0) 18 | 19 | def run(self): 20 | columns = self.csv_data.columns # [object, split, label, image, mask] 21 | info = {phase: {} for phase in self.phases} 22 | anomaly_samples = 0 23 | normal_samples = 0 24 | for cls_name in self.CLSNAMES: 25 | cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name] 26 | for phase in self.phases: 27 | cls_info = [] 28 | cls_data_phase = cls_data[cls_data[columns[1]] == phase] 29 | cls_data_phase.index = list(range(len(cls_data_phase))) 30 | for idx in range(cls_data_phase.shape[0]): 31 | data = cls_data_phase.loc[idx] 32 | is_abnormal = True if data[2] == 'anomaly' else False 33 | info_img = dict( 34 | img_path=data[3], 35 | mask_path=data[4] if is_abnormal else '', 36 | cls_name=cls_name, 37 | specie_name='', 38 | anomaly=1 if is_abnormal else 0, 39 | ) 40 | cls_info.append(info_img) 41 | if phase == 'test': 42 | if is_abnormal: 43 | anomaly_samples = anomaly_samples + 1 44 | else: 45 | normal_samples = normal_samples + 1 46 | info[phase][cls_name] = cls_info 47 | with open(self.meta_path, 'w') as f: 48 | f.write(json.dumps(info, indent=4) + "\n") 49 | print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples) 50 | 51 | 52 | if __name__ == '__main__': 53 | runner = VisASolver(root='/remote-home/iot_zhouqihang/data/Visa') 54 | runner.run() 55 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import json 3 | import random 4 | from PIL import Image 5 | import numpy as np 6 | import torch 7 | import os 8 | 9 | Vis_CLSNAMES = ['candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2', 10 | 'pcb1', 'pcb2', 'pcb3', 'pcb4', 'pipe_fryum'] 11 | 12 | Vis_CLSNAMES_map_index = {} 13 | for k, index in zip(Vis_CLSNAMES, range(len(Vis_CLSNAMES))): 14 | Vis_CLSNAMES_map_index[k] = index 15 | 16 | CLSNAMES = ['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill', 17 | 'transistor', 'metal_nut', 'screw', 'toothbrush', 'zipper', 'tile', 'wood'] 18 | CLSNAMES_map_index = {} 19 | for k, index in zip(CLSNAMES, range(len(CLSNAMES))): 20 | CLSNAMES_map_index[k] = index 21 | 22 | 23 | 24 | class VisaDataset(data.Dataset): 25 | def __init__(self, root, transform, target_transform, mode='test', k_shot=0, save_dir=None, obj_name=None): 26 | self.root = root 27 | self.transform = transform 28 | self.target_transform = target_transform 29 | 30 | self.data_all = [] 31 | meta_info = json.load(open(f'{self.root}/meta.json', 'r')) 32 | name = self.root.split('/')[-1] 33 | meta_info = meta_info[mode] 34 | 35 | if mode == 'train': 36 | self.cls_names = [obj_name] 37 | save_dir = os.path.join(save_dir, 'k_shot.txt') 38 | else: 39 | self.cls_names = list(meta_info.keys()) 40 | for cls_name in self.cls_names: 41 | if mode == 'train': 42 | data_tmp = meta_info[cls_name] 43 | indices = torch.randint(0, len(data_tmp), (k_shot,)) 44 | for i in range(len(indices)): 45 | self.data_all.append(data_tmp[indices[i]]) 46 | with open(save_dir, "a") as f: 47 | f.write(data_tmp[indices[i]]['img_path'] + '\n') 48 | else: 49 | self.data_all.extend(meta_info[cls_name]) 50 | self.length = len(self.data_all) 51 | 52 | def __len__(self): 53 | return self.length 54 | 55 | def __getitem__(self, index): 56 | data = self.data_all[index] 57 | img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \ 58 | data['specie_name'], data['anomaly'] 59 | img = Image.open(os.path.join(self.root, img_path)) 60 | if anomaly == 0: 61 | img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') 62 | else: 63 | img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0 64 | img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') 65 | img = self.transform(img) if self.transform is not None else img 66 | img_mask = self.target_transform( 67 | img_mask) if self.target_transform is not None and img_mask is not None else img_mask 68 | img_mask = [] if img_mask is None else img_mask 69 | 70 | return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly, 71 | 'img_path': os.path.join(self.root, img_path), "cls_id":Vis_CLSNAMES_map_index[cls_name]} 72 | 73 | 74 | 75 | class MVTecDataset(data.Dataset): 76 | def __init__(self, root, transform, target_transform, aug_rate, mode='test', k_shot=0, save_dir=None, obj_name=None): 77 | self.root = root 78 | self.transform = transform 79 | self.target_transform = target_transform 80 | self.aug_rate = aug_rate 81 | 82 | self.data_all = [] 83 | meta_info = json.load(open(f'{self.root}/meta.json', 'r')) 84 | name = self.root.split('/')[-1] 85 | meta_info = meta_info[mode] 86 | 87 | if mode == 'train': 88 | if isinstance(obj_name, list): 89 | self.cls_names = obj_name 90 | else: 91 | self.cls_names = [obj_name] 92 | save_dir = os.path.join(save_dir, 'k_shot.txt') 93 | else: 94 | self.cls_names = list(meta_info.keys()) 95 | for cls_name in self.cls_names: 96 | if mode == 'train': 97 | data_tmp = meta_info[cls_name] 98 | indices = torch.randint(0, len(data_tmp), (k_shot,)) 99 | for i in range(len(indices)): 100 | self.data_all.append(data_tmp[indices[i]]) 101 | with open(save_dir, "a") as f: 102 | f.write(data_tmp[indices[i]]['img_path'] + '\n') 103 | else: 104 | self.data_all.extend(meta_info[cls_name]) 105 | self.length = len(self.data_all) 106 | 107 | def __len__(self): 108 | return self.length 109 | 110 | 111 | def __getitem__(self, index): 112 | data = self.data_all[index] 113 | img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \ 114 | data['specie_name'], data['anomaly'] 115 | 116 | img = Image.open(os.path.join(self.root, img_path)) 117 | if anomaly == 0: 118 | img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') 119 | else: 120 | img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0 121 | img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') 122 | # transforms 123 | img = self.transform(img) if self.transform is not None else img 124 | img_mask = self.target_transform( 125 | img_mask) if self.target_transform is not None and img_mask is not None else img_mask 126 | img_mask = [] if img_mask is None else img_mask 127 | return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly, 128 | 'img_path': os.path.join(self.root, img_path), "cls_id":CLSNAMES_map_index[cls_name]} 129 | 130 | 131 | -------------------------------------------------------------------------------- /few_shot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataset import * 3 | 4 | 5 | 6 | from collections import OrderedDict 7 | def initialize_memory(obj_list): 8 | 9 | mid = [] 10 | large = [] 11 | patch = [] 12 | for x in obj_list: 13 | mid.append((x, [])) 14 | large.append((x, [])) 15 | patch.append((x, [])) 16 | mid_memory = OrderedDict(mid) 17 | large_memory = OrderedDict(large) 18 | patch_memory = OrderedDict(patch) 19 | return mid_memory, large_memory, patch_memory 20 | 21 | 22 | @torch.no_grad() 23 | def memory(model, obj_list, dataset_dir, save_path, preprocess, transform, k_shot, few_shot_features, 24 | dataset_name, device): 25 | normal_features_ls = {} 26 | mid_memory, large_memory, patch_memory = initialize_memory(obj_list) 27 | for i in range(len(obj_list)): 28 | if dataset_name == 'mvtec': 29 | normal_data = MVTecDataset(root=dataset_dir, transform=preprocess, target_transform=transform, 30 | aug_rate=-1, mode='train', k_shot=k_shot, save_dir=save_path, 31 | obj_name=obj_list[i]) 32 | elif dataset_name == 'visa': 33 | normal_data = VisaDataset(root=dataset_dir, transform=preprocess, target_transform=transform, 34 | mode='train', k_shot=k_shot, save_dir=save_path, obj_name=obj_list[i]) 35 | 36 | normal_dataloader = torch.utils.data.DataLoader(normal_data, batch_size=1, shuffle=False) 37 | for index, items in enumerate(normal_dataloader): 38 | 39 | images = items['img'].to(device) 40 | cls_name = items['cls_name'] 41 | cls_id = items['cls_id'] 42 | patch_size = 16 43 | gt_mask = items['img_mask'] 44 | gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0 45 | # print("class_name", cls_name) 46 | large_scale_tokens, mid_scale_tokens, patch_tokens, class_tokens, large_scale, mid_scale = model.encode_image(images, patch_size) 47 | # print("large_scale_tokens", large_scale_tokens.shape, mid_scale_tokens.shape, patch_tokens.shape) 48 | for class_name, tokens in zip(cls_name, large_scale_tokens): 49 | large_memory[class_name].append(tokens) 50 | for class_name, tokens in zip(cls_name, mid_scale_tokens): 51 | mid_memory[class_name].append(tokens) 52 | for class_name, tokens in zip(cls_name, patch_tokens): 53 | patch_memory[class_name].append(tokens) 54 | # print("lennnnnshape", tokens.shape) 55 | # print("large_memory", large_memory) 56 | # print("mid_memory", mid_memory) 57 | # print("large_memory", patch_memory) 58 | for class_name in obj_list: 59 | large_memory[class_name] = torch.cat(large_memory[class_name]) 60 | mid_memory[class_name] = torch.cat(mid_memory[class_name]) 61 | patch_memory[class_name] = torch.cat(patch_memory[class_name]) 62 | # print("lennnnnshape", patch_memory[class_name].shape) 63 | 64 | 65 | return large_memory, mid_memory, patch_memory -------------------------------------------------------------------------------- /few_shot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | few_shots=(1 2 4) 4 | 5 | for few_num in "${!few_shots[@]}";do 6 | ## train on the VisA dataset 7 | base_dir=winclip_mvtec 8 | save_dir=./exps_${base_dir}/mvtecvit_large_14_518/ 9 | 10 | CUDA_VISIBLE_DEVICES=3 python reproduce_WinCLIP.py --dataset mvtec \ 11 | --data_path /remote-home/iot_zhouqihang/data/mvdataset --save_path ./results/mvtec_${base_dir}/few_shot_${few_shots[few_num]} \ 12 | --model ViT-B-16-plus-240 --pretrained openai --k_shot ${few_shots[few_num]} --image_size 240 13 | wait 14 | done 15 | 16 | 17 | for few_num in "${!few_shots[@]}";do 18 | ## train on the VisA dataset 19 | base_dir=winclip_visa 20 | save_dir=./exps_${base_dir}/mvtecvit_large_14_518/ 21 | 22 | 23 | CUDA_VISIBLE_DEVICES=3 python reproduce_WinCLIP.py --dataset visa \ 24 | --data_path /remote-home/iot_zhouqihang/data/Visa --save_path ./results/mvtec_${base_dir}/few_shot_${few_shots[few_num]} \ 25 | --model ViT-B-16-plus-240 --pretrained openai --k_shot ${few_shots[few_num]} --image_size 240 26 | wait 27 | done 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | addict==2.4.0 3 | ansi2html==1.8.0 4 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work 5 | attrs==22.2.0 6 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work 7 | backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work 8 | beautifulsoup4==4.12.0 9 | bleach==6.0.0 10 | cachetools==5.3.1 11 | certifi==2022.12.7 12 | charset-normalizer==3.1.0 13 | click==8.1.3 14 | cmake==3.26.1 15 | colorama==0.4.6 16 | comm==0.1.4 17 | ConfigArgParse==1.7 18 | contourpy==1.0.7 19 | cycler==0.11.0 20 | dash==2.13.0 21 | dash-core-components==2.0.0 22 | dash-html-components==2.0.0 23 | dash-table==5.0.0 24 | debugpy @ file:///tmp/build/80754af9/debugpy_1637091796427/work 25 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work 26 | defusedxml==0.7.1 27 | einops==0.6.0 28 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work 29 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work 30 | fastjsonschema==2.16.3 31 | filelock==3.10.7 32 | Flask==2.2.5 33 | fonttools==4.39.3 34 | ftfy==6.1.1 35 | future==0.18.3 36 | google-auth==2.23.2 37 | google-auth-oauthlib==1.0.0 38 | grpcio==1.59.0 39 | huggingface-hub==0.13.3 40 | hydra-core==0.11.3 41 | idna==3.4 42 | imageio==2.27.0 43 | importlib-metadata==6.1.0 44 | importlib-resources==5.12.0 45 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1655369107642/work 46 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1680185408135/work 47 | ipywidgets==8.1.1 48 | itsdangerous==2.1.2 49 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work 50 | Jinja2==3.1.2 51 | joblib==1.2.0 52 | jsonschema==4.17.3 53 | jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1633454794268/work 54 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1678994158927/work 55 | jupyterlab-pygments==0.2.2 56 | jupyterlab-widgets==3.0.9 57 | kiwisolver==1.4.4 58 | kornia==0.7.0 59 | lazy_loader==0.2 60 | lit==16.0.0 61 | Markdown==3.4.3 62 | markdown-it-py==2.2.0 63 | MarkupSafe==2.1.2 64 | matplotlib==3.7.1 65 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work 66 | mdurl==0.1.2 67 | mistune==2.0.5 68 | mmcv-full==1.7.1 69 | model-index==0.1.11 70 | mpmath==1.3.0 71 | nbclient==0.7.3 72 | nbconvert==7.3.0 73 | nbformat==5.7.0 74 | nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work 75 | networkx==3.0 76 | ninja==1.11.1 77 | numpy==1.23.5 78 | nvidia-cublas-cu11==11.10.3.66 79 | nvidia-cuda-cupti-cu11==11.7.101 80 | nvidia-cuda-nvrtc-cu11==11.7.99 81 | nvidia-cuda-runtime-cu11==11.7.99 82 | nvidia-cudnn-cu11==8.5.0.96 83 | nvidia-cufft-cu11==10.9.0.58 84 | nvidia-curand-cu11==10.2.10.91 85 | nvidia-cusolver-cu11==11.4.0.1 86 | nvidia-cusparse-cu11==11.7.4.91 87 | nvidia-nccl-cu11==2.14.3 88 | nvidia-nvtx-cu11==11.7.91 89 | oauthlib==3.2.2 90 | omegaconf==1.4.1 91 | open-clip-torch==2.16.0 92 | open3d==0.17.0 93 | opencv-python==4.7.0.72 94 | opencv-python-headless==4.8.0.76 95 | openmim==0.3.7 96 | ordered-set==4.1.0 97 | packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1673482170163/work 98 | pandas==2.0.0 99 | pandocfilters==1.5.0 100 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work 101 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work 102 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 103 | Pillow==9.4.0 104 | pkgutil_resolve_name==1.3.10 105 | platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1679871349196/work 106 | plotly==5.17.0 107 | # Editable install with no version control (pointnet2==3.0.0) 108 | -e /remote-home/iot_zhouqihang/root/zqh/AnomalyPointCLIP/Pointnet2_PyTorch-master 109 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1677600924538/work 110 | protobuf==3.20.3 111 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work 112 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 113 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work 114 | pyasn1==0.5.0 115 | pyasn1-modules==0.3.0 116 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1672682006896/work 117 | pyparsing==3.0.9 118 | pyquaternion==0.9.9 119 | pyrsistent==0.19.3 120 | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work 121 | pytorch-lightning==0.7.1 122 | pytz==2023.3 123 | PyWavelets==1.4.1 124 | PyYAML==6.0 125 | pyzmq==19.0.2 126 | regex==2023.3.23 127 | requests==2.28.2 128 | requests-oauthlib==1.3.1 129 | retrying==1.3.4 130 | rich==13.3.5 131 | rsa==4.9 132 | scikit-image==0.20.0 133 | scikit-learn==1.2.2 134 | scipy==1.9.1 135 | seaborn==0.12.2 136 | sentencepiece==0.1.97 137 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 138 | soupsieve==2.4 139 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work 140 | sympy==1.11.1 141 | tabulate==0.9.0 142 | tenacity==8.2.3 143 | tensorboard==2.14.0 144 | tensorboard-data-server==0.7.1 145 | threadpoolctl==3.1.0 146 | tifffile==2023.3.21 147 | timm==0.6.13 148 | tinycss2==1.2.1 149 | tomli==2.0.1 150 | torch==2.0.0 151 | torchsummary==1.5.1 152 | torchvision==0.15.1 153 | tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827257044/work 154 | tqdm==4.65.0 155 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work 156 | triton==2.0.0 157 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1678559861143/work 158 | tzdata==2023.3 159 | urllib3==1.26.15 160 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work 161 | webencodings==0.5.1 162 | Werkzeug==2.2.3 163 | widgetsnbextension==4.0.9 164 | yacs==0.1.8 165 | yapf==0.33.0 166 | zipp==3.15.0 167 | -------------------------------------------------------------------------------- /src/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 8 | from .openai import load_openai_model, list_openai_models 9 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 10 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 11 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 12 | from .tokenizer import SimpleTokenizer, tokenize, decode 13 | from .transform import image_transform, AugmentationCfg 14 | 15 | 16 | from .factory import create_customer_model_and_transforms -------------------------------------------------------------------------------- /src/open_clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/coca_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/coca_model.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/constants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/constants.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/factory.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/hf_configs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/hf_configs.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/hf_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/hf_model.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/model_revise.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/model_revise.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/model_revise_learn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/model_revise_learn.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/modified_resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/modified_resnet.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/openai.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/openai.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/pretrained.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/pretrained.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/timm_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/timm_model.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/transform.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/__pycache__/version.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/__pycache__/version.cpython-38.pyc -------------------------------------------------------------------------------- /src/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /src/open_clip/generation_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/open_clip/generation_utils.py -------------------------------------------------------------------------------- /src/open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | } 46 | -------------------------------------------------------------------------------- /src/open_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | 6 | import re 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch import TensorType 11 | 12 | try: 13 | import transformers 14 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 15 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 16 | BaseModelOutputWithPoolingAndCrossAttentions 17 | except ImportError as e: 18 | transformers = None 19 | 20 | 21 | class BaseModelOutput: 22 | pass 23 | 24 | 25 | class PretrainedConfig: 26 | pass 27 | 28 | from .hf_configs import arch_dict 29 | 30 | 31 | # utils 32 | def _camel2snake(s): 33 | return re.sub(r'(? torch.Tensor: 90 | # calculated ground-truth and cache if enabled 91 | if self.prev_num_logits != num_logits or device not in self.labels: 92 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 93 | if self.world_size > 1 and self.local_loss: 94 | labels = labels + num_logits * self.rank 95 | if self.cache_labels: 96 | self.labels[device] = labels 97 | self.prev_num_logits = num_logits 98 | else: 99 | labels = self.labels[device] 100 | return labels 101 | 102 | def get_logits(self, image_features, text_features, logit_scale): 103 | if self.world_size > 1: 104 | all_image_features, all_text_features = gather_features( 105 | image_features, text_features, 106 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 107 | 108 | if self.local_loss: 109 | logits_per_image = logit_scale * image_features @ all_text_features.T 110 | logits_per_text = logit_scale * text_features @ all_image_features.T 111 | else: 112 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 113 | logits_per_text = logits_per_image.T 114 | else: 115 | logits_per_image = logit_scale * image_features @ text_features.T 116 | logits_per_text = logit_scale * text_features @ image_features.T 117 | 118 | return logits_per_image, logits_per_text 119 | 120 | def forward(self, image_features, text_features, logit_scale, output_dict=False): 121 | device = image_features.device 122 | logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) 123 | 124 | labels = self.get_ground_truth(device, logits_per_image.shape[0]) 125 | 126 | total_loss = ( 127 | F.cross_entropy(logits_per_image, labels) + 128 | F.cross_entropy(logits_per_text, labels) 129 | ) / 2 130 | 131 | return {"contrastive_loss": total_loss} if output_dict else total_loss 132 | 133 | 134 | class CoCaLoss(ClipLoss): 135 | def __init__( 136 | self, 137 | caption_loss_weight, 138 | clip_loss_weight, 139 | pad_id=0, # pad_token for open_clip custom tokenizer 140 | local_loss=False, 141 | gather_with_grad=False, 142 | cache_labels=False, 143 | rank=0, 144 | world_size=1, 145 | use_horovod=False, 146 | ): 147 | super().__init__( 148 | local_loss=local_loss, 149 | gather_with_grad=gather_with_grad, 150 | cache_labels=cache_labels, 151 | rank=rank, 152 | world_size=world_size, 153 | use_horovod=use_horovod 154 | ) 155 | 156 | self.clip_loss_weight = clip_loss_weight 157 | self.caption_loss_weight = caption_loss_weight 158 | self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) 159 | 160 | def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): 161 | clip_loss = super().forward(image_features, text_features, logit_scale) 162 | clip_loss = self.clip_loss_weight * clip_loss 163 | 164 | caption_loss = self.caption_loss( 165 | logits.permute(0, 2, 1), 166 | labels, 167 | ) 168 | caption_loss = caption_loss * self.caption_loss_weight 169 | 170 | if output_dict: 171 | return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} 172 | 173 | return clip_loss, caption_loss 174 | 175 | 176 | class DistillClipLoss(ClipLoss): 177 | 178 | def dist_loss(self, teacher_logits, student_logits): 179 | return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) 180 | 181 | def forward( 182 | self, 183 | image_features, 184 | text_features, 185 | logit_scale, 186 | dist_image_features, 187 | dist_text_features, 188 | dist_logit_scale, 189 | output_dict=False, 190 | ): 191 | logits_per_image, logits_per_text = \ 192 | self.get_logits(image_features, text_features, logit_scale) 193 | 194 | dist_logits_per_image, dist_logits_per_text = \ 195 | self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) 196 | 197 | labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) 198 | 199 | contrastive_loss = ( 200 | F.cross_entropy(logits_per_image, labels) + 201 | F.cross_entropy(logits_per_text, labels) 202 | ) / 2 203 | 204 | distill_loss = ( 205 | self.dist_loss(dist_logits_per_image, logits_per_image) + 206 | self.dist_loss(dist_logits_per_text, logits_per_text) 207 | ) / 2 208 | 209 | if output_dict: 210 | return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} 211 | 212 | return contrastive_loss, distill_loss 213 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50x64.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": [ 6 | 3, 7 | 15, 8 | 36, 9 | 10 10 | ], 11 | "width": 128, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 1024, 18 | "heads": 16, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-M-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16, 8 | "ls_init_value": 1e-4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 384, 14 | "heads": 6, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-M-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-M-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-M-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-S-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-S-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-S-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-S-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-e-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 56, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.5715, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 36 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/coca_ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 512, 25 | "heads": 8, 26 | "layers": 12, 27 | "attn_pooler_heads": 8 28 | }, 29 | "custom_text": true 30 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/coca_ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 768, 25 | "heads": 12, 26 | "layers": 12, 27 | "attn_pooler_heads": 12 28 | }, 29 | "custom_text": true 30 | } 31 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/coca_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "multimodal_cfg": { 4 | "width": 768, 5 | "context_length": 76, 6 | "vocab_size": 64000, 7 | "mlp_ratio": 4, 8 | "layers": 12, 9 | "dim_head": 64, 10 | "heads": 12, 11 | "n_queries": 256, 12 | "attn_pooler_heads": 8 13 | }, 14 | "vision_cfg": { 15 | "image_size": 288, 16 | "layers": 12, 17 | "width": 768, 18 | "patch_size": 18, 19 | "output_tokens": true 20 | }, 21 | "text_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 64000, 24 | "layers": 12, 25 | "heads": 12, 26 | "width": 768, 27 | "embed_cls": true, 28 | "output_tokens": true 29 | }, 30 | "custom_text": true 31 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/coca_roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "output_tokens": true 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "linear", 14 | "width": 768, 15 | "output_tokens": true 16 | }, 17 | "multimodal_cfg": { 18 | "context_length": 76, 19 | "width": 768, 20 | "heads": 8, 21 | "layers": 12 22 | }, 23 | "custom_text": true 24 | } 25 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_base_w.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_base_w_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_large_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_large_d_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_small", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_tiny", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 20 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_xxlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/convnext_xxlarge_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/mt5-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "google/mt5-base", 11 | "hf_tokenizer_name": "google/mt5-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/mt5-xl-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "google/mt5-xl", 12 | "hf_tokenizer_name": "google/mt5-xl", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 640, 14 | "heads": 10, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/vit_medium_patch16_gap_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_medium_patch16_gap_256", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 256 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "xlm-roberta-base", 11 | "hf_tokenizer_name": "xlm-roberta-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "xlm-roberta-large", 12 | "hf_tokenizer_name": "xlm-roberta-large", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/open_clip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from open_clip.utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /src/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = 'fp32' if device == 'cpu' else 'fp16' 56 | 57 | if get_pretrained_url(name, 'openai'): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith('amp') or precision == 'fp32': 87 | model.float() 88 | elif precision == 'bf16': 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == 'fp32': 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /src/open_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from functools import partial 6 | from typing import Dict, Union 7 | 8 | from tqdm import tqdm 9 | 10 | from .version import __version__ 11 | 12 | try: 13 | from huggingface_hub import hf_hub_download 14 | hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) 15 | _has_hf_hub = True 16 | except ImportError: 17 | hf_hub_download = None 18 | _has_hf_hub = False 19 | 20 | 21 | def _pcfg(url='', hf_hub='', mean=None, std=None): 22 | return dict( 23 | url=url, 24 | hf_hub=hf_hub, 25 | mean=mean, 26 | std=std, 27 | ) 28 | 29 | 30 | _RN50 = dict( 31 | openai=_pcfg( 32 | "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), 33 | yfcc15m=_pcfg( 34 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), 35 | cc12m=_pcfg( 36 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), 37 | ) 38 | 39 | _RN50_quickgelu = dict( 40 | openai=_pcfg( 41 | "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), 42 | yfcc15m=_pcfg( 43 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), 44 | cc12m=_pcfg( 45 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), 46 | ) 47 | 48 | _RN101 = dict( 49 | openai=_pcfg( 50 | "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), 51 | yfcc15m=_pcfg( 52 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), 53 | ) 54 | 55 | _RN101_quickgelu = dict( 56 | openai=_pcfg( 57 | "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), 58 | yfcc15m=_pcfg( 59 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), 60 | ) 61 | 62 | _RN50x4 = dict( 63 | openai=_pcfg( 64 | "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), 65 | ) 66 | 67 | _RN50x16 = dict( 68 | openai=_pcfg( 69 | "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), 70 | ) 71 | 72 | _RN50x64 = dict( 73 | openai=_pcfg( 74 | "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), 75 | ) 76 | 77 | _VITB32 = dict( 78 | openai=_pcfg( 79 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 80 | laion400m_e31=_pcfg( 81 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 82 | laion400m_e32=_pcfg( 83 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 84 | laion2b_e16=_pcfg( 85 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), 86 | laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') 87 | ) 88 | 89 | _VITB32_quickgelu = dict( 90 | openai=_pcfg( 91 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 92 | laion400m_e31=_pcfg( 93 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 94 | laion400m_e32=_pcfg( 95 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 96 | ) 97 | 98 | _VITB16 = dict( 99 | openai=_pcfg( 100 | "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), 101 | laion400m_e31=_pcfg( 102 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), 103 | laion400m_e32=_pcfg( 104 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), 105 | # laion400m_32k=_pcfg( 106 | # url="", 107 | # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 108 | # laion400m_64k=_pcfg( 109 | # url="", 110 | # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 111 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), 112 | ) 113 | 114 | _VITB16_PLUS_240 = dict( 115 | laion400m_e31=_pcfg( 116 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), 117 | laion400m_e32=_pcfg( 118 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), 119 | ) 120 | 121 | _VITL14 = dict( 122 | openai=_pcfg( 123 | "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), 124 | laion400m_e31=_pcfg( 125 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), 126 | laion400m_e32=_pcfg( 127 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), 128 | laion2b_s32b_b82k=_pcfg( 129 | hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', 130 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 131 | ) 132 | 133 | _VITL14_336 = dict( 134 | openai=_pcfg( 135 | "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), 136 | ) 137 | 138 | _VITH14 = dict( 139 | laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), 140 | ) 141 | 142 | _VITg14 = dict( 143 | laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), 144 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), 145 | ) 146 | 147 | _VITbigG14 = dict( 148 | laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), 149 | ) 150 | 151 | _robertaViTB32 = dict( 152 | laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), 153 | ) 154 | 155 | _xlmRobertaBaseViTB32 = dict( 156 | laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), 157 | ) 158 | 159 | _xlmRobertaLargeFrozenViTH14 = dict( 160 | frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), 161 | ) 162 | 163 | _convnext_base = dict( 164 | laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), 165 | ) 166 | 167 | _convnext_base_w = dict( 168 | laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), 169 | laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), 170 | laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), 171 | ) 172 | 173 | _convnext_base_w_320 = dict( 174 | laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), 175 | laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), 176 | ) 177 | 178 | _convnext_large_d = dict( 179 | laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), 180 | ) 181 | 182 | _convnext_large_d_320 = dict( 183 | laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), 184 | laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), 185 | ) 186 | 187 | _convnext_xxlarge = dict( 188 | laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), 189 | laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), 190 | laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), 191 | ) 192 | 193 | _coca_VITB32 = dict( 194 | laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), 195 | mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') 196 | ) 197 | 198 | _coca_VITL14 = dict( 199 | laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), 200 | mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') 201 | ) 202 | 203 | 204 | _PRETRAINED = { 205 | "RN50": _RN50, 206 | "RN50-quickgelu": _RN50_quickgelu, 207 | "RN101": _RN101, 208 | "RN101-quickgelu": _RN101_quickgelu, 209 | "RN50x4": _RN50x4, 210 | "RN50x16": _RN50x16, 211 | "RN50x64": _RN50x64, 212 | "ViT-B-32": _VITB32, 213 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 214 | "ViT-B-16": _VITB16, 215 | "ViT-B-16-plus-240": _VITB16_PLUS_240, 216 | "ViT-L-14": _VITL14, 217 | "ViT-L-14-336": _VITL14_336, 218 | "ViT-H-14": _VITH14, 219 | "ViT-g-14": _VITg14, 220 | "ViT-bigG-14": _VITbigG14, 221 | "roberta-ViT-B-32": _robertaViTB32, 222 | "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, 223 | "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, 224 | "convnext_base": _convnext_base, 225 | "convnext_base_w": _convnext_base_w, 226 | "convnext_base_w_320": _convnext_base_w_320, 227 | "convnext_large_d": _convnext_large_d, 228 | "convnext_large_d_320": _convnext_large_d_320, 229 | "convnext_xxlarge": _convnext_xxlarge, 230 | "coca_ViT-B-32": _coca_VITB32, 231 | "coca_ViT-L-14": _coca_VITL14, 232 | } 233 | 234 | 235 | def _clean_tag(tag: str): 236 | # normalize pretrained tags 237 | return tag.lower().replace('-', '_') 238 | 239 | 240 | def list_pretrained(as_str: bool = False): 241 | """ returns list of pretrained models 242 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 243 | """ 244 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 245 | 246 | 247 | def list_pretrained_models_by_tag(tag: str): 248 | """ return all models having the specified pretrain tag """ 249 | models = [] 250 | tag = _clean_tag(tag) 251 | for k in _PRETRAINED.keys(): 252 | if tag in _PRETRAINED[k]: 253 | models.append(k) 254 | return models 255 | 256 | 257 | def list_pretrained_tags_by_model(model: str): 258 | """ return all pretrain tags for the specified model architecture """ 259 | tags = [] 260 | if model in _PRETRAINED: 261 | tags.extend(_PRETRAINED[model].keys()) 262 | return tags 263 | 264 | 265 | def is_pretrained_cfg(model: str, tag: str): 266 | if model not in _PRETRAINED: 267 | return False 268 | return _clean_tag(tag) in _PRETRAINED[model] 269 | 270 | 271 | def get_pretrained_cfg(model: str, tag: str): 272 | if model not in _PRETRAINED: 273 | return {} 274 | model_pretrained = _PRETRAINED[model] 275 | return model_pretrained.get(_clean_tag(tag), {}) 276 | 277 | 278 | def get_pretrained_url(model: str, tag: str): 279 | cfg = get_pretrained_cfg(model, _clean_tag(tag)) 280 | return cfg.get('url', '') 281 | 282 | 283 | def download_pretrained_from_url( 284 | url: str, 285 | cache_dir: Union[str, None] = None, 286 | ): 287 | if not cache_dir: 288 | cache_dir = os.path.expanduser("~/.cache/clip") 289 | os.makedirs(cache_dir, exist_ok=True) 290 | filename = os.path.basename(url) 291 | 292 | if 'openaipublic' in url: 293 | expected_sha256 = url.split("/")[-2] 294 | elif 'mlfoundations' in url: 295 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] 296 | else: 297 | expected_sha256 = '' 298 | 299 | download_target = os.path.join(cache_dir, filename) 300 | 301 | if os.path.exists(download_target) and not os.path.isfile(download_target): 302 | raise RuntimeError(f"{download_target} exists and is not a regular file") 303 | 304 | if os.path.isfile(download_target): 305 | if expected_sha256: 306 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 307 | return download_target 308 | else: 309 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 310 | else: 311 | return download_target 312 | 313 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 314 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 315 | while True: 316 | buffer = source.read(8192) 317 | if not buffer: 318 | break 319 | 320 | output.write(buffer) 321 | loop.update(len(buffer)) 322 | 323 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 324 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 325 | 326 | return download_target 327 | 328 | 329 | def has_hf_hub(necessary=False): 330 | if not _has_hf_hub and necessary: 331 | # if no HF Hub module installed, and it is necessary to continue, raise error 332 | raise RuntimeError( 333 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 334 | return _has_hf_hub 335 | 336 | 337 | def download_pretrained_from_hf( 338 | model_id: str, 339 | filename: str = 'open_clip_pytorch_model.bin', 340 | revision=None, 341 | cache_dir: Union[str, None] = None, 342 | ): 343 | has_hf_hub(True) 344 | cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) 345 | return cached_file 346 | 347 | 348 | def download_pretrained( 349 | cfg: Dict, 350 | force_hf_hub: bool = False, 351 | cache_dir: Union[str, None] = None, 352 | ): 353 | target = '' 354 | if not cfg: 355 | return target 356 | 357 | download_url = cfg.get('url', '') 358 | download_hf_hub = cfg.get('hf_hub', '') 359 | if download_hf_hub and force_hf_hub: 360 | # use HF hub even if url exists 361 | download_url = '' 362 | 363 | if download_url: 364 | target = download_pretrained_from_url(download_url, cache_dir=cache_dir) 365 | elif download_hf_hub: 366 | has_hf_hub(True) 367 | # we assume the hf_hub entries in pretrained config combine model_id + filename in 368 | # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and 369 | # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. 370 | model_id, filename = os.path.split(download_hf_hub) 371 | if filename: 372 | target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) 373 | else: 374 | target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) 375 | 376 | return target 377 | -------------------------------------------------------------------------------- /src/open_clip/push_to_hf_hub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from tempfile import TemporaryDirectory 5 | from typing import Optional, Tuple 6 | 7 | import torch 8 | 9 | try: 10 | from huggingface_hub import ( 11 | create_repo, 12 | get_hf_file_metadata, 13 | hf_hub_download, 14 | hf_hub_url, 15 | repo_type_and_id_from_hf_id, 16 | upload_folder, 17 | ) 18 | from huggingface_hub.utils import EntryNotFoundError 19 | _has_hf_hub = True 20 | except ImportError: 21 | _has_hf_hub = False 22 | 23 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer 24 | from .tokenizer import HFTokenizer 25 | 26 | 27 | def save_config_for_hf( 28 | model, 29 | config_path: str, 30 | model_config: Optional[dict] 31 | ): 32 | preprocess_cfg = { 33 | 'mean': model.visual.image_mean, 34 | 'std': model.visual.image_std, 35 | } 36 | hf_config = { 37 | 'model_cfg': model_config, 38 | 'preprocess_cfg': preprocess_cfg, 39 | } 40 | 41 | with config_path.open('w') as f: 42 | json.dump(hf_config, f, indent=2) 43 | 44 | 45 | def save_for_hf( 46 | model, 47 | tokenizer: HFTokenizer, 48 | model_config: dict, 49 | save_directory: str, 50 | weights_filename='open_clip_pytorch_model.bin', 51 | config_filename='open_clip_config.json', 52 | ): 53 | save_directory = Path(save_directory) 54 | save_directory.mkdir(exist_ok=True, parents=True) 55 | 56 | weights_path = save_directory / weights_filename 57 | torch.save(model.state_dict(), weights_path) 58 | 59 | tokenizer.save_pretrained(save_directory) 60 | 61 | config_path = save_directory / config_filename 62 | save_config_for_hf(model, config_path, model_config=model_config) 63 | 64 | 65 | def push_to_hf_hub( 66 | model, 67 | tokenizer, 68 | model_config: Optional[dict], 69 | repo_id: str, 70 | commit_message: str = 'Add model', 71 | token: Optional[str] = None, 72 | revision: Optional[str] = None, 73 | private: bool = False, 74 | create_pr: bool = False, 75 | model_card: Optional[dict] = None, 76 | ): 77 | if not isinstance(tokenizer, HFTokenizer): 78 | # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 79 | tokenizer = HFTokenizer('openai/clip-vit-large-patch14') 80 | 81 | # Create repo if it doesn't exist yet 82 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) 83 | 84 | # Infer complete repo_id from repo_url 85 | # Can be different from the input `repo_id` if repo_owner was implicit 86 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) 87 | repo_id = f"{repo_owner}/{repo_name}" 88 | 89 | # Check if README file already exist in repo 90 | try: 91 | get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) 92 | has_readme = True 93 | except EntryNotFoundError: 94 | has_readme = False 95 | 96 | # Dump model and push to Hub 97 | with TemporaryDirectory() as tmpdir: 98 | # Save model weights and config. 99 | save_for_hf( 100 | model, 101 | tokenizer=tokenizer, 102 | model_config=model_config, 103 | save_directory=tmpdir, 104 | ) 105 | 106 | # Add readme if it does not exist 107 | if not has_readme: 108 | model_card = model_card or {} 109 | model_name = repo_id.split('/')[-1] 110 | readme_path = Path(tmpdir) / "README.md" 111 | readme_text = generate_readme(model_card, model_name) 112 | readme_path.write_text(readme_text) 113 | 114 | # Upload model and return 115 | return upload_folder( 116 | repo_id=repo_id, 117 | folder_path=tmpdir, 118 | revision=revision, 119 | create_pr=create_pr, 120 | commit_message=commit_message, 121 | ) 122 | 123 | 124 | def push_pretrained_to_hf_hub( 125 | model_name, 126 | pretrained: str, 127 | repo_id: str, 128 | image_mean: Optional[Tuple[float, ...]] = None, 129 | image_std: Optional[Tuple[float, ...]] = None, 130 | commit_message: str = 'Add model', 131 | token: Optional[str] = None, 132 | revision: Optional[str] = None, 133 | private: bool = False, 134 | create_pr: bool = False, 135 | model_card: Optional[dict] = None, 136 | ): 137 | model, preprocess_eval = create_model_from_pretrained( 138 | model_name, 139 | pretrained=pretrained, 140 | image_mean=image_mean, 141 | image_std=image_std, 142 | ) 143 | 144 | model_config = get_model_config(model_name) 145 | assert model_config 146 | 147 | tokenizer = get_tokenizer(model_name) 148 | 149 | push_to_hf_hub( 150 | model=model, 151 | tokenizer=tokenizer, 152 | model_config=model_config, 153 | repo_id=repo_id, 154 | commit_message=commit_message, 155 | token=token, 156 | revision=revision, 157 | private=private, 158 | create_pr=create_pr, 159 | model_card=model_card, 160 | ) 161 | 162 | 163 | def generate_readme(model_card: dict, model_name: str): 164 | readme_text = "---\n" 165 | readme_text += "tags:\n- zero-shot-image-classification\n- clip\n" 166 | readme_text += "library_tag: open_clip\n" 167 | readme_text += f"license: {model_card.get('license', 'mit')}\n" 168 | if 'details' in model_card and 'Dataset' in model_card['details']: 169 | readme_text += 'datasets:\n' 170 | readme_text += f"- {model_card['details']['Dataset'].lower()}\n" 171 | readme_text += "---\n" 172 | readme_text += f"# Model card for {model_name}\n" 173 | if 'description' in model_card: 174 | readme_text += f"\n{model_card['description']}\n" 175 | if 'details' in model_card: 176 | readme_text += f"\n## Model Details\n" 177 | for k, v in model_card['details'].items(): 178 | if isinstance(v, (list, tuple)): 179 | readme_text += f"- **{k}:**\n" 180 | for vi in v: 181 | readme_text += f" - {vi}\n" 182 | elif isinstance(v, dict): 183 | readme_text += f"- **{k}:**\n" 184 | for ki, vi in v.items(): 185 | readme_text += f" - {ki}: {vi}\n" 186 | else: 187 | readme_text += f"- **{k}:** {v}\n" 188 | if 'usage' in model_card: 189 | readme_text += f"\n## Model Usage\n" 190 | readme_text += model_card['usage'] 191 | readme_text += '\n' 192 | 193 | if 'comparison' in model_card: 194 | readme_text += f"\n## Model Comparison\n" 195 | readme_text += model_card['comparison'] 196 | readme_text += '\n' 197 | 198 | if 'citation' in model_card: 199 | readme_text += f"\n## Citation\n" 200 | if not isinstance(model_card['citation'], (list, tuple)): 201 | citations = [model_card['citation']] 202 | else: 203 | citations = model_card['citation'] 204 | for c in citations: 205 | readme_text += f"```bibtex\n{c}\n```\n" 206 | 207 | return readme_text 208 | 209 | 210 | if __name__ == "__main__": 211 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") 212 | parser.add_argument( 213 | "--model", type=str, help="Name of the model to use.", 214 | ) 215 | parser.add_argument( 216 | "--pretrained", type=str, 217 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 218 | ) 219 | parser.add_argument( 220 | "--repo-id", type=str, 221 | help="Destination HF Hub repo-id ie 'organization/model_id'.", 222 | ) 223 | parser.add_argument( 224 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', 225 | help='Override default image mean value of dataset') 226 | parser.add_argument( 227 | '--image-std', type=float, nargs='+', default=None, metavar='STD', 228 | help='Override default image std deviation of of dataset') 229 | args = parser.parse_args() 230 | 231 | print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') 232 | 233 | # FIXME add support to pass model_card json / template from file via cmd line 234 | 235 | push_pretrained_to_hf_hub( 236 | args.model, 237 | args.pretrained, 238 | args.repo_id, 239 | image_mean=args.image_mean, # override image mean/std if trained w/ non defaults 240 | image_std=args.image_std, 241 | ) 242 | 243 | print(f'{args.model} saved.') 244 | -------------------------------------------------------------------------------- /src/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 31 | """ 32 | 33 | def __init__( 34 | self, 35 | model_name, 36 | embed_dim, 37 | image_size=224, 38 | pool='avg', 39 | proj='linear', 40 | proj_bias=False, 41 | drop=0., 42 | drop_path=None, 43 | pretrained=False, 44 | ): 45 | super().__init__() 46 | if timm is None: 47 | raise RuntimeError("Please `pip install timm` to use timm models.") 48 | 49 | self.image_size = to_2tuple(image_size) 50 | timm_kwargs = {} 51 | if drop_path is not None: 52 | timm_kwargs['drop_path_rate'] = drop_path 53 | self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs) 54 | feat_size = self.trunk.default_cfg.get('pool_size', None) 55 | feature_ndim = 1 if not feat_size else 2 56 | if pool in ('abs_attn', 'rot_attn'): 57 | assert feature_ndim == 2 58 | # if attn pooling used, remove both classifier and default pool 59 | self.trunk.reset_classifier(0, global_pool='') 60 | else: 61 | # reset global pool if pool config set, otherwise leave as network default 62 | reset_kwargs = dict(global_pool=pool) if pool else {} 63 | self.trunk.reset_classifier(0, **reset_kwargs) 64 | prev_chs = self.trunk.num_features 65 | 66 | head_layers = OrderedDict() 67 | if pool == 'abs_attn': 68 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 69 | prev_chs = embed_dim 70 | elif pool == 'rot_attn': 71 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 72 | prev_chs = embed_dim 73 | else: 74 | assert proj, 'projection layer needed if non-attention pooling is used.' 75 | 76 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 77 | if proj == 'linear': 78 | head_layers['drop'] = nn.Dropout(drop) 79 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 80 | elif proj == 'mlp': 81 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 82 | 83 | self.head = nn.Sequential(head_layers) 84 | 85 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 86 | """ lock modules 87 | Args: 88 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 89 | """ 90 | if not unlocked_groups: 91 | # lock full model 92 | for param in self.trunk.parameters(): 93 | param.requires_grad = False 94 | if freeze_bn_stats: 95 | freeze_batch_norm_2d(self.trunk) 96 | else: 97 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 98 | try: 99 | # FIXME import here until API stable and in an official release 100 | from timm.models.helpers import group_parameters, group_modules 101 | except ImportError: 102 | raise RuntimeError( 103 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 104 | matcher = self.trunk.group_matcher() 105 | gparams = group_parameters(self.trunk, matcher) 106 | max_layer_id = max(gparams.keys()) 107 | max_layer_id = max_layer_id - unlocked_groups 108 | for group_idx in range(max_layer_id + 1): 109 | group = gparams[group_idx] 110 | for param in group: 111 | self.trunk.get_parameter(param).requires_grad = False 112 | if freeze_bn_stats: 113 | gmodules = group_modules(self.trunk, matcher, reverse=True) 114 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 115 | freeze_batch_norm_2d(self.trunk, gmodules) 116 | 117 | @torch.jit.ignore 118 | def set_grad_checkpointing(self, enable=True): 119 | try: 120 | self.trunk.set_grad_checkpointing(enable) 121 | except Exception as e: 122 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 123 | 124 | def forward(self, x): 125 | x = self.trunk(x) 126 | x = self.head(x) 127 | return x 128 | -------------------------------------------------------------------------------- /src/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a significant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | def decode(output_ids: torch.Tensor): 156 | output_ids = output_ids.cpu().numpy() 157 | return _tokenizer.decode(output_ids) 158 | 159 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 160 | """ 161 | Returns the tokenized representation of given input string(s) 162 | 163 | Parameters 164 | ---------- 165 | texts : Union[str, List[str]] 166 | An input string or a list of input strings to tokenize 167 | context_length : int 168 | The context length to use; all CLIP models use 77 as the context length 169 | 170 | Returns 171 | ------- 172 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 173 | """ 174 | if isinstance(texts, str): 175 | texts = [texts] 176 | 177 | sot_token = _tokenizer.encoder[""] 178 | eot_token = _tokenizer.encoder[""] 179 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 180 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 181 | 182 | for i, tokens in enumerate(all_tokens): 183 | if len(tokens) > context_length: 184 | tokens = tokens[:context_length] # Truncate 185 | tokens[-1] = eot_token 186 | result[i, :len(tokens)] = torch.tensor(tokens) 187 | 188 | return result 189 | 190 | 191 | class HFTokenizer: 192 | """HuggingFace tokenizer wrapper""" 193 | 194 | def __init__(self, tokenizer_name: str): 195 | from transformers import AutoTokenizer 196 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 197 | 198 | def save_pretrained(self, dest): 199 | self.tokenizer.save_pretrained(dest) 200 | 201 | def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: 202 | # same cleaning as for default tokenizer, except lowercasing 203 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 207 | input_ids = self.tokenizer( 208 | texts, 209 | return_tensors='pt', 210 | max_length=context_length, 211 | padding='max_length', 212 | truncation=True, 213 | ).input_ids 214 | return input_ids 215 | -------------------------------------------------------------------------------- /src/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass, asdict 3 | from typing import Any, Dict, Optional, Sequence, Tuple, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms.functional as F 8 | 9 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 10 | CenterCrop 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | 14 | 15 | @dataclass 16 | class AugmentationCfg: 17 | scale: Tuple[float, float] = (0.9, 1.0) 18 | ratio: Optional[Tuple[float, float]] = None 19 | color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None 20 | interpolation: Optional[str] = None 21 | re_prob: Optional[float] = None 22 | re_count: Optional[int] = None 23 | use_timm: bool = False 24 | 25 | 26 | class ResizeMaxSize(nn.Module): 27 | 28 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 29 | super().__init__() 30 | if not isinstance(max_size, int): 31 | raise TypeError(f"Size should be int. Got {type(max_size)}") 32 | self.max_size = max_size 33 | self.interpolation = interpolation 34 | self.fn = min if fn == 'min' else min 35 | self.fill = fill 36 | 37 | def forward(self, img): 38 | if isinstance(img, torch.Tensor): 39 | height, width = img.shape[:2] 40 | else: 41 | width, height = img.size 42 | scale = self.max_size / float(max(height, width)) 43 | if scale != 1.0: 44 | new_size = tuple(round(dim * scale) for dim in (height, width)) 45 | img = F.resize(img, new_size, self.interpolation) 46 | pad_h = self.max_size - new_size[0] 47 | pad_w = self.max_size - new_size[1] 48 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 49 | return img 50 | 51 | 52 | def _convert_to_rgb(image): 53 | return image.convert('RGB') 54 | 55 | 56 | def image_transform( 57 | image_size: int, 58 | is_train: bool, 59 | mean: Optional[Tuple[float, ...]] = None, 60 | std: Optional[Tuple[float, ...]] = None, 61 | resize_longest_max: bool = False, 62 | fill_color: int = 0, 63 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 64 | ): 65 | mean = mean or OPENAI_DATASET_MEAN 66 | if not isinstance(mean, (list, tuple)): 67 | mean = (mean,) * 3 68 | 69 | std = std or OPENAI_DATASET_STD 70 | if not isinstance(std, (list, tuple)): 71 | std = (std,) * 3 72 | 73 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 74 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 75 | image_size = image_size[0] 76 | 77 | if isinstance(aug_cfg, dict): 78 | aug_cfg = AugmentationCfg(**aug_cfg) 79 | else: 80 | aug_cfg = aug_cfg or AugmentationCfg() 81 | normalize = Normalize(mean=mean, std=std) 82 | if is_train: 83 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} 84 | use_timm = aug_cfg_dict.pop('use_timm', False) 85 | if use_timm: 86 | from timm.data import create_transform # timm can still be optional 87 | if isinstance(image_size, (tuple, list)): 88 | assert len(image_size) >= 2 89 | input_size = (3,) + image_size[-2:] 90 | else: 91 | input_size = (3, image_size, image_size) 92 | # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time 93 | aug_cfg_dict.setdefault('interpolation', 'random') 94 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default 95 | train_transform = create_transform( 96 | input_size=input_size, 97 | is_training=True, 98 | hflip=0., 99 | mean=mean, 100 | std=std, 101 | re_mode='pixel', 102 | **aug_cfg_dict, 103 | ) 104 | else: 105 | train_transform = Compose([ 106 | RandomResizedCrop( 107 | image_size, 108 | scale=aug_cfg_dict.pop('scale'), 109 | interpolation=InterpolationMode.BICUBIC, 110 | ), 111 | _convert_to_rgb, 112 | ToTensor(), 113 | normalize, 114 | ]) 115 | if aug_cfg_dict: 116 | warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') 117 | return train_transform 118 | else: 119 | if resize_longest_max: 120 | transforms = [ 121 | ResizeMaxSize(image_size, fill=fill_color) 122 | ] 123 | else: 124 | transforms = [ 125 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 126 | CenterCrop(image_size), 127 | ] 128 | transforms.extend([ 129 | _convert_to_rgb, 130 | ToTensor(), 131 | normalize, 132 | ]) 133 | return Compose(transforms) 134 | -------------------------------------------------------------------------------- /src/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /src/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.16.0' 2 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import open_clip 4 | 5 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32') 6 | tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu') 7 | 8 | image = preprocess(Image.open(".CLIP.png")).unsqueeze(0) 9 | text = tokenizer(["a diagram", "a dog", "a cat"]) 10 | 11 | with torch.no_grad(), torch.cuda.amp.autocast(): 12 | image_features = model.encode_image(image) 13 | text_features = model.encode_text(text) 14 | image_features /= image_features.norm(dim=-1, keepdim=True) 15 | text_features /= text_features.norm(dim=-1, keepdim=True) 16 | 17 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 18 | 19 | print("Label probs:", text_probs) # prints: [[1., 0., 0.]] -------------------------------------------------------------------------------- /src/training/.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zqhang/Accurate-WinCLIP-pytorch/eef4f8cce6a80eddfa07415b503c22c6c3427351/src/training/__init__.py -------------------------------------------------------------------------------- /src/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def is_global_master(args): 13 | return args.rank == 0 14 | 15 | 16 | def is_local_master(args): 17 | return args.local_rank == 0 18 | 19 | 20 | def is_master(args, local=False): 21 | return is_local_master(args) if local else is_global_master(args) 22 | 23 | 24 | def is_using_horovod(): 25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 29 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 30 | return True 31 | else: 32 | return False 33 | 34 | 35 | def is_using_distributed(): 36 | if 'WORLD_SIZE' in os.environ: 37 | return int(os.environ['WORLD_SIZE']) > 1 38 | if 'SLURM_NTASKS' in os.environ: 39 | return int(os.environ['SLURM_NTASKS']) > 1 40 | return False 41 | 42 | 43 | def world_info_from_env(): 44 | local_rank = 0 45 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 46 | if v in os.environ: 47 | local_rank = int(os.environ[v]) 48 | break 49 | global_rank = 0 50 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 51 | if v in os.environ: 52 | global_rank = int(os.environ[v]) 53 | break 54 | world_size = 1 55 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 56 | if v in os.environ: 57 | world_size = int(os.environ[v]) 58 | break 59 | 60 | return local_rank, global_rank, world_size 61 | 62 | 63 | def init_distributed_device(args): 64 | # Distributed training = training on more than one GPU. 65 | # Works in both single and multi-node scenarios. 66 | args.distributed = False 67 | args.world_size = 1 68 | args.rank = 0 # global rank 69 | args.local_rank = 0 70 | if args.horovod: 71 | assert hvd is not None, "Horovod is not installed" 72 | hvd.init() 73 | args.local_rank = int(hvd.local_rank()) 74 | args.rank = hvd.rank() 75 | args.world_size = hvd.size() 76 | args.distributed = True 77 | os.environ['LOCAL_RANK'] = str(args.local_rank) 78 | os.environ['RANK'] = str(args.rank) 79 | os.environ['WORLD_SIZE'] = str(args.world_size) 80 | elif is_using_distributed(): 81 | if 'SLURM_PROCID' in os.environ: 82 | # DDP via SLURM 83 | args.local_rank, args.rank, args.world_size = world_info_from_env() 84 | # SLURM var -> torch.distributed vars in case needed 85 | os.environ['LOCAL_RANK'] = str(args.local_rank) 86 | os.environ['RANK'] = str(args.rank) 87 | os.environ['WORLD_SIZE'] = str(args.world_size) 88 | torch.distributed.init_process_group( 89 | backend=args.dist_backend, 90 | init_method=args.dist_url, 91 | world_size=args.world_size, 92 | rank=args.rank, 93 | ) 94 | else: 95 | # DDP via torchrun, torch.distributed.launch 96 | args.local_rank, _, _ = world_info_from_env() 97 | torch.distributed.init_process_group( 98 | backend=args.dist_backend, 99 | init_method=args.dist_url) 100 | args.world_size = torch.distributed.get_world_size() 101 | args.rank = torch.distributed.get_rank() 102 | args.distributed = True 103 | 104 | if torch.cuda.is_available(): 105 | if args.distributed and not args.no_set_device_rank: 106 | device = 'cuda:%d' % args.local_rank 107 | else: 108 | device = 'cuda:0' 109 | torch.cuda.set_device(device) 110 | else: 111 | device = 'cpu' 112 | args.device = device 113 | device = torch.device(device) 114 | return device 115 | 116 | 117 | def broadcast_object(args, obj, src=0): 118 | # broadcast a pickle-able python object from rank-0 to all ranks 119 | if args.horovod: 120 | return hvd.broadcast_object(obj, root_rank=src) 121 | else: 122 | if args.rank == src: 123 | objects = [obj] 124 | else: 125 | objects = [None] 126 | dist.broadcast_object_list(objects, src=src) 127 | return objects[0] 128 | 129 | 130 | def all_gather_object(args, obj, dst=0): 131 | # gather a pickle-able python object across all ranks 132 | if args.horovod: 133 | return hvd.allgather_object(obj) 134 | else: 135 | objects = [None for _ in range(args.world_size)] 136 | dist.all_gather_object(objects, obj) 137 | return objects 138 | -------------------------------------------------------------------------------- /src/training/file_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import multiprocessing 4 | import subprocess 5 | import time 6 | import fsspec 7 | import torch 8 | from tqdm import tqdm 9 | 10 | def remote_sync_s3(local_dir, remote_dir): 11 | # skip epoch_latest which can change during sync. 12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | if result.returncode != 0: 14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") 15 | return False 16 | 17 | logging.info(f"Successfully synced with S3 bucket") 18 | return True 19 | 20 | def remote_sync_fsspec(local_dir, remote_dir): 21 | # FIXME currently this is slow and not recommended. Look into speeding up. 22 | a = fsspec.get_mapper(local_dir) 23 | b = fsspec.get_mapper(remote_dir) 24 | 25 | for k in a: 26 | # skip epoch_latest which can change during sync. 27 | if 'epoch_latest.pt' in k: 28 | continue 29 | 30 | logging.info(f'Attempting to sync {k}') 31 | if k in b and len(a[k]) == len(b[k]): 32 | logging.debug(f'Skipping remote sync for {k}.') 33 | continue 34 | 35 | try: 36 | logging.info(f'Successful sync for {k}.') 37 | b[k] = a[k] 38 | except Exception as e: 39 | logging.info(f'Error during remote sync for {k}: {e}') 40 | return False 41 | 42 | return True 43 | 44 | def remote_sync(local_dir, remote_dir, protocol): 45 | logging.info('Starting remote sync.') 46 | if protocol == 's3': 47 | return remote_sync_s3(local_dir, remote_dir) 48 | elif protocol == 'fsspec': 49 | return remote_sync_fsspec(local_dir, remote_dir) 50 | else: 51 | logging.error('Remote protocol not known') 52 | return False 53 | 54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): 55 | while True: 56 | time.sleep(sync_every) 57 | remote_sync(local_dir, remote_dir, protocol) 58 | 59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol): 60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) 61 | return p 62 | 63 | # Note: we are not currently using this save function. 64 | def pt_save(pt_obj, file_path): 65 | of = fsspec.open(file_path, "wb") 66 | with of as f: 67 | torch.save(pt_obj, file_path) 68 | 69 | def pt_load(file_path, map_location=None): 70 | if file_path.startswith('s3'): 71 | logging.info('Loading remote checkpoint, which may take a bit.') 72 | of = fsspec.open(file_path, "rb") 73 | with of as f: 74 | out = torch.load(f, map_location=map_location) 75 | return out 76 | 77 | def check_exists(file_path): 78 | try: 79 | with fsspec.open(file_path): 80 | pass 81 | except FileNotFoundError: 82 | return False 83 | return True 84 | -------------------------------------------------------------------------------- /src/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /src/training/params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | 4 | 5 | def get_default_params(model_name): 6 | # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) 7 | model_name = model_name.lower() 8 | if "vit" in model_name: 9 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} 10 | else: 11 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} 12 | 13 | 14 | class ParseKwargs(argparse.Action): 15 | def __call__(self, parser, namespace, values, option_string=None): 16 | kw = {} 17 | for value in values: 18 | key, value = value.split('=') 19 | try: 20 | kw[key] = ast.literal_eval(value) 21 | except ValueError: 22 | kw[key] = str(value) # fallback to string (avoid need to escape on command line) 23 | setattr(namespace, self.dest, kw) 24 | 25 | 26 | def parse_args(args): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | "--train-data", 30 | type=str, 31 | default=None, 32 | help="Path to file(s) with training data. When using webdataset, multiple datasources can be combined using the `::` separator.", 33 | ) 34 | parser.add_argument( 35 | "--train-data-upsampling-factors", 36 | type=str, 37 | default=None, 38 | help=( 39 | "When using multiple data sources with webdataset and sampling with replacement, this can be used to upsample specific data sources. " 40 | "Similar to --train-data, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) " 41 | "By default, datapoints are sampled uniformly regardless of the dataset sizes." 42 | ) 43 | ) 44 | parser.add_argument( 45 | "--val-data", 46 | type=str, 47 | default=None, 48 | help="Path to file(s) with validation data", 49 | ) 50 | parser.add_argument( 51 | "--train-num-samples", 52 | type=int, 53 | default=None, 54 | help="Number of samples in dataset. Required for webdataset if not available in info file.", 55 | ) 56 | parser.add_argument( 57 | "--val-num-samples", 58 | type=int, 59 | default=None, 60 | help="Number of samples in dataset. Useful for webdataset if not available in info file.", 61 | ) 62 | parser.add_argument( 63 | "--dataset-type", 64 | choices=["webdataset", "csv", "synthetic", "auto"], 65 | default="auto", 66 | help="Which type of dataset to process." 67 | ) 68 | parser.add_argument( 69 | "--dataset-resampled", 70 | default=False, 71 | action="store_true", 72 | help="Whether to use sampling with replacement for webdataset shard selection." 73 | ) 74 | parser.add_argument( 75 | "--csv-separator", 76 | type=str, 77 | default="\t", 78 | help="For csv-like datasets, which separator to use." 79 | ) 80 | parser.add_argument( 81 | "--csv-img-key", 82 | type=str, 83 | default="filepath", 84 | help="For csv-like datasets, the name of the key for the image paths." 85 | ) 86 | parser.add_argument( 87 | "--csv-caption-key", 88 | type=str, 89 | default="title", 90 | help="For csv-like datasets, the name of the key for the captions." 91 | ) 92 | parser.add_argument( 93 | "--imagenet-val", 94 | type=str, 95 | default=None, 96 | help="Path to imagenet val set for conducting zero shot evaluation.", 97 | ) 98 | parser.add_argument( 99 | "--imagenet-v2", 100 | type=str, 101 | default=None, 102 | help="Path to imagenet v2 for conducting zero shot evaluation.", 103 | ) 104 | parser.add_argument( 105 | "--logs", 106 | type=str, 107 | default="./logs/", 108 | help="Where to store tensorboard logs. Use None to avoid storing logs.", 109 | ) 110 | parser.add_argument( 111 | "--log-local", 112 | action="store_true", 113 | default=False, 114 | help="log files on local master, otherwise global master only.", 115 | ) 116 | parser.add_argument( 117 | "--name", 118 | type=str, 119 | default=None, 120 | help="Optional identifier for the experiment when storing logs. Otherwise use current time.", 121 | ) 122 | parser.add_argument( 123 | "--workers", type=int, default=1, help="Number of dataloader workers per GPU." 124 | ) 125 | parser.add_argument( 126 | "--batch-size", type=int, default=64, help="Batch size per GPU." 127 | ) 128 | parser.add_argument( 129 | "--epochs", type=int, default=32, help="Number of epochs to train for." 130 | ) 131 | parser.add_argument( 132 | "--epochs-cooldown", type=int, default=None, 133 | help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards." 134 | ) 135 | parser.add_argument("--lr", type=float, default=None, help="Learning rate.") 136 | parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") 137 | parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") 138 | parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") 139 | parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") 140 | parser.add_argument( 141 | "--warmup", type=int, default=10000, help="Number of steps to warmup for." 142 | ) 143 | parser.add_argument( 144 | "--use-bn-sync", 145 | default=False, 146 | action="store_true", 147 | help="Whether to use batch norm sync.") 148 | parser.add_argument( 149 | "--skip-scheduler", 150 | action="store_true", 151 | default=False, 152 | help="Use this flag to skip the learning rate decay.", 153 | ) 154 | parser.add_argument( 155 | "--lr-scheduler", 156 | type=str, 157 | default='cosine', 158 | help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine", 159 | ) 160 | parser.add_argument( 161 | "--lr-cooldown-end", type=float, default=0.0, 162 | help="End learning rate for cooldown schedule. Default: 0" 163 | ) 164 | parser.add_argument( 165 | "--lr-cooldown-power", type=float, default=1.0, 166 | help="Power for polynomial cooldown schedule. Default: 1.0 (linear decay)" 167 | ) 168 | parser.add_argument( 169 | "--save-frequency", type=int, default=1, help="How often to save checkpoints." 170 | ) 171 | parser.add_argument( 172 | "--save-most-recent", 173 | action="store_true", 174 | default=False, 175 | help="Always save the most recent model trained to epoch_latest.pt.", 176 | ) 177 | parser.add_argument( 178 | "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." 179 | ) 180 | parser.add_argument( 181 | "--val-frequency", type=int, default=1, help="How often to run evaluation with val data." 182 | ) 183 | parser.add_argument( 184 | "--resume", 185 | default=None, 186 | type=str, 187 | help="path to latest checkpoint (default: none)", 188 | ) 189 | parser.add_argument( 190 | "--precision", 191 | choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], 192 | default="amp", 193 | help="Floating point precision." 194 | ) 195 | parser.add_argument( 196 | "--model", 197 | type=str, 198 | default="RN50", 199 | help="Name of the vision backbone to use.", 200 | ) 201 | parser.add_argument( 202 | "--pretrained", 203 | default='', 204 | type=str, 205 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 206 | ) 207 | parser.add_argument( 208 | "--pretrained-image", 209 | default=False, 210 | action='store_true', 211 | help="Load imagenet pretrained weights for image tower backbone if available.", 212 | ) 213 | parser.add_argument( 214 | "--lock-image", 215 | default=False, 216 | action='store_true', 217 | help="Lock full image tower by disabling gradients.", 218 | ) 219 | parser.add_argument( 220 | "--lock-image-unlocked-groups", 221 | type=int, 222 | default=0, 223 | help="Leave last n image tower layer groups unlocked.", 224 | ) 225 | parser.add_argument( 226 | "--lock-image-freeze-bn-stats", 227 | default=False, 228 | action='store_true', 229 | help="Freeze BatchNorm running stats in image tower for any locked layers.", 230 | ) 231 | parser.add_argument( 232 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', 233 | help='Override default image mean value of dataset') 234 | parser.add_argument( 235 | '--image-std', type=float, nargs='+', default=None, metavar='STD', 236 | help='Override default image std deviation of of dataset') 237 | parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs) 238 | parser.add_argument( 239 | "--grad-checkpointing", 240 | default=False, 241 | action='store_true', 242 | help="Enable gradient checkpointing.", 243 | ) 244 | parser.add_argument( 245 | "--local-loss", 246 | default=False, 247 | action="store_true", 248 | help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)" 249 | ) 250 | parser.add_argument( 251 | "--gather-with-grad", 252 | default=False, 253 | action="store_true", 254 | help="enable full distributed gradient for feature gather" 255 | ) 256 | parser.add_argument( 257 | '--force-image-size', type=int, nargs='+', default=None, 258 | help='Override default image size' 259 | ) 260 | parser.add_argument( 261 | "--force-quick-gelu", 262 | default=False, 263 | action='store_true', 264 | help="Force use of QuickGELU activation for non-OpenAI transformer models.", 265 | ) 266 | parser.add_argument( 267 | "--force-patch-dropout", 268 | default=None, 269 | type=float, 270 | help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper", 271 | ) 272 | parser.add_argument( 273 | "--force-custom-text", 274 | default=False, 275 | action='store_true', 276 | help="Force use of CustomTextCLIP model (separate text-tower).", 277 | ) 278 | parser.add_argument( 279 | "--torchscript", 280 | default=False, 281 | action='store_true', 282 | help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", 283 | ) 284 | parser.add_argument( 285 | "--trace", 286 | default=False, 287 | action='store_true', 288 | help="torch.jit.trace the model for inference / eval only", 289 | ) 290 | parser.add_argument( 291 | "--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps." 292 | ) 293 | # arguments for distributed training 294 | parser.add_argument( 295 | "--dist-url", 296 | default="env://", 297 | type=str, 298 | help="url used to set up distributed training", 299 | ) 300 | parser.add_argument( 301 | "--dist-backend", default="nccl", type=str, help="distributed backend" 302 | ) 303 | parser.add_argument( 304 | "--report-to", 305 | default='', 306 | type=str, 307 | help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']" 308 | ) 309 | parser.add_argument( 310 | "--wandb-notes", 311 | default='', 312 | type=str, 313 | help="Notes if logging with wandb" 314 | ) 315 | parser.add_argument( 316 | "--wandb-project-name", 317 | type=str, 318 | default='open-clip', 319 | help="Name of the project if logging with wandb.", 320 | ) 321 | parser.add_argument( 322 | "--debug", 323 | default=False, 324 | action="store_true", 325 | help="If true, more information is logged." 326 | ) 327 | parser.add_argument( 328 | "--copy-codebase", 329 | default=False, 330 | action="store_true", 331 | help="If true, we copy the entire base on the log directory, and execute from there." 332 | ) 333 | parser.add_argument( 334 | "--horovod", 335 | default=False, 336 | action="store_true", 337 | help="Use horovod for distributed training." 338 | ) 339 | parser.add_argument( 340 | "--ddp-static-graph", 341 | default=False, 342 | action='store_true', 343 | help="Enable static graph optimization for DDP in PyTorch >= 1.11.", 344 | ) 345 | parser.add_argument( 346 | "--no-set-device-rank", 347 | default=False, 348 | action="store_true", 349 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)." 350 | ) 351 | parser.add_argument( 352 | "--seed", type=int, default=0, help="Default random seed." 353 | ) 354 | parser.add_argument( 355 | "--grad-clip-norm", type=float, default=None, help="Gradient clip." 356 | ) 357 | parser.add_argument( 358 | "--lock-text", 359 | default=False, 360 | action='store_true', 361 | help="Lock full text tower by disabling gradients.", 362 | ) 363 | parser.add_argument( 364 | "--lock-text-unlocked-layers", 365 | type=int, 366 | default=0, 367 | help="Leave last n image tower layer groups unlocked.", 368 | ) 369 | parser.add_argument( 370 | "--lock-text-freeze-layer-norm", 371 | default=False, 372 | action='store_true', 373 | help="Freeze BatchNorm running stats in image tower for any locked layers.", 374 | ) 375 | parser.add_argument( 376 | "--log-every-n-steps", 377 | type=int, 378 | default=100, 379 | help="Log every n steps to tensorboard/console/wandb.", 380 | ) 381 | parser.add_argument( 382 | "--coca-caption-loss-weight", 383 | type=float, 384 | default=2.0, 385 | help="Weight assigned to caption loss in CoCa." 386 | ) 387 | parser.add_argument( 388 | "--coca-contrastive-loss-weight", 389 | type=float, 390 | default=1.0, 391 | help="Weight assigned to contrastive loss when training CoCa." 392 | ) 393 | parser.add_argument( 394 | "--remote-sync", 395 | type=str, 396 | default=None, 397 | help="Optinoally sync with a remote path specified by this arg", 398 | ) 399 | parser.add_argument( 400 | "--remote-sync-frequency", 401 | type=int, 402 | default=300, 403 | help="How frequently to sync to a remote directly if --remote-sync is not None.", 404 | ) 405 | parser.add_argument( 406 | "--remote-sync-protocol", 407 | choices=["s3", "fsspec"], 408 | default="s3", 409 | help="How to do the remote sync backup if --remote-sync is not None.", 410 | ) 411 | parser.add_argument( 412 | "--delete-previous-checkpoint", 413 | default=False, 414 | action="store_true", 415 | help="If true, delete previous checkpoint after storing a new one." 416 | ) 417 | parser.add_argument( 418 | "--distill-model", 419 | default=None, 420 | help='Which model arch to distill from, if any.' 421 | ) 422 | parser.add_argument( 423 | "--distill-pretrained", 424 | default=None, 425 | help='Which pre-trained weights to distill from, if any.' 426 | ) 427 | args = parser.parse_args(args) 428 | 429 | # If some params are not passed, we use the default values based on model name. 430 | default_params = get_default_params(args.model) 431 | for name, val in default_params.items(): 432 | if getattr(args, name) is None: 433 | setattr(args, name, val) 434 | 435 | return args 436 | -------------------------------------------------------------------------------- /src/training/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | 4 | 5 | def get_autocast(precision): 6 | if precision == 'amp': 7 | return torch.cuda.amp.autocast 8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16': 9 | # amp_bfloat16 is more stable than amp float16 for clip training 10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 11 | else: 12 | return suppress 13 | -------------------------------------------------------------------------------- /src/training/profile.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import open_clip 5 | import pandas as pd 6 | from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis 7 | 8 | 9 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler') 10 | 11 | # benchmark specific args 12 | parser.add_argument('--model', metavar='NAME', default='', 13 | help='model(s) to profile') 14 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 15 | help='Output csv file for results') 16 | 17 | 18 | def profile_fvcore( 19 | model, 20 | image_input_size=(3, 224, 224), 21 | text_input_size=(77,), 22 | batch_size=1, 23 | detailed=False, 24 | force_cpu=False 25 | ): 26 | if force_cpu: 27 | model = model.to('cpu') 28 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 29 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 30 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 31 | fca = FlopCountAnalysis(model, (example_image_input, example_text_input)) 32 | aca = ActivationCountAnalysis(model, (example_image_input, example_text_input)) 33 | if detailed: 34 | fcs = flop_count_str(fca) 35 | print(fcs) 36 | return fca.total(), aca.total() 37 | 38 | 39 | def profile_fvcore_text( 40 | model, 41 | text_input_size=(77,), 42 | batch_size=1, 43 | detailed=False, 44 | force_cpu=False 45 | ): 46 | if force_cpu: 47 | model = model.to('cpu') 48 | device = next(model.parameters()).device 49 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 50 | fca = FlopCountAnalysis(model, example_input) 51 | aca = ActivationCountAnalysis(model, example_input) 52 | if detailed: 53 | fcs = flop_count_str(fca) 54 | print(fcs) 55 | return fca.total(), aca.total() 56 | 57 | 58 | def profile_fvcore_image( 59 | model, 60 | image_input_size=(3, 224, 224), 61 | batch_size=1, 62 | detailed=False, 63 | force_cpu=False 64 | ): 65 | if force_cpu: 66 | model = model.to('cpu') 67 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 68 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 69 | fca = FlopCountAnalysis(model, example_input) 70 | aca = ActivationCountAnalysis(model, example_input) 71 | if detailed: 72 | fcs = flop_count_str(fca) 73 | print(fcs) 74 | return fca.total(), aca.total() 75 | 76 | 77 | def count_params(model): 78 | return sum([m.numel() for m in model.parameters()]) 79 | 80 | 81 | def profile_model(model_name): 82 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) 83 | model.eval() 84 | if torch.cuda.is_available(): 85 | model = model.cuda() 86 | 87 | if isinstance(model.visual.image_size, (tuple, list)): 88 | image_input_size = (3,) + tuple(model.visual.image_size[-2:]) 89 | else: 90 | image_input_size = (3, model.visual.image_size, model.visual.image_size) 91 | text_input_size = (77,) 92 | 93 | results = {} 94 | results['model'] = model_name 95 | results['image_size'] = image_input_size[1] 96 | 97 | model_cfg = open_clip.get_model_config(model_name) 98 | if model_cfg: 99 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) 100 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) 101 | results['image_width'] = int(vision_cfg.width) 102 | results['text_width'] = int(text_cfg.width) 103 | results['embed_dim'] = int(model_cfg['embed_dim']) 104 | else: 105 | results['image_width'] = 0 106 | results['text_width'] = 0 107 | results['embed_dim'] = 0 108 | 109 | retries = 2 110 | while retries: 111 | retries -= 1 112 | try: 113 | macs, acts = profile_fvcore( 114 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries) 115 | 116 | image_macs, image_acts = profile_fvcore_image( 117 | model.visual, image_input_size=image_input_size, force_cpu=not retries) 118 | 119 | text_macs, text_acts = profile_fvcore_text( 120 | model.text, text_input_size=text_input_size, force_cpu=not retries) 121 | 122 | results['gmacs'] = round(macs / 1e9, 2) 123 | results['macts'] = round(acts / 1e6, 2) 124 | results['mparams'] = round(count_params(model) / 1e6, 2) 125 | results['image_gmacs'] = round(image_macs / 1e9, 2) 126 | results['image_macts'] = round(image_acts / 1e6, 2) 127 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) 128 | results['text_gmacs'] = round(text_macs / 1e9, 2) 129 | results['text_macts'] = round(text_acts / 1e6, 2) 130 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2) 131 | except RuntimeError as e: 132 | pass 133 | return results 134 | 135 | 136 | def main(): 137 | args = parser.parse_args() 138 | 139 | # FIXME accept a text file name to allow lists of models in txt/csv 140 | if args.model == 'all': 141 | parsed_model = open_clip.list_models() 142 | else: 143 | parsed_model = args.model.split(',') 144 | 145 | results = [] 146 | for m in parsed_model: 147 | row = profile_model(m) 148 | results.append(row) 149 | 150 | df = pd.DataFrame(results, columns=results[0].keys()) 151 | df = df.sort_values('gmacs') 152 | print(df) 153 | if args.results_file: 154 | df.to_csv(args.results_file, index=False) 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /src/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | return _lr_adjuster 22 | 23 | 24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): 25 | def _lr_adjuster(step): 26 | start_cooldown_step = steps - cooldown_steps 27 | if step < warmup_length: 28 | lr = _warmup_lr(base_lr, warmup_length, step) 29 | else: 30 | if step < start_cooldown_step: 31 | lr = base_lr 32 | else: 33 | e = step - start_cooldown_step 34 | es = steps - start_cooldown_step 35 | # linear decay if power == 1; polynomial decay otherwise; 36 | decay = (1 - (e/es)) ** cooldown_power 37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 38 | assign_learning_rate(optimizer, lr) 39 | return lr 40 | return _lr_adjuster 41 | 42 | 43 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 44 | def _lr_adjuster(step): 45 | if step < warmup_length: 46 | lr = _warmup_lr(base_lr, warmup_length, step) 47 | else: 48 | e = step - warmup_length 49 | es = steps - warmup_length 50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 51 | assign_learning_rate(optimizer, lr) 52 | return lr 53 | return _lr_adjuster 54 | -------------------------------------------------------------------------------- /src/training/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.nn.parallel.distributed import DistributedDataParallel 11 | 12 | try: 13 | import wandb 14 | except ImportError: 15 | wandb = None 16 | 17 | from open_clip import get_cast_dtype, CLIP, CustomTextCLIP 18 | from .distributed import is_master 19 | from .zero_shot import zero_shot_eval 20 | from .precision import get_autocast 21 | 22 | 23 | class AverageMeter(object): 24 | """Computes and stores the average and current value""" 25 | 26 | def __init__(self): 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | def postprocess_clip_output(model_out): 42 | return { 43 | "image_features": model_out[0], 44 | "text_features": model_out[1], 45 | "logit_scale": model_out[2] 46 | } 47 | 48 | def unwrap_model(model): 49 | if hasattr(model, 'module'): 50 | return model.module 51 | else: 52 | return model 53 | 54 | 55 | def backward(total_loss, scaler): 56 | if scaler is not None: 57 | scaler.scale(total_loss).backward() 58 | else: 59 | total_loss.backward() 60 | 61 | 62 | def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None): 63 | device = torch.device(args.device) 64 | autocast = get_autocast(args.precision) 65 | cast_dtype = get_cast_dtype(args.precision) 66 | 67 | 68 | model.train() 69 | if args.distill: 70 | dist_model.eval() 71 | 72 | data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch 73 | dataloader = data['train'].dataloader 74 | num_batches_per_epoch = dataloader.num_batches // args.accum_freq 75 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) 76 | 77 | if args.accum_freq > 1: 78 | accum_images, accum_texts, accum_features = [], [], {} 79 | 80 | losses_m = {} 81 | batch_time_m = AverageMeter() 82 | data_time_m = AverageMeter() 83 | end = time.time() 84 | for i, batch in enumerate(dataloader): 85 | i_accum = i // args.accum_freq 86 | step = num_batches_per_epoch * epoch + i_accum 87 | 88 | if not args.skip_scheduler: 89 | scheduler(step) 90 | 91 | images, texts = batch 92 | images = images.to(device=device, dtype=cast_dtype, non_blocking=True) 93 | texts = texts.to(device=device, non_blocking=True) 94 | 95 | data_time_m.update(time.time() - end) 96 | optimizer.zero_grad() 97 | 98 | if args.accum_freq == 1: 99 | with autocast(): 100 | model_out = model(images, texts) 101 | logit_scale = model_out["logit_scale"] 102 | if args.distill: 103 | with torch.no_grad(): 104 | dist_model_out = dist_model(images, texts) 105 | model_out.update({f'dist_{k}' : v for k, v in dist_model_out.items()}) 106 | losses = loss(**model_out, output_dict=True) 107 | 108 | total_loss = sum(losses.values()) 109 | losses["loss"] = total_loss 110 | 111 | backward(total_loss, scaler) 112 | else: 113 | # First, cache the features without any gradient tracking. 114 | with torch.no_grad(): 115 | with autocast(): 116 | model_out = model(images, texts) 117 | model_out.pop("logit_scale") 118 | for key, val in model_out.items(): 119 | if key in accum_features: 120 | accum_features[key].append(val) 121 | else: 122 | accum_features[key] = [val] 123 | 124 | accum_images.append(images) 125 | accum_texts.append(texts) 126 | 127 | # If (i + 1) % accum_freq is not zero, move on to the next batch. 128 | if ((i + 1) % args.accum_freq) > 0: 129 | # FIXME this makes data time logging unreliable when accumulating 130 | continue 131 | 132 | # Now, ready to take gradients for the last accum_freq batches. 133 | # Re-do the forward pass for those batches, and use the cached features from the other batches as negatives. 134 | # Call backwards each time, but only step optimizer at the end. 135 | optimizer.zero_grad() 136 | for j in range(args.accum_freq): 137 | images = accum_images[j] 138 | texts = accum_texts[j] 139 | with autocast(): 140 | model_out = model(images, texts) 141 | logit_scale = model_out.pop("logit_scale") 142 | inputs = {} 143 | for key, val in accum_features.items(): 144 | accumulated = accum_features[key] 145 | inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:]) 146 | losses = loss(**inputs, logit_scale=logit_scale, output_dict=True) 147 | del inputs 148 | total_loss = sum(losses.values()) 149 | losses["loss"] = total_loss 150 | backward(total_loss, scaler) 151 | 152 | if scaler is not None: 153 | if args.horovod: 154 | optimizer.synchronize() 155 | scaler.unscale_(optimizer) 156 | if args.grad_clip_norm is not None: 157 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) 158 | with optimizer.skip_synchronize(): 159 | scaler.step(optimizer) 160 | else: 161 | if args.grad_clip_norm is not None: 162 | scaler.unscale_(optimizer) 163 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) 164 | scaler.step(optimizer) 165 | scaler.update() 166 | else: 167 | if args.grad_clip_norm is not None: 168 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) 169 | optimizer.step() 170 | 171 | # reset gradient accum, if enabled 172 | if args.accum_freq > 1: 173 | accum_images, accum_texts, accum_features = [], [], {} 174 | 175 | # Note: we clamp to 4.6052 = ln(100), as in the original paper. 176 | with torch.no_grad(): 177 | unwrap_model(model).logit_scale.clamp_(0, math.log(100)) 178 | 179 | batch_time_m.update(time.time() - end) 180 | end = time.time() 181 | batch_count = i_accum + 1 182 | if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): 183 | batch_size = len(images) 184 | num_samples = batch_count * batch_size * args.accum_freq * args.world_size 185 | samples_per_epoch = dataloader.num_samples 186 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 187 | 188 | # NOTE loss is coarsely sampled, just master node and per log update 189 | for key, val in losses.items(): 190 | if key not in losses_m: 191 | losses_m[key] = AverageMeter() 192 | losses_m[key].update(val.item(), batch_size) 193 | 194 | logit_scale_scalar = logit_scale.item() 195 | loss_log = " ".join( 196 | [ 197 | f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" 198 | for loss_name, loss_m in losses_m.items() 199 | ] 200 | ) 201 | samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val 202 | samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val 203 | logging.info( 204 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 205 | f"Data (t): {data_time_m.avg:.3f} " 206 | f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " 207 | f"LR: {optimizer.param_groups[0]['lr']:5f} " 208 | f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log 209 | ) 210 | 211 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing 212 | log_data = { 213 | "data_time": data_time_m.val, 214 | "batch_time": batch_time_m.val, 215 | "samples_per_second": samples_per_second, 216 | "samples_per_second_per_gpu": samples_per_second_per_gpu, 217 | "scale": logit_scale_scalar, 218 | "lr": optimizer.param_groups[0]["lr"] 219 | } 220 | log_data.update({name:val.val for name,val in losses_m.items()}) 221 | 222 | for name, val in log_data.items(): 223 | name = "train/" + name 224 | if tb_writer is not None: 225 | tb_writer.add_scalar(name, val, step) 226 | if args.wandb: 227 | assert wandb is not None, 'Please install wandb.' 228 | wandb.log({name: val, 'step': step}) 229 | 230 | # resetting batch / data time meters per log window 231 | batch_time_m.reset() 232 | data_time_m.reset() 233 | # end for 234 | 235 | 236 | def evaluate(model, data, epoch, args, tb_writer=None): 237 | metrics = {} 238 | if not is_master(args): 239 | return metrics 240 | device = torch.device(args.device) 241 | model.eval() 242 | 243 | zero_shot_metrics = zero_shot_eval(model, data, epoch, args) 244 | metrics.update(zero_shot_metrics) 245 | 246 | autocast = get_autocast(args.precision) 247 | cast_dtype = get_cast_dtype(args.precision) 248 | 249 | if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): 250 | dataloader = data['val'].dataloader 251 | num_samples = 0 252 | samples_per_val = dataloader.num_samples 253 | 254 | # FIXME this does not scale past small eval datasets 255 | # all_image_features @ all_text_features will blow up memory and compute very quickly 256 | cumulative_loss = 0.0 257 | cumulative_gen_loss = 0.0 258 | all_image_features, all_text_features = [], [] 259 | with torch.no_grad(): 260 | for i, batch in enumerate(dataloader): 261 | images, texts = batch 262 | images = images.to(device=device, dtype=cast_dtype, non_blocking=True) 263 | texts = texts.to(device=device, non_blocking=True) 264 | 265 | with autocast(): 266 | model_out = model(images, texts) 267 | image_features = model_out["image_features"] 268 | text_features = model_out["text_features"] 269 | logit_scale = model_out["logit_scale"] 270 | # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly 271 | # however, system RAM is easily exceeded and compute time becomes problematic 272 | all_image_features.append(image_features.cpu()) 273 | all_text_features.append(text_features.cpu()) 274 | logit_scale = logit_scale.mean() 275 | logits_per_image = logit_scale * image_features @ text_features.t() 276 | logits_per_text = logits_per_image.t() 277 | 278 | batch_size = images.shape[0] 279 | labels = torch.arange(batch_size, device=device).long() 280 | total_loss = ( 281 | F.cross_entropy(logits_per_image, labels) + 282 | F.cross_entropy(logits_per_text, labels) 283 | ) / 2 284 | 285 | gen_loss = maybe_compute_generative_loss(model_out) 286 | 287 | cumulative_loss += total_loss * batch_size 288 | num_samples += batch_size 289 | if is_master(args) and (i % 100) == 0: 290 | logging.info( 291 | f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" 292 | f"Clip Loss: {cumulative_loss / num_samples:.6f}\t") 293 | 294 | if gen_loss is not None: 295 | cumulative_gen_loss += gen_loss * batch_size 296 | logging.info( 297 | f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t") 298 | 299 | val_metrics = get_clip_metrics( 300 | image_features=torch.cat(all_image_features), 301 | text_features=torch.cat(all_text_features), 302 | logit_scale=logit_scale.cpu(), 303 | ) 304 | loss = cumulative_loss / num_samples 305 | metrics.update( 306 | {**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} 307 | ) 308 | if gen_loss is not None: 309 | gen_loss = cumulative_gen_loss / num_samples 310 | metrics.update({"val_generative_loss": gen_loss.item()}) 311 | 312 | if not metrics: 313 | return metrics 314 | 315 | logging.info( 316 | f"Eval Epoch: {epoch} " 317 | + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) 318 | ) 319 | 320 | if args.save_logs: 321 | for name, val in metrics.items(): 322 | if tb_writer is not None: 323 | tb_writer.add_scalar(f"val/{name}", val, epoch) 324 | 325 | with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: 326 | f.write(json.dumps(metrics)) 327 | f.write("\n") 328 | 329 | if args.wandb: 330 | assert wandb is not None, 'Please install wandb.' 331 | for name, val in metrics.items(): 332 | wandb.log({f"val/{name}": val, 'epoch': epoch}) 333 | 334 | return metrics 335 | 336 | 337 | def get_clip_metrics(image_features, text_features, logit_scale): 338 | metrics = {} 339 | logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() 340 | logits_per_text = logits_per_image.t().detach().cpu() 341 | 342 | logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} 343 | ground_truth = torch.arange(len(text_features)).view(-1, 1) 344 | 345 | for name, logit in logits.items(): 346 | ranking = torch.argsort(logit, descending=True) 347 | preds = torch.where(ranking == ground_truth)[1] 348 | preds = preds.detach().cpu().numpy() 349 | metrics[f"{name}_mean_rank"] = preds.mean() + 1 350 | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 351 | for k in [1, 5, 10]: 352 | metrics[f"{name}_R@{k}"] = np.mean(preds < k) 353 | 354 | return metrics 355 | 356 | 357 | def maybe_compute_generative_loss(model_out): 358 | if "logits" in model_out and "labels" in model_out: 359 | token_logits = model_out["logits"] 360 | token_labels = model_out["labels"] 361 | return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) 362 | -------------------------------------------------------------------------------- /src/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | 7 | from open_clip import get_cast_dtype, get_tokenizer 8 | from .precision import get_autocast 9 | from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template 10 | 11 | 12 | def zero_shot_classifier(model, classnames, templates, args): 13 | tokenizer = get_tokenizer(args.model) 14 | with torch.no_grad(): 15 | zeroshot_weights = [] 16 | for classname in tqdm(classnames): 17 | texts = [template(classname) for template in templates] # format with class 18 | texts = tokenizer(texts).to(args.device) # tokenize 19 | if args.distributed and not args.horovod: 20 | class_embeddings = model.module.encode_text(texts) 21 | else: 22 | class_embeddings = model.encode_text(texts) 23 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 24 | class_embedding /= class_embedding.norm() 25 | zeroshot_weights.append(class_embedding) 26 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) 27 | return zeroshot_weights 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | pred = output.topk(max(topk), 1, True, True)[1].t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 34 | 35 | 36 | def run(model, classifier, dataloader, args): 37 | autocast = get_autocast(args.precision) 38 | cast_dtype = get_cast_dtype(args.precision) 39 | with torch.no_grad(): 40 | top1, top5, n = 0., 0., 0. 41 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 42 | images = images.to(args.device) 43 | if cast_dtype is not None: 44 | images = images.to(dtype=cast_dtype) 45 | target = target.to(args.device) 46 | 47 | with autocast(): 48 | # predict 49 | if args.distributed and not args.horovod: 50 | image_features = model.module.encode_image(images) 51 | else: 52 | image_features = model.encode_image(images) 53 | image_features = F.normalize(image_features, dim=-1) 54 | logits = 100. * image_features @ classifier 55 | 56 | # measure accuracy 57 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 58 | top1 += acc1 59 | top5 += acc5 60 | n += images.size(0) 61 | 62 | top1 = (top1 / n) 63 | top5 = (top5 / n) 64 | return top1, top5 65 | 66 | 67 | def zero_shot_eval(model, data, epoch, args): 68 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 69 | return {} 70 | if args.zeroshot_frequency == 0: 71 | return {} 72 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 73 | return {} 74 | 75 | logging.info('Starting zero-shot imagenet.') 76 | 77 | logging.info('Building zero-shot classifier') 78 | classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args) 79 | 80 | logging.info('Using classifier') 81 | results = {} 82 | if 'imagenet-val' in data: 83 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 84 | results['imagenet-zeroshot-val-top1'] = top1 85 | results['imagenet-zeroshot-val-top5'] = top5 86 | if 'imagenet-v2' in data: 87 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 88 | results['imagenetv2-zeroshot-val-top1'] = top1 89 | results['imagenetv2-zeroshot-val-top5'] = top5 90 | 91 | logging.info('Finished zero-shot imagenet.') 92 | 93 | return results 94 | -------------------------------------------------------------------------------- /zero_shot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | few_shots=(0) 4 | 5 | for few_num in "${!few_shots[@]}";do 6 | ## train on the VisA dataset 7 | base_dir=winclip_mvtec 8 | save_dir=./exps_${base_dir}/mvtecvit_large_14_518/ 9 | 10 | CUDA_VISIBLE_DEVICES=3 python reproduce_WinCLIP.py --dataset mvtec \ 11 | --data_path /remote-home/iot_zhouqihang/data/mvdataset --save_path ./results/mvtec_${base_dir}/few_shot_${few_shots[few_num]} \ 12 | --model ViT-B-16-plus-240 --pretrained openai --k_shot ${few_shots[few_num]} --image_size 240 13 | wait 14 | done 15 | 16 | 17 | for few_num in "${!few_shots[@]}";do 18 | ## train on the VisA dataset 19 | base_dir=winclip_visa 20 | save_dir=./exps_${base_dir}/mvtecvit_large_14_518/ 21 | 22 | 23 | CUDA_VISIBLE_DEVICES=3 python reproduce_WinCLIP.py --dataset visa \ 24 | --data_path /remote-home/iot_zhouqihang/data/Visa --save_path ./results/mvtec_${base_dir}/few_shot_${few_shots[few_num]} \ 25 | --model ViT-B-16-plus-240 --pretrained openai --k_shot ${few_shots[few_num]} --image_size 240 26 | wait 27 | done 28 | --------------------------------------------------------------------------------