├── DIVA_for_DFN.sh ├── DIVA_for_MetaCLIP.sh ├── DIVA_for_OpenAICLIP.sh ├── DIVA_for_SigLIP.sh ├── LICENSE ├── README.md ├── accelerator.json ├── arguments.py ├── assets ├── introduction.png ├── methodology.png └── qualitative_mmvp.png ├── callbacks.py ├── condition ├── DFN_for_openclip_transformer.py ├── MetaCLIP_for_openclip_transformer.py ├── OpenAICLIP_for_clip_model.py └── SigLIP_for_timm_models_visiontransformer.py ├── config.py ├── data ├── __init__.py ├── constants.py ├── data.py ├── image_data.py ├── register.py └── transform.py ├── models ├── CLIP_bank.py ├── SD_with_DFN.py ├── SD_with_MetaCLIP.py ├── SD_with_OpenAICLIP.py ├── SD_with_SigLIP.py ├── build.py └── utils.py ├── requirements.txt ├── run_DIVA_with_DFN.py ├── run_DIVA_with_MetaCLIP.py ├── run_DIVA_with_OpenAICLIP.py ├── run_DIVA_with_SigLIP.py └── trainer.py /DIVA_for_DFN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | set -x 4 | 5 | export TORCH_DISTRIBUTED_DEBUG=INFO 6 | export NCCL_DEBUG=INFO 7 | export OMP_NUM_THREADS=4 8 | export NCCL_P2P_DISABLE=1 9 | MASTER_ADDR=$(hostname -I | awk '{print $1}') 10 | MASTER_PORT=12345 11 | 12 | run_name=DIVA_for_DFN 13 | num_steps=4600 14 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 15 | --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} --use_env \ 16 | run_DIVA_with_DFN.py \ 17 | --clip_image_size 224 \ 18 | --visual_pattern None \ 19 | --train_steps 2 \ 20 | --image_size 512 \ 21 | --fixed_image_size False \ 22 | --dataset_path dataset/cc3m/\*.tar \ 23 | --output_dir ./outputs/$run_name \ 24 | --remove_unused_columns False \ 25 | --do_train \ 26 | --ddp_find_unused_parameters True \ 27 | --dataloader_num_workers 8 \ 28 | --learning_rate 1e-4 \ 29 | --bf16 True \ 30 | --tf32 True \ 31 | --warmup_ratio 0.005 \ 32 | --weight_decay 0 \ 33 | --max_steps $num_steps \ 34 | --per_device_train_batch_size 16 \ 35 | --logging_strategy steps \ 36 | --logging_steps 50 \ 37 | --gradient_accumulation_steps 5 \ 38 | --save_strategy steps \ 39 | --save_steps $num_steps \ 40 | --save_total_limit 1 \ 41 | --ddp_backend nccl \ 42 | --report_to wandb \ 43 | --run_name $run_name \ 44 | --enable_flash True \ 45 | --lr_scheduler_type "cosine" \ 46 | --seed 42 \ 47 | --accelerator_config accelerator.json > ./logs/debug_$run_name.log 48 | -------------------------------------------------------------------------------- /DIVA_for_MetaCLIP.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | set -x 4 | 5 | export TORCH_DISTRIBUTED_DEBUG=INFO 6 | export NCCL_DEBUG=INFO 7 | export OMP_NUM_THREADS=4 8 | export NCCL_P2P_DISABLE=1 9 | MASTER_ADDR=$(hostname -I | awk '{print $1}') 10 | MASTER_PORT=12345 11 | 12 | run_name=DIVA_for_MetaCLIP 13 | num_steps=4600 14 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 15 | --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} --use_env \ 16 | run_DIVA_with_MetaCLIP.py \ 17 | --metaclip_version large \ 18 | --clip_image_size 224 \ 19 | --visual_pattern None \ 20 | --train_steps 2 \ 21 | --image_size 512 \ 22 | --fixed_image_size False \ 23 | --dataset_path dataset/cc3m/\*.tar \ 24 | --output_dir ./outputs/$run_name \ 25 | --remove_unused_columns False \ 26 | --do_train \ 27 | --ddp_find_unused_parameters True \ 28 | --dataloader_num_workers 8 \ 29 | --learning_rate 1e-4 \ 30 | --bf16 True \ 31 | --tf32 True \ 32 | --warmup_ratio 0.005 \ 33 | --weight_decay 0 \ 34 | --max_steps $num_steps \ 35 | --per_device_train_batch_size 16 \ 36 | --logging_strategy steps \ 37 | --logging_steps 50 \ 38 | --gradient_accumulation_steps 5 \ 39 | --save_strategy steps \ 40 | --save_steps $num_steps \ 41 | --save_total_limit 1 \ 42 | --ddp_backend nccl \ 43 | --report_to wandb \ 44 | --run_name $run_name \ 45 | --enable_flash True \ 46 | --lr_scheduler_type "cosine" \ 47 | --seed 42 \ 48 | --accelerator_config accelerator.json > ./logs/debug_$run_name.log 49 | -------------------------------------------------------------------------------- /DIVA_for_OpenAICLIP.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | set -x 4 | 5 | export TORCH_DISTRIBUTED_DEBUG=INFO 6 | export NCCL_DEBUG=INFO 7 | export OMP_NUM_THREADS=4 8 | export NCCL_P2P_DISABLE=1 9 | MASTER_ADDR=$(hostname -I | awk '{print $1}') 10 | MASTER_PORT=12345 11 | 12 | run_name=DIVA_for_OpenAICLIP 13 | num_steps=4600 14 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 15 | --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} --use_env \ 16 | run_DIVA_with_OpenAICLIP.py \ 17 | --clip_image_size 224 \ 18 | --visual_pattern None \ 19 | --train_steps 2 \ 20 | --image_size 512 \ 21 | --fixed_image_size False \ 22 | --dataset_path dataset/cc3m/\*.tar \ 23 | --output_dir ./outputs/$run_name \ 24 | --remove_unused_columns False \ 25 | --do_train \ 26 | --ddp_find_unused_parameters True \ 27 | --dataloader_num_workers 8 \ 28 | --learning_rate 1e-4 \ 29 | --bf16 True \ 30 | --tf32 True \ 31 | --warmup_ratio 0.005 \ 32 | --weight_decay 0 \ 33 | --max_steps $num_steps \ 34 | --per_device_train_batch_size 16 \ 35 | --logging_strategy steps \ 36 | --logging_steps 50 \ 37 | --gradient_accumulation_steps 5 \ 38 | --save_strategy steps \ 39 | --save_steps $num_steps \ 40 | --save_total_limit 1 \ 41 | --ddp_backend nccl \ 42 | --report_to wandb \ 43 | --run_name $run_name \ 44 | --enable_flash True \ 45 | --lr_scheduler_type "cosine" \ 46 | --seed 42 \ 47 | --accelerator_config accelerator.json > ./logs/debug_$run_name.log 48 | -------------------------------------------------------------------------------- /DIVA_for_SigLIP.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | set -x 4 | 5 | export TORCH_DISTRIBUTED_DEBUG=INFO 6 | export NCCL_DEBUG=INFO 7 | export OMP_NUM_THREADS=4 8 | export NCCL_P2P_DISABLE=1 9 | MASTER_ADDR=$(hostname -I | awk '{print $1}') 10 | MASTER_PORT=12345 11 | 12 | run_name=DIVA_for_SigLIP 13 | num_steps=4600 14 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 15 | --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} --use_env \ 16 | run_DIVA_with_SigLIP.py \ 17 | --clip_image_size 224 \ 18 | --visual_pattern None \ 19 | --train_steps 2 \ 20 | --image_size 512 \ 21 | --fixed_image_size False \ 22 | --dataset_path dataset/cc3m/\*.tar \ 23 | --output_dir ./outputs/$run_name \ 24 | --remove_unused_columns False \ 25 | --do_train \ 26 | --ddp_find_unused_parameters True \ 27 | --dataloader_num_workers 8 \ 28 | --learning_rate 1e-4 \ 29 | --bf16 True \ 30 | --tf32 True \ 31 | --warmup_ratio 0.005 \ 32 | --weight_decay 0 \ 33 | --max_steps $num_steps \ 34 | --per_device_train_batch_size 16 \ 35 | --logging_strategy steps \ 36 | --logging_steps 50 \ 37 | --gradient_accumulation_steps 5 \ 38 | --save_strategy steps \ 39 | --save_steps $num_steps \ 40 | --save_total_limit 1 \ 41 | --ddp_backend nccl \ 42 | --report_to wandb \ 43 | --run_name $run_name \ 44 | --enable_flash True \ 45 | --lr_scheduler_type "cosine" \ 46 | --seed 42 \ 47 | --accelerator_config accelerator.json > ./logs/debug_$run_name.log 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 BAAI-Vision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Diffusion Feedback Helps CLIP See Better

4 | 5 | [Wenxuan Wang](https://scholar.google.com/citations?user=75OyC-oAAAAJ&hl=zh-CN)1,2,3*, [Quan Sun](https://scholar.google.cz/citations?user=pVKiHdEAAAAJ&hl=zh-CN&oi=ao)3*, [Fan Zhang](https://scholar.google.cz/citations?hl=zh-CN&user=VsJ39HMAAAAJ&view_op=list_works&sortby=pubdate)3, [Yepeng Tang](https://scholar.google.cz/citations?user=CAC_4OUAAAAJ&hl=zh-CN&oi=ao)4, [Jing Liu](https://scholar.google.com/citations?user=sOI-S7oAAAAJ&hl=zh-CN)1,2, [Xinlong Wang](https://scholar.google.com/citations?hl=zh-CN&user=DPz0DjYAAAAJ&view_op=list_works&sortby=pubdate/)3 6 | 7 | 1[CASIA](http://english.ia.cas.cn/), 2[UCAS](https://english.ucas.ac.cn/), 3[BAAI](https://www.baai.ac.cn/english.html), 4[BJTU](https://en.bjtu.edu.cn/)
* Equal Contribution
8 | 9 | 10 |
11 | 12 | 13 | ## ⏰ Schedule 14 | 15 | ### [2025-01-23] Our [paper](https://arxiv.org/abs/2407.20171) is accepted by ICLR 2025 ! 💥 16 | ### [2024-08-07] We release [CLIP model weights](https://huggingface.co/BAAI/DIVA) ! 💥 17 | ### [2024-08-05] We release [training & evaluation code](https://github.com/baaivision/DIVA) ! 💥 18 | ### [2024-07-30] Our [paper](https://arxiv.org/abs/2407.20171) is released on arXiv ! 💥 19 | 20 | 21 | ## 💡 Motivation 22 | 23 |

24 | overview 25 |

26 | 27 | In this work, we present a simple post-training approach for CLIP models, which largely overcomes its visual shortcomings via a self-supervised diffusion process. We introduce DIVA, which uses the DIffusion model as a Visual Assistant for CLIP. Specifically, DIVA leverages generative feedback from text-to-image diffusion models to optimize CLIP representations, with only images (w/o corresponding text). We demonstrate that DIVA improves CLIP's performance on the challenging MMVP-VLM benchmark which assesses fine-grained visual abilities to a large extent (e.g., 3-7% ↑), and enhances the performance of MLLMs and vision models on multimodal understanding and segmentation tasks. Extensive evaluation on 29 image classification and retrieval benchmarks confirms that DIVA preserves CLIP's strong zero-shot capabilities. 28 | 29 | 30 | ## 🤖 Architecture 31 | 32 |

33 | overview 34 |

35 | 36 | Given an image, the CLIP model encodes the visual features as the main part of condition, then the generative diffusion model predicts the added noise taking the noisy image and condition as input. We optimize the CLIP's representation by maximizing the image likelihood with the diffusion loss via generative feedback. 37 | 38 | 39 | ## 🔨 Installation 40 | Clone this repository and install the required packages: 41 | 42 | ```shell 43 | git clone https://github.com/baaivision/DIVA.git 44 | cd DIVA 45 | mkdir -p outputs logs datasets pretrained_weights/CLIP pretrained_weights/SD 46 | 47 | conda create -n diva python=3.9 48 | conda activate diva 49 | pip install -r requirements.txt 50 | ``` 51 | Core packages: 52 | - [Pytorch](https://pytorch.org/) version 2.0.0 53 | - [open-clip-torch](https://github.com/mlfoundations/open_clip) version 2.24.0 54 | - [timm](https://github.com/rwightman/pytorch-image-models) version 0.9.8 55 | 56 | 57 | ## 🍹 Preparation for DIVA's Generative Fine-tuning 58 | 59 | ### Data Acquisition 60 | For data preparation, please refer to [image2dataset](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md) and [MMVP](https://github.com/tsb0601/MMVP/tree/main) for the employed training and evaluation data in this work. After collecting the corresponding datasets, directly put them into the `dataset/` folder path. 61 | 62 | ### Pre-trained Weight Downloading 63 | As for pre-trained weight preparation, please refer to [OpenAI ViT-L-14/224&336](https://github.com/openai/CLIP/blob/main/clip/clip.py), [MetaCLIP ViT-L/H-14](https://github.com/facebookresearch/metaclip), [SigLIP ViT-SO-14/224](https://huggingface.co/timm/ViT-SO400M-14-SigLIP), [SigLIP ViT-SO-14/384](https://huggingface.co/timm/ViT-SO400M-14-SigLIP-384), [DFN ViT-H-14/224](https://huggingface.co/apple/DFN5B-CLIP-ViT-H-14), [DFN ViT-H-14/378](https://huggingface.co/apple/DFN5B-CLIP-ViT-H-14-378) and [SD-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) to acquire the model weights for discriminative CLIP models and the leveraged diffusion model that provides generative feedback. After downloading all these necessary weights, move them respectively to the corresponding folder path `pretrained_weights/CLIP/` and `pretrained_weights/SD/`. 64 | 65 | ### Code Modification 66 | For the preparation for our DIVA's condition design, some source code in the installed [CLIP](https://github.com/openai/CLIP) and [OpenCLIP](https://github.com/mlfoundations/open_clip) packages need to be modified. 67 | 68 | For OpenAI CLIP, use the content in our provided `condition/OpenAICLIP_for_clip_model.py` to replace the content in `Your Conda Installation Path/anaconda3/envs/diva/lib/python3.9/site-packages/clip/model.py`. 69 | 70 | For MetaCLIP and DFN, use the content in our provided `condition/MetaCLIP_for_openclip_transformer.py` and `condition/DFN_for_openclip_transformer.py` to replace the content in `Your Conda Installation Path/anaconda3/envs/diva/lib/python3.9/site-packages/open_clip/transformer.py`, respectively. 71 | 72 | For SigLIP, use the content in our provided `condition/SigLIP_for_timm_models_visiontransformer.py` to replace the content in `Your Conda Installation Path/anaconda3/envs/diva/lib/python3.9/site-packages/timm/models/vision_transformer.py`. 73 | 74 | 75 | ## 🍻 Quick Start for Training & Evaluation 76 | 77 | After all the above preparation steps, you can simply start training for our DIVA with the following command: 78 | ```shell 79 | # For OpenAICLIP 80 | bash DIVA_for_OpenAICLIP.sh 81 | 82 | # For MetaCLIP 83 | bash DIVA_for_MetaCLIP.sh 84 | 85 | # For SigLIP 86 | bash DIVA_for_SigLIP.sh 87 | 88 | # For DFN 89 | bash DIVA_for_DFN.sh 90 | ``` 91 | 92 | ## Model Zoo 93 | 94 | | Method | Image Size | Params (M) | Average Score | 95 | |----------------------|------------|------------|---------------| 96 | | [OpenAI ViT-L-14](https://huggingface.co/BAAI/DIVA/blob/main/OpenAICLIP/OpenAI-ViT-L-14-224.pth) | 224² | 427.6 | 25.9 (+6.6) | 97 | | [OpenAI ViT-L-14](https://huggingface.co/BAAI/DIVA/blob/main/OpenAICLIP/OpenAI-ViT-L-14-336.pth) | 336² | 427.9 | 25.2 (+5.2) | 98 | | [MetaCLIP ViT-L-14](https://huggingface.co/BAAI/DIVA/blob/main/MetaCLIP/MetaCLIP-ViT-L-14.pth) | 224² | 427.6 | 27.4 (+3.7) | 99 | | [MetaCLIP ViT-H-14](https://huggingface.co/BAAI/DIVA/blob/main/MetaCLIP/MetaCLIP-ViT-H-14.pth) | 224² | 986.1 | 31.9 (+6.7) | 100 | | [SigLIP ViT-SO-14](https://huggingface.co/BAAI/DIVA/blob/main/SigLIP/SigLIP-ViT-SO-14-224.pth) | 224² | 877.4 | 40.7 (+2.9) | 101 | | [SigLIP ViT-SO-14](https://huggingface.co/BAAI/DIVA/blob/main/SigLIP/SigLIP-ViT-SO-14-384.pth) | 384² | 878.0 | 38.5 (+1.5) | 102 | | [DFN ViT-H-14](https://huggingface.co/BAAI/DIVA/blob/main/DFN/DFN-ViT-H-14-224.pth) | 224² | 986.1 | 43.7 (+4.4) | 103 | | [DFN ViT-H-14](https://huggingface.co/BAAI/DIVA/blob/main/DFN/DFN-ViT-H-14-378.pth) | 378² | 986.7 | 37.8 (+3.0) | 104 | 105 | 106 | It is worth noting that, due to the randomness among the introduced condition design during the training phase and the selection of local patch tokens during the inference phase for OpenAI CLIP, the obtained scores on MMVP_VLM benchmark using our provided OpenAI CLIP weights might not be the same as the reported results in our paper. At this time, we recommend trying different random seeds multiple times if the scores do not meet expectations. 107 | 108 | ## 🎨 Visualization 109 | 110 |

111 | scene 112 |

113 | 114 | 115 | ## 💙 Acknowledgement 116 | DIVA is built upon the awesome [Diffusion-TTA](https://github.com/mihirp1998/Diffusion-TTA), [MMVP](https://github.com/tsb0601/MMVP), [CLIP](https://github.com/openai/CLIP), [OpenCLIP](https://github.com/mlfoundations/open_clip), [timm](https://github.com/huggingface/pytorch-image-models/). 117 | 118 | ## 📝 Citation 119 | ```bib 120 | @article{wang2024diffusion, 121 | title={Diffusion Feedback Helps CLIP See Better}, 122 | author={Wang, Wenxuan and Sun, Quan and Zhang, Fan and Tang, Yepeng and Liu, Jing and Wang, Xinlong}, 123 | journal={arXiv preprint arXiv:2407.20171}, 124 | year={2024} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /accelerator.json: -------------------------------------------------------------------------------- 1 | { 2 | "dispatch_batches": false 3 | } -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional, List 3 | import transformers 4 | 5 | @dataclass 6 | class DataTrainingArguments: 7 | 8 | dataset_path: str = "dataset_path" 9 | 10 | arbitrary_resolution: Optional[bool] = field( 11 | default=False, 12 | metadata={ 13 | "help": "If true, images will have arbitrary resolutions." 14 | }, 15 | ) 16 | 17 | max_train_samples: Optional[int] = field( 18 | default=None, 19 | metadata={ 20 | "help": ( 21 | "For debugging purposes or quicker training, truncate the number of training examples to this " 22 | "value if set." 23 | ) 24 | }, 25 | ) 26 | max_eval_samples: Optional[int] = field( 27 | default=None, 28 | metadata={ 29 | "help": ( 30 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 31 | "value if set." 32 | ) 33 | }, 34 | ) 35 | 36 | one_minus_one_data_transform: Optional[bool] = field( 37 | default=False, 38 | metadata={ 39 | "help": "If true, the data will be scaled to [-1, 1] instead of [0, 1]." 40 | }, 41 | ) 42 | 43 | 44 | @dataclass 45 | class ModelArguments: 46 | """ 47 | Arguments pertaining to which model/config/image processor we are going to pre-train. 48 | """ 49 | 50 | model_type: Optional[str] = field( 51 | default=None, 52 | metadata={"help": "If training from scratch, pass a model type from the list: "}, 53 | ) 54 | config_overrides: Optional[str] = field( 55 | default=None, 56 | metadata={ 57 | "help": ( 58 | "Override some existing default config settings when a model is trained from scratch. Example: " 59 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 60 | ) 61 | }, 62 | ) 63 | image_size: Optional[int] = field( 64 | default=None, 65 | metadata={ 66 | "help": ( 67 | "The size (resolution) of each image. If not specified, will use `image_size` of the configuration." 68 | ) 69 | }, 70 | ) 71 | 72 | fixed_image_size: Optional[bool] = field( 73 | default=True, 74 | ) 75 | 76 | patch_size: Optional[int] = field( 77 | default=None, 78 | metadata={ 79 | "help": ( 80 | "The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration." 81 | ) 82 | }, 83 | ) 84 | 85 | tublet_size: Optional[List[int]] = field( 86 | default_factory=lambda: [2, 16, 16], 87 | metadata={ 88 | "help": ( 89 | "The size of each tubelet (3D patch size). If not specified, will use `tubelet_size` of the configuration." 90 | ) 91 | }, 92 | ) 93 | 94 | cost_gradient_penalty: Optional[float] = field( 95 | default=0, # 0.2 96 | ) 97 | 98 | enable_flash: Optional[bool] = field( 99 | default=False, 100 | ) 101 | 102 | @dataclass 103 | class TrainingArguments(transformers.TrainingArguments): 104 | 105 | multiple_optimizer_training: Optional[float] = field( default=False, metadata={ "help": "will become true if `gan_loss_weight` in `model_args` is set to allow multiple optimizers" } ) 106 | 107 | wandb_api_key: Optional[str] = field( 108 | default=None, 109 | metadata={ 110 | "help": "wandb api key" 111 | } 112 | ) 113 | 114 | train_steps: Optional[int] = field(default=1,) 115 | 116 | visual_pattern: Optional[str] = field(default=None,) 117 | 118 | clip_image_size: Optional[int] = field(default=224,) 119 | 120 | metaclip_version: Optional[str] = field(default=None,) -------------------------------------------------------------------------------- /assets/introduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/DIVA/dee07edc8ac299757e310bf84c31cc82a8ccc44c/assets/introduction.png -------------------------------------------------------------------------------- /assets/methodology.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/DIVA/dee07edc8ac299757e310bf84c31cc82a8ccc44c/assets/methodology.png -------------------------------------------------------------------------------- /assets/qualitative_mmvp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/DIVA/dee07edc8ac299757e310bf84c31cc82a8ccc44c/assets/qualitative_mmvp.png -------------------------------------------------------------------------------- /callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers.training_args import TrainingArguments 3 | from transformers.utils import logging, is_torch_tpu_available 4 | logger = logging.get_logger(__name__) 5 | 6 | from transformers.integrations import WandbCallback 7 | from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState 8 | 9 | class MyWandbCallback(WandbCallback): 10 | def setup(self, args, state, model, **kwargs): 11 | """ 12 | Setup the optional Weights & Biases (*wandb*) integration. 13 | 14 | One can subclass and override this method to customize the setup if needed. Find more information 15 | [here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment 16 | variables: 17 | 18 | Environment: 19 | - **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`): 20 | Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set 21 | to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint 22 | will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along 23 | with [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model. 24 | 25 | 26 | 27 | Setting `WANDB_LOG_MODEL` as `bool` will be deprecated in version 5 of 🤗 Transformers. 28 | 29 | 30 | - **WANDB_WATCH** (`str`, *optional* defaults to `"false"`): 31 | Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and 32 | parameters. 33 | """ 34 | if self._wandb is None: 35 | return 36 | self._initialized = True 37 | if state.is_world_process_zero: 38 | logger.info( 39 | 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' 40 | ) 41 | combined_dict = {**args.to_dict()} 42 | 43 | if hasattr(model, "config") and model.config is not None: 44 | model_config = model.config.to_dict() 45 | combined_dict = {**model_config, **combined_dict} 46 | trial_name = state.trial_name 47 | init_args = {} 48 | if trial_name is not None: 49 | init_args["name"] = trial_name 50 | init_args["group"] = args.run_name 51 | else: 52 | if not (args.run_name is None or args.run_name == args.output_dir): 53 | init_args["name"] = args.run_name 54 | 55 | if self._wandb.run is None: 56 | self._wandb.init( 57 | project=os.getenv("WANDB_PROJECT", "huggingface"), 58 | settings=self._wandb.Settings(code_dir="/share/project/qiying/projects/vision-tokenizer"), 59 | **init_args, 60 | ) 61 | # add config parameters (run may have been created manually) 62 | self._wandb.config.update(combined_dict, allow_val_change=True) 63 | 64 | # define default x-axis (for latest wandb versions) 65 | if getattr(self._wandb, "define_metric", None): 66 | self._wandb.define_metric("train/global_step") 67 | self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True) 68 | 69 | # keep track of model topology and gradients, unsupported on TPU 70 | _watch_model = os.getenv("WANDB_WATCH", "false") 71 | if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"): 72 | self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps)) 73 | self._wandb.run._label(code="transformers_trainer") 74 | 75 | class ModelCallback(TrainerCallback): 76 | def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 77 | kwargs['model'].global_step = state.global_step 78 | return super().on_step_begin(args, state, control, **kwargs) -------------------------------------------------------------------------------- /condition/OpenAICLIP_for_clip_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 103 | super().__init__() 104 | self.output_dim = output_dim 105 | self.input_resolution = input_resolution 106 | 107 | # the 3-layer stem 108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(width // 2) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(2) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | def stem(x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.avgpool(x) 144 | return x 145 | 146 | x = x.type(self.conv1.weight.dtype) 147 | x = stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | 154 | return x 155 | 156 | 157 | class LayerNorm(nn.LayerNorm): 158 | """Subclass torch's LayerNorm to handle fp16.""" 159 | 160 | def forward(self, x: torch.Tensor): 161 | orig_type = x.dtype 162 | ret = super().forward(x.type(torch.float32)) 163 | # ret = super().forward(x) 164 | return ret.type(orig_type) 165 | 166 | 167 | class QuickGELU(nn.Module): 168 | def forward(self, x: torch.Tensor): 169 | return x * torch.sigmoid(1.702 * x) 170 | 171 | 172 | class ResidualAttentionBlock(nn.Module): 173 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 174 | super().__init__() 175 | 176 | self.attn = nn.MultiheadAttention(d_model, n_head) 177 | self.ln_1 = LayerNorm(d_model) 178 | self.mlp = nn.Sequential(OrderedDict([ 179 | ("c_fc", nn.Linear(d_model, d_model * 4)), 180 | ("gelu", QuickGELU()), 181 | ("c_proj", nn.Linear(d_model * 4, d_model)) 182 | ])) 183 | self.ln_2 = LayerNorm(d_model) 184 | self.attn_mask = attn_mask 185 | 186 | def attention(self, x: torch.Tensor): 187 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 188 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 189 | 190 | def forward(self, x: torch.Tensor): 191 | x = x + self.attention(self.ln_1(x)) 192 | x = x + self.mlp(self.ln_2(x)) 193 | return x 194 | 195 | 196 | class Transformer(nn.Module): 197 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 198 | super().__init__() 199 | self.width = width 200 | self.layers = layers 201 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 202 | 203 | def forward(self, x: torch.Tensor): 204 | return self.resblocks(x) 205 | 206 | 207 | class VisionTransformer(nn.Module): 208 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 209 | super().__init__() 210 | self.input_resolution = input_resolution 211 | self.output_dim = output_dim 212 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 213 | 214 | scale = width ** -0.5 215 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 216 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 217 | self.ln_pre = LayerNorm(width) 218 | 219 | self.transformer = Transformer(width, layers, heads) 220 | 221 | self.ln_post = LayerNorm(width) 222 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 223 | self.output_dim = output_dim 224 | 225 | def forward(self, x: torch.Tensor): 226 | x = self.conv1(x) # shape = [*, width, grid, grid] 227 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 228 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 229 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 230 | x = x + self.positional_embedding.to(x.dtype) 231 | x = self.ln_pre(x) 232 | 233 | x = x.permute(1, 0, 2) # NLD -> LND 234 | x = self.transformer(x) 235 | x = x.permute(1, 0, 2) # LND -> NLD 236 | 237 | 238 | # original code 239 | # x = self.ln_post(x[:, 0, :]) #[2, 1024] 240 | # if self.proj is not None: 241 | # x = x @ self.proj #[2, 768] 242 | # return x.unsqueeze(1) 243 | 244 | 245 | # DIVA's condition for OpenAI-CLIP(vit-l/224 & vit-l/336) 246 | class_token = x[:, 0, :].unsqueeze(1) 247 | remaining_tokens = x[:, 1:, :] 248 | num_seleted_tokens = int(77-1) 249 | random_indices = torch.randint(0, remaining_tokens.shape[1], (x.shape[0], num_seleted_tokens)) 250 | batch_indices = torch.arange(x.shape[0]).unsqueeze(1).expand(x.shape[0], num_seleted_tokens) 251 | selected_tokens = remaining_tokens[batch_indices, random_indices] 252 | x = torch.cat((class_token, selected_tokens), dim=1) 253 | B,N,C = x.shape 254 | x = x.reshape(B*N,C) 255 | x = self.ln_post(x) 256 | if self.proj is not None: 257 | x = x @ self.proj 258 | x = x.reshape(B,N,self.output_dim) 259 | return x 260 | 261 | 262 | class CLIP(nn.Module): 263 | def __init__(self, 264 | embed_dim: int, 265 | # vision 266 | image_resolution: int, 267 | vision_layers: Union[Tuple[int, int, int, int], int], 268 | vision_width: int, 269 | vision_patch_size: int, 270 | # text 271 | context_length: int, 272 | vocab_size: int, 273 | transformer_width: int, 274 | transformer_heads: int, 275 | transformer_layers: int 276 | ): 277 | super().__init__() 278 | 279 | self.context_length = context_length 280 | 281 | if isinstance(vision_layers, (tuple, list)): 282 | vision_heads = vision_width * 32 // 64 283 | self.visual = ModifiedResNet( 284 | layers=vision_layers, 285 | output_dim=embed_dim, 286 | heads=vision_heads, 287 | input_resolution=image_resolution, 288 | width=vision_width 289 | ) 290 | else: 291 | vision_heads = vision_width // 64 292 | self.visual = VisionTransformer( 293 | input_resolution=image_resolution, 294 | patch_size=vision_patch_size, 295 | width=vision_width, 296 | layers=vision_layers, 297 | heads=vision_heads, 298 | output_dim=embed_dim 299 | ) 300 | 301 | self.transformer = Transformer( 302 | width=transformer_width, 303 | layers=transformer_layers, 304 | heads=transformer_heads, 305 | attn_mask=self.build_attention_mask() 306 | ) 307 | 308 | self.vocab_size = vocab_size 309 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 310 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 311 | self.ln_final = LayerNorm(transformer_width) 312 | 313 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 314 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 315 | 316 | self.initialize_parameters() 317 | 318 | def initialize_parameters(self): 319 | nn.init.normal_(self.token_embedding.weight, std=0.02) 320 | nn.init.normal_(self.positional_embedding, std=0.01) 321 | 322 | if isinstance(self.visual, ModifiedResNet): 323 | if self.visual.attnpool is not None: 324 | std = self.visual.attnpool.c_proj.in_features ** -0.5 325 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 326 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 327 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 328 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 329 | 330 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 331 | for name, param in resnet_block.named_parameters(): 332 | if name.endswith("bn3.weight"): 333 | nn.init.zeros_(param) 334 | 335 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 336 | attn_std = self.transformer.width ** -0.5 337 | fc_std = (2 * self.transformer.width) ** -0.5 338 | for block in self.transformer.resblocks: 339 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 340 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 341 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 342 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 343 | 344 | if self.text_projection is not None: 345 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 346 | 347 | def build_attention_mask(self): 348 | # lazily create causal attention mask, with full attention between the vision tokens 349 | # pytorch uses additive attention mask; fill with -inf 350 | mask = torch.empty(self.context_length, self.context_length) 351 | mask.fill_(float("-inf")) 352 | mask.triu_(1) # zero out the lower diagonal 353 | return mask 354 | 355 | @property 356 | def dtype(self): 357 | return self.visual.conv1.weight.dtype 358 | 359 | def encode_image(self, image): 360 | return self.visual(image.type(self.dtype)) 361 | 362 | def encode_text(self, text): 363 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 364 | x = x + self.positional_embedding.type(self.dtype) 365 | x = x.permute(1, 0, 2) # NLD -> LND 366 | x = self.transformer(x) 367 | x = x.permute(1, 0, 2) # LND -> NLD 368 | x = self.ln_final(x).type(self.dtype) 369 | 370 | # x.shape = [batch_size, n_ctx, transformer.width] 371 | # take features from the eot embedding (eot_token is the highest number in each sequence) 372 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 373 | 374 | return x 375 | 376 | def forward(self, image, text): 377 | 378 | # original code 379 | # image_features = self.encode_image(image) 380 | 381 | # ours 382 | global_image_features = self.encode_image(image)[:,0,:] 383 | local_image_features = self.encode_image(image)[:,1:,:].mean(dim=1) 384 | image_features = global_image_features+local_image_features 385 | 386 | text_features = self.encode_text(text) 387 | 388 | # normalized features 389 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 390 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 391 | 392 | # cosine similarity as logits 393 | logit_scale = self.logit_scale.exp() 394 | logits_per_image = logit_scale * image_features @ text_features.t() 395 | logits_per_text = logits_per_image.t() 396 | 397 | # shape = [global_batch_size, global_batch_size] 398 | return logits_per_image, logits_per_text 399 | 400 | 401 | def convert_weights(model: nn.Module): 402 | """Convert applicable model parameters to fp16""" 403 | 404 | def _convert_weights_to_fp16(l): 405 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 406 | l.weight.data = l.weight.data.half() 407 | if l.bias is not None: 408 | l.bias.data = l.bias.data.half() 409 | 410 | if isinstance(l, nn.MultiheadAttention): 411 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 412 | tensor = getattr(l, attr) 413 | if tensor is not None: 414 | tensor.data = tensor.data.half() 415 | 416 | for name in ["text_projection", "proj"]: 417 | if hasattr(l, name): 418 | attr = getattr(l, name) 419 | if attr is not None: 420 | attr.data = attr.data.half() 421 | 422 | model.apply(_convert_weights_to_fp16) 423 | 424 | 425 | def build_model(state_dict: dict): 426 | vit = "visual.proj" in state_dict 427 | 428 | if vit: 429 | vision_width = state_dict["visual.conv1.weight"].shape[0] 430 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 431 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 432 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 433 | 434 | #default 435 | image_resolution = vision_patch_size * grid_size 436 | 437 | # 以下代码服务于CLIP的位置编码插值 438 | # image_resolution=int(512) 439 | # src = state_dict["visual.positional_embedding"] 440 | # src_cls = src[0:1] 441 | # src_oth = src[1:] 442 | # src_oth = F.interpolate(src_oth.reshape(grid_size,grid_size,1024).permute(2,0,1).unsqueeze(0),(image_resolution//14,image_resolution//14),mode='bilinear') 443 | # src_oth = src_oth[0].permute(1,2,0).reshape(-1,1024) 444 | # tgt = torch.cat((src_cls,src_oth),dim=0) 445 | # state_dict["visual.positional_embedding"] = tgt 446 | 447 | 448 | else: 449 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 450 | vision_layers = tuple(counts) 451 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 452 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 453 | vision_patch_size = None 454 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 455 | image_resolution = output_width * 32 456 | 457 | embed_dim = state_dict["text_projection"].shape[1] 458 | context_length = state_dict["positional_embedding"].shape[0] 459 | vocab_size = state_dict["token_embedding.weight"].shape[0] 460 | transformer_width = state_dict["ln_final.weight"].shape[0] 461 | transformer_heads = transformer_width // 64 462 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 463 | 464 | model = CLIP( 465 | embed_dim, 466 | image_resolution, vision_layers, vision_width, vision_patch_size, 467 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 468 | ) 469 | 470 | for key in ["input_resolution", "context_length", "vocab_size"]: 471 | if key in state_dict: 472 | del state_dict[key] 473 | 474 | convert_weights(model) 475 | model.load_state_dict(state_dict) 476 | return model.eval() 477 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | from transformers.utils import logging 3 | 4 | logger = logging.get_logger(__name__) 5 | 6 | class EmptyClass(PretrainedConfig): 7 | def __init__(self): 8 | pass 9 | class SDConfig(PretrainedConfig): 10 | 11 | def __init__(self, 12 | sd_version = '2-1', 13 | override_total_steps = -1, 14 | freeze_class_embeds = True, 15 | freeze_vae = False, 16 | use_flash = False, 17 | adapt_only_classifier = True, 18 | adapt_topk = -1, 19 | loss = 'mse', 20 | actual_bs = 16, 21 | mean = [0.485, 0.456, 0.406], 22 | std = [0.229, 0.224, 0.225], 23 | use_same_noise_among_timesteps = False, 24 | random_timestep_per_iteration = True, 25 | rand_timestep_equal_int = False, 26 | weight_decay = 0, 27 | train_steps = 1, 28 | accum_iter = 1, 29 | optimizer = 'sgd', 30 | optimizer_momentum = 0.9, 31 | pred_noise_batch_size = 1, 32 | output_dir = './outputs/First_Start', 33 | visual_pattern = None, 34 | clip_image_size = 224, 35 | metaclip_version = 1 36 | ): 37 | super().__init__() 38 | self.model = EmptyClass() 39 | self.model.sd_version = sd_version 40 | self.model.override_total_steps = override_total_steps 41 | self.model.freeze_class_embeds = freeze_class_embeds 42 | self.model.freeze_vae = freeze_vae 43 | self.model.use_flash = use_flash 44 | self.model.adapt_only_classifier = adapt_only_classifier 45 | self.tta = EmptyClass() 46 | self.tta.gradient_descent = EmptyClass() 47 | self.tta.adapt_topk = adapt_topk 48 | self.tta.loss = loss 49 | self.tta.use_same_noise_among_timesteps = use_same_noise_among_timesteps 50 | self.tta.random_timestep_per_iteration = random_timestep_per_iteration 51 | self.tta.rand_timestep_equal_int = rand_timestep_equal_int 52 | self.tta.gradient_descent.weight_decay = weight_decay 53 | self.tta.gradient_descent.train_steps = train_steps 54 | self.tta.gradient_descent.accum_iter = accum_iter 55 | self.tta.gradient_descent.optimizer = optimizer 56 | self.tta.gradient_descent.optimizer_momentum = optimizer_momentum 57 | self.input = EmptyClass() 58 | self.input.batch_size = pred_noise_batch_size 59 | self.input.mean = mean 60 | self.input.std = std 61 | self.output_dir = output_dir 62 | self.actual_bs = actual_bs 63 | self.visual_pattern = visual_pattern 64 | self.clip_image_size = clip_image_size 65 | self.metaclip_version = metaclip_version 66 | 67 | if __name__ =='__main__': 68 | SDConfig() -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import image_transform, image_transform_vq, DiffAugment 2 | from .data import get_wds_dataset_and_collator, get_cc3m_wds_dataset_and_collator, get_in1k_val_dataset 3 | from .image_data import get_wds_dataset_and_collator_arbitrary_resolution, get_highres_eval_dataset, get_in1k_dataset 4 | from .constants import ASPECT_RATIO_1024, ASPECT_RATIO_512, ASPECT_RATIO_256, DEFAULT_IMAGE_FILE_SUFFIX, OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -------------------------------------------------------------------------------- /data/constants.py: -------------------------------------------------------------------------------- 1 | 2 | OPENAI_DATASET_MEAN = [0.48145466, 0.4578275, 0.40821073] 3 | OPENAI_DATASET_STD = [0.26862954, 0.26130258, 0.27577711] 4 | 5 | DEFAULT_IMAGE_FILE_SUFFIX = ['jpg', '0.jpg', '0.png', 'png', 'jpeg', '0.jpeg', 'webp'] 6 | DEFAULT_VIDEO_FILE_SUFFIX = ['mp4', 'video.mp4'] 7 | 8 | ASPECT_RATIO_1024 = { 9 | '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.], 10 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], 11 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], 12 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], 13 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], 14 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], 15 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], 16 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], 17 | '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.], 18 | '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.], 19 | } 20 | 21 | 22 | ASPECT_RATIO_1024_norm = { 23 | '0.68': [832, 1216], 24 | '0.72': [832, 1152], 25 | '0.78': [896, 1152], 26 | '1.0': [1024, 1024], 27 | '1.29': [1152, 896], 28 | '1.46': [1216, 832], 29 | '1.75': [1344, 768], 30 | } 31 | 32 | ASPECT_RATIO_512 = { 33 | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], 34 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], 35 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], 36 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], 37 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], 38 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], 39 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], 40 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], 41 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], 42 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] 43 | } 44 | 45 | 46 | ASPECT_RATIO_512_norm = { 47 | '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], 48 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], 49 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], 50 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], 51 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], 52 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], 53 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] 54 | } 55 | 56 | ASPECT_RATIO_256 = { 57 | '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0], 58 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], 59 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], 60 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], 61 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], 62 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], 63 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], 64 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], 65 | '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0], 66 | '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0] 67 | } -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import io 3 | import os 4 | import json 5 | from functools import partial 6 | from typing import Sequence, Dict, Union, Tuple 7 | from dataclasses import dataclass 8 | import numpy as np 9 | from einops import rearrange 10 | from PIL import Image 11 | import random 12 | import torch 13 | import torchvision 14 | from torch.utils.data import Dataset 15 | from datasets import load_dataset 16 | from torchvision import transforms 17 | 18 | from .constants import ASPECT_RATIO_256, ASPECT_RATIO_512, ASPECT_RATIO_1024, DEFAULT_IMAGE_FILE_SUFFIX 19 | from .transform import image_transform, image_transform_original_resolution, image_transform_original_resolution_test 20 | 21 | 22 | aspect_ratio_database = { 23 | 256: ASPECT_RATIO_256, 24 | 512: ASPECT_RATIO_512, 25 | 1024: ASPECT_RATIO_1024 26 | } 27 | 28 | def ratio_sample(ratio, aspect_ratios=ASPECT_RATIO_1024): 29 | closest_ratio = min(aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) 30 | return closest_ratio 31 | 32 | def find_image(sample): 33 | for suffix in DEFAULT_IMAGE_FILE_SUFFIX: 34 | if suffix in sample.keys(): 35 | sample['0.jpg'] = sample[suffix] 36 | break 37 | return sample 38 | 39 | def get_cc3m_wds_dataset_and_collator(data_args, model_args): 40 | img_size = model_args.image_size 41 | train_processor = image_transform(img_size, is_train=True) 42 | val_processor = image_transform(img_size, is_train=False) 43 | 44 | data = load_dataset("webdataset", data_dir=data_args.dataset_path, split="train", streaming=True) 45 | data = data.shuffle(buffer_size=2_000, seed=data_args.seed) 46 | 47 | def decode(sample, img_processor): 48 | sample = find_image(sample) 49 | sample['image'] = img_processor(sample['jpg']) 50 | sample['text'] = sample['txt'] 51 | return sample 52 | data = data.map( 53 | partial(decode, img_processor=train_processor), 54 | remove_columns=['__key__', '__url__'] 55 | ) 56 | data = data.filter(lambda sample: 'image' in sample and 'text' in sample) # filter return samples that match the given condition 57 | data_collator = CC3M_WebdatasetCollator(model_args.patch_size) 58 | 59 | return data, data_collator 60 | 61 | def get_wds_dataset_and_collator(data_args, model_args): 62 | img_size = model_args.image_size 63 | 64 | train_processor = image_transform(img_size, is_train=True) if model_args.fixed_image_size else image_transform 65 | data = load_dataset("webdataset", data_dir=data_args.dataset_path, split="train", streaming=True) 66 | data = data.shuffle(buffer_size=2_000, seed=data_args.seed) 67 | 68 | def decode(sample, img_processor): 69 | sample = find_image(sample) 70 | if model_args.fixed_image_size: 71 | sample['0.jpg'] = img_processor(sample['0.jpg']) 72 | return sample 73 | 74 | data = data.map( 75 | partial(decode, img_processor=train_processor), 76 | remove_columns=['__key__', '__url__'] 77 | ) 78 | data = data.filter(lambda sample: '0.jpg' in sample) 79 | data = data.rename_columns({'0.jpg': 'image'}) 80 | 81 | aspect_ratios = aspect_ratio_database.get(model_args.image_size, None) 82 | aspect_ratios = aspect_ratios or ASPECT_RATIO_512 83 | 84 | data_collator = WebdatasetCollator(model_args.fixed_image_size, model_args.patch_size, aspect_ratios) 85 | import ipdb 86 | ipdb.set_trace() 87 | return data, data_collator 88 | 89 | def get_wds_dataset_and_collator_arbitrary_resolution(data_args, model_args): 90 | 91 | data = load_dataset("webdataset", data_dir=data_args.dataset_path, split="train", streaming=True) 92 | data = data.shuffle(buffer_size=2_000, seed=data_args.seed) 93 | 94 | def decode_sample(sample, img_processor): 95 | sample = find_image(sample) 96 | sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg']) 97 | return sample 98 | 99 | data = data.map( 100 | partial( 101 | decode_sample, 102 | img_processor=partial(image_transform_original_resolution, patch_size=model_args.patch_size) 103 | ), 104 | remove_columns=['__key__', '__url__'] 105 | ) 106 | data = data.filter(lambda sample: '0.jpg' in sample and sample['0.jpg'].ndim == 3 and sample['0.jpg'].shape[-1] > 0 and sample['0.jpg'].shape[-2] > 0) # filter return samples that match the given condition 107 | data = data.rename_columns({'0.jpg': 'image'}) 108 | data_collator = WebdatasetCollator(model_args.patch_size) 109 | 110 | return data, data_collator 111 | 112 | def dataset_test(data_args, model_args): 113 | from datasets import load_dataset 114 | import numpy as np 115 | from functools import partial 116 | from torchvision import transforms 117 | OPENAI_DATASET_MEAN = np.array([0.48145466, 0.4578275, 0.40821073]) 118 | OPENAI_DATASET_STD = np.array([0.26862954, 0.26130258, 0.27577711]) 119 | DEFAULT_IMAGE_FILE_SUFFIX = ['jpg', '0.jpg', '0.png', 'png', 'jpeg', '0.jpeg', 'webp'] 120 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/*/*.tar", split="train", streaming=True) 121 | data = data.shuffle(buffer_size=2_000, seed=1) 122 | 123 | data_iter = iter(data) 124 | 125 | def decode_sample(sample, img_processor): 126 | sample = find_image(sample) 127 | sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg']) 128 | return sample 129 | 130 | data = data.map( 131 | partial( 132 | decode_sample, 133 | img_processor=partial(image_transform_original_resolution_test, patch_size=16) 134 | ), 135 | remove_columns=['__key__', '__url__'] 136 | ) 137 | data = data.filter(lambda sample: '0.jpg' in sample) # filter return samples that match the given condition 138 | data = data.rename_columns({'0.jpg': 'image'}) 139 | data_collator = WebdatasetCollator() 140 | 141 | return data, data_collator 142 | 143 | def collate_anyres(images, sizes, patch_size, max_size=2048): 144 | """ 145 | Args: 146 | * images: list of images 147 | * sizes: list of image sizes in (ph, pw), i.e., number of patches in h and w 148 | 149 | Return: args accepted by VQModel 150 | * pixel_values: packed images 151 | * cu_seqlens_img 152 | * max_seqlen_img 153 | * grid_hw 154 | * image_sizes 155 | """ 156 | b, c = len(images), images[0].shape[0] 157 | max_patch_num = max_size // patch_size 158 | 159 | image_sizes = torch.tensor([(image.shape[1], image.shape[2]) for image in images]) 160 | H, W = image_sizes.max(dim=0).values 161 | padded_images = images[0].new_zeros(size=(b, c, H.item(), W.item())) 162 | 163 | h, w = torch.tensor(sizes).max(dim=0).values 164 | padding_masks = torch.zeros(size=(b, h.item(), w.item()), dtype=torch.bool) 165 | 166 | for i, (image, mask_size) in enumerate(zip(images, sizes)): 167 | padded_images[i, :, : image.shape[1], : image.shape[2]].copy_(image) 168 | padding_masks[i, : mask_size[0], : mask_size[1]] = 1 169 | 170 | padded_images = padded_images.reshape(b, c, h, patch_size, w, patch_size) 171 | padded_images = torch.einsum("nchpwq->nhwpqc", padded_images) 172 | padded_images = padded_images.reshape(b, h, w, -1) 173 | packed_images = padded_images[padding_masks] 174 | 175 | seq_lens = padding_masks.flatten(1, 2).sum(dim=-1) 176 | cu_seqlens_img = torch.nn.functional.pad( 177 | torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) 178 | ) 179 | max_seqlen_img = seq_lens.max() 180 | 181 | grid_h = torch.arange(0, h)[None, :, None].repeat(b, 1, w) 182 | grid_w = torch.arange(0, w)[None, None, :].repeat(b, h, 1) 183 | grid_hw = grid_h[padding_masks] * max_patch_num + grid_w[padding_masks] 184 | 185 | return packed_images, cu_seqlens_img, max_seqlen_img, grid_hw, torch.tensor(sizes) 186 | @dataclass 187 | class CC3M_WebdatasetCollator: 188 | def __init__(self, patch_size: int = 1): 189 | self.patch_size = patch_size 190 | self.count = 0 191 | 192 | def __call__( 193 | self, 194 | samples: Sequence[Dict], 195 | ) -> Dict[str, torch.Tensor]: 196 | 197 | self.count += 1 198 | images = [sample["image"] for sample in samples] 199 | texts = [sample["text"] for sample in samples] 200 | 201 | if "size" in samples[0]: 202 | sizes = [sample['size'] for sample in samples] 203 | 204 | batch = {} 205 | 206 | if all(x is not None and x.shape == images[0].shape for x in images): 207 | batch['image'] = torch.stack(images) 208 | else: 209 | if "size" in samples[0]: 210 | batch['image'], batch['cu_seqlens_img'], \ 211 | batch['max_seqlen_img'], batch['grid_hw'], \ 212 | batch['image_sizes'] = collate_anyres(images, sizes, self.patch_size) 213 | else: 214 | batch['image'] = images 215 | batch['text'] = texts 216 | return batch 217 | 218 | @dataclass 219 | class WebdatasetCollator: 220 | def __init__(self,fixed_image_size: bool = True, patch_size: int = 8, aspect_ratios=ASPECT_RATIO_512): 221 | self.fixed_image_size = fixed_image_size 222 | self.patch_size = patch_size 223 | self.aspect_ratios = aspect_ratios 224 | self.count = 0 225 | 226 | def __call__( 227 | self, 228 | samples: Sequence[Dict], 229 | ) -> Dict[str, torch.Tensor]: 230 | 231 | self.count += 1 232 | images = [sample["image"] for sample in samples] 233 | 234 | if "size" in samples[0]: 235 | sizes = [sample['size'] for sample in samples] 236 | 237 | if "size" not in samples[0] and not self.fixed_image_size: 238 | np.random.seed(self.count) 239 | 240 | aspect_ratio = np.random.choice(list(self.aspect_ratios.keys())) 241 | image_sizes = [int(x) for x in self.aspect_ratios[aspect_ratio]] 242 | img_processor = image_transform(image_sizes, is_train=True) 243 | images = [img_processor(image) for image in images] 244 | 245 | batch = {} 246 | 247 | if all(x is not None and x.shape == images[0].shape for x in images): 248 | batch['pixel_values'] = torch.stack(images) 249 | else: 250 | if "size" in samples[0]: 251 | batch['pixel_values'], batch['cu_seqlens_img'], \ 252 | batch['max_seqlen_img'], batch['grid_hw'], \ 253 | batch['image_sizes'] = collate_anyres(images, sizes, self.patch_size) 254 | else: 255 | batch['pixel_values'] = images 256 | return batch 257 | 258 | def anyres_process_images_for_model(image_path=None, pil_image=None, patch_size=32): 259 | """ 260 | given a list of image_path or pil_image, transform to input to model 261 | """ 262 | if image_path is not None: 263 | assert pil_image is None 264 | if not isinstance(image_path, list): 265 | image_path = [image_path] 266 | pil_image = [] 267 | for p in image_path: 268 | pil_image.append(Image.open(p).convert('RGB')) 269 | if not isinstance(pil_image, list): 270 | pil_image = [pil_image] 271 | 272 | if len(pil_image) % 2 != 0: 273 | pil_image.append(pil_image[-1]) 274 | 275 | image_tensors, sizes = [], [] 276 | for pil_i in pil_image: 277 | image_tensor, size = image_transform_original_resolution(image=pil_i, patch_size=patch_size) 278 | image_tensors.append(image_tensor) 279 | sizes.append(size) 280 | 281 | pixel_values, cu_seqlens_img, max_seqlen_img, grid_hw, image_sizes = collate_anyres(image_tensors, sizes, patch_size) 282 | 283 | return { 284 | 'pixel_values': pixel_values, 285 | 'cu_seqlens_img': cu_seqlens_img, 286 | 'max_seqlen_img': max_seqlen_img, 287 | 'grid_hw': grid_hw, 288 | 'image_sizes': image_sizes 289 | } 290 | 291 | def get_in1k_val_dataset(data_args, model_args): 292 | import torchvision 293 | transform = image_transform(model_args.image_size, is_train=False) 294 | dataset = torchvision.datasets.ImageFolder(root="/share/project/datasets/ImageNet/val", transform=transform) 295 | def in1k_collator(samples): 296 | if model_args.gan_loss_weight: 297 | return {"pixel_values": torch.stack([sample[0] for sample in samples]), "optimizer_idx": 0} 298 | return {"pixel_values": torch.stack([sample[0] for sample in samples])} 299 | def in1k_collator_anyres(samples): 300 | images = [sample[0] for sample in samples] 301 | sizes = [[image.shape[1] // model_args.patch_size, image.shape[2] // model_args.patch_size] for image in images] 302 | b, c = len(images), images[0].shape[0] 303 | max_patch_num = 1024 // model_args.patch_size 304 | 305 | image_sizes = torch.tensor([(image.shape[1], image.shape[2]) for image in images]) 306 | H, W = image_sizes.max(dim=0).values 307 | padded_images = images[0].new_zeros(size=(b, c, H.item(), W.item())) 308 | 309 | h, w = torch.tensor(sizes).max(dim=0).values 310 | padding_masks = torch.zeros(size=(b, h.item(), w.item()), dtype=torch.bool) 311 | 312 | for i, (image, mask_size) in enumerate(zip(images, sizes)): 313 | padded_images[i, :, : image.shape[1], : image.shape[2]].copy_(image) 314 | padding_masks[i, : mask_size[0], : mask_size[1]] = 1 315 | 316 | padded_images = padded_images.reshape(b, c, h, model_args.patch_size, w, model_args.patch_size) 317 | padded_images = torch.einsum("nchpwq->nhwpqc", padded_images) 318 | padded_images = padded_images.reshape(b, h, w, -1) 319 | packed_images = padded_images[padding_masks] 320 | 321 | seq_lens = padding_masks.flatten(1, 2).sum(dim=-1) 322 | cu_seqlens_img = torch.nn.functional.pad( 323 | torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) 324 | ) 325 | max_seqlen_img = seq_lens.max() 326 | 327 | grid_h = torch.arange(0, h)[None, :, None].repeat(b, 1, w) 328 | grid_w = torch.arange(0, w)[None, None, :].repeat(b, h, 1) 329 | grid_hw = grid_h[padding_masks] * max_patch_num + grid_w[padding_masks] 330 | 331 | batch = {} 332 | batch['pixel_values'] = packed_images 333 | batch['cu_seqlens_img'] = cu_seqlens_img 334 | batch['max_seqlen_img'] = max_seqlen_img 335 | batch['grid_hw'] = grid_hw 336 | batch['image_sizes'] = torch.tensor(sizes) 337 | if model_args.gan_loss_weight: 338 | batch["optimizer_idx"] = 0 339 | return batch 340 | 341 | return dataset, in1k_collator_anyres if data_args.arbitrary_resolution else in1k_collator 342 | 343 | 344 | def split_val_set(): 345 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/*/*.tar", split="train", streaming=True) 346 | data = data.shuffle(buffer_size=100_000, seed=100) 347 | val_set = [] 348 | val_ids = set() 349 | for i, item in enumerate(data): 350 | item_id = item['__url__'] + item['__key__'] 351 | if i == 50_000: 352 | break 353 | val_set.append(item) 354 | val_ids.add(item_id) 355 | with open("/share/project/datasets/laion-high-resolution/50k_eval_ids.pkl", "wb") as f: 356 | pickle.dump(val_ids, f) 357 | 358 | import webdataset as wds 359 | from PIL import Image 360 | from pathlib import Path 361 | for i in range(50): 362 | sink = wds.TarWriter(f"/share/project/datasets/laion-high-resolution/eval/eval_{i}.tar") 363 | for sample in val_set[i * 1000: (i + 1) * 1000]: 364 | sink.write(sample) 365 | sink.close() 366 | 367 | def preprocess_val_data(): 368 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/eval/*.tar", split="train", streaming=True) 369 | save_dir = "/share/project/datasets/laion-high-resolution/eval/" 370 | info = [] 371 | def decode_sample(sample): 372 | sample = find_image(sample) 373 | sample['0.jpg'], sample['size'] = image_transform_original_resolution(sample['0.jpg'], patch_size=32) 374 | return sample 375 | for i, item in tqdm(enumerate(data)): 376 | image_file = save_dir + f"image_{i}.pkl" 377 | # process image -> size, image pkl 378 | item = decode_sample(item) 379 | with open(image_file, "wb") as f: 380 | pickle.dump(item['0.jpg'], f) 381 | info.append({ 382 | "image_path": image_file, 383 | "size": item['size'] 384 | }) 385 | print(i) 386 | with open(save_dir + f"image_info.pkl", "wb") as f: 387 | pickle.dump(info, f) 388 | 389 | class HighresEvalDataset(Dataset): 390 | def __init__(self): 391 | with open("/share/project/datasets/laion-high-resolution/eval/image_info.pkl", "rb") as f: 392 | self.info = pickle.load(f) 393 | 394 | def __getitem__(self, index): 395 | info = self.info[index] 396 | image_path, size = info['image_path'], info['size'] 397 | with open(image_path, "rb") as f: 398 | image = pickle.load(f) 399 | return {"image": image, "size": size} 400 | 401 | def __len__(self): 402 | return len(self.info) 403 | 404 | def get_highres_eval_dataset(data_args, model_args): 405 | data = HighresEvalDataset() 406 | data_collator = WebdatasetCollator(model_args.patch_size) 407 | 408 | return data, data_collator -------------------------------------------------------------------------------- /data/image_data.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import copy 3 | import os 4 | import json 5 | import pickle 6 | from functools import partial 7 | from typing import Sequence, Dict 8 | from dataclasses import dataclass 9 | import numpy as np 10 | 11 | from PIL import Image 12 | 13 | import torch 14 | from torch.utils.data import Dataset 15 | import transformers 16 | from datasets import load_dataset 17 | from torchvision import transforms 18 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 19 | CenterCrop 20 | 21 | OPENAI_DATASET_MEAN = np.array([0.48145466, 0.4578275, 0.40821073]) 22 | OPENAI_DATASET_STD = np.array([0.26862954, 0.26130258, 0.27577711]) 23 | DEFAULT_IMAGE_FILE_SUFFIX = ['jpg', '0.jpg', '0.png', 'png', 'jpeg', '0.jpeg', 'webp'] 24 | 25 | def find_image(sample): 26 | for suffix in DEFAULT_IMAGE_FILE_SUFFIX: 27 | if suffix in sample.keys(): 28 | sample['0.jpg'] = sample[suffix] 29 | break 30 | return sample 31 | 32 | # remove_columns=['__key__', '__url__', '0.txt', 'original_prompt'] 33 | def get_wds_dataset_and_collator(data_args, model_args): 34 | img_size = model_args.image_size 35 | train_processor = image_transform(img_size, is_train=True) 36 | val_processor = image_transform(img_size, is_train=False) 37 | 38 | data = load_dataset("webdataset", data_dir=data_args.dataset_path, split="train", streaming=True) 39 | data = data.shuffle(buffer_size=2_000, seed=data_args.seed) 40 | 41 | def decode(sample, img_processor): 42 | sample = find_image(sample) 43 | sample['0.jpg'] = img_processor(sample['0.jpg']) 44 | return sample 45 | data = data.map( 46 | partial(decode, img_processor=train_processor), 47 | # remove_columns=['__key__', '__url__', '0.txt', 'seg.txt'] 48 | remove_columns=['__key__', '__url__'] 49 | ) 50 | data = data.filter(lambda sample: '0.jpg' in sample) # filter return samples that match the given condition 51 | data = data.rename_columns({'0.jpg': 'image'}) 52 | data_collator = WebdatasetCollator(model_args.patch_size) 53 | 54 | return data, data_collator 55 | 56 | def image_transform_original_resolution( 57 | image, 58 | patch_size: int, 59 | ): 60 | """accept a pil image and transform into torch.tensor""" 61 | w, h = map(lambda x: x // patch_size * patch_size, image.size) 62 | if w > 1024: 63 | h = int(h / (w / 1024) // patch_size * patch_size) 64 | w = 1024 65 | elif h > 1024: 66 | w = int(w / (h / 1024) // patch_size * patch_size) 67 | h = 1024 68 | def _convert_to_rgb(image): 69 | return image.convert('RGB') 70 | normalize = transforms.Normalize( 71 | mean=OPENAI_DATASET_MEAN, 72 | std=OPENAI_DATASET_STD 73 | ) 74 | transform = transforms.Compose([ 75 | transforms.CenterCrop((h, w)), 76 | _convert_to_rgb, 77 | transforms.ToTensor(), 78 | normalize, 79 | ]) 80 | ph, pw = h // patch_size, w // patch_size 81 | return transform(image), (ph, pw) 82 | 83 | def get_wds_dataset_and_collator_arbitrary_resolution(data_args, model_args): 84 | data = load_dataset("webdataset", data_dir=data_args.dataset_path, split="train", streaming=True) 85 | data = data.shuffle(buffer_size=2_000, seed=data_args.seed) 86 | 87 | def decode_sample(sample, img_processor): 88 | sample = find_image(sample) 89 | sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg']) 90 | return sample 91 | 92 | data = data.map( 93 | partial( 94 | decode_sample, 95 | img_processor=partial(image_transform_original_resolution, patch_size=model_args.patch_size) 96 | ), 97 | remove_columns=['__key__', '__url__'] 98 | ) 99 | data = data.filter(lambda sample: '0.jpg' in sample and sample['0.jpg'].ndim == 3 and sample['0.jpg'].shape[-1] > 0 and sample['0.jpg'].shape[-2] > 0) # filter return samples that match the given condition 100 | data = data.rename_columns({'0.jpg': 'image'}) 101 | data_collator = WebdatasetCollator(model_args.patch_size) 102 | 103 | return data, data_collator 104 | 105 | def dataset_test(data_args, model_args): 106 | from datasets import load_dataset 107 | import numpy as np 108 | from functools import partial 109 | from torchvision import transforms 110 | OPENAI_DATASET_MEAN = np.array([0.48145466, 0.4578275, 0.40821073]) 111 | OPENAI_DATASET_STD = np.array([0.26862954, 0.26130258, 0.27577711]) 112 | DEFAULT_IMAGE_FILE_SUFFIX = ['jpg', '0.jpg', '0.png', 'png', 'jpeg', '0.jpeg', 'webp'] 113 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/*/*.tar", split="train", streaming=True) 114 | data = data.shuffle(buffer_size=2_000, seed=1) 115 | 116 | data_iter = iter(data) 117 | 118 | def decode_sample(sample, img_processor): 119 | sample = find_image(sample) 120 | sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg']) 121 | return sample 122 | 123 | def image_transform_original_resolution( 124 | image, 125 | patch_size: int, 126 | ): 127 | w, h = map(lambda x: x // patch_size * patch_size, image.size) 128 | def _convert_to_rgb(image): 129 | return image.convert('RGB') 130 | normalize = transforms.Normalize( 131 | mean=OPENAI_DATASET_MEAN, 132 | std=OPENAI_DATASET_STD 133 | ) 134 | transform = transforms.Compose([ 135 | transforms.Resize((h, w), interpolation=transforms.InterpolationMode.BICUBIC), 136 | _convert_to_rgb, 137 | transforms.ToTensor(), 138 | normalize, 139 | ]) 140 | return transform(image) 141 | 142 | data = data.map( 143 | partial( 144 | decode_sample, 145 | img_processor=partial(image_transform_original_resolution, patch_size=16) 146 | ), 147 | remove_columns=['__key__', '__url__'] 148 | ) 149 | data = data.filter(lambda sample: '0.jpg' in sample) # filter return samples that match the given condition 150 | data = data.rename_columns({'0.jpg': 'image'}) 151 | data_collator = WebdatasetCollator() 152 | 153 | return data, data_collator 154 | 155 | def collate_anyres(images, sizes, patch_size): 156 | """ 157 | Args: 158 | * images: list of images 159 | * sizes: list of image sizes in (ph, pw), i.e., number of patches in h and w 160 | 161 | Return: args accepted by VQModel 162 | * pixel_values: packed images 163 | * cu_seqlens_img: 164 | * max_seqlen_img 165 | * grid_hw 166 | * image_sizes 167 | """ 168 | b, c = len(images), images[0].shape[0] 169 | max_patch_num = 1024 // patch_size 170 | 171 | image_sizes = torch.tensor([(image.shape[1], image.shape[2]) for image in images]) 172 | H, W = image_sizes.max(dim=0).values 173 | padded_images = images[0].new_zeros(size=(b, c, H.item(), W.item())) 174 | 175 | h, w = torch.tensor(sizes).max(dim=0).values 176 | padding_masks = torch.zeros(size=(b, h.item(), w.item()), dtype=torch.bool) 177 | 178 | for i, (image, mask_size) in enumerate(zip(images, sizes)): 179 | padded_images[i, :, : image.shape[1], : image.shape[2]].copy_(image) 180 | padding_masks[i, : mask_size[0], : mask_size[1]] = 1 181 | 182 | padded_images = padded_images.reshape(b, c, h, patch_size, w, patch_size) 183 | padded_images = torch.einsum("nchpwq->nhwpqc", padded_images) 184 | padded_images = padded_images.reshape(b, h, w, -1) 185 | packed_images = padded_images[padding_masks] 186 | 187 | seq_lens = padding_masks.flatten(1, 2).sum(dim=-1) 188 | cu_seqlens_img = torch.nn.functional.pad( 189 | torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) 190 | ) 191 | max_seqlen_img = seq_lens.max() 192 | 193 | grid_h = torch.arange(0, h)[None, :, None].repeat(b, 1, w) 194 | grid_w = torch.arange(0, w)[None, None, :].repeat(b, h, 1) 195 | grid_hw = grid_h[padding_masks] * max_patch_num + grid_w[padding_masks] 196 | 197 | return packed_images, cu_seqlens_img, max_seqlen_img, grid_hw, torch.tensor(sizes) 198 | 199 | @dataclass 200 | class WebdatasetCollator: 201 | patch_size: int 202 | def __call__(self, samples: Sequence[Dict]) -> Dict[str, torch.Tensor]: 203 | images = [sample["image"] for sample in samples] 204 | if "size" in samples[0]: 205 | sizes = [sample['size'] for sample in samples] 206 | 207 | batch = {} 208 | 209 | if all(x is not None and x.shape == images[0].shape for x in images): 210 | batch['pixel_values'] = torch.stack(images) 211 | else: 212 | batch['pixel_values'], batch['cu_seqlens_img'], \ 213 | batch['max_seqlen_img'], batch['grid_hw'], \ 214 | batch['image_sizes'] = collate_anyres(images, sizes, self.patch_size) 215 | 216 | # print(f"{[image.shape for image in batch['pixel_values']]=}") 217 | return batch 218 | 219 | def anyres_process_images_for_model(image_path=None, pil_image=None, patch_size=32): 220 | """ 221 | given a list of image_path or pil_image, transform to input to model 222 | """ 223 | if image_path is not None: 224 | assert pil_image is None 225 | if not isinstance(image_path, list): 226 | image_path = [image_path] 227 | pil_image = [] 228 | for p in image_path: 229 | pil_image.append(Image.open(p).convert('RGB')) 230 | if not isinstance(pil_image, list): 231 | pil_image = [pil_image] 232 | 233 | if len(pil_image) % 2 != 0: 234 | pil_image.append(pil_image[-1]) 235 | 236 | image_tensors, sizes = [], [] 237 | for pil_i in pil_image: 238 | image_tensor, size = image_transform_original_resolution(image=pil_i, patch_size=patch_size) 239 | image_tensors.append(image_tensor) 240 | sizes.append(size) 241 | 242 | pixel_values, cu_seqlens_img, max_seqlen_img, grid_hw, image_sizes = collate_anyres(image_tensors, sizes, patch_size) 243 | 244 | return { 245 | 'pixel_values': pixel_values, 246 | 'cu_seqlens_img': cu_seqlens_img, 247 | 'max_seqlen_img': max_seqlen_img, 248 | 'grid_hw': grid_hw, 249 | 'image_sizes': image_sizes 250 | } 251 | 252 | def get_in1k_dataset(data_args, model_args): 253 | import torchvision 254 | transform = image_transform(model_args.image_size, is_train=False) 255 | dataset = torchvision.datasets.ImageFolder(root="/share/project/datasets/ImageNet/val", transform=transform) 256 | def in1k_collator(samples): 257 | if model_args.gan_loss_weight: 258 | return {"pixel_values": torch.stack([sample[0] for sample in samples]), "optimizer_idx": 0} 259 | return {"pixel_values": torch.stack([sample[0] for sample in samples])} 260 | def in1k_collator_anyres(samples): 261 | images = [sample[0] for sample in samples] 262 | sizes = [[image.shape[1] // model_args.patch_size, image.shape[2] // model_args.patch_size] for image in images] 263 | b, c = len(images), images[0].shape[0] 264 | max_patch_num = 1024 // model_args.patch_size 265 | 266 | image_sizes = torch.tensor([(image.shape[1], image.shape[2]) for image in images]) 267 | H, W = image_sizes.max(dim=0).values 268 | padded_images = images[0].new_zeros(size=(b, c, H.item(), W.item())) 269 | 270 | h, w = torch.tensor(sizes).max(dim=0).values 271 | padding_masks = torch.zeros(size=(b, h.item(), w.item()), dtype=torch.bool) 272 | 273 | for i, (image, mask_size) in enumerate(zip(images, sizes)): 274 | padded_images[i, :, : image.shape[1], : image.shape[2]].copy_(image) 275 | padding_masks[i, : mask_size[0], : mask_size[1]] = 1 276 | 277 | padded_images = padded_images.reshape(b, c, h, model_args.patch_size, w, model_args.patch_size) 278 | padded_images = torch.einsum("nchpwq->nhwpqc", padded_images) 279 | padded_images = padded_images.reshape(b, h, w, -1) 280 | packed_images = padded_images[padding_masks] 281 | 282 | seq_lens = padding_masks.flatten(1, 2).sum(dim=-1) 283 | cu_seqlens_img = torch.nn.functional.pad( 284 | torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) 285 | ) 286 | max_seqlen_img = seq_lens.max() 287 | 288 | grid_h = torch.arange(0, h)[None, :, None].repeat(b, 1, w) 289 | grid_w = torch.arange(0, w)[None, None, :].repeat(b, h, 1) 290 | grid_hw = grid_h[padding_masks] * max_patch_num + grid_w[padding_masks] 291 | 292 | batch = {} 293 | batch['pixel_values'] = packed_images 294 | batch['cu_seqlens_img'] = cu_seqlens_img 295 | batch['max_seqlen_img'] = max_seqlen_img 296 | batch['grid_hw'] = grid_hw 297 | batch['image_sizes'] = torch.tensor(sizes) 298 | if model_args.gan_loss_weight: 299 | batch["optimizer_idx"] = 0 300 | return batch 301 | # dataset = load_dataset("imagefolder", data_dir="/share/project/qiying/datasets/ImageNet/ImageNet/val")['validation'] 302 | # transform = image_transform(model_args.image_size, is_train=False) 303 | # def transforms(examples): 304 | # examples["pixel_values"] = [transform(image) for image in examples["pixel_values"]] 305 | # return examples 306 | # dataset.set_transform(transforms) 307 | # dataset = dataset.remove_columns(["label"]) 308 | # dataset = dataset.rename_column('image', 'pixel_values') 309 | return dataset, in1k_collator_anyres if data_args.arbitrary_resolution else in1k_collator 310 | 311 | def get_highres_eval_dataset(data_args, model_args): 312 | # data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/eval/eval_*.tar", split="train", streaming=True) 313 | 314 | # def decode_sample(sample, img_processor): 315 | # sample = find_image(sample) 316 | # sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg']) 317 | # return sample 318 | 319 | # data = data.map( 320 | # partial( 321 | # decode_sample, 322 | # img_processor=partial(image_transform_original_resolution, patch_size=model_args.patch_size) 323 | # ), 324 | # remove_columns=['__key__', '__url__'], 325 | # ) 326 | # data = data.filter(lambda sample: '0.jpg' in sample and sample['0.jpg'].ndim == 3 and sample['0.jpg'].shape[-1] > 0 and sample['0.jpg'].shape[-2] > 0) # filter return samples that match the given condition 327 | # data = data.rename_columns({'0.jpg': 'image'}) 328 | data = HighresEvalDataset() 329 | data_collator = WebdatasetCollator(model_args.patch_size) 330 | 331 | return data, data_collator 332 | 333 | def image_transform( 334 | image_size: int, 335 | is_train: bool, 336 | ): 337 | mean = OPENAI_DATASET_MEAN 338 | std = OPENAI_DATASET_STD 339 | def _convert_to_rgb(image): 340 | return image.convert('RGB') 341 | normalize = transforms.Normalize(mean=mean, std=std) 342 | if is_train: 343 | return transforms.Compose([ 344 | transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 345 | _convert_to_rgb, 346 | transforms.ToTensor(), 347 | normalize, 348 | ]) 349 | else: 350 | return transforms.Compose([ 351 | transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC), 352 | transforms.CenterCrop(image_size), 353 | _convert_to_rgb, 354 | transforms.ToTensor(), 355 | normalize, 356 | ]) 357 | 358 | 359 | def norm_vq_img(img): 360 | arr = np.array(img) 361 | arr = arr.astype(np.float32) / 127.5 - 1 362 | img = torch.from_numpy(np.transpose(arr, [2, 0, 1])) 363 | return img 364 | 365 | 366 | def prepare_image(img): 367 | """ Transform and normalize PIL Image to tensor. """ 368 | transform = transforms.Compose([ 369 | transforms.RandomResizedCrop(512, scale=(1., 1.), ratio=(1., 1.), interpolation=InterpolationMode.BICUBIC), 370 | ]) 371 | pil_image = transform(img) 372 | arr = np.array(pil_image.convert("RGB")) 373 | arr = arr.astype(np.float32) / 127.5 - 1 374 | return torch.from_numpy(np.transpose(arr, [2, 0, 1])) 375 | 376 | 377 | def image_transform_for_vq( 378 | image_size: int, 379 | is_train: bool, 380 | ): 381 | 382 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 383 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 384 | image_size = image_size[0] 385 | 386 | if is_train: 387 | return Compose([ 388 | RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.), interpolation=InterpolationMode.BICUBIC), 389 | _convert_to_rgb, 390 | norm_vq_img, 391 | ]) 392 | else: 393 | transforms = [ 394 | RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.), interpolation=InterpolationMode.BICUBIC), 395 | _convert_to_rgb, 396 | norm_vq_img 397 | ] 398 | return Compose(transforms) 399 | 400 | def split_val_set(): 401 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/*/*.tar", split="train", streaming=True) 402 | data = data.shuffle(buffer_size=100_000, seed=100) 403 | val_set = [] 404 | val_ids = set() 405 | for i, item in enumerate(data): 406 | item_id = item['__url__'] + item['__key__'] 407 | if i == 50_000: 408 | break 409 | val_set.append(item) 410 | val_ids.add(item_id) 411 | # with open("/share/project/datasets/laion-high-resolution/50k_eval.pkl", "wb") as f: 412 | # pickle.dump(val_set, f) 413 | with open("/share/project/datasets/laion-high-resolution/50k_eval_ids.pkl", "wb") as f: 414 | pickle.dump(val_ids, f) 415 | 416 | import webdataset as wds 417 | from PIL import Image 418 | from pathlib import Path 419 | for i in range(50): 420 | sink = wds.TarWriter(f"/share/project/datasets/laion-high-resolution/eval/eval_{i}.tar") 421 | for sample in val_set[i * 1000: (i + 1) * 1000]: 422 | sink.write(sample) 423 | sink.close() 424 | 425 | def preprocess_val_data(): 426 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/eval/*.tar", split="train", streaming=True) 427 | save_dir = "/share/project/datasets/laion-high-resolution/eval/" 428 | info = [] 429 | def decode_sample(sample): 430 | sample = find_image(sample) 431 | sample['0.jpg'], sample['size'] = image_transform_original_resolution(sample['0.jpg'], patch_size=32) 432 | return sample 433 | for i, item in tqdm(enumerate(data)): 434 | image_file = save_dir + f"image_{i}.pkl" 435 | # process image -> size, image pkl 436 | item = decode_sample(item) 437 | with open(image_file, "wb") as f: 438 | pickle.dump(item['0.jpg'], f) 439 | info.append({ 440 | "image_path": image_file, 441 | "size": item['size'] 442 | }) 443 | print(i) 444 | with open(save_dir + f"image_info.pkl", "wb") as f: 445 | pickle.dump(info, f) 446 | 447 | class HighresEvalDataset(Dataset): 448 | def __init__(self): 449 | with open("/share/project/datasets/laion-high-resolution/eval/image_info.pkl", "rb") as f: 450 | self.info = pickle.load(f) 451 | 452 | def __getitem__(self, index): 453 | info = self.info[index] 454 | image_path, size = info['image_path'], info['size'] 455 | with open(image_path, "rb") as f: 456 | image = pickle.load(f) 457 | return {"image": image, "size": size} 458 | 459 | def __len__(self): 460 | return len(self.info) 461 | 462 | # preprocess_val_data() 463 | # split_val_set() 464 | 465 | # def find_image(sample): 466 | # for suffix in DEFAULT_IMAGE_FILE_SUFFIX: 467 | # if suffix in sample.keys(): 468 | # sample['0.jpg'] = sample[suffix] 469 | # break 470 | # return sample 471 | 472 | # def decode_sample(sample, img_processor): 473 | # sample = find_image(sample) 474 | # sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg']) 475 | # return sample 476 | 477 | # data = data.map( 478 | # partial( 479 | # decode_sample, 480 | # img_processor=partial(image_transform_original_resolution, patch_size=model_args.patch_size) 481 | # ), 482 | # remove_columns=['__key__', '__url__'] 483 | # ) 484 | # data = data.filter(lambda sample: '0.jpg' in sample and sample['0.jpg'].ndim == 3 and sample['0.jpg'].shape[-1] > 0 and sample['0.jpg'].shape[-2] > 0) # filter return samples that match the given condition 485 | # data = data.rename_columns({'0.jpg': 'image'}) 486 | # data_collator = WebdatasetCollator(model_args.patch_size) 487 | 488 | # return data, data_collator -------------------------------------------------------------------------------- /data/register.py: -------------------------------------------------------------------------------- 1 | class Registry: 2 | mapping = { 3 | "data_name_mapping": {}, 4 | } 5 | 6 | @classmethod 7 | def data_builder(cls, name): 8 | def wrap(data): 9 | cls.mapping["data_name_mapping"][name] = data 10 | return data 11 | return wrap 12 | 13 | @classmethod 14 | def get_data_builder(cls, name): 15 | return cls.mapping["data_name_mapping"].get(name, None) 16 | 17 | 18 | registry = Registry() 19 | -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Dict, Union, Tuple 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torchvision import transforms as T 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 8 | CenterCrop 9 | 10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 11 | 12 | def to_tensor(image): 13 | try: 14 | image = ToTensor()(image) 15 | except Exception as e: 16 | image = image.float() 17 | print(e) 18 | return image 19 | 20 | def _convert_to_rgb(image): 21 | try: 22 | image = image.convert('RGB') 23 | except Exception as e: 24 | print(e) 25 | return image 26 | 27 | def image_transform_original_resolution( 28 | image, 29 | patch_size: int, 30 | max_size:int = 2048 31 | ): 32 | """accept a pil image and transform into torch.tensor""" 33 | w, h = map(lambda x: x // patch_size * patch_size, image.size) 34 | if max(w, h) > max_size: 35 | if w > h: 36 | h = int(h / (w / max_size) // patch_size * patch_size) 37 | w = max_size 38 | else: 39 | w = int(w / (h / max_size) // patch_size * patch_size) 40 | h = max_size 41 | 42 | def _convert_to_rgb(image): 43 | return image.convert('RGB') 44 | 45 | normalize = Normalize( 46 | mean=OPENAI_DATASET_MEAN, 47 | std=OPENAI_DATASET_STD 48 | ) 49 | transform = Compose([ 50 | CenterCrop((h, w)), 51 | _convert_to_rgb, 52 | to_tensor, 53 | normalize, 54 | ]) 55 | ph, pw = h // patch_size, w // patch_size 56 | return transform(image), (ph, pw) 57 | 58 | def image_transform_original_resolution_test( 59 | image, 60 | patch_size: int, 61 | ): 62 | w, h = map(lambda x: x // patch_size * patch_size, image.size) 63 | normalize = Normalize( 64 | mean=OPENAI_DATASET_MEAN, 65 | std=OPENAI_DATASET_STD 66 | ) 67 | transform = Compose([ 68 | Resize((h, w), interpolation=InterpolationMode.BICUBIC), 69 | _convert_to_rgb, 70 | to_tensor, 71 | normalize, 72 | ]) 73 | return transform(image) 74 | 75 | def image_transform( 76 | image_size: Union[int, Tuple[int, int]], 77 | is_train: bool, 78 | mean: Optional[Tuple[float, ...]] = None, 79 | std: Optional[Tuple[float, ...]] = None, 80 | ): 81 | mean = mean or OPENAI_DATASET_MEAN 82 | if not isinstance(mean, (list, tuple)): 83 | mean = (mean,) * 3 84 | 85 | std = std or OPENAI_DATASET_STD 86 | if not isinstance(std, (list, tuple)): 87 | std = (std,) * 3 88 | 89 | normalize = Normalize(mean=mean, std=std) 90 | 91 | if is_train: 92 | return Compose([ 93 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 94 | _convert_to_rgb, 95 | # ToTensor(), 96 | to_tensor, 97 | normalize, 98 | ]) 99 | else: 100 | return Compose([ 101 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 102 | CenterCrop(image_size), 103 | _convert_to_rgb, 104 | # ToTensor(), 105 | to_tensor, 106 | normalize, 107 | ]) 108 | 109 | def norm_img_vq(img): 110 | arr = np.array(img) 111 | arr = arr.astype(np.float32) / 127.5 - 1 112 | img = torch.from_numpy(np.transpose(arr, [2, 0, 1])) 113 | return img 114 | 115 | 116 | def prepare_image(img): 117 | """ Transform and normalize PIL Image to tensor. """ 118 | transform = Compose([ 119 | RandomResizedCrop(512, scale=(1., 1.), ratio=(1., 1.), interpolation=InterpolationMode.BICUBIC), 120 | ]) 121 | pil_image = transform(img) 122 | arr = np.array(pil_image.convert("RGB")) 123 | arr = arr.astype(np.float32) / 127.5 - 1 124 | return torch.from_numpy(np.transpose(arr, [2, 0, 1])) 125 | 126 | 127 | def image_transform_vq( 128 | image_size: Union[int, Tuple[int, int]], 129 | is_train: bool, 130 | ): 131 | 132 | # if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 133 | # # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 134 | # image_size = image_size[0] 135 | def _convert_to_rgb(image): 136 | return image.convert('RGB') 137 | 138 | return Compose([ 139 | RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.), interpolation=InterpolationMode.BICUBIC), 140 | _convert_to_rgb, 141 | norm_img_vq 142 | ]) 143 | 144 | 145 | def DiffAugment(x, policy='color,translation,cutout', is_tensor=True, channels_first=True): 146 | if policy: 147 | if not is_tensor and not channels_first: 148 | x = x.permute(0, 3, 1, 2) 149 | for p in policy.split(','): 150 | for f in AUGMENT_FNS[p]: 151 | x = f(x) 152 | if not is_tensor and not channels_first: 153 | x = x.permute(0, 2, 3, 1) 154 | x = x.contiguous() 155 | return x 156 | 157 | 158 | def rand_brightness(x): 159 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 160 | return x 161 | 162 | 163 | def rand_saturation(x): 164 | x_mean = x.mean(dim=1, keepdim=True) 165 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 166 | return x 167 | 168 | 169 | def rand_contrast(x): 170 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 171 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 172 | return x 173 | 174 | 175 | def rand_translation(x, ratio=0.125): 176 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 177 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 178 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 179 | grid_batch, grid_x, grid_y = torch.meshgrid( 180 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 181 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 182 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 183 | ) 184 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 185 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 186 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 187 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() 188 | return x 189 | 190 | 191 | def rand_cutout(x, ratio=0.5): 192 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 193 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 194 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 195 | grid_batch, grid_x, grid_y = torch.meshgrid( 196 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 197 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 198 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 199 | ) 200 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 201 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 202 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 203 | mask[grid_batch, grid_x, grid_y] = 0 204 | x = x * mask.unsqueeze(1) 205 | return x 206 | 207 | 208 | AUGMENT_FNS = { 209 | 'color': [rand_brightness, rand_saturation, rand_contrast], 210 | 'translation': [rand_translation], 211 | 'cutout': [rand_cutout], 212 | } -------------------------------------------------------------------------------- /models/CLIP_bank.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch.nn as nn 3 | from open_clip import create_model_from_pretrained, create_model_and_transforms 4 | 5 | 6 | class OpenAICLIP(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | 10 | if config.clip_image_size == 224: 11 | model, _ = clip.load("pretrained_weights/CLIP/ViT-L-14.pt", jit=False) 12 | if config.clip_image_size == 336: 13 | model, _ = clip.load("pretrained_weights/CLIP/ViT-L-14-336px.pt",jit=False) 14 | 15 | self.final_fc = nn.Linear(768, config.actual_bs, bias=False) 16 | self.model = model 17 | self.config = config 18 | 19 | def forward(self, images): 20 | 21 | image_features = self.model.encode_image(images).float() 22 | logits = 100. * self.final_fc(image_features[:,0,:]).float() 23 | 24 | return image_features, logits 25 | 26 | 27 | class DFN(nn.Module): 28 | def __init__(self, config): 29 | super().__init__() 30 | 31 | if config.clip_image_size == 224: 32 | model, _ = create_model_from_pretrained(model_name='ViT-H-14-quickgelu', pretrained="pretrained_weights/CLIP/DFN5B-CLIP-ViT-H-14/open_clip_pytorch_model.bin") 33 | if config.clip_image_size == 378: 34 | model, _ = create_model_from_pretrained(model_name='ViT-H-14-378-quickgelu', pretrained="pretrained_weights/CLIP/DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin") 35 | 36 | self.final_fc = nn.Linear(1024, config.actual_bs, bias=False) 37 | self.model = model 38 | self.config = config 39 | 40 | def forward(self, images): 41 | 42 | image_features = self.model.encode_image(images).float() 43 | logits = 100. * self.final_fc(image_features[:,0,:]).float() 44 | 45 | return image_features, logits 46 | 47 | 48 | class SigLIP(nn.Module): 49 | def __init__(self, config): 50 | super().__init__() 51 | 52 | if config.clip_image_size == 224: 53 | model, _ = create_model_from_pretrained(model_name='ViT-SO400M-14-SigLIP', pretrained="pretrained_weights/CLIP/ViT-SO400M-14-SigLIP/open_clip_pytorch_model.bin", 54 | image_mean=([0.5,0.5,0.5]), image_std=([0.5,0.5,0.5]), image_interpolation="bicubic", image_resize_mode="squash") 55 | if config.clip_image_size == 384: 56 | model, _ = create_model_from_pretrained(model_name='ViT-SO400M-14-SigLIP-384', pretrained="pretrained_weights/CLIP/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin", 57 | image_mean=([0.5,0.5,0.5]), image_std=([0.5,0.5,0.5]), image_interpolation="bicubic", image_resize_mode="squash") 58 | 59 | self.final_fc = nn.Linear(1152, config.actual_bs, bias=False) 60 | self.model = model 61 | self.config = config 62 | 63 | def forward(self, images): 64 | 65 | image_features = self.model.encode_image(images).float() 66 | logits = 100. * self.final_fc(image_features[:,0,:]).float() 67 | 68 | return image_features, logits 69 | 70 | 71 | class MetaCLIP(nn.Module): 72 | def __init__(self, config): 73 | super().__init__() 74 | 75 | if config.metaclip_version == "large": 76 | model, _, _ = create_model_and_transforms(model_name='ViT-L-14-quickgelu', pretrained="pretrained_weights/CLIP/MetaCLIP/l14_fullcc2.5b.pt") 77 | self.final_fc = nn.Linear(768, config.actual_bs, bias=False) 78 | if config.metaclip_version == "huge": 79 | model, _, _ = create_model_and_transforms(model_name='ViT-H-14-quickgelu', pretrained="pretrained_weights/CLIP/MetaCLIP/h14_fullcc2.5b.pt") 80 | self.final_fc = nn.Linear(1024, config.actual_bs, bias=False) 81 | 82 | self.model = model 83 | self.config = config 84 | 85 | def forward(self, images): 86 | 87 | image_features = self.model.encode_image(images).float() 88 | logits = 100. * self.final_fc(image_features[:,0,:]).float() 89 | 90 | return image_features, logits 91 | -------------------------------------------------------------------------------- /models/SD_with_DFN.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | from dataclasses import dataclass 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | import numpy as np 9 | import torchvision 10 | from transformers.utils import ModelOutput 11 | from transformers.modeling_utils import PreTrainedModel 12 | from .build import load_sd_model, load_clip_model_DFN 13 | from .utils import initiate_time_steps, prepare_class_text_embeddings 14 | import torchvision.transforms as transforms 15 | from open_clip import get_tokenizer 16 | 17 | @dataclass 18 | class SDOutput(ModelOutput): 19 | loss: Optional[torch.FloatTensor] = None 20 | 21 | class SDModel(PreTrainedModel): 22 | def __init__( 23 | self, 24 | config = None, 25 | ): 26 | super().__init__(config) 27 | 28 | self.model_id, self.pipe, self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler, self.image_renormalizer = load_sd_model(config) 29 | self.text_encoder.eval() 30 | self.vae.eval() 31 | self.unet.eval() 32 | 33 | self.pattern_dictionary={'None':['']} 34 | self.config.actual_bs = len(self.pattern_dictionary[self.config.visual_pattern]) 35 | self.class_model = load_clip_model_DFN(config) 36 | self.class_model.eval() 37 | self.config = config 38 | discrimi_size = self.config.clip_image_size 39 | self.resize_transform_discrimi = transforms.Resize((discrimi_size, discrimi_size)) 40 | self.visual_proj = nn.Linear(1024, 1024) 41 | 42 | 43 | def classify(self, image, classes): 44 | 45 | image_features, logits = self.class_model(image) 46 | 47 | if classes is not None: 48 | logits = logits[:, classes] 49 | 50 | probs = logits.softmax(-1) 51 | max_idx = probs.argmax(-1) 52 | K = probs.shape[-1] if self.config.tta.adapt_topk == -1 else self.config.tta.adapt_topk 53 | topk_idx = probs.argsort(descending=True)[:, :K] 54 | 55 | if classes is not None: 56 | classes = torch.tensor(classes).to(logits.device) 57 | max_class_idx = classes[max_idx.flatten()].view(max_idx.shape) 58 | topk_class_idx = classes[topk_idx.flatten()].view(topk_idx.shape) 59 | else: 60 | max_class_idx, topk_class_idx = max_idx, topk_idx 61 | 62 | return image_features, logits, topk_idx, max_class_idx, topk_class_idx 63 | 64 | def _unet_pred_noise(self, x_start, t, noise, context): 65 | 66 | _,c,h,w = x_start.shape 67 | device = t.device 68 | nt = t.shape[0] 69 | 70 | x_start = x_start.unsqueeze(1) 71 | x_start = x_start.expand(-1, nt//x_start.shape[0], -1, -1, -1) 72 | x_start = x_start.reshape(-1,c,h,w) 73 | 74 | alphas_cumprod = self.scheduler.alphas_cumprod.to(device) 75 | noised_latent = ( 76 | x_start * (alphas_cumprod[t]**0.5).view(-1, 1, 1, 1).to(device) 77 | + noise * ((1 - alphas_cumprod[t])**0.5).view(-1, 1, 1, 1).to(device) 78 | ) 79 | pred_noise = self.unet(noised_latent, t, encoder_hidden_states=context.expand(nt, -1, -1)).sample 80 | 81 | return pred_noise 82 | 83 | def zeroshot_classifier_DFN(self, classnames, templates, model): 84 | with torch.no_grad(): 85 | zeroshot_weights = [] 86 | tokenizer = get_tokenizer('ViT-H-14') 87 | for classname in classnames: 88 | texts = [template.format(classname) for template in templates] 89 | texts = tokenizer(texts, context_length=model.context_length).cuda() 90 | class_embeddings = model.encode_text(texts) 91 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 92 | class_embedding = class_embeddings.mean(dim=0) 93 | class_embedding /= class_embedding.norm() 94 | zeroshot_weights.append(class_embedding) 95 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() 96 | return zeroshot_weights 97 | 98 | def forward( 99 | self, 100 | image: torch.Tensor = None, 101 | text = None 102 | ) -> SDOutput: 103 | 104 | text = self.pattern_dictionary[self.config.visual_pattern] 105 | with torch.no_grad(): 106 | imagenet_templates = ['{}',] 107 | zeroshot_weights = self.zeroshot_classifier_DFN(text, imagenet_templates, self.class_model.model.float()).float() 108 | 109 | self.class_model.final_fc.weight.data = zeroshot_weights.T 110 | self.class_model.final_fc.weight.data = self.class_model.final_fc.weight.data.contiguous() 111 | classes = [i for i in range(len(text))] 112 | 113 | discrimi_image = self.resize_transform_discrimi(image) 114 | genera_image = image 115 | real_BS = image.shape[0] 116 | after_DF_expand_BS = real_BS*self.config.input.batch_size 117 | 118 | # prepare_vae_latent 119 | self.vae, self.text_encoder, self.unet = self.vae.to(torch.float32), self.text_encoder.to(torch.float32), self.unet.to(torch.float32) 120 | renormed_image = self.image_renormalizer(genera_image).detach() 121 | x0 = self.vae.encode(renormed_image).latent_dist.mean.float() 122 | latent = x0 * 0.18215 123 | 124 | # prepare_total_timesteps 125 | total_timestep = self.scheduler.num_train_timesteps 126 | 127 | for step in range(self.config.tta.gradient_descent.train_steps): 128 | # Initiate timesteps and noise 129 | timesteps = initiate_time_steps(step, total_timestep, after_DF_expand_BS, self.config).long() 130 | timesteps = timesteps.cuda() 131 | 132 | c, h, w = latent.shape[1:] 133 | if not self.config.tta.use_same_noise_among_timesteps: 134 | noise = torch.randn((real_BS* self.config.input.batch_size, c, h, w)).cuda() 135 | else: 136 | noise = torch.randn((1, c, h, w)).cuda() 137 | noise = noise.repeat(real_BS* self.config.input.batch_size, 1, 1, 1) 138 | 139 | if self.config.tta.adapt_topk == -1: 140 | image_features, logits, _, _, _ = self.classify(discrimi_image, classes) 141 | pred_top_idx = None 142 | else: 143 | image_features, logits, pred_top_idx, _, _ = self.classify(discrimi_image, classes) 144 | real_BS, C = logits.shape[:2] 145 | 146 | # Pick top-K predictions 147 | if pred_top_idx is not None: 148 | pred_top_idx = pred_top_idx.squeeze(0) 149 | else: 150 | pred_top_idx = torch.arange(C).cuda() 151 | 152 | logits = logits[:, pred_top_idx] 153 | 154 | class_text_embeddings = prepare_class_text_embeddings(self.tokenizer, self.text_encoder, class_names=text) 155 | class_text_embeddings = class_text_embeddings.detach() 156 | class_text_embeddings = class_text_embeddings[pred_top_idx, :] 157 | 158 | # Compute conditional text embeddings using weighted-summed predictions 159 | probs = logits.softmax(-1) 160 | probs = probs[:, :, None, None] 161 | class_text_embeddings = (class_text_embeddings.unsqueeze(0).repeat(after_DF_expand_BS, 1, 1, 1)) 162 | _, word_num, _, _ = probs.shape 163 | probs = probs.unsqueeze(1).repeat(1,self.config.input.batch_size,1,1,1).reshape(-1,word_num,1,1) 164 | context = (probs * class_text_embeddings).sum(1) 165 | image_features = self.visual_proj(image_features) 166 | context = context.mean(dim=1).unsqueeze(1) + image_features 167 | 168 | # Predict noise with the diffusion model 169 | pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=context).float() 170 | 171 | # Compute diffusion loss 172 | if self.config.tta.loss == "l1": 173 | loss = torch.nn.functional.l1_loss(pred_noise, noise) 174 | else: 175 | loss = torch.nn.functional.mse_loss(pred_noise, noise) 176 | 177 | if step != (self.config.tta.gradient_descent.train_steps-1): 178 | loss.backward() 179 | 180 | return SDOutput(loss=loss) -------------------------------------------------------------------------------- /models/SD_with_MetaCLIP.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from dataclasses import dataclass 4 | import torch 5 | import torch.nn as nn 6 | from transformers.utils import ModelOutput 7 | from transformers.modeling_utils import PreTrainedModel 8 | from .build import load_sd_model, load_clip_model_MetaCLIP 9 | from .utils import initiate_time_steps, prepare_class_text_embeddings 10 | import torchvision.transforms as transforms 11 | from open_clip import tokenize 12 | 13 | @dataclass 14 | class SDOutput(ModelOutput): 15 | loss: Optional[torch.FloatTensor] = None 16 | 17 | class SDModel(PreTrainedModel): 18 | def __init__( 19 | self, 20 | config = None, 21 | ): 22 | super().__init__(config) 23 | 24 | self.model_id, self.pipe, self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler, self.image_renormalizer = load_sd_model(config) 25 | self.text_encoder.eval() 26 | self.vae.eval() 27 | self.unet.eval() 28 | 29 | self.pattern_dictionary={'None':['']} 30 | self.config.actual_bs = len(self.pattern_dictionary[self.config.visual_pattern]) 31 | self.class_model = load_clip_model_MetaCLIP(config) 32 | self.class_model.eval() 33 | self.config = config 34 | discrimi_size = self.config.clip_image_size 35 | self.resize_transform_discrimi = transforms.Resize((discrimi_size, discrimi_size)) 36 | if config.metaclip_version == "large": 37 | self.visual_proj = nn.Linear(768, 1024) 38 | if config.metaclip_version == "huge": 39 | self.visual_proj = nn.Linear(1024, 1024) 40 | 41 | def classify(self, image, classes): 42 | 43 | image_features, logits = self.class_model(image) 44 | 45 | if classes is not None: 46 | logits = logits[:, classes] 47 | 48 | probs = logits.softmax(-1) 49 | max_idx = probs.argmax(-1) 50 | K = probs.shape[-1] if self.config.tta.adapt_topk == -1 else self.config.tta.adapt_topk 51 | topk_idx = probs.argsort(descending=True)[:, :K] 52 | 53 | if classes is not None: 54 | classes = torch.tensor(classes).to(logits.device) 55 | max_class_idx = classes[max_idx.flatten()].view(max_idx.shape) 56 | topk_class_idx = classes[topk_idx.flatten()].view(topk_idx.shape) 57 | else: 58 | max_class_idx, topk_class_idx = max_idx, topk_idx 59 | 60 | return image_features, logits, topk_idx, max_class_idx, topk_class_idx 61 | 62 | def _unet_pred_noise(self, x_start, t, noise, context): 63 | 64 | _,c,h,w = x_start.shape 65 | device = t.device 66 | nt = t.shape[0] 67 | 68 | x_start = x_start.unsqueeze(1) 69 | x_start = x_start.expand(-1, nt//x_start.shape[0], -1, -1, -1) 70 | x_start = x_start.reshape(-1,c,h,w) 71 | 72 | alphas_cumprod = self.scheduler.alphas_cumprod.to(device) 73 | noised_latent = ( 74 | x_start * (alphas_cumprod[t]**0.5).view(-1, 1, 1, 1).to(device) 75 | + noise * ((1 - alphas_cumprod[t])**0.5).view(-1, 1, 1, 1).to(device) 76 | ) 77 | pred_noise = self.unet(noised_latent, t, encoder_hidden_states=context.expand(nt, -1, -1)).sample 78 | 79 | return pred_noise 80 | 81 | def zeroshot_classifier_MetaCLIP(self, classnames, templates, model): 82 | with torch.no_grad(): 83 | zeroshot_weights = [] 84 | for classname in classnames: 85 | texts = [template.format(classname) for template in templates] #format with class 86 | texts = tokenize(texts).cuda() #tokenize 87 | class_embeddings = model.encode_text(texts) #embed with text encoder 88 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 89 | class_embedding = class_embeddings.mean(dim=0) 90 | class_embedding /= class_embedding.norm() 91 | zeroshot_weights.append(class_embedding) 92 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() 93 | return zeroshot_weights 94 | 95 | def forward( 96 | self, 97 | image: torch.Tensor = None, 98 | text = None 99 | ) -> SDOutput: 100 | 101 | text = self.pattern_dictionary[self.config.visual_pattern] 102 | with torch.no_grad(): 103 | imagenet_templates = ['{}',] 104 | zeroshot_weights = self.zeroshot_classifier_MetaCLIP(text, imagenet_templates, self.class_model.model.float()).float() 105 | 106 | self.class_model.final_fc.weight.data = zeroshot_weights.T 107 | self.class_model.final_fc.weight.data = self.class_model.final_fc.weight.data.contiguous() 108 | classes = [i for i in range(len(text))] 109 | 110 | discrimi_image = self.resize_transform_discrimi(image) 111 | genera_image = image 112 | real_BS = image.shape[0] 113 | after_DF_expand_BS = real_BS*self.config.input.batch_size 114 | 115 | # prepare_vae_latent 116 | self.vae, self.text_encoder, self.unet = self.vae.to(torch.float32), self.text_encoder.to(torch.float32), self.unet.to(torch.float32) 117 | renormed_image = self.image_renormalizer(genera_image).detach() 118 | x0 = self.vae.encode(renormed_image).latent_dist.mean.float() 119 | latent = x0 * 0.18215 120 | 121 | # prepare_total_timesteps 122 | total_timestep = self.scheduler.num_train_timesteps 123 | 124 | for step in range(self.config.tta.gradient_descent.train_steps): 125 | # Initiate timesteps and noise 126 | timesteps = initiate_time_steps(step, total_timestep, after_DF_expand_BS, self.config).long() 127 | timesteps = timesteps.cuda() 128 | 129 | c, h, w = latent.shape[1:] 130 | if not self.config.tta.use_same_noise_among_timesteps: 131 | noise = torch.randn((real_BS* self.config.input.batch_size, c, h, w)).cuda() 132 | else: 133 | noise = torch.randn((1, c, h, w)).cuda() 134 | noise = noise.repeat(real_BS* self.config.input.batch_size, 1, 1, 1) 135 | 136 | if self.config.tta.adapt_topk == -1: 137 | image_features, logits, _, _, _ = self.classify(discrimi_image, classes) 138 | pred_top_idx = None 139 | else: 140 | image_features, logits, pred_top_idx, _, _ = self.classify(discrimi_image, classes) 141 | real_BS, C = logits.shape[:2] 142 | 143 | # Pick top-K predictions 144 | if pred_top_idx is not None: 145 | pred_top_idx = pred_top_idx.squeeze(0) 146 | else: 147 | pred_top_idx = torch.arange(C).cuda() 148 | 149 | logits = logits[:, pred_top_idx] 150 | 151 | class_text_embeddings = prepare_class_text_embeddings(self.tokenizer, self.text_encoder, class_names=text) 152 | class_text_embeddings = class_text_embeddings.detach() 153 | class_text_embeddings = class_text_embeddings[pred_top_idx, :] 154 | 155 | # Compute conditional text embeddings using weighted-summed predictions 156 | probs = logits.softmax(-1) 157 | probs = probs[:, :, None, None] 158 | class_text_embeddings = (class_text_embeddings.unsqueeze(0).repeat(after_DF_expand_BS, 1, 1, 1)) 159 | _, word_num, _, _ = probs.shape 160 | probs = probs.unsqueeze(1).repeat(1,self.config.input.batch_size,1,1,1).reshape(-1,word_num,1,1) 161 | context = (probs * class_text_embeddings).sum(1) 162 | image_features = self.visual_proj(image_features) 163 | context = context.mean(dim=1).unsqueeze(1) + image_features 164 | 165 | # Predict noise with the diffusion model 166 | pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=context).float() 167 | 168 | # Compute diffusion loss 169 | if self.config.tta.loss == "l1": 170 | loss = torch.nn.functional.l1_loss(pred_noise, noise) 171 | else: 172 | loss = torch.nn.functional.mse_loss(pred_noise, noise) 173 | 174 | if step != (self.config.tta.gradient_descent.train_steps-1): 175 | loss.backward() 176 | 177 | return SDOutput(loss=loss) -------------------------------------------------------------------------------- /models/SD_with_OpenAICLIP.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from dataclasses import dataclass 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | from transformers.utils import ModelOutput 8 | from transformers.modeling_utils import PreTrainedModel 9 | from .build import load_sd_model, load_clip_model_OpenAICLIP 10 | from .utils import initiate_time_steps, prepare_class_text_embeddings 11 | import torchvision.transforms as transforms 12 | import clip 13 | 14 | 15 | @dataclass 16 | class SDOutput(ModelOutput): 17 | loss: Optional[torch.FloatTensor] = None 18 | 19 | 20 | class SDModel(PreTrainedModel): 21 | def __init__( 22 | self, 23 | config = None, 24 | ): 25 | super().__init__(config) 26 | 27 | self.model_id, self.pipe, self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler, self.image_renormalizer = load_sd_model(config) 28 | self.text_encoder.eval() 29 | self.vae.eval() 30 | self.unet.eval() 31 | 32 | self.pattern_dictionary={'None':['']} 33 | self.config.actual_bs = len(self.pattern_dictionary[self.config.visual_pattern]) 34 | self.class_model = load_clip_model_OpenAICLIP(config) 35 | self.class_model.eval() 36 | self.config = config 37 | discrimi_size = self.config.clip_image_size 38 | self.resize_transform_discrimi = transforms.Resize((discrimi_size, discrimi_size)) 39 | self.visual_proj = nn.Linear(768, 1024) 40 | 41 | def classify(self, image, classes): 42 | 43 | image_features, logits = self.class_model(image) 44 | 45 | if classes is not None: 46 | logits = logits[:, classes] 47 | 48 | probs = logits.softmax(-1) 49 | max_idx = probs.argmax(-1) 50 | K = probs.shape[-1] if self.config.tta.adapt_topk == -1 else self.config.tta.adapt_topk 51 | topk_idx = probs.argsort(descending=True)[:, :K] 52 | 53 | if classes is not None: 54 | classes = torch.tensor(classes).to(logits.device) 55 | max_class_idx = classes[max_idx.flatten()].view(max_idx.shape) 56 | topk_class_idx = classes[topk_idx.flatten()].view(topk_idx.shape) 57 | else: 58 | max_class_idx, topk_class_idx = max_idx, topk_idx 59 | 60 | return image_features, logits, topk_idx, max_class_idx, topk_class_idx 61 | 62 | def _unet_pred_noise(self, x_start, t, noise, context): 63 | 64 | _,c,h,w = x_start.shape 65 | device = t.device 66 | nt = t.shape[0] 67 | 68 | x_start = x_start.unsqueeze(1) 69 | x_start = x_start.expand(-1, nt//x_start.shape[0], -1, -1, -1) 70 | x_start = x_start.reshape(-1,c,h,w) 71 | 72 | alphas_cumprod = self.scheduler.alphas_cumprod.to(device) 73 | noised_latent = ( 74 | x_start * (alphas_cumprod[t]**0.5).view(-1, 1, 1, 1).to(device) 75 | + noise * ((1 - alphas_cumprod[t])**0.5).view(-1, 1, 1, 1).to(device) 76 | ) 77 | pred_noise = self.unet(noised_latent, t, encoder_hidden_states=context.expand(nt, -1, -1)).sample 78 | 79 | return pred_noise 80 | 81 | def zeroshot_classifier(self, classnames, templates, model): 82 | with torch.no_grad(): 83 | zeroshot_weights = [] 84 | for classname in classnames: 85 | texts = [template.format(classname) for template in templates] 86 | texts = clip.tokenize(texts, truncate=True).cuda() 87 | class_embeddings = model.encode_text(texts) 88 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 89 | class_embedding = class_embeddings.mean(dim=0) 90 | class_embedding /= class_embedding.norm() 91 | zeroshot_weights.append(class_embedding) 92 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() 93 | return zeroshot_weights 94 | 95 | def forward( 96 | self, 97 | image: torch.Tensor = None, 98 | text = None 99 | ) -> SDOutput: 100 | 101 | text = self.pattern_dictionary[self.config.visual_pattern] 102 | with torch.no_grad(): 103 | imagenet_templates = ['{}',] 104 | zeroshot_weights = self.zeroshot_classifier(text, imagenet_templates, self.class_model.model.float()).float() 105 | 106 | self.class_model.final_fc.weight.data = zeroshot_weights.T 107 | self.class_model.final_fc.weight.data = self.class_model.final_fc.weight.data.contiguous() 108 | classes = [i for i in range(len(text))] 109 | 110 | discrimi_image = self.resize_transform_discrimi(image) 111 | genera_image = image 112 | real_BS = image.shape[0] 113 | after_DF_expand_BS = real_BS*self.config.input.batch_size 114 | 115 | # prepare_vae_latent 116 | self.vae, self.text_encoder, self.unet = self.vae.to(torch.float32), self.text_encoder.to(torch.float32), self.unet.to(torch.float32) 117 | renormed_image = self.image_renormalizer(genera_image).detach() 118 | x0 = self.vae.encode(renormed_image).latent_dist.mean.float() 119 | latent = x0 * 0.18215 120 | 121 | # prepare_total_timesteps 122 | total_timestep = self.scheduler.num_train_timesteps 123 | 124 | for step in range(self.config.tta.gradient_descent.train_steps): 125 | # Initiate timesteps and noise 126 | timesteps = initiate_time_steps(step, total_timestep, after_DF_expand_BS, self.config).long() 127 | timesteps = timesteps.cuda() 128 | 129 | c, h, w = latent.shape[1:] 130 | if not self.config.tta.use_same_noise_among_timesteps: 131 | noise = torch.randn((real_BS* self.config.input.batch_size, c, h, w)).cuda() 132 | else: 133 | noise = torch.randn((1, c, h, w)).cuda() 134 | noise = noise.repeat(real_BS* self.config.input.batch_size, 1, 1, 1) 135 | 136 | if self.config.tta.adapt_topk == -1: 137 | image_features, logits, _, _, _ = self.classify(discrimi_image, classes) 138 | pred_top_idx = None 139 | else: 140 | image_features, logits, pred_top_idx, _, _ = self.classify(discrimi_image, classes) 141 | real_BS, C = logits.shape[:2] 142 | 143 | # Pick top-K predictions 144 | if pred_top_idx is not None: 145 | pred_top_idx = pred_top_idx.squeeze(0) 146 | else: 147 | pred_top_idx = torch.arange(C).cuda() 148 | 149 | logits = logits[:, pred_top_idx] 150 | 151 | class_text_embeddings = prepare_class_text_embeddings(self.tokenizer, self.text_encoder, class_names=text) 152 | class_text_embeddings = class_text_embeddings.detach() 153 | class_text_embeddings = class_text_embeddings[pred_top_idx, :] 154 | 155 | # Compute conditional text embeddings using weighted-summed predictions 156 | probs = logits.softmax(-1) 157 | probs = probs[:, :, None, None] 158 | class_text_embeddings = (class_text_embeddings.unsqueeze(0).repeat(after_DF_expand_BS, 1, 1, 1)) 159 | _, word_num, _, _ = probs.shape 160 | probs = probs.unsqueeze(1).repeat(1,self.config.input.batch_size,1,1,1).reshape(-1,word_num,1,1) 161 | context = (probs * class_text_embeddings).sum(1) 162 | image_features = self.visual_proj(image_features) 163 | context = context + image_features 164 | 165 | # Predict noise with the diffusion model 166 | pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=context).float() 167 | 168 | # Compute diffusion loss 169 | if self.config.tta.loss == "l1": 170 | loss = torch.nn.functional.l1_loss(pred_noise, noise) 171 | else: 172 | loss = torch.nn.functional.mse_loss(pred_noise, noise) 173 | 174 | if step != (self.config.tta.gradient_descent.train_steps-1): 175 | loss.backward() 176 | 177 | return SDOutput(loss=loss) -------------------------------------------------------------------------------- /models/SD_with_SigLIP.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | from dataclasses import dataclass 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | import numpy as np 9 | import torchvision 10 | from transformers.utils import ModelOutput 11 | from transformers.modeling_utils import PreTrainedModel 12 | from .build import load_sd_model, load_clip_model_SigLIP 13 | from .utils import initiate_time_steps, prepare_class_text_embeddings 14 | import torchvision.transforms as transforms 15 | from open_clip import get_tokenizer 16 | 17 | @dataclass 18 | class SDOutput(ModelOutput): 19 | loss: Optional[torch.FloatTensor] = None 20 | 21 | class SDModel(PreTrainedModel): 22 | def __init__( 23 | self, 24 | config = None, 25 | ): 26 | super().__init__(config) 27 | 28 | self.model_id, self.pipe, self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler, self.image_renormalizer = load_sd_model(config) 29 | self.text_encoder.eval() 30 | self.vae.eval() 31 | self.unet.eval() 32 | 33 | self.pattern_dictionary={'None':['']} 34 | self.config.actual_bs = len(self.pattern_dictionary[self.config.visual_pattern]) 35 | self.class_model = load_clip_model_SigLIP(config) 36 | self.class_model.eval() 37 | self.config = config 38 | discrimi_size = self.config.clip_image_size 39 | self.resize_transform_discrimi = transforms.Resize((discrimi_size, discrimi_size)) 40 | self.visual_proj = nn.Linear(1152, 1024) 41 | 42 | 43 | def classify(self, image, classes): 44 | 45 | image_features, logits = self.class_model(image) 46 | 47 | if classes is not None: 48 | logits = logits[:, classes] 49 | 50 | probs = logits.softmax(-1) 51 | max_idx = probs.argmax(-1) 52 | K = probs.shape[-1] if self.config.tta.adapt_topk == -1 else self.config.tta.adapt_topk 53 | topk_idx = probs.argsort(descending=True)[:, :K] 54 | 55 | if classes is not None: 56 | classes = torch.tensor(classes).to(logits.device) 57 | max_class_idx = classes[max_idx.flatten()].view(max_idx.shape) 58 | topk_class_idx = classes[topk_idx.flatten()].view(topk_idx.shape) 59 | else: 60 | max_class_idx, topk_class_idx = max_idx, topk_idx 61 | 62 | return image_features, logits, topk_idx, max_class_idx, topk_class_idx 63 | 64 | def _unet_pred_noise(self, x_start, t, noise, context): 65 | 66 | _,c,h,w = x_start.shape 67 | device = t.device 68 | nt = t.shape[0] 69 | 70 | x_start = x_start.unsqueeze(1) 71 | x_start = x_start.expand(-1, nt//x_start.shape[0], -1, -1, -1) 72 | x_start = x_start.reshape(-1,c,h,w) 73 | 74 | alphas_cumprod = self.scheduler.alphas_cumprod.to(device) 75 | noised_latent = ( 76 | x_start * (alphas_cumprod[t]**0.5).view(-1, 1, 1, 1).to(device) 77 | + noise * ((1 - alphas_cumprod[t])**0.5).view(-1, 1, 1, 1).to(device) 78 | ) 79 | pred_noise = self.unet(noised_latent, t, encoder_hidden_states=context.expand(nt, -1, -1)).sample 80 | 81 | return pred_noise 82 | 83 | def zeroshot_classifier_SigLIP(self, classnames, templates, model): 84 | with torch.no_grad(): 85 | zeroshot_weights = [] 86 | if self.config.clip_image_size == 224: 87 | tokenizer = get_tokenizer('ViT-SO400M-14-SigLIP') 88 | if self.config.clip_image_size == 384: 89 | tokenizer = get_tokenizer('ViT-SO400M-14-SigLIP-384') 90 | for classname in classnames: 91 | texts = [template.format(classname) for template in templates] 92 | texts = tokenizer(texts, context_length=model.context_length).cuda() 93 | class_embeddings = model.encode_text(texts) 94 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 95 | class_embedding = class_embeddings.mean(dim=0) 96 | class_embedding /= class_embedding.norm() 97 | zeroshot_weights.append(class_embedding) 98 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() 99 | return zeroshot_weights 100 | 101 | def forward( 102 | self, 103 | image: torch.Tensor = None, 104 | text = None 105 | ) -> SDOutput: 106 | 107 | text = self.pattern_dictionary[self.config.visual_pattern] 108 | with torch.no_grad(): 109 | imagenet_templates = ['{}',] 110 | zeroshot_weights = self.zeroshot_classifier_SigLIP(text, imagenet_templates, self.class_model.model.float()).float() 111 | 112 | self.class_model.final_fc.weight.data = zeroshot_weights.T 113 | self.class_model.final_fc.weight.data = self.class_model.final_fc.weight.data.contiguous() 114 | classes = [i for i in range(len(text))] 115 | 116 | discrimi_image = self.resize_transform_discrimi(image) 117 | genera_image = image 118 | real_BS = image.shape[0] 119 | after_DF_expand_BS = real_BS*self.config.input.batch_size 120 | 121 | # prepare_vae_latent 122 | self.vae, self.text_encoder, self.unet = self.vae.to(torch.float32), self.text_encoder.to(torch.float32), self.unet.to(torch.float32) 123 | renormed_image = self.image_renormalizer(genera_image).detach() 124 | x0 = self.vae.encode(renormed_image).latent_dist.mean.float() 125 | latent = x0 * 0.18215 126 | 127 | # prepare_total_timesteps 128 | total_timestep = self.scheduler.num_train_timesteps 129 | 130 | for step in range(self.config.tta.gradient_descent.train_steps): 131 | # Initiate timesteps and noise 132 | timesteps = initiate_time_steps(step, total_timestep, after_DF_expand_BS, self.config).long() 133 | timesteps = timesteps.cuda() 134 | 135 | c, h, w = latent.shape[1:] 136 | if not self.config.tta.use_same_noise_among_timesteps: 137 | noise = torch.randn((real_BS* self.config.input.batch_size, c, h, w)).cuda() 138 | else: 139 | noise = torch.randn((1, c, h, w)).cuda() 140 | noise = noise.repeat(real_BS* self.config.input.batch_size, 1, 1, 1) 141 | 142 | if self.config.tta.adapt_topk == -1: 143 | image_features, logits, _, _, _ = self.classify(discrimi_image, classes) 144 | pred_top_idx = None 145 | else: 146 | image_features, logits, pred_top_idx, _, _ = self.classify(discrimi_image, classes) 147 | real_BS, C = logits.shape[:2] 148 | 149 | # Pick top-K predictions 150 | if pred_top_idx is not None: 151 | pred_top_idx = pred_top_idx.squeeze(0) 152 | else: 153 | pred_top_idx = torch.arange(C).cuda() 154 | 155 | logits = logits[:, pred_top_idx] 156 | 157 | class_text_embeddings = prepare_class_text_embeddings(self.tokenizer, self.text_encoder, class_names=text) 158 | class_text_embeddings = class_text_embeddings.detach() 159 | class_text_embeddings = class_text_embeddings[pred_top_idx, :] 160 | 161 | # Compute conditional text embeddings using weighted-summed predictions 162 | probs = logits.softmax(-1) 163 | probs = probs[:, :, None, None] 164 | class_text_embeddings = (class_text_embeddings.unsqueeze(0).repeat(after_DF_expand_BS, 1, 1, 1)) 165 | _, word_num, _, _ = probs.shape 166 | probs = probs.unsqueeze(1).repeat(1,self.config.input.batch_size,1,1,1).reshape(-1,word_num,1,1) 167 | context = (probs * class_text_embeddings).sum(1) 168 | image_features = self.visual_proj(image_features) 169 | context = context.mean(dim=1).unsqueeze(1) + image_features 170 | 171 | # Predict noise with the diffusion model 172 | pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=context).float() 173 | 174 | # Compute diffusion loss 175 | if self.config.tta.loss == "l1": 176 | loss = torch.nn.functional.l1_loss(pred_noise, noise) 177 | else: 178 | loss = torch.nn.functional.mse_loss(pred_noise, noise) 179 | 180 | if step != (self.config.tta.gradient_descent.train_steps-1): 181 | loss.backward() 182 | 183 | return SDOutput(loss=loss) -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionPipeline, DDPMScheduler, EulerDiscreteScheduler 3 | from transformers import CLIPTextModel, CLIPTokenizer 4 | from .utils import VQVAEUnNormalize 5 | from .CLIP_bank import OpenAICLIP, DFN, SigLIP, MetaCLIP 6 | 7 | def load_sd_model(config): 8 | """Load Stable Diffusion model""" 9 | dtype = torch.float32 10 | image_renormalizer = VQVAEUnNormalize( 11 | mean=config.input.mean, std=config.input.std 12 | ) 13 | if config.model.sd_version == '1-4': 14 | if config.model.use_flash: 15 | model_id = "CompVis/stable-diffusion-v1-4" 16 | scheduler = EulerDiscreteScheduler.from_pretrained( 17 | model_id, subfolder="scheduler" 18 | ) 19 | pipe = StableDiffusionPipeline.from_pretrained( 20 | model_id, scheduler=scheduler, torch_dtype=dtype 21 | ).cuda() 22 | pipe.enable_xformers_memory_efficient_attention() 23 | vae = pipe.vae.cuda() 24 | tokenizer = pipe.tokenizer 25 | text_encoder = pipe.text_encoder.cuda() 26 | unet = pipe.unet.cuda() 27 | else: 28 | vae = AutoencoderKL.from_pretrained( 29 | f"CompVis/stable-diffusion-v{config.model.sd_version}", 30 | subfolder="vae", torch_dtype=dtype 31 | ).cuda() 32 | tokenizer = CLIPTokenizer.from_pretrained( 33 | "/share/project/wangwenxuan/projects/Overcome_VS/MMVP_Test/openai/clip-vit-large-patch14" 34 | ) 35 | text_encoder = CLIPTextModel.from_pretrained( 36 | "/share/project/wangwenxuan/projects/Overcome_VS/MMVP_Test/openai/clip-vit-large-patch14", torch_dtype=dtype 37 | ).cuda() 38 | unet = UNet2DConditionModel.from_pretrained( 39 | f"CompVis/stable-diffusion-v{config.model.sd_version}", 40 | subfolder="unet", torch_dtype=dtype 41 | ).cuda() 42 | scheduler_config = get_scheduler_config(config) 43 | scheduler = DDPMScheduler( 44 | num_train_timesteps=scheduler_config['num_train_timesteps'], 45 | beta_start=scheduler_config['beta_start'], 46 | beta_end=scheduler_config['beta_end'], 47 | beta_schedule=scheduler_config['beta_schedule'] 48 | ) 49 | elif config.model.sd_version == '2-1': 50 | 51 | model_id = "pretrained_weights/SD/stable-diffusion-2-1-base" 52 | print(f'model_id:{model_id}') 53 | 54 | scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") 55 | pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=dtype) 56 | pipe.to(dtype) 57 | 58 | pipe.enable_xformers_memory_efficient_attention() 59 | vae = pipe.vae.cuda() 60 | tokenizer = pipe.tokenizer 61 | text_encoder = pipe.text_encoder.cuda() 62 | unet = pipe.unet.cuda() 63 | 64 | if config.model.adapt_only_classifier: 65 | for m in [vae, text_encoder, unet]: 66 | for param in m.parameters(): 67 | param.requires_grad = False 68 | for m in [vae, text_encoder]: 69 | for param in m.parameters(): 70 | param.requires_grad = False 71 | 72 | return (model_id, pipe, vae, tokenizer, text_encoder, unet, scheduler, image_renormalizer) 73 | 74 | 75 | def get_scheduler_config(config): 76 | assert config.model.sd_version in {'1-4', '2-1'} 77 | if config.model.sd_version == '1-4': 78 | schedule_config = { 79 | "_class_name": "PNDMScheduler", 80 | "_diffusers_version": "0.7.0.dev0", 81 | "beta_end": 0.012, 82 | "beta_schedule": "scaled_linear", 83 | "beta_start": 0.00085, 84 | "num_train_timesteps": 1000, 85 | "set_alpha_to_one": False, 86 | "skip_prk_steps": True, 87 | "steps_offset": 1, 88 | "trained_betas": None, 89 | "clip_sample": False 90 | } 91 | elif config.model.sd_version == '2-1': 92 | schedule_config = { 93 | "_class_name": "EulerDiscreteScheduler", 94 | "_diffusers_version": "0.10.2", 95 | "beta_end": 0.012, 96 | "beta_schedule": "scaled_linear", 97 | "beta_start": 0.00085, 98 | "clip_sample": False, 99 | "num_train_timesteps": 1000, 100 | "prediction_type": "epsilon", 101 | "set_alpha_to_one": False, 102 | "skip_prk_steps": True, 103 | "steps_offset": 1, # todo 104 | "trained_betas": None 105 | } 106 | else: 107 | raise NotImplementedError 108 | 109 | return schedule_config 110 | 111 | 112 | def load_clip_model_OpenAICLIP(config): 113 | 114 | class_model = OpenAICLIP(config) 115 | class_model.to(torch.float32) 116 | 117 | return class_model 118 | 119 | 120 | def load_clip_model_DFN(config): 121 | 122 | class_model = DFN(config) 123 | class_model.to(torch.float32) 124 | 125 | return class_model 126 | 127 | 128 | def load_clip_model_SigLIP(config): 129 | 130 | class_model = SigLIP(config) 131 | class_model.to(torch.float32) 132 | 133 | return class_model 134 | 135 | 136 | def load_clip_model_MetaCLIP(config): 137 | 138 | class_model = MetaCLIP(config) 139 | class_model.to(torch.float32) 140 | 141 | return class_model -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions""" 2 | import importlib 3 | import random 4 | 5 | import torch 6 | import numpy as np 7 | from PIL import Image 8 | from omegaconf import OmegaConf, open_dict 9 | 10 | 11 | class UnNormalize(object): 12 | """Unformalize image as: image = (image * std) + mean 13 | """ 14 | def __init__(self, mean, std): 15 | self.mean = torch.tensor(mean) 16 | self.std = torch.tensor(std) 17 | 18 | def __call__(self, tensor): 19 | """ 20 | Args: 21 | tensor: A tensor of shape [C, H, W] or [N, C, H, W] 22 | 23 | Returns: 24 | tensor: A tensor of shape [C, H, W] or [N, C, H, W] 25 | """ 26 | 27 | std = self.std.to(tensor.device) 28 | mean = self.mean.to(tensor.device) 29 | if tensor.ndim == 3: 30 | std, mean = std.view(-1, 1, 1), mean.view(-1, 1, 1) 31 | elif tensor.ndim == 4: 32 | std, mean = std.view(1, -1, 1, 1), mean.view(1, -1, 1, 1) 33 | tensor = (tensor * std) + mean 34 | return tensor 35 | 36 | 37 | class VQVAEUnNormalize(UnNormalize): 38 | """Unformalize image as: 39 | First: image = (image * std) + mean 40 | Second: image = (image * 2) - 1 41 | """ 42 | def __call__(self, tensor): 43 | """ 44 | Args: 45 | tensor (Tensor): Tensor image of size (C, H, W) or (N, C, H, W) 46 | to be unnormalized. 47 | Returns: 48 | Tensor: UnNormalized image. 49 | """ 50 | tensor = super().__call__(tensor) 51 | tensor = 2 * tensor - 1 52 | return tensor 53 | 54 | 55 | def mean_list(l): 56 | l = [int(_l) for _l in l] 57 | return float(sum(l)) / len(l) 58 | 59 | 60 | def segment_mean(x, index): 61 | """Function as tf.segment_mean. 62 | """ 63 | x = x.view(-1, x.shape[-1]) 64 | index = index.view(-1) 65 | 66 | max_index = index.max() + 1 67 | sum_x = torch.zeros((max_index, x.shape[-1]), 68 | dtype=x.dtype, 69 | device=x.device) 70 | num_index = torch.zeros((max_index,), 71 | dtype=x.dtype, 72 | device=x.device) 73 | 74 | num_index = num_index.scatter_add_( 75 | 0, index, torch.ones_like(index, dtype=x.dtype)) 76 | num_index = torch.where(torch.eq(num_index, 0), 77 | torch.ones_like(num_index, dtype=x.dtype), 78 | num_index) 79 | 80 | index_2d = index.view(-1, 1).expand(-1, x.shape[-1]) 81 | sum_x = sum_x.scatter_add_(0, index_2d, x) 82 | mean_x = sum_x.div_(num_index.view(-1, 1)) 83 | 84 | return mean_x 85 | 86 | 87 | def get_class_sd_features(tokenizer, text_encoder, input, device=None): 88 | """Prepare class text embeddings for Stable Diffusion 89 | 90 | Args: 91 | tokenizer: A nn.Module object of tokenizer. 92 | text_encoder: A nn.Module object of text encoder. 93 | input: A string 94 | device: GPU/CPU device 95 | """ 96 | with torch.no_grad(): 97 | input_val = f'a photo of a {input}.', 98 | # Tokenize the text 99 | text_input = tokenizer(input_val, padding="max_length", 100 | max_length=tokenizer.model_max_length, 101 | truncation=True, 102 | return_tensors="pt") 103 | # Get the text embeddings 104 | text_embeddings = text_encoder(text_input.input_ids.cuda())[0] 105 | 106 | return text_embeddings 107 | 108 | 109 | def prepare_class_text_embeddings(tokenizer=None, text_encoder=None, class_names=None): 110 | 111 | text_features = [] 112 | for class_name in class_names: 113 | text_features.append( 114 | get_class_sd_features(tokenizer, text_encoder, class_name) 115 | ) 116 | text_features = torch.cat(text_features, dim=0) 117 | 118 | return text_features 119 | 120 | 121 | def initiate_time_steps(step, total_timestep, batch_size, config): 122 | """A helper function to initiate time steps for the diffusion model. 123 | 124 | Args: 125 | step: An integer of the constant step 126 | total_timestep: An integer of the total timesteps of the diffusion model 127 | batch_size: An integer of the batch size 128 | config: A config object 129 | 130 | Returns: 131 | timesteps: A tensor of shape [batch_size,] of the time steps 132 | """ 133 | if config.tta.rand_timestep_equal_int: 134 | interval_val = total_timestep // batch_size 135 | start_point = random.randint(0, interval_val - 1) 136 | timesteps = torch.tensor( 137 | list(range(start_point, total_timestep, interval_val)) 138 | ).long() 139 | return timesteps 140 | elif config.tta.random_timestep_per_iteration: 141 | return torch.randint(0, total_timestep, (batch_size,)).long() #default 142 | else: 143 | return torch.tensor([step] * batch_size).long() 144 | 145 | 146 | def instantiate_from_config(config): 147 | """A helper function to instantiate a class from a config object. 148 | See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py 149 | """ 150 | if not "target" in config: 151 | if config == '__is_first_stage__': 152 | return None 153 | elif config == "__is_unconditional__": 154 | return None 155 | raise KeyError("Expected key `target` to instantiate.") 156 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 157 | 158 | 159 | def get_obj_from_str(string, reload=False): 160 | """A helper function to instantiate a class from a config object. 161 | See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py 162 | """ 163 | module, cls = string.rsplit(".", 1) 164 | if reload: 165 | module_imp = importlib.import_module(module) 166 | importlib.reload(module_imp) 167 | return getattr(importlib.import_module(module, package=None), cls) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.28.0 3 | aiohttp==3.9.3 4 | aiosignal==1.3.1 5 | annotated-types==0.6.0 6 | antlr4-python3-runtime==4.9.3 7 | anyio==4.3.0 8 | appdirs==1.4.4 9 | asttokens==2.4.1 10 | async-timeout==4.0.3 11 | attrs==23.2.0 12 | beartype==0.17.2 13 | blessed==1.20.0 14 | bypy==1.8.4 15 | certifi==2024.2.2 16 | charset-normalizer==3.3.2 17 | click==8.1.7 18 | clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33 19 | cmake==3.28.3 20 | contourpy==1.2.0 21 | cycler==0.12.1 22 | datasets==2.17.1 23 | decorator==5.1.1 24 | decord==0.6.0 25 | deepspeed==0.14.0 26 | diffusers==0.25.0 27 | dill==0.3.8 28 | distro==1.9.0 29 | docker-pycreds==0.4.0 30 | einops==0.6.0 31 | ema-pytorch==0.4.2 32 | exceptiongroup==1.2.0 33 | executing==2.0.1 34 | filelock==3.13.1 35 | flash-attn==2.5.6 36 | fonttools==4.47.2 37 | frozenlist==1.4.1 38 | fsspec==2023.10.0 39 | ftfy==6.1.3 40 | gitdb==4.0.11 41 | GitPython==3.1.42 42 | gpustat==1.1.1 43 | grpcio==1.60.1 44 | h11==0.14.0 45 | hjson==3.1.0 46 | httpcore==1.0.4 47 | httpx==0.27.0 48 | huggingface-hub==0.20.3 49 | hydra-core==1.3.2 50 | idna==3.6 51 | importlib-metadata==7.0.1 52 | importlib_resources==6.1.2 53 | ipdb==0.13.13 54 | ipython==8.18.1 55 | jedi==0.19.1 56 | Jinja2==3.1.3 57 | joblib==1.3.2 58 | kiwisolver==1.4.5 59 | lightning-utilities==0.11.2 60 | lit==17.0.6 61 | Markdown==3.5.2 62 | MarkupSafe==2.1.5 63 | matplotlib==3.8.2 64 | matplotlib-inline==0.1.6 65 | mergedeep==1.3.4 66 | mpmath==1.3.0 67 | multidict==6.0.5 68 | multiprocess==0.70.16 69 | mypy-extensions==1.0.0 70 | networkx==3.2.1 71 | ninja==1.11.1.1 72 | numpy==1.26.4 73 | nvidia-cublas-cu11==11.10.3.66 74 | nvidia-cuda-cupti-cu11==11.7.101 75 | nvidia-cuda-nvrtc-cu11==11.7.99 76 | nvidia-cuda-runtime-cu11==11.7.99 77 | nvidia-cudnn-cu11==8.5.0.96 78 | nvidia-cufft-cu11==10.9.0.58 79 | nvidia-curand-cu11==10.2.10.91 80 | nvidia-cusolver-cu11==11.4.0.1 81 | nvidia-cusparse-cu11==11.7.4.91 82 | nvidia-ml-py==12.535.133 83 | nvidia-nccl-cu11==2.14.3 84 | nvidia-nvtx-cu11==11.7.91 85 | omegaconf==2.2.3 86 | open-clip-torch==2.24.0 87 | openai==0.28.1 88 | opencv-python==4.9.0.80 89 | packaging==23.2 90 | pandas==2.2.0 91 | parso==0.8.3 92 | pexpect==4.9.0 93 | pillow==10.2.0 94 | prompt-toolkit==3.0.43 95 | protobuf==4.25.3 96 | psutil==5.9.8 97 | ptyprocess==0.7.0 98 | pure-eval==0.2.2 99 | py-cpuinfo==9.0.0 100 | pyarrow==12.0.0 101 | pyarrow-hotfix==0.6 102 | pydantic==2.6.3 103 | pydantic_core==2.16.3 104 | Pygments==2.17.2 105 | pynvml==11.5.0 106 | pyparsing==3.1.1 107 | pyre-extensions==0.0.29 108 | python-dateutil==2.8.2 109 | pytz==2024.1 110 | PyYAML==6.0.1 111 | regex==2023.12.25 112 | requests==2.31.0 113 | requests-toolbelt==1.0.0 114 | safetensors==0.4.2 115 | scikit-learn==1.4.1.post1 116 | scipy==1.12.0 117 | sentencepiece==0.2.0 118 | sentry-sdk==1.40.5 119 | setproctitle==1.3.3 120 | six==1.16.0 121 | smmap==5.0.1 122 | sniffio==1.3.1 123 | stack-data==0.6.3 124 | sympy==1.12 125 | tensorboard==2.16.2 126 | tensorboard-data-server==0.7.2 127 | threadpoolctl==3.3.0 128 | timm==0.9.8 129 | tokenizers==0.15.2 130 | tomli==2.0.1 131 | torch==2.0.0 132 | torchaudio==2.0.1 133 | torchmetrics==1.3.2 134 | torchvision==0.15.1 135 | tqdm==4.66.2 136 | traitlets==5.14.1 137 | transformers==4.39.3 138 | triton==2.0.0 139 | typing-inspect==0.9.0 140 | typing_extensions==4.9.0 141 | tzdata==2023.4 142 | urllib3==2.2.1 143 | wandb==0.16.3 144 | wcwidth==0.2.13 145 | Werkzeug==3.0.1 146 | xformers==0.0.19 147 | xxhash==3.4.1 148 | yarl==1.9.4 149 | zipp==3.17.0 150 | -------------------------------------------------------------------------------- /run_DIVA_with_DFN.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import random 5 | import datetime 6 | import builtins 7 | sys.path.append(os.getcwd()) 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 9 | import numpy as np 10 | import torch 11 | import transformers 12 | from transformers.trainer_utils import set_seed 13 | from transformers import HfArgumentParser 14 | from transformers.trainer_utils import get_last_checkpoint 15 | from transformers.utils.versions import require_version 16 | from trainer import CustomTrainer 17 | from arguments import DataTrainingArguments, ModelArguments, TrainingArguments 18 | import csv 19 | from tqdm import tqdm 20 | from PIL import Image 21 | import json 22 | from open_clip import create_model_from_pretrained, get_tokenizer 23 | logger = logging.getLogger(__name__) 24 | import warnings 25 | warnings.filterwarnings("ignore") 26 | 27 | 28 | def random_seed(seed=42, rank=0): 29 | set_seed(seed) 30 | torch.manual_seed(seed + rank) 31 | np.random.seed(seed + rank) 32 | random.seed(seed + rank) 33 | try: 34 | import deepspeed 35 | deepspeed.runtime.utils.set_random_seed(seed + rank) 36 | except: 37 | print("deepspeed.runtime.utils.set_random_seed is not available") 38 | 39 | 40 | def setup_for_distributed(is_master): 41 | """ 42 | This function disables printing when not in master process 43 | """ 44 | builtin_print = builtins.print 45 | 46 | def print(*args, **kwargs): 47 | force = kwargs.pop('force', False) 48 | if is_master or force: 49 | now = datetime.datetime.now().time() 50 | builtin_print('[{}] '.format(now), end='') 51 | builtin_print(*args, **kwargs) 52 | 53 | builtins.print = print 54 | 55 | 56 | def setup_wandb_env(wandb_api_key=None): 57 | os.environ["WANDB_API_KEY"] = wandb_api_key or '' 58 | os.environ["WANDB_MODE"] = "offline" 59 | os.environ["WANDB__SERVICE_WAIT"] = "300" 60 | os.environ["WANDB_CONFIG_DIR"] = "./wandb" 61 | 62 | 63 | def main(): 64 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 65 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 66 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 67 | else: 68 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 69 | 70 | training_args.ddp_find_unused_parameters = True 71 | training_args.multiple_optimizer_training = False 72 | training_args.one_minus_one_data_transform = data_args.one_minus_one_data_transform 73 | training_args.cost_gradient_penalty = model_args.cost_gradient_penalty 74 | setup_wandb_env(training_args.wandb_api_key) 75 | 76 | # Setup logging 77 | logging.basicConfig( 78 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 79 | datefmt="%m/%d/%Y %H:%M:%S", 80 | handlers=[logging.StreamHandler(sys.stdout)], 81 | ) 82 | 83 | if training_args.should_log: 84 | transformers.utils.logging.set_verbosity_info() 85 | 86 | log_level = training_args.get_process_log_level() 87 | logger.setLevel(log_level) 88 | transformers.utils.logging.set_verbosity(log_level) 89 | transformers.utils.logging.enable_default_handler() 90 | transformers.utils.logging.enable_explicit_format() 91 | 92 | # Log on each process the small summary: 93 | logger.warning( 94 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, " 95 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}" 96 | ) 97 | logger.info(f"Training/evaluation parameters {training_args}") 98 | logger.info(f"Model parameters {model_args}") 99 | logger.info(f"Data parameters {data_args}") 100 | 101 | # Detecting last checkpoint. 102 | last_checkpoint = None 103 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 104 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 105 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 106 | raise ValueError( 107 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 108 | "Use --overwrite_output_dir to overcome." 109 | ) 110 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 111 | logger.info( 112 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 113 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 114 | ) 115 | 116 | # data_args.data_seed 117 | random_seed(training_args.seed) 118 | data_args.seed = training_args.seed 119 | training_args.model_type = "image" 120 | 121 | from models.SD_with_DFN import SDModel 122 | from config import SDConfig 123 | 124 | config = SDConfig() 125 | config.tta.gradient_descent.train_steps = training_args.train_steps 126 | config.visual_pattern = training_args.visual_pattern 127 | config.clip_image_size = training_args.clip_image_size 128 | model = SDModel(config) 129 | 130 | # print model parameters 131 | logger.info(f"{str(model)}") 132 | model.cuda() 133 | 134 | from data import get_cc3m_wds_dataset_and_collator 135 | wds_dataset, wds_collator = get_cc3m_wds_dataset_and_collator(data_args, model_args) 136 | 137 | if config.model.freeze_class_embeds: 138 | params = [] 139 | for key,parm in model.named_parameters(): 140 | if 'final_fc' not in key: 141 | params.append(parm) 142 | 143 | optimizer = torch.optim.SGD( 144 | params, lr=training_args.learning_rate, 145 | weight_decay=training_args.weight_decay, 146 | momentum=config.tta.gradient_descent.optimizer_momentum 147 | ) 148 | scheduler = None 149 | 150 | trainer = CustomTrainer( 151 | model=model, 152 | args=training_args, 153 | train_dataset=wds_dataset, 154 | data_collator=wds_collator, 155 | optimizers=(optimizer, scheduler) 156 | ) 157 | setup_for_distributed(torch.distributed.get_rank() == 0) 158 | 159 | from callbacks import ModelCallback 160 | trainer.add_callback(ModelCallback) 161 | 162 | # Evaluation 163 | if training_args.local_rank == 0: 164 | print("CLIP's Performance on MMVP-VLM —— Before Generative Fine-tuning") 165 | results_before = official_evaluation(model.class_model.model, config) 166 | print(results_before) 167 | 168 | # Training 169 | if training_args.do_train: 170 | checkpoint = None 171 | if training_args.resume_from_checkpoint is not None: 172 | checkpoint = training_args.resume_from_checkpoint 173 | elif last_checkpoint is not None: 174 | checkpoint = last_checkpoint 175 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 176 | trainer.save_model(output_dir=training_args.output_dir) 177 | trainer.log_metrics("train", train_result.metrics) 178 | trainer.save_metrics("train", train_result.metrics) 179 | trainer.save_state() 180 | 181 | # Evaluation 182 | if training_args.local_rank == 0: 183 | print("CLIP's Performance on MMVP-VLM —— After Generative Fine-tuning") 184 | model_weight_save_path = os.path.join(training_args.output_dir, 'CLIP_after_GenFT.pth') 185 | torch.save(trainer.model.state_dict(), model_weight_save_path) 186 | results_final_after = official_evaluation(trainer.model.class_model.model, config) 187 | print(results_final_after) 188 | save_results(results_before, results_final_after, output_dir=training_args.output_dir) 189 | 190 | 191 | def benchmark_model(model, benchmark_dir, device = "cpu", config=None): 192 | if config.clip_image_size == 224: 193 | _, preprocess = create_model_from_pretrained(model_name='ViT-H-14-quickgelu', pretrained="pretrained_weights/CLIP/DFN5B-CLIP-ViT-H-14/open_clip_pytorch_model.bin", device=device) 194 | if config.clip_image_size == 378: 195 | _, preprocess = create_model_from_pretrained(model_name='ViT-H-14-378-quickgelu', pretrained="pretrained_weights/CLIP/DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin", device=device) 196 | 197 | tokenizer = get_tokenizer('ViT-H-14') 198 | 199 | image_dir = os.path.join(benchmark_dir, 'MLLM_VLM_Images') 200 | csv_file = os.path.join(benchmark_dir, 'Questions.csv') 201 | 202 | csv_outfile = open('Prediction_Results_DFN.csv', 'w', newline='') 203 | csv_writer = csv.writer(csv_outfile) 204 | csv_writer.writerow(['qid1', 'qid2', 'pred1', 'pred2', 'gt1', 'gt2', 'q1score', 'q2score']) # header 205 | 206 | categories = [ 207 | 'Orientation and Direction', 'Presence of Specific Features', 208 | 'State and Condition', 'Quantity and Count', 209 | 'Positional and Relational Context', 'Color and Appearance', 210 | 'Structural Characteristics', 'Texts', 211 | 'Viewpoint and Perspective' 212 | ] 213 | 214 | pair_accuracies = {category: 0 for category in categories} 215 | num_pairs = 0 216 | 217 | with open(csv_file, 'r') as f: 218 | reader = csv.reader(f) 219 | next(reader) # skip header 220 | for i, row in tqdm(enumerate(reader)): 221 | qid1, qtype1, statement1 = row 222 | 223 | # Get next row for the pair 224 | row = next(reader, None) 225 | if not row: 226 | break 227 | qid2, qtype2, statement2 = row 228 | 229 | qid1, qid2 = int(qid1), int(qid2) 230 | 231 | img1 = Image.open(os.path.join(image_dir, qtype1, f'{qid1}.jpg')) 232 | img2 = Image.open(os.path.join(image_dir, qtype1, f'{qid2}.jpg')) 233 | 234 | text1 = 'a photo of ' + statement1 235 | text2 = 'a photo of ' + statement2 236 | 237 | text1 = tokenizer(text1).to(device) 238 | text2 = tokenizer(text2).to(device) 239 | 240 | img1 = preprocess(img1).unsqueeze(0).to(device) 241 | img2 = preprocess(img2).unsqueeze(0).to(device) 242 | imgs = torch.cat((img1, img2), dim=0) 243 | 244 | with torch.no_grad(), torch.cuda.amp.autocast(): 245 | model.eval().float() 246 | 247 | # original code 248 | # image_features = model.encode_image(imgs) 249 | 250 | # ours 251 | if config.clip_image_size == 224: 252 | image_features = model.encode_image(imgs, normalize=True)[:,0,:] 253 | if config.clip_image_size == 378: 254 | global_image_features = model.encode_image(imgs, normalize=True)[:,0,:] 255 | local_image_features = model.encode_image(imgs, normalize=True)[:,1:,:].mean(dim=1) 256 | image_features = global_image_features + local_image_features 257 | 258 | text1_features = model.encode_text(text1, normalize=True) 259 | text2_features = model.encode_text(text2, normalize=True) 260 | logits_per_image1 = model.logit_scale.exp() * image_features @ text1_features.T 261 | logits_per_text1 = logits_per_image1.T 262 | logits_per_image2 = model.logit_scale.exp() * image_features @ text2_features.T 263 | logits_per_text2 = logits_per_image2.T 264 | probs1 = logits_per_text1.softmax(dim=-1).cpu().numpy() 265 | probs2 = logits_per_text2.softmax(dim=-1).cpu().numpy() 266 | 267 | 268 | img1_score1 = probs1[0][0] 269 | img1_score2 = probs2[0][0] 270 | 271 | pred1 = "img1" if img1_score1 > 0.5 else "img2" 272 | pred2 = "img1" if img1_score2 > 0.5 else "img2" 273 | 274 | gt1 = "img1" if qid1 % 2 == 1 else "img2" 275 | gt2 = "img1" if qid2 % 2 == 1 else "img2" 276 | 277 | csv_writer.writerow([qid1, qid2, pred1, pred2, gt1, gt2, img1_score1, img1_score2]) 278 | 279 | current_category = categories[num_pairs // 15] 280 | if pred1 == gt1 and pred2 == gt2: 281 | pair_accuracies[current_category] += 1 282 | num_pairs += 1 283 | 284 | csv_outfile.close() 285 | 286 | # Calculate percentage accuracies 287 | Category_Score_List = [] 288 | 289 | for category in pair_accuracies: 290 | pair_accuracies[category] = (pair_accuracies[category] / (num_pairs // len(categories))) * 100 291 | Category_Score_List.append(pair_accuracies[category]) 292 | 293 | pair_accuracies['average_score'] = sum(Category_Score_List)/len(Category_Score_List) 294 | 295 | return pair_accuracies 296 | 297 | def official_evaluation(clip_model, config): 298 | 299 | with torch.no_grad(): 300 | clip_model.eval() 301 | 302 | # models 303 | data = "dataset/MMVP_VLM" 304 | clip_model_device = next(clip_model.parameters()).device 305 | if config.clip_image_size == 224: 306 | results_openai = {f'DFN5B-CLIP-ViT-H-14': benchmark_model(clip_model, data, clip_model_device, config)} 307 | if config.clip_image_size == 378: 308 | results_openai = {f'DFN5B-CLIP-ViT-H-14-378': benchmark_model(clip_model, data, clip_model_device, config)} 309 | 310 | # Merge results 311 | results = {**results_openai} 312 | 313 | # Convert results to format suitable for star plot 314 | categories = results[list(results.keys())[0]].keys() 315 | data = {'Categories': list(categories)} 316 | for model in list(results_openai.keys()): 317 | data[model] = [results[model][category] for category in categories] 318 | 319 | return results 320 | 321 | def save_results(results_before, results_final_after, output_dir, filename='pred_result.json'): 322 | 323 | os.makedirs(output_dir, exist_ok=True) 324 | 325 | output_data = { 326 | 'results_before': results_before, 327 | 'results_final_after': results_final_after 328 | } 329 | 330 | output_path = os.path.join(output_dir, filename) 331 | 332 | with open(output_path, 'w') as f: 333 | json.dump(output_data, f, indent=4) 334 | 335 | 336 | if __name__ == "__main__": 337 | main() 338 | -------------------------------------------------------------------------------- /run_DIVA_with_MetaCLIP.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import random 5 | import datetime 6 | import builtins 7 | sys.path.append(os.getcwd()) 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | import transformers 13 | from transformers.trainer_utils import set_seed 14 | from transformers import HfArgumentParser 15 | from transformers.trainer_utils import get_last_checkpoint 16 | from transformers.utils.versions import require_version 17 | from trainer import CustomTrainer 18 | from arguments import DataTrainingArguments, ModelArguments, TrainingArguments 19 | import csv 20 | from tqdm import tqdm 21 | from PIL import Image 22 | import json 23 | from open_clip import create_model_and_transforms, tokenize 24 | logger = logging.getLogger(__name__) 25 | import warnings 26 | warnings.filterwarnings("ignore") 27 | 28 | 29 | def random_seed(seed=42, rank=0): 30 | set_seed(seed) 31 | torch.manual_seed(seed + rank) 32 | np.random.seed(seed + rank) 33 | random.seed(seed + rank) 34 | try: 35 | import deepspeed 36 | deepspeed.runtime.utils.set_random_seed(seed + rank) 37 | except: 38 | print("deepspeed.runtime.utils.set_random_seed is not available") 39 | 40 | 41 | def setup_for_distributed(is_master): 42 | """ 43 | This function disables printing when not in master process 44 | """ 45 | builtin_print = builtins.print 46 | 47 | def print(*args, **kwargs): 48 | force = kwargs.pop('force', False) 49 | if is_master or force: 50 | now = datetime.datetime.now().time() 51 | builtin_print('[{}] '.format(now), end='') 52 | builtin_print(*args, **kwargs) 53 | 54 | builtins.print = print 55 | 56 | 57 | def setup_wandb_env(wandb_api_key=None): 58 | os.environ["WANDB_API_KEY"] = wandb_api_key or '' 59 | os.environ["WANDB_MODE"] = "offline" 60 | os.environ["WANDB__SERVICE_WAIT"] = "300" 61 | os.environ["WANDB_CONFIG_DIR"] = "./wandb" 62 | 63 | 64 | def main(): 65 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 66 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 67 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 68 | else: 69 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 70 | 71 | training_args.ddp_find_unused_parameters = True 72 | training_args.multiple_optimizer_training = False 73 | training_args.one_minus_one_data_transform = data_args.one_minus_one_data_transform 74 | training_args.cost_gradient_penalty = model_args.cost_gradient_penalty 75 | setup_wandb_env(training_args.wandb_api_key) 76 | 77 | # Setup logging 78 | logging.basicConfig( 79 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 80 | datefmt="%m/%d/%Y %H:%M:%S", 81 | handlers=[logging.StreamHandler(sys.stdout)], 82 | ) 83 | 84 | if training_args.should_log: 85 | transformers.utils.logging.set_verbosity_info() 86 | 87 | log_level = training_args.get_process_log_level() 88 | logger.setLevel(log_level) 89 | transformers.utils.logging.set_verbosity(log_level) 90 | transformers.utils.logging.enable_default_handler() 91 | transformers.utils.logging.enable_explicit_format() 92 | 93 | # Log on each process the small summary: 94 | logger.warning( 95 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, " 96 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}" 97 | ) 98 | logger.info(f"Training/evaluation parameters {training_args}") 99 | logger.info(f"Model parameters {model_args}") 100 | logger.info(f"Data parameters {data_args}") 101 | 102 | # Detecting last checkpoint. 103 | last_checkpoint = None 104 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 105 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 106 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 107 | raise ValueError( 108 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 109 | "Use --overwrite_output_dir to overcome." 110 | ) 111 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 112 | logger.info( 113 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 114 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 115 | ) 116 | 117 | random_seed(training_args.seed) 118 | data_args.seed = training_args.seed 119 | training_args.model_type = "image" 120 | 121 | from models.SD_with_MetaCLIP import SDModel 122 | from config import SDConfig 123 | 124 | config = SDConfig() 125 | config.tta.gradient_descent.train_steps = training_args.train_steps 126 | config.visual_pattern = training_args.visual_pattern 127 | config.clip_image_size = training_args.clip_image_size 128 | config.metaclip_version = training_args.metaclip_version 129 | model = SDModel(config) 130 | 131 | # print model parameters 132 | logger.info(f"{str(model)}") 133 | model.cuda() 134 | 135 | from data import get_cc3m_wds_dataset_and_collator 136 | wds_dataset, wds_collator = get_cc3m_wds_dataset_and_collator(data_args, model_args) 137 | 138 | if config.model.freeze_class_embeds: 139 | params = [] 140 | for key,parm in model.named_parameters(): 141 | if 'final_fc' not in key: 142 | params.append(parm) 143 | 144 | optimizer = torch.optim.SGD( 145 | params, lr=training_args.learning_rate, 146 | weight_decay=training_args.weight_decay, 147 | momentum=config.tta.gradient_descent.optimizer_momentum 148 | ) 149 | scheduler = None 150 | 151 | trainer = CustomTrainer( 152 | model=model, 153 | args=training_args, 154 | train_dataset=wds_dataset, 155 | data_collator=wds_collator, 156 | optimizers=(optimizer, scheduler) 157 | ) 158 | 159 | setup_for_distributed(torch.distributed.get_rank() == 0) 160 | 161 | from callbacks import ModelCallback 162 | trainer.add_callback(ModelCallback) 163 | 164 | # Evaluation 165 | if training_args.local_rank == 0: 166 | print("CLIP's Performance on MMVP-VLM —— Before Generative Fine-tuning") 167 | results_before = official_evaluation(model.class_model.model, config) 168 | print(results_before) 169 | 170 | # Training 171 | if training_args.do_train: 172 | checkpoint = None 173 | if training_args.resume_from_checkpoint is not None: 174 | checkpoint = training_args.resume_from_checkpoint 175 | elif last_checkpoint is not None: 176 | checkpoint = last_checkpoint 177 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 178 | trainer.save_model(output_dir=training_args.output_dir) 179 | trainer.log_metrics("train", train_result.metrics) 180 | trainer.save_metrics("train", train_result.metrics) 181 | trainer.save_state() 182 | 183 | # Evaluation 184 | if training_args.local_rank == 0: 185 | print("CLIP's Performance on MMVP-VLM —— After Generative Fine-tuning") 186 | model_weight_save_path = os.path.join(training_args.output_dir, 'CLIP_after_GenFT.pth') 187 | torch.save(trainer.model.state_dict(), model_weight_save_path) 188 | results_final_after = official_evaluation(trainer.model.class_model.model, config) 189 | print(results_final_after) 190 | save_results(results_before, results_final_after, output_dir=training_args.output_dir) 191 | 192 | 193 | def benchmark_model(model, benchmark_dir, device = "cpu", config=None): 194 | 195 | if config.metaclip_version == "large": 196 | _, _, preprocess = create_model_and_transforms(model_name='ViT-L-14-quickgelu', pretrained="pretrained_weights/CLIP/MetaCLIP/l14_fullcc2.5b.pt") 197 | if config.metaclip_version == "huge": 198 | _, _, preprocess = create_model_and_transforms(model_name='ViT-H-14-quickgelu', pretrained="pretrained_weights/CLIP/MetaCLIPz/h14_fullcc2.5b.pt") 199 | 200 | image_dir = os.path.join(benchmark_dir, 'MLLM_VLM_Images') 201 | csv_file = os.path.join(benchmark_dir, 'Questions.csv') 202 | 203 | csv_outfile = open('Prediction_Results_MetaCLIP', 'w', newline='') 204 | csv_writer = csv.writer(csv_outfile) 205 | csv_writer.writerow(['qid1', 'qid2', 'pred1', 'pred2', 'gt1', 'gt2', 'q1score', 'q2score']) # header 206 | 207 | categories = [ 208 | 'Orientation and Direction', 'Presence of Specific Features', 209 | 'State and Condition', 'Quantity and Count', 210 | 'Positional and Relational Context', 'Color and Appearance', 211 | 'Structural Characteristics', 'Texts', 212 | 'Viewpoint and Perspective' 213 | ] 214 | 215 | pair_accuracies = {category: 0 for category in categories} 216 | num_pairs = 0 217 | 218 | with open(csv_file, 'r') as f: 219 | reader = csv.reader(f) 220 | next(reader) 221 | for i, row in tqdm(enumerate(reader)): 222 | qid1, qtype1, statement1 = row 223 | 224 | # Get next row for the pair 225 | row = next(reader, None) 226 | if not row: 227 | break 228 | qid2, qtype2, statement2 = row 229 | 230 | qid1, qid2 = int(qid1), int(qid2) 231 | 232 | img1 = Image.open(os.path.join(image_dir, qtype1, f'{qid1}.jpg')) 233 | img2 = Image.open(os.path.join(image_dir, qtype1, f'{qid2}.jpg')) 234 | 235 | text1 = 'a photo of ' + statement1 236 | text2 = 'a photo of ' + statement2 237 | 238 | text1 = tokenize(text1).to(device) 239 | text2 = tokenize(text2).to(device) 240 | 241 | img1 = preprocess(img1).unsqueeze(0).to(device) 242 | img2 = preprocess(img2).unsqueeze(0).to(device) 243 | imgs = torch.cat((img1, img2), dim=0) 244 | 245 | with torch.no_grad(): 246 | model.eval().float() 247 | 248 | # original code 249 | # image_features = model.encode_image(imgs) 250 | 251 | # ours 252 | global_image_features = model.encode_image(imgs, normalize=True)[:,0,:] 253 | local_image_features = model.encode_image(imgs, normalize=True)[:,1:,:].mean(dim=1) 254 | image_features = global_image_features + local_image_features 255 | 256 | text1_features = model.encode_text(text1) 257 | text2_features = model.encode_text(text2) 258 | image_features = F.normalize(image_features, dim=-1) 259 | text1_features = F.normalize(text1_features, dim=-1) 260 | text2_features = F.normalize(text2_features, dim=-1) 261 | 262 | logits_per_image1 = 100.0 * image_features @ text1_features.T 263 | logits_per_text1 = logits_per_image1.T 264 | logits_per_image2 = 100.0 * image_features @ text2_features.T 265 | logits_per_text2 = logits_per_image2.T 266 | 267 | probs1 = logits_per_text1.softmax(dim=-1).cpu().numpy() 268 | probs2 = logits_per_text2.softmax(dim=-1).cpu().numpy() 269 | 270 | img1_score1 = probs1[0][0] 271 | img1_score2 = probs2[0][0] 272 | 273 | pred1 = "img1" if img1_score1 > 0.5 else "img2" 274 | pred2 = "img1" if img1_score2 > 0.5 else "img2" 275 | 276 | gt1 = "img1" if qid1 % 2 == 1 else "img2" 277 | gt2 = "img1" if qid2 % 2 == 1 else "img2" 278 | 279 | csv_writer.writerow([qid1, qid2, pred1, pred2, gt1, gt2, img1_score1, img1_score2]) 280 | 281 | current_category = categories[num_pairs // 15] 282 | if pred1 == gt1 and pred2 == gt2: 283 | pair_accuracies[current_category] += 1 284 | num_pairs += 1 285 | 286 | csv_outfile.close() 287 | 288 | # Calculate percentage accuracies 289 | Category_Score_List = [] 290 | 291 | for category in pair_accuracies: 292 | pair_accuracies[category] = (pair_accuracies[category] / (num_pairs // len(categories))) * 100 293 | Category_Score_List.append(pair_accuracies[category]) 294 | 295 | pair_accuracies['average_score'] = sum(Category_Score_List)/len(Category_Score_List) 296 | 297 | return pair_accuracies 298 | 299 | def official_evaluation(clip_model, config): 300 | 301 | with torch.no_grad(): 302 | clip_model.eval() 303 | 304 | # models 305 | data = "dataset/MMVP_VLM" 306 | clip_model_device = next(clip_model.parameters()).device 307 | 308 | if config.metaclip_version == "large": 309 | results_openai = {f'MetaCLIP-ViT-L-14': benchmark_model(clip_model, data, clip_model_device, config)} 310 | if config.metaclip_version == "huge": 311 | results_openai = {f'MetaCLIP-ViT-H-14': benchmark_model(clip_model, data, clip_model_device, config)} 312 | 313 | # Merge results 314 | results = {**results_openai} 315 | 316 | # Convert results to format suitable for star plot 317 | categories = results[list(results.keys())[0]].keys() 318 | data = {'Categories': list(categories)} 319 | for model in list(results_openai.keys()): 320 | data[model] = [results[model][category] for category in categories] 321 | 322 | return results 323 | 324 | def save_results(results_before, results_final_after, output_dir, filename='pred_result.json'): 325 | 326 | os.makedirs(output_dir, exist_ok=True) 327 | 328 | output_data = { 329 | 'results_before': results_before, 330 | 'results_final_after': results_final_after 331 | } 332 | 333 | output_path = os.path.join(output_dir, filename) 334 | 335 | with open(output_path, 'w') as f: 336 | json.dump(output_data, f, indent=4) 337 | 338 | 339 | if __name__ == "__main__": 340 | main() 341 | -------------------------------------------------------------------------------- /run_DIVA_with_OpenAICLIP.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import random 5 | import datetime 6 | import builtins 7 | sys.path.append(os.getcwd()) 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 9 | import numpy as np 10 | import torch 11 | import transformers 12 | from transformers.trainer_utils import set_seed 13 | from transformers import HfArgumentParser 14 | from transformers.trainer_utils import get_last_checkpoint 15 | from transformers.utils.versions import require_version 16 | from trainer import CustomTrainer 17 | from arguments import DataTrainingArguments, ModelArguments, TrainingArguments 18 | from clip import load 19 | import clip 20 | import csv 21 | from tqdm import tqdm 22 | from PIL import Image 23 | import json 24 | logger = logging.getLogger(__name__) 25 | import warnings 26 | warnings.filterwarnings("ignore") 27 | 28 | 29 | def random_seed(seed=42, rank=0): 30 | set_seed(seed) 31 | torch.manual_seed(seed + rank) 32 | np.random.seed(seed + rank) 33 | random.seed(seed + rank) 34 | try: 35 | import deepspeed 36 | deepspeed.runtime.utils.set_random_seed(seed + rank) 37 | except: 38 | print("deepspeed.runtime.utils.set_random_seed is not available") 39 | 40 | 41 | def setup_for_distributed(is_master): 42 | """ 43 | This function disables printing when not in master process 44 | """ 45 | builtin_print = builtins.print 46 | 47 | def print(*args, **kwargs): 48 | force = kwargs.pop('force', False) 49 | if is_master or force: 50 | now = datetime.datetime.now().time() 51 | builtin_print('[{}] '.format(now), end='') 52 | builtin_print(*args, **kwargs) 53 | 54 | builtins.print = print 55 | 56 | 57 | def setup_wandb_env(wandb_api_key=None): 58 | os.environ["WANDB_API_KEY"] = wandb_api_key or '' 59 | os.environ["WANDB_MODE"] = "offline" 60 | os.environ["WANDB__SERVICE_WAIT"] = "300" 61 | os.environ["WANDB_CONFIG_DIR"] = "./wandb" 62 | 63 | 64 | def main(): 65 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 66 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 67 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 68 | else: 69 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 70 | 71 | training_args.ddp_find_unused_parameters = True 72 | training_args.multiple_optimizer_training = False 73 | training_args.one_minus_one_data_transform = data_args.one_minus_one_data_transform 74 | training_args.cost_gradient_penalty = model_args.cost_gradient_penalty 75 | setup_wandb_env(training_args.wandb_api_key) 76 | 77 | # Setup logging 78 | logging.basicConfig( 79 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 80 | datefmt="%m/%d/%Y %H:%M:%S", 81 | handlers=[logging.StreamHandler(sys.stdout)], 82 | ) 83 | 84 | if training_args.should_log: 85 | transformers.utils.logging.set_verbosity_info() 86 | 87 | log_level = training_args.get_process_log_level() 88 | logger.setLevel(log_level) 89 | transformers.utils.logging.set_verbosity(log_level) 90 | transformers.utils.logging.enable_default_handler() 91 | transformers.utils.logging.enable_explicit_format() 92 | 93 | # Log on each process the small summary: 94 | logger.warning( 95 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, " 96 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}" 97 | ) 98 | logger.info(f"Training/evaluation parameters {training_args}") 99 | logger.info(f"Model parameters {model_args}") 100 | logger.info(f"Data parameters {data_args}") 101 | 102 | # Detecting last checkpoint. 103 | last_checkpoint = None 104 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 105 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 106 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 107 | raise ValueError( 108 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 109 | "Use --overwrite_output_dir to overcome." 110 | ) 111 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 112 | logger.info( 113 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 114 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 115 | ) 116 | 117 | # data_args.data_seed 118 | random_seed(training_args.seed) 119 | data_args.seed = training_args.seed 120 | training_args.model_type = "image" 121 | 122 | from models.SD_with_OpenAICLIP import SDModel 123 | from config import SDConfig 124 | 125 | config = SDConfig() 126 | config.tta.gradient_descent.train_steps = training_args.train_steps 127 | config.visual_pattern = training_args.visual_pattern 128 | config.clip_image_size = training_args.clip_image_size 129 | model = SDModel(config) 130 | 131 | # print model parameters 132 | logger.info(f"{str(model)}") 133 | model.cuda() 134 | 135 | from data import get_cc3m_wds_dataset_and_collator 136 | wds_dataset, wds_collator = get_cc3m_wds_dataset_and_collator(data_args, model_args) 137 | 138 | if config.model.freeze_class_embeds: 139 | params = [] 140 | for key,parm in model.named_parameters(): 141 | if 'final_fc' not in key: 142 | params.append(parm) 143 | 144 | optimizer = torch.optim.SGD( 145 | params, lr=training_args.learning_rate, 146 | weight_decay=training_args.weight_decay, 147 | momentum=config.tta.gradient_descent.optimizer_momentum 148 | ) 149 | scheduler = None 150 | 151 | trainer = CustomTrainer( 152 | model=model, 153 | args=training_args, 154 | train_dataset=wds_dataset, 155 | data_collator=wds_collator, 156 | optimizers=(optimizer, scheduler) 157 | ) 158 | setup_for_distributed(torch.distributed.get_rank() == 0) 159 | 160 | from callbacks import ModelCallback 161 | trainer.add_callback(ModelCallback) 162 | 163 | # Evaluation 164 | if training_args.local_rank == 0: 165 | print("CLIP's Performance on MMVP-VLM —— Before Generative Fine-tuning") 166 | results_before = official_evaluation(model.class_model.model, config) 167 | print(results_before) 168 | 169 | # Training 170 | if training_args.do_train: 171 | checkpoint = None 172 | if training_args.resume_from_checkpoint is not None: 173 | checkpoint = training_args.resume_from_checkpoint 174 | elif last_checkpoint is not None: 175 | checkpoint = last_checkpoint 176 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 177 | trainer.save_model(output_dir=training_args.output_dir) 178 | trainer.log_metrics("train", train_result.metrics) 179 | trainer.save_metrics("train", train_result.metrics) 180 | trainer.save_state() 181 | 182 | # Evaluation 183 | if training_args.local_rank == 0: 184 | print("CLIP's Performance on MMVP-VLM —— After Generative Fine-tuning") 185 | model_weight_save_path = os.path.join(training_args.output_dir, 'CLIP_after_GenFT.pth') 186 | torch.save(trainer.model.state_dict(), model_weight_save_path) 187 | results_final_after = official_evaluation(trainer.model.class_model.model, config) 188 | print(results_final_after) 189 | save_results(results_before, results_final_after, output_dir=training_args.output_dir) 190 | 191 | 192 | def benchmark_model(base_model, model, benchmark_dir, device = "cpu"): 193 | 194 | _, preprocess = load(base_model, device=device) 195 | image_dir = os.path.join(benchmark_dir, 'MLLM_VLM_Images') 196 | csv_file = os.path.join(benchmark_dir, 'Questions.csv') 197 | 198 | csv_outfile = open('Prediction_Results_OpenAICLIP', 'w', newline='') 199 | csv_writer = csv.writer(csv_outfile) 200 | csv_writer.writerow(['qid1', 'qid2', 'pred1', 'pred2', 'gt1', 'gt2', 'q1score', 'q2score']) # header 201 | 202 | categories = [ 203 | 'Orientation and Direction', 'Presence of Specific Features', 204 | 'State and Condition', 'Quantity and Count', 205 | 'Positional and Relational Context', 'Color and Appearance', 206 | 'Structural Characteristics', 'Texts', 207 | 'Viewpoint and Perspective' 208 | ] 209 | 210 | pair_accuracies = {category: 0 for category in categories} 211 | num_pairs = 0 212 | 213 | with open(csv_file, 'r') as f: 214 | reader = csv.reader(f) 215 | next(reader) # skip header 216 | for i, row in tqdm(enumerate(reader)): 217 | qid1, qtype1, statement1 = row 218 | 219 | # Get next row for the pair 220 | row = next(reader, None) 221 | if not row: 222 | break 223 | qid2, qtype2, statement2 = row 224 | 225 | qid1, qid2 = int(qid1), int(qid2) 226 | 227 | img1 = Image.open(os.path.join(image_dir, qtype1, f'{qid1}.jpg')) 228 | img2 = Image.open(os.path.join(image_dir, qtype1, f'{qid2}.jpg')) 229 | 230 | text1 = 'a photo of ' + statement1 231 | text2 = 'a photo of ' + statement2 232 | 233 | text1 = clip.tokenize([text1]).to(device) 234 | text2 = clip.tokenize([text2]).to(device) 235 | 236 | img1 = preprocess(img1).unsqueeze(0).to(device) 237 | img2 = preprocess(img2).unsqueeze(0).to(device) 238 | imgs = torch.cat((img1, img2), dim=0) 239 | 240 | with torch.no_grad(): 241 | model.eval().float() 242 | logits_per_image1, logits_per_text1 = model(imgs, text1) 243 | logits_per_image2, logits_per_text2 = model(imgs, text2) 244 | 245 | probs1 = logits_per_text1.softmax(dim=-1).cpu().numpy() 246 | probs2 = logits_per_text2.softmax(dim=-1).cpu().numpy() 247 | 248 | img1_score1 = probs1[0][0] 249 | img1_score2 = probs2[0][0] 250 | 251 | pred1 = "img1" if img1_score1 > 0.5 else "img2" 252 | pred2 = "img1" if img1_score2 > 0.5 else "img2" 253 | 254 | gt1 = "img1" if qid1 % 2 == 1 else "img2" 255 | gt2 = "img1" if qid2 % 2 == 1 else "img2" 256 | 257 | csv_writer.writerow([qid1, qid2, pred1, pred2, gt1, gt2, img1_score1, img1_score2]) 258 | 259 | current_category = categories[num_pairs // 15] 260 | if pred1 == gt1 and pred2 == gt2: 261 | pair_accuracies[current_category] += 1 262 | num_pairs += 1 263 | 264 | csv_outfile.close() 265 | 266 | # Calculate percentage accuracies 267 | Category_Score_List = [] 268 | 269 | for category in pair_accuracies: 270 | pair_accuracies[category] = (pair_accuracies[category] / (num_pairs // len(categories))) * 100 271 | Category_Score_List.append(pair_accuracies[category]) 272 | 273 | pair_accuracies['average_score'] = sum(Category_Score_List)/len(Category_Score_List) 274 | 275 | return pair_accuracies 276 | 277 | def official_evaluation(clip_model, config): 278 | 279 | with torch.no_grad(): 280 | clip_model.eval() 281 | 282 | # models 283 | data = "dataset/MMVP_VLM" 284 | if config.clip_image_size == 224: 285 | base_model = "pretrained_weights/CLIP/ViT-L-14.pt" 286 | if config.clip_image_size == 336: 287 | base_model = "pretrained_weights/CLIP/ViT-L-14-336px.pt" 288 | clip_model_device = next(clip_model.parameters()).device 289 | clip_model_name = base_model.split('/')[-1].split('.')[0] 290 | results_openai = {f'openai-{clip_model_name}': benchmark_model(base_model, clip_model, data, clip_model_device)} 291 | 292 | # Merge results 293 | results = {**results_openai} 294 | 295 | # Convert results to format suitable for star plot 296 | categories = results[list(results.keys())[0]].keys() 297 | data = {'Categories': list(categories)} 298 | for model in list(results_openai.keys()): 299 | data[model] = [results[model][category] for category in categories] 300 | 301 | return results 302 | 303 | def save_results(results_before, results_final_after, output_dir, filename='pred_result.json'): 304 | 305 | os.makedirs(output_dir, exist_ok=True) 306 | 307 | output_data = { 308 | 'results_before': results_before, 309 | 'results_final_after': results_final_after 310 | } 311 | 312 | output_path = os.path.join(output_dir, filename) 313 | 314 | with open(output_path, 'w') as f: 315 | json.dump(output_data, f, indent=4) 316 | 317 | 318 | if __name__ == "__main__": 319 | main() 320 | -------------------------------------------------------------------------------- /run_DIVA_with_SigLIP.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import random 5 | import datetime 6 | import builtins 7 | sys.path.append(os.getcwd()) 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | import transformers 13 | from transformers.trainer_utils import set_seed 14 | from transformers import HfArgumentParser 15 | from transformers.trainer_utils import get_last_checkpoint 16 | from transformers.utils.versions import require_version 17 | from trainer import CustomTrainer 18 | from arguments import DataTrainingArguments, ModelArguments, TrainingArguments 19 | import csv 20 | from tqdm import tqdm 21 | from PIL import Image 22 | import json 23 | from open_clip import create_model_from_pretrained, get_tokenizer 24 | logger = logging.getLogger(__name__) 25 | import warnings 26 | warnings.filterwarnings("ignore") 27 | 28 | 29 | def random_seed(seed=42, rank=0): 30 | set_seed(seed) 31 | torch.manual_seed(seed + rank) 32 | np.random.seed(seed + rank) 33 | random.seed(seed + rank) 34 | try: 35 | import deepspeed 36 | deepspeed.runtime.utils.set_random_seed(seed + rank) 37 | except: 38 | print("deepspeed.runtime.utils.set_random_seed is not available") 39 | 40 | 41 | def setup_for_distributed(is_master): 42 | """ 43 | This function disables printing when not in master process 44 | """ 45 | builtin_print = builtins.print 46 | 47 | def print(*args, **kwargs): 48 | force = kwargs.pop('force', False) 49 | if is_master or force: 50 | now = datetime.datetime.now().time() 51 | builtin_print('[{}] '.format(now), end='') 52 | builtin_print(*args, **kwargs) 53 | 54 | builtins.print = print 55 | 56 | 57 | def setup_wandb_env(wandb_api_key=None): 58 | os.environ["WANDB_API_KEY"] = wandb_api_key or '' 59 | os.environ["WANDB_MODE"] = "offline" 60 | os.environ["WANDB__SERVICE_WAIT"] = "300" 61 | os.environ["WANDB_CONFIG_DIR"] = "./wandb" 62 | 63 | 64 | def main(): 65 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 66 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 67 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 68 | else: 69 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 70 | 71 | training_args.ddp_find_unused_parameters = True 72 | training_args.multiple_optimizer_training = False 73 | training_args.one_minus_one_data_transform = data_args.one_minus_one_data_transform 74 | training_args.cost_gradient_penalty = model_args.cost_gradient_penalty 75 | setup_wandb_env(training_args.wandb_api_key) 76 | 77 | # Setup logging 78 | logging.basicConfig( 79 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 80 | datefmt="%m/%d/%Y %H:%M:%S", 81 | handlers=[logging.StreamHandler(sys.stdout)], 82 | ) 83 | 84 | if training_args.should_log: 85 | transformers.utils.logging.set_verbosity_info() 86 | 87 | log_level = training_args.get_process_log_level() 88 | logger.setLevel(log_level) 89 | transformers.utils.logging.set_verbosity(log_level) 90 | transformers.utils.logging.enable_default_handler() 91 | transformers.utils.logging.enable_explicit_format() 92 | 93 | # Log on each process the small summary: 94 | logger.warning( 95 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, " 96 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}" 97 | ) 98 | logger.info(f"Training/evaluation parameters {training_args}") 99 | logger.info(f"Model parameters {model_args}") 100 | logger.info(f"Data parameters {data_args}") 101 | 102 | # Detecting last checkpoint. 103 | last_checkpoint = None 104 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 105 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 106 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 107 | raise ValueError( 108 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 109 | "Use --overwrite_output_dir to overcome." 110 | ) 111 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 112 | logger.info( 113 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 114 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 115 | ) 116 | 117 | random_seed(training_args.seed) 118 | data_args.seed = training_args.seed 119 | training_args.model_type = "image" 120 | 121 | from models.SD_with_SigLIP import SDModel 122 | from config import SDConfig 123 | 124 | config = SDConfig() 125 | config.tta.gradient_descent.train_steps = training_args.train_steps 126 | config.visual_pattern = training_args.visual_pattern 127 | config.clip_image_size = training_args.clip_image_size 128 | model = SDModel(config) 129 | 130 | # print model parameters 131 | logger.info(f"{str(model)}") 132 | model.cuda() 133 | 134 | from data import get_cc3m_wds_dataset_and_collator 135 | wds_dataset, wds_collator = get_cc3m_wds_dataset_and_collator(data_args, model_args) 136 | 137 | if config.model.freeze_class_embeds: 138 | params = [] 139 | for key,parm in model.named_parameters(): 140 | if 'final_fc' not in key: 141 | params.append(parm) 142 | 143 | optimizer = torch.optim.SGD( 144 | params, lr=training_args.learning_rate, 145 | weight_decay=training_args.weight_decay, 146 | momentum=config.tta.gradient_descent.optimizer_momentum 147 | ) 148 | scheduler = None 149 | 150 | trainer = CustomTrainer( 151 | model=model, 152 | args=training_args, 153 | train_dataset=wds_dataset, 154 | data_collator=wds_collator, 155 | optimizers=(optimizer, scheduler) 156 | ) 157 | setup_for_distributed(torch.distributed.get_rank() == 0) 158 | 159 | from callbacks import ModelCallback 160 | trainer.add_callback(ModelCallback) 161 | 162 | # Evaluation 163 | if training_args.local_rank == 0: 164 | print("CLIP's Performance on MMVP-VLM —— Before Generative Fine-tuning") 165 | results_before = official_evaluation(model.class_model.model, config) 166 | print(results_before) 167 | 168 | # Training 169 | if training_args.do_train: 170 | checkpoint = None 171 | if training_args.resume_from_checkpoint is not None: 172 | checkpoint = training_args.resume_from_checkpoint 173 | elif last_checkpoint is not None: 174 | checkpoint = last_checkpoint 175 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 176 | trainer.save_model(output_dir=training_args.output_dir) 177 | trainer.log_metrics("train", train_result.metrics) 178 | trainer.save_metrics("train", train_result.metrics) 179 | trainer.save_state() 180 | 181 | # Evaluation 182 | if training_args.local_rank == 0: 183 | print("CLIP's Performance on MMVP-VLM —— After Generative Fine-tuning") 184 | model_weight_save_path = os.path.join(training_args.output_dir, 'CLIP_after_GenFT.pth') 185 | torch.save(trainer.model.state_dict(), model_weight_save_path) 186 | results_final_after = official_evaluation(trainer.model.class_model.model, config) 187 | print(results_final_after) 188 | save_results(results_before, results_final_after, output_dir=training_args.output_dir) 189 | 190 | 191 | def benchmark_model(model, benchmark_dir, device = "cpu", config=None): 192 | if config.clip_image_size == 224: 193 | _, preprocess = create_model_from_pretrained(model_name='ViT-SO400M-14-SigLIP', pretrained="pretrained_weights/CLIP/ViT-SO400M-14-SigLIP/open_clip_pytorch_model.bin", device=device, 194 | image_mean=([0.5,0.5,0.5]), image_std=([0.5,0.5,0.5]), image_interpolation="bicubic", image_resize_mode="squash") 195 | tokenizer = get_tokenizer('ViT-SO400M-14-SigLIP') 196 | if config.clip_image_size == 384: 197 | _, preprocess = create_model_from_pretrained(model_name='ViT-SO400M-14-SigLIP-384', pretrained="pretrained_weights/CLIP/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin", device=device, 198 | image_mean=([0.5,0.5,0.5]), image_std=([0.5,0.5,0.5]), image_interpolation="bicubic", image_resize_mode="squash") 199 | tokenizer = get_tokenizer('ViT-SO400M-14-SigLIP-384') 200 | 201 | image_dir = os.path.join(benchmark_dir, 'MLLM_VLM_Images') 202 | csv_file = os.path.join(benchmark_dir, 'Questions.csv') 203 | 204 | csv_outfile = open('Prediction_Results_SigLIP', 'w', newline='') 205 | csv_writer = csv.writer(csv_outfile) 206 | csv_writer.writerow(['qid1', 'qid2', 'pred1', 'pred2', 'gt1', 'gt2', 'q1score', 'q2score']) # header 207 | 208 | categories = [ 209 | 'Orientation and Direction', 'Presence of Specific Features', 210 | 'State and Condition', 'Quantity and Count', 211 | 'Positional and Relational Context', 'Color and Appearance', 212 | 'Structural Characteristics', 'Texts', 213 | 'Viewpoint and Perspective' 214 | ] 215 | 216 | pair_accuracies = {category: 0 for category in categories} 217 | num_pairs = 0 218 | 219 | with open(csv_file, 'r') as f: 220 | reader = csv.reader(f) 221 | next(reader) 222 | for i, row in tqdm(enumerate(reader)): 223 | qid1, qtype1, statement1 = row 224 | 225 | # Get next row for the pair 226 | row = next(reader, None) 227 | if not row: 228 | break 229 | qid2, qtype2, statement2 = row 230 | 231 | qid1, qid2 = int(qid1), int(qid2) 232 | 233 | img1 = Image.open(os.path.join(image_dir, qtype1, f'{qid1}.jpg')) 234 | img2 = Image.open(os.path.join(image_dir, qtype1, f'{qid2}.jpg')) 235 | 236 | text1 = 'a photo of ' + statement1 237 | text2 = 'a photo of ' + statement2 238 | 239 | text1 = tokenizer(text1, context_length=model.context_length).to(device) 240 | text2 = tokenizer(text2, context_length=model.context_length).to(device) 241 | 242 | img1 = preprocess(img1).unsqueeze(0).to(device) 243 | img2 = preprocess(img2).unsqueeze(0).to(device) 244 | imgs = torch.cat((img1, img2), dim=0) 245 | 246 | with torch.no_grad(), torch.cuda.amp.autocast(): 247 | model.eval().float() 248 | 249 | # original code 250 | # image_features = model.encode_image(imgs) 251 | 252 | # ours 253 | image_features = model.encode_image(imgs)[:,0,:] 254 | 255 | text1_features = model.encode_text(text1) 256 | text2_features = model.encode_text(text2) 257 | image_features = F.normalize(image_features, dim=-1) 258 | text1_features = F.normalize(text1_features, dim=-1) 259 | text2_features = F.normalize(text2_features, dim=-1) 260 | logits_per_image1 = image_features @ text1_features.T * model.logit_scale.exp() + model.logit_bias 261 | logits_per_text1 = logits_per_image1.T 262 | logits_per_image2 = image_features @ text2_features.T * model.logit_scale.exp() + model.logit_bias 263 | logits_per_text2 = logits_per_image2.T 264 | 265 | probs1 = logits_per_text1.softmax(dim=-1).cpu().numpy() 266 | probs2 = logits_per_text2.softmax(dim=-1).cpu().numpy() 267 | 268 | img1_score1 = probs1[0][0] 269 | img1_score2 = probs2[0][0] 270 | 271 | pred1 = "img1" if img1_score1 > 0.5 else "img2" 272 | pred2 = "img1" if img1_score2 > 0.5 else "img2" 273 | 274 | gt1 = "img1" if qid1 % 2 == 1 else "img2" 275 | gt2 = "img1" if qid2 % 2 == 1 else "img2" 276 | 277 | csv_writer.writerow([qid1, qid2, pred1, pred2, gt1, gt2, img1_score1, img1_score2]) 278 | 279 | current_category = categories[num_pairs // 15] 280 | if pred1 == gt1 and pred2 == gt2: 281 | pair_accuracies[current_category] += 1 282 | num_pairs += 1 283 | 284 | csv_outfile.close() 285 | 286 | # Calculate percentage accuracies 287 | Category_Score_List = [] 288 | 289 | for category in pair_accuracies: 290 | pair_accuracies[category] = (pair_accuracies[category] / (num_pairs // len(categories))) * 100 291 | Category_Score_List.append(pair_accuracies[category]) 292 | 293 | pair_accuracies['average_score'] = sum(Category_Score_List)/len(Category_Score_List) 294 | 295 | return pair_accuracies 296 | 297 | def official_evaluation(clip_model, config): 298 | 299 | with torch.no_grad(): 300 | clip_model.eval() 301 | 302 | # models 303 | data = "dataset/MMVP_VLM" 304 | clip_model_device = next(clip_model.parameters()).device 305 | if config.clip_image_size == 224: 306 | results_openai = {f'ViT-SO400M-14-SigLIP': benchmark_model(clip_model, data, clip_model_device, config)} 307 | if config.clip_image_size == 384: 308 | results_openai = {f'ViT-SO400M-14-SigLIP-384': benchmark_model(clip_model, data, clip_model_device, config)} 309 | 310 | # Merge results 311 | results = {**results_openai} 312 | 313 | # Convert results to format suitable for star plot 314 | categories = results[list(results.keys())[0]].keys() 315 | data = {'Categories': list(categories)} 316 | for model in list(results_openai.keys()): 317 | data[model] = [results[model][category] for category in categories] 318 | 319 | return results 320 | 321 | def save_results(results_before, results_final_after, output_dir, filename='pred_result.json'): 322 | 323 | os.makedirs(output_dir, exist_ok=True) 324 | 325 | output_data = { 326 | 'results_before': results_before, 327 | 'results_final_after': results_final_after 328 | } 329 | 330 | output_path = os.path.join(output_dir, filename) 331 | 332 | with open(output_path, 'w') as f: 333 | json.dump(output_data, f, indent=4) 334 | 335 | 336 | if __name__ == "__main__": 337 | main() 338 | --------------------------------------------------------------------------------