├── data
└── put data here.txt
├── checkpoints
└── put model checkpoints here.txt
├── assets
└── poster.pdf
├── requirements.txt
├── scripts
├── evaluate.sh
├── inference_generate_data.sh
├── inference.sh
├── train_cvae.sh
├── train_diffharmony.sh
├── train_refinement_stage.sh
├── misc
│ └── classify_cand_gen_data.py
├── inference
│ ├── inverse.py
│ └── main.py
├── evaluate
│ └── main.py
└── train
│ └── cvae_with_gen_data.py
├── configs
├── acc_configs
│ ├── multi_deepspeed.yaml
│ ├── multi_default.yaml
│ └── single_default.yaml
└── stage2_configs
│ ├── small.json
│ ├── base.json
│ └── large.json
├── src
├── dataset
│ ├── harmony_gen.py
│ └── ihd_dataset.py
├── utils.py
└── models
│ ├── unet_2d_blocks.py
│ ├── unet_2d.py
│ ├── vae.py
│ └── condition_vae.py
├── README_zh.md
├── README.md
└── .gitignore
/data/put data here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/checkpoints/put model checkpoints here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nicecv/DiffHarmony/HEAD/assets/poster.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.30.1
2 | bitsandbytes==0.43.1
3 | deepspeed==0.14.4
4 | diffusers==0.26.3
5 | einops==0.8.0
6 | huggingface-hub==0.23.4
7 | opencv-python-headless==4.10.0.84
8 | pandas==2.2.2
9 | safetensors==0.4.3
10 | sentencepiece==0.2.0
11 | scikit-image==0.24.0
12 | tensorboard==2.16.2
13 | tqdm==4.66.4
14 | transformers==4.41.2
15 | wandb==0.17.3
--------------------------------------------------------------------------------
/scripts/evaluate.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=.:$PYTHONPATH
2 |
3 | OUTPUT_DIR=""
4 | DATA_DIR=data/iHarmony4
5 | TEST_FILE=test.jsonl
6 |
7 | python scripts/evaluate/main.py \
8 | --input_dir $OUTPUT_DIR \
9 | --output_dir $OUTPUT_DIR-evaluation \
10 | --data_dir $DATA_DIR \
11 | --json_file_path $TEST_FILE \
12 | --resolution=256 \
13 | --use_gt_bg \
14 | --num_processes=32
--------------------------------------------------------------------------------
/configs/acc_configs/multi_deepspeed.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | gradient_accumulation_steps: 1
4 | gradient_clipping: 1.0
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | fsdp_config: {}
11 | machine_rank: 0
12 | main_process_ip: null
13 | main_process_port: null
14 | main_training_function: main
15 | mixed_precision: fp16
16 | num_machines: 1
17 | num_processes: 2
18 | use_cpu: false
--------------------------------------------------------------------------------
/configs/acc_configs/multi_default.yaml:
--------------------------------------------------------------------------------
1 | command_file: null
2 | commands: null
3 | compute_environment: LOCAL_MACHINE
4 | deepspeed_config: {}
5 | distributed_type: MULTI_GPU
6 | downcast_bf16: 'no'
7 | dynamo_backend: 'NO'
8 | fsdp_config: {}
9 | gpu_ids:
10 | machine_rank: 0
11 | main_process_ip: null
12 | main_process_port: null
13 | main_training_function: main
14 | megatron_lm_config: {}
15 | mixed_precision: 'no'
16 | num_machines: 1
17 | num_processes: 2
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_name: null
21 | tpu_zone: null
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/configs/acc_configs/single_default.yaml:
--------------------------------------------------------------------------------
1 | command_file: null
2 | commands: null
3 | compute_environment: LOCAL_MACHINE
4 | deepspeed_config: {}
5 | distributed_type: 'NO'
6 | downcast_bf16: 'no'
7 | dynamo_backend: 'NO'
8 | fsdp_config: {}
9 | gpu_ids:
10 | machine_rank: 0
11 | main_process_ip: null
12 | main_process_port: null
13 | main_training_function: main
14 | megatron_lm_config: {}
15 | mixed_precision: 'no'
16 | num_machines: 1
17 | num_processes: 1
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_name: null
21 | tpu_zone: null
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/configs/stage2_configs/small.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "UNet2DCustom",
3 | "sample_size": 256,
4 | "in_channels": 3,
5 | "out_channels": 3,
6 |
7 | "down_block_types": [
8 | "DownBlock2D",
9 | "DownBlock2D",
10 | "DownBlock2D",
11 | "DownBlock2D"
12 | ],
13 | "up_block_types": [
14 | "UpBlock2D",
15 | "UpBlock2D",
16 | "UpBlock2D",
17 | "UpBlock2D"
18 | ],
19 | "block_out_channels": [64, 128, 256, 256],
20 |
21 | "layers_per_block": 2,
22 | "mid_block_scale_factor": 1,
23 | "downsample_padding": 0,
24 | "downsample_type": "conv",
25 | "upsample_type": "conv",
26 | "act_fn": "silu",
27 | "attention_head_dim": null,
28 | "norm_num_groups": 32,
29 | "norm_eps": 1e-6,
30 | "add_attention": true,
31 | "input_residual": true
32 | }
--------------------------------------------------------------------------------
/configs/stage2_configs/base.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "UNet2DCustom",
3 | "sample_size": 256,
4 | "in_channels": 3,
5 | "out_channels": 3,
6 |
7 | "down_block_types": [
8 | "DownBlock2D",
9 | "DownBlock2D",
10 | "DownBlock2D",
11 | "DownBlock2D"
12 | ],
13 | "up_block_types": [
14 | "UpBlock2D",
15 | "UpBlock2D",
16 | "UpBlock2D",
17 | "UpBlock2D"
18 | ],
19 | "block_out_channels": [128, 256, 512, 512],
20 |
21 | "layers_per_block": 2,
22 | "mid_block_scale_factor": 1,
23 | "downsample_padding": 0,
24 | "downsample_type": "conv",
25 | "upsample_type": "conv",
26 | "act_fn": "silu",
27 | "attention_head_dim": null,
28 | "norm_num_groups": 32,
29 | "norm_eps": 1e-6,
30 | "add_attention": true,
31 | "input_residual": true
32 | }
33 |
--------------------------------------------------------------------------------
/configs/stage2_configs/large.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "UNet2DCustom",
3 | "sample_size": 256,
4 | "in_channels": 3,
5 | "out_channels": 3,
6 |
7 | "down_block_types": [
8 | "DownBlock2D",
9 | "DownBlock2D",
10 | "DownBlock2D",
11 | "DownBlock2D"
12 | ],
13 | "up_block_types": [
14 | "UpBlock2D",
15 | "UpBlock2D",
16 | "UpBlock2D",
17 | "UpBlock2D"
18 | ],
19 | "block_out_channels": [256, 512, 1024, 1024],
20 |
21 | "layers_per_block": 2,
22 | "mid_block_scale_factor": 1,
23 | "downsample_padding": 0,
24 | "downsample_type": "conv",
25 | "upsample_type": "conv",
26 | "act_fn": "silu",
27 | "attention_head_dim": null,
28 | "norm_num_groups": 32,
29 | "norm_eps": 1e-6,
30 | "add_attention": true,
31 | "input_residual": true
32 | }
33 |
--------------------------------------------------------------------------------
/scripts/inference_generate_data.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=.:$PYTHONPATH
2 | ACC_CONFIG_FILE="configs/acc_configs/multi_default.yaml"
3 | export CUDA_VISIBLE_DEVICES="0,1"
4 | NUM_PROCESSES=2
5 | MASTER_PORT=29500
6 |
7 | DATASET_ROOT=
8 | OUTPUT_DIR=$DATASET_ROOT/cand_composite_images
9 | TEST_FILE="all_mask_metadata.jsonl"
10 |
11 |
12 | accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --main_process_port $MASTER_PORT \
13 | scripts/inference/inverse.py \
14 | --pretrained_model_name_or_path checkpoints/stable-diffusion-inpainting \
15 | --pretrained_vae_model_name_or_path \
16 | --pretrained_unet_model_name_or_path \
17 | --dataset_root $DATASET_ROOT \
18 | --test_file $TEST_FILE \
19 | --output_dir $OUTPUT_DIR \
20 | --seed=0 \
21 | --resolution=1024 \
22 | --output_resolution=1024 \
23 | --eval_batch_size= \
24 | --dataloader_num_workers= \
25 | --mixed_precision="fp16" \
26 | --rounds=10
--------------------------------------------------------------------------------
/scripts/inference.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=.:$PYTHONPATH
2 | ACC_CONFIG_FILE="configs/acc_configs/multi_default.yaml"
3 | export CUDA_VISIBLE_DEVICES="0,1"
4 | NUM_PROCESSES=2
5 | MASTER_PORT=29500
6 |
7 | OUTPUT_DIR=""
8 | mkdir -p $OUTPUT_DIR
9 | cat "$0" >> $OUTPUT_DIR/run_script.sh
10 |
11 | DATA_DIR=data/iHarmony4
12 | TEST_FILE=test.jsonl
13 |
14 | accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --main_process_port $MASTER_PORT \
15 | scripts/inference/main.py \
16 | --pretrained_model_name_or_path checkpoints/stable-diffusion-inpainting \
17 | --pretrained_vae_model_name_or_path checkpoints/sd-vae-ft-mse \
18 | --pretrained_unet_model_name_or_path "" \
19 | --dataset_root $DATA_DIR \
20 | --test_file $TEST_FILE \
21 | --output_dir $OUTPUT_DIR \
22 | --seed=0 \
23 | --resolution=512 \
24 | --output_resolution=256 \
25 | --eval_batch_size= \
26 | --dataloader_num_workers= \
27 | --mixed_precision="fp16"
28 |
29 | # --stage2_model_name_or_path ""
--------------------------------------------------------------------------------
/scripts/train_cvae.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=.:$PYTHONPATH
2 | ACC_CONFIG_FILE="configs/acc_configs/multi_default.yaml"
3 | export CUDA_VISIBLE_DEVICES="0,1"
4 | NUM_PROCESSES=2
5 | MASTER_PORT=29500
6 |
7 | OUTPUT_DIR=""
8 | mkdir -p $OUTPUT_DIR
9 | cat "$0" >> $OUTPUT_DIR/run_script.sh
10 |
11 | accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --main_process_port $MASTER_PORT \
12 | scripts/train/cvae.py \
13 | --pretrained_vae_model_name_or_path checkpoints/sd-vae-ft-mse \
14 | --output_dir $OUTPUT_DIR \
15 | --seed= \
16 | --train_batch_size= \
17 | --eval_batch_size= \
18 | --dataloader_num_workers= \
19 | --num_train_epochs= \
20 | --gradient_accumulation_steps= \
21 | --learning_rate= \
22 | --lr_scheduler "" \
23 | --lr_warmup_ratio= \
24 | --use_ema \
25 | --adam_weight_decay= \
26 | --ema_decay= \
27 | --mixed_precision="fp16" \
28 | --checkpointing_epochs= \
29 | --checkpoints_total_limit= \
30 | --image_logging_epochs= \
31 | --dataset_root "data/iHarmony4" \
32 | --train_file "train.jsonl" \
33 | --test_file "test.jsonl" \
34 | --resolution=256 \
35 | --additional_in_channels=1 \
36 | --gradient_checkpointing
--------------------------------------------------------------------------------
/scripts/train_diffharmony.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=.:$PYTHONPATH
2 |
3 | ACC_CONFIG_FILE="configs/acc_configs/multi_default.yaml"
4 | export CUDA_VISIBLE_DEVICES="0,1"
5 | NUM_PROCESSES=2
6 | MASTER_PORT=29500
7 |
8 | OUTPUT_DIR=""
9 | mkdir -p $OUTPUT_DIR
10 | cat "$0" >> $OUTPUT_DIR/run_script.sh
11 |
12 | accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --main_process_port $MASTER_PORT \
13 | scripts/train/diffharmony.py \
14 | --pretrained_model_name_or_path "checkpoints/stable-diffusion-inpainting" \
15 | --vae_path "checkpoints/sd-vae-ft-mse" \
16 | --output_dir $OUTPUT_DIR \
17 | --seed= \
18 | --train_batch_size= \
19 | --eval_batch_size= \
20 | --dataloader_num_workers= \
21 | --num_train_epochs= \
22 | --gradient_accumulation_steps= \
23 | --learning_rate= \
24 | --lr_scheduler "" \
25 | --lr_warmup_ratio= \
26 | --use_ema \
27 | --adam_weight_decay= \
28 | --mixed_precision="fp16" \
29 | --checkpointing_epochs= \
30 | --checkpoints_total_limit= \
31 | --dataset_root "data/iHarmony4" \
32 | --train_file "train.jsonl" \
33 | --test_file "test.jsonl" \
34 | --resolution=512 \
35 | --infer_resolution=512 \
36 | --image_log_interval= \
37 | --gradient_checkpointing \
38 | --enable_xformers_memory_efficient_attention
--------------------------------------------------------------------------------
/scripts/train_refinement_stage.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=.:$PYTHONPATH
2 | ACC_CONFIG_FILE="configs/acc_configs/multi_default.yaml"
3 | export CUDA_VISIBLE_DEVICES="0,1"
4 | NUM_PROCESSES=2
5 | MASTER_PORT=29500
6 |
7 | OUTPUT_DIR=""
8 | mkdir -p $OUTPUT_DIR
9 | cat "$0" >> $OUTPUT_DIR/run_script.sh
10 |
11 | accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --main_process_port $MASTER_PORT \
12 | scripts/train/refinement_stage.py \
13 | --pipeline_path "checkpoints/stable-diffusion-inpainting" \
14 | --pretrained_vae_path "checkpoints/sd-vae-ft-mse" \
15 | --pretrained_unet_path "checkpoints/base/unet" \
16 | --model_path configs/stage2_configs/base.json \
17 | --output_dir $OUTPUT_DIR \
18 | --seed= \
19 | --dataloader_num_workers= \
20 | --train_batch_size= \
21 | --num_train_epochs= \
22 | --gradient_accumulation_steps= \
23 | --learning_rate= \
24 | --lr_scheduler "" \
25 | --lr_warmup_ratio= \
26 | --use_ema \
27 | --ema_decay= \
28 | --adam_weight_decay= \
29 | --mixed_precision="fp16" \
30 | --checkpointing_epochs= \
31 | --checkpoints_total_limit= \
32 | --infer_resolution=512 \
33 | --resolution=256 \
34 | --in_channels=7 \
35 | --dataset_root "data/iHarmony4" \
36 | --train_file "train.jsonl" \
37 | --gradient_checkpointing \
38 | --enable_xformers_memory_efficient_attention
39 |
40 | # --kl_div_weight=1e-8 \
--------------------------------------------------------------------------------
/src/dataset/harmony_gen.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from torch.utils.data import Dataset
3 | import torchvision.transforms as transforms
4 | from argparse import Namespace
5 | from PIL import Image
6 | import json
7 | import os
8 |
9 |
10 | class GenHarmonyDataset(Dataset):
11 | def __init__(self, dataset_root, resolution):
12 | self.metadata = [
13 | json.loads(line)
14 | for line in open(os.path.join(dataset_root, "metadata.jsonl")).readlines()
15 | ]
16 | self.dataset_root = dataset_root
17 | self.resolution = resolution
18 | self.transforms = Namespace(
19 | resize=transforms.Resize(
20 | [self.resolution, self.resolution],
21 | interpolation=transforms.InterpolationMode.BILINEAR,
22 | antialias=True,
23 | ),
24 | convert = transforms.Compose([
25 | transforms.ToTensor(),
26 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
27 | ]),
28 | )
29 |
30 | def __len__(self):
31 | return len(self.metadata)
32 |
33 | def __getitem__(self, index):
34 | image_path = os.path.join(self.dataset_root, self.metadata[index]["sample"])
35 | cond_path = self.metadata[index]["cond"]
36 | target_path = self.metadata[index]["target"]
37 |
38 | image = Image.open(image_path).convert("RGB")
39 | if image.size != (self.resolution, self.resolution):
40 | image = self.transforms.resize(image)
41 |
42 | cond_image = Image.open(cond_path).convert("RGB")
43 | if cond_image.size != (self.resolution, self.resolution):
44 | cond_image = self.transforms.resize(cond_image)
45 |
46 | target_image = Image.open(target_path).convert("RGB")
47 | if target_image.size != (self.resolution, self.resolution):
48 | target_image = self.transforms.resize(target_image)
49 |
50 | image= self.transforms.convert(image)
51 | cond_image = self.transforms.convert(cond_image)
52 | target_image = self.transforms.convert(target_image)
53 |
54 | return {
55 | "image": image,
56 | "cond": cond_image,
57 | "target": target_image,
58 | "image_path": image_path,
59 | "cond_path": cond_path,
60 | "target_path": target_path,
61 | }
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 | # DiffHarmony & DiffHarmony++
2 |
3 | [DiffHarmony](https://arxiv.org/abs/2404.06139) 和 [DiffHarmony++](https://dl.acm.org/doi/10.1145/3664647.3681466) 的官方 Pytorch 实现。
4 |
5 | 论文DiffHarmony完整的会议海报在[这里](./assets/poster.pdf)。
6 |
7 | ## 准备
8 |
9 | ### 环境
10 |
11 | 首先,准备一个虚拟环境。你可以使用 **conda** 或其他你喜欢的工具。
12 | ```shell
13 | python 3.10
14 | pytorch 2.2.0
15 | cuda 12.1
16 | xformers 0.0.24
17 | ```
18 |
19 | 然后,安装所需的依赖。
20 | ```shell
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ## 数据集
25 |
26 | 从[这里](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4)下载 iHarmony4 数据集。
27 |
28 | 确保目录结构如下:
29 | ```shell
30 | data/iHarmony4
31 | |- HCOCO
32 | |- composite_images
33 | |- masks
34 | |- real_images
35 | |- ...
36 | |- HAdobe5k
37 | |- HFlickr
38 | |- Hday2night
39 | |- train.jsonl
40 | |- test.jsonl
41 | ```
42 |
43 | `train.jsonl` 文件内容格式如下:
44 | ```json
45 | {"file_name": "HAdobe5k/composite_images/a0001_1_1.jpg", "text": ""}
46 | {"file_name": "HAdobe5k/composite_images/a0001_1_2.jpg", "text": ""}
47 | {"file_name": "HAdobe5k/composite_images/a0001_1_3.jpg", "text": ""}
48 | {"file_name": "HAdobe5k/composite_images/a0001_1_4.jpg", "text": ""}
49 | ...
50 | ```
51 | 所有 `file_name` 来自原始的 `IHD_train.txt`。`test.jsonl` 和 `IHD_test.txt` 也是如此。
52 |
53 | ## 训练
54 | ### 训练 diffharmony 模型
55 | ```shell
56 | sh scripts/train_diffharmony.sh
57 | ```
58 |
59 | ### 训练 refinement 模型
60 | ```shell
61 | sh scripts/train_refinement_stage.sh
62 | ```
63 |
64 | ### 训练 condition vae(cvae)
65 | ```shell
66 | sh scripts/train_cvae.sh
67 | ```
68 |
69 | ### 训练 diffharmony-gen 和 cvae-gen
70 | 在你的训练参数中添加以下内容:
71 | ```shell
72 | $script
73 | ...
74 | --mode "inverse"
75 | ```
76 | 基本上,它会使用真实图像作为条件,而不是合成图像。
77 |
78 | ### (可选)在线训练 condition vae
79 | 参考 `scripts/train/cvae_online.py`
80 |
81 | ### (可选)使用生成的数据训练 cvae
82 | 参考 `scripts/train/cvae_with_gen_data.py`
83 |
84 | 目的是进一步提高 cvae 在特定领域(即我们生成的数据集)的性能。
85 |
86 | ## 推理
87 | 推理 iHarmony4 数据集
88 | ```shell
89 | sh scripts/inference.sh
90 | ```
91 |
92 | ### 使用 diffharmony-gen 和 cvae-gen 增强 HFlickr 和 Hday2night
93 | ```shell
94 | sh scripts/inference_generate_data.sh
95 | ```
96 | `all_mask_metadata.jsonl` 文件内容格式如下:
97 | ```json
98 | {"file_name": "masks/f800_1.png", "text": ""}
99 | {"file_name": "masks/f801_1.png", "text": ""}
100 | {"file_name": "masks/f803_1.png", "text": ""}
101 | {"file_name": "masks/f804_1.png", "text": ""}
102 | ...
103 | ```
104 |
105 | ### 制作 HumanHarmony 数据集
106 | 首先,生成一些候选的合成图像。
107 |
108 | 然后,使用和谐分类器选择最不和谐的图像。
109 | ```shell
110 | python scripts/misc/classify_cand_gen_data.py
111 | ```
112 |
113 | ## 评估
114 | ```shell
115 | sh scripts/evaluate.sh
116 | ```
117 |
118 | ## 预训练模型
119 | [Baidu](https://pan.baidu.com/s/1IkF6YP4C3fsEAi0_9eCESg), 提取码: aqqd
120 |
121 | [Google Drive](https://drive.google.com/file/d/1rezNdcuZbwejbC9rH9S1SUuaWTGTz_wG/view?usp=drive_link)
122 |
123 | ## 引用
124 | 如果你觉得这个工作有用,请考虑引用:
125 | ```bibtex
126 | @inproceedings{zhou2024diffharmony,
127 | title={DiffHarmony: Latent Diffusion Model Meets Image Harmonization},
128 | author={Zhou, Pengfei and Feng, Fangxiang and Wang, Xiaojie},
129 | booktitle={Proceedings of the 2024 International Conference on Multimedia Retrieval},
130 | pages={1130--1134},
131 | year={2024}
132 | }
133 | @inproceedings{zhou2024diffharmonypp,
134 | title={DiffHarmony++: Enhancing Image Harmonization with Harmony-VAE and Inverse Harmonization Model},
135 | author={Zhou, Pengfei and Feng, Fangxiang and Liu, Guang and Li, Ruifan and Wang, Xiaojie},
136 | booktitle={ACM MM},
137 | year={2024}
138 | }
139 | ```
140 |
141 | ## 联系
142 | 如果你有任何问题,请随时通过 `zhoupengfei@bupt.edu.cn` 联系我。
143 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [中文](./README_zh.md)
2 |
3 | # DiffHarmony & DiffHarmony++
4 |
5 | The official pytorch implementation of [DiffHarmony](https://arxiv.org/abs/2404.06139) and [DiffHarmony++](https://dl.acm.org/doi/10.1145/3664647.3681466).
6 |
7 | Full Conference Poster of DiffHarmony is [here](./assets/poster.pdf).
8 |
9 | ## Preparation
10 |
11 | ### enviroment
12 |
13 | First, prepare a virtual env. You can use **conda** or anything you like.
14 | ```shell
15 | python 3.10
16 | pytorch 2.2.0
17 | cuda 12.1
18 | xformers 0.0.24
19 | ```
20 |
21 | Then, install requirements.
22 | ```shell
23 | pip install -r requirements.txt
24 | ```
25 | ## dataset
26 |
27 | Download iHarmony4 dataset from [here](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4).
28 |
29 | Make sure the structure is just like that:
30 | ```shell
31 | data/iHarmony4
32 | |- HCOCO
33 | |- composite_images
34 | |- masks
35 | |- real_images
36 | |- ...
37 | |- HAdobe5k
38 | |- HFlickr
39 | |- Hday2night
40 | |- train.jsonl
41 | |- test.jsonl
42 | ```
43 |
44 | The content in `train.jsonl` fit the following format
45 | ```json
46 | {"file_name": "HAdobe5k/composite_images/a0001_1_1.jpg", "text": ""}
47 | {"file_name": "HAdobe5k/composite_images/a0001_1_2.jpg", "text": ""}
48 | {"file_name": "HAdobe5k/composite_images/a0001_1_3.jpg", "text": ""}
49 | {"file_name": "HAdobe5k/composite_images/a0001_1_4.jpg", "text": ""}
50 | ...
51 | ```
52 | All `file_name` are from the original `IHD_train.txt`. Same way with `test.jsonl` and `IHD_test.txt`.
53 |
54 |
55 | ## Training
56 | ### Train diffharmony model
57 | ```shell
58 | sh scripts/train_diffharmony.sh
59 | ```
60 |
61 | ### Train refinement model
62 | ```shell
63 | sh scripts/train_refinement_stage.sh
64 | ```
65 |
66 | ### Train condition vae (cvae)
67 | ```shell
68 | sh scripts/train_cvae.sh
69 | ```
70 |
71 | ### Train diffharmony-gen and cvae-gen
72 | Just add this in your training args:
73 | ```shell
74 | $script
75 | ...
76 | --mode "inverse"
77 | ```
78 | Basically it will use ground truth images as condition instead of composite images.
79 |
80 |
81 | ### (optional) online training of condition vae
82 | refer to `scripts/train/cvae_online.py`
83 |
84 | ### (optional) train cvae with generated data
85 | refer to `scripts/train/cvae_with_gen_data.py`
86 |
87 | Purpose here is trying to improve cvae performance further on specific domain, i.e. our generated dataset.
88 |
89 | ## Inference
90 | Inference iHarmony4 dataset
91 | ```shell
92 | sh scripts/inference.sh
93 | ```
94 |
95 | ### use diffharmony-gen and cvae-gen to augment HFlickr and Hday2night
96 | ```shell
97 | sh scripts/inference_generate_data.sh
98 | ```
99 | The `all_mask_metadata.jsonl` file as its name fits following format:
100 | ```json
101 | {"file_name": "masks/f800_1.png", "text": ""}
102 | {"file_name": "masks/f801_1.png", "text": ""}
103 | {"file_name": "masks/f803_1.png", "text": ""}
104 | {"file_name": "masks/f804_1.png", "text": ""}
105 | ...
106 | ```
107 |
108 | ### Make HumanHarmony dataset
109 | First, generate some candidate composite images.
110 |
111 | Then, use harmony classifier to select the most unharmonious images.
112 | ```shell
113 | python scripts/misc/classify_cand_gen_data.py
114 | ```
115 |
116 | ## Evaluation
117 | ```shell
118 | sh scripts/evaluate.sh
119 | ```
120 | ## Pretrained Models
121 | [Baidu](https://pan.baidu.com/s/1IkF6YP4C3fsEAi0_9eCESg), code: aqqd
122 |
123 | [Google Drive](https://drive.google.com/file/d/1rezNdcuZbwejbC9rH9S1SUuaWTGTz_wG/view?usp=drive_link)
124 |
125 | ## Citation
126 | If you find this work useful, please consider citing:
127 | ```bibtex
128 | @inproceedings{zhou2024diffharmony,
129 | title={DiffHarmony: Latent Diffusion Model Meets Image Harmonization},
130 | author={Zhou, Pengfei and Feng, Fangxiang and Wang, Xiaojie},
131 | booktitle={Proceedings of the 2024 International Conference on Multimedia Retrieval},
132 | pages={1130--1134},
133 | year={2024}
134 | }
135 | @inproceedings{zhou2024diffharmonypp,
136 | title={DiffHarmony++: Enhancing Image Harmonization with Harmony-VAE and Inverse Harmonization Model},
137 | author={Zhou, Pengfei and Feng, Fangxiang and Liu, Guang and Li, Ruifan and Wang, Xiaojie},
138 | booktitle={ACM MM},
139 | year={2024}
140 | }
141 | ```
142 | ## Contact
143 | If you have any questions, please feel free to contact me via `zhoupengfei@bupt.edu.cn` .
144 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/python
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python
3 |
4 | output
5 | .vscode
6 |
7 | ### Python ###
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | .pybuilder/
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | # For a library or package, you might want to ignore these files since the code is
94 | # intended to run in multiple environments; otherwise, check them in:
95 | # .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/#use-with-ide
117 | .pdm.toml
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # Cython debug symbols
160 | cython_debug/
161 |
162 | # PyCharm
163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165 | # and can be added to the global gitignore or merged into this file. For a more nuclear
166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167 | #.idea/
168 |
169 | ### Python Patch ###
170 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
171 | poetry.toml
172 |
173 | # ruff
174 | .ruff_cache/
175 |
176 | # LSP config files
177 | pyrightconfig.json
178 |
179 | # End of https://www.toptal.com/developers/gitignore/api/python
180 |
181 |
--------------------------------------------------------------------------------
/scripts/misc/classify_cand_gen_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import os
4 | from torchvision import models, transforms
5 | from torchvision.transforms import InterpolationMode
6 | from tqdm import tqdm
7 | from PIL import Image
8 | import glob
9 | from collections import defaultdict
10 | from accelerate import Accelerator
11 | from einops import rearrange
12 | from accelerate.utils import set_seed
13 |
14 |
15 | class CustomDataset(Dataset):
16 | def __init__(self, dataset_root):
17 | cand_composite_paths = sorted(glob.glob(
18 | os.path.join(dataset_root, "cand_composite_images", "*.jpg")
19 | ))
20 | data = defaultdict(list)
21 | for cand_composite_path in tqdm(cand_composite_paths):
22 | image_name = os.path.basename(cand_composite_path).split("_")[0]
23 | data[image_name].append(cand_composite_path)
24 | self.data = [
25 | {
26 | "real_path": os.path.join(dataset_root, "real_images", k + ".jpg"),
27 | "cand_composite_paths": v,
28 | }
29 | for k, v in data.items()
30 | ]
31 |
32 | self.transform = transforms.Compose(
33 | [
34 | transforms.Resize(
35 | (256, 256), interpolation=InterpolationMode.BILINEAR, antialias=True
36 | ),
37 | transforms.ToTensor(),
38 | transforms.Normalize(
39 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
40 | ),
41 | ]
42 | )
43 |
44 | def __len__(self):
45 | return len(self.data)
46 |
47 | def __getitem__(self, idx):
48 | item = self.data[idx]
49 | cand_composite_images = [
50 | self.transform(Image.open(cand_composite_path).convert("RGB"))
51 | for cand_composite_path in item["cand_composite_paths"]
52 | ]
53 | cand_composite_images = torch.stack(
54 | cand_composite_images, dim=0
55 | ) # (t, c, h, w)
56 | return {
57 | "cand_composite_images": cand_composite_images,
58 | "cand_composite_paths": item["cand_composite_paths"],
59 | }
60 |
61 | def collate_fn(batch):
62 | cand_composite_images = [item["cand_composite_images"] for item in batch]
63 | cand_composite_paths = [item["cand_composite_paths"] for item in batch]
64 | return {
65 | "cand_composite_images": torch.stack(cand_composite_images, dim=0), # (b, t, c, h, w)
66 | "cand_composite_paths": cand_composite_paths,
67 | }
68 |
69 |
70 | if __name__ == "__main__":
71 | seed = 0
72 | dataset_root = ""
73 | model_path = "checkpoints/harmony_classifier/model_state_dict.pth"
74 | batch_size = 16
75 |
76 | set_seed(seed)
77 | accelerator = Accelerator()
78 | model = models.resnet50()
79 | num_ftrs = model.fc.in_features
80 | model.fc = torch.nn.Linear(num_ftrs, 2)
81 | model.load_state_dict(torch.load(model_path))
82 | model.eval()
83 | model.to(accelerator.device)
84 |
85 | dataset = CustomDataset(dataset_root)
86 | dataloader = DataLoader(
87 | dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn
88 | )
89 | dataloader = accelerator.prepare(dataloader)
90 |
91 | progress_bar = tqdm(
92 | range(0, len(dataloader)),
93 | initial=0,
94 | desc="Batches",
95 | # Only show the progress bar once on each machine.
96 | disable=not accelerator.is_local_main_process,
97 | )
98 | for step, batch in enumerate(dataloader):
99 | imgs = batch["cand_composite_images"]
100 | paths = batch["cand_composite_paths"]
101 | bs = imgs.size(0)
102 | imgs = (
103 | rearrange(imgs, "b t c h w -> (b t) c h w")
104 | .contiguous()
105 | .to(accelerator.device)
106 | )
107 | with torch.inference_mode():
108 | logits = model(imgs) # (b*t, 2)
109 | probabilities = torch.nn.functional.softmax(logits, dim=-1)
110 | unharmony_probs = rearrange(
111 | probabilities[:, 1], "(b t) -> b t", b=bs
112 | ).contiguous()
113 | _, max_index = torch.max(unharmony_probs, dim=1) # (b,)
114 | # 复制 max index 指定的图片
115 | for i, (p, idx) in enumerate(zip(paths, max_index)):
116 | os.system(
117 | f"cp {p[idx.item()]} {os.path.join(dataset_root, 'composite_images')}"
118 | )
119 | progress_bar.close()
120 | accelerator.wait_for_everyone()
121 |
--------------------------------------------------------------------------------
/scripts/inference/inverse.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 |
5 | import torch
6 | import torch.utils.checkpoint
7 | from accelerate import Accelerator
8 | from accelerate.utils import set_seed
9 | from torchvision.transforms.functional import to_pil_image, resize, to_tensor
10 | from tqdm.auto import tqdm
11 |
12 | from diffusers import (
13 | AutoencoderKL,
14 | UNet2DConditionModel,
15 | EulerAncestralDiscreteScheduler,
16 | )
17 | from src.pipelines.pipeline_stable_diffusion_harmony import (
18 | StableDiffusionHarmonyPipeline,
19 | )
20 | from src.dataset.ihd_dataset import IhdDatasetMultiRes as Dataset
21 | from src.models.condition_vae import ConditionVAE
22 |
23 | def parse_args(input_args=None):
24 | parser = argparse.ArgumentParser(
25 | description="Simple example of a inference script."
26 | )
27 | parser.add_argument(
28 | "--pretrained_model_name_or_path",
29 | type=str,
30 | default=None,
31 | required=True,
32 | help="Path to pretrained model or model identifier from huggingface.co/models.",
33 | )
34 | parser.add_argument(
35 | "--pretrained_vae_model_name_or_path",
36 | type=str,
37 | default=None,
38 | help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
39 | )
40 | parser.add_argument(
41 | "--pretrained_unet_model_name_or_path",
42 | type=str,
43 | default=None,
44 | )
45 | parser.add_argument(
46 | "--stage2_model_name_or_path",
47 | type=str,
48 | default=None,
49 | )
50 | parser.add_argument(
51 | "--dataset_root",
52 | type=str,
53 | default=None,
54 | help=(
55 | "A folder containing the training data. Folder contents must follow the structure described in"
56 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
57 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
58 | ),
59 | )
60 | parser.add_argument(
61 | "--train_file",
62 | type=str,
63 | default=None,
64 | )
65 | parser.add_argument(
66 | "--test_file",
67 | type=str,
68 | default=None,
69 | )
70 | parser.add_argument(
71 | "--output_dir",
72 | type=str,
73 | required=True,
74 | )
75 | parser.add_argument(
76 | "--seed", type=int, default=42, help="A seed for reproducible training."
77 | )
78 | parser.add_argument(
79 | "--resolution",
80 | type=int,
81 | default=1024,
82 | )
83 | parser.add_argument(
84 | "--output_resolution",
85 | type=int,
86 | default=256,
87 | )
88 | parser.add_argument(
89 | "--random_crop",
90 | default=False,
91 | action="store_true",
92 | )
93 | parser.add_argument(
94 | "--random_flip",
95 | default=False,
96 | action="store_true",
97 | )
98 | parser.add_argument(
99 | "--mask_dilate",
100 | type=int,
101 | default=0,
102 | )
103 | parser.add_argument(
104 | "--eval_batch_size",
105 | type=int,
106 | default=4,
107 | help="The number of images to generate for evaluation.",
108 | )
109 | parser.add_argument(
110 | "--dataloader_num_workers",
111 | type=int,
112 | default=0,
113 | help=(
114 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
115 | ),
116 | )
117 | parser.add_argument(
118 | "--mixed_precision",
119 | type=str,
120 | default=None,
121 | choices=["no", "fp16", "bf16"],
122 | help=(
123 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
124 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
125 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
126 | ),
127 | )
128 | parser.add_argument(
129 | "--rounds",
130 | type=int,
131 | default=1,
132 | )
133 | if input_args is not None:
134 | args = parser.parse_args(input_args)
135 | else:
136 | args = parser.parse_args()
137 | return args
138 |
139 | def list_in_str(input_list, target_str):
140 | for item in input_list:
141 | if item in target_str:
142 | return True
143 | return False
144 |
145 | def replace_background(fake_image:torch.Tensor, real_image:torch.Tensor, mask:torch.Tensor):
146 | real_image = real_image.to(fake_image)
147 | mask = mask.to(fake_image)
148 |
149 | fake_image = fake_image * mask + real_image * (1 - mask)
150 | fake_image = to_pil_image(fake_image)
151 | return fake_image
152 |
153 |
154 | def main(args):
155 | accelerator = Accelerator(
156 | mixed_precision=args.mixed_precision,
157 | )
158 | if args.seed is not None:
159 | set_seed(args.seed)
160 | if accelerator.is_main_process:
161 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
162 | weight_dtype = torch.float32
163 | if accelerator.mixed_precision == "fp16":
164 | weight_dtype = torch.float16
165 | elif accelerator.mixed_precision == "bf16":
166 | weight_dtype = torch.bfloat16
167 |
168 | if list_in_str(["condition_vae", "cvae"], args.pretrained_vae_model_name_or_path):
169 | vae_cls = ConditionVAE
170 | else:
171 | vae_cls = AutoencoderKL
172 |
173 | vae = vae_cls.from_pretrained(
174 | args.pretrained_vae_model_name_or_path,
175 | torch_dtype=weight_dtype,
176 | )
177 | unet = UNet2DConditionModel.from_pretrained(
178 | args.pretrained_unet_model_name_or_path,
179 | torch_dtype=weight_dtype,
180 | )
181 | pipeline = StableDiffusionHarmonyPipeline.from_pretrained(
182 | args.pretrained_model_name_or_path,
183 | vae=vae,
184 | unet=unet,
185 | torch_dtype=weight_dtype,
186 | )
187 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
188 | pipeline.scheduler.config
189 | )
190 | pipeline.to(accelerator.device)
191 | pipeline.enable_xformers_memory_efficient_attention()
192 | pipeline.set_progress_bar_config(disable=True)
193 |
194 | dataset = Dataset(
195 | split="test",
196 | tokenizer=None,
197 | resolutions=[args.resolution, args.output_resolution],
198 | opt=args,
199 | )
200 | dataloader = torch.utils.data.DataLoader(
201 | dataset,
202 | batch_size=args.eval_batch_size,
203 | shuffle=True,
204 | num_workers=args.dataloader_num_workers,
205 | )
206 |
207 | dataloader = accelerator.prepare(dataloader)
208 |
209 | for num_round in range(1, args.rounds+1):
210 | progress_bar = tqdm(
211 | range(0, len(dataloader)),
212 | initial=0,
213 | desc="Batches",
214 | # Only show the progress bar once on each machine.
215 | disable=not accelerator.is_local_main_process,
216 | )
217 | for step, batch in enumerate(dataloader):
218 | eval_mask_images = batch[args.resolution]["mask"]
219 | eval_gt_images = batch[args.resolution]["real"]
220 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
221 | with torch.inference_mode():
222 | samples = pipeline(
223 | prompt=batch[args.resolution]["caption"],
224 | image=eval_gt_images,
225 | mask_image=eval_mask_images,
226 | height=args.resolution,
227 | width=args.resolution,
228 | num_inference_steps=5,
229 | guidance_scale=1.0,
230 | generator=generator,
231 | output_type="pt",
232 | ).images # [0,1] torch tensor
233 |
234 | samples = samples.clamp(0, 1)
235 |
236 | for i, sample in enumerate(samples):
237 | sample = to_pil_image(sample)
238 | output_shape = (args.output_resolution, args.output_resolution)
239 | if sample.size != output_shape:
240 | sample = resize(sample, output_shape, antialias=True)
241 |
242 | sample = replace_background(to_tensor(sample), batch[args.output_resolution]["real"][i].add(1).div(2), batch[args.output_resolution]["mask"][i])
243 |
244 | save_name = (
245 | batch[args.output_resolution]["mask_path"][i]
246 | .split("/")[-1]
247 | .split(".")[0]
248 | + f"_{num_round}.jpg"
249 | )
250 | sample.save(
251 | os.path.join(args.output_dir, save_name), quality=100
252 | )
253 | progress_bar.update(1)
254 | progress_bar.close()
255 | accelerator.wait_for_everyone()
256 |
257 |
258 | if __name__ == "__main__":
259 | args = parse_args()
260 | main(args)
261 |
--------------------------------------------------------------------------------
/scripts/inference/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 |
5 | import torch
6 | import torch.utils.checkpoint
7 | from accelerate import Accelerator
8 | from accelerate.utils import set_seed
9 | from torchvision.transforms.functional import to_pil_image, resize
10 | from tqdm.auto import tqdm
11 |
12 | from diffusers import (
13 | AutoencoderKL,
14 | UNet2DConditionModel,
15 | EulerAncestralDiscreteScheduler,
16 | )
17 | from src.pipelines.pipeline_stable_diffusion_harmony import (
18 | StableDiffusionHarmonyPipeline,
19 | )
20 | from src.dataset.ihd_dataset import IhdDatasetMultiRes as Dataset
21 | from src.models.condition_vae import ConditionVAE
22 | from src.models.unet_2d import UNet2DCustom
23 | from src.utils import make_stage2_input
24 |
25 |
26 | def parse_args(input_args=None):
27 | parser = argparse.ArgumentParser(
28 | description="Simple example of a inference script."
29 | )
30 | parser.add_argument(
31 | "--pretrained_model_name_or_path",
32 | type=str,
33 | default=None,
34 | required=True,
35 | help="Path to pretrained model or model identifier from huggingface.co/models.",
36 | )
37 | parser.add_argument(
38 | "--pretrained_vae_model_name_or_path",
39 | type=str,
40 | default=None,
41 | help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
42 | )
43 | parser.add_argument(
44 | "--pretrained_unet_model_name_or_path",
45 | type=str,
46 | default=None,
47 | )
48 | parser.add_argument(
49 | "--stage2_model_name_or_path",
50 | type=str,
51 | default=None,
52 | )
53 | parser.add_argument(
54 | "--dataset_root",
55 | type=str,
56 | default=None,
57 | help=(
58 | "A folder containing the training data. Folder contents must follow the structure described in"
59 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
60 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
61 | ),
62 | )
63 | parser.add_argument(
64 | "--train_file",
65 | type=str,
66 | default=None,
67 | )
68 | parser.add_argument(
69 | "--test_file",
70 | type=str,
71 | default=None,
72 | )
73 | parser.add_argument(
74 | "--output_dir",
75 | type=str,
76 | required=True,
77 | )
78 | parser.add_argument(
79 | "--seed", type=int, default=42, help="A seed for reproducible training."
80 | )
81 | parser.add_argument(
82 | "--resolution",
83 | type=int,
84 | default=1024,
85 | )
86 | parser.add_argument(
87 | "--output_resolution",
88 | type=int,
89 | default=256,
90 | )
91 | parser.add_argument(
92 | "--random_crop",
93 | default=False,
94 | action="store_true",
95 | )
96 | parser.add_argument(
97 | "--random_flip",
98 | default=False,
99 | action="store_true",
100 | )
101 | parser.add_argument(
102 | "--mask_dilate",
103 | type=int,
104 | default=0,
105 | )
106 | parser.add_argument(
107 | "--eval_batch_size",
108 | type=int,
109 | default=4,
110 | help="The number of images to generate for evaluation.",
111 | )
112 | parser.add_argument(
113 | "--dataloader_num_workers",
114 | type=int,
115 | default=0,
116 | help=(
117 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
118 | ),
119 | )
120 | parser.add_argument(
121 | "--mixed_precision",
122 | type=str,
123 | default=None,
124 | choices=["no", "fp16", "bf16"],
125 | help=(
126 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
127 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
128 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
129 | ),
130 | )
131 | parser.add_argument(
132 | "--strict_mode",
133 | default=False,
134 | action="store_true",
135 | )
136 | if input_args is not None:
137 | args = parser.parse_args(input_args)
138 | else:
139 | args = parser.parse_args()
140 | return args
141 |
142 | def list_in_str(input_list, target_str):
143 | for item in input_list:
144 | if item in target_str:
145 | return True
146 | return False
147 |
148 | def main(args):
149 | accelerator = Accelerator(
150 | mixed_precision=args.mixed_precision,
151 | )
152 | if args.seed is not None:
153 | set_seed(args.seed)
154 | if accelerator.is_main_process:
155 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
156 | weight_dtype = torch.float32
157 | if accelerator.mixed_precision == "fp16":
158 | weight_dtype = torch.float16
159 | elif accelerator.mixed_precision == "bf16":
160 | weight_dtype = torch.bfloat16
161 |
162 | if list_in_str(["condition_vae", "cvae"], args.pretrained_vae_model_name_or_path):
163 | vae_cls = ConditionVAE
164 | else:
165 | vae_cls = AutoencoderKL
166 |
167 | vae = vae_cls.from_pretrained(
168 | args.pretrained_vae_model_name_or_path,
169 | torch_dtype=weight_dtype,
170 | )
171 | unet = UNet2DConditionModel.from_pretrained(
172 | args.pretrained_unet_model_name_or_path,
173 | torch_dtype=weight_dtype,
174 | )
175 | pipeline = StableDiffusionHarmonyPipeline.from_pretrained(
176 | args.pretrained_model_name_or_path,
177 | vae=vae,
178 | unet=unet,
179 | torch_dtype=weight_dtype,
180 | )
181 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
182 | pipeline.scheduler.config
183 | )
184 | pipeline.to(accelerator.device)
185 | # pipeline.enable_model_cpu_offload(device=accelerator.device)
186 | pipeline.enable_xformers_memory_efficient_attention()
187 | pipeline.set_progress_bar_config(disable=True)
188 |
189 | use_stage2 = args.stage2_model_name_or_path is not None
190 | if use_stage2:
191 | stage2_model = UNet2DCustom.from_pretrained(
192 | args.stage2_model_name_or_path,
193 | torch_dtype=weight_dtype,
194 | )
195 | stage2_model.to(accelerator.device)
196 | stage2_model.eval()
197 | stage2_model.requires_grad_(False)
198 | in_channels = stage2_model.config.in_channels
199 | stage2_model.enable_xformers_memory_efficient_attention()
200 |
201 | dataset = Dataset(
202 | split="test",
203 | tokenizer=None,
204 | resolutions=[args.resolution, args.output_resolution],
205 | opt=args,
206 | )
207 | dataloader = torch.utils.data.DataLoader(
208 | dataset,
209 | batch_size=args.eval_batch_size,
210 | shuffle=True,
211 | num_workers=args.dataloader_num_workers,
212 | )
213 |
214 | dataloader = accelerator.prepare(dataloader)
215 | progress_bar = tqdm(
216 | range(0, len(dataloader)),
217 | initial=0,
218 | desc="Batches",
219 | # Only show the progress bar once on each machine.
220 | disable=not accelerator.is_local_main_process,
221 | )
222 | for step, batch in enumerate(dataloader):
223 | if args.strict_mode:
224 | eval_mask_images = batch[args.output_resolution]["mask"]
225 | eval_composite_images = batch[args.output_resolution]["comp"]
226 | if args.output_resolution != args.resolution:
227 | tgt_size = [args.resolution, args.resolution]
228 | eval_mask_images = resize(
229 | eval_mask_images,
230 | size=tgt_size,
231 | antialias=True,
232 | ).clamp(0, 1)
233 | eval_composite_images = resize(
234 | eval_composite_images,
235 | size=tgt_size,
236 | antialias=True,
237 | ).clamp(-1, 1)
238 | else:
239 | eval_mask_images = batch[args.resolution]["mask"]
240 | eval_composite_images = batch[args.resolution]["comp"]
241 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
242 | with torch.inference_mode():
243 | samples = pipeline(
244 | prompt=batch[args.resolution]["caption"],
245 | image=eval_composite_images,
246 | mask_image=eval_mask_images,
247 | height=args.resolution,
248 | width=args.resolution,
249 | num_inference_steps=5,
250 | guidance_scale=1.0,
251 | generator=generator,
252 | output_type="pt",
253 | ).images # [0,1] torch tensor
254 |
255 | if use_stage2:
256 | samples = 2 * samples - 1 # [-1,1] torch tensor
257 | if tuple(samples.shape[-2:]) != (
258 | args.output_resolution,
259 | args.output_resolution,
260 | ):
261 | samples = resize(
262 | samples,
263 | size=[args.output_resolution, args.output_resolution],
264 | antialias=True,
265 | ).clamp(-1, 1)
266 | stage2_input = make_stage2_input(
267 | samples,
268 | batch[args.output_resolution]["comp"],
269 | batch[args.output_resolution]["mask"],
270 | in_channels,
271 | )
272 | samples = (
273 | stage2_model(
274 | stage2_input.to(device=accelerator.device, dtype=stage2_model.dtype)
275 | )
276 | .sample.cpu()
277 | .clamp(-1, 1)
278 | )
279 | samples = (samples + 1) / 2
280 |
281 | samples = samples.clamp(0, 1)
282 |
283 | for i, sample in enumerate(samples):
284 | sample = to_pil_image(sample)
285 | output_shape = (args.output_resolution, args.output_resolution)
286 | if sample.size != output_shape:
287 | sample = resize(sample, output_shape, antialias=True)
288 | save_name = (
289 | batch[args.output_resolution]["comp_path"][i]
290 | .split("/")[-1]
291 | .split(".")[0]
292 | + ".png"
293 | )
294 | sample.save(
295 | os.path.join(args.output_dir, save_name), compression=None, quality=100
296 | )
297 | progress_bar.update(1)
298 | progress_bar.close()
299 | accelerator.wait_for_everyone()
300 |
301 |
302 | if __name__ == "__main__":
303 | args = parse_args()
304 | main(args)
305 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | def tensor_to_pil(tensor: torch.Tensor, mode="RGB"):
5 | if mode == "RGB":
6 | image_np = (
7 | tensor.permute(1, 2, 0).add(1).multiply(127.5).numpy().astype(np.uint8)
8 | )
9 | image = Image.fromarray(image_np, mode=mode)
10 | elif mode == "1":
11 | image_np = tensor.squeeze().multiply(255).numpy().astype(np.uint8)
12 | image = Image.fromarray(image_np).convert("1")
13 | else:
14 | raise ValueError(f"not supported mode {mode}")
15 | return image
16 |
17 | import os
18 | def get_paths(comp_path):
19 | # .../ds_name/composite_images/xxx_x_x.jpg
20 | parts = comp_path.split("/")
21 | img_name_parts = parts[-1].split("_")
22 | real_path = os.path.join(*parts[:-2], "real_images", f"{img_name_parts[0]}.jpg")
23 | mask_path = os.path.join(
24 | *parts[:-2], "masks", f"{img_name_parts[0]}_{img_name_parts[1]}.png"
25 | )
26 |
27 | return {
28 | "real": real_path,
29 | "mask": mask_path,
30 | "comp": comp_path,
31 | }
32 |
33 | comp_to_paths = get_paths
34 |
35 |
36 | def harm_to_names(harm_name):
37 | # xxx_x_x_harmonized.jpg
38 | img_name_parts = harm_name.split("_")
39 | real_name = f"{img_name_parts[0]}.jpg"
40 | mask_name = f"{img_name_parts[0]}_{img_name_parts[1]}.png"
41 | comp_name = f"{img_name_parts[0]}_{img_name_parts[1]}_{img_name_parts[2]}.jpg"
42 |
43 | return {
44 | "real": real_name,
45 | "mask": mask_name,
46 | "comp": comp_name,
47 | "harm": harm_name,
48 | }
49 |
50 |
51 | def comp_to_harm_path(output_dir, comp_path, suffix="jpg"):
52 | img_name = comp_path.split("/")[-1]
53 | prefix, _ = img_name.split(".")
54 | out_path = os.path.join(output_dir, f"{prefix}_harmonized.{suffix}")
55 | return out_path
56 |
57 |
58 | import torch
59 | import torch.nn.functional as F
60 |
61 |
62 | def get_mask_shift_with_corresponding_center(
63 | masks: torch.Tensor, cand_masks: torch.Tensor
64 | ):
65 | # masks : (b,c,h,w)
66 | h, w = masks.shape[-2:]
67 |
68 | # mask boundaries
69 | true_rows = torch.any(masks, dim=3)
70 | true_cols = torch.any(masks, dim=2)
71 | true_cand_rows = torch.any(cand_masks, dim=3)
72 | true_cand_cols = torch.any(cand_masks, dim=2)
73 | bs = len(masks)
74 | dy, dx = [], []
75 |
76 | for i in range(bs):
77 | try:
78 | row_min, row_max = torch.nonzero(true_rows[i], as_tuple=True)[-1][
79 | [0, -1]
80 | ].tolist()
81 | col_min, col_max = torch.nonzero(true_cols[i], as_tuple=True)[-1][
82 | [0, -1]
83 | ].tolist()
84 | cand_row_min, cand_row_max = torch.nonzero(
85 | true_cand_rows[i], as_tuple=True
86 | )[-1][[0, -1]].tolist()
87 | cand_col_min, cand_col_max = torch.nonzero(
88 | true_cand_cols[i], as_tuple=True
89 | )[-1][[0, -1]].tolist()
90 | center_y, center_x = (row_min + row_max) / 2, (col_min + col_max) / 2
91 | cand_center_y, cand_center_x = (cand_row_min + cand_row_max) / 2, (
92 | cand_col_min + cand_col_max
93 | ) / 2
94 | dy_i = (
95 | torch.tensor([cand_center_y - center_y])
96 | .float()
97 | .clamp(-row_min, (h - 1) - row_max)
98 | )
99 | dx_i = (
100 | torch.tensor([cand_center_x - center_x])
101 | .float()
102 | .clamp(-col_min, (w - 1) - col_max)
103 | )
104 | dy.append(dy_i)
105 | dx.append(dx_i)
106 | except:
107 | dy.append(torch.tensor([0], dtype=torch.float32))
108 | dx.append(torch.tensor([0], dtype=torch.float32))
109 | dy = torch.cat(dy, dim=0)[..., None, None]
110 | dx = torch.cat(dx, dim=0)[..., None, None]
111 |
112 | shift = torch.stack([dy, dx], dim=0).to(masks.device)
113 | # (2,b,1,1)
114 | return shift
115 |
116 |
117 | def get_random_mask_shift(
118 | masks,
119 | ):
120 | """ """
121 | # masks : (b,c,h,w)
122 | h, w = masks.shape[-2:]
123 |
124 | # mask boundaries
125 | true_rows = torch.any(masks, dim=3)
126 | true_cols = torch.any(masks, dim=2)
127 | bs = len(masks)
128 | dy, dx = [], []
129 | for i in range(bs):
130 | try:
131 | row_min, row_max = torch.nonzero(true_rows[i], as_tuple=True)[-1][
132 | [0, -1]
133 | ].tolist()
134 | col_min, col_max = torch.nonzero(true_cols[i], as_tuple=True)[-1][
135 | [0, -1]
136 | ].tolist()
137 | dy.append(torch.randint(-row_min, h - row_max, (1,)).float())
138 | dx.append(torch.randint(-col_min, w - col_max, (1,)).float())
139 | except:
140 | dy.append(torch.tensor([0], dtype=torch.float32))
141 | dx.append(torch.tensor([0], dtype=torch.float32))
142 | dy = torch.cat(dy, dim=0)[..., None, None]
143 | dx = torch.cat(dx, dim=0)[..., None, None]
144 |
145 | shift = torch.stack([dy, dx], dim=0).to(masks.device)
146 | # (2,b,1,1)
147 | return shift
148 |
149 |
150 | def shift_grid(shape, shift):
151 | h, w = shape
152 | dy, dx = shift
153 |
154 | # make grid for space transformation
155 | y, x = torch.meshgrid(torch.linspace(-1, 1, h), torch.linspace(-1, 1, w), indexing='ij')
156 | y, x = y.to(shift.device)[None, ...], x.to(shift.device)[None, ...] # (1,h,w)
157 |
158 | # xy shift
159 | y = y + 2 * (dy.float() / (h - 1))
160 | x = x + 2 * (dx.float() / (w - 1))
161 |
162 | grid = torch.stack((x, y), dim=-1)
163 | return grid
164 |
165 |
166 | def make_stage2_input(
167 | harm: torch.Tensor, comp: torch.Tensor, mask: torch.Tensor, in_channels
168 | ):
169 | if harm.dim()==4:
170 | cat_dim=1
171 | elif harm.dim()==3:
172 | cat_dim=0
173 | else:
174 | raise ValueError(f"image dims should be 3 or 4 but got {harm.dim()}")
175 | if in_channels == 3:
176 | stage2_input = harm
177 | elif in_channels == 4:
178 | stage2_input = torch.cat([harm, mask.to(harm)], dim=cat_dim)
179 | elif in_channels == 7:
180 | stage2_input = torch.cat([harm, mask.to(harm), comp.to(harm)], dim=cat_dim)
181 | else:
182 | raise ValueError(
183 | f"unsupported stage2 input type : got in channels {in_channels}"
184 | )
185 | return stage2_input
186 |
187 |
188 | import random
189 | import torch.nn.functional as F
190 | from typing import Union
191 |
192 | METHOD_VERSION=("v1", "v2-fgfg", "v2-fgbg")
193 |
194 | def select_cand(real:torch.Tensor, mask:torch.Tensor, method_version:str):
195 | assert method_version in METHOD_VERSION , f"method_version {METHOD_VERSION} not implemented"
196 | if method_version=='v1': # * random selection
197 | cand_indices = []
198 | bs = len(real)
199 | range_indices = list(range(bs))
200 | while True:
201 | shuffled_indices = list(range(bs))
202 | random.shuffle(shuffled_indices)
203 | if any([i == j for i, j in zip(range_indices, shuffled_indices)]):
204 | continue
205 | break
206 | cand_indices = torch.tensor(shuffled_indices, dtype=torch.long)
207 | elif method_version.startswith("v2"): # * select based on lightness difference
208 | fg_images = real * mask
209 | h, w = real.shape[-2:]
210 | EPS = 1e-6
211 | fg_lightness = (
212 | ((fg_images.max(dim=1).values + fg_images.min(dim=1).values) / 2).mean(
213 | dim=[1, 2]
214 | )
215 | * (h * w)
216 | / (mask.sum(dim=[1, 2, 3]) + EPS)
217 | ) # (b,c,h,w) -> (b,h,w) -> (b,) ; average lightness of foreground region
218 | if method_version == "v2-fgfg":
219 | lightness_diff = fg_lightness.view(-1, 1) - fg_lightness.view(1, -1) # (b,b)
220 | elif method_version == "v2-fgbg":
221 | bg_images = real * (1 - mask)
222 | bg_lightness = (
223 | ((bg_images.max(dim=1).values + bg_images.min(dim=1).values) / 2).mean(
224 | dim=[1, 2]
225 | )
226 | * (h * w)
227 | / ((1 - mask).sum(dim=[1, 2, 3]) + EPS)
228 | ) # (b,c,h,w) -> (b,h,w) -> (b,) ; average lightness of background region
229 |
230 | lightness_diff = fg_lightness.view(-1, 1) - bg_lightness.view(1, -1) # (b,b)
231 | cand_indices = lightness_diff.abs().max(dim=1).indices
232 | return cand_indices
233 |
234 | @torch.inference_mode()
235 | def make_comp(
236 | real: torch.Tensor,
237 | cand: torch.Tensor,
238 | mask: torch.Tensor,
239 | method_version:str,
240 | pipeline,
241 | stage2_model = None,
242 | cand_mask: torch.Tensor = None,
243 | return_dict=False,
244 | ):
245 | assert method_version in METHOD_VERSION , f"method_version {method_version} not implemented"
246 | if method_version=='v1':
247 | shift = get_random_mask_shift(mask)
248 | elif method_version.startswith('v2'):
249 | shift = get_mask_shift_with_corresponding_center(mask, cand_mask)
250 |
251 | grid = shift_grid(mask.shape[-2:], -shift)
252 | mask_cand = F.grid_sample(
253 | input=mask.float(), grid=grid, mode="nearest", padding_mode="zeros", align_corners=False
254 | ).to(
255 | mask
256 | ) # (b,c,h,w)
257 | infer_input = (
258 | mask_cand
259 | * F.grid_sample(
260 | input=real.float(), grid=grid, mode="nearest", padding_mode="zeros", align_corners=False
261 | ).to(real)
262 | + (1 - mask_cand) * cand
263 | )
264 |
265 | if return_dict:
266 | dict_output={
267 | "infer_input":infer_input,
268 | }
269 |
270 | bs=len(real)
271 | h, w = infer_input.shape[-2:]
272 | infer_output = pipeline(
273 | prompt=[""] * bs,
274 | image=infer_input,
275 | mask_image=mask_cand,
276 | # generator=generator,
277 | height=h,
278 | width=w,
279 | guidance_scale=1.0,
280 | num_inference_steps=5,
281 | output_type="numpy",
282 | ).images
283 | # output (b,h,w,c) numpy ndarray , [0,1] range value
284 | infer_output = (
285 | torch.tensor(infer_output).to(real).permute(0, 3, 1, 2).sub(0.5).multiply(2)
286 | ) # convert to [-1,1] (b,c,h,w,) pytorch tensor
287 |
288 | if stage2_model is not None:
289 | stage2_input = make_stage2_input(
290 | infer_output,
291 | infer_input,
292 | mask,
293 | stage2_model.config.in_channels,
294 | )
295 | infer_output = stage2_model(stage2_input).sample.clamp(-1, 1)
296 |
297 | if return_dict:
298 | dict_output['infer_output']=infer_output
299 |
300 | inverse_grid = shift_grid(mask.shape[-2:], shift)
301 | comp = (1 - mask) * real + mask * F.grid_sample(
302 | input=infer_output.float(),
303 | grid=inverse_grid,
304 | mode="nearest",
305 | padding_mode="zeros",
306 | align_corners=False,
307 | ).to(infer_output)
308 |
309 | if return_dict:
310 | dict_output['comp']=comp
311 | return dict_output
312 |
313 | return comp
--------------------------------------------------------------------------------
/src/models/unet_2d_blocks.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Union
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from diffusers.utils import logging
7 | from diffusers.models.resnet import (
8 | Downsample2D,
9 | ResnetBlock2D,
10 | ResnetBlockCondNorm2D,
11 | Upsample2D,
12 | )
13 |
14 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
15 |
16 |
17 | def get_down_block(
18 | down_block_type: str,
19 | num_layers: int,
20 | in_channels: int,
21 | out_channels: int,
22 | temb_channels: int,
23 | add_downsample: bool,
24 | resnet_eps: float,
25 | resnet_act_fn: str,
26 | transformer_layers_per_block: int = 1,
27 | num_attention_heads: Optional[int] = None,
28 | resnet_groups: Optional[int] = None,
29 | cross_attention_dim: Optional[int] = None,
30 | downsample_padding: Optional[int] = None,
31 | dual_cross_attention: bool = False,
32 | use_linear_projection: bool = False,
33 | only_cross_attention: bool = False,
34 | upcast_attention: bool = False,
35 | resnet_time_scale_shift: str = "default",
36 | attention_type: str = "default",
37 | resnet_skip_time_act: bool = False,
38 | resnet_out_scale_factor: float = 1.0,
39 | cross_attention_norm: Optional[str] = None,
40 | attention_head_dim: Optional[int] = None,
41 | downsample_type: Optional[str] = None,
42 | dropout: float = 0.0,
43 | ):
44 | # If attn head dim is not defined, we default it to the number of heads
45 | if attention_head_dim is None:
46 | logger.warn(
47 | f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
48 | )
49 | attention_head_dim = num_attention_heads
50 |
51 | down_block_type = (
52 | down_block_type[7:]
53 | if down_block_type.startswith("UNetRes")
54 | else down_block_type
55 | )
56 |
57 | if down_block_type == "DownEncoderBlock2D":
58 | return DownEncoderSkipBlock2D(
59 | num_layers=num_layers,
60 | in_channels=in_channels,
61 | out_channels=out_channels,
62 | dropout=dropout,
63 | add_downsample=add_downsample,
64 | resnet_eps=resnet_eps,
65 | resnet_act_fn=resnet_act_fn,
66 | resnet_groups=resnet_groups,
67 | downsample_padding=downsample_padding,
68 | resnet_time_scale_shift=resnet_time_scale_shift,
69 | )
70 |
71 | raise ValueError(f"{down_block_type} does not exist.")
72 |
73 |
74 | def get_up_block(
75 | up_block_type: str,
76 | num_layers: int,
77 | in_channels: int,
78 | out_channels: int,
79 | prev_output_channel: int,
80 | temb_channels: int,
81 | add_upsample: bool,
82 | resnet_eps: float,
83 | resnet_act_fn: str,
84 | resolution_idx: Optional[int] = None,
85 | transformer_layers_per_block: int = 1,
86 | num_attention_heads: Optional[int] = None,
87 | resnet_groups: Optional[int] = None,
88 | cross_attention_dim: Optional[int] = None,
89 | dual_cross_attention: bool = False,
90 | use_linear_projection: bool = False,
91 | only_cross_attention: bool = False,
92 | upcast_attention: bool = False,
93 | resnet_time_scale_shift: str = "default",
94 | attention_type: str = "default",
95 | resnet_skip_time_act: bool = False,
96 | resnet_out_scale_factor: float = 1.0,
97 | cross_attention_norm: Optional[str] = None,
98 | attention_head_dim: Optional[int] = None,
99 | upsample_type: Optional[str] = None,
100 | dropout: float = 0.0,
101 | ) -> nn.Module:
102 | # If attn head dim is not defined, we default it to the number of heads
103 | if attention_head_dim is None:
104 | logger.warn(
105 | f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
106 | )
107 | attention_head_dim = num_attention_heads
108 |
109 | up_block_type = (
110 | up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
111 | )
112 |
113 | if up_block_type == "UpDecoderBlock2D":
114 | return UpDecoderSkipBlock2D(
115 | num_layers=num_layers,
116 | in_channels=in_channels,
117 | out_channels=out_channels,
118 | resolution_idx=resolution_idx,
119 | dropout=dropout,
120 | add_upsample=add_upsample,
121 | resnet_eps=resnet_eps,
122 | resnet_act_fn=resnet_act_fn,
123 | resnet_groups=resnet_groups,
124 | resnet_time_scale_shift=resnet_time_scale_shift,
125 | temb_channels=temb_channels,
126 | )
127 | raise ValueError(f"{up_block_type} does not exist.")
128 |
129 |
130 | class DownEncoderSkipBlock2D(nn.Module):
131 | def __init__(
132 | self,
133 | in_channels: int,
134 | out_channels: int,
135 | dropout: float = 0.0,
136 | num_layers: int = 1,
137 | resnet_eps: float = 1e-6,
138 | resnet_time_scale_shift: str = "default",
139 | resnet_act_fn: str = "swish",
140 | resnet_groups: int = 32,
141 | resnet_pre_norm: bool = True,
142 | output_scale_factor: float = 1.0,
143 | add_downsample: bool = True,
144 | downsample_padding: int = 1,
145 | ):
146 | super().__init__()
147 | resnets = []
148 |
149 | for i in range(num_layers):
150 | in_channels = in_channels if i == 0 else out_channels
151 | if resnet_time_scale_shift == "spatial":
152 | resnets.append(
153 | ResnetBlockCondNorm2D(
154 | in_channels=in_channels,
155 | out_channels=out_channels,
156 | temb_channels=None,
157 | eps=resnet_eps,
158 | groups=resnet_groups,
159 | dropout=dropout,
160 | time_embedding_norm="spatial",
161 | non_linearity=resnet_act_fn,
162 | output_scale_factor=output_scale_factor,
163 | )
164 | )
165 | else:
166 | resnets.append(
167 | ResnetBlock2D(
168 | in_channels=in_channels,
169 | out_channels=out_channels,
170 | temb_channels=None,
171 | eps=resnet_eps,
172 | groups=resnet_groups,
173 | dropout=dropout,
174 | time_embedding_norm=resnet_time_scale_shift,
175 | non_linearity=resnet_act_fn,
176 | output_scale_factor=output_scale_factor,
177 | pre_norm=resnet_pre_norm,
178 | )
179 | )
180 |
181 | self.resnets = nn.ModuleList(resnets)
182 |
183 | if add_downsample:
184 | self.downsamplers = nn.ModuleList(
185 | [
186 | Downsample2D(
187 | out_channels,
188 | use_conv=True,
189 | out_channels=out_channels,
190 | padding=downsample_padding,
191 | name="op",
192 | )
193 | ]
194 | )
195 | else:
196 | self.downsamplers = None
197 |
198 | def forward(
199 | self, hidden_states: torch.FloatTensor, scale: float = 1.0
200 | ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
201 | output_states = ()
202 | for resnet in self.resnets:
203 | hidden_states = resnet(hidden_states, temb=None, scale=scale)
204 | output_states = output_states + (hidden_states,)
205 |
206 | if self.downsamplers is not None:
207 | for downsampler in self.downsamplers:
208 | hidden_states = downsampler(hidden_states, scale)
209 | output_states = output_states + (hidden_states,)
210 |
211 | return hidden_states, output_states
212 |
213 |
214 | class UpDecoderSkipBlock2D(nn.Module):
215 | def __init__(
216 | self,
217 | in_channels: int,
218 | out_channels: int,
219 | resolution_idx: Optional[int] = None,
220 | dropout: float = 0.0,
221 | num_layers: int = 1,
222 | resnet_eps: float = 1e-6,
223 | resnet_time_scale_shift: str = "default", # default, spatial
224 | resnet_act_fn: str = "swish",
225 | resnet_groups: int = 32,
226 | resnet_pre_norm: bool = True,
227 | output_scale_factor: float = 1.0,
228 | add_upsample: bool = True,
229 | temb_channels: Optional[int] = None,
230 | ):
231 | super().__init__()
232 | resnets = []
233 |
234 | for i in range(num_layers):
235 | input_channels = in_channels if i == 0 else out_channels
236 |
237 | if resnet_time_scale_shift == "spatial":
238 | resnets.append(
239 | ResnetBlockCondNorm2D(
240 | in_channels=input_channels,
241 | out_channels=out_channels,
242 | temb_channels=temb_channels,
243 | eps=resnet_eps,
244 | groups=resnet_groups,
245 | dropout=dropout,
246 | time_embedding_norm="spatial",
247 | non_linearity=resnet_act_fn,
248 | output_scale_factor=output_scale_factor,
249 | )
250 | )
251 | else:
252 | resnets.append(
253 | ResnetBlock2D(
254 | in_channels=input_channels,
255 | out_channels=out_channels,
256 | temb_channels=temb_channels,
257 | eps=resnet_eps,
258 | groups=resnet_groups,
259 | dropout=dropout,
260 | time_embedding_norm=resnet_time_scale_shift,
261 | non_linearity=resnet_act_fn,
262 | output_scale_factor=output_scale_factor,
263 | pre_norm=resnet_pre_norm,
264 | )
265 | )
266 |
267 | self.resnets = nn.ModuleList(resnets)
268 |
269 | if add_upsample:
270 | self.upsamplers = nn.ModuleList(
271 | [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
272 | )
273 | else:
274 | self.upsamplers = None
275 |
276 | self.resolution_idx = resolution_idx
277 |
278 | def forward(
279 | self,
280 | hidden_states: torch.FloatTensor,
281 | res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
282 | temb: Optional[torch.FloatTensor] = None,
283 | scale: float = 1.0,
284 | ) -> torch.FloatTensor:
285 | for resnet in self.resnets:
286 | hidden_states = resnet(hidden_states, temb=temb, scale=scale)
287 | if res_hidden_states_tuple is not None:
288 | res_hidden_states = res_hidden_states_tuple[-1]
289 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
290 | hidden_states = hidden_states + res_hidden_states
291 |
292 | if self.upsamplers is not None:
293 | for upsampler in self.upsamplers:
294 | hidden_states = upsampler(hidden_states)
295 |
296 | return hidden_states
297 |
--------------------------------------------------------------------------------
/scripts/evaluate/main.py:
--------------------------------------------------------------------------------
1 | from os.path import join as pjoin
2 | from skimage.metrics import mean_squared_error as mse
3 | from skimage.metrics import peak_signal_noise_ratio as psnr
4 | from PIL import Image
5 | import numpy as np
6 | import os
7 | import json
8 | from collections import defaultdict
9 | from pathlib import Path
10 | from tqdm import tqdm
11 | import torchvision.transforms.functional as TF
12 | import argparse
13 | from concurrent.futures import ThreadPoolExecutor, as_completed
14 | from threading import Lock
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--input_dir", type=str, required=True)
18 | parser.add_argument("--output_dir", type=str, required=True)
19 | parser.add_argument("--data_dir", type=str, default="./data/iHarmony4")
20 | parser.add_argument("--json_file_path", type=str, default="DataSpecs/all_test.jsonl")
21 | parser.add_argument("--resolution", type=int, default=256)
22 | parser.add_argument("--num_processes", type=int, default=4)
23 | parser.add_argument("--use_gt_bg", action="store_true")
24 | args = parser.parse_args()
25 |
26 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
27 | (Path(args.output_dir) / "imgs").mkdir(parents=True, exist_ok=True)
28 |
29 | # 初始化变量
30 | count = 0
31 | psnr_total = 0
32 | mse_total = 0
33 | fmse_total = 0
34 | bmse_total = 0
35 | ds2psnr = defaultdict(list)
36 | ds2mse = defaultdict(list)
37 | ds2fmse = defaultdict(list)
38 | ds2bmse = defaultdict(list)
39 | ds2mask = defaultdict(list)
40 |
41 |
42 | def get_paths(comp_path):
43 | # .../ds_name/composite_images/xxx_x_x.jpg
44 | parts = comp_path.split("/")
45 | img_name_parts = parts[-1].split("_")
46 | real_path = os.path.join(*parts[:-2], "real_images", f"{img_name_parts[0]}.jpg")
47 | mask_path = os.path.join(
48 | *parts[:-2], "masks", f"{img_name_parts[0]}_{img_name_parts[1]}.png"
49 | )
50 |
51 | return {
52 | "real": real_path,
53 | "mask": mask_path,
54 | "comp": comp_path,
55 | }
56 |
57 |
58 | comp_to_paths = get_paths
59 |
60 |
61 | # 读取jsonl文件并存储到列表中
62 | def read_jsonl_to_list(jsonl_file):
63 | data_list = []
64 | with open(jsonl_file, "r") as file:
65 | for line in file:
66 | # 解析每一行并将其添加到列表中
67 | data = json.loads(line)
68 | data_list.append(data)
69 | return data_list
70 |
71 |
72 | def possible_resize(image: Image.Image, resolution: int):
73 | if image.size != (resolution, resolution):
74 | image = TF.resize(image, [resolution, resolution], antialias=True)
75 | return image
76 |
77 |
78 | dataset_map = {
79 | "a": "HAdobe5k",
80 | "c": "HCOCO",
81 | "d": "Hday2night",
82 | "f": "HFlickr",
83 | "s": "SAM",
84 | "g": "Gen",
85 | }
86 |
87 |
88 | def process_fn(comp_path, lock):
89 |
90 | global count
91 | global psnr_total
92 | global mse_total
93 | global fmse_total
94 | global bmse_total
95 | global ds2psnr
96 | global ds2mse
97 | global ds2fmse
98 | global ds2bmse
99 | global ds2mask
100 |
101 | _output_str = ""
102 | _htmlstr = ""
103 |
104 | all_paths = comp_to_paths(comp_path)
105 | target_path = all_paths["real"]
106 | mask_path = all_paths["mask"]
107 | comp_path = all_paths["comp"]
108 |
109 | harm_path = os.path.join(
110 | args.input_dir, comp_path.split("/")[-1].replace(".jpg", ".png")
111 | )
112 | harm_name = harm_path.split("/")[-1]
113 |
114 | ot_file = harm_path
115 |
116 | dataset = dataset_map[harm_name[0]]
117 |
118 | if os.path.exists(ot_file):
119 | # import pdb ; pdb.set_trace()
120 | ot_img = Image.open(ot_file).convert("RGB")
121 | ot_img = possible_resize(ot_img, args.resolution)
122 |
123 | mk_img = Image.open(mask_path).convert("1")
124 | mk_img = possible_resize(mk_img, args.resolution)
125 |
126 | gt_img = Image.open(target_path).convert("RGB")
127 | gt_img = possible_resize(gt_img, args.resolution)
128 |
129 | mk_np = np.array(mk_img, dtype=np.float32)[..., np.newaxis]
130 | gt_np = np.array(gt_img, dtype=np.float32)
131 | ot_np = np.array(ot_img, dtype=np.float32)
132 |
133 | if args.use_gt_bg:
134 | ot_np = ot_np * mk_np + gt_np * (1 - mk_np)
135 |
136 | mse_score = mse(ot_np, gt_np)
137 | psnr_score = psnr(gt_np, ot_np, data_range=255)
138 |
139 | with lock:
140 | mse_total += mse_score
141 | psnr_total += psnr_score
142 |
143 | w, h = ot_img.size
144 | fscore = mse(ot_np * mk_np, gt_np * mk_np) * (h * w) / (mk_np.sum())
145 | fmse_total += fscore
146 |
147 | bscore = (
148 | mse(ot_np * (1 - mk_np), gt_np * (1 - mk_np))
149 | * (h * w)
150 | / ((1 - mk_np).sum())
151 | )
152 | with lock:
153 | bmse_total += bscore
154 |
155 | _output_str = f"""
156 | filename : {ot_file}
157 | image size : {ot_img.size}
158 | psnr : {psnr_score:.3f}
159 | mse : {mse_score:.3f}
160 | fmse : {fscore:.3f}
161 | bmse : {bscore:.3f}
162 | \n"""
163 | if fscore > 1000:
164 | gt_save_name = (
165 | args.output_dir + "/" + "imgs/" + "gt_" + target_path.split("/")[-1]
166 | )
167 | comp_save_name = (
168 | args.output_dir + "/" + "imgs/" + "comp_" + comp_path.split("/")[-1]
169 | )
170 | ot_save_name = (
171 | args.output_dir + "/" + "imgs/" + "out_" + ot_file.split("/")[-1]
172 | )
173 | mk_save_name = (
174 | args.output_dir + "/" + "imgs/" + "mask_" + mask_path.split("/")[-1]
175 | )
176 |
177 | gt_img.save(gt_save_name)
178 |
179 | comp_img = Image.open(comp_path).convert("RGB")
180 | comp_img = possible_resize(comp_img, args.resolution)
181 | comp_img.save(comp_save_name)
182 |
183 | ot_img.save(ot_save_name)
184 | mk_img.save(
185 | args.output_dir + "/" + "imgs/" + "mask_" + mask_path.split("/")[-1]
186 | )
187 | _htmlstr = (
188 | ''
193 | )
194 | _htmlstr += (
195 | '
'
200 | )
201 | _htmlstr += (
202 | '
'
207 | )
208 | _htmlstr += (
209 | '
'
214 | )
215 |
216 | _htmlstr += '