├── DIVA_for_DFN.sh
├── DIVA_for_MetaCLIP.sh
├── DIVA_for_OpenAICLIP.sh
├── DIVA_for_SigLIP.sh
├── LICENSE
├── README.md
├── accelerator.json
├── arguments.py
├── assets
├── introduction.png
├── methodology.png
└── qualitative_mmvp.png
├── callbacks.py
├── condition
├── DFN_for_openclip_transformer.py
├── MetaCLIP_for_openclip_transformer.py
├── OpenAICLIP_for_clip_model.py
└── SigLIP_for_timm_models_visiontransformer.py
├── config.py
├── data
├── __init__.py
├── constants.py
├── data.py
├── image_data.py
├── register.py
└── transform.py
├── models
├── CLIP_bank.py
├── SD_with_DFN.py
├── SD_with_MetaCLIP.py
├── SD_with_OpenAICLIP.py
├── SD_with_SigLIP.py
├── build.py
└── utils.py
├── requirements.txt
├── run_DIVA_with_DFN.py
├── run_DIVA_with_MetaCLIP.py
├── run_DIVA_with_OpenAICLIP.py
├── run_DIVA_with_SigLIP.py
└── trainer.py
/DIVA_for_DFN.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 |
3 | set -x
4 |
5 | export TORCH_DISTRIBUTED_DEBUG=INFO
6 | export NCCL_DEBUG=INFO
7 | export OMP_NUM_THREADS=4
8 | export NCCL_P2P_DISABLE=1
9 | MASTER_ADDR=$(hostname -I | awk '{print $1}')
10 | MASTER_PORT=12345
11 |
12 | run_name=DIVA_for_DFN
13 | num_steps=4600
14 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 \
15 | --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} --use_env \
16 | run_DIVA_with_DFN.py \
17 | --clip_image_size 224 \
18 | --visual_pattern None \
19 | --train_steps 2 \
20 | --image_size 512 \
21 | --fixed_image_size False \
22 | --dataset_path dataset/cc3m/\*.tar \
23 | --output_dir ./outputs/$run_name \
24 | --remove_unused_columns False \
25 | --do_train \
26 | --ddp_find_unused_parameters True \
27 | --dataloader_num_workers 8 \
28 | --learning_rate 1e-4 \
29 | --bf16 True \
30 | --tf32 True \
31 | --warmup_ratio 0.005 \
32 | --weight_decay 0 \
33 | --max_steps $num_steps \
34 | --per_device_train_batch_size 16 \
35 | --logging_strategy steps \
36 | --logging_steps 50 \
37 | --gradient_accumulation_steps 5 \
38 | --save_strategy steps \
39 | --save_steps $num_steps \
40 | --save_total_limit 1 \
41 | --ddp_backend nccl \
42 | --report_to wandb \
43 | --run_name $run_name \
44 | --enable_flash True \
45 | --lr_scheduler_type "cosine" \
46 | --seed 42 \
47 | --accelerator_config accelerator.json > ./logs/debug_$run_name.log
48 |
--------------------------------------------------------------------------------
/DIVA_for_MetaCLIP.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 |
3 | set -x
4 |
5 | export TORCH_DISTRIBUTED_DEBUG=INFO
6 | export NCCL_DEBUG=INFO
7 | export OMP_NUM_THREADS=4
8 | export NCCL_P2P_DISABLE=1
9 | MASTER_ADDR=$(hostname -I | awk '{print $1}')
10 | MASTER_PORT=12345
11 |
12 | run_name=DIVA_for_MetaCLIP
13 | num_steps=4600
14 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 \
15 | --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} --use_env \
16 | run_DIVA_with_MetaCLIP.py \
17 | --metaclip_version large \
18 | --clip_image_size 224 \
19 | --visual_pattern None \
20 | --train_steps 2 \
21 | --image_size 512 \
22 | --fixed_image_size False \
23 | --dataset_path dataset/cc3m/\*.tar \
24 | --output_dir ./outputs/$run_name \
25 | --remove_unused_columns False \
26 | --do_train \
27 | --ddp_find_unused_parameters True \
28 | --dataloader_num_workers 8 \
29 | --learning_rate 1e-4 \
30 | --bf16 True \
31 | --tf32 True \
32 | --warmup_ratio 0.005 \
33 | --weight_decay 0 \
34 | --max_steps $num_steps \
35 | --per_device_train_batch_size 16 \
36 | --logging_strategy steps \
37 | --logging_steps 50 \
38 | --gradient_accumulation_steps 5 \
39 | --save_strategy steps \
40 | --save_steps $num_steps \
41 | --save_total_limit 1 \
42 | --ddp_backend nccl \
43 | --report_to wandb \
44 | --run_name $run_name \
45 | --enable_flash True \
46 | --lr_scheduler_type "cosine" \
47 | --seed 42 \
48 | --accelerator_config accelerator.json > ./logs/debug_$run_name.log
49 |
--------------------------------------------------------------------------------
/DIVA_for_OpenAICLIP.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 |
3 | set -x
4 |
5 | export TORCH_DISTRIBUTED_DEBUG=INFO
6 | export NCCL_DEBUG=INFO
7 | export OMP_NUM_THREADS=4
8 | export NCCL_P2P_DISABLE=1
9 | MASTER_ADDR=$(hostname -I | awk '{print $1}')
10 | MASTER_PORT=12345
11 |
12 | run_name=DIVA_for_OpenAICLIP
13 | num_steps=4600
14 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 \
15 | --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} --use_env \
16 | run_DIVA_with_OpenAICLIP.py \
17 | --clip_image_size 224 \
18 | --visual_pattern None \
19 | --train_steps 2 \
20 | --image_size 512 \
21 | --fixed_image_size False \
22 | --dataset_path dataset/cc3m/\*.tar \
23 | --output_dir ./outputs/$run_name \
24 | --remove_unused_columns False \
25 | --do_train \
26 | --ddp_find_unused_parameters True \
27 | --dataloader_num_workers 8 \
28 | --learning_rate 1e-4 \
29 | --bf16 True \
30 | --tf32 True \
31 | --warmup_ratio 0.005 \
32 | --weight_decay 0 \
33 | --max_steps $num_steps \
34 | --per_device_train_batch_size 16 \
35 | --logging_strategy steps \
36 | --logging_steps 50 \
37 | --gradient_accumulation_steps 5 \
38 | --save_strategy steps \
39 | --save_steps $num_steps \
40 | --save_total_limit 1 \
41 | --ddp_backend nccl \
42 | --report_to wandb \
43 | --run_name $run_name \
44 | --enable_flash True \
45 | --lr_scheduler_type "cosine" \
46 | --seed 42 \
47 | --accelerator_config accelerator.json > ./logs/debug_$run_name.log
48 |
--------------------------------------------------------------------------------
/DIVA_for_SigLIP.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 |
3 | set -x
4 |
5 | export TORCH_DISTRIBUTED_DEBUG=INFO
6 | export NCCL_DEBUG=INFO
7 | export OMP_NUM_THREADS=4
8 | export NCCL_P2P_DISABLE=1
9 | MASTER_ADDR=$(hostname -I | awk '{print $1}')
10 | MASTER_PORT=12345
11 |
12 | run_name=DIVA_for_SigLIP
13 | num_steps=4600
14 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 \
15 | --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} --use_env \
16 | run_DIVA_with_SigLIP.py \
17 | --clip_image_size 224 \
18 | --visual_pattern None \
19 | --train_steps 2 \
20 | --image_size 512 \
21 | --fixed_image_size False \
22 | --dataset_path dataset/cc3m/\*.tar \
23 | --output_dir ./outputs/$run_name \
24 | --remove_unused_columns False \
25 | --do_train \
26 | --ddp_find_unused_parameters True \
27 | --dataloader_num_workers 8 \
28 | --learning_rate 1e-4 \
29 | --bf16 True \
30 | --tf32 True \
31 | --warmup_ratio 0.005 \
32 | --weight_decay 0 \
33 | --max_steps $num_steps \
34 | --per_device_train_batch_size 16 \
35 | --logging_strategy steps \
36 | --logging_steps 50 \
37 | --gradient_accumulation_steps 5 \
38 | --save_strategy steps \
39 | --save_steps $num_steps \
40 | --save_total_limit 1 \
41 | --ddp_backend nccl \
42 | --report_to wandb \
43 | --run_name $run_name \
44 | --enable_flash True \
45 | --lr_scheduler_type "cosine" \
46 | --seed 42 \
47 | --accelerator_config accelerator.json > ./logs/debug_$run_name.log
48 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 BAAI-Vision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | [Wenxuan Wang](https://scholar.google.com/citations?user=75OyC-oAAAAJ&hl=zh-CN)
1,2,3*, [Quan Sun](https://scholar.google.cz/citations?user=pVKiHdEAAAAJ&hl=zh-CN&oi=ao)
3*, [Fan Zhang](https://scholar.google.cz/citations?hl=zh-CN&user=VsJ39HMAAAAJ&view_op=list_works&sortby=pubdate)
3, [Yepeng Tang](https://scholar.google.cz/citations?user=CAC_4OUAAAAJ&hl=zh-CN&oi=ao)
4, [Jing Liu](https://scholar.google.com/citations?user=sOI-S7oAAAAJ&hl=zh-CN)
1,2, [Xinlong Wang](https://scholar.google.com/citations?hl=zh-CN&user=DPz0DjYAAAAJ&view_op=list_works&sortby=pubdate/)
3
6 |
7 |
1[CASIA](http://english.ia.cas.cn/),
2[UCAS](https://english.ucas.ac.cn/),
3[BAAI](https://www.baai.ac.cn/english.html),
4[BJTU](https://en.bjtu.edu.cn/)
* Equal Contribution
8 |
9 |
10 |
11 |
12 |
13 | ## ⏰ Schedule
14 |
15 | ### [2025-01-23] Our [paper](https://arxiv.org/abs/2407.20171) is accepted by ICLR 2025 ! 💥
16 | ### [2024-08-07] We release [CLIP model weights](https://huggingface.co/BAAI/DIVA) ! 💥
17 | ### [2024-08-05] We release [training & evaluation code](https://github.com/baaivision/DIVA) ! 💥
18 | ### [2024-07-30] Our [paper](https://arxiv.org/abs/2407.20171) is released on arXiv ! 💥
19 |
20 |
21 | ## 💡 Motivation
22 |
23 |
24 |
25 |
26 |
27 | In this work, we present a simple post-training approach for CLIP models, which largely overcomes its visual shortcomings via a self-supervised diffusion process. We introduce DIVA, which uses the DIffusion model as a Visual Assistant for CLIP. Specifically, DIVA leverages generative feedback from text-to-image diffusion models to optimize CLIP representations, with only images (w/o corresponding text). We demonstrate that DIVA improves CLIP's performance on the challenging MMVP-VLM benchmark which assesses fine-grained visual abilities to a large extent (e.g., 3-7% ↑), and enhances the performance of MLLMs and vision models on multimodal understanding and segmentation tasks. Extensive evaluation on 29 image classification and retrieval benchmarks confirms that DIVA preserves CLIP's strong zero-shot capabilities.
28 |
29 |
30 | ## 🤖 Architecture
31 |
32 |
33 |
34 |
35 |
36 | Given an image, the CLIP model encodes the visual features as the main part of condition, then the generative diffusion model predicts the added noise taking the noisy image and condition as input. We optimize the CLIP's representation by maximizing the image likelihood with the diffusion loss via generative feedback.
37 |
38 |
39 | ## 🔨 Installation
40 | Clone this repository and install the required packages:
41 |
42 | ```shell
43 | git clone https://github.com/baaivision/DIVA.git
44 | cd DIVA
45 | mkdir -p outputs logs datasets pretrained_weights/CLIP pretrained_weights/SD
46 |
47 | conda create -n diva python=3.9
48 | conda activate diva
49 | pip install -r requirements.txt
50 | ```
51 | Core packages:
52 | - [Pytorch](https://pytorch.org/) version 2.0.0
53 | - [open-clip-torch](https://github.com/mlfoundations/open_clip) version 2.24.0
54 | - [timm](https://github.com/rwightman/pytorch-image-models) version 0.9.8
55 |
56 |
57 | ## 🍹 Preparation for DIVA's Generative Fine-tuning
58 |
59 | ### Data Acquisition
60 | For data preparation, please refer to [image2dataset](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md) and [MMVP](https://github.com/tsb0601/MMVP/tree/main) for the employed training and evaluation data in this work. After collecting the corresponding datasets, directly put them into the `dataset/` folder path.
61 |
62 | ### Pre-trained Weight Downloading
63 | As for pre-trained weight preparation, please refer to [OpenAI ViT-L-14/224&336](https://github.com/openai/CLIP/blob/main/clip/clip.py), [MetaCLIP ViT-L/H-14](https://github.com/facebookresearch/metaclip), [SigLIP ViT-SO-14/224](https://huggingface.co/timm/ViT-SO400M-14-SigLIP), [SigLIP ViT-SO-14/384](https://huggingface.co/timm/ViT-SO400M-14-SigLIP-384), [DFN ViT-H-14/224](https://huggingface.co/apple/DFN5B-CLIP-ViT-H-14), [DFN ViT-H-14/378](https://huggingface.co/apple/DFN5B-CLIP-ViT-H-14-378) and [SD-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) to acquire the model weights for discriminative CLIP models and the leveraged diffusion model that provides generative feedback. After downloading all these necessary weights, move them respectively to the corresponding folder path `pretrained_weights/CLIP/` and `pretrained_weights/SD/`.
64 |
65 | ### Code Modification
66 | For the preparation for our DIVA's condition design, some source code in the installed [CLIP](https://github.com/openai/CLIP) and [OpenCLIP](https://github.com/mlfoundations/open_clip) packages need to be modified.
67 |
68 | For OpenAI CLIP, use the content in our provided `condition/OpenAICLIP_for_clip_model.py` to replace the content in `Your Conda Installation Path/anaconda3/envs/diva/lib/python3.9/site-packages/clip/model.py`.
69 |
70 | For MetaCLIP and DFN, use the content in our provided `condition/MetaCLIP_for_openclip_transformer.py` and `condition/DFN_for_openclip_transformer.py` to replace the content in `Your Conda Installation Path/anaconda3/envs/diva/lib/python3.9/site-packages/open_clip/transformer.py`, respectively.
71 |
72 | For SigLIP, use the content in our provided `condition/SigLIP_for_timm_models_visiontransformer.py` to replace the content in `Your Conda Installation Path/anaconda3/envs/diva/lib/python3.9/site-packages/timm/models/vision_transformer.py`.
73 |
74 |
75 | ## 🍻 Quick Start for Training & Evaluation
76 |
77 | After all the above preparation steps, you can simply start training for our DIVA with the following command:
78 | ```shell
79 | # For OpenAICLIP
80 | bash DIVA_for_OpenAICLIP.sh
81 |
82 | # For MetaCLIP
83 | bash DIVA_for_MetaCLIP.sh
84 |
85 | # For SigLIP
86 | bash DIVA_for_SigLIP.sh
87 |
88 | # For DFN
89 | bash DIVA_for_DFN.sh
90 | ```
91 |
92 | ## Model Zoo
93 |
94 | | Method | Image Size | Params (M) | Average Score |
95 | |----------------------|------------|------------|---------------|
96 | | [OpenAI ViT-L-14](https://huggingface.co/BAAI/DIVA/blob/main/OpenAICLIP/OpenAI-ViT-L-14-224.pth) | 224² | 427.6 | 25.9 (+6.6) |
97 | | [OpenAI ViT-L-14](https://huggingface.co/BAAI/DIVA/blob/main/OpenAICLIP/OpenAI-ViT-L-14-336.pth) | 336² | 427.9 | 25.2 (+5.2) |
98 | | [MetaCLIP ViT-L-14](https://huggingface.co/BAAI/DIVA/blob/main/MetaCLIP/MetaCLIP-ViT-L-14.pth) | 224² | 427.6 | 27.4 (+3.7) |
99 | | [MetaCLIP ViT-H-14](https://huggingface.co/BAAI/DIVA/blob/main/MetaCLIP/MetaCLIP-ViT-H-14.pth) | 224² | 986.1 | 31.9 (+6.7) |
100 | | [SigLIP ViT-SO-14](https://huggingface.co/BAAI/DIVA/blob/main/SigLIP/SigLIP-ViT-SO-14-224.pth) | 224² | 877.4 | 40.7 (+2.9) |
101 | | [SigLIP ViT-SO-14](https://huggingface.co/BAAI/DIVA/blob/main/SigLIP/SigLIP-ViT-SO-14-384.pth) | 384² | 878.0 | 38.5 (+1.5) |
102 | | [DFN ViT-H-14](https://huggingface.co/BAAI/DIVA/blob/main/DFN/DFN-ViT-H-14-224.pth) | 224² | 986.1 | 43.7 (+4.4) |
103 | | [DFN ViT-H-14](https://huggingface.co/BAAI/DIVA/blob/main/DFN/DFN-ViT-H-14-378.pth) | 378² | 986.7 | 37.8 (+3.0) |
104 |
105 |
106 | It is worth noting that, due to the randomness among the introduced condition design during the training phase and the selection of local patch tokens during the inference phase for OpenAI CLIP, the obtained scores on MMVP_VLM benchmark using our provided OpenAI CLIP weights might not be the same as the reported results in our paper. At this time, we recommend trying different random seeds multiple times if the scores do not meet expectations.
107 |
108 | ## 🎨 Visualization
109 |
110 |
111 |
112 |
113 |
114 |
115 | ## 💙 Acknowledgement
116 | DIVA is built upon the awesome [Diffusion-TTA](https://github.com/mihirp1998/Diffusion-TTA), [MMVP](https://github.com/tsb0601/MMVP), [CLIP](https://github.com/openai/CLIP), [OpenCLIP](https://github.com/mlfoundations/open_clip), [timm](https://github.com/huggingface/pytorch-image-models/).
117 |
118 | ## 📝 Citation
119 | ```bib
120 | @article{wang2024diffusion,
121 | title={Diffusion Feedback Helps CLIP See Better},
122 | author={Wang, Wenxuan and Sun, Quan and Zhang, Fan and Tang, Yepeng and Liu, Jing and Wang, Xinlong},
123 | journal={arXiv preprint arXiv:2407.20171},
124 | year={2024}
125 | }
126 | ```
127 |
--------------------------------------------------------------------------------
/accelerator.json:
--------------------------------------------------------------------------------
1 | {
2 | "dispatch_batches": false
3 | }
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional, List
3 | import transformers
4 |
5 | @dataclass
6 | class DataTrainingArguments:
7 |
8 | dataset_path: str = "dataset_path"
9 |
10 | arbitrary_resolution: Optional[bool] = field(
11 | default=False,
12 | metadata={
13 | "help": "If true, images will have arbitrary resolutions."
14 | },
15 | )
16 |
17 | max_train_samples: Optional[int] = field(
18 | default=None,
19 | metadata={
20 | "help": (
21 | "For debugging purposes or quicker training, truncate the number of training examples to this "
22 | "value if set."
23 | )
24 | },
25 | )
26 | max_eval_samples: Optional[int] = field(
27 | default=None,
28 | metadata={
29 | "help": (
30 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
31 | "value if set."
32 | )
33 | },
34 | )
35 |
36 | one_minus_one_data_transform: Optional[bool] = field(
37 | default=False,
38 | metadata={
39 | "help": "If true, the data will be scaled to [-1, 1] instead of [0, 1]."
40 | },
41 | )
42 |
43 |
44 | @dataclass
45 | class ModelArguments:
46 | """
47 | Arguments pertaining to which model/config/image processor we are going to pre-train.
48 | """
49 |
50 | model_type: Optional[str] = field(
51 | default=None,
52 | metadata={"help": "If training from scratch, pass a model type from the list: "},
53 | )
54 | config_overrides: Optional[str] = field(
55 | default=None,
56 | metadata={
57 | "help": (
58 | "Override some existing default config settings when a model is trained from scratch. Example: "
59 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
60 | )
61 | },
62 | )
63 | image_size: Optional[int] = field(
64 | default=None,
65 | metadata={
66 | "help": (
67 | "The size (resolution) of each image. If not specified, will use `image_size` of the configuration."
68 | )
69 | },
70 | )
71 |
72 | fixed_image_size: Optional[bool] = field(
73 | default=True,
74 | )
75 |
76 | patch_size: Optional[int] = field(
77 | default=None,
78 | metadata={
79 | "help": (
80 | "The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration."
81 | )
82 | },
83 | )
84 |
85 | tublet_size: Optional[List[int]] = field(
86 | default_factory=lambda: [2, 16, 16],
87 | metadata={
88 | "help": (
89 | "The size of each tubelet (3D patch size). If not specified, will use `tubelet_size` of the configuration."
90 | )
91 | },
92 | )
93 |
94 | cost_gradient_penalty: Optional[float] = field(
95 | default=0, # 0.2
96 | )
97 |
98 | enable_flash: Optional[bool] = field(
99 | default=False,
100 | )
101 |
102 | @dataclass
103 | class TrainingArguments(transformers.TrainingArguments):
104 |
105 | multiple_optimizer_training: Optional[float] = field( default=False, metadata={ "help": "will become true if `gan_loss_weight` in `model_args` is set to allow multiple optimizers" } )
106 |
107 | wandb_api_key: Optional[str] = field(
108 | default=None,
109 | metadata={
110 | "help": "wandb api key"
111 | }
112 | )
113 |
114 | train_steps: Optional[int] = field(default=1,)
115 |
116 | visual_pattern: Optional[str] = field(default=None,)
117 |
118 | clip_image_size: Optional[int] = field(default=224,)
119 |
120 | metaclip_version: Optional[str] = field(default=None,)
--------------------------------------------------------------------------------
/assets/introduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baaivision/DIVA/dee07edc8ac299757e310bf84c31cc82a8ccc44c/assets/introduction.png
--------------------------------------------------------------------------------
/assets/methodology.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baaivision/DIVA/dee07edc8ac299757e310bf84c31cc82a8ccc44c/assets/methodology.png
--------------------------------------------------------------------------------
/assets/qualitative_mmvp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baaivision/DIVA/dee07edc8ac299757e310bf84c31cc82a8ccc44c/assets/qualitative_mmvp.png
--------------------------------------------------------------------------------
/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | from transformers.training_args import TrainingArguments
3 | from transformers.utils import logging, is_torch_tpu_available
4 | logger = logging.get_logger(__name__)
5 |
6 | from transformers.integrations import WandbCallback
7 | from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
8 |
9 | class MyWandbCallback(WandbCallback):
10 | def setup(self, args, state, model, **kwargs):
11 | """
12 | Setup the optional Weights & Biases (*wandb*) integration.
13 |
14 | One can subclass and override this method to customize the setup if needed. Find more information
15 | [here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment
16 | variables:
17 |
18 | Environment:
19 | - **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`):
20 | Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set
21 | to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint
22 | will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along
23 | with [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model.
24 |
25 |
26 |
27 | Setting `WANDB_LOG_MODEL` as `bool` will be deprecated in version 5 of 🤗 Transformers.
28 |
29 |
30 | - **WANDB_WATCH** (`str`, *optional* defaults to `"false"`):
31 | Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and
32 | parameters.
33 | """
34 | if self._wandb is None:
35 | return
36 | self._initialized = True
37 | if state.is_world_process_zero:
38 | logger.info(
39 | 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
40 | )
41 | combined_dict = {**args.to_dict()}
42 |
43 | if hasattr(model, "config") and model.config is not None:
44 | model_config = model.config.to_dict()
45 | combined_dict = {**model_config, **combined_dict}
46 | trial_name = state.trial_name
47 | init_args = {}
48 | if trial_name is not None:
49 | init_args["name"] = trial_name
50 | init_args["group"] = args.run_name
51 | else:
52 | if not (args.run_name is None or args.run_name == args.output_dir):
53 | init_args["name"] = args.run_name
54 |
55 | if self._wandb.run is None:
56 | self._wandb.init(
57 | project=os.getenv("WANDB_PROJECT", "huggingface"),
58 | settings=self._wandb.Settings(code_dir="/share/project/qiying/projects/vision-tokenizer"),
59 | **init_args,
60 | )
61 | # add config parameters (run may have been created manually)
62 | self._wandb.config.update(combined_dict, allow_val_change=True)
63 |
64 | # define default x-axis (for latest wandb versions)
65 | if getattr(self._wandb, "define_metric", None):
66 | self._wandb.define_metric("train/global_step")
67 | self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
68 |
69 | # keep track of model topology and gradients, unsupported on TPU
70 | _watch_model = os.getenv("WANDB_WATCH", "false")
71 | if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
72 | self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
73 | self._wandb.run._label(code="transformers_trainer")
74 |
75 | class ModelCallback(TrainerCallback):
76 | def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
77 | kwargs['model'].global_step = state.global_step
78 | return super().on_step_begin(args, state, control, **kwargs)
--------------------------------------------------------------------------------
/condition/OpenAICLIP_for_clip_model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.relu1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.relu2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.relu3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.relu1(self.bn1(self.conv1(x)))
46 | out = self.relu2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x[:1], key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 | return x.squeeze(0)
92 |
93 |
94 | class ModifiedResNet(nn.Module):
95 | """
96 | A ResNet class that is similar to torchvision's but contains the following changes:
97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99 | - The final pooling layer is a QKV attention instead of an average pool
100 | """
101 |
102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103 | super().__init__()
104 | self.output_dim = output_dim
105 | self.input_resolution = input_resolution
106 |
107 | # the 3-layer stem
108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109 | self.bn1 = nn.BatchNorm2d(width // 2)
110 | self.relu1 = nn.ReLU(inplace=True)
111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112 | self.bn2 = nn.BatchNorm2d(width // 2)
113 | self.relu2 = nn.ReLU(inplace=True)
114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115 | self.bn3 = nn.BatchNorm2d(width)
116 | self.relu3 = nn.ReLU(inplace=True)
117 | self.avgpool = nn.AvgPool2d(2)
118 |
119 | # residual layers
120 | self._inplanes = width # this is a *mutable* variable used during construction
121 | self.layer1 = self._make_layer(width, layers[0])
122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125 |
126 | embed_dim = width * 32 # the ResNet feature dimension
127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128 |
129 | def _make_layer(self, planes, blocks, stride=1):
130 | layers = [Bottleneck(self._inplanes, planes, stride)]
131 |
132 | self._inplanes = planes * Bottleneck.expansion
133 | for _ in range(1, blocks):
134 | layers.append(Bottleneck(self._inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | def stem(x):
140 | x = self.relu1(self.bn1(self.conv1(x)))
141 | x = self.relu2(self.bn2(self.conv2(x)))
142 | x = self.relu3(self.bn3(self.conv3(x)))
143 | x = self.avgpool(x)
144 | return x
145 |
146 | x = x.type(self.conv1.weight.dtype)
147 | x = stem(x)
148 | x = self.layer1(x)
149 | x = self.layer2(x)
150 | x = self.layer3(x)
151 | x = self.layer4(x)
152 | x = self.attnpool(x)
153 |
154 | return x
155 |
156 |
157 | class LayerNorm(nn.LayerNorm):
158 | """Subclass torch's LayerNorm to handle fp16."""
159 |
160 | def forward(self, x: torch.Tensor):
161 | orig_type = x.dtype
162 | ret = super().forward(x.type(torch.float32))
163 | # ret = super().forward(x)
164 | return ret.type(orig_type)
165 |
166 |
167 | class QuickGELU(nn.Module):
168 | def forward(self, x: torch.Tensor):
169 | return x * torch.sigmoid(1.702 * x)
170 |
171 |
172 | class ResidualAttentionBlock(nn.Module):
173 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
174 | super().__init__()
175 |
176 | self.attn = nn.MultiheadAttention(d_model, n_head)
177 | self.ln_1 = LayerNorm(d_model)
178 | self.mlp = nn.Sequential(OrderedDict([
179 | ("c_fc", nn.Linear(d_model, d_model * 4)),
180 | ("gelu", QuickGELU()),
181 | ("c_proj", nn.Linear(d_model * 4, d_model))
182 | ]))
183 | self.ln_2 = LayerNorm(d_model)
184 | self.attn_mask = attn_mask
185 |
186 | def attention(self, x: torch.Tensor):
187 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
188 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
189 |
190 | def forward(self, x: torch.Tensor):
191 | x = x + self.attention(self.ln_1(x))
192 | x = x + self.mlp(self.ln_2(x))
193 | return x
194 |
195 |
196 | class Transformer(nn.Module):
197 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
198 | super().__init__()
199 | self.width = width
200 | self.layers = layers
201 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
202 |
203 | def forward(self, x: torch.Tensor):
204 | return self.resblocks(x)
205 |
206 |
207 | class VisionTransformer(nn.Module):
208 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
209 | super().__init__()
210 | self.input_resolution = input_resolution
211 | self.output_dim = output_dim
212 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
213 |
214 | scale = width ** -0.5
215 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
216 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
217 | self.ln_pre = LayerNorm(width)
218 |
219 | self.transformer = Transformer(width, layers, heads)
220 |
221 | self.ln_post = LayerNorm(width)
222 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
223 | self.output_dim = output_dim
224 |
225 | def forward(self, x: torch.Tensor):
226 | x = self.conv1(x) # shape = [*, width, grid, grid]
227 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
228 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
229 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
230 | x = x + self.positional_embedding.to(x.dtype)
231 | x = self.ln_pre(x)
232 |
233 | x = x.permute(1, 0, 2) # NLD -> LND
234 | x = self.transformer(x)
235 | x = x.permute(1, 0, 2) # LND -> NLD
236 |
237 |
238 | # original code
239 | # x = self.ln_post(x[:, 0, :]) #[2, 1024]
240 | # if self.proj is not None:
241 | # x = x @ self.proj #[2, 768]
242 | # return x.unsqueeze(1)
243 |
244 |
245 | # DIVA's condition for OpenAI-CLIP(vit-l/224 & vit-l/336)
246 | class_token = x[:, 0, :].unsqueeze(1)
247 | remaining_tokens = x[:, 1:, :]
248 | num_seleted_tokens = int(77-1)
249 | random_indices = torch.randint(0, remaining_tokens.shape[1], (x.shape[0], num_seleted_tokens))
250 | batch_indices = torch.arange(x.shape[0]).unsqueeze(1).expand(x.shape[0], num_seleted_tokens)
251 | selected_tokens = remaining_tokens[batch_indices, random_indices]
252 | x = torch.cat((class_token, selected_tokens), dim=1)
253 | B,N,C = x.shape
254 | x = x.reshape(B*N,C)
255 | x = self.ln_post(x)
256 | if self.proj is not None:
257 | x = x @ self.proj
258 | x = x.reshape(B,N,self.output_dim)
259 | return x
260 |
261 |
262 | class CLIP(nn.Module):
263 | def __init__(self,
264 | embed_dim: int,
265 | # vision
266 | image_resolution: int,
267 | vision_layers: Union[Tuple[int, int, int, int], int],
268 | vision_width: int,
269 | vision_patch_size: int,
270 | # text
271 | context_length: int,
272 | vocab_size: int,
273 | transformer_width: int,
274 | transformer_heads: int,
275 | transformer_layers: int
276 | ):
277 | super().__init__()
278 |
279 | self.context_length = context_length
280 |
281 | if isinstance(vision_layers, (tuple, list)):
282 | vision_heads = vision_width * 32 // 64
283 | self.visual = ModifiedResNet(
284 | layers=vision_layers,
285 | output_dim=embed_dim,
286 | heads=vision_heads,
287 | input_resolution=image_resolution,
288 | width=vision_width
289 | )
290 | else:
291 | vision_heads = vision_width // 64
292 | self.visual = VisionTransformer(
293 | input_resolution=image_resolution,
294 | patch_size=vision_patch_size,
295 | width=vision_width,
296 | layers=vision_layers,
297 | heads=vision_heads,
298 | output_dim=embed_dim
299 | )
300 |
301 | self.transformer = Transformer(
302 | width=transformer_width,
303 | layers=transformer_layers,
304 | heads=transformer_heads,
305 | attn_mask=self.build_attention_mask()
306 | )
307 |
308 | self.vocab_size = vocab_size
309 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
310 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
311 | self.ln_final = LayerNorm(transformer_width)
312 |
313 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
314 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
315 |
316 | self.initialize_parameters()
317 |
318 | def initialize_parameters(self):
319 | nn.init.normal_(self.token_embedding.weight, std=0.02)
320 | nn.init.normal_(self.positional_embedding, std=0.01)
321 |
322 | if isinstance(self.visual, ModifiedResNet):
323 | if self.visual.attnpool is not None:
324 | std = self.visual.attnpool.c_proj.in_features ** -0.5
325 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
326 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
327 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
328 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
329 |
330 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
331 | for name, param in resnet_block.named_parameters():
332 | if name.endswith("bn3.weight"):
333 | nn.init.zeros_(param)
334 |
335 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
336 | attn_std = self.transformer.width ** -0.5
337 | fc_std = (2 * self.transformer.width) ** -0.5
338 | for block in self.transformer.resblocks:
339 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
340 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
341 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
342 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
343 |
344 | if self.text_projection is not None:
345 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
346 |
347 | def build_attention_mask(self):
348 | # lazily create causal attention mask, with full attention between the vision tokens
349 | # pytorch uses additive attention mask; fill with -inf
350 | mask = torch.empty(self.context_length, self.context_length)
351 | mask.fill_(float("-inf"))
352 | mask.triu_(1) # zero out the lower diagonal
353 | return mask
354 |
355 | @property
356 | def dtype(self):
357 | return self.visual.conv1.weight.dtype
358 |
359 | def encode_image(self, image):
360 | return self.visual(image.type(self.dtype))
361 |
362 | def encode_text(self, text):
363 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
364 | x = x + self.positional_embedding.type(self.dtype)
365 | x = x.permute(1, 0, 2) # NLD -> LND
366 | x = self.transformer(x)
367 | x = x.permute(1, 0, 2) # LND -> NLD
368 | x = self.ln_final(x).type(self.dtype)
369 |
370 | # x.shape = [batch_size, n_ctx, transformer.width]
371 | # take features from the eot embedding (eot_token is the highest number in each sequence)
372 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
373 |
374 | return x
375 |
376 | def forward(self, image, text):
377 |
378 | # original code
379 | # image_features = self.encode_image(image)
380 |
381 | # ours
382 | global_image_features = self.encode_image(image)[:,0,:]
383 | local_image_features = self.encode_image(image)[:,1:,:].mean(dim=1)
384 | image_features = global_image_features+local_image_features
385 |
386 | text_features = self.encode_text(text)
387 |
388 | # normalized features
389 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
390 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
391 |
392 | # cosine similarity as logits
393 | logit_scale = self.logit_scale.exp()
394 | logits_per_image = logit_scale * image_features @ text_features.t()
395 | logits_per_text = logits_per_image.t()
396 |
397 | # shape = [global_batch_size, global_batch_size]
398 | return logits_per_image, logits_per_text
399 |
400 |
401 | def convert_weights(model: nn.Module):
402 | """Convert applicable model parameters to fp16"""
403 |
404 | def _convert_weights_to_fp16(l):
405 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
406 | l.weight.data = l.weight.data.half()
407 | if l.bias is not None:
408 | l.bias.data = l.bias.data.half()
409 |
410 | if isinstance(l, nn.MultiheadAttention):
411 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
412 | tensor = getattr(l, attr)
413 | if tensor is not None:
414 | tensor.data = tensor.data.half()
415 |
416 | for name in ["text_projection", "proj"]:
417 | if hasattr(l, name):
418 | attr = getattr(l, name)
419 | if attr is not None:
420 | attr.data = attr.data.half()
421 |
422 | model.apply(_convert_weights_to_fp16)
423 |
424 |
425 | def build_model(state_dict: dict):
426 | vit = "visual.proj" in state_dict
427 |
428 | if vit:
429 | vision_width = state_dict["visual.conv1.weight"].shape[0]
430 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
431 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
432 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
433 |
434 | #default
435 | image_resolution = vision_patch_size * grid_size
436 |
437 | # 以下代码服务于CLIP的位置编码插值
438 | # image_resolution=int(512)
439 | # src = state_dict["visual.positional_embedding"]
440 | # src_cls = src[0:1]
441 | # src_oth = src[1:]
442 | # src_oth = F.interpolate(src_oth.reshape(grid_size,grid_size,1024).permute(2,0,1).unsqueeze(0),(image_resolution//14,image_resolution//14),mode='bilinear')
443 | # src_oth = src_oth[0].permute(1,2,0).reshape(-1,1024)
444 | # tgt = torch.cat((src_cls,src_oth),dim=0)
445 | # state_dict["visual.positional_embedding"] = tgt
446 |
447 |
448 | else:
449 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
450 | vision_layers = tuple(counts)
451 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
452 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
453 | vision_patch_size = None
454 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
455 | image_resolution = output_width * 32
456 |
457 | embed_dim = state_dict["text_projection"].shape[1]
458 | context_length = state_dict["positional_embedding"].shape[0]
459 | vocab_size = state_dict["token_embedding.weight"].shape[0]
460 | transformer_width = state_dict["ln_final.weight"].shape[0]
461 | transformer_heads = transformer_width // 64
462 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
463 |
464 | model = CLIP(
465 | embed_dim,
466 | image_resolution, vision_layers, vision_width, vision_patch_size,
467 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
468 | )
469 |
470 | for key in ["input_resolution", "context_length", "vocab_size"]:
471 | if key in state_dict:
472 | del state_dict[key]
473 |
474 | convert_weights(model)
475 | model.load_state_dict(state_dict)
476 | return model.eval()
477 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from transformers.configuration_utils import PretrainedConfig
2 | from transformers.utils import logging
3 |
4 | logger = logging.get_logger(__name__)
5 |
6 | class EmptyClass(PretrainedConfig):
7 | def __init__(self):
8 | pass
9 | class SDConfig(PretrainedConfig):
10 |
11 | def __init__(self,
12 | sd_version = '2-1',
13 | override_total_steps = -1,
14 | freeze_class_embeds = True,
15 | freeze_vae = False,
16 | use_flash = False,
17 | adapt_only_classifier = True,
18 | adapt_topk = -1,
19 | loss = 'mse',
20 | actual_bs = 16,
21 | mean = [0.485, 0.456, 0.406],
22 | std = [0.229, 0.224, 0.225],
23 | use_same_noise_among_timesteps = False,
24 | random_timestep_per_iteration = True,
25 | rand_timestep_equal_int = False,
26 | weight_decay = 0,
27 | train_steps = 1,
28 | accum_iter = 1,
29 | optimizer = 'sgd',
30 | optimizer_momentum = 0.9,
31 | pred_noise_batch_size = 1,
32 | output_dir = './outputs/First_Start',
33 | visual_pattern = None,
34 | clip_image_size = 224,
35 | metaclip_version = 1
36 | ):
37 | super().__init__()
38 | self.model = EmptyClass()
39 | self.model.sd_version = sd_version
40 | self.model.override_total_steps = override_total_steps
41 | self.model.freeze_class_embeds = freeze_class_embeds
42 | self.model.freeze_vae = freeze_vae
43 | self.model.use_flash = use_flash
44 | self.model.adapt_only_classifier = adapt_only_classifier
45 | self.tta = EmptyClass()
46 | self.tta.gradient_descent = EmptyClass()
47 | self.tta.adapt_topk = adapt_topk
48 | self.tta.loss = loss
49 | self.tta.use_same_noise_among_timesteps = use_same_noise_among_timesteps
50 | self.tta.random_timestep_per_iteration = random_timestep_per_iteration
51 | self.tta.rand_timestep_equal_int = rand_timestep_equal_int
52 | self.tta.gradient_descent.weight_decay = weight_decay
53 | self.tta.gradient_descent.train_steps = train_steps
54 | self.tta.gradient_descent.accum_iter = accum_iter
55 | self.tta.gradient_descent.optimizer = optimizer
56 | self.tta.gradient_descent.optimizer_momentum = optimizer_momentum
57 | self.input = EmptyClass()
58 | self.input.batch_size = pred_noise_batch_size
59 | self.input.mean = mean
60 | self.input.std = std
61 | self.output_dir = output_dir
62 | self.actual_bs = actual_bs
63 | self.visual_pattern = visual_pattern
64 | self.clip_image_size = clip_image_size
65 | self.metaclip_version = metaclip_version
66 |
67 | if __name__ =='__main__':
68 | SDConfig()
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .transform import image_transform, image_transform_vq, DiffAugment
2 | from .data import get_wds_dataset_and_collator, get_cc3m_wds_dataset_and_collator, get_in1k_val_dataset
3 | from .image_data import get_wds_dataset_and_collator_arbitrary_resolution, get_highres_eval_dataset, get_in1k_dataset
4 | from .constants import ASPECT_RATIO_1024, ASPECT_RATIO_512, ASPECT_RATIO_256, DEFAULT_IMAGE_FILE_SUFFIX, OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
--------------------------------------------------------------------------------
/data/constants.py:
--------------------------------------------------------------------------------
1 |
2 | OPENAI_DATASET_MEAN = [0.48145466, 0.4578275, 0.40821073]
3 | OPENAI_DATASET_STD = [0.26862954, 0.26130258, 0.27577711]
4 |
5 | DEFAULT_IMAGE_FILE_SUFFIX = ['jpg', '0.jpg', '0.png', 'png', 'jpeg', '0.jpeg', 'webp']
6 | DEFAULT_VIDEO_FILE_SUFFIX = ['mp4', 'video.mp4']
7 |
8 | ASPECT_RATIO_1024 = {
9 | '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
10 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
11 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
12 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
13 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
14 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
15 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
16 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
17 | '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
18 | '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
19 | }
20 |
21 |
22 | ASPECT_RATIO_1024_norm = {
23 | '0.68': [832, 1216],
24 | '0.72': [832, 1152],
25 | '0.78': [896, 1152],
26 | '1.0': [1024, 1024],
27 | '1.29': [1152, 896],
28 | '1.46': [1216, 832],
29 | '1.75': [1344, 768],
30 | }
31 |
32 | ASPECT_RATIO_512 = {
33 | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
34 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
35 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
36 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
37 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
38 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
39 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
40 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
41 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
42 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
43 | }
44 |
45 |
46 | ASPECT_RATIO_512_norm = {
47 | '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
48 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
49 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
50 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
51 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
52 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
53 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
54 | }
55 |
56 | ASPECT_RATIO_256 = {
57 | '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
58 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
59 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
60 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
61 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
62 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
63 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
64 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
65 | '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
66 | '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
67 | }
--------------------------------------------------------------------------------
/data/data.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import io
3 | import os
4 | import json
5 | from functools import partial
6 | from typing import Sequence, Dict, Union, Tuple
7 | from dataclasses import dataclass
8 | import numpy as np
9 | from einops import rearrange
10 | from PIL import Image
11 | import random
12 | import torch
13 | import torchvision
14 | from torch.utils.data import Dataset
15 | from datasets import load_dataset
16 | from torchvision import transforms
17 |
18 | from .constants import ASPECT_RATIO_256, ASPECT_RATIO_512, ASPECT_RATIO_1024, DEFAULT_IMAGE_FILE_SUFFIX
19 | from .transform import image_transform, image_transform_original_resolution, image_transform_original_resolution_test
20 |
21 |
22 | aspect_ratio_database = {
23 | 256: ASPECT_RATIO_256,
24 | 512: ASPECT_RATIO_512,
25 | 1024: ASPECT_RATIO_1024
26 | }
27 |
28 | def ratio_sample(ratio, aspect_ratios=ASPECT_RATIO_1024):
29 | closest_ratio = min(aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
30 | return closest_ratio
31 |
32 | def find_image(sample):
33 | for suffix in DEFAULT_IMAGE_FILE_SUFFIX:
34 | if suffix in sample.keys():
35 | sample['0.jpg'] = sample[suffix]
36 | break
37 | return sample
38 |
39 | def get_cc3m_wds_dataset_and_collator(data_args, model_args):
40 | img_size = model_args.image_size
41 | train_processor = image_transform(img_size, is_train=True)
42 | val_processor = image_transform(img_size, is_train=False)
43 |
44 | data = load_dataset("webdataset", data_dir=data_args.dataset_path, split="train", streaming=True)
45 | data = data.shuffle(buffer_size=2_000, seed=data_args.seed)
46 |
47 | def decode(sample, img_processor):
48 | sample = find_image(sample)
49 | sample['image'] = img_processor(sample['jpg'])
50 | sample['text'] = sample['txt']
51 | return sample
52 | data = data.map(
53 | partial(decode, img_processor=train_processor),
54 | remove_columns=['__key__', '__url__']
55 | )
56 | data = data.filter(lambda sample: 'image' in sample and 'text' in sample) # filter return samples that match the given condition
57 | data_collator = CC3M_WebdatasetCollator(model_args.patch_size)
58 |
59 | return data, data_collator
60 |
61 | def get_wds_dataset_and_collator(data_args, model_args):
62 | img_size = model_args.image_size
63 |
64 | train_processor = image_transform(img_size, is_train=True) if model_args.fixed_image_size else image_transform
65 | data = load_dataset("webdataset", data_dir=data_args.dataset_path, split="train", streaming=True)
66 | data = data.shuffle(buffer_size=2_000, seed=data_args.seed)
67 |
68 | def decode(sample, img_processor):
69 | sample = find_image(sample)
70 | if model_args.fixed_image_size:
71 | sample['0.jpg'] = img_processor(sample['0.jpg'])
72 | return sample
73 |
74 | data = data.map(
75 | partial(decode, img_processor=train_processor),
76 | remove_columns=['__key__', '__url__']
77 | )
78 | data = data.filter(lambda sample: '0.jpg' in sample)
79 | data = data.rename_columns({'0.jpg': 'image'})
80 |
81 | aspect_ratios = aspect_ratio_database.get(model_args.image_size, None)
82 | aspect_ratios = aspect_ratios or ASPECT_RATIO_512
83 |
84 | data_collator = WebdatasetCollator(model_args.fixed_image_size, model_args.patch_size, aspect_ratios)
85 | import ipdb
86 | ipdb.set_trace()
87 | return data, data_collator
88 |
89 | def get_wds_dataset_and_collator_arbitrary_resolution(data_args, model_args):
90 |
91 | data = load_dataset("webdataset", data_dir=data_args.dataset_path, split="train", streaming=True)
92 | data = data.shuffle(buffer_size=2_000, seed=data_args.seed)
93 |
94 | def decode_sample(sample, img_processor):
95 | sample = find_image(sample)
96 | sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg'])
97 | return sample
98 |
99 | data = data.map(
100 | partial(
101 | decode_sample,
102 | img_processor=partial(image_transform_original_resolution, patch_size=model_args.patch_size)
103 | ),
104 | remove_columns=['__key__', '__url__']
105 | )
106 | data = data.filter(lambda sample: '0.jpg' in sample and sample['0.jpg'].ndim == 3 and sample['0.jpg'].shape[-1] > 0 and sample['0.jpg'].shape[-2] > 0) # filter return samples that match the given condition
107 | data = data.rename_columns({'0.jpg': 'image'})
108 | data_collator = WebdatasetCollator(model_args.patch_size)
109 |
110 | return data, data_collator
111 |
112 | def dataset_test(data_args, model_args):
113 | from datasets import load_dataset
114 | import numpy as np
115 | from functools import partial
116 | from torchvision import transforms
117 | OPENAI_DATASET_MEAN = np.array([0.48145466, 0.4578275, 0.40821073])
118 | OPENAI_DATASET_STD = np.array([0.26862954, 0.26130258, 0.27577711])
119 | DEFAULT_IMAGE_FILE_SUFFIX = ['jpg', '0.jpg', '0.png', 'png', 'jpeg', '0.jpeg', 'webp']
120 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/*/*.tar", split="train", streaming=True)
121 | data = data.shuffle(buffer_size=2_000, seed=1)
122 |
123 | data_iter = iter(data)
124 |
125 | def decode_sample(sample, img_processor):
126 | sample = find_image(sample)
127 | sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg'])
128 | return sample
129 |
130 | data = data.map(
131 | partial(
132 | decode_sample,
133 | img_processor=partial(image_transform_original_resolution_test, patch_size=16)
134 | ),
135 | remove_columns=['__key__', '__url__']
136 | )
137 | data = data.filter(lambda sample: '0.jpg' in sample) # filter return samples that match the given condition
138 | data = data.rename_columns({'0.jpg': 'image'})
139 | data_collator = WebdatasetCollator()
140 |
141 | return data, data_collator
142 |
143 | def collate_anyres(images, sizes, patch_size, max_size=2048):
144 | """
145 | Args:
146 | * images: list of images
147 | * sizes: list of image sizes in (ph, pw), i.e., number of patches in h and w
148 |
149 | Return: args accepted by VQModel
150 | * pixel_values: packed images
151 | * cu_seqlens_img
152 | * max_seqlen_img
153 | * grid_hw
154 | * image_sizes
155 | """
156 | b, c = len(images), images[0].shape[0]
157 | max_patch_num = max_size // patch_size
158 |
159 | image_sizes = torch.tensor([(image.shape[1], image.shape[2]) for image in images])
160 | H, W = image_sizes.max(dim=0).values
161 | padded_images = images[0].new_zeros(size=(b, c, H.item(), W.item()))
162 |
163 | h, w = torch.tensor(sizes).max(dim=0).values
164 | padding_masks = torch.zeros(size=(b, h.item(), w.item()), dtype=torch.bool)
165 |
166 | for i, (image, mask_size) in enumerate(zip(images, sizes)):
167 | padded_images[i, :, : image.shape[1], : image.shape[2]].copy_(image)
168 | padding_masks[i, : mask_size[0], : mask_size[1]] = 1
169 |
170 | padded_images = padded_images.reshape(b, c, h, patch_size, w, patch_size)
171 | padded_images = torch.einsum("nchpwq->nhwpqc", padded_images)
172 | padded_images = padded_images.reshape(b, h, w, -1)
173 | packed_images = padded_images[padding_masks]
174 |
175 | seq_lens = padding_masks.flatten(1, 2).sum(dim=-1)
176 | cu_seqlens_img = torch.nn.functional.pad(
177 | torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
178 | )
179 | max_seqlen_img = seq_lens.max()
180 |
181 | grid_h = torch.arange(0, h)[None, :, None].repeat(b, 1, w)
182 | grid_w = torch.arange(0, w)[None, None, :].repeat(b, h, 1)
183 | grid_hw = grid_h[padding_masks] * max_patch_num + grid_w[padding_masks]
184 |
185 | return packed_images, cu_seqlens_img, max_seqlen_img, grid_hw, torch.tensor(sizes)
186 | @dataclass
187 | class CC3M_WebdatasetCollator:
188 | def __init__(self, patch_size: int = 1):
189 | self.patch_size = patch_size
190 | self.count = 0
191 |
192 | def __call__(
193 | self,
194 | samples: Sequence[Dict],
195 | ) -> Dict[str, torch.Tensor]:
196 |
197 | self.count += 1
198 | images = [sample["image"] for sample in samples]
199 | texts = [sample["text"] for sample in samples]
200 |
201 | if "size" in samples[0]:
202 | sizes = [sample['size'] for sample in samples]
203 |
204 | batch = {}
205 |
206 | if all(x is not None and x.shape == images[0].shape for x in images):
207 | batch['image'] = torch.stack(images)
208 | else:
209 | if "size" in samples[0]:
210 | batch['image'], batch['cu_seqlens_img'], \
211 | batch['max_seqlen_img'], batch['grid_hw'], \
212 | batch['image_sizes'] = collate_anyres(images, sizes, self.patch_size)
213 | else:
214 | batch['image'] = images
215 | batch['text'] = texts
216 | return batch
217 |
218 | @dataclass
219 | class WebdatasetCollator:
220 | def __init__(self,fixed_image_size: bool = True, patch_size: int = 8, aspect_ratios=ASPECT_RATIO_512):
221 | self.fixed_image_size = fixed_image_size
222 | self.patch_size = patch_size
223 | self.aspect_ratios = aspect_ratios
224 | self.count = 0
225 |
226 | def __call__(
227 | self,
228 | samples: Sequence[Dict],
229 | ) -> Dict[str, torch.Tensor]:
230 |
231 | self.count += 1
232 | images = [sample["image"] for sample in samples]
233 |
234 | if "size" in samples[0]:
235 | sizes = [sample['size'] for sample in samples]
236 |
237 | if "size" not in samples[0] and not self.fixed_image_size:
238 | np.random.seed(self.count)
239 |
240 | aspect_ratio = np.random.choice(list(self.aspect_ratios.keys()))
241 | image_sizes = [int(x) for x in self.aspect_ratios[aspect_ratio]]
242 | img_processor = image_transform(image_sizes, is_train=True)
243 | images = [img_processor(image) for image in images]
244 |
245 | batch = {}
246 |
247 | if all(x is not None and x.shape == images[0].shape for x in images):
248 | batch['pixel_values'] = torch.stack(images)
249 | else:
250 | if "size" in samples[0]:
251 | batch['pixel_values'], batch['cu_seqlens_img'], \
252 | batch['max_seqlen_img'], batch['grid_hw'], \
253 | batch['image_sizes'] = collate_anyres(images, sizes, self.patch_size)
254 | else:
255 | batch['pixel_values'] = images
256 | return batch
257 |
258 | def anyres_process_images_for_model(image_path=None, pil_image=None, patch_size=32):
259 | """
260 | given a list of image_path or pil_image, transform to input to model
261 | """
262 | if image_path is not None:
263 | assert pil_image is None
264 | if not isinstance(image_path, list):
265 | image_path = [image_path]
266 | pil_image = []
267 | for p in image_path:
268 | pil_image.append(Image.open(p).convert('RGB'))
269 | if not isinstance(pil_image, list):
270 | pil_image = [pil_image]
271 |
272 | if len(pil_image) % 2 != 0:
273 | pil_image.append(pil_image[-1])
274 |
275 | image_tensors, sizes = [], []
276 | for pil_i in pil_image:
277 | image_tensor, size = image_transform_original_resolution(image=pil_i, patch_size=patch_size)
278 | image_tensors.append(image_tensor)
279 | sizes.append(size)
280 |
281 | pixel_values, cu_seqlens_img, max_seqlen_img, grid_hw, image_sizes = collate_anyres(image_tensors, sizes, patch_size)
282 |
283 | return {
284 | 'pixel_values': pixel_values,
285 | 'cu_seqlens_img': cu_seqlens_img,
286 | 'max_seqlen_img': max_seqlen_img,
287 | 'grid_hw': grid_hw,
288 | 'image_sizes': image_sizes
289 | }
290 |
291 | def get_in1k_val_dataset(data_args, model_args):
292 | import torchvision
293 | transform = image_transform(model_args.image_size, is_train=False)
294 | dataset = torchvision.datasets.ImageFolder(root="/share/project/datasets/ImageNet/val", transform=transform)
295 | def in1k_collator(samples):
296 | if model_args.gan_loss_weight:
297 | return {"pixel_values": torch.stack([sample[0] for sample in samples]), "optimizer_idx": 0}
298 | return {"pixel_values": torch.stack([sample[0] for sample in samples])}
299 | def in1k_collator_anyres(samples):
300 | images = [sample[0] for sample in samples]
301 | sizes = [[image.shape[1] // model_args.patch_size, image.shape[2] // model_args.patch_size] for image in images]
302 | b, c = len(images), images[0].shape[0]
303 | max_patch_num = 1024 // model_args.patch_size
304 |
305 | image_sizes = torch.tensor([(image.shape[1], image.shape[2]) for image in images])
306 | H, W = image_sizes.max(dim=0).values
307 | padded_images = images[0].new_zeros(size=(b, c, H.item(), W.item()))
308 |
309 | h, w = torch.tensor(sizes).max(dim=0).values
310 | padding_masks = torch.zeros(size=(b, h.item(), w.item()), dtype=torch.bool)
311 |
312 | for i, (image, mask_size) in enumerate(zip(images, sizes)):
313 | padded_images[i, :, : image.shape[1], : image.shape[2]].copy_(image)
314 | padding_masks[i, : mask_size[0], : mask_size[1]] = 1
315 |
316 | padded_images = padded_images.reshape(b, c, h, model_args.patch_size, w, model_args.patch_size)
317 | padded_images = torch.einsum("nchpwq->nhwpqc", padded_images)
318 | padded_images = padded_images.reshape(b, h, w, -1)
319 | packed_images = padded_images[padding_masks]
320 |
321 | seq_lens = padding_masks.flatten(1, 2).sum(dim=-1)
322 | cu_seqlens_img = torch.nn.functional.pad(
323 | torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
324 | )
325 | max_seqlen_img = seq_lens.max()
326 |
327 | grid_h = torch.arange(0, h)[None, :, None].repeat(b, 1, w)
328 | grid_w = torch.arange(0, w)[None, None, :].repeat(b, h, 1)
329 | grid_hw = grid_h[padding_masks] * max_patch_num + grid_w[padding_masks]
330 |
331 | batch = {}
332 | batch['pixel_values'] = packed_images
333 | batch['cu_seqlens_img'] = cu_seqlens_img
334 | batch['max_seqlen_img'] = max_seqlen_img
335 | batch['grid_hw'] = grid_hw
336 | batch['image_sizes'] = torch.tensor(sizes)
337 | if model_args.gan_loss_weight:
338 | batch["optimizer_idx"] = 0
339 | return batch
340 |
341 | return dataset, in1k_collator_anyres if data_args.arbitrary_resolution else in1k_collator
342 |
343 |
344 | def split_val_set():
345 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/*/*.tar", split="train", streaming=True)
346 | data = data.shuffle(buffer_size=100_000, seed=100)
347 | val_set = []
348 | val_ids = set()
349 | for i, item in enumerate(data):
350 | item_id = item['__url__'] + item['__key__']
351 | if i == 50_000:
352 | break
353 | val_set.append(item)
354 | val_ids.add(item_id)
355 | with open("/share/project/datasets/laion-high-resolution/50k_eval_ids.pkl", "wb") as f:
356 | pickle.dump(val_ids, f)
357 |
358 | import webdataset as wds
359 | from PIL import Image
360 | from pathlib import Path
361 | for i in range(50):
362 | sink = wds.TarWriter(f"/share/project/datasets/laion-high-resolution/eval/eval_{i}.tar")
363 | for sample in val_set[i * 1000: (i + 1) * 1000]:
364 | sink.write(sample)
365 | sink.close()
366 |
367 | def preprocess_val_data():
368 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/eval/*.tar", split="train", streaming=True)
369 | save_dir = "/share/project/datasets/laion-high-resolution/eval/"
370 | info = []
371 | def decode_sample(sample):
372 | sample = find_image(sample)
373 | sample['0.jpg'], sample['size'] = image_transform_original_resolution(sample['0.jpg'], patch_size=32)
374 | return sample
375 | for i, item in tqdm(enumerate(data)):
376 | image_file = save_dir + f"image_{i}.pkl"
377 | # process image -> size, image pkl
378 | item = decode_sample(item)
379 | with open(image_file, "wb") as f:
380 | pickle.dump(item['0.jpg'], f)
381 | info.append({
382 | "image_path": image_file,
383 | "size": item['size']
384 | })
385 | print(i)
386 | with open(save_dir + f"image_info.pkl", "wb") as f:
387 | pickle.dump(info, f)
388 |
389 | class HighresEvalDataset(Dataset):
390 | def __init__(self):
391 | with open("/share/project/datasets/laion-high-resolution/eval/image_info.pkl", "rb") as f:
392 | self.info = pickle.load(f)
393 |
394 | def __getitem__(self, index):
395 | info = self.info[index]
396 | image_path, size = info['image_path'], info['size']
397 | with open(image_path, "rb") as f:
398 | image = pickle.load(f)
399 | return {"image": image, "size": size}
400 |
401 | def __len__(self):
402 | return len(self.info)
403 |
404 | def get_highres_eval_dataset(data_args, model_args):
405 | data = HighresEvalDataset()
406 | data_collator = WebdatasetCollator(model_args.patch_size)
407 |
408 | return data, data_collator
--------------------------------------------------------------------------------
/data/image_data.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import copy
3 | import os
4 | import json
5 | import pickle
6 | from functools import partial
7 | from typing import Sequence, Dict
8 | from dataclasses import dataclass
9 | import numpy as np
10 |
11 | from PIL import Image
12 |
13 | import torch
14 | from torch.utils.data import Dataset
15 | import transformers
16 | from datasets import load_dataset
17 | from torchvision import transforms
18 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
19 | CenterCrop
20 |
21 | OPENAI_DATASET_MEAN = np.array([0.48145466, 0.4578275, 0.40821073])
22 | OPENAI_DATASET_STD = np.array([0.26862954, 0.26130258, 0.27577711])
23 | DEFAULT_IMAGE_FILE_SUFFIX = ['jpg', '0.jpg', '0.png', 'png', 'jpeg', '0.jpeg', 'webp']
24 |
25 | def find_image(sample):
26 | for suffix in DEFAULT_IMAGE_FILE_SUFFIX:
27 | if suffix in sample.keys():
28 | sample['0.jpg'] = sample[suffix]
29 | break
30 | return sample
31 |
32 | # remove_columns=['__key__', '__url__', '0.txt', 'original_prompt']
33 | def get_wds_dataset_and_collator(data_args, model_args):
34 | img_size = model_args.image_size
35 | train_processor = image_transform(img_size, is_train=True)
36 | val_processor = image_transform(img_size, is_train=False)
37 |
38 | data = load_dataset("webdataset", data_dir=data_args.dataset_path, split="train", streaming=True)
39 | data = data.shuffle(buffer_size=2_000, seed=data_args.seed)
40 |
41 | def decode(sample, img_processor):
42 | sample = find_image(sample)
43 | sample['0.jpg'] = img_processor(sample['0.jpg'])
44 | return sample
45 | data = data.map(
46 | partial(decode, img_processor=train_processor),
47 | # remove_columns=['__key__', '__url__', '0.txt', 'seg.txt']
48 | remove_columns=['__key__', '__url__']
49 | )
50 | data = data.filter(lambda sample: '0.jpg' in sample) # filter return samples that match the given condition
51 | data = data.rename_columns({'0.jpg': 'image'})
52 | data_collator = WebdatasetCollator(model_args.patch_size)
53 |
54 | return data, data_collator
55 |
56 | def image_transform_original_resolution(
57 | image,
58 | patch_size: int,
59 | ):
60 | """accept a pil image and transform into torch.tensor"""
61 | w, h = map(lambda x: x // patch_size * patch_size, image.size)
62 | if w > 1024:
63 | h = int(h / (w / 1024) // patch_size * patch_size)
64 | w = 1024
65 | elif h > 1024:
66 | w = int(w / (h / 1024) // patch_size * patch_size)
67 | h = 1024
68 | def _convert_to_rgb(image):
69 | return image.convert('RGB')
70 | normalize = transforms.Normalize(
71 | mean=OPENAI_DATASET_MEAN,
72 | std=OPENAI_DATASET_STD
73 | )
74 | transform = transforms.Compose([
75 | transforms.CenterCrop((h, w)),
76 | _convert_to_rgb,
77 | transforms.ToTensor(),
78 | normalize,
79 | ])
80 | ph, pw = h // patch_size, w // patch_size
81 | return transform(image), (ph, pw)
82 |
83 | def get_wds_dataset_and_collator_arbitrary_resolution(data_args, model_args):
84 | data = load_dataset("webdataset", data_dir=data_args.dataset_path, split="train", streaming=True)
85 | data = data.shuffle(buffer_size=2_000, seed=data_args.seed)
86 |
87 | def decode_sample(sample, img_processor):
88 | sample = find_image(sample)
89 | sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg'])
90 | return sample
91 |
92 | data = data.map(
93 | partial(
94 | decode_sample,
95 | img_processor=partial(image_transform_original_resolution, patch_size=model_args.patch_size)
96 | ),
97 | remove_columns=['__key__', '__url__']
98 | )
99 | data = data.filter(lambda sample: '0.jpg' in sample and sample['0.jpg'].ndim == 3 and sample['0.jpg'].shape[-1] > 0 and sample['0.jpg'].shape[-2] > 0) # filter return samples that match the given condition
100 | data = data.rename_columns({'0.jpg': 'image'})
101 | data_collator = WebdatasetCollator(model_args.patch_size)
102 |
103 | return data, data_collator
104 |
105 | def dataset_test(data_args, model_args):
106 | from datasets import load_dataset
107 | import numpy as np
108 | from functools import partial
109 | from torchvision import transforms
110 | OPENAI_DATASET_MEAN = np.array([0.48145466, 0.4578275, 0.40821073])
111 | OPENAI_DATASET_STD = np.array([0.26862954, 0.26130258, 0.27577711])
112 | DEFAULT_IMAGE_FILE_SUFFIX = ['jpg', '0.jpg', '0.png', 'png', 'jpeg', '0.jpeg', 'webp']
113 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/*/*.tar", split="train", streaming=True)
114 | data = data.shuffle(buffer_size=2_000, seed=1)
115 |
116 | data_iter = iter(data)
117 |
118 | def decode_sample(sample, img_processor):
119 | sample = find_image(sample)
120 | sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg'])
121 | return sample
122 |
123 | def image_transform_original_resolution(
124 | image,
125 | patch_size: int,
126 | ):
127 | w, h = map(lambda x: x // patch_size * patch_size, image.size)
128 | def _convert_to_rgb(image):
129 | return image.convert('RGB')
130 | normalize = transforms.Normalize(
131 | mean=OPENAI_DATASET_MEAN,
132 | std=OPENAI_DATASET_STD
133 | )
134 | transform = transforms.Compose([
135 | transforms.Resize((h, w), interpolation=transforms.InterpolationMode.BICUBIC),
136 | _convert_to_rgb,
137 | transforms.ToTensor(),
138 | normalize,
139 | ])
140 | return transform(image)
141 |
142 | data = data.map(
143 | partial(
144 | decode_sample,
145 | img_processor=partial(image_transform_original_resolution, patch_size=16)
146 | ),
147 | remove_columns=['__key__', '__url__']
148 | )
149 | data = data.filter(lambda sample: '0.jpg' in sample) # filter return samples that match the given condition
150 | data = data.rename_columns({'0.jpg': 'image'})
151 | data_collator = WebdatasetCollator()
152 |
153 | return data, data_collator
154 |
155 | def collate_anyres(images, sizes, patch_size):
156 | """
157 | Args:
158 | * images: list of images
159 | * sizes: list of image sizes in (ph, pw), i.e., number of patches in h and w
160 |
161 | Return: args accepted by VQModel
162 | * pixel_values: packed images
163 | * cu_seqlens_img:
164 | * max_seqlen_img
165 | * grid_hw
166 | * image_sizes
167 | """
168 | b, c = len(images), images[0].shape[0]
169 | max_patch_num = 1024 // patch_size
170 |
171 | image_sizes = torch.tensor([(image.shape[1], image.shape[2]) for image in images])
172 | H, W = image_sizes.max(dim=0).values
173 | padded_images = images[0].new_zeros(size=(b, c, H.item(), W.item()))
174 |
175 | h, w = torch.tensor(sizes).max(dim=0).values
176 | padding_masks = torch.zeros(size=(b, h.item(), w.item()), dtype=torch.bool)
177 |
178 | for i, (image, mask_size) in enumerate(zip(images, sizes)):
179 | padded_images[i, :, : image.shape[1], : image.shape[2]].copy_(image)
180 | padding_masks[i, : mask_size[0], : mask_size[1]] = 1
181 |
182 | padded_images = padded_images.reshape(b, c, h, patch_size, w, patch_size)
183 | padded_images = torch.einsum("nchpwq->nhwpqc", padded_images)
184 | padded_images = padded_images.reshape(b, h, w, -1)
185 | packed_images = padded_images[padding_masks]
186 |
187 | seq_lens = padding_masks.flatten(1, 2).sum(dim=-1)
188 | cu_seqlens_img = torch.nn.functional.pad(
189 | torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
190 | )
191 | max_seqlen_img = seq_lens.max()
192 |
193 | grid_h = torch.arange(0, h)[None, :, None].repeat(b, 1, w)
194 | grid_w = torch.arange(0, w)[None, None, :].repeat(b, h, 1)
195 | grid_hw = grid_h[padding_masks] * max_patch_num + grid_w[padding_masks]
196 |
197 | return packed_images, cu_seqlens_img, max_seqlen_img, grid_hw, torch.tensor(sizes)
198 |
199 | @dataclass
200 | class WebdatasetCollator:
201 | patch_size: int
202 | def __call__(self, samples: Sequence[Dict]) -> Dict[str, torch.Tensor]:
203 | images = [sample["image"] for sample in samples]
204 | if "size" in samples[0]:
205 | sizes = [sample['size'] for sample in samples]
206 |
207 | batch = {}
208 |
209 | if all(x is not None and x.shape == images[0].shape for x in images):
210 | batch['pixel_values'] = torch.stack(images)
211 | else:
212 | batch['pixel_values'], batch['cu_seqlens_img'], \
213 | batch['max_seqlen_img'], batch['grid_hw'], \
214 | batch['image_sizes'] = collate_anyres(images, sizes, self.patch_size)
215 |
216 | # print(f"{[image.shape for image in batch['pixel_values']]=}")
217 | return batch
218 |
219 | def anyres_process_images_for_model(image_path=None, pil_image=None, patch_size=32):
220 | """
221 | given a list of image_path or pil_image, transform to input to model
222 | """
223 | if image_path is not None:
224 | assert pil_image is None
225 | if not isinstance(image_path, list):
226 | image_path = [image_path]
227 | pil_image = []
228 | for p in image_path:
229 | pil_image.append(Image.open(p).convert('RGB'))
230 | if not isinstance(pil_image, list):
231 | pil_image = [pil_image]
232 |
233 | if len(pil_image) % 2 != 0:
234 | pil_image.append(pil_image[-1])
235 |
236 | image_tensors, sizes = [], []
237 | for pil_i in pil_image:
238 | image_tensor, size = image_transform_original_resolution(image=pil_i, patch_size=patch_size)
239 | image_tensors.append(image_tensor)
240 | sizes.append(size)
241 |
242 | pixel_values, cu_seqlens_img, max_seqlen_img, grid_hw, image_sizes = collate_anyres(image_tensors, sizes, patch_size)
243 |
244 | return {
245 | 'pixel_values': pixel_values,
246 | 'cu_seqlens_img': cu_seqlens_img,
247 | 'max_seqlen_img': max_seqlen_img,
248 | 'grid_hw': grid_hw,
249 | 'image_sizes': image_sizes
250 | }
251 |
252 | def get_in1k_dataset(data_args, model_args):
253 | import torchvision
254 | transform = image_transform(model_args.image_size, is_train=False)
255 | dataset = torchvision.datasets.ImageFolder(root="/share/project/datasets/ImageNet/val", transform=transform)
256 | def in1k_collator(samples):
257 | if model_args.gan_loss_weight:
258 | return {"pixel_values": torch.stack([sample[0] for sample in samples]), "optimizer_idx": 0}
259 | return {"pixel_values": torch.stack([sample[0] for sample in samples])}
260 | def in1k_collator_anyres(samples):
261 | images = [sample[0] for sample in samples]
262 | sizes = [[image.shape[1] // model_args.patch_size, image.shape[2] // model_args.patch_size] for image in images]
263 | b, c = len(images), images[0].shape[0]
264 | max_patch_num = 1024 // model_args.patch_size
265 |
266 | image_sizes = torch.tensor([(image.shape[1], image.shape[2]) for image in images])
267 | H, W = image_sizes.max(dim=0).values
268 | padded_images = images[0].new_zeros(size=(b, c, H.item(), W.item()))
269 |
270 | h, w = torch.tensor(sizes).max(dim=0).values
271 | padding_masks = torch.zeros(size=(b, h.item(), w.item()), dtype=torch.bool)
272 |
273 | for i, (image, mask_size) in enumerate(zip(images, sizes)):
274 | padded_images[i, :, : image.shape[1], : image.shape[2]].copy_(image)
275 | padding_masks[i, : mask_size[0], : mask_size[1]] = 1
276 |
277 | padded_images = padded_images.reshape(b, c, h, model_args.patch_size, w, model_args.patch_size)
278 | padded_images = torch.einsum("nchpwq->nhwpqc", padded_images)
279 | padded_images = padded_images.reshape(b, h, w, -1)
280 | packed_images = padded_images[padding_masks]
281 |
282 | seq_lens = padding_masks.flatten(1, 2).sum(dim=-1)
283 | cu_seqlens_img = torch.nn.functional.pad(
284 | torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
285 | )
286 | max_seqlen_img = seq_lens.max()
287 |
288 | grid_h = torch.arange(0, h)[None, :, None].repeat(b, 1, w)
289 | grid_w = torch.arange(0, w)[None, None, :].repeat(b, h, 1)
290 | grid_hw = grid_h[padding_masks] * max_patch_num + grid_w[padding_masks]
291 |
292 | batch = {}
293 | batch['pixel_values'] = packed_images
294 | batch['cu_seqlens_img'] = cu_seqlens_img
295 | batch['max_seqlen_img'] = max_seqlen_img
296 | batch['grid_hw'] = grid_hw
297 | batch['image_sizes'] = torch.tensor(sizes)
298 | if model_args.gan_loss_weight:
299 | batch["optimizer_idx"] = 0
300 | return batch
301 | # dataset = load_dataset("imagefolder", data_dir="/share/project/qiying/datasets/ImageNet/ImageNet/val")['validation']
302 | # transform = image_transform(model_args.image_size, is_train=False)
303 | # def transforms(examples):
304 | # examples["pixel_values"] = [transform(image) for image in examples["pixel_values"]]
305 | # return examples
306 | # dataset.set_transform(transforms)
307 | # dataset = dataset.remove_columns(["label"])
308 | # dataset = dataset.rename_column('image', 'pixel_values')
309 | return dataset, in1k_collator_anyres if data_args.arbitrary_resolution else in1k_collator
310 |
311 | def get_highres_eval_dataset(data_args, model_args):
312 | # data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/eval/eval_*.tar", split="train", streaming=True)
313 |
314 | # def decode_sample(sample, img_processor):
315 | # sample = find_image(sample)
316 | # sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg'])
317 | # return sample
318 |
319 | # data = data.map(
320 | # partial(
321 | # decode_sample,
322 | # img_processor=partial(image_transform_original_resolution, patch_size=model_args.patch_size)
323 | # ),
324 | # remove_columns=['__key__', '__url__'],
325 | # )
326 | # data = data.filter(lambda sample: '0.jpg' in sample and sample['0.jpg'].ndim == 3 and sample['0.jpg'].shape[-1] > 0 and sample['0.jpg'].shape[-2] > 0) # filter return samples that match the given condition
327 | # data = data.rename_columns({'0.jpg': 'image'})
328 | data = HighresEvalDataset()
329 | data_collator = WebdatasetCollator(model_args.patch_size)
330 |
331 | return data, data_collator
332 |
333 | def image_transform(
334 | image_size: int,
335 | is_train: bool,
336 | ):
337 | mean = OPENAI_DATASET_MEAN
338 | std = OPENAI_DATASET_STD
339 | def _convert_to_rgb(image):
340 | return image.convert('RGB')
341 | normalize = transforms.Normalize(mean=mean, std=std)
342 | if is_train:
343 | return transforms.Compose([
344 | transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
345 | _convert_to_rgb,
346 | transforms.ToTensor(),
347 | normalize,
348 | ])
349 | else:
350 | return transforms.Compose([
351 | transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
352 | transforms.CenterCrop(image_size),
353 | _convert_to_rgb,
354 | transforms.ToTensor(),
355 | normalize,
356 | ])
357 |
358 |
359 | def norm_vq_img(img):
360 | arr = np.array(img)
361 | arr = arr.astype(np.float32) / 127.5 - 1
362 | img = torch.from_numpy(np.transpose(arr, [2, 0, 1]))
363 | return img
364 |
365 |
366 | def prepare_image(img):
367 | """ Transform and normalize PIL Image to tensor. """
368 | transform = transforms.Compose([
369 | transforms.RandomResizedCrop(512, scale=(1., 1.), ratio=(1., 1.), interpolation=InterpolationMode.BICUBIC),
370 | ])
371 | pil_image = transform(img)
372 | arr = np.array(pil_image.convert("RGB"))
373 | arr = arr.astype(np.float32) / 127.5 - 1
374 | return torch.from_numpy(np.transpose(arr, [2, 0, 1]))
375 |
376 |
377 | def image_transform_for_vq(
378 | image_size: int,
379 | is_train: bool,
380 | ):
381 |
382 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
383 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
384 | image_size = image_size[0]
385 |
386 | if is_train:
387 | return Compose([
388 | RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.), interpolation=InterpolationMode.BICUBIC),
389 | _convert_to_rgb,
390 | norm_vq_img,
391 | ])
392 | else:
393 | transforms = [
394 | RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.), interpolation=InterpolationMode.BICUBIC),
395 | _convert_to_rgb,
396 | norm_vq_img
397 | ]
398 | return Compose(transforms)
399 |
400 | def split_val_set():
401 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/*/*.tar", split="train", streaming=True)
402 | data = data.shuffle(buffer_size=100_000, seed=100)
403 | val_set = []
404 | val_ids = set()
405 | for i, item in enumerate(data):
406 | item_id = item['__url__'] + item['__key__']
407 | if i == 50_000:
408 | break
409 | val_set.append(item)
410 | val_ids.add(item_id)
411 | # with open("/share/project/datasets/laion-high-resolution/50k_eval.pkl", "wb") as f:
412 | # pickle.dump(val_set, f)
413 | with open("/share/project/datasets/laion-high-resolution/50k_eval_ids.pkl", "wb") as f:
414 | pickle.dump(val_ids, f)
415 |
416 | import webdataset as wds
417 | from PIL import Image
418 | from pathlib import Path
419 | for i in range(50):
420 | sink = wds.TarWriter(f"/share/project/datasets/laion-high-resolution/eval/eval_{i}.tar")
421 | for sample in val_set[i * 1000: (i + 1) * 1000]:
422 | sink.write(sample)
423 | sink.close()
424 |
425 | def preprocess_val_data():
426 | data = load_dataset("webdataset", data_dir="/share/project/datasets/laion-high-resolution/eval/*.tar", split="train", streaming=True)
427 | save_dir = "/share/project/datasets/laion-high-resolution/eval/"
428 | info = []
429 | def decode_sample(sample):
430 | sample = find_image(sample)
431 | sample['0.jpg'], sample['size'] = image_transform_original_resolution(sample['0.jpg'], patch_size=32)
432 | return sample
433 | for i, item in tqdm(enumerate(data)):
434 | image_file = save_dir + f"image_{i}.pkl"
435 | # process image -> size, image pkl
436 | item = decode_sample(item)
437 | with open(image_file, "wb") as f:
438 | pickle.dump(item['0.jpg'], f)
439 | info.append({
440 | "image_path": image_file,
441 | "size": item['size']
442 | })
443 | print(i)
444 | with open(save_dir + f"image_info.pkl", "wb") as f:
445 | pickle.dump(info, f)
446 |
447 | class HighresEvalDataset(Dataset):
448 | def __init__(self):
449 | with open("/share/project/datasets/laion-high-resolution/eval/image_info.pkl", "rb") as f:
450 | self.info = pickle.load(f)
451 |
452 | def __getitem__(self, index):
453 | info = self.info[index]
454 | image_path, size = info['image_path'], info['size']
455 | with open(image_path, "rb") as f:
456 | image = pickle.load(f)
457 | return {"image": image, "size": size}
458 |
459 | def __len__(self):
460 | return len(self.info)
461 |
462 | # preprocess_val_data()
463 | # split_val_set()
464 |
465 | # def find_image(sample):
466 | # for suffix in DEFAULT_IMAGE_FILE_SUFFIX:
467 | # if suffix in sample.keys():
468 | # sample['0.jpg'] = sample[suffix]
469 | # break
470 | # return sample
471 |
472 | # def decode_sample(sample, img_processor):
473 | # sample = find_image(sample)
474 | # sample['0.jpg'], sample['size'] = img_processor(sample['0.jpg'])
475 | # return sample
476 |
477 | # data = data.map(
478 | # partial(
479 | # decode_sample,
480 | # img_processor=partial(image_transform_original_resolution, patch_size=model_args.patch_size)
481 | # ),
482 | # remove_columns=['__key__', '__url__']
483 | # )
484 | # data = data.filter(lambda sample: '0.jpg' in sample and sample['0.jpg'].ndim == 3 and sample['0.jpg'].shape[-1] > 0 and sample['0.jpg'].shape[-2] > 0) # filter return samples that match the given condition
485 | # data = data.rename_columns({'0.jpg': 'image'})
486 | # data_collator = WebdatasetCollator(model_args.patch_size)
487 |
488 | # return data, data_collator
--------------------------------------------------------------------------------
/data/register.py:
--------------------------------------------------------------------------------
1 | class Registry:
2 | mapping = {
3 | "data_name_mapping": {},
4 | }
5 |
6 | @classmethod
7 | def data_builder(cls, name):
8 | def wrap(data):
9 | cls.mapping["data_name_mapping"][name] = data
10 | return data
11 | return wrap
12 |
13 | @classmethod
14 | def get_data_builder(cls, name):
15 | return cls.mapping["data_name_mapping"].get(name, None)
16 |
17 |
18 | registry = Registry()
19 |
--------------------------------------------------------------------------------
/data/transform.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence, Dict, Union, Tuple
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from torchvision import transforms as T
7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8 | CenterCrop
9 |
10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11 |
12 | def to_tensor(image):
13 | try:
14 | image = ToTensor()(image)
15 | except Exception as e:
16 | image = image.float()
17 | print(e)
18 | return image
19 |
20 | def _convert_to_rgb(image):
21 | try:
22 | image = image.convert('RGB')
23 | except Exception as e:
24 | print(e)
25 | return image
26 |
27 | def image_transform_original_resolution(
28 | image,
29 | patch_size: int,
30 | max_size:int = 2048
31 | ):
32 | """accept a pil image and transform into torch.tensor"""
33 | w, h = map(lambda x: x // patch_size * patch_size, image.size)
34 | if max(w, h) > max_size:
35 | if w > h:
36 | h = int(h / (w / max_size) // patch_size * patch_size)
37 | w = max_size
38 | else:
39 | w = int(w / (h / max_size) // patch_size * patch_size)
40 | h = max_size
41 |
42 | def _convert_to_rgb(image):
43 | return image.convert('RGB')
44 |
45 | normalize = Normalize(
46 | mean=OPENAI_DATASET_MEAN,
47 | std=OPENAI_DATASET_STD
48 | )
49 | transform = Compose([
50 | CenterCrop((h, w)),
51 | _convert_to_rgb,
52 | to_tensor,
53 | normalize,
54 | ])
55 | ph, pw = h // patch_size, w // patch_size
56 | return transform(image), (ph, pw)
57 |
58 | def image_transform_original_resolution_test(
59 | image,
60 | patch_size: int,
61 | ):
62 | w, h = map(lambda x: x // patch_size * patch_size, image.size)
63 | normalize = Normalize(
64 | mean=OPENAI_DATASET_MEAN,
65 | std=OPENAI_DATASET_STD
66 | )
67 | transform = Compose([
68 | Resize((h, w), interpolation=InterpolationMode.BICUBIC),
69 | _convert_to_rgb,
70 | to_tensor,
71 | normalize,
72 | ])
73 | return transform(image)
74 |
75 | def image_transform(
76 | image_size: Union[int, Tuple[int, int]],
77 | is_train: bool,
78 | mean: Optional[Tuple[float, ...]] = None,
79 | std: Optional[Tuple[float, ...]] = None,
80 | ):
81 | mean = mean or OPENAI_DATASET_MEAN
82 | if not isinstance(mean, (list, tuple)):
83 | mean = (mean,) * 3
84 |
85 | std = std or OPENAI_DATASET_STD
86 | if not isinstance(std, (list, tuple)):
87 | std = (std,) * 3
88 |
89 | normalize = Normalize(mean=mean, std=std)
90 |
91 | if is_train:
92 | return Compose([
93 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
94 | _convert_to_rgb,
95 | # ToTensor(),
96 | to_tensor,
97 | normalize,
98 | ])
99 | else:
100 | return Compose([
101 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
102 | CenterCrop(image_size),
103 | _convert_to_rgb,
104 | # ToTensor(),
105 | to_tensor,
106 | normalize,
107 | ])
108 |
109 | def norm_img_vq(img):
110 | arr = np.array(img)
111 | arr = arr.astype(np.float32) / 127.5 - 1
112 | img = torch.from_numpy(np.transpose(arr, [2, 0, 1]))
113 | return img
114 |
115 |
116 | def prepare_image(img):
117 | """ Transform and normalize PIL Image to tensor. """
118 | transform = Compose([
119 | RandomResizedCrop(512, scale=(1., 1.), ratio=(1., 1.), interpolation=InterpolationMode.BICUBIC),
120 | ])
121 | pil_image = transform(img)
122 | arr = np.array(pil_image.convert("RGB"))
123 | arr = arr.astype(np.float32) / 127.5 - 1
124 | return torch.from_numpy(np.transpose(arr, [2, 0, 1]))
125 |
126 |
127 | def image_transform_vq(
128 | image_size: Union[int, Tuple[int, int]],
129 | is_train: bool,
130 | ):
131 |
132 | # if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
133 | # # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
134 | # image_size = image_size[0]
135 | def _convert_to_rgb(image):
136 | return image.convert('RGB')
137 |
138 | return Compose([
139 | RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.), interpolation=InterpolationMode.BICUBIC),
140 | _convert_to_rgb,
141 | norm_img_vq
142 | ])
143 |
144 |
145 | def DiffAugment(x, policy='color,translation,cutout', is_tensor=True, channels_first=True):
146 | if policy:
147 | if not is_tensor and not channels_first:
148 | x = x.permute(0, 3, 1, 2)
149 | for p in policy.split(','):
150 | for f in AUGMENT_FNS[p]:
151 | x = f(x)
152 | if not is_tensor and not channels_first:
153 | x = x.permute(0, 2, 3, 1)
154 | x = x.contiguous()
155 | return x
156 |
157 |
158 | def rand_brightness(x):
159 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
160 | return x
161 |
162 |
163 | def rand_saturation(x):
164 | x_mean = x.mean(dim=1, keepdim=True)
165 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
166 | return x
167 |
168 |
169 | def rand_contrast(x):
170 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
171 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
172 | return x
173 |
174 |
175 | def rand_translation(x, ratio=0.125):
176 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
177 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
178 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
179 | grid_batch, grid_x, grid_y = torch.meshgrid(
180 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
181 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
182 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
183 | )
184 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
185 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
186 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
187 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
188 | return x
189 |
190 |
191 | def rand_cutout(x, ratio=0.5):
192 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
193 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
194 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
195 | grid_batch, grid_x, grid_y = torch.meshgrid(
196 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
197 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
198 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
199 | )
200 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
201 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
202 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
203 | mask[grid_batch, grid_x, grid_y] = 0
204 | x = x * mask.unsqueeze(1)
205 | return x
206 |
207 |
208 | AUGMENT_FNS = {
209 | 'color': [rand_brightness, rand_saturation, rand_contrast],
210 | 'translation': [rand_translation],
211 | 'cutout': [rand_cutout],
212 | }
--------------------------------------------------------------------------------
/models/CLIP_bank.py:
--------------------------------------------------------------------------------
1 | import clip
2 | import torch.nn as nn
3 | from open_clip import create_model_from_pretrained, create_model_and_transforms
4 |
5 |
6 | class OpenAICLIP(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 |
10 | if config.clip_image_size == 224:
11 | model, _ = clip.load("pretrained_weights/CLIP/ViT-L-14.pt", jit=False)
12 | if config.clip_image_size == 336:
13 | model, _ = clip.load("pretrained_weights/CLIP/ViT-L-14-336px.pt",jit=False)
14 |
15 | self.final_fc = nn.Linear(768, config.actual_bs, bias=False)
16 | self.model = model
17 | self.config = config
18 |
19 | def forward(self, images):
20 |
21 | image_features = self.model.encode_image(images).float()
22 | logits = 100. * self.final_fc(image_features[:,0,:]).float()
23 |
24 | return image_features, logits
25 |
26 |
27 | class DFN(nn.Module):
28 | def __init__(self, config):
29 | super().__init__()
30 |
31 | if config.clip_image_size == 224:
32 | model, _ = create_model_from_pretrained(model_name='ViT-H-14-quickgelu', pretrained="pretrained_weights/CLIP/DFN5B-CLIP-ViT-H-14/open_clip_pytorch_model.bin")
33 | if config.clip_image_size == 378:
34 | model, _ = create_model_from_pretrained(model_name='ViT-H-14-378-quickgelu', pretrained="pretrained_weights/CLIP/DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin")
35 |
36 | self.final_fc = nn.Linear(1024, config.actual_bs, bias=False)
37 | self.model = model
38 | self.config = config
39 |
40 | def forward(self, images):
41 |
42 | image_features = self.model.encode_image(images).float()
43 | logits = 100. * self.final_fc(image_features[:,0,:]).float()
44 |
45 | return image_features, logits
46 |
47 |
48 | class SigLIP(nn.Module):
49 | def __init__(self, config):
50 | super().__init__()
51 |
52 | if config.clip_image_size == 224:
53 | model, _ = create_model_from_pretrained(model_name='ViT-SO400M-14-SigLIP', pretrained="pretrained_weights/CLIP/ViT-SO400M-14-SigLIP/open_clip_pytorch_model.bin",
54 | image_mean=([0.5,0.5,0.5]), image_std=([0.5,0.5,0.5]), image_interpolation="bicubic", image_resize_mode="squash")
55 | if config.clip_image_size == 384:
56 | model, _ = create_model_from_pretrained(model_name='ViT-SO400M-14-SigLIP-384', pretrained="pretrained_weights/CLIP/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin",
57 | image_mean=([0.5,0.5,0.5]), image_std=([0.5,0.5,0.5]), image_interpolation="bicubic", image_resize_mode="squash")
58 |
59 | self.final_fc = nn.Linear(1152, config.actual_bs, bias=False)
60 | self.model = model
61 | self.config = config
62 |
63 | def forward(self, images):
64 |
65 | image_features = self.model.encode_image(images).float()
66 | logits = 100. * self.final_fc(image_features[:,0,:]).float()
67 |
68 | return image_features, logits
69 |
70 |
71 | class MetaCLIP(nn.Module):
72 | def __init__(self, config):
73 | super().__init__()
74 |
75 | if config.metaclip_version == "large":
76 | model, _, _ = create_model_and_transforms(model_name='ViT-L-14-quickgelu', pretrained="pretrained_weights/CLIP/MetaCLIP/l14_fullcc2.5b.pt")
77 | self.final_fc = nn.Linear(768, config.actual_bs, bias=False)
78 | if config.metaclip_version == "huge":
79 | model, _, _ = create_model_and_transforms(model_name='ViT-H-14-quickgelu', pretrained="pretrained_weights/CLIP/MetaCLIP/h14_fullcc2.5b.pt")
80 | self.final_fc = nn.Linear(1024, config.actual_bs, bias=False)
81 |
82 | self.model = model
83 | self.config = config
84 |
85 | def forward(self, images):
86 |
87 | image_features = self.model.encode_image(images).float()
88 | logits = 100. * self.final_fc(image_features[:,0,:]).float()
89 |
90 | return image_features, logits
91 |
--------------------------------------------------------------------------------
/models/SD_with_DFN.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Tuple
3 | from dataclasses import dataclass
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 | import numpy as np
9 | import torchvision
10 | from transformers.utils import ModelOutput
11 | from transformers.modeling_utils import PreTrainedModel
12 | from .build import load_sd_model, load_clip_model_DFN
13 | from .utils import initiate_time_steps, prepare_class_text_embeddings
14 | import torchvision.transforms as transforms
15 | from open_clip import get_tokenizer
16 |
17 | @dataclass
18 | class SDOutput(ModelOutput):
19 | loss: Optional[torch.FloatTensor] = None
20 |
21 | class SDModel(PreTrainedModel):
22 | def __init__(
23 | self,
24 | config = None,
25 | ):
26 | super().__init__(config)
27 |
28 | self.model_id, self.pipe, self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler, self.image_renormalizer = load_sd_model(config)
29 | self.text_encoder.eval()
30 | self.vae.eval()
31 | self.unet.eval()
32 |
33 | self.pattern_dictionary={'None':['']}
34 | self.config.actual_bs = len(self.pattern_dictionary[self.config.visual_pattern])
35 | self.class_model = load_clip_model_DFN(config)
36 | self.class_model.eval()
37 | self.config = config
38 | discrimi_size = self.config.clip_image_size
39 | self.resize_transform_discrimi = transforms.Resize((discrimi_size, discrimi_size))
40 | self.visual_proj = nn.Linear(1024, 1024)
41 |
42 |
43 | def classify(self, image, classes):
44 |
45 | image_features, logits = self.class_model(image)
46 |
47 | if classes is not None:
48 | logits = logits[:, classes]
49 |
50 | probs = logits.softmax(-1)
51 | max_idx = probs.argmax(-1)
52 | K = probs.shape[-1] if self.config.tta.adapt_topk == -1 else self.config.tta.adapt_topk
53 | topk_idx = probs.argsort(descending=True)[:, :K]
54 |
55 | if classes is not None:
56 | classes = torch.tensor(classes).to(logits.device)
57 | max_class_idx = classes[max_idx.flatten()].view(max_idx.shape)
58 | topk_class_idx = classes[topk_idx.flatten()].view(topk_idx.shape)
59 | else:
60 | max_class_idx, topk_class_idx = max_idx, topk_idx
61 |
62 | return image_features, logits, topk_idx, max_class_idx, topk_class_idx
63 |
64 | def _unet_pred_noise(self, x_start, t, noise, context):
65 |
66 | _,c,h,w = x_start.shape
67 | device = t.device
68 | nt = t.shape[0]
69 |
70 | x_start = x_start.unsqueeze(1)
71 | x_start = x_start.expand(-1, nt//x_start.shape[0], -1, -1, -1)
72 | x_start = x_start.reshape(-1,c,h,w)
73 |
74 | alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
75 | noised_latent = (
76 | x_start * (alphas_cumprod[t]**0.5).view(-1, 1, 1, 1).to(device)
77 | + noise * ((1 - alphas_cumprod[t])**0.5).view(-1, 1, 1, 1).to(device)
78 | )
79 | pred_noise = self.unet(noised_latent, t, encoder_hidden_states=context.expand(nt, -1, -1)).sample
80 |
81 | return pred_noise
82 |
83 | def zeroshot_classifier_DFN(self, classnames, templates, model):
84 | with torch.no_grad():
85 | zeroshot_weights = []
86 | tokenizer = get_tokenizer('ViT-H-14')
87 | for classname in classnames:
88 | texts = [template.format(classname) for template in templates]
89 | texts = tokenizer(texts, context_length=model.context_length).cuda()
90 | class_embeddings = model.encode_text(texts)
91 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
92 | class_embedding = class_embeddings.mean(dim=0)
93 | class_embedding /= class_embedding.norm()
94 | zeroshot_weights.append(class_embedding)
95 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
96 | return zeroshot_weights
97 |
98 | def forward(
99 | self,
100 | image: torch.Tensor = None,
101 | text = None
102 | ) -> SDOutput:
103 |
104 | text = self.pattern_dictionary[self.config.visual_pattern]
105 | with torch.no_grad():
106 | imagenet_templates = ['{}',]
107 | zeroshot_weights = self.zeroshot_classifier_DFN(text, imagenet_templates, self.class_model.model.float()).float()
108 |
109 | self.class_model.final_fc.weight.data = zeroshot_weights.T
110 | self.class_model.final_fc.weight.data = self.class_model.final_fc.weight.data.contiguous()
111 | classes = [i for i in range(len(text))]
112 |
113 | discrimi_image = self.resize_transform_discrimi(image)
114 | genera_image = image
115 | real_BS = image.shape[0]
116 | after_DF_expand_BS = real_BS*self.config.input.batch_size
117 |
118 | # prepare_vae_latent
119 | self.vae, self.text_encoder, self.unet = self.vae.to(torch.float32), self.text_encoder.to(torch.float32), self.unet.to(torch.float32)
120 | renormed_image = self.image_renormalizer(genera_image).detach()
121 | x0 = self.vae.encode(renormed_image).latent_dist.mean.float()
122 | latent = x0 * 0.18215
123 |
124 | # prepare_total_timesteps
125 | total_timestep = self.scheduler.num_train_timesteps
126 |
127 | for step in range(self.config.tta.gradient_descent.train_steps):
128 | # Initiate timesteps and noise
129 | timesteps = initiate_time_steps(step, total_timestep, after_DF_expand_BS, self.config).long()
130 | timesteps = timesteps.cuda()
131 |
132 | c, h, w = latent.shape[1:]
133 | if not self.config.tta.use_same_noise_among_timesteps:
134 | noise = torch.randn((real_BS* self.config.input.batch_size, c, h, w)).cuda()
135 | else:
136 | noise = torch.randn((1, c, h, w)).cuda()
137 | noise = noise.repeat(real_BS* self.config.input.batch_size, 1, 1, 1)
138 |
139 | if self.config.tta.adapt_topk == -1:
140 | image_features, logits, _, _, _ = self.classify(discrimi_image, classes)
141 | pred_top_idx = None
142 | else:
143 | image_features, logits, pred_top_idx, _, _ = self.classify(discrimi_image, classes)
144 | real_BS, C = logits.shape[:2]
145 |
146 | # Pick top-K predictions
147 | if pred_top_idx is not None:
148 | pred_top_idx = pred_top_idx.squeeze(0)
149 | else:
150 | pred_top_idx = torch.arange(C).cuda()
151 |
152 | logits = logits[:, pred_top_idx]
153 |
154 | class_text_embeddings = prepare_class_text_embeddings(self.tokenizer, self.text_encoder, class_names=text)
155 | class_text_embeddings = class_text_embeddings.detach()
156 | class_text_embeddings = class_text_embeddings[pred_top_idx, :]
157 |
158 | # Compute conditional text embeddings using weighted-summed predictions
159 | probs = logits.softmax(-1)
160 | probs = probs[:, :, None, None]
161 | class_text_embeddings = (class_text_embeddings.unsqueeze(0).repeat(after_DF_expand_BS, 1, 1, 1))
162 | _, word_num, _, _ = probs.shape
163 | probs = probs.unsqueeze(1).repeat(1,self.config.input.batch_size,1,1,1).reshape(-1,word_num,1,1)
164 | context = (probs * class_text_embeddings).sum(1)
165 | image_features = self.visual_proj(image_features)
166 | context = context.mean(dim=1).unsqueeze(1) + image_features
167 |
168 | # Predict noise with the diffusion model
169 | pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=context).float()
170 |
171 | # Compute diffusion loss
172 | if self.config.tta.loss == "l1":
173 | loss = torch.nn.functional.l1_loss(pred_noise, noise)
174 | else:
175 | loss = torch.nn.functional.mse_loss(pred_noise, noise)
176 |
177 | if step != (self.config.tta.gradient_descent.train_steps-1):
178 | loss.backward()
179 |
180 | return SDOutput(loss=loss)
--------------------------------------------------------------------------------
/models/SD_with_MetaCLIP.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 | from dataclasses import dataclass
4 | import torch
5 | import torch.nn as nn
6 | from transformers.utils import ModelOutput
7 | from transformers.modeling_utils import PreTrainedModel
8 | from .build import load_sd_model, load_clip_model_MetaCLIP
9 | from .utils import initiate_time_steps, prepare_class_text_embeddings
10 | import torchvision.transforms as transforms
11 | from open_clip import tokenize
12 |
13 | @dataclass
14 | class SDOutput(ModelOutput):
15 | loss: Optional[torch.FloatTensor] = None
16 |
17 | class SDModel(PreTrainedModel):
18 | def __init__(
19 | self,
20 | config = None,
21 | ):
22 | super().__init__(config)
23 |
24 | self.model_id, self.pipe, self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler, self.image_renormalizer = load_sd_model(config)
25 | self.text_encoder.eval()
26 | self.vae.eval()
27 | self.unet.eval()
28 |
29 | self.pattern_dictionary={'None':['']}
30 | self.config.actual_bs = len(self.pattern_dictionary[self.config.visual_pattern])
31 | self.class_model = load_clip_model_MetaCLIP(config)
32 | self.class_model.eval()
33 | self.config = config
34 | discrimi_size = self.config.clip_image_size
35 | self.resize_transform_discrimi = transforms.Resize((discrimi_size, discrimi_size))
36 | if config.metaclip_version == "large":
37 | self.visual_proj = nn.Linear(768, 1024)
38 | if config.metaclip_version == "huge":
39 | self.visual_proj = nn.Linear(1024, 1024)
40 |
41 | def classify(self, image, classes):
42 |
43 | image_features, logits = self.class_model(image)
44 |
45 | if classes is not None:
46 | logits = logits[:, classes]
47 |
48 | probs = logits.softmax(-1)
49 | max_idx = probs.argmax(-1)
50 | K = probs.shape[-1] if self.config.tta.adapt_topk == -1 else self.config.tta.adapt_topk
51 | topk_idx = probs.argsort(descending=True)[:, :K]
52 |
53 | if classes is not None:
54 | classes = torch.tensor(classes).to(logits.device)
55 | max_class_idx = classes[max_idx.flatten()].view(max_idx.shape)
56 | topk_class_idx = classes[topk_idx.flatten()].view(topk_idx.shape)
57 | else:
58 | max_class_idx, topk_class_idx = max_idx, topk_idx
59 |
60 | return image_features, logits, topk_idx, max_class_idx, topk_class_idx
61 |
62 | def _unet_pred_noise(self, x_start, t, noise, context):
63 |
64 | _,c,h,w = x_start.shape
65 | device = t.device
66 | nt = t.shape[0]
67 |
68 | x_start = x_start.unsqueeze(1)
69 | x_start = x_start.expand(-1, nt//x_start.shape[0], -1, -1, -1)
70 | x_start = x_start.reshape(-1,c,h,w)
71 |
72 | alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
73 | noised_latent = (
74 | x_start * (alphas_cumprod[t]**0.5).view(-1, 1, 1, 1).to(device)
75 | + noise * ((1 - alphas_cumprod[t])**0.5).view(-1, 1, 1, 1).to(device)
76 | )
77 | pred_noise = self.unet(noised_latent, t, encoder_hidden_states=context.expand(nt, -1, -1)).sample
78 |
79 | return pred_noise
80 |
81 | def zeroshot_classifier_MetaCLIP(self, classnames, templates, model):
82 | with torch.no_grad():
83 | zeroshot_weights = []
84 | for classname in classnames:
85 | texts = [template.format(classname) for template in templates] #format with class
86 | texts = tokenize(texts).cuda() #tokenize
87 | class_embeddings = model.encode_text(texts) #embed with text encoder
88 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
89 | class_embedding = class_embeddings.mean(dim=0)
90 | class_embedding /= class_embedding.norm()
91 | zeroshot_weights.append(class_embedding)
92 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
93 | return zeroshot_weights
94 |
95 | def forward(
96 | self,
97 | image: torch.Tensor = None,
98 | text = None
99 | ) -> SDOutput:
100 |
101 | text = self.pattern_dictionary[self.config.visual_pattern]
102 | with torch.no_grad():
103 | imagenet_templates = ['{}',]
104 | zeroshot_weights = self.zeroshot_classifier_MetaCLIP(text, imagenet_templates, self.class_model.model.float()).float()
105 |
106 | self.class_model.final_fc.weight.data = zeroshot_weights.T
107 | self.class_model.final_fc.weight.data = self.class_model.final_fc.weight.data.contiguous()
108 | classes = [i for i in range(len(text))]
109 |
110 | discrimi_image = self.resize_transform_discrimi(image)
111 | genera_image = image
112 | real_BS = image.shape[0]
113 | after_DF_expand_BS = real_BS*self.config.input.batch_size
114 |
115 | # prepare_vae_latent
116 | self.vae, self.text_encoder, self.unet = self.vae.to(torch.float32), self.text_encoder.to(torch.float32), self.unet.to(torch.float32)
117 | renormed_image = self.image_renormalizer(genera_image).detach()
118 | x0 = self.vae.encode(renormed_image).latent_dist.mean.float()
119 | latent = x0 * 0.18215
120 |
121 | # prepare_total_timesteps
122 | total_timestep = self.scheduler.num_train_timesteps
123 |
124 | for step in range(self.config.tta.gradient_descent.train_steps):
125 | # Initiate timesteps and noise
126 | timesteps = initiate_time_steps(step, total_timestep, after_DF_expand_BS, self.config).long()
127 | timesteps = timesteps.cuda()
128 |
129 | c, h, w = latent.shape[1:]
130 | if not self.config.tta.use_same_noise_among_timesteps:
131 | noise = torch.randn((real_BS* self.config.input.batch_size, c, h, w)).cuda()
132 | else:
133 | noise = torch.randn((1, c, h, w)).cuda()
134 | noise = noise.repeat(real_BS* self.config.input.batch_size, 1, 1, 1)
135 |
136 | if self.config.tta.adapt_topk == -1:
137 | image_features, logits, _, _, _ = self.classify(discrimi_image, classes)
138 | pred_top_idx = None
139 | else:
140 | image_features, logits, pred_top_idx, _, _ = self.classify(discrimi_image, classes)
141 | real_BS, C = logits.shape[:2]
142 |
143 | # Pick top-K predictions
144 | if pred_top_idx is not None:
145 | pred_top_idx = pred_top_idx.squeeze(0)
146 | else:
147 | pred_top_idx = torch.arange(C).cuda()
148 |
149 | logits = logits[:, pred_top_idx]
150 |
151 | class_text_embeddings = prepare_class_text_embeddings(self.tokenizer, self.text_encoder, class_names=text)
152 | class_text_embeddings = class_text_embeddings.detach()
153 | class_text_embeddings = class_text_embeddings[pred_top_idx, :]
154 |
155 | # Compute conditional text embeddings using weighted-summed predictions
156 | probs = logits.softmax(-1)
157 | probs = probs[:, :, None, None]
158 | class_text_embeddings = (class_text_embeddings.unsqueeze(0).repeat(after_DF_expand_BS, 1, 1, 1))
159 | _, word_num, _, _ = probs.shape
160 | probs = probs.unsqueeze(1).repeat(1,self.config.input.batch_size,1,1,1).reshape(-1,word_num,1,1)
161 | context = (probs * class_text_embeddings).sum(1)
162 | image_features = self.visual_proj(image_features)
163 | context = context.mean(dim=1).unsqueeze(1) + image_features
164 |
165 | # Predict noise with the diffusion model
166 | pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=context).float()
167 |
168 | # Compute diffusion loss
169 | if self.config.tta.loss == "l1":
170 | loss = torch.nn.functional.l1_loss(pred_noise, noise)
171 | else:
172 | loss = torch.nn.functional.mse_loss(pred_noise, noise)
173 |
174 | if step != (self.config.tta.gradient_descent.train_steps-1):
175 | loss.backward()
176 |
177 | return SDOutput(loss=loss)
--------------------------------------------------------------------------------
/models/SD_with_OpenAICLIP.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 | from dataclasses import dataclass
4 | import torch
5 | import torch.nn as nn
6 | from einops import rearrange
7 | from transformers.utils import ModelOutput
8 | from transformers.modeling_utils import PreTrainedModel
9 | from .build import load_sd_model, load_clip_model_OpenAICLIP
10 | from .utils import initiate_time_steps, prepare_class_text_embeddings
11 | import torchvision.transforms as transforms
12 | import clip
13 |
14 |
15 | @dataclass
16 | class SDOutput(ModelOutput):
17 | loss: Optional[torch.FloatTensor] = None
18 |
19 |
20 | class SDModel(PreTrainedModel):
21 | def __init__(
22 | self,
23 | config = None,
24 | ):
25 | super().__init__(config)
26 |
27 | self.model_id, self.pipe, self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler, self.image_renormalizer = load_sd_model(config)
28 | self.text_encoder.eval()
29 | self.vae.eval()
30 | self.unet.eval()
31 |
32 | self.pattern_dictionary={'None':['']}
33 | self.config.actual_bs = len(self.pattern_dictionary[self.config.visual_pattern])
34 | self.class_model = load_clip_model_OpenAICLIP(config)
35 | self.class_model.eval()
36 | self.config = config
37 | discrimi_size = self.config.clip_image_size
38 | self.resize_transform_discrimi = transforms.Resize((discrimi_size, discrimi_size))
39 | self.visual_proj = nn.Linear(768, 1024)
40 |
41 | def classify(self, image, classes):
42 |
43 | image_features, logits = self.class_model(image)
44 |
45 | if classes is not None:
46 | logits = logits[:, classes]
47 |
48 | probs = logits.softmax(-1)
49 | max_idx = probs.argmax(-1)
50 | K = probs.shape[-1] if self.config.tta.adapt_topk == -1 else self.config.tta.adapt_topk
51 | topk_idx = probs.argsort(descending=True)[:, :K]
52 |
53 | if classes is not None:
54 | classes = torch.tensor(classes).to(logits.device)
55 | max_class_idx = classes[max_idx.flatten()].view(max_idx.shape)
56 | topk_class_idx = classes[topk_idx.flatten()].view(topk_idx.shape)
57 | else:
58 | max_class_idx, topk_class_idx = max_idx, topk_idx
59 |
60 | return image_features, logits, topk_idx, max_class_idx, topk_class_idx
61 |
62 | def _unet_pred_noise(self, x_start, t, noise, context):
63 |
64 | _,c,h,w = x_start.shape
65 | device = t.device
66 | nt = t.shape[0]
67 |
68 | x_start = x_start.unsqueeze(1)
69 | x_start = x_start.expand(-1, nt//x_start.shape[0], -1, -1, -1)
70 | x_start = x_start.reshape(-1,c,h,w)
71 |
72 | alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
73 | noised_latent = (
74 | x_start * (alphas_cumprod[t]**0.5).view(-1, 1, 1, 1).to(device)
75 | + noise * ((1 - alphas_cumprod[t])**0.5).view(-1, 1, 1, 1).to(device)
76 | )
77 | pred_noise = self.unet(noised_latent, t, encoder_hidden_states=context.expand(nt, -1, -1)).sample
78 |
79 | return pred_noise
80 |
81 | def zeroshot_classifier(self, classnames, templates, model):
82 | with torch.no_grad():
83 | zeroshot_weights = []
84 | for classname in classnames:
85 | texts = [template.format(classname) for template in templates]
86 | texts = clip.tokenize(texts, truncate=True).cuda()
87 | class_embeddings = model.encode_text(texts)
88 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
89 | class_embedding = class_embeddings.mean(dim=0)
90 | class_embedding /= class_embedding.norm()
91 | zeroshot_weights.append(class_embedding)
92 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
93 | return zeroshot_weights
94 |
95 | def forward(
96 | self,
97 | image: torch.Tensor = None,
98 | text = None
99 | ) -> SDOutput:
100 |
101 | text = self.pattern_dictionary[self.config.visual_pattern]
102 | with torch.no_grad():
103 | imagenet_templates = ['{}',]
104 | zeroshot_weights = self.zeroshot_classifier(text, imagenet_templates, self.class_model.model.float()).float()
105 |
106 | self.class_model.final_fc.weight.data = zeroshot_weights.T
107 | self.class_model.final_fc.weight.data = self.class_model.final_fc.weight.data.contiguous()
108 | classes = [i for i in range(len(text))]
109 |
110 | discrimi_image = self.resize_transform_discrimi(image)
111 | genera_image = image
112 | real_BS = image.shape[0]
113 | after_DF_expand_BS = real_BS*self.config.input.batch_size
114 |
115 | # prepare_vae_latent
116 | self.vae, self.text_encoder, self.unet = self.vae.to(torch.float32), self.text_encoder.to(torch.float32), self.unet.to(torch.float32)
117 | renormed_image = self.image_renormalizer(genera_image).detach()
118 | x0 = self.vae.encode(renormed_image).latent_dist.mean.float()
119 | latent = x0 * 0.18215
120 |
121 | # prepare_total_timesteps
122 | total_timestep = self.scheduler.num_train_timesteps
123 |
124 | for step in range(self.config.tta.gradient_descent.train_steps):
125 | # Initiate timesteps and noise
126 | timesteps = initiate_time_steps(step, total_timestep, after_DF_expand_BS, self.config).long()
127 | timesteps = timesteps.cuda()
128 |
129 | c, h, w = latent.shape[1:]
130 | if not self.config.tta.use_same_noise_among_timesteps:
131 | noise = torch.randn((real_BS* self.config.input.batch_size, c, h, w)).cuda()
132 | else:
133 | noise = torch.randn((1, c, h, w)).cuda()
134 | noise = noise.repeat(real_BS* self.config.input.batch_size, 1, 1, 1)
135 |
136 | if self.config.tta.adapt_topk == -1:
137 | image_features, logits, _, _, _ = self.classify(discrimi_image, classes)
138 | pred_top_idx = None
139 | else:
140 | image_features, logits, pred_top_idx, _, _ = self.classify(discrimi_image, classes)
141 | real_BS, C = logits.shape[:2]
142 |
143 | # Pick top-K predictions
144 | if pred_top_idx is not None:
145 | pred_top_idx = pred_top_idx.squeeze(0)
146 | else:
147 | pred_top_idx = torch.arange(C).cuda()
148 |
149 | logits = logits[:, pred_top_idx]
150 |
151 | class_text_embeddings = prepare_class_text_embeddings(self.tokenizer, self.text_encoder, class_names=text)
152 | class_text_embeddings = class_text_embeddings.detach()
153 | class_text_embeddings = class_text_embeddings[pred_top_idx, :]
154 |
155 | # Compute conditional text embeddings using weighted-summed predictions
156 | probs = logits.softmax(-1)
157 | probs = probs[:, :, None, None]
158 | class_text_embeddings = (class_text_embeddings.unsqueeze(0).repeat(after_DF_expand_BS, 1, 1, 1))
159 | _, word_num, _, _ = probs.shape
160 | probs = probs.unsqueeze(1).repeat(1,self.config.input.batch_size,1,1,1).reshape(-1,word_num,1,1)
161 | context = (probs * class_text_embeddings).sum(1)
162 | image_features = self.visual_proj(image_features)
163 | context = context + image_features
164 |
165 | # Predict noise with the diffusion model
166 | pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=context).float()
167 |
168 | # Compute diffusion loss
169 | if self.config.tta.loss == "l1":
170 | loss = torch.nn.functional.l1_loss(pred_noise, noise)
171 | else:
172 | loss = torch.nn.functional.mse_loss(pred_noise, noise)
173 |
174 | if step != (self.config.tta.gradient_descent.train_steps-1):
175 | loss.backward()
176 |
177 | return SDOutput(loss=loss)
--------------------------------------------------------------------------------
/models/SD_with_SigLIP.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Tuple
3 | from dataclasses import dataclass
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 | import numpy as np
9 | import torchvision
10 | from transformers.utils import ModelOutput
11 | from transformers.modeling_utils import PreTrainedModel
12 | from .build import load_sd_model, load_clip_model_SigLIP
13 | from .utils import initiate_time_steps, prepare_class_text_embeddings
14 | import torchvision.transforms as transforms
15 | from open_clip import get_tokenizer
16 |
17 | @dataclass
18 | class SDOutput(ModelOutput):
19 | loss: Optional[torch.FloatTensor] = None
20 |
21 | class SDModel(PreTrainedModel):
22 | def __init__(
23 | self,
24 | config = None,
25 | ):
26 | super().__init__(config)
27 |
28 | self.model_id, self.pipe, self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler, self.image_renormalizer = load_sd_model(config)
29 | self.text_encoder.eval()
30 | self.vae.eval()
31 | self.unet.eval()
32 |
33 | self.pattern_dictionary={'None':['']}
34 | self.config.actual_bs = len(self.pattern_dictionary[self.config.visual_pattern])
35 | self.class_model = load_clip_model_SigLIP(config)
36 | self.class_model.eval()
37 | self.config = config
38 | discrimi_size = self.config.clip_image_size
39 | self.resize_transform_discrimi = transforms.Resize((discrimi_size, discrimi_size))
40 | self.visual_proj = nn.Linear(1152, 1024)
41 |
42 |
43 | def classify(self, image, classes):
44 |
45 | image_features, logits = self.class_model(image)
46 |
47 | if classes is not None:
48 | logits = logits[:, classes]
49 |
50 | probs = logits.softmax(-1)
51 | max_idx = probs.argmax(-1)
52 | K = probs.shape[-1] if self.config.tta.adapt_topk == -1 else self.config.tta.adapt_topk
53 | topk_idx = probs.argsort(descending=True)[:, :K]
54 |
55 | if classes is not None:
56 | classes = torch.tensor(classes).to(logits.device)
57 | max_class_idx = classes[max_idx.flatten()].view(max_idx.shape)
58 | topk_class_idx = classes[topk_idx.flatten()].view(topk_idx.shape)
59 | else:
60 | max_class_idx, topk_class_idx = max_idx, topk_idx
61 |
62 | return image_features, logits, topk_idx, max_class_idx, topk_class_idx
63 |
64 | def _unet_pred_noise(self, x_start, t, noise, context):
65 |
66 | _,c,h,w = x_start.shape
67 | device = t.device
68 | nt = t.shape[0]
69 |
70 | x_start = x_start.unsqueeze(1)
71 | x_start = x_start.expand(-1, nt//x_start.shape[0], -1, -1, -1)
72 | x_start = x_start.reshape(-1,c,h,w)
73 |
74 | alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
75 | noised_latent = (
76 | x_start * (alphas_cumprod[t]**0.5).view(-1, 1, 1, 1).to(device)
77 | + noise * ((1 - alphas_cumprod[t])**0.5).view(-1, 1, 1, 1).to(device)
78 | )
79 | pred_noise = self.unet(noised_latent, t, encoder_hidden_states=context.expand(nt, -1, -1)).sample
80 |
81 | return pred_noise
82 |
83 | def zeroshot_classifier_SigLIP(self, classnames, templates, model):
84 | with torch.no_grad():
85 | zeroshot_weights = []
86 | if self.config.clip_image_size == 224:
87 | tokenizer = get_tokenizer('ViT-SO400M-14-SigLIP')
88 | if self.config.clip_image_size == 384:
89 | tokenizer = get_tokenizer('ViT-SO400M-14-SigLIP-384')
90 | for classname in classnames:
91 | texts = [template.format(classname) for template in templates]
92 | texts = tokenizer(texts, context_length=model.context_length).cuda()
93 | class_embeddings = model.encode_text(texts)
94 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
95 | class_embedding = class_embeddings.mean(dim=0)
96 | class_embedding /= class_embedding.norm()
97 | zeroshot_weights.append(class_embedding)
98 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
99 | return zeroshot_weights
100 |
101 | def forward(
102 | self,
103 | image: torch.Tensor = None,
104 | text = None
105 | ) -> SDOutput:
106 |
107 | text = self.pattern_dictionary[self.config.visual_pattern]
108 | with torch.no_grad():
109 | imagenet_templates = ['{}',]
110 | zeroshot_weights = self.zeroshot_classifier_SigLIP(text, imagenet_templates, self.class_model.model.float()).float()
111 |
112 | self.class_model.final_fc.weight.data = zeroshot_weights.T
113 | self.class_model.final_fc.weight.data = self.class_model.final_fc.weight.data.contiguous()
114 | classes = [i for i in range(len(text))]
115 |
116 | discrimi_image = self.resize_transform_discrimi(image)
117 | genera_image = image
118 | real_BS = image.shape[0]
119 | after_DF_expand_BS = real_BS*self.config.input.batch_size
120 |
121 | # prepare_vae_latent
122 | self.vae, self.text_encoder, self.unet = self.vae.to(torch.float32), self.text_encoder.to(torch.float32), self.unet.to(torch.float32)
123 | renormed_image = self.image_renormalizer(genera_image).detach()
124 | x0 = self.vae.encode(renormed_image).latent_dist.mean.float()
125 | latent = x0 * 0.18215
126 |
127 | # prepare_total_timesteps
128 | total_timestep = self.scheduler.num_train_timesteps
129 |
130 | for step in range(self.config.tta.gradient_descent.train_steps):
131 | # Initiate timesteps and noise
132 | timesteps = initiate_time_steps(step, total_timestep, after_DF_expand_BS, self.config).long()
133 | timesteps = timesteps.cuda()
134 |
135 | c, h, w = latent.shape[1:]
136 | if not self.config.tta.use_same_noise_among_timesteps:
137 | noise = torch.randn((real_BS* self.config.input.batch_size, c, h, w)).cuda()
138 | else:
139 | noise = torch.randn((1, c, h, w)).cuda()
140 | noise = noise.repeat(real_BS* self.config.input.batch_size, 1, 1, 1)
141 |
142 | if self.config.tta.adapt_topk == -1:
143 | image_features, logits, _, _, _ = self.classify(discrimi_image, classes)
144 | pred_top_idx = None
145 | else:
146 | image_features, logits, pred_top_idx, _, _ = self.classify(discrimi_image, classes)
147 | real_BS, C = logits.shape[:2]
148 |
149 | # Pick top-K predictions
150 | if pred_top_idx is not None:
151 | pred_top_idx = pred_top_idx.squeeze(0)
152 | else:
153 | pred_top_idx = torch.arange(C).cuda()
154 |
155 | logits = logits[:, pred_top_idx]
156 |
157 | class_text_embeddings = prepare_class_text_embeddings(self.tokenizer, self.text_encoder, class_names=text)
158 | class_text_embeddings = class_text_embeddings.detach()
159 | class_text_embeddings = class_text_embeddings[pred_top_idx, :]
160 |
161 | # Compute conditional text embeddings using weighted-summed predictions
162 | probs = logits.softmax(-1)
163 | probs = probs[:, :, None, None]
164 | class_text_embeddings = (class_text_embeddings.unsqueeze(0).repeat(after_DF_expand_BS, 1, 1, 1))
165 | _, word_num, _, _ = probs.shape
166 | probs = probs.unsqueeze(1).repeat(1,self.config.input.batch_size,1,1,1).reshape(-1,word_num,1,1)
167 | context = (probs * class_text_embeddings).sum(1)
168 | image_features = self.visual_proj(image_features)
169 | context = context.mean(dim=1).unsqueeze(1) + image_features
170 |
171 | # Predict noise with the diffusion model
172 | pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=context).float()
173 |
174 | # Compute diffusion loss
175 | if self.config.tta.loss == "l1":
176 | loss = torch.nn.functional.l1_loss(pred_noise, noise)
177 | else:
178 | loss = torch.nn.functional.mse_loss(pred_noise, noise)
179 |
180 | if step != (self.config.tta.gradient_descent.train_steps-1):
181 | loss.backward()
182 |
183 | return SDOutput(loss=loss)
--------------------------------------------------------------------------------
/models/build.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionPipeline, DDPMScheduler, EulerDiscreteScheduler
3 | from transformers import CLIPTextModel, CLIPTokenizer
4 | from .utils import VQVAEUnNormalize
5 | from .CLIP_bank import OpenAICLIP, DFN, SigLIP, MetaCLIP
6 |
7 | def load_sd_model(config):
8 | """Load Stable Diffusion model"""
9 | dtype = torch.float32
10 | image_renormalizer = VQVAEUnNormalize(
11 | mean=config.input.mean, std=config.input.std
12 | )
13 | if config.model.sd_version == '1-4':
14 | if config.model.use_flash:
15 | model_id = "CompVis/stable-diffusion-v1-4"
16 | scheduler = EulerDiscreteScheduler.from_pretrained(
17 | model_id, subfolder="scheduler"
18 | )
19 | pipe = StableDiffusionPipeline.from_pretrained(
20 | model_id, scheduler=scheduler, torch_dtype=dtype
21 | ).cuda()
22 | pipe.enable_xformers_memory_efficient_attention()
23 | vae = pipe.vae.cuda()
24 | tokenizer = pipe.tokenizer
25 | text_encoder = pipe.text_encoder.cuda()
26 | unet = pipe.unet.cuda()
27 | else:
28 | vae = AutoencoderKL.from_pretrained(
29 | f"CompVis/stable-diffusion-v{config.model.sd_version}",
30 | subfolder="vae", torch_dtype=dtype
31 | ).cuda()
32 | tokenizer = CLIPTokenizer.from_pretrained(
33 | "/share/project/wangwenxuan/projects/Overcome_VS/MMVP_Test/openai/clip-vit-large-patch14"
34 | )
35 | text_encoder = CLIPTextModel.from_pretrained(
36 | "/share/project/wangwenxuan/projects/Overcome_VS/MMVP_Test/openai/clip-vit-large-patch14", torch_dtype=dtype
37 | ).cuda()
38 | unet = UNet2DConditionModel.from_pretrained(
39 | f"CompVis/stable-diffusion-v{config.model.sd_version}",
40 | subfolder="unet", torch_dtype=dtype
41 | ).cuda()
42 | scheduler_config = get_scheduler_config(config)
43 | scheduler = DDPMScheduler(
44 | num_train_timesteps=scheduler_config['num_train_timesteps'],
45 | beta_start=scheduler_config['beta_start'],
46 | beta_end=scheduler_config['beta_end'],
47 | beta_schedule=scheduler_config['beta_schedule']
48 | )
49 | elif config.model.sd_version == '2-1':
50 |
51 | model_id = "pretrained_weights/SD/stable-diffusion-2-1-base"
52 | print(f'model_id:{model_id}')
53 |
54 | scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
55 | pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=dtype)
56 | pipe.to(dtype)
57 |
58 | pipe.enable_xformers_memory_efficient_attention()
59 | vae = pipe.vae.cuda()
60 | tokenizer = pipe.tokenizer
61 | text_encoder = pipe.text_encoder.cuda()
62 | unet = pipe.unet.cuda()
63 |
64 | if config.model.adapt_only_classifier:
65 | for m in [vae, text_encoder, unet]:
66 | for param in m.parameters():
67 | param.requires_grad = False
68 | for m in [vae, text_encoder]:
69 | for param in m.parameters():
70 | param.requires_grad = False
71 |
72 | return (model_id, pipe, vae, tokenizer, text_encoder, unet, scheduler, image_renormalizer)
73 |
74 |
75 | def get_scheduler_config(config):
76 | assert config.model.sd_version in {'1-4', '2-1'}
77 | if config.model.sd_version == '1-4':
78 | schedule_config = {
79 | "_class_name": "PNDMScheduler",
80 | "_diffusers_version": "0.7.0.dev0",
81 | "beta_end": 0.012,
82 | "beta_schedule": "scaled_linear",
83 | "beta_start": 0.00085,
84 | "num_train_timesteps": 1000,
85 | "set_alpha_to_one": False,
86 | "skip_prk_steps": True,
87 | "steps_offset": 1,
88 | "trained_betas": None,
89 | "clip_sample": False
90 | }
91 | elif config.model.sd_version == '2-1':
92 | schedule_config = {
93 | "_class_name": "EulerDiscreteScheduler",
94 | "_diffusers_version": "0.10.2",
95 | "beta_end": 0.012,
96 | "beta_schedule": "scaled_linear",
97 | "beta_start": 0.00085,
98 | "clip_sample": False,
99 | "num_train_timesteps": 1000,
100 | "prediction_type": "epsilon",
101 | "set_alpha_to_one": False,
102 | "skip_prk_steps": True,
103 | "steps_offset": 1, # todo
104 | "trained_betas": None
105 | }
106 | else:
107 | raise NotImplementedError
108 |
109 | return schedule_config
110 |
111 |
112 | def load_clip_model_OpenAICLIP(config):
113 |
114 | class_model = OpenAICLIP(config)
115 | class_model.to(torch.float32)
116 |
117 | return class_model
118 |
119 |
120 | def load_clip_model_DFN(config):
121 |
122 | class_model = DFN(config)
123 | class_model.to(torch.float32)
124 |
125 | return class_model
126 |
127 |
128 | def load_clip_model_SigLIP(config):
129 |
130 | class_model = SigLIP(config)
131 | class_model.to(torch.float32)
132 |
133 | return class_model
134 |
135 |
136 | def load_clip_model_MetaCLIP(config):
137 |
138 | class_model = MetaCLIP(config)
139 | class_model.to(torch.float32)
140 |
141 | return class_model
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions"""
2 | import importlib
3 | import random
4 |
5 | import torch
6 | import numpy as np
7 | from PIL import Image
8 | from omegaconf import OmegaConf, open_dict
9 |
10 |
11 | class UnNormalize(object):
12 | """Unformalize image as: image = (image * std) + mean
13 | """
14 | def __init__(self, mean, std):
15 | self.mean = torch.tensor(mean)
16 | self.std = torch.tensor(std)
17 |
18 | def __call__(self, tensor):
19 | """
20 | Args:
21 | tensor: A tensor of shape [C, H, W] or [N, C, H, W]
22 |
23 | Returns:
24 | tensor: A tensor of shape [C, H, W] or [N, C, H, W]
25 | """
26 |
27 | std = self.std.to(tensor.device)
28 | mean = self.mean.to(tensor.device)
29 | if tensor.ndim == 3:
30 | std, mean = std.view(-1, 1, 1), mean.view(-1, 1, 1)
31 | elif tensor.ndim == 4:
32 | std, mean = std.view(1, -1, 1, 1), mean.view(1, -1, 1, 1)
33 | tensor = (tensor * std) + mean
34 | return tensor
35 |
36 |
37 | class VQVAEUnNormalize(UnNormalize):
38 | """Unformalize image as:
39 | First: image = (image * std) + mean
40 | Second: image = (image * 2) - 1
41 | """
42 | def __call__(self, tensor):
43 | """
44 | Args:
45 | tensor (Tensor): Tensor image of size (C, H, W) or (N, C, H, W)
46 | to be unnormalized.
47 | Returns:
48 | Tensor: UnNormalized image.
49 | """
50 | tensor = super().__call__(tensor)
51 | tensor = 2 * tensor - 1
52 | return tensor
53 |
54 |
55 | def mean_list(l):
56 | l = [int(_l) for _l in l]
57 | return float(sum(l)) / len(l)
58 |
59 |
60 | def segment_mean(x, index):
61 | """Function as tf.segment_mean.
62 | """
63 | x = x.view(-1, x.shape[-1])
64 | index = index.view(-1)
65 |
66 | max_index = index.max() + 1
67 | sum_x = torch.zeros((max_index, x.shape[-1]),
68 | dtype=x.dtype,
69 | device=x.device)
70 | num_index = torch.zeros((max_index,),
71 | dtype=x.dtype,
72 | device=x.device)
73 |
74 | num_index = num_index.scatter_add_(
75 | 0, index, torch.ones_like(index, dtype=x.dtype))
76 | num_index = torch.where(torch.eq(num_index, 0),
77 | torch.ones_like(num_index, dtype=x.dtype),
78 | num_index)
79 |
80 | index_2d = index.view(-1, 1).expand(-1, x.shape[-1])
81 | sum_x = sum_x.scatter_add_(0, index_2d, x)
82 | mean_x = sum_x.div_(num_index.view(-1, 1))
83 |
84 | return mean_x
85 |
86 |
87 | def get_class_sd_features(tokenizer, text_encoder, input, device=None):
88 | """Prepare class text embeddings for Stable Diffusion
89 |
90 | Args:
91 | tokenizer: A nn.Module object of tokenizer.
92 | text_encoder: A nn.Module object of text encoder.
93 | input: A string
94 | device: GPU/CPU device
95 | """
96 | with torch.no_grad():
97 | input_val = f'a photo of a {input}.',
98 | # Tokenize the text
99 | text_input = tokenizer(input_val, padding="max_length",
100 | max_length=tokenizer.model_max_length,
101 | truncation=True,
102 | return_tensors="pt")
103 | # Get the text embeddings
104 | text_embeddings = text_encoder(text_input.input_ids.cuda())[0]
105 |
106 | return text_embeddings
107 |
108 |
109 | def prepare_class_text_embeddings(tokenizer=None, text_encoder=None, class_names=None):
110 |
111 | text_features = []
112 | for class_name in class_names:
113 | text_features.append(
114 | get_class_sd_features(tokenizer, text_encoder, class_name)
115 | )
116 | text_features = torch.cat(text_features, dim=0)
117 |
118 | return text_features
119 |
120 |
121 | def initiate_time_steps(step, total_timestep, batch_size, config):
122 | """A helper function to initiate time steps for the diffusion model.
123 |
124 | Args:
125 | step: An integer of the constant step
126 | total_timestep: An integer of the total timesteps of the diffusion model
127 | batch_size: An integer of the batch size
128 | config: A config object
129 |
130 | Returns:
131 | timesteps: A tensor of shape [batch_size,] of the time steps
132 | """
133 | if config.tta.rand_timestep_equal_int:
134 | interval_val = total_timestep // batch_size
135 | start_point = random.randint(0, interval_val - 1)
136 | timesteps = torch.tensor(
137 | list(range(start_point, total_timestep, interval_val))
138 | ).long()
139 | return timesteps
140 | elif config.tta.random_timestep_per_iteration:
141 | return torch.randint(0, total_timestep, (batch_size,)).long() #default
142 | else:
143 | return torch.tensor([step] * batch_size).long()
144 |
145 |
146 | def instantiate_from_config(config):
147 | """A helper function to instantiate a class from a config object.
148 | See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py
149 | """
150 | if not "target" in config:
151 | if config == '__is_first_stage__':
152 | return None
153 | elif config == "__is_unconditional__":
154 | return None
155 | raise KeyError("Expected key `target` to instantiate.")
156 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
157 |
158 |
159 | def get_obj_from_str(string, reload=False):
160 | """A helper function to instantiate a class from a config object.
161 | See https://github.com/CompVis/stable-diffusion/blob/main/ldm/util.py
162 | """
163 | module, cls = string.rsplit(".", 1)
164 | if reload:
165 | module_imp = importlib.import_module(module)
166 | importlib.reload(module_imp)
167 | return getattr(importlib.import_module(module, package=None), cls)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.28.0
3 | aiohttp==3.9.3
4 | aiosignal==1.3.1
5 | annotated-types==0.6.0
6 | antlr4-python3-runtime==4.9.3
7 | anyio==4.3.0
8 | appdirs==1.4.4
9 | asttokens==2.4.1
10 | async-timeout==4.0.3
11 | attrs==23.2.0
12 | beartype==0.17.2
13 | blessed==1.20.0
14 | bypy==1.8.4
15 | certifi==2024.2.2
16 | charset-normalizer==3.3.2
17 | click==8.1.7
18 | clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
19 | cmake==3.28.3
20 | contourpy==1.2.0
21 | cycler==0.12.1
22 | datasets==2.17.1
23 | decorator==5.1.1
24 | decord==0.6.0
25 | deepspeed==0.14.0
26 | diffusers==0.25.0
27 | dill==0.3.8
28 | distro==1.9.0
29 | docker-pycreds==0.4.0
30 | einops==0.6.0
31 | ema-pytorch==0.4.2
32 | exceptiongroup==1.2.0
33 | executing==2.0.1
34 | filelock==3.13.1
35 | flash-attn==2.5.6
36 | fonttools==4.47.2
37 | frozenlist==1.4.1
38 | fsspec==2023.10.0
39 | ftfy==6.1.3
40 | gitdb==4.0.11
41 | GitPython==3.1.42
42 | gpustat==1.1.1
43 | grpcio==1.60.1
44 | h11==0.14.0
45 | hjson==3.1.0
46 | httpcore==1.0.4
47 | httpx==0.27.0
48 | huggingface-hub==0.20.3
49 | hydra-core==1.3.2
50 | idna==3.6
51 | importlib-metadata==7.0.1
52 | importlib_resources==6.1.2
53 | ipdb==0.13.13
54 | ipython==8.18.1
55 | jedi==0.19.1
56 | Jinja2==3.1.3
57 | joblib==1.3.2
58 | kiwisolver==1.4.5
59 | lightning-utilities==0.11.2
60 | lit==17.0.6
61 | Markdown==3.5.2
62 | MarkupSafe==2.1.5
63 | matplotlib==3.8.2
64 | matplotlib-inline==0.1.6
65 | mergedeep==1.3.4
66 | mpmath==1.3.0
67 | multidict==6.0.5
68 | multiprocess==0.70.16
69 | mypy-extensions==1.0.0
70 | networkx==3.2.1
71 | ninja==1.11.1.1
72 | numpy==1.26.4
73 | nvidia-cublas-cu11==11.10.3.66
74 | nvidia-cuda-cupti-cu11==11.7.101
75 | nvidia-cuda-nvrtc-cu11==11.7.99
76 | nvidia-cuda-runtime-cu11==11.7.99
77 | nvidia-cudnn-cu11==8.5.0.96
78 | nvidia-cufft-cu11==10.9.0.58
79 | nvidia-curand-cu11==10.2.10.91
80 | nvidia-cusolver-cu11==11.4.0.1
81 | nvidia-cusparse-cu11==11.7.4.91
82 | nvidia-ml-py==12.535.133
83 | nvidia-nccl-cu11==2.14.3
84 | nvidia-nvtx-cu11==11.7.91
85 | omegaconf==2.2.3
86 | open-clip-torch==2.24.0
87 | openai==0.28.1
88 | opencv-python==4.9.0.80
89 | packaging==23.2
90 | pandas==2.2.0
91 | parso==0.8.3
92 | pexpect==4.9.0
93 | pillow==10.2.0
94 | prompt-toolkit==3.0.43
95 | protobuf==4.25.3
96 | psutil==5.9.8
97 | ptyprocess==0.7.0
98 | pure-eval==0.2.2
99 | py-cpuinfo==9.0.0
100 | pyarrow==12.0.0
101 | pyarrow-hotfix==0.6
102 | pydantic==2.6.3
103 | pydantic_core==2.16.3
104 | Pygments==2.17.2
105 | pynvml==11.5.0
106 | pyparsing==3.1.1
107 | pyre-extensions==0.0.29
108 | python-dateutil==2.8.2
109 | pytz==2024.1
110 | PyYAML==6.0.1
111 | regex==2023.12.25
112 | requests==2.31.0
113 | requests-toolbelt==1.0.0
114 | safetensors==0.4.2
115 | scikit-learn==1.4.1.post1
116 | scipy==1.12.0
117 | sentencepiece==0.2.0
118 | sentry-sdk==1.40.5
119 | setproctitle==1.3.3
120 | six==1.16.0
121 | smmap==5.0.1
122 | sniffio==1.3.1
123 | stack-data==0.6.3
124 | sympy==1.12
125 | tensorboard==2.16.2
126 | tensorboard-data-server==0.7.2
127 | threadpoolctl==3.3.0
128 | timm==0.9.8
129 | tokenizers==0.15.2
130 | tomli==2.0.1
131 | torch==2.0.0
132 | torchaudio==2.0.1
133 | torchmetrics==1.3.2
134 | torchvision==0.15.1
135 | tqdm==4.66.2
136 | traitlets==5.14.1
137 | transformers==4.39.3
138 | triton==2.0.0
139 | typing-inspect==0.9.0
140 | typing_extensions==4.9.0
141 | tzdata==2023.4
142 | urllib3==2.2.1
143 | wandb==0.16.3
144 | wcwidth==0.2.13
145 | Werkzeug==3.0.1
146 | xformers==0.0.19
147 | xxhash==3.4.1
148 | yarl==1.9.4
149 | zipp==3.17.0
150 |
--------------------------------------------------------------------------------
/run_DIVA_with_DFN.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import random
5 | import datetime
6 | import builtins
7 | sys.path.append(os.getcwd())
8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
9 | import numpy as np
10 | import torch
11 | import transformers
12 | from transformers.trainer_utils import set_seed
13 | from transformers import HfArgumentParser
14 | from transformers.trainer_utils import get_last_checkpoint
15 | from transformers.utils.versions import require_version
16 | from trainer import CustomTrainer
17 | from arguments import DataTrainingArguments, ModelArguments, TrainingArguments
18 | import csv
19 | from tqdm import tqdm
20 | from PIL import Image
21 | import json
22 | from open_clip import create_model_from_pretrained, get_tokenizer
23 | logger = logging.getLogger(__name__)
24 | import warnings
25 | warnings.filterwarnings("ignore")
26 |
27 |
28 | def random_seed(seed=42, rank=0):
29 | set_seed(seed)
30 | torch.manual_seed(seed + rank)
31 | np.random.seed(seed + rank)
32 | random.seed(seed + rank)
33 | try:
34 | import deepspeed
35 | deepspeed.runtime.utils.set_random_seed(seed + rank)
36 | except:
37 | print("deepspeed.runtime.utils.set_random_seed is not available")
38 |
39 |
40 | def setup_for_distributed(is_master):
41 | """
42 | This function disables printing when not in master process
43 | """
44 | builtin_print = builtins.print
45 |
46 | def print(*args, **kwargs):
47 | force = kwargs.pop('force', False)
48 | if is_master or force:
49 | now = datetime.datetime.now().time()
50 | builtin_print('[{}] '.format(now), end='')
51 | builtin_print(*args, **kwargs)
52 |
53 | builtins.print = print
54 |
55 |
56 | def setup_wandb_env(wandb_api_key=None):
57 | os.environ["WANDB_API_KEY"] = wandb_api_key or ''
58 | os.environ["WANDB_MODE"] = "offline"
59 | os.environ["WANDB__SERVICE_WAIT"] = "300"
60 | os.environ["WANDB_CONFIG_DIR"] = "./wandb"
61 |
62 |
63 | def main():
64 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
65 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
66 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
67 | else:
68 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
69 |
70 | training_args.ddp_find_unused_parameters = True
71 | training_args.multiple_optimizer_training = False
72 | training_args.one_minus_one_data_transform = data_args.one_minus_one_data_transform
73 | training_args.cost_gradient_penalty = model_args.cost_gradient_penalty
74 | setup_wandb_env(training_args.wandb_api_key)
75 |
76 | # Setup logging
77 | logging.basicConfig(
78 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
79 | datefmt="%m/%d/%Y %H:%M:%S",
80 | handlers=[logging.StreamHandler(sys.stdout)],
81 | )
82 |
83 | if training_args.should_log:
84 | transformers.utils.logging.set_verbosity_info()
85 |
86 | log_level = training_args.get_process_log_level()
87 | logger.setLevel(log_level)
88 | transformers.utils.logging.set_verbosity(log_level)
89 | transformers.utils.logging.enable_default_handler()
90 | transformers.utils.logging.enable_explicit_format()
91 |
92 | # Log on each process the small summary:
93 | logger.warning(
94 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
95 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}"
96 | )
97 | logger.info(f"Training/evaluation parameters {training_args}")
98 | logger.info(f"Model parameters {model_args}")
99 | logger.info(f"Data parameters {data_args}")
100 |
101 | # Detecting last checkpoint.
102 | last_checkpoint = None
103 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
104 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
105 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
106 | raise ValueError(
107 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
108 | "Use --overwrite_output_dir to overcome."
109 | )
110 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
111 | logger.info(
112 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
113 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
114 | )
115 |
116 | # data_args.data_seed
117 | random_seed(training_args.seed)
118 | data_args.seed = training_args.seed
119 | training_args.model_type = "image"
120 |
121 | from models.SD_with_DFN import SDModel
122 | from config import SDConfig
123 |
124 | config = SDConfig()
125 | config.tta.gradient_descent.train_steps = training_args.train_steps
126 | config.visual_pattern = training_args.visual_pattern
127 | config.clip_image_size = training_args.clip_image_size
128 | model = SDModel(config)
129 |
130 | # print model parameters
131 | logger.info(f"{str(model)}")
132 | model.cuda()
133 |
134 | from data import get_cc3m_wds_dataset_and_collator
135 | wds_dataset, wds_collator = get_cc3m_wds_dataset_and_collator(data_args, model_args)
136 |
137 | if config.model.freeze_class_embeds:
138 | params = []
139 | for key,parm in model.named_parameters():
140 | if 'final_fc' not in key:
141 | params.append(parm)
142 |
143 | optimizer = torch.optim.SGD(
144 | params, lr=training_args.learning_rate,
145 | weight_decay=training_args.weight_decay,
146 | momentum=config.tta.gradient_descent.optimizer_momentum
147 | )
148 | scheduler = None
149 |
150 | trainer = CustomTrainer(
151 | model=model,
152 | args=training_args,
153 | train_dataset=wds_dataset,
154 | data_collator=wds_collator,
155 | optimizers=(optimizer, scheduler)
156 | )
157 | setup_for_distributed(torch.distributed.get_rank() == 0)
158 |
159 | from callbacks import ModelCallback
160 | trainer.add_callback(ModelCallback)
161 |
162 | # Evaluation
163 | if training_args.local_rank == 0:
164 | print("CLIP's Performance on MMVP-VLM —— Before Generative Fine-tuning")
165 | results_before = official_evaluation(model.class_model.model, config)
166 | print(results_before)
167 |
168 | # Training
169 | if training_args.do_train:
170 | checkpoint = None
171 | if training_args.resume_from_checkpoint is not None:
172 | checkpoint = training_args.resume_from_checkpoint
173 | elif last_checkpoint is not None:
174 | checkpoint = last_checkpoint
175 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
176 | trainer.save_model(output_dir=training_args.output_dir)
177 | trainer.log_metrics("train", train_result.metrics)
178 | trainer.save_metrics("train", train_result.metrics)
179 | trainer.save_state()
180 |
181 | # Evaluation
182 | if training_args.local_rank == 0:
183 | print("CLIP's Performance on MMVP-VLM —— After Generative Fine-tuning")
184 | model_weight_save_path = os.path.join(training_args.output_dir, 'CLIP_after_GenFT.pth')
185 | torch.save(trainer.model.state_dict(), model_weight_save_path)
186 | results_final_after = official_evaluation(trainer.model.class_model.model, config)
187 | print(results_final_after)
188 | save_results(results_before, results_final_after, output_dir=training_args.output_dir)
189 |
190 |
191 | def benchmark_model(model, benchmark_dir, device = "cpu", config=None):
192 | if config.clip_image_size == 224:
193 | _, preprocess = create_model_from_pretrained(model_name='ViT-H-14-quickgelu', pretrained="pretrained_weights/CLIP/DFN5B-CLIP-ViT-H-14/open_clip_pytorch_model.bin", device=device)
194 | if config.clip_image_size == 378:
195 | _, preprocess = create_model_from_pretrained(model_name='ViT-H-14-378-quickgelu', pretrained="pretrained_weights/CLIP/DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin", device=device)
196 |
197 | tokenizer = get_tokenizer('ViT-H-14')
198 |
199 | image_dir = os.path.join(benchmark_dir, 'MLLM_VLM_Images')
200 | csv_file = os.path.join(benchmark_dir, 'Questions.csv')
201 |
202 | csv_outfile = open('Prediction_Results_DFN.csv', 'w', newline='')
203 | csv_writer = csv.writer(csv_outfile)
204 | csv_writer.writerow(['qid1', 'qid2', 'pred1', 'pred2', 'gt1', 'gt2', 'q1score', 'q2score']) # header
205 |
206 | categories = [
207 | 'Orientation and Direction', 'Presence of Specific Features',
208 | 'State and Condition', 'Quantity and Count',
209 | 'Positional and Relational Context', 'Color and Appearance',
210 | 'Structural Characteristics', 'Texts',
211 | 'Viewpoint and Perspective'
212 | ]
213 |
214 | pair_accuracies = {category: 0 for category in categories}
215 | num_pairs = 0
216 |
217 | with open(csv_file, 'r') as f:
218 | reader = csv.reader(f)
219 | next(reader) # skip header
220 | for i, row in tqdm(enumerate(reader)):
221 | qid1, qtype1, statement1 = row
222 |
223 | # Get next row for the pair
224 | row = next(reader, None)
225 | if not row:
226 | break
227 | qid2, qtype2, statement2 = row
228 |
229 | qid1, qid2 = int(qid1), int(qid2)
230 |
231 | img1 = Image.open(os.path.join(image_dir, qtype1, f'{qid1}.jpg'))
232 | img2 = Image.open(os.path.join(image_dir, qtype1, f'{qid2}.jpg'))
233 |
234 | text1 = 'a photo of ' + statement1
235 | text2 = 'a photo of ' + statement2
236 |
237 | text1 = tokenizer(text1).to(device)
238 | text2 = tokenizer(text2).to(device)
239 |
240 | img1 = preprocess(img1).unsqueeze(0).to(device)
241 | img2 = preprocess(img2).unsqueeze(0).to(device)
242 | imgs = torch.cat((img1, img2), dim=0)
243 |
244 | with torch.no_grad(), torch.cuda.amp.autocast():
245 | model.eval().float()
246 |
247 | # original code
248 | # image_features = model.encode_image(imgs)
249 |
250 | # ours
251 | if config.clip_image_size == 224:
252 | image_features = model.encode_image(imgs, normalize=True)[:,0,:]
253 | if config.clip_image_size == 378:
254 | global_image_features = model.encode_image(imgs, normalize=True)[:,0,:]
255 | local_image_features = model.encode_image(imgs, normalize=True)[:,1:,:].mean(dim=1)
256 | image_features = global_image_features + local_image_features
257 |
258 | text1_features = model.encode_text(text1, normalize=True)
259 | text2_features = model.encode_text(text2, normalize=True)
260 | logits_per_image1 = model.logit_scale.exp() * image_features @ text1_features.T
261 | logits_per_text1 = logits_per_image1.T
262 | logits_per_image2 = model.logit_scale.exp() * image_features @ text2_features.T
263 | logits_per_text2 = logits_per_image2.T
264 | probs1 = logits_per_text1.softmax(dim=-1).cpu().numpy()
265 | probs2 = logits_per_text2.softmax(dim=-1).cpu().numpy()
266 |
267 |
268 | img1_score1 = probs1[0][0]
269 | img1_score2 = probs2[0][0]
270 |
271 | pred1 = "img1" if img1_score1 > 0.5 else "img2"
272 | pred2 = "img1" if img1_score2 > 0.5 else "img2"
273 |
274 | gt1 = "img1" if qid1 % 2 == 1 else "img2"
275 | gt2 = "img1" if qid2 % 2 == 1 else "img2"
276 |
277 | csv_writer.writerow([qid1, qid2, pred1, pred2, gt1, gt2, img1_score1, img1_score2])
278 |
279 | current_category = categories[num_pairs // 15]
280 | if pred1 == gt1 and pred2 == gt2:
281 | pair_accuracies[current_category] += 1
282 | num_pairs += 1
283 |
284 | csv_outfile.close()
285 |
286 | # Calculate percentage accuracies
287 | Category_Score_List = []
288 |
289 | for category in pair_accuracies:
290 | pair_accuracies[category] = (pair_accuracies[category] / (num_pairs // len(categories))) * 100
291 | Category_Score_List.append(pair_accuracies[category])
292 |
293 | pair_accuracies['average_score'] = sum(Category_Score_List)/len(Category_Score_List)
294 |
295 | return pair_accuracies
296 |
297 | def official_evaluation(clip_model, config):
298 |
299 | with torch.no_grad():
300 | clip_model.eval()
301 |
302 | # models
303 | data = "dataset/MMVP_VLM"
304 | clip_model_device = next(clip_model.parameters()).device
305 | if config.clip_image_size == 224:
306 | results_openai = {f'DFN5B-CLIP-ViT-H-14': benchmark_model(clip_model, data, clip_model_device, config)}
307 | if config.clip_image_size == 378:
308 | results_openai = {f'DFN5B-CLIP-ViT-H-14-378': benchmark_model(clip_model, data, clip_model_device, config)}
309 |
310 | # Merge results
311 | results = {**results_openai}
312 |
313 | # Convert results to format suitable for star plot
314 | categories = results[list(results.keys())[0]].keys()
315 | data = {'Categories': list(categories)}
316 | for model in list(results_openai.keys()):
317 | data[model] = [results[model][category] for category in categories]
318 |
319 | return results
320 |
321 | def save_results(results_before, results_final_after, output_dir, filename='pred_result.json'):
322 |
323 | os.makedirs(output_dir, exist_ok=True)
324 |
325 | output_data = {
326 | 'results_before': results_before,
327 | 'results_final_after': results_final_after
328 | }
329 |
330 | output_path = os.path.join(output_dir, filename)
331 |
332 | with open(output_path, 'w') as f:
333 | json.dump(output_data, f, indent=4)
334 |
335 |
336 | if __name__ == "__main__":
337 | main()
338 |
--------------------------------------------------------------------------------
/run_DIVA_with_MetaCLIP.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import random
5 | import datetime
6 | import builtins
7 | sys.path.append(os.getcwd())
8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | import transformers
13 | from transformers.trainer_utils import set_seed
14 | from transformers import HfArgumentParser
15 | from transformers.trainer_utils import get_last_checkpoint
16 | from transformers.utils.versions import require_version
17 | from trainer import CustomTrainer
18 | from arguments import DataTrainingArguments, ModelArguments, TrainingArguments
19 | import csv
20 | from tqdm import tqdm
21 | from PIL import Image
22 | import json
23 | from open_clip import create_model_and_transforms, tokenize
24 | logger = logging.getLogger(__name__)
25 | import warnings
26 | warnings.filterwarnings("ignore")
27 |
28 |
29 | def random_seed(seed=42, rank=0):
30 | set_seed(seed)
31 | torch.manual_seed(seed + rank)
32 | np.random.seed(seed + rank)
33 | random.seed(seed + rank)
34 | try:
35 | import deepspeed
36 | deepspeed.runtime.utils.set_random_seed(seed + rank)
37 | except:
38 | print("deepspeed.runtime.utils.set_random_seed is not available")
39 |
40 |
41 | def setup_for_distributed(is_master):
42 | """
43 | This function disables printing when not in master process
44 | """
45 | builtin_print = builtins.print
46 |
47 | def print(*args, **kwargs):
48 | force = kwargs.pop('force', False)
49 | if is_master or force:
50 | now = datetime.datetime.now().time()
51 | builtin_print('[{}] '.format(now), end='')
52 | builtin_print(*args, **kwargs)
53 |
54 | builtins.print = print
55 |
56 |
57 | def setup_wandb_env(wandb_api_key=None):
58 | os.environ["WANDB_API_KEY"] = wandb_api_key or ''
59 | os.environ["WANDB_MODE"] = "offline"
60 | os.environ["WANDB__SERVICE_WAIT"] = "300"
61 | os.environ["WANDB_CONFIG_DIR"] = "./wandb"
62 |
63 |
64 | def main():
65 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
66 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
67 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
68 | else:
69 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
70 |
71 | training_args.ddp_find_unused_parameters = True
72 | training_args.multiple_optimizer_training = False
73 | training_args.one_minus_one_data_transform = data_args.one_minus_one_data_transform
74 | training_args.cost_gradient_penalty = model_args.cost_gradient_penalty
75 | setup_wandb_env(training_args.wandb_api_key)
76 |
77 | # Setup logging
78 | logging.basicConfig(
79 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
80 | datefmt="%m/%d/%Y %H:%M:%S",
81 | handlers=[logging.StreamHandler(sys.stdout)],
82 | )
83 |
84 | if training_args.should_log:
85 | transformers.utils.logging.set_verbosity_info()
86 |
87 | log_level = training_args.get_process_log_level()
88 | logger.setLevel(log_level)
89 | transformers.utils.logging.set_verbosity(log_level)
90 | transformers.utils.logging.enable_default_handler()
91 | transformers.utils.logging.enable_explicit_format()
92 |
93 | # Log on each process the small summary:
94 | logger.warning(
95 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
96 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}"
97 | )
98 | logger.info(f"Training/evaluation parameters {training_args}")
99 | logger.info(f"Model parameters {model_args}")
100 | logger.info(f"Data parameters {data_args}")
101 |
102 | # Detecting last checkpoint.
103 | last_checkpoint = None
104 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
105 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
106 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
107 | raise ValueError(
108 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
109 | "Use --overwrite_output_dir to overcome."
110 | )
111 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
112 | logger.info(
113 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
114 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
115 | )
116 |
117 | random_seed(training_args.seed)
118 | data_args.seed = training_args.seed
119 | training_args.model_type = "image"
120 |
121 | from models.SD_with_MetaCLIP import SDModel
122 | from config import SDConfig
123 |
124 | config = SDConfig()
125 | config.tta.gradient_descent.train_steps = training_args.train_steps
126 | config.visual_pattern = training_args.visual_pattern
127 | config.clip_image_size = training_args.clip_image_size
128 | config.metaclip_version = training_args.metaclip_version
129 | model = SDModel(config)
130 |
131 | # print model parameters
132 | logger.info(f"{str(model)}")
133 | model.cuda()
134 |
135 | from data import get_cc3m_wds_dataset_and_collator
136 | wds_dataset, wds_collator = get_cc3m_wds_dataset_and_collator(data_args, model_args)
137 |
138 | if config.model.freeze_class_embeds:
139 | params = []
140 | for key,parm in model.named_parameters():
141 | if 'final_fc' not in key:
142 | params.append(parm)
143 |
144 | optimizer = torch.optim.SGD(
145 | params, lr=training_args.learning_rate,
146 | weight_decay=training_args.weight_decay,
147 | momentum=config.tta.gradient_descent.optimizer_momentum
148 | )
149 | scheduler = None
150 |
151 | trainer = CustomTrainer(
152 | model=model,
153 | args=training_args,
154 | train_dataset=wds_dataset,
155 | data_collator=wds_collator,
156 | optimizers=(optimizer, scheduler)
157 | )
158 |
159 | setup_for_distributed(torch.distributed.get_rank() == 0)
160 |
161 | from callbacks import ModelCallback
162 | trainer.add_callback(ModelCallback)
163 |
164 | # Evaluation
165 | if training_args.local_rank == 0:
166 | print("CLIP's Performance on MMVP-VLM —— Before Generative Fine-tuning")
167 | results_before = official_evaluation(model.class_model.model, config)
168 | print(results_before)
169 |
170 | # Training
171 | if training_args.do_train:
172 | checkpoint = None
173 | if training_args.resume_from_checkpoint is not None:
174 | checkpoint = training_args.resume_from_checkpoint
175 | elif last_checkpoint is not None:
176 | checkpoint = last_checkpoint
177 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
178 | trainer.save_model(output_dir=training_args.output_dir)
179 | trainer.log_metrics("train", train_result.metrics)
180 | trainer.save_metrics("train", train_result.metrics)
181 | trainer.save_state()
182 |
183 | # Evaluation
184 | if training_args.local_rank == 0:
185 | print("CLIP's Performance on MMVP-VLM —— After Generative Fine-tuning")
186 | model_weight_save_path = os.path.join(training_args.output_dir, 'CLIP_after_GenFT.pth')
187 | torch.save(trainer.model.state_dict(), model_weight_save_path)
188 | results_final_after = official_evaluation(trainer.model.class_model.model, config)
189 | print(results_final_after)
190 | save_results(results_before, results_final_after, output_dir=training_args.output_dir)
191 |
192 |
193 | def benchmark_model(model, benchmark_dir, device = "cpu", config=None):
194 |
195 | if config.metaclip_version == "large":
196 | _, _, preprocess = create_model_and_transforms(model_name='ViT-L-14-quickgelu', pretrained="pretrained_weights/CLIP/MetaCLIP/l14_fullcc2.5b.pt")
197 | if config.metaclip_version == "huge":
198 | _, _, preprocess = create_model_and_transforms(model_name='ViT-H-14-quickgelu', pretrained="pretrained_weights/CLIP/MetaCLIPz/h14_fullcc2.5b.pt")
199 |
200 | image_dir = os.path.join(benchmark_dir, 'MLLM_VLM_Images')
201 | csv_file = os.path.join(benchmark_dir, 'Questions.csv')
202 |
203 | csv_outfile = open('Prediction_Results_MetaCLIP', 'w', newline='')
204 | csv_writer = csv.writer(csv_outfile)
205 | csv_writer.writerow(['qid1', 'qid2', 'pred1', 'pred2', 'gt1', 'gt2', 'q1score', 'q2score']) # header
206 |
207 | categories = [
208 | 'Orientation and Direction', 'Presence of Specific Features',
209 | 'State and Condition', 'Quantity and Count',
210 | 'Positional and Relational Context', 'Color and Appearance',
211 | 'Structural Characteristics', 'Texts',
212 | 'Viewpoint and Perspective'
213 | ]
214 |
215 | pair_accuracies = {category: 0 for category in categories}
216 | num_pairs = 0
217 |
218 | with open(csv_file, 'r') as f:
219 | reader = csv.reader(f)
220 | next(reader)
221 | for i, row in tqdm(enumerate(reader)):
222 | qid1, qtype1, statement1 = row
223 |
224 | # Get next row for the pair
225 | row = next(reader, None)
226 | if not row:
227 | break
228 | qid2, qtype2, statement2 = row
229 |
230 | qid1, qid2 = int(qid1), int(qid2)
231 |
232 | img1 = Image.open(os.path.join(image_dir, qtype1, f'{qid1}.jpg'))
233 | img2 = Image.open(os.path.join(image_dir, qtype1, f'{qid2}.jpg'))
234 |
235 | text1 = 'a photo of ' + statement1
236 | text2 = 'a photo of ' + statement2
237 |
238 | text1 = tokenize(text1).to(device)
239 | text2 = tokenize(text2).to(device)
240 |
241 | img1 = preprocess(img1).unsqueeze(0).to(device)
242 | img2 = preprocess(img2).unsqueeze(0).to(device)
243 | imgs = torch.cat((img1, img2), dim=0)
244 |
245 | with torch.no_grad():
246 | model.eval().float()
247 |
248 | # original code
249 | # image_features = model.encode_image(imgs)
250 |
251 | # ours
252 | global_image_features = model.encode_image(imgs, normalize=True)[:,0,:]
253 | local_image_features = model.encode_image(imgs, normalize=True)[:,1:,:].mean(dim=1)
254 | image_features = global_image_features + local_image_features
255 |
256 | text1_features = model.encode_text(text1)
257 | text2_features = model.encode_text(text2)
258 | image_features = F.normalize(image_features, dim=-1)
259 | text1_features = F.normalize(text1_features, dim=-1)
260 | text2_features = F.normalize(text2_features, dim=-1)
261 |
262 | logits_per_image1 = 100.0 * image_features @ text1_features.T
263 | logits_per_text1 = logits_per_image1.T
264 | logits_per_image2 = 100.0 * image_features @ text2_features.T
265 | logits_per_text2 = logits_per_image2.T
266 |
267 | probs1 = logits_per_text1.softmax(dim=-1).cpu().numpy()
268 | probs2 = logits_per_text2.softmax(dim=-1).cpu().numpy()
269 |
270 | img1_score1 = probs1[0][0]
271 | img1_score2 = probs2[0][0]
272 |
273 | pred1 = "img1" if img1_score1 > 0.5 else "img2"
274 | pred2 = "img1" if img1_score2 > 0.5 else "img2"
275 |
276 | gt1 = "img1" if qid1 % 2 == 1 else "img2"
277 | gt2 = "img1" if qid2 % 2 == 1 else "img2"
278 |
279 | csv_writer.writerow([qid1, qid2, pred1, pred2, gt1, gt2, img1_score1, img1_score2])
280 |
281 | current_category = categories[num_pairs // 15]
282 | if pred1 == gt1 and pred2 == gt2:
283 | pair_accuracies[current_category] += 1
284 | num_pairs += 1
285 |
286 | csv_outfile.close()
287 |
288 | # Calculate percentage accuracies
289 | Category_Score_List = []
290 |
291 | for category in pair_accuracies:
292 | pair_accuracies[category] = (pair_accuracies[category] / (num_pairs // len(categories))) * 100
293 | Category_Score_List.append(pair_accuracies[category])
294 |
295 | pair_accuracies['average_score'] = sum(Category_Score_List)/len(Category_Score_List)
296 |
297 | return pair_accuracies
298 |
299 | def official_evaluation(clip_model, config):
300 |
301 | with torch.no_grad():
302 | clip_model.eval()
303 |
304 | # models
305 | data = "dataset/MMVP_VLM"
306 | clip_model_device = next(clip_model.parameters()).device
307 |
308 | if config.metaclip_version == "large":
309 | results_openai = {f'MetaCLIP-ViT-L-14': benchmark_model(clip_model, data, clip_model_device, config)}
310 | if config.metaclip_version == "huge":
311 | results_openai = {f'MetaCLIP-ViT-H-14': benchmark_model(clip_model, data, clip_model_device, config)}
312 |
313 | # Merge results
314 | results = {**results_openai}
315 |
316 | # Convert results to format suitable for star plot
317 | categories = results[list(results.keys())[0]].keys()
318 | data = {'Categories': list(categories)}
319 | for model in list(results_openai.keys()):
320 | data[model] = [results[model][category] for category in categories]
321 |
322 | return results
323 |
324 | def save_results(results_before, results_final_after, output_dir, filename='pred_result.json'):
325 |
326 | os.makedirs(output_dir, exist_ok=True)
327 |
328 | output_data = {
329 | 'results_before': results_before,
330 | 'results_final_after': results_final_after
331 | }
332 |
333 | output_path = os.path.join(output_dir, filename)
334 |
335 | with open(output_path, 'w') as f:
336 | json.dump(output_data, f, indent=4)
337 |
338 |
339 | if __name__ == "__main__":
340 | main()
341 |
--------------------------------------------------------------------------------
/run_DIVA_with_OpenAICLIP.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import random
5 | import datetime
6 | import builtins
7 | sys.path.append(os.getcwd())
8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
9 | import numpy as np
10 | import torch
11 | import transformers
12 | from transformers.trainer_utils import set_seed
13 | from transformers import HfArgumentParser
14 | from transformers.trainer_utils import get_last_checkpoint
15 | from transformers.utils.versions import require_version
16 | from trainer import CustomTrainer
17 | from arguments import DataTrainingArguments, ModelArguments, TrainingArguments
18 | from clip import load
19 | import clip
20 | import csv
21 | from tqdm import tqdm
22 | from PIL import Image
23 | import json
24 | logger = logging.getLogger(__name__)
25 | import warnings
26 | warnings.filterwarnings("ignore")
27 |
28 |
29 | def random_seed(seed=42, rank=0):
30 | set_seed(seed)
31 | torch.manual_seed(seed + rank)
32 | np.random.seed(seed + rank)
33 | random.seed(seed + rank)
34 | try:
35 | import deepspeed
36 | deepspeed.runtime.utils.set_random_seed(seed + rank)
37 | except:
38 | print("deepspeed.runtime.utils.set_random_seed is not available")
39 |
40 |
41 | def setup_for_distributed(is_master):
42 | """
43 | This function disables printing when not in master process
44 | """
45 | builtin_print = builtins.print
46 |
47 | def print(*args, **kwargs):
48 | force = kwargs.pop('force', False)
49 | if is_master or force:
50 | now = datetime.datetime.now().time()
51 | builtin_print('[{}] '.format(now), end='')
52 | builtin_print(*args, **kwargs)
53 |
54 | builtins.print = print
55 |
56 |
57 | def setup_wandb_env(wandb_api_key=None):
58 | os.environ["WANDB_API_KEY"] = wandb_api_key or ''
59 | os.environ["WANDB_MODE"] = "offline"
60 | os.environ["WANDB__SERVICE_WAIT"] = "300"
61 | os.environ["WANDB_CONFIG_DIR"] = "./wandb"
62 |
63 |
64 | def main():
65 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
66 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
67 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
68 | else:
69 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
70 |
71 | training_args.ddp_find_unused_parameters = True
72 | training_args.multiple_optimizer_training = False
73 | training_args.one_minus_one_data_transform = data_args.one_minus_one_data_transform
74 | training_args.cost_gradient_penalty = model_args.cost_gradient_penalty
75 | setup_wandb_env(training_args.wandb_api_key)
76 |
77 | # Setup logging
78 | logging.basicConfig(
79 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
80 | datefmt="%m/%d/%Y %H:%M:%S",
81 | handlers=[logging.StreamHandler(sys.stdout)],
82 | )
83 |
84 | if training_args.should_log:
85 | transformers.utils.logging.set_verbosity_info()
86 |
87 | log_level = training_args.get_process_log_level()
88 | logger.setLevel(log_level)
89 | transformers.utils.logging.set_verbosity(log_level)
90 | transformers.utils.logging.enable_default_handler()
91 | transformers.utils.logging.enable_explicit_format()
92 |
93 | # Log on each process the small summary:
94 | logger.warning(
95 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
96 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}"
97 | )
98 | logger.info(f"Training/evaluation parameters {training_args}")
99 | logger.info(f"Model parameters {model_args}")
100 | logger.info(f"Data parameters {data_args}")
101 |
102 | # Detecting last checkpoint.
103 | last_checkpoint = None
104 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
105 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
106 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
107 | raise ValueError(
108 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
109 | "Use --overwrite_output_dir to overcome."
110 | )
111 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
112 | logger.info(
113 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
114 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
115 | )
116 |
117 | # data_args.data_seed
118 | random_seed(training_args.seed)
119 | data_args.seed = training_args.seed
120 | training_args.model_type = "image"
121 |
122 | from models.SD_with_OpenAICLIP import SDModel
123 | from config import SDConfig
124 |
125 | config = SDConfig()
126 | config.tta.gradient_descent.train_steps = training_args.train_steps
127 | config.visual_pattern = training_args.visual_pattern
128 | config.clip_image_size = training_args.clip_image_size
129 | model = SDModel(config)
130 |
131 | # print model parameters
132 | logger.info(f"{str(model)}")
133 | model.cuda()
134 |
135 | from data import get_cc3m_wds_dataset_and_collator
136 | wds_dataset, wds_collator = get_cc3m_wds_dataset_and_collator(data_args, model_args)
137 |
138 | if config.model.freeze_class_embeds:
139 | params = []
140 | for key,parm in model.named_parameters():
141 | if 'final_fc' not in key:
142 | params.append(parm)
143 |
144 | optimizer = torch.optim.SGD(
145 | params, lr=training_args.learning_rate,
146 | weight_decay=training_args.weight_decay,
147 | momentum=config.tta.gradient_descent.optimizer_momentum
148 | )
149 | scheduler = None
150 |
151 | trainer = CustomTrainer(
152 | model=model,
153 | args=training_args,
154 | train_dataset=wds_dataset,
155 | data_collator=wds_collator,
156 | optimizers=(optimizer, scheduler)
157 | )
158 | setup_for_distributed(torch.distributed.get_rank() == 0)
159 |
160 | from callbacks import ModelCallback
161 | trainer.add_callback(ModelCallback)
162 |
163 | # Evaluation
164 | if training_args.local_rank == 0:
165 | print("CLIP's Performance on MMVP-VLM —— Before Generative Fine-tuning")
166 | results_before = official_evaluation(model.class_model.model, config)
167 | print(results_before)
168 |
169 | # Training
170 | if training_args.do_train:
171 | checkpoint = None
172 | if training_args.resume_from_checkpoint is not None:
173 | checkpoint = training_args.resume_from_checkpoint
174 | elif last_checkpoint is not None:
175 | checkpoint = last_checkpoint
176 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
177 | trainer.save_model(output_dir=training_args.output_dir)
178 | trainer.log_metrics("train", train_result.metrics)
179 | trainer.save_metrics("train", train_result.metrics)
180 | trainer.save_state()
181 |
182 | # Evaluation
183 | if training_args.local_rank == 0:
184 | print("CLIP's Performance on MMVP-VLM —— After Generative Fine-tuning")
185 | model_weight_save_path = os.path.join(training_args.output_dir, 'CLIP_after_GenFT.pth')
186 | torch.save(trainer.model.state_dict(), model_weight_save_path)
187 | results_final_after = official_evaluation(trainer.model.class_model.model, config)
188 | print(results_final_after)
189 | save_results(results_before, results_final_after, output_dir=training_args.output_dir)
190 |
191 |
192 | def benchmark_model(base_model, model, benchmark_dir, device = "cpu"):
193 |
194 | _, preprocess = load(base_model, device=device)
195 | image_dir = os.path.join(benchmark_dir, 'MLLM_VLM_Images')
196 | csv_file = os.path.join(benchmark_dir, 'Questions.csv')
197 |
198 | csv_outfile = open('Prediction_Results_OpenAICLIP', 'w', newline='')
199 | csv_writer = csv.writer(csv_outfile)
200 | csv_writer.writerow(['qid1', 'qid2', 'pred1', 'pred2', 'gt1', 'gt2', 'q1score', 'q2score']) # header
201 |
202 | categories = [
203 | 'Orientation and Direction', 'Presence of Specific Features',
204 | 'State and Condition', 'Quantity and Count',
205 | 'Positional and Relational Context', 'Color and Appearance',
206 | 'Structural Characteristics', 'Texts',
207 | 'Viewpoint and Perspective'
208 | ]
209 |
210 | pair_accuracies = {category: 0 for category in categories}
211 | num_pairs = 0
212 |
213 | with open(csv_file, 'r') as f:
214 | reader = csv.reader(f)
215 | next(reader) # skip header
216 | for i, row in tqdm(enumerate(reader)):
217 | qid1, qtype1, statement1 = row
218 |
219 | # Get next row for the pair
220 | row = next(reader, None)
221 | if not row:
222 | break
223 | qid2, qtype2, statement2 = row
224 |
225 | qid1, qid2 = int(qid1), int(qid2)
226 |
227 | img1 = Image.open(os.path.join(image_dir, qtype1, f'{qid1}.jpg'))
228 | img2 = Image.open(os.path.join(image_dir, qtype1, f'{qid2}.jpg'))
229 |
230 | text1 = 'a photo of ' + statement1
231 | text2 = 'a photo of ' + statement2
232 |
233 | text1 = clip.tokenize([text1]).to(device)
234 | text2 = clip.tokenize([text2]).to(device)
235 |
236 | img1 = preprocess(img1).unsqueeze(0).to(device)
237 | img2 = preprocess(img2).unsqueeze(0).to(device)
238 | imgs = torch.cat((img1, img2), dim=0)
239 |
240 | with torch.no_grad():
241 | model.eval().float()
242 | logits_per_image1, logits_per_text1 = model(imgs, text1)
243 | logits_per_image2, logits_per_text2 = model(imgs, text2)
244 |
245 | probs1 = logits_per_text1.softmax(dim=-1).cpu().numpy()
246 | probs2 = logits_per_text2.softmax(dim=-1).cpu().numpy()
247 |
248 | img1_score1 = probs1[0][0]
249 | img1_score2 = probs2[0][0]
250 |
251 | pred1 = "img1" if img1_score1 > 0.5 else "img2"
252 | pred2 = "img1" if img1_score2 > 0.5 else "img2"
253 |
254 | gt1 = "img1" if qid1 % 2 == 1 else "img2"
255 | gt2 = "img1" if qid2 % 2 == 1 else "img2"
256 |
257 | csv_writer.writerow([qid1, qid2, pred1, pred2, gt1, gt2, img1_score1, img1_score2])
258 |
259 | current_category = categories[num_pairs // 15]
260 | if pred1 == gt1 and pred2 == gt2:
261 | pair_accuracies[current_category] += 1
262 | num_pairs += 1
263 |
264 | csv_outfile.close()
265 |
266 | # Calculate percentage accuracies
267 | Category_Score_List = []
268 |
269 | for category in pair_accuracies:
270 | pair_accuracies[category] = (pair_accuracies[category] / (num_pairs // len(categories))) * 100
271 | Category_Score_List.append(pair_accuracies[category])
272 |
273 | pair_accuracies['average_score'] = sum(Category_Score_List)/len(Category_Score_List)
274 |
275 | return pair_accuracies
276 |
277 | def official_evaluation(clip_model, config):
278 |
279 | with torch.no_grad():
280 | clip_model.eval()
281 |
282 | # models
283 | data = "dataset/MMVP_VLM"
284 | if config.clip_image_size == 224:
285 | base_model = "pretrained_weights/CLIP/ViT-L-14.pt"
286 | if config.clip_image_size == 336:
287 | base_model = "pretrained_weights/CLIP/ViT-L-14-336px.pt"
288 | clip_model_device = next(clip_model.parameters()).device
289 | clip_model_name = base_model.split('/')[-1].split('.')[0]
290 | results_openai = {f'openai-{clip_model_name}': benchmark_model(base_model, clip_model, data, clip_model_device)}
291 |
292 | # Merge results
293 | results = {**results_openai}
294 |
295 | # Convert results to format suitable for star plot
296 | categories = results[list(results.keys())[0]].keys()
297 | data = {'Categories': list(categories)}
298 | for model in list(results_openai.keys()):
299 | data[model] = [results[model][category] for category in categories]
300 |
301 | return results
302 |
303 | def save_results(results_before, results_final_after, output_dir, filename='pred_result.json'):
304 |
305 | os.makedirs(output_dir, exist_ok=True)
306 |
307 | output_data = {
308 | 'results_before': results_before,
309 | 'results_final_after': results_final_after
310 | }
311 |
312 | output_path = os.path.join(output_dir, filename)
313 |
314 | with open(output_path, 'w') as f:
315 | json.dump(output_data, f, indent=4)
316 |
317 |
318 | if __name__ == "__main__":
319 | main()
320 |
--------------------------------------------------------------------------------
/run_DIVA_with_SigLIP.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import random
5 | import datetime
6 | import builtins
7 | sys.path.append(os.getcwd())
8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | import transformers
13 | from transformers.trainer_utils import set_seed
14 | from transformers import HfArgumentParser
15 | from transformers.trainer_utils import get_last_checkpoint
16 | from transformers.utils.versions import require_version
17 | from trainer import CustomTrainer
18 | from arguments import DataTrainingArguments, ModelArguments, TrainingArguments
19 | import csv
20 | from tqdm import tqdm
21 | from PIL import Image
22 | import json
23 | from open_clip import create_model_from_pretrained, get_tokenizer
24 | logger = logging.getLogger(__name__)
25 | import warnings
26 | warnings.filterwarnings("ignore")
27 |
28 |
29 | def random_seed(seed=42, rank=0):
30 | set_seed(seed)
31 | torch.manual_seed(seed + rank)
32 | np.random.seed(seed + rank)
33 | random.seed(seed + rank)
34 | try:
35 | import deepspeed
36 | deepspeed.runtime.utils.set_random_seed(seed + rank)
37 | except:
38 | print("deepspeed.runtime.utils.set_random_seed is not available")
39 |
40 |
41 | def setup_for_distributed(is_master):
42 | """
43 | This function disables printing when not in master process
44 | """
45 | builtin_print = builtins.print
46 |
47 | def print(*args, **kwargs):
48 | force = kwargs.pop('force', False)
49 | if is_master or force:
50 | now = datetime.datetime.now().time()
51 | builtin_print('[{}] '.format(now), end='')
52 | builtin_print(*args, **kwargs)
53 |
54 | builtins.print = print
55 |
56 |
57 | def setup_wandb_env(wandb_api_key=None):
58 | os.environ["WANDB_API_KEY"] = wandb_api_key or ''
59 | os.environ["WANDB_MODE"] = "offline"
60 | os.environ["WANDB__SERVICE_WAIT"] = "300"
61 | os.environ["WANDB_CONFIG_DIR"] = "./wandb"
62 |
63 |
64 | def main():
65 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
66 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
67 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
68 | else:
69 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
70 |
71 | training_args.ddp_find_unused_parameters = True
72 | training_args.multiple_optimizer_training = False
73 | training_args.one_minus_one_data_transform = data_args.one_minus_one_data_transform
74 | training_args.cost_gradient_penalty = model_args.cost_gradient_penalty
75 | setup_wandb_env(training_args.wandb_api_key)
76 |
77 | # Setup logging
78 | logging.basicConfig(
79 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
80 | datefmt="%m/%d/%Y %H:%M:%S",
81 | handlers=[logging.StreamHandler(sys.stdout)],
82 | )
83 |
84 | if training_args.should_log:
85 | transformers.utils.logging.set_verbosity_info()
86 |
87 | log_level = training_args.get_process_log_level()
88 | logger.setLevel(log_level)
89 | transformers.utils.logging.set_verbosity(log_level)
90 | transformers.utils.logging.enable_default_handler()
91 | transformers.utils.logging.enable_explicit_format()
92 |
93 | # Log on each process the small summary:
94 | logger.warning(
95 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
96 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}"
97 | )
98 | logger.info(f"Training/evaluation parameters {training_args}")
99 | logger.info(f"Model parameters {model_args}")
100 | logger.info(f"Data parameters {data_args}")
101 |
102 | # Detecting last checkpoint.
103 | last_checkpoint = None
104 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
105 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
106 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
107 | raise ValueError(
108 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
109 | "Use --overwrite_output_dir to overcome."
110 | )
111 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
112 | logger.info(
113 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
114 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
115 | )
116 |
117 | random_seed(training_args.seed)
118 | data_args.seed = training_args.seed
119 | training_args.model_type = "image"
120 |
121 | from models.SD_with_SigLIP import SDModel
122 | from config import SDConfig
123 |
124 | config = SDConfig()
125 | config.tta.gradient_descent.train_steps = training_args.train_steps
126 | config.visual_pattern = training_args.visual_pattern
127 | config.clip_image_size = training_args.clip_image_size
128 | model = SDModel(config)
129 |
130 | # print model parameters
131 | logger.info(f"{str(model)}")
132 | model.cuda()
133 |
134 | from data import get_cc3m_wds_dataset_and_collator
135 | wds_dataset, wds_collator = get_cc3m_wds_dataset_and_collator(data_args, model_args)
136 |
137 | if config.model.freeze_class_embeds:
138 | params = []
139 | for key,parm in model.named_parameters():
140 | if 'final_fc' not in key:
141 | params.append(parm)
142 |
143 | optimizer = torch.optim.SGD(
144 | params, lr=training_args.learning_rate,
145 | weight_decay=training_args.weight_decay,
146 | momentum=config.tta.gradient_descent.optimizer_momentum
147 | )
148 | scheduler = None
149 |
150 | trainer = CustomTrainer(
151 | model=model,
152 | args=training_args,
153 | train_dataset=wds_dataset,
154 | data_collator=wds_collator,
155 | optimizers=(optimizer, scheduler)
156 | )
157 | setup_for_distributed(torch.distributed.get_rank() == 0)
158 |
159 | from callbacks import ModelCallback
160 | trainer.add_callback(ModelCallback)
161 |
162 | # Evaluation
163 | if training_args.local_rank == 0:
164 | print("CLIP's Performance on MMVP-VLM —— Before Generative Fine-tuning")
165 | results_before = official_evaluation(model.class_model.model, config)
166 | print(results_before)
167 |
168 | # Training
169 | if training_args.do_train:
170 | checkpoint = None
171 | if training_args.resume_from_checkpoint is not None:
172 | checkpoint = training_args.resume_from_checkpoint
173 | elif last_checkpoint is not None:
174 | checkpoint = last_checkpoint
175 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
176 | trainer.save_model(output_dir=training_args.output_dir)
177 | trainer.log_metrics("train", train_result.metrics)
178 | trainer.save_metrics("train", train_result.metrics)
179 | trainer.save_state()
180 |
181 | # Evaluation
182 | if training_args.local_rank == 0:
183 | print("CLIP's Performance on MMVP-VLM —— After Generative Fine-tuning")
184 | model_weight_save_path = os.path.join(training_args.output_dir, 'CLIP_after_GenFT.pth')
185 | torch.save(trainer.model.state_dict(), model_weight_save_path)
186 | results_final_after = official_evaluation(trainer.model.class_model.model, config)
187 | print(results_final_after)
188 | save_results(results_before, results_final_after, output_dir=training_args.output_dir)
189 |
190 |
191 | def benchmark_model(model, benchmark_dir, device = "cpu", config=None):
192 | if config.clip_image_size == 224:
193 | _, preprocess = create_model_from_pretrained(model_name='ViT-SO400M-14-SigLIP', pretrained="pretrained_weights/CLIP/ViT-SO400M-14-SigLIP/open_clip_pytorch_model.bin", device=device,
194 | image_mean=([0.5,0.5,0.5]), image_std=([0.5,0.5,0.5]), image_interpolation="bicubic", image_resize_mode="squash")
195 | tokenizer = get_tokenizer('ViT-SO400M-14-SigLIP')
196 | if config.clip_image_size == 384:
197 | _, preprocess = create_model_from_pretrained(model_name='ViT-SO400M-14-SigLIP-384', pretrained="pretrained_weights/CLIP/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin", device=device,
198 | image_mean=([0.5,0.5,0.5]), image_std=([0.5,0.5,0.5]), image_interpolation="bicubic", image_resize_mode="squash")
199 | tokenizer = get_tokenizer('ViT-SO400M-14-SigLIP-384')
200 |
201 | image_dir = os.path.join(benchmark_dir, 'MLLM_VLM_Images')
202 | csv_file = os.path.join(benchmark_dir, 'Questions.csv')
203 |
204 | csv_outfile = open('Prediction_Results_SigLIP', 'w', newline='')
205 | csv_writer = csv.writer(csv_outfile)
206 | csv_writer.writerow(['qid1', 'qid2', 'pred1', 'pred2', 'gt1', 'gt2', 'q1score', 'q2score']) # header
207 |
208 | categories = [
209 | 'Orientation and Direction', 'Presence of Specific Features',
210 | 'State and Condition', 'Quantity and Count',
211 | 'Positional and Relational Context', 'Color and Appearance',
212 | 'Structural Characteristics', 'Texts',
213 | 'Viewpoint and Perspective'
214 | ]
215 |
216 | pair_accuracies = {category: 0 for category in categories}
217 | num_pairs = 0
218 |
219 | with open(csv_file, 'r') as f:
220 | reader = csv.reader(f)
221 | next(reader)
222 | for i, row in tqdm(enumerate(reader)):
223 | qid1, qtype1, statement1 = row
224 |
225 | # Get next row for the pair
226 | row = next(reader, None)
227 | if not row:
228 | break
229 | qid2, qtype2, statement2 = row
230 |
231 | qid1, qid2 = int(qid1), int(qid2)
232 |
233 | img1 = Image.open(os.path.join(image_dir, qtype1, f'{qid1}.jpg'))
234 | img2 = Image.open(os.path.join(image_dir, qtype1, f'{qid2}.jpg'))
235 |
236 | text1 = 'a photo of ' + statement1
237 | text2 = 'a photo of ' + statement2
238 |
239 | text1 = tokenizer(text1, context_length=model.context_length).to(device)
240 | text2 = tokenizer(text2, context_length=model.context_length).to(device)
241 |
242 | img1 = preprocess(img1).unsqueeze(0).to(device)
243 | img2 = preprocess(img2).unsqueeze(0).to(device)
244 | imgs = torch.cat((img1, img2), dim=0)
245 |
246 | with torch.no_grad(), torch.cuda.amp.autocast():
247 | model.eval().float()
248 |
249 | # original code
250 | # image_features = model.encode_image(imgs)
251 |
252 | # ours
253 | image_features = model.encode_image(imgs)[:,0,:]
254 |
255 | text1_features = model.encode_text(text1)
256 | text2_features = model.encode_text(text2)
257 | image_features = F.normalize(image_features, dim=-1)
258 | text1_features = F.normalize(text1_features, dim=-1)
259 | text2_features = F.normalize(text2_features, dim=-1)
260 | logits_per_image1 = image_features @ text1_features.T * model.logit_scale.exp() + model.logit_bias
261 | logits_per_text1 = logits_per_image1.T
262 | logits_per_image2 = image_features @ text2_features.T * model.logit_scale.exp() + model.logit_bias
263 | logits_per_text2 = logits_per_image2.T
264 |
265 | probs1 = logits_per_text1.softmax(dim=-1).cpu().numpy()
266 | probs2 = logits_per_text2.softmax(dim=-1).cpu().numpy()
267 |
268 | img1_score1 = probs1[0][0]
269 | img1_score2 = probs2[0][0]
270 |
271 | pred1 = "img1" if img1_score1 > 0.5 else "img2"
272 | pred2 = "img1" if img1_score2 > 0.5 else "img2"
273 |
274 | gt1 = "img1" if qid1 % 2 == 1 else "img2"
275 | gt2 = "img1" if qid2 % 2 == 1 else "img2"
276 |
277 | csv_writer.writerow([qid1, qid2, pred1, pred2, gt1, gt2, img1_score1, img1_score2])
278 |
279 | current_category = categories[num_pairs // 15]
280 | if pred1 == gt1 and pred2 == gt2:
281 | pair_accuracies[current_category] += 1
282 | num_pairs += 1
283 |
284 | csv_outfile.close()
285 |
286 | # Calculate percentage accuracies
287 | Category_Score_List = []
288 |
289 | for category in pair_accuracies:
290 | pair_accuracies[category] = (pair_accuracies[category] / (num_pairs // len(categories))) * 100
291 | Category_Score_List.append(pair_accuracies[category])
292 |
293 | pair_accuracies['average_score'] = sum(Category_Score_List)/len(Category_Score_List)
294 |
295 | return pair_accuracies
296 |
297 | def official_evaluation(clip_model, config):
298 |
299 | with torch.no_grad():
300 | clip_model.eval()
301 |
302 | # models
303 | data = "dataset/MMVP_VLM"
304 | clip_model_device = next(clip_model.parameters()).device
305 | if config.clip_image_size == 224:
306 | results_openai = {f'ViT-SO400M-14-SigLIP': benchmark_model(clip_model, data, clip_model_device, config)}
307 | if config.clip_image_size == 384:
308 | results_openai = {f'ViT-SO400M-14-SigLIP-384': benchmark_model(clip_model, data, clip_model_device, config)}
309 |
310 | # Merge results
311 | results = {**results_openai}
312 |
313 | # Convert results to format suitable for star plot
314 | categories = results[list(results.keys())[0]].keys()
315 | data = {'Categories': list(categories)}
316 | for model in list(results_openai.keys()):
317 | data[model] = [results[model][category] for category in categories]
318 |
319 | return results
320 |
321 | def save_results(results_before, results_final_after, output_dir, filename='pred_result.json'):
322 |
323 | os.makedirs(output_dir, exist_ok=True)
324 |
325 | output_data = {
326 | 'results_before': results_before,
327 | 'results_final_after': results_final_after
328 | }
329 |
330 | output_path = os.path.join(output_dir, filename)
331 |
332 | with open(output_path, 'w') as f:
333 | json.dump(output_data, f, indent=4)
334 |
335 |
336 | if __name__ == "__main__":
337 | main()
338 |
--------------------------------------------------------------------------------