├── .idea
├── .gitignore
├── Mask-Conditioned-Latent-Space-Diffusion.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── ReadMe.md
├── assets
├── denoising_process
│ ├── 2018
│ │ ├── 0.png
│ │ ├── 1.png
│ │ ├── 10.png
│ │ ├── 11.png
│ │ ├── 12.png
│ │ ├── 13.png
│ │ ├── 14.png
│ │ ├── 15.png
│ │ ├── 16.png
│ │ ├── 17.png
│ │ ├── 18.png
│ │ ├── 19.png
│ │ ├── 2.png
│ │ ├── 20.png
│ │ ├── 3.png
│ │ ├── 4.png
│ │ ├── 5.png
│ │ ├── 6.png
│ │ ├── 7.png
│ │ ├── 8.png
│ │ ├── 9.png
│ │ └── test.gif
│ ├── 3063
│ │ ├── 0.png
│ │ ├── 1.png
│ │ ├── 10.png
│ │ ├── 11.png
│ │ ├── 12.png
│ │ ├── 13.png
│ │ ├── 14.png
│ │ ├── 15.png
│ │ ├── 16.png
│ │ ├── 17.png
│ │ ├── 18.png
│ │ ├── 19.png
│ │ ├── 2.png
│ │ ├── 20.png
│ │ ├── 3.png
│ │ ├── 4.png
│ │ ├── 5.png
│ │ ├── 6.png
│ │ ├── 7.png
│ │ ├── 8.png
│ │ ├── 9.png
│ │ └── test.gif
│ └── 5096
│ │ ├── 0.png
│ │ ├── 1.png
│ │ ├── 10.png
│ │ ├── 11.png
│ │ ├── 12.png
│ │ ├── 13.png
│ │ ├── 14.png
│ │ ├── 15.png
│ │ ├── 16.png
│ │ ├── 17.png
│ │ ├── 18.png
│ │ ├── 19.png
│ │ ├── 2.png
│ │ ├── 20.png
│ │ ├── 3.png
│ │ ├── 4.png
│ │ ├── 5.png
│ │ ├── 6.png
│ │ ├── 7.png
│ │ ├── 8.png
│ │ ├── 9.png
│ │ └── test.gif
├── framework.png
└── teaser.png
├── configs
├── BIPED_sample.yaml
├── BIPED_train.yaml
├── BSDS_sample.yaml
├── BSDS_train.yaml
├── NYUD_sample.yaml
├── NYUD_train.yaml
├── default.yaml
└── first_stage_d4.yaml
├── demo.py
├── demo_trt.py
├── denoising_diffusion_pytorch
├── __init__.py
├── data.py
├── ddm_const_sde.py
├── efficientnet.py
├── ema.py
├── encoder_decoder.py
├── imagenet.py
├── loss.py
├── mask_cond_unet.py
├── quantization.py
├── resnet.py
├── swin_transformer.py
├── uncond_unet.py
├── utils.py
├── vgg.py
├── wavelet.py
└── wcc.py
├── metrics
├── datasets.py
├── defaults.py
├── feature_extractor_base.py
├── feature_extractor_inceptionv3.py
├── generative_model_base.py
├── helpers.py
├── interpolate_compat_tensorflow.py
├── metric.py
├── metric_fid.py
├── metric_isc.py
├── metric_kid.py
├── metric_ppl.py
├── noise.py
├── registry.py
├── sample_similarity_base.py
├── sample_similarity_lpips.py
└── utils.py
├── requirement.txt
├── sample_cond_ldm.py
├── taming
├── __init__.py
├── data
│ ├── ade20k.py
│ ├── annotated_objects_coco.py
│ ├── annotated_objects_dataset.py
│ ├── annotated_objects_open_images.py
│ ├── base.py
│ ├── coco.py
│ ├── conditional_builder
│ │ ├── objects_bbox.py
│ │ ├── objects_center_points.py
│ │ └── utils.py
│ ├── custom.py
│ ├── faceshq.py
│ ├── helper_types.py
│ ├── image_transforms.py
│ ├── imagenet.py
│ ├── open_images_helper.py
│ ├── sflckr.py
│ └── utils.py
├── modules
│ ├── autoencoder
│ │ └── lpips
│ │ │ └── vgg.pth
│ ├── diffusionmodules
│ │ └── model.py
│ ├── discriminator
│ │ └── model.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── lpips.py
│ │ ├── segmentation.py
│ │ ├── util.py
│ │ └── vqperceptual.py
│ ├── misc
│ │ └── coord.py
│ ├── util.py
│ └── vqvae
│ │ └── quantize.py
└── util.py
├── train_cond_ldm.py
├── train_vae.py
└── unet_plus
├── __init__.py
├── ema.py
├── layers.py
├── layerspp.py
├── ncsnpp.py
├── ncsnpp2.py
├── ncsnpp3.py
├── ncsnpp4.py
├── ncsnpp5.py
├── ncsnpp6.py
├── ncsnpp7.py
├── ncsnpp8.py
├── ncsnpp9.py
├── ncsnv2.py
├── normalization.py
├── op
├── __init__.py
├── fused_act.py
├── fused_bias_act.cpp
├── fused_bias_act_kernel.cu
├── upfirdn2d.cpp
├── upfirdn2d.py
└── upfirdn2d_kernel.cu
├── unet_pp.py
├── up_or_down_sampling.py
└── utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/.idea/.gitignore
--------------------------------------------------------------------------------
/.idea/Mask-Conditioned-Latent-Space-Diffusion.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/ReadMe.md:
--------------------------------------------------------------------------------
1 | ## DiffusionEdge: Diffusion Probabilistic Model for Crisp Edge Detection ([arxiv](https://arxiv.org/abs/2401.02032))
2 | [Yunfan Ye](https://yunfan1202.github.io), [Yuhang Huang](https://github.com/GuHuangAI), [Renjiao Yi](https://renjiaoyi.github.io/), [Zhiping Cai](), [Kai Xu](http://kevinkaixu.net/index.html).
3 |
4 | 
5 | 
6 | 
7 |
8 | # News
9 | - We release a real-time model trained on BSDS, please see **[Real-time DiffusionEdge](#vi-real-time-diffusionedge)**.
10 | - We create a [WeChat Group](https://github.com/GuHuangAI/DiffusionEdge/issues/17) for flexible discussion.
11 | Please use WeChat APP to scan the QR code.
12 | - 2023-12-09: The paper is accepted by **AAAI-2024**.
13 | - Upload the pretrained **first stage checkpoint** [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1.1/first_stage_total_320.pt).
14 | - Upload **pretrained weights** and **pre-computed results**.
15 | - We now update a simple demo, please see **[Quickly Demo](#iii-quickly-demo-)**
16 | - First Committed.
17 |
18 | ## I. Before Starting.
19 | 1. install torch
20 | ~~~
21 | conda create -n diffedge python=3.9
22 | conda activate diffedge
23 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
24 | ~~~
25 | 2. install other packages.
26 | ~~~
27 | pip install -r requirement.txt
28 | ~~~
29 | 3. prepare accelerate config.
30 | ~~~
31 | accelerate config
32 | ~~~
33 |
34 | ## II. Prepare Data.
35 | The training data structure should look like:
36 | ```commandline
37 | |-- $data_root
38 | | |-- image
39 | | |-- |-- raw
40 | | |-- |-- |-- XXXXX.jpg
41 | | |-- |-- |-- XXXXX.jpg
42 | | |-- edge
43 | | |-- |-- raw
44 | | |-- |-- |-- XXXXX.png
45 | | |-- |-- |-- XXXXX.png
46 | ```
47 | The testing data structure should look like:
48 | ```commandline
49 | |-- $data_root
50 | | |-- XXXXX.jpg
51 | | |-- XXXXX.jpg
52 | ```
53 |
54 | ## III. Quickly Demo !
55 | 1. download the pretrained weights:
56 |
57 | | Dataset | ODS (SEval/CEval) | OIS (SEval/CEval) | AC | Weight | Pre-computed results |
58 | |---------|--------------------------------------------------------------------|--------------------------------------------------------------------|-------|----------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------|
59 | | BSDS | 0.834 / 0.749 | 0.848 / 0.754 | 0.476 | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1.1/bsds.pt) | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1/results_bsds_stride240_step5.zip) |
60 | | NYUD | 0.761 / 0.732 | 0.766 / 0.738 | 0.846 | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1.1/nyud.pt) | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1/results_nyud_stride240_step5.zip) |
61 | | BIPED | 0.899 | 0.901 | 0.849 | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1.1/biped.pt) | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1/results_biped_stride240_step5.zip) |
62 |
63 | 2. put your images in a directory and run:
64 | ~~~
65 | python demo.py --input_dir $your input dir$ --pre_weight $the downloaded weight path$ --out_dir $the path saves your results$ --bs 8
66 | ~~~
67 | The larger `--bs` is, the faster the inference speed is and the larger the CUDA memory is.
68 |
69 | ## IV. Training.
70 | 1. train the first stage model (AutoEncoder):
71 | ~~~[inference_numpy_for_slide.py](..%2F..%2F..%2F..%2Fmedia%2Fhuang%2F2da18d46-7cba-4259-9abd-0df819bb104c%2Finference_numpy_for_slide.py)
72 | accelerate launch train_vae.py --cfg ./configs/first_stage_d4.yaml
73 | ~~~
74 | 2. you should add the final model weight of the first stage to the config file `./configs/BSDS_train.yaml` (**line 42**), then train latent diffusion-edge model:
75 | ~~~
76 | accelerate launch train_cond_ldm.py --cfg ./configs/BSDS_train.yaml
77 | ~~~
78 |
79 | ## V. Inference.
80 | make sure your model weight path is added in the config file `./configs/BSDS_sample.yaml` (**line 73**), and run:
81 | ~~~
82 | python sample_cond_ldm.py --cfg ./configs/BSDS_sample.yaml
83 | ~~~
84 | Note that you can modify the `sampling_timesteps` (**line 11**) to control the inference speed.
85 |
86 | ## VI. Real-time DiffusionEdge.
87 | 1. We now only test in the following environment, and more details will be released soon.
88 |
89 | | Environment | Version |
90 | |-------------|---------|
91 | | TensorRT | 8.6.1 |
92 | | cuda | 11.6 |
93 | | cudnn | 8.7.0 |
94 | | pycuda | 2024.1 |
95 |
96 | Please follow this [link](https://github.com/NVIDIA/TensorRT) to install TensorRT.
97 |
98 | 2. Download the pretrained [weight](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1.1/model_crop_size_256_fps_150_ods_0813_ois_0825.trt).
99 | Real-time, qi~dong!
100 | ~~~
101 | python demo_trt.py --input_dir $your input dir$ --pre_weight $the downloaded weight path$ --out_dir $the path saves your results$
102 | ~~~
103 |
104 | ## Contact
105 | If you have some questions, please contact with huangai@nudt.edu.cn.
106 | ## Thanks
107 | Thanks to the base code [DDM-Public](https://github.com/GuHuangAI/DDM-Public).
108 | ## Citation
109 | ~~~
110 | @inproceedings{ye2024diffusionedge,
111 | title={DiffusionEdge: Diffusion Probabilistic Model for Crisp Edge Detection},
112 | author={Yunfan Ye and Kai Xu and Yuhang Huang and Renjiao Yi and Zhiping Cai},
113 | year={2024},
114 | booktitle={AAAI}
115 | }
116 | ~~~
117 |
--------------------------------------------------------------------------------
/assets/denoising_process/2018/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/0.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/1.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/10.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/11.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/12.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/13.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/14.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/15.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/16.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/17.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/18.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/19.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/2.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/20.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/3.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/4.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/5.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/6.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/7.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/8.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/9.png
--------------------------------------------------------------------------------
/assets/denoising_process/2018/test.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/test.gif
--------------------------------------------------------------------------------
/assets/denoising_process/3063/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/0.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/1.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/10.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/11.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/12.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/13.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/14.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/15.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/16.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/17.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/18.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/19.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/2.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/20.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/3.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/4.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/5.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/6.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/7.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/8.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/9.png
--------------------------------------------------------------------------------
/assets/denoising_process/3063/test.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/test.gif
--------------------------------------------------------------------------------
/assets/denoising_process/5096/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/0.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/1.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/10.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/11.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/12.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/13.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/14.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/15.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/16.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/17.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/18.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/19.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/2.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/20.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/3.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/4.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/5.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/6.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/7.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/8.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/9.png
--------------------------------------------------------------------------------
/assets/denoising_process/5096/test.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/test.gif
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/framework.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/teaser.png
--------------------------------------------------------------------------------
/configs/BIPED_sample.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | model_type: const_sde
3 | model_name: cond_unet
4 | image_size: [320, 320]
5 | input_keys: ['image', 'cond']
6 | ckpt_path:
7 | ignore_keys: [ ]
8 | only_model: False
9 | timesteps: 1000
10 | train_sample: -1
11 | sampling_timesteps: 5
12 | loss_type: l2
13 | objective: pred_KC
14 | start_dist: normal
15 | perceptual_weight: 0
16 | scale_factor: 0.3
17 | scale_by_std: True
18 | default_scale: True
19 | scale_by_softsign: False
20 | eps: !!float 1e-4
21 | weighting_loss: True
22 | first_stage:
23 | embed_dim: 3
24 | lossconfig:
25 | disc_start: 50001
26 | kl_weight: 0.000001
27 | disc_weight: 0.5
28 | disc_in_channels: 1
29 | ddconfig:
30 | double_z: True
31 | z_channels: 3
32 | resolution: [ 320, 320 ]
33 | in_channels: 1
34 | out_ch: 1
35 | ch: 128
36 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
37 | num_res_blocks: 2
38 | attn_resolutions: [ ]
39 | dropout: 0.0
40 | ckpt_path: 'checkpoints/first_stage_total_320.pt'
41 | unet:
42 | dim: 128
43 | cond_net: swin
44 | channels: 3
45 | out_mul: 1
46 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults)
47 | cond_in_dim: 3
48 | cond_dim: 128
49 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults)
50 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
51 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
52 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
53 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
54 | fourier_scale: 16
55 | cond_pe: False
56 | num_pos_feats: 128
57 | cond_feature_size: [ 80, 80 ]
58 |
59 | data:
60 | name: edge
61 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BIPED_test'
62 | augment_horizontal_flip: True
63 | batch_size: 8
64 | num_workers: 4
65 |
66 | sampler:
67 | sample_type: "slide"
68 | stride: [240, 240]
69 | batch_size: 1
70 | sample_num: 300
71 | use_ema: False
72 | save_folder: "./results"
73 | ckpt_path: "/data/huang/diffusion_edge/checkpoints/BIPED_size320_swin_unet12_no_resize_disloss/model-9.pt"
--------------------------------------------------------------------------------
/configs/BIPED_train.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | model_type: const_sde
3 | model_name: cond_unet
4 | image_size: [320, 320]
5 | input_keys: ['image', 'cond']
6 | ckpt_path:
7 | ignore_keys: [ ]
8 | only_model: False
9 | timesteps: 1000
10 | train_sample: -1
11 | sampling_timesteps: 10
12 | loss_type: l2
13 | objective: pred_KC
14 | start_dist: normal
15 | perceptual_weight: 0
16 | scale_factor: 0.3
17 | scale_by_std: True
18 | default_scale: True
19 | scale_by_softsign: False
20 | eps: !!float 1e-4
21 | weighting_loss: True
22 | use_disloss: True
23 |
24 | first_stage:
25 | embed_dim: 3
26 | lossconfig:
27 | disc_start: 50001
28 | kl_weight: 0.000001
29 | disc_weight: 0.5
30 | disc_in_channels: 1
31 | ddconfig:
32 | double_z: True
33 | z_channels: 3
34 | resolution: [ 320, 320 ]
35 | in_channels: 1
36 | out_ch: 1
37 | ch: 128
38 | ch_mult: [1, 2, 4] # num_down = len(ch_mult)-1
39 | num_res_blocks: 2
40 | attn_resolutions: [ ]
41 | dropout: 0.0
42 | ckpt_path: './checkpoints/first_stage_total_320.pt'
43 |
44 | unet:
45 | dim: 128
46 | cond_net: swin
47 | fix_bb: False
48 | channels: 3
49 | out_mul: 1
50 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults)
51 | cond_in_dim: 3
52 | cond_dim: 128
53 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults)
54 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
55 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
56 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
57 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
58 | fourier_scale: 16
59 | cond_pe: False
60 | num_pos_feats: 128
61 | cond_feature_size: [ 80, 80 ]
62 | input_size: [80, 80]
63 |
64 | data:
65 | name: edge
66 | crop_type: rand_crop
67 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BIPED_100%'
68 | augment_horizontal_flip: True
69 | batch_size: 8
70 | num_workers: 8
71 |
72 | trainer:
73 | gradient_accumulate_every: 1
74 | lr: !!float 5e-5
75 | min_lr: !!float 5e-6
76 | train_num_steps: 100000
77 | save_and_sample_every: 5000
78 | enable_resume: False
79 | log_freq: 1000
80 | results_folder: "./training/BIPED_size320_swin_unet12_no_resize_disloss"
81 | amp: False
82 | fp16: False
83 | resume_milestone: 0
84 | test_before: True
85 | ema_update_after_step: 10000
86 | ema_update_every: 10
87 |
--------------------------------------------------------------------------------
/configs/BSDS_sample.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | model_type: const_sde
3 | model_name: cond_unet
4 | image_size: [320, 320]
5 | input_keys: ['image', 'cond']
6 | ckpt_path:
7 | ignore_keys: [ ]
8 | only_model: False
9 | timesteps: 1000
10 | train_sample: -1
11 | sampling_timesteps: 5
12 | loss_type: l2
13 | objective: pred_noise
14 | start_dist: normal
15 | perceptual_weight: 0
16 | scale_factor: 0.3
17 | scale_by_std: True
18 | default_scale: True
19 | scale_by_softsign: False
20 | eps: !!float 1e-4
21 | weighting_loss: False
22 | first_stage:
23 | embed_dim: 3
24 | lossconfig:
25 | disc_start: 50001
26 | kl_weight: 0.000001
27 | disc_weight: 0.5
28 | disc_in_channels: 1
29 | ddconfig:
30 | double_z: True
31 | z_channels: 3
32 | resolution: [ 320, 320 ]
33 | in_channels: 1
34 | out_ch: 1
35 | ch: 128
36 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
37 | num_res_blocks: 2
38 | attn_resolutions: [ ]
39 | dropout: 0.0
40 | ckpt_path: 'checkpoints/first_stage_total_320.pt'
41 | unet:
42 | dim: 128
43 | cond_net: swin
44 | channels: 3
45 | out_mul: 1
46 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults)
47 | cond_in_dim: 3
48 | cond_dim: 128
49 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults)
50 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
51 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
52 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
53 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
54 | fourier_scale: 16
55 | cond_pe: False
56 | num_pos_feats: 128
57 | cond_feature_size: [ 80, 80 ]
58 |
59 | data:
60 | name: edge
61 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BSDS_test'
62 | augment_horizontal_flip: True
63 | batch_size: 8
64 | num_workers: 4
65 |
66 | sampler:
67 | sample_type: "slide"
68 | stride: [240, 240]
69 | batch_size: 1
70 | sample_num: 300
71 | use_ema: False
72 | save_folder: './results'
73 | ckpt_path: "/data/huang/diffusion_edge/checkpoints/BSDS_swin_unet12_disloss_bs2x8/model-4.pt"
--------------------------------------------------------------------------------
/configs/BSDS_train.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | model_type: const_sde
3 | model_name: cond_unet
4 | image_size: [320, 320]
5 | input_keys: ['image', 'cond']
6 | ckpt_path:
7 | ignore_keys: [ ]
8 | only_model: False
9 | timesteps: 1000
10 | train_sample: -1
11 | sampling_timesteps: 1
12 | loss_type: l2
13 | objective: pred_KC
14 | start_dist: normal
15 | perceptual_weight: 0
16 | scale_factor: 0.3
17 | scale_by_std: True
18 | default_scale: True
19 | scale_by_softsign: False
20 | eps: !!float 1e-4
21 | weighting_loss: True
22 | use_disloss: True
23 |
24 | first_stage:
25 | embed_dim: 3
26 | lossconfig:
27 | disc_start: 50001
28 | kl_weight: 0.000001
29 | disc_weight: 0.5
30 | disc_in_channels: 1
31 | ddconfig:
32 | double_z: True
33 | z_channels: 3
34 | resolution: [ 320, 320 ]
35 | in_channels: 1
36 | out_ch: 1
37 | ch: 128
38 | ch_mult: [1, 2, 4] # num_down = len(ch_mult)-1
39 | num_res_blocks: 2
40 | attn_resolutions: [ ]
41 | dropout: 0.0
42 | ckpt_path: './checkpoints/first_stage_total_320.pt' # the weight is obtained by training the first stage model
43 |
44 | unet:
45 | dim: 128
46 | cond_net: swin
47 | fix_bb: False
48 | channels: 3
49 | out_mul: 1
50 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults)
51 | cond_in_dim: 3
52 | cond_dim: 128
53 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults)
54 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
55 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
56 | fourier_scale: 16
57 | cond_pe: False
58 | num_pos_feats: 128
59 | cond_feature_size: [ 80, 80 ]
60 | input_size: [80, 80]
61 |
62 | data:
63 | name: edge
64 | crop_type: rand_resize_crop
65 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BSDS_100%'
66 | augment_horizontal_flip: True
67 | batch_size: 2
68 | num_workers: 8
69 |
70 | trainer:
71 | gradient_accumulate_every: 8
72 | lr: !!float 5e-5
73 | min_lr: !!float 5e-6
74 | train_num_steps: 100000
75 | save_and_sample_every: 5000
76 | enable_resume: False
77 | log_freq: 1000
78 | results_folder: "./training/BSDS_swin_unet12_disloss_bs2x8"
79 | amp: False
80 | fp16: False
81 | resume_milestone: 0
82 | test_before: True
83 | ema_update_after_step: 10000
84 | ema_update_every: 10
85 |
--------------------------------------------------------------------------------
/configs/NYUD_sample.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | model_type: const_sde
3 | model_name: cond_unet
4 | image_size: [320, 320]
5 | input_keys: ['image', 'cond']
6 | ckpt_path:
7 | ignore_keys: [ ]
8 | only_model: False
9 | timesteps: 1000
10 | train_sample: -1
11 | sampling_timesteps: 5
12 | loss_type: l2
13 | objective: pred_KC
14 | start_dist: normal
15 | perceptual_weight: 0
16 | scale_factor: 0.3
17 | scale_by_std: True
18 | default_scale: True
19 | scale_by_softsign: False
20 | eps: !!float 1e-4
21 | weighting_loss: True
22 | first_stage:
23 | embed_dim: 3
24 | lossconfig:
25 | disc_start: 50001
26 | kl_weight: 0.000001
27 | disc_weight: 0.5
28 | disc_in_channels: 1
29 | ddconfig:
30 | double_z: True
31 | z_channels: 3
32 | resolution: [ 320, 320 ]
33 | in_channels: 1
34 | out_ch: 1
35 | ch: 128
36 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
37 | num_res_blocks: 2
38 | attn_resolutions: [ ]
39 | dropout: 0.0
40 | ckpt_path: 'checkpoints/first_stage_total_320.pt'
41 | unet:
42 | dim: 128
43 | cond_net: swin
44 | channels: 3
45 | out_mul: 1
46 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults)
47 | cond_in_dim: 3
48 | cond_dim: 128
49 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults)
50 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
51 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
52 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
53 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
54 | fourier_scale: 16
55 | cond_pe: False
56 | num_pos_feats: 128
57 | cond_feature_size: [ 80, 80 ]
58 |
59 | data:
60 | name: edge
61 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/NYUD_100%/test/image'
62 | augment_horizontal_flip: True
63 | batch_size: 8
64 | num_workers: 4
65 |
66 | sampler:
67 | sample_type: "slide"
68 | stride: [240, 240]
69 | batch_size: 1
70 | sample_num: 300
71 | use_ema: False
72 | save_folder: "./results"
73 | ckpt_path: "/data/huang/diffusion_edge/checkpoints/NYUD_swin_unet12_no_resize_disloss//model-17.pt"
--------------------------------------------------------------------------------
/configs/NYUD_train.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | model_type: const_sde
3 | model_name: cond_unet
4 | image_size: [320, 320]
5 | input_keys: ['image', 'cond']
6 | ckpt_path:
7 | ignore_keys: [ ]
8 | only_model: False
9 | timesteps: 1000
10 | train_sample: -1
11 | sampling_timesteps: 10
12 | loss_type: l2
13 | objective: pred_KC
14 | start_dist: normal
15 | perceptual_weight: 0
16 | scale_factor: 0.3
17 | scale_by_std: True
18 | default_scale: True
19 | scale_by_softsign: False
20 | eps: !!float 1e-4
21 | weighting_loss: True
22 | use_disloss: True
23 |
24 | first_stage:
25 | embed_dim: 3
26 | lossconfig:
27 | disc_start: 50001
28 | kl_weight: 0.000001
29 | disc_weight: 0.5
30 | disc_in_channels: 1
31 | ddconfig:
32 | double_z: True
33 | z_channels: 3
34 | resolution: [ 320, 320 ]
35 | in_channels: 1
36 | out_ch: 1
37 | ch: 128
38 | ch_mult: [1, 2, 4] # num_down = len(ch_mult)-1
39 | num_res_blocks: 2
40 | attn_resolutions: [ ]
41 | dropout: 0.0
42 | ckpt_path: './checkpoints/first_stage_total_320.pt'
43 |
44 | unet:
45 | dim: 128
46 | cond_net: swin
47 | fix_bb: False
48 | channels: 3
49 | out_mul: 1
50 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults)
51 | cond_in_dim: 3
52 | cond_dim: 128
53 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults)
54 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
55 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
56 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
57 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
58 | fourier_scale: 16
59 | cond_pe: False
60 | num_pos_feats: 128
61 | cond_feature_size: [ 80, 80 ]
62 | input_size: [80, 80]
63 |
64 | data:
65 | name: edge
66 | crop_type: rand_crop
67 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/NYUD_100%/train'
68 | augment_horizontal_flip: True
69 | batch_size: 8
70 | num_workers: 8
71 |
72 | trainer:
73 | gradient_accumulate_every: 2
74 | lr: !!float 5e-5
75 | min_lr: !!float 5e-6
76 | train_num_steps: 100000
77 | save_and_sample_every: 5000
78 | enable_resume: False
79 | log_freq: 1000
80 | results_folder: "./training/NYUD_swin_unet12_no_resize_disloss"
81 | amp: False
82 | fp16: False
83 | resume_milestone: 0
84 | test_before: True
85 | ema_update_after_step: 10000
86 | ema_update_every: 10
87 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | model_type: const_sde
3 | model_name: cond_unet
4 | image_size: [320, 320]
5 | input_keys: ['image', 'cond']
6 | ckpt_path:
7 | ignore_keys: [ ]
8 | only_model: False
9 | timesteps: 1000
10 | train_sample: -1
11 | sampling_timesteps: 1
12 | loss_type: l2
13 | objective: pred_noise
14 | start_dist: normal
15 | perceptual_weight: 0
16 | scale_factor: 0.3
17 | scale_by_std: True
18 | default_scale: True
19 | scale_by_softsign: False
20 | eps: !!float 1e-4
21 | weighting_loss: False
22 | first_stage:
23 | embed_dim: 3
24 | lossconfig:
25 | disc_start: 50001
26 | kl_weight: 0.000001
27 | disc_weight: 0.5
28 | disc_in_channels: 1
29 | ddconfig:
30 | double_z: True
31 | z_channels: 3
32 | resolution: [ 320, 320 ]
33 | in_channels: 1
34 | out_ch: 1
35 | ch: 128
36 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
37 | num_res_blocks: 2
38 | attn_resolutions: [ ]
39 | dropout: 0.0
40 | ckpt_path:
41 | unet:
42 | dim: 128
43 | cond_net: swin
44 | without_pretrain: False
45 | channels: 3
46 | out_mul: 1
47 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults)
48 | cond_in_dim: 3
49 | cond_dim: 128
50 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults)
51 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
52 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
53 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
54 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
55 | fourier_scale: 16
56 | cond_pe: False
57 | num_pos_feats: 128
58 | cond_feature_size: [ 80, 80 ]
59 |
60 | data:
61 | name: edge
62 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BSDS_test'
63 | augment_horizontal_flip: True
64 | batch_size: 8
65 | num_workers: 4
66 |
67 | sampler:
68 | sample_type: "slide"
69 | stride: [240, 240]
70 | batch_size: 1
71 | sample_num: 300
72 | use_ema: True
73 | save_folder:
74 | ckpt_path:
--------------------------------------------------------------------------------
/configs/first_stage_d4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | embed_dim: 3
3 | lossconfig:
4 | disc_start: 50001
5 | kl_weight: 0.000001
6 | disc_weight: 0.5
7 | disc_in_channels: 1
8 | ddconfig:
9 | double_z: True
10 | z_channels: 3
11 | resolution: [320, 320]
12 | in_channels: 1
13 | out_ch: 1
14 | ch: 128
15 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
16 | num_res_blocks: 2
17 | attn_resolutions: [ ]
18 | dropout: 0.0
19 | ckpt_path: #'/home/zhuchenyang/project/hyx/data/pretrain_weights/model-kl-d4.ckpt'
20 |
21 | data:
22 | name: edge
23 | img_folder: '/home/zhuchenyang/project/hyx/data/total_edges'
24 | augment_horizontal_flip: True
25 | batch_size: 8
26 |
27 | trainer:
28 | gradient_accumulate_every: 2
29 | lr: !!float 5e-6
30 | min_lr: !!float 5e-7
31 | train_num_steps: 150000
32 | save_and_sample_every: 10000
33 | log_freq: 100
34 | results_folder: '/home/zhuchenyang/project/hyx/data/total_edges/results_ae_kl_320x320_d4'
35 | amp: False
36 | fp16: False
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | # from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer
2 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | # .path.append()
5 | from taming.modules.losses.vqperceptual import *
6 |
7 |
8 | class LPIPSWithDiscriminator(nn.Module):
9 | def __init__(self, *, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
10 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
11 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
12 | disc_loss="hinge"):
13 |
14 | super().__init__()
15 | assert disc_loss in ["hinge", "vanilla"]
16 | self.kl_weight = kl_weight
17 | self.pixel_weight = pixelloss_weight
18 | self.perceptual_loss = LPIPS().eval()
19 | self.perceptual_weight = perceptual_weight
20 | # output log variance
21 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
22 |
23 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
24 | n_layers=disc_num_layers,
25 | use_actnorm=use_actnorm
26 | ).apply(weights_init)
27 | self.discriminator_iter_start = disc_start
28 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
29 | self.disc_factor = disc_factor
30 | self.discriminator_weight = disc_weight
31 | self.disc_conditional = disc_conditional
32 |
33 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
34 | if last_layer is not None:
35 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
36 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
37 | else:
38 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
39 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
40 |
41 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
42 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
43 | d_weight = d_weight * self.discriminator_weight
44 | return d_weight
45 |
46 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
47 | global_step, last_layer=None, cond=None, split="train",
48 | weights=None):
49 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + \
50 | F.mse_loss(inputs, reconstructions, reduction="none")
51 | if self.perceptual_weight > 0:
52 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
53 | rec_loss = rec_loss + self.perceptual_weight * p_loss
54 |
55 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
56 | weighted_nll_loss = nll_loss
57 | if weights is not None:
58 | weighted_nll_loss = weights*nll_loss
59 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
60 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
61 | kl_loss = posteriors.kl()
62 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
63 |
64 | # now the GAN part
65 | if optimizer_idx == 0:
66 | # generator update
67 | if cond is None:
68 | assert not self.disc_conditional
69 | logits_fake = self.discriminator(reconstructions.contiguous())
70 | else:
71 | assert self.disc_conditional
72 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
73 | g_loss = -torch.mean(logits_fake)
74 |
75 | if self.disc_factor > 0.0:
76 | try:
77 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
78 | except RuntimeError:
79 | assert not self.training
80 | d_weight = torch.tensor(0.0)
81 | else:
82 | d_weight = torch.tensor(0.0)
83 |
84 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
85 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
86 |
87 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
88 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
89 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
90 | "{}/d_weight".format(split): d_weight.detach(),
91 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
92 | "{}/g_loss".format(split): g_loss.detach().mean(),
93 | }
94 | return loss, log
95 |
96 | if optimizer_idx == 1:
97 | # second pass for discriminator update
98 | if cond is None:
99 | logits_real = self.discriminator(inputs.contiguous().detach())
100 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
101 | else:
102 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
103 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
104 |
105 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
106 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
107 |
108 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
109 | "{}/logits_real".format(split): logits_real.detach().mean(),
110 | "{}/logits_fake".format(split): logits_fake.detach().mean()
111 | }
112 | return d_loss, log
113 |
114 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/quantization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import Parameter
4 |
5 |
6 | def weight_quantization(b):
7 | def uniform_quant(x, b):
8 | xdiv = x.mul((2 ** b - 1))
9 | xhard = xdiv.round().div(2 ** b - 1)
10 | return xhard
11 |
12 | class _pq(torch.autograd.Function):
13 | @staticmethod
14 | def forward(ctx, input, alpha):
15 | input.div_(alpha) # weights are first divided by alpha
16 | input_c = input.clamp(min=-1, max=1) # then clipped to [-1,1]
17 | sign = input_c.sign()
18 | input_abs = input_c.abs()
19 | input_q = uniform_quant(input_abs, b).mul(sign)
20 | ctx.save_for_backward(input, input_q)
21 | input_q = input_q.mul(alpha) # rescale to the original range
22 | return input_q
23 |
24 | @staticmethod
25 | def backward(ctx, grad_output):
26 | grad_input = grad_output.clone() # grad for weights will not be clipped
27 | input, input_q = ctx.saved_tensors
28 | i = (input.abs() > 1.).float()
29 | sign = input.sign()
30 | grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum()
31 | return grad_input, grad_alpha
32 |
33 | return _pq().apply
34 |
35 |
36 | class weight_quantize_fn(nn.Module):
37 | def __init__(self, bit_w):
38 | super(weight_quantize_fn, self).__init__()
39 | assert bit_w > 0
40 |
41 | self.bit_w = bit_w - 1
42 | self.weight_q = weight_quantization(b=self.bit_w)
43 | self.register_parameter('w_alpha', Parameter(torch.tensor(3.0), requires_grad=True))
44 |
45 | def forward(self, weight):
46 | mean = weight.data.mean()
47 | std = weight.data.std()
48 | weight = weight.add(-mean).div(std) # weights normalization
49 | weight_q = self.weight_q(weight, self.w_alpha)
50 | return weight_q
51 |
52 | def change_bit(self, bit_w):
53 | self.bit_w = bit_w - 1
54 | self.weight_q = weight_quantization(b=self.bit_w)
55 |
56 | def act_quantization(b, signed=False):
57 | def uniform_quant(x, b=3):
58 | xdiv = x.mul(2 ** b - 1)
59 | xhard = xdiv.round().div(2 ** b - 1)
60 | return xhard
61 |
62 | class _uq(torch.autograd.Function):
63 | @staticmethod
64 | def forward(ctx, input, alpha):
65 | input = input.div(alpha)
66 | input_c = input.clamp(min=-1, max=1) if signed else input.clamp(max=1)
67 | input_q = uniform_quant(input_c, b)
68 | ctx.save_for_backward(input, input_q)
69 | input_q = input_q.mul(alpha)
70 | return input_q
71 |
72 | @staticmethod
73 | def backward(ctx, grad_output):
74 | grad_input = grad_output.clone()
75 | input, input_q = ctx.saved_tensors
76 | i = (input.abs() > 1.).float()
77 | sign = input.sign()
78 | grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum()
79 | grad_input = grad_input * (1 - i)
80 | return grad_input, grad_alpha
81 |
82 | return _uq().apply
83 |
84 | class act_quantize_fn(nn.Module):
85 | def __init__(self, bit_a, signed=False):
86 | super(act_quantize_fn, self).__init__()
87 | self.bit_a = bit_a
88 | self.signed = signed
89 | if signed:
90 | self.bit_a -= 1
91 | assert bit_a > 0
92 |
93 | self.act_q = act_quantization(b=self.bit_a, signed=signed)
94 | self.register_parameter('a_alpha', Parameter(torch.tensor(8.0), requires_grad=True))
95 |
96 | def forward(self, x):
97 | return self.act_q(x, self.a_alpha)
98 |
99 | def change_bit(self, bit_a):
100 | self.bit_a = bit_a
101 | if self.signed:
102 | self.bit_a -= 1
103 | self.act_q = act_quantization(b=self.bit_a, signed=self.signed)
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import time
4 | import logging
5 | import math
6 |
7 | def create_logger(root_dir, des=''):
8 | root_output_dir = Path(root_dir)
9 | # set up logger
10 | if not root_output_dir.exists():
11 | print('=> creating {}'.format(root_output_dir))
12 | root_output_dir.mkdir(exist_ok=True, parents=True)
13 | time_str = time.strftime('%Y-%m-%d-%H-%M')
14 | log_file = '{}_{}.log'.format(time_str, des)
15 | final_log_file = root_output_dir / log_file
16 | head = '%(asctime)-15s %(message)s'
17 | logging.basicConfig(filename=str(final_log_file), format=head)
18 | logger = logging.getLogger()
19 | logger.setLevel(logging.INFO)
20 | console = logging.StreamHandler()
21 | logging.getLogger('').addHandler(console)
22 | return logger
23 |
24 | def exists(x):
25 | return x is not None
26 |
27 | def default(val, d):
28 | if exists(val):
29 | return val
30 | return d() if callable(d) else d
31 |
32 | def identity(t, *args, **kwargs):
33 | return t
34 |
35 | def cycle(dl):
36 | while True:
37 | for data in dl:
38 | yield data
39 |
40 | def has_int_squareroot(num):
41 | return (math.sqrt(num) ** 2) == num
42 |
43 | def num_to_groups(num, divisor):
44 | groups = num // divisor
45 | remainder = num % divisor
46 | arr = [divisor] * groups
47 | if remainder > 0:
48 | arr.append(remainder)
49 | return arr
50 |
51 | def convert_image_to_fn(img_type, image):
52 | if image.mode != img_type:
53 | return image.convert(img_type)
54 | return image
55 |
56 | # normalization functions
57 |
58 | def normalize_to_neg_one_to_one(img):
59 | return img * 2 - 1
60 |
61 | def unnormalize_to_zero_to_one(t):
62 | return (t + 1) * 0.5
63 |
64 | def dict2str(dict):
65 | s = ''
66 | for k, v in dict.items():
67 | s += "{}: {:.5f}, ".format(k, v)
68 | return s
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/wavelet.py:
--------------------------------------------------------------------------------
1 | import pywt
2 | import pywt.data
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Function
6 | import torch.nn.functional as F
7 |
8 |
9 | def create_wavelet_filter(wave, in_size, out_size, type=torch.float):
10 | w = pywt.Wavelet(wave)
11 | dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
12 | dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
13 | dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
14 | dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
15 | dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
16 | dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
17 |
18 | dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
19 |
20 | rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])
21 | rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])
22 | rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
23 | rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
24 | rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
25 | rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
26 |
27 | rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
28 |
29 | return dec_filters, rec_filters
30 |
31 |
32 | def wt(x, filters, in_size, level):
33 | _, _, h, w = x.shape
34 | pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
35 | res = F.conv2d(x, filters, stride=2, groups=in_size, padding=pad)
36 | if level > 1:
37 | res[:, ::4] = wt(res[:, ::4], filters, in_size, level - 1)
38 | res = res.reshape(-1, 2, h // 2, w // 2).transpose(1, 2).reshape(-1, in_size, h, w)
39 | return res
40 |
41 |
42 | def iwt(x, inv_filters, in_size, level):
43 | _, _, h, w = x.shape
44 | pad = (inv_filters.shape[2] // 2 - 1, inv_filters.shape[3] // 2 - 1)
45 | res = x.reshape(-1, h // 2, 2, w // 2).transpose(1, 2).reshape(-1, 4 * in_size, h // 2, w // 2)
46 | if level > 1:
47 | res[:, ::4] = iwt(res[:, ::4], inv_filters, in_size, level - 1)
48 | res = F.conv_transpose2d(res, inv_filters, stride=2, groups=in_size, padding=pad)
49 | return res
50 |
51 |
52 | def get_inverse_transform(weights, in_size, level):
53 | class InverseWaveletTransform(Function):
54 |
55 | @staticmethod
56 | def forward(ctx, input):
57 | with torch.no_grad():
58 | x = iwt(input, weights, in_size, level)
59 | return x
60 |
61 | @staticmethod
62 | def backward(ctx, grad_output):
63 | grad = wt(grad_output, weights, in_size, level)
64 | return grad, None
65 |
66 | return InverseWaveletTransform().apply
67 |
68 |
69 | def get_transform(weights, in_size, level):
70 | class WaveletTransform(Function):
71 |
72 | @staticmethod
73 | def forward(ctx, input):
74 | with torch.no_grad():
75 | x = wt(input, weights, in_size, level)
76 | return x
77 |
78 | @staticmethod
79 | def backward(ctx, grad_output):
80 | grad = iwt(grad_output, weights, in_size, level)
81 | return grad, None
82 |
83 | return WaveletTransform().apply
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/wcc.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Tuple
2 |
3 | import torch
4 | from torch import nn as nn
5 | from torch.nn import functional as F
6 |
7 | from denoising_diffusion_pytorch.quantization import weight_quantize_fn, act_quantize_fn
8 | from denoising_diffusion_pytorch import wavelet
9 |
10 |
11 | class WCC(nn.Conv1d):
12 | def __init__(self, in_channels: int,
13 | out_channels: int,
14 | stride: Union[int, Tuple] = 1,
15 | padding: Union[int, Tuple] = 0,
16 | dilation: Union[int, Tuple] = 1,
17 | groups: int = 1,
18 | bias: bool = False,
19 | levels: int = 3,
20 | compress_rate: float = 0.25,
21 | bit_w: int = 8,
22 | bit_a: int = 8,
23 | wt_type: str = "db1"):
24 | super(WCC, self).__init__(in_channels, out_channels, 1, stride, padding, dilation, groups, bias)
25 | self.layer_type = 'WCC'
26 | self.bit_w = bit_w
27 | self.bit_a = bit_a
28 |
29 | self.weight_quant = weight_quantize_fn(self.bit_w)
30 | self.act_quant = act_quantize_fn(self.bit_a, signed=True)
31 |
32 | self.levels = levels
33 | self.wt_type = wt_type
34 | self.compress_rate = compress_rate
35 |
36 | dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type,
37 | in_size=in_channels,
38 | out_size=out_channels)
39 | self.wt_filters = nn.Parameter(dec_filters, requires_grad=False)
40 | self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False)
41 | self.wt = wavelet.get_transform(self.wt_filters, in_channels, levels)
42 | self.iwt = wavelet.get_inverse_transform(self.iwt_filters, out_channels, levels)
43 |
44 | self.get_pad = lambda n: ((2 ** levels) - n) % (2 ** levels)
45 |
46 | def forward(self, x):
47 | in_shape = x.shape
48 | pads = (0, self.get_pad(in_shape[2]), 0, self.get_pad(in_shape[3]))
49 | x = F.pad(x, pads) # pad to match 2^(levels)
50 |
51 | weight_q = self.weight_quant(self.weight) # quantize weights
52 | x = self.wt(x) # H
53 | topk, ids = self.compress(x) # T
54 | topk_q = self.act_quant(topk) # quantize activations
55 | topk_q = F.conv1d(topk_q, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups) # K_1x1
56 | x = self.decompress(topk_q, ids, x.shape) # T^T
57 | x = self.iwt(x) # H^T
58 |
59 | x = x[:, :, :in_shape[2], :in_shape[3]] # remove pads
60 | return x
61 |
62 | def compress(self, x):
63 | b, c, h, w = x.shape
64 | acc = x.norm(dim=1).pow(2)
65 | acc = acc.view(b, h * w)
66 | k = int(h * w * self.compress_rate)
67 | ids = acc.topk(k, dim=1, sorted=False)[1]
68 | ids.unsqueeze_(dim=1)
69 | topk = x.reshape((b, c, h * w)).gather(dim=2, index=ids.repeat(1, c, 1))
70 | return topk, ids
71 |
72 | def decompress(self, topk, ids, shape):
73 | b, _, h, w = shape
74 | ids = ids.repeat(1, self.out_channels, 1)
75 | x = torch.zeros(size=(b, self.out_channels, h * w), requires_grad=True, device=topk.device)
76 | x = x.scatter(dim=2, index=ids, src=topk)
77 | x = x.reshape((b, self.out_channels, h, w))
78 | return x
79 |
80 | def change_wt_params(self, compress_rate, levels, wt_type="db1"):
81 | self.compress_rate = compress_rate
82 | self.levels = levels
83 | dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type,
84 | in_size=self.in_channels,
85 | out_size=self.out_channels)
86 | self.wt_filters = nn.Parameter(dec_filters, requires_grad=False)
87 | self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False)
88 | self.wt = wavelet.get_transform(self.wt_filters, self.in_channels, levels)
89 | self.iwt = wavelet.get_inverse_transform(self.iwt_filters, self.out_channels, levels)
90 |
91 | def change_bit(self, bit_w, bit_a):
92 | self.bit_w = bit_w
93 | self.bit_a = bit_a
94 | self.weight_quant.change_bit(bit_w)
95 | self.act_quant.change_bit(bit_a)
96 |
97 | if __name__ == '__main__':
98 | wcc = WCC(80, 80)
99 | x = torch.rand(1, 80, 80, 80)
100 | y = wcc(x)
101 | pause = 0
--------------------------------------------------------------------------------
/metrics/datasets.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from contextlib import redirect_stdout
3 |
4 | import torch
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 | from torchvision.datasets import CIFAR10, STL10
8 |
9 | from metrics.helpers import vassert
10 |
11 |
12 | class TransformPILtoRGBTensor:
13 | def __call__(self, img):
14 | vassert(type(img) is Image.Image, 'Input is not a PIL.Image')
15 | width, height = img.size
16 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)
17 | img = img.permute(2, 0, 1)
18 | return img
19 |
20 |
21 | class ImagesPathDataset(Dataset):
22 | def __init__(self, files, transforms=None):
23 | self.files = files
24 | self.transforms = TransformPILtoRGBTensor() if transforms is None else transforms
25 |
26 | def __len__(self):
27 | return len(self.files)
28 |
29 | def __getitem__(self, i):
30 | path = self.files[i]
31 | img = Image.open(path).convert('RGB')
32 | img = self.transforms(img)
33 | return img
34 |
35 |
36 | class Cifar10_RGB(CIFAR10):
37 | def __init__(self, *args, **kwargs):
38 | with redirect_stdout(sys.stderr):
39 | super().__init__(*args, **kwargs)
40 |
41 | def __getitem__(self, index):
42 | img, target = super().__getitem__(index)
43 | return img
44 |
45 |
46 | class STL10_RGB(STL10):
47 | def __init__(self, *args, **kwargs):
48 | with redirect_stdout(sys.stderr):
49 | super().__init__(*args, **kwargs)
50 |
51 | def __getitem__(self, index):
52 | img, target = super().__getitem__(index)
53 | return img
54 |
55 |
56 | class RandomlyGeneratedDataset(Dataset):
57 | def __init__(self, num_samples, *dimensions, dtype=torch.uint8, seed=2021):
58 | vassert(dtype == torch.uint8, 'Unsupported dtype')
59 | rng_stash = torch.get_rng_state()
60 | try:
61 | torch.manual_seed(seed)
62 | self.imgs = torch.randint(0, 255, (num_samples, *dimensions), dtype=dtype)
63 | finally:
64 | torch.set_rng_state(rng_stash)
65 |
66 | def __len__(self):
67 | return self.imgs.shape[0]
68 |
69 | def __getitem__(self, i):
70 | return self.imgs[i]
71 |
--------------------------------------------------------------------------------
/metrics/defaults.py:
--------------------------------------------------------------------------------
1 | DEFAULTS = {
2 | 'input1': None,
3 | 'input2': None,
4 | 'cuda': True,
5 | 'batch_size': 64,
6 | 'isc': False,
7 | 'fid': False,
8 | 'kid': False,
9 | 'ppl': False,
10 | 'feature_extractor': 'inception-v3-compat',
11 | 'feature_layer_isc': 'logits_unbiased',
12 | 'feature_layer_fid': '2048',
13 | 'feature_layer_kid': '2048',
14 | 'feature_extractor_weights_path': None,
15 | 'isc_splits': 10,
16 | 'kid_subsets': 100,
17 | 'kid_subset_size': 1000,
18 | 'kid_degree': 3,
19 | 'kid_gamma': None,
20 | 'kid_coef0': 1,
21 | 'ppl_epsilon': 1e-4,
22 | 'ppl_reduction': 'mean',
23 | 'ppl_sample_similarity': 'lpips-vgg16',
24 | 'ppl_sample_similarity_resize': 64,
25 | 'ppl_sample_similarity_dtype': 'uint8',
26 | 'ppl_discard_percentile_lower': 1,
27 | 'ppl_discard_percentile_higher': 99,
28 | 'ppl_z_interp_mode': 'lerp',
29 | 'samples_shuffle': True,
30 | 'samples_find_deep': False,
31 | 'samples_find_ext': 'png,jpg,jpeg',
32 | 'samples_ext_lossy': 'jpg,jpeg',
33 | 'datasets_root': None,
34 | 'datasets_download': True,
35 | 'cache_root': None,
36 | 'cache': True,
37 | 'input1_cache_name': None,
38 | 'input1_model_z_type': 'normal',
39 | 'input1_model_z_size': None,
40 | 'input1_model_num_classes': 0,
41 | 'input1_model_num_samples': None,
42 | 'input2_cache_name': None,
43 | 'input2_model_z_type': 'normal',
44 | 'input2_model_z_size': None,
45 | 'input2_model_num_classes': 0,
46 | 'input2_model_num_samples': None,
47 | 'rng_seed': 2020,
48 | 'save_cpu_ram': False,
49 | 'verbose': True,
50 | }
51 |
--------------------------------------------------------------------------------
/metrics/feature_extractor_base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from metrics.helpers import vassert
4 |
5 |
6 | class FeatureExtractorBase(nn.Module):
7 | def __init__(self, name, features_list):
8 | """
9 | Base class for feature extractors that can be used in :func:`calculate_metrics`.
10 |
11 | Args:
12 |
13 | name (str): Unique name of the subclassed feature extractor, must be the same as used in
14 | :func:`register_feature_extractor`.
15 |
16 | features_list (list): List of feature names, provided by the subclassed feature extractor.
17 | """
18 | super(FeatureExtractorBase, self).__init__()
19 | vassert(type(name) is str, 'Feature extractor name must be a string')
20 | vassert(type(features_list) in (list, tuple), 'Wrong features list type')
21 | vassert(
22 | all((a in self.get_provided_features_list() for a in features_list)),
23 | f'Requested features {tuple(features_list)} are not on the list provided by the selected feature extractor '
24 | f'{self.get_provided_features_list()}'
25 | )
26 | vassert(len(features_list) == len(set(features_list)), 'Duplicate features requested')
27 | self.name = name
28 | self.features_list = features_list
29 |
30 | def get_name(self):
31 | return self.name
32 |
33 | @staticmethod
34 | def get_provided_features_list():
35 | """
36 | Returns a tuple of feature names, extracted by the subclassed feature extractor.
37 | """
38 | raise NotImplementedError
39 |
40 | def get_requested_features_list(self):
41 | return self.features_list
42 |
43 | def convert_features_tuple_to_dict(self, features):
44 | # The only compound return type of the forward function amenable to JIT tracing is tuple.
45 | # This function simply helps to recover the mapping.
46 | vassert(
47 | type(features) is tuple and len(features) == len(self.features_list),
48 | 'Features must be the output of forward function'
49 | )
50 | return dict(((name, feature) for name, feature in zip(self.features_list, features)))
51 |
52 | def forward(self, input):
53 | """
54 | Returns a tuple of tensors extracted from the `input`, in the same order as they are provided by
55 | `get_provided_features_list()`.
56 | """
57 | raise NotImplementedError
58 |
--------------------------------------------------------------------------------
/metrics/generative_model_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import torch
4 |
5 |
6 | class GenerativeModelBase(ABC, torch.nn.Module):
7 | """
8 | Base class for generative models that can be used as inputs in :func:`calculate_metrics`.
9 | """
10 |
11 | @property
12 | @abstractmethod
13 | def z_size(self):
14 | """
15 | Size of the noise dimension of the generative model (positive integer).
16 | """
17 | pass
18 |
19 | @property
20 | @abstractmethod
21 | def z_type(self):
22 | """
23 | Type of the noise used by the generative model (see :ref:`registry ` for a list of preregistered noise
24 | types, see :func:`register_noise_source` for registering a new noise type).
25 | """
26 | pass
27 |
28 | @property
29 | @abstractmethod
30 | def num_classes(self):
31 | """
32 | Number of classes used by a conditional generative model. Must return zero for unconditional models.
33 | """
34 | pass
35 |
--------------------------------------------------------------------------------
/metrics/helpers.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 |
4 | from metrics.defaults import DEFAULTS
5 |
6 |
7 | def vassert(truecond, message):
8 | if not truecond:
9 | raise ValueError(message)
10 |
11 |
12 | def vprint(verbose, message):
13 | if verbose:
14 | print(message, file=sys.stderr)
15 |
16 |
17 | def get_kwarg(name, kwargs):
18 | return kwargs.get(name, DEFAULTS[name])
19 |
20 |
21 | def json_decode_string(s):
22 | try:
23 | out = json.loads(s)
24 | except json.JSONDecodeError as e:
25 | print(f'Failed to decode JSON string: {s}', file=sys.stderr)
26 | raise
27 | return out
28 |
--------------------------------------------------------------------------------
/metrics/interpolate_compat_tensorflow.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.nn.modules.utils import _ntuple
6 |
7 |
8 | def interpolate_bilinear_2d_like_tensorflow1x(input, size=None, scale_factor=None, align_corners=None, method='slow'):
9 | r"""Down/up samples the input to either the given :attr:`size` or the given :attr:`scale_factor`
10 |
11 | Epsilon-exact bilinear interpolation as it is implemented in TensorFlow 1.x:
12 | https://github.com/tensorflow/tensorflow/blob/f66daa493e7383052b2b44def2933f61faf196e0/tensorflow/core/kernels/image_resizer_state.h#L41
13 | https://github.com/tensorflow/tensorflow/blob/6795a8c3a3678fb805b6a8ba806af77ddfe61628/tensorflow/core/kernels/resize_bilinear_op.cc#L85
14 | as per proposal:
15 | https://github.com/pytorch/pytorch/issues/10604#issuecomment-465783319
16 |
17 | Related materials:
18 | https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35
19 | https://jricheimer.github.io/tensorflow/2019/02/11/resize-confusion/
20 | https://machinethink.net/blog/coreml-upsampling/
21 |
22 | Currently only 2D spatial sampling is supported, i.e. expected inputs are 4-D in shape.
23 |
24 | The input dimensions are interpreted in the form:
25 | `mini-batch x channels x height x width`.
26 |
27 | Args:
28 | input (Tensor): the input tensor
29 | size (Tuple[int, int]): output spatial size.
30 | scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
31 | align_corners (bool, optional): Same meaning as in TensorFlow 1.x.
32 | method (str, optional):
33 | 'slow' (1e-4 L_inf error on GPU, bit-exact on CPU, with checkerboard 32x32->299x299), or
34 | 'fast' (1e-3 L_inf error on GPU and CPU, with checkerboard 32x32->299x299)
35 | """
36 | if method not in ('slow', 'fast'):
37 | raise ValueError('how_exact can only be one of "slow", "fast"')
38 |
39 | if input.dim() != 4:
40 | raise ValueError('input must be a 4-D tensor')
41 |
42 | if not torch.is_floating_point(input):
43 | raise ValueError('input must be of floating point dtype')
44 |
45 | if size is not None and (type(size) not in (tuple, list) or len(size) != 2):
46 | raise ValueError('size must be a list or a tuple of two elements')
47 |
48 | if align_corners is None:
49 | raise ValueError('align_corners is not specified (use this function for a complete determinism)')
50 |
51 | def _check_size_scale_factor(dim):
52 | if size is None and scale_factor is None:
53 | raise ValueError('either size or scale_factor should be defined')
54 | if size is not None and scale_factor is not None:
55 | raise ValueError('only one of size or scale_factor should be defined')
56 | if scale_factor is not None and isinstance(scale_factor, tuple) and len(scale_factor) != dim:
57 | raise ValueError('scale_factor shape must match input shape. '
58 | 'Input is {}D, scale_factor size is {}'.format(dim, len(scale_factor)))
59 |
60 | is_tracing = torch._C._get_tracing_state()
61 |
62 | def _output_size(dim):
63 | _check_size_scale_factor(dim)
64 | if size is not None:
65 | if is_tracing:
66 | return [torch.tensor(i) for i in size]
67 | else:
68 | return size
69 | scale_factors = _ntuple(dim)(scale_factor)
70 | # math.floor might return float in py2.7
71 |
72 | # make scale_factor a tensor in tracing so constant doesn't get baked in
73 | if is_tracing:
74 | return [
75 | (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float()))
76 | for i in range(dim)
77 | ]
78 | else:
79 | return [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]
80 |
81 | def tf_calculate_resize_scale(in_size, out_size):
82 | if align_corners:
83 | if is_tracing:
84 | return (in_size - 1) / (out_size.float() - 1).clamp(min=1)
85 | else:
86 | return (in_size - 1) / max(1, out_size - 1)
87 | else:
88 | if is_tracing:
89 | return in_size / out_size.float()
90 | else:
91 | return in_size / out_size
92 |
93 | out_size = _output_size(2)
94 | scale_x = tf_calculate_resize_scale(input.shape[3], out_size[1])
95 | scale_y = tf_calculate_resize_scale(input.shape[2], out_size[0])
96 |
97 | def resample_using_grid_sample():
98 | grid_x = torch.arange(0, out_size[1], 1, dtype=input.dtype, device=input.device)
99 | grid_x = grid_x * (2 * scale_x / (input.shape[3] - 1)) - 1
100 |
101 | grid_y = torch.arange(0, out_size[0], 1, dtype=input.dtype, device=input.device)
102 | grid_y = grid_y * (2 * scale_y / (input.shape[2] - 1)) - 1
103 |
104 | grid_x = grid_x.view(1, out_size[1]).repeat(out_size[0], 1)
105 | grid_y = grid_y.view(out_size[0], 1).repeat(1, out_size[1])
106 |
107 | grid_xy = torch.cat((grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)), dim=2).unsqueeze(0)
108 | grid_xy = grid_xy.repeat(input.shape[0], 1, 1, 1)
109 |
110 | out = F.grid_sample(input, grid_xy, mode='bilinear', padding_mode='border', align_corners=True)
111 | return out
112 |
113 | def resample_manually():
114 | grid_x = torch.arange(0, out_size[1], 1, dtype=input.dtype, device=input.device)
115 | grid_x = grid_x * torch.tensor(scale_x, dtype=torch.float32)
116 | grid_x_lo = grid_x.long()
117 | grid_x_hi = (grid_x_lo + 1).clamp_max(input.shape[3] - 1)
118 | grid_dx = grid_x - grid_x_lo.float()
119 |
120 | grid_y = torch.arange(0, out_size[0], 1, dtype=input.dtype, device=input.device)
121 | grid_y = grid_y * torch.tensor(scale_y, dtype=torch.float32)
122 | grid_y_lo = grid_y.long()
123 | grid_y_hi = (grid_y_lo + 1).clamp_max(input.shape[2] - 1)
124 | grid_dy = grid_y - grid_y_lo.float()
125 |
126 | # could be improved with index_select
127 | in_00 = input[:, :, grid_y_lo, :][:, :, :, grid_x_lo]
128 | in_01 = input[:, :, grid_y_lo, :][:, :, :, grid_x_hi]
129 | in_10 = input[:, :, grid_y_hi, :][:, :, :, grid_x_lo]
130 | in_11 = input[:, :, grid_y_hi, :][:, :, :, grid_x_hi]
131 |
132 | in_0 = in_00 + (in_01 - in_00) * grid_dx.view(1, 1, 1, out_size[1])
133 | in_1 = in_10 + (in_11 - in_10) * grid_dx.view(1, 1, 1, out_size[1])
134 | out = in_0 + (in_1 - in_0) * grid_dy.view(1, 1, out_size[0], 1)
135 |
136 | return out
137 |
138 | if method == 'slow':
139 | out = resample_manually()
140 | else:
141 | out = resample_using_grid_sample()
142 |
143 | return out
144 |
--------------------------------------------------------------------------------
/metrics/metric_fid.py:
--------------------------------------------------------------------------------
1 | # Functions fid_features_to_statistics and fid_statistics_to_metric are adapted from
2 | # https://github.com/bioinf-jku/TTUR/blob/master/fid.py commit id d4baae8
3 | # Distributed under Apache License 2.0: https://github.com/bioinf-jku/TTUR/blob/master/LICENSE
4 |
5 | import numpy as np
6 | import scipy.linalg
7 | import torch
8 |
9 | from metrics.helpers import get_kwarg, vprint
10 | from metrics.utils import get_cacheable_input_name, cache_lookup_one_recompute_on_miss, \
11 | extract_featuresdict_from_input_id_cached, create_feature_extractor
12 |
13 | KEY_METRIC_FID = 'frechet_inception_distance'
14 |
15 |
16 | def fid_features_to_statistics(features):
17 | assert torch.is_tensor(features) and features.dim() == 2
18 | features = features.numpy()
19 | mu = np.mean(features, axis=0)
20 | sigma = np.cov(features, rowvar=False)
21 | return {
22 | 'mu': mu,
23 | 'sigma': sigma,
24 | }
25 |
26 |
27 | def fid_statistics_to_metric(stat_1, stat_2, verbose):
28 | eps = 1e-6
29 |
30 | mu1, sigma1 = stat_1['mu'], stat_1['sigma']
31 | mu2, sigma2 = stat_2['mu'], stat_2['sigma']
32 | assert mu1.shape == mu2.shape and mu1.dtype == mu2.dtype
33 | assert sigma1.shape == sigma2.shape and sigma1.dtype == sigma2.dtype
34 |
35 | mu1 = np.atleast_1d(mu1)
36 | mu2 = np.atleast_1d(mu2)
37 |
38 | sigma1 = np.atleast_2d(sigma1)
39 | sigma2 = np.atleast_2d(sigma2)
40 |
41 | assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
42 | assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'
43 |
44 | diff = mu1 - mu2
45 |
46 | # Product might be almost singular
47 | covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
48 | if not np.isfinite(covmean).all():
49 | vprint(verbose,
50 | f'WARNING: fid calculation produces singular product; '
51 | f'adding {eps} to diagonal of cov estimates'
52 | )
53 | offset = np.eye(sigma1.shape[0]) * eps
54 | covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset), disp=verbose)
55 |
56 | # Numerical error might give slight imaginary component
57 | if np.iscomplexobj(covmean):
58 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
59 | m = np.max(np.abs(covmean.imag))
60 | assert False, 'Imaginary component {}'.format(m)
61 | covmean = covmean.real
62 |
63 | tr_covmean = np.trace(covmean)
64 |
65 | out = {
66 | KEY_METRIC_FID: float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
67 | }
68 |
69 | vprint(verbose, f'Frechet Inception Distance: {out[KEY_METRIC_FID]}')
70 |
71 | return out
72 |
73 |
74 | def fid_featuresdict_to_statistics(featuresdict, feat_layer_name):
75 | features = featuresdict[feat_layer_name]
76 | statistics = fid_features_to_statistics(features)
77 | return statistics
78 |
79 |
80 | def fid_featuresdict_to_statistics_cached(
81 | featuresdict, cacheable_input_name, feat_extractor, feat_layer_name, **kwargs
82 | ):
83 |
84 | def fn_recompute():
85 | return fid_featuresdict_to_statistics(featuresdict, feat_layer_name)
86 |
87 | if cacheable_input_name is not None:
88 | feat_extractor_name = feat_extractor.get_name()
89 | cached_name = f'{cacheable_input_name}-{feat_extractor_name}-stat-fid-{feat_layer_name}'
90 | stat = cache_lookup_one_recompute_on_miss(cached_name, fn_recompute, **kwargs)
91 | else:
92 | stat = fn_recompute()
93 | return stat
94 |
95 |
96 | def fid_input_id_to_statistics(input_id, feat_extractor, feat_layer_name, **kwargs):
97 | featuresdict = extract_featuresdict_from_input_id_cached(input_id, feat_extractor, **kwargs)
98 | return fid_featuresdict_to_statistics(featuresdict, feat_layer_name)
99 |
100 |
101 | def fid_input_id_to_statistics_cached(input_id, feat_extractor, feat_layer_name, **kwargs):
102 |
103 | def fn_recompute():
104 | return fid_input_id_to_statistics(input_id, feat_extractor, feat_layer_name, **kwargs)
105 |
106 | cacheable_input_name = get_cacheable_input_name(input_id, **kwargs)
107 |
108 | if cacheable_input_name is not None:
109 | feat_extractor_name = feat_extractor.get_name()
110 | cached_name = f'{cacheable_input_name}-{feat_extractor_name}-stat-fid-{feat_layer_name}'
111 | stat = cache_lookup_one_recompute_on_miss(cached_name, fn_recompute, **kwargs)
112 | else:
113 | stat = fn_recompute()
114 | return stat
115 |
116 |
117 | def fid_inputs_to_metric(feat_extractor, **kwargs):
118 | feat_layer_name = get_kwarg('feature_layer_fid', kwargs)
119 | verbose = get_kwarg('verbose', kwargs)
120 |
121 | vprint(verbose, f'Extracting statistics from input 1')
122 | stats_1 = fid_input_id_to_statistics_cached(1, feat_extractor, feat_layer_name, **kwargs)
123 |
124 | vprint(verbose, f'Extracting statistics from input 2')
125 | stats_2 = fid_input_id_to_statistics_cached(2, feat_extractor, feat_layer_name, **kwargs)
126 |
127 | metric = fid_statistics_to_metric(stats_1, stats_2, get_kwarg('verbose', kwargs))
128 | return metric
129 |
130 |
131 | def calculate_fid(**kwargs):
132 | feature_extractor = get_kwarg('feature_extractor', kwargs)
133 | feat_layer_name = get_kwarg('feature_layer_fid', kwargs)
134 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs)
135 | metric = fid_inputs_to_metric(feat_extractor, **kwargs)
136 | return metric
137 |
--------------------------------------------------------------------------------
/metrics/metric_isc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from metrics.helpers import get_kwarg, vprint
5 | from metrics.utils import extract_featuresdict_from_input_id_cached, create_feature_extractor
6 |
7 | KEY_METRIC_ISC_MEAN = 'inception_score_mean'
8 | KEY_METRIC_ISC_STD = 'inception_score_std'
9 |
10 |
11 | def isc_features_to_metric(feature, splits=10, shuffle=True, rng_seed=2020):
12 | assert torch.is_tensor(feature) and feature.dim() == 2
13 | N, C = feature.shape
14 | if shuffle:
15 | rng = np.random.RandomState(rng_seed)
16 | feature = feature[rng.permutation(N), :]
17 | feature = feature.double()
18 |
19 | p = feature.softmax(dim=1)
20 | log_p = feature.log_softmax(dim=1)
21 |
22 | scores = []
23 | for i in range(splits):
24 | p_chunk = p[(i * N // splits): ((i + 1) * N // splits), :]
25 | log_p_chunk = log_p[(i * N // splits): ((i + 1) * N // splits), :]
26 | q_chunk = p_chunk.mean(dim=0, keepdim=True)
27 | kl = p_chunk * (log_p_chunk - q_chunk.log())
28 | kl = kl.sum(dim=1).mean().exp().item()
29 | scores.append(kl)
30 |
31 | return {
32 | KEY_METRIC_ISC_MEAN: float(np.mean(scores)),
33 | KEY_METRIC_ISC_STD: float(np.std(scores)),
34 | }
35 |
36 |
37 | def isc_featuresdict_to_metric(featuresdict, feat_layer_name, **kwargs):
38 | features = featuresdict[feat_layer_name]
39 |
40 | out = isc_features_to_metric(
41 | features,
42 | get_kwarg('isc_splits', kwargs),
43 | get_kwarg('samples_shuffle', kwargs),
44 | get_kwarg('rng_seed', kwargs),
45 | )
46 |
47 | vprint(get_kwarg('verbose', kwargs), f'Inception Score: {out[KEY_METRIC_ISC_MEAN]} ± {out[KEY_METRIC_ISC_STD]}')
48 |
49 | return out
50 |
51 |
52 | def isc_input_id_to_metric(input_id, feat_extractor, feat_layer_name, **kwargs):
53 | featuresdict = extract_featuresdict_from_input_id_cached(input_id, feat_extractor, **kwargs)
54 | return isc_featuresdict_to_metric(featuresdict, feat_layer_name, **kwargs)
55 |
56 |
57 | def calculate_isc(input_id, **kwargs):
58 | feature_extractor = get_kwarg('feature_extractor', kwargs)
59 | feat_layer_name = get_kwarg('feature_layer_isc', kwargs)
60 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs)
61 | metric = isc_input_id_to_metric(input_id, feat_extractor, feat_layer_name, **kwargs)
62 | return metric
63 |
--------------------------------------------------------------------------------
/metrics/metric_kid.py:
--------------------------------------------------------------------------------
1 | # Functions mmd2 and polynomial_kernel are adapted from
2 | # https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py
3 | # Distributed under BSD 3-Clause: https://github.com/mbinkowski/MMD-GAN/blob/master/LICENSE
4 |
5 | import numpy as np
6 | import torch
7 | from tqdm import tqdm
8 |
9 | from metrics.helpers import get_kwarg, vassert, vprint
10 | from metrics.utils import create_feature_extractor, extract_featuresdict_from_input_id_cached
11 |
12 | KEY_METRIC_KID_MEAN = 'kernel_inception_distance_mean'
13 | KEY_METRIC_KID_STD = 'kernel_inception_distance_std'
14 |
15 |
16 | def mmd2(K_XX, K_XY, K_YY, unit_diagonal=False, mmd_est='unbiased'):
17 | vassert(mmd_est in ('biased', 'unbiased', 'u-statistic'), 'Invalid value of mmd_est')
18 |
19 | m = K_XX.shape[0]
20 | assert K_XX.shape == (m, m)
21 | assert K_XY.shape == (m, m)
22 | assert K_YY.shape == (m, m)
23 |
24 | # Get the various sums of kernels that we'll use
25 | # Kts drop the diagonal, but we don't need to compute them explicitly
26 | if unit_diagonal:
27 | diag_X = diag_Y = 1
28 | sum_diag_X = sum_diag_Y = m
29 | else:
30 | diag_X = np.diagonal(K_XX)
31 | diag_Y = np.diagonal(K_YY)
32 |
33 | sum_diag_X = diag_X.sum()
34 | sum_diag_Y = diag_Y.sum()
35 |
36 | Kt_XX_sums = K_XX.sum(axis=1) - diag_X
37 | Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
38 | K_XY_sums_0 = K_XY.sum(axis=0)
39 |
40 | Kt_XX_sum = Kt_XX_sums.sum()
41 | Kt_YY_sum = Kt_YY_sums.sum()
42 | K_XY_sum = K_XY_sums_0.sum()
43 |
44 | if mmd_est == 'biased':
45 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
46 | + (Kt_YY_sum + sum_diag_Y) / (m * m)
47 | - 2 * K_XY_sum / (m * m))
48 | else:
49 | mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1))
50 | if mmd_est == 'unbiased':
51 | mmd2 -= 2 * K_XY_sum / (m * m)
52 | else:
53 | mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1))
54 |
55 | return mmd2
56 |
57 |
58 | def polynomial_kernel(X, Y, degree=3, gamma=None, coef0=1):
59 | if gamma is None:
60 | gamma = 1.0 / X.shape[1]
61 | K = (np.matmul(X, Y.T) * gamma + coef0) ** degree
62 | return K
63 |
64 |
65 | def polynomial_mmd(features_1, features_2, degree, gamma, coef0):
66 | k_11 = polynomial_kernel(features_1, features_1, degree=degree, gamma=gamma, coef0=coef0)
67 | k_22 = polynomial_kernel(features_2, features_2, degree=degree, gamma=gamma, coef0=coef0)
68 | k_12 = polynomial_kernel(features_1, features_2, degree=degree, gamma=gamma, coef0=coef0)
69 | return mmd2(k_11, k_12, k_22)
70 |
71 |
72 | def kid_features_to_metric(features_1, features_2, **kwargs):
73 | assert torch.is_tensor(features_1) and features_1.dim() == 2
74 | assert torch.is_tensor(features_2) and features_2.dim() == 2
75 | assert features_1.shape[1] == features_2.shape[1]
76 |
77 | kid_subsets = get_kwarg('kid_subsets', kwargs)
78 | kid_subset_size = get_kwarg('kid_subset_size', kwargs)
79 | verbose = get_kwarg('verbose', kwargs)
80 |
81 | n_samples_1, n_samples_2 = len(features_1), len(features_2)
82 | vassert(
83 | n_samples_1 >= kid_subset_size and n_samples_2 >= kid_subset_size,
84 | f'KID subset size {kid_subset_size} cannot be smaller than the number of samples (input_1: {n_samples_1}, '
85 | f'input_2: {n_samples_2}). Consider using "kid_subset_size" kwarg or "--kid-subset-size" command line key to '
86 | f'proceed.'
87 | )
88 |
89 | features_1 = features_1.cpu().numpy()
90 | features_2 = features_2.cpu().numpy()
91 |
92 | mmds = np.zeros(kid_subsets)
93 | rng = np.random.RandomState(get_kwarg('rng_seed', kwargs))
94 |
95 | for i in tqdm(
96 | range(kid_subsets), disable=not verbose, leave=False, unit='subsets',
97 | desc='Kernel Inception Distance'
98 | ):
99 | f1 = features_1[rng.choice(n_samples_1, kid_subset_size, replace=False)]
100 | f2 = features_2[rng.choice(n_samples_2, kid_subset_size, replace=False)]
101 | o = polynomial_mmd(
102 | f1,
103 | f2,
104 | get_kwarg('kid_degree', kwargs),
105 | get_kwarg('kid_gamma', kwargs),
106 | get_kwarg('kid_coef0', kwargs),
107 | )
108 | mmds[i] = o
109 |
110 | out = {
111 | KEY_METRIC_KID_MEAN: float(np.mean(mmds)),
112 | KEY_METRIC_KID_STD: float(np.std(mmds)),
113 | }
114 |
115 | vprint(verbose, f'Kernel Inception Distance: {out[KEY_METRIC_KID_MEAN]} ± {out[KEY_METRIC_KID_STD]}')
116 |
117 | return out
118 |
119 |
120 | def kid_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs):
121 | features_1 = featuresdict_1[feat_layer_name]
122 | features_2 = featuresdict_2[feat_layer_name]
123 | metric = kid_features_to_metric(features_1, features_2, **kwargs)
124 | return metric
125 |
126 |
127 | def calculate_kid(**kwargs):
128 | feature_extractor = get_kwarg('feature_extractor', kwargs)
129 | feat_layer_name = get_kwarg('feature_layer_kid', kwargs)
130 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs)
131 | featuresdict_1 = extract_featuresdict_from_input_id_cached(1, feat_extractor, **kwargs)
132 | featuresdict_2 = extract_featuresdict_from_input_id_cached(2, feat_extractor, **kwargs)
133 | metric = kid_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs)
134 | return metric
135 |
--------------------------------------------------------------------------------
/metrics/metric_ppl.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 |
5 | from metrics.generative_model_base import GenerativeModelBase
6 | from metrics.helpers import get_kwarg, vassert, vprint
7 | from metrics.utils import sample_random, batch_interp, create_sample_similarity, \
8 | prepare_input_descriptor_from_input_id, prepare_input_from_descriptor
9 |
10 | KEY_METRIC_PPL_RAW = 'perceptual_path_length_raw'
11 | KEY_METRIC_PPL_MEAN = 'perceptual_path_length_mean'
12 | KEY_METRIC_PPL_STD = 'perceptual_path_length_std'
13 |
14 |
15 | def calculate_ppl(input_id, **kwargs):
16 | """
17 | Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py
18 | """
19 | batch_size = get_kwarg('batch_size', kwargs)
20 | is_cuda = get_kwarg('cuda', kwargs)
21 | verbose = get_kwarg('verbose', kwargs)
22 | epsilon = get_kwarg('ppl_epsilon', kwargs)
23 | interp = get_kwarg('ppl_z_interp_mode', kwargs)
24 | reduction = get_kwarg('ppl_reduction', kwargs)
25 | similarity_name = get_kwarg('ppl_sample_similarity', kwargs)
26 | sample_similarity_resize = get_kwarg('ppl_sample_similarity_resize', kwargs)
27 | sample_similarity_dtype = get_kwarg('ppl_sample_similarity_dtype', kwargs)
28 | discard_percentile_lower = get_kwarg('ppl_discard_percentile_lower', kwargs)
29 | discard_percentile_higher = get_kwarg('ppl_discard_percentile_higher', kwargs)
30 |
31 | input_desc = prepare_input_descriptor_from_input_id(input_id, **kwargs)
32 | model = prepare_input_from_descriptor(input_desc, **kwargs)
33 | vassert(
34 | isinstance(model, GenerativeModelBase),
35 | 'Input needs to be an instance of GenerativeModelBase, which can be either passed programmatically by wrapping '
36 | 'a model with GenerativeModelModuleWrapper, or via command line by specifying a path to a ONNX or PTH (JIT) '
37 | 'model and a set of input1_model_* arguments'
38 | )
39 |
40 | if is_cuda:
41 | model.cuda()
42 |
43 | input_model_num_samples = input_desc['input_model_num_samples']
44 | input_model_num_classes = model.num_classes
45 | input_model_z_size = model.z_size
46 | input_model_z_type = model.z_type
47 |
48 | vassert(input_model_num_classes >= 0, 'Model can be unconditional (0 classes) or conditional (positive)')
49 | vassert(type(input_model_z_size) is int and input_model_z_size > 0,
50 | 'Dimensionality of generator noise not specified ("input1_model_z_size" argument)')
51 | vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number')
52 | vassert(type(input_model_num_samples) is int and input_model_num_samples > 0, 'Number of samples must be positive')
53 | vassert(reduction in ('none', 'mean'), 'Reduction must be one of [none, mean]')
54 | vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile')
55 | vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile')
56 | if discard_percentile_lower is not None and discard_percentile_higher is not None:
57 | vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles')
58 |
59 | sample_similarity = create_sample_similarity(
60 | similarity_name,
61 | sample_similarity_resize=sample_similarity_resize,
62 | sample_similarity_dtype=sample_similarity_dtype,
63 | **kwargs
64 | )
65 |
66 | is_cond = input_desc['input_model_num_classes'] > 0
67 |
68 | rng = np.random.RandomState(get_kwarg('rng_seed', kwargs))
69 |
70 | lat_e0 = sample_random(rng, (input_model_num_samples, input_model_z_size), input_model_z_type)
71 | lat_e1 = sample_random(rng, (input_model_num_samples, input_model_z_size), input_model_z_type)
72 | lat_e1 = batch_interp(lat_e0, lat_e1, epsilon, interp)
73 |
74 | labels = None
75 | if is_cond:
76 | labels = torch.from_numpy(rng.randint(0, input_model_num_classes, (input_model_num_samples,)))
77 |
78 | distances = []
79 |
80 | with tqdm(disable=not verbose, leave=False, unit='samples', total=input_model_num_samples,
81 | desc='Perceptual Path Length') as t, torch.no_grad():
82 | for begin_id in range(0, input_model_num_samples, batch_size):
83 | end_id = min(begin_id + batch_size, input_model_num_samples)
84 | batch_sz = end_id - begin_id
85 |
86 | batch_lat_e0 = lat_e0[begin_id:end_id]
87 | batch_lat_e1 = lat_e1[begin_id:end_id]
88 | if is_cond:
89 | batch_labels = labels[begin_id:end_id]
90 |
91 | if is_cuda:
92 | batch_lat_e0 = batch_lat_e0.cuda(non_blocking=True)
93 | batch_lat_e1 = batch_lat_e1.cuda(non_blocking=True)
94 | if is_cond:
95 | batch_labels = batch_labels.cuda(non_blocking=True)
96 |
97 | if is_cond:
98 | rgb_e01 = model.forward(
99 | torch.cat((batch_lat_e0, batch_lat_e1), dim=0),
100 | torch.cat((batch_labels, batch_labels), dim=0)
101 | )
102 | else:
103 | rgb_e01 = model.forward(
104 | torch.cat((batch_lat_e0, batch_lat_e1), dim=0)
105 | )
106 | rgb_e0, rgb_e1 = rgb_e01.chunk(2)
107 |
108 | sim = sample_similarity(rgb_e0, rgb_e1)
109 | dist_lat_e01 = sim / (epsilon ** 2)
110 | distances.append(dist_lat_e01.cpu().numpy())
111 |
112 | t.update(batch_sz)
113 |
114 | distances = np.concatenate(distances, axis=0)
115 |
116 | cond, lo, hi = None, None, None
117 | if discard_percentile_lower is not None:
118 | lo = np.percentile(distances, discard_percentile_lower, interpolation='lower')
119 | cond = lo <= distances
120 | if discard_percentile_higher is not None:
121 | hi = np.percentile(distances, discard_percentile_higher, interpolation='higher')
122 | cond = np.logical_and(cond, distances <= hi)
123 | if cond is not None:
124 | distances = np.extract(cond, distances)
125 |
126 | out = {
127 | KEY_METRIC_PPL_MEAN: float(np.mean(distances)),
128 | KEY_METRIC_PPL_STD: float(np.std(distances))
129 | }
130 | if reduction == 'none':
131 | out[KEY_METRIC_PPL_RAW] = distances
132 |
133 | vprint(verbose, f'Perceptual Path Length: {out[KEY_METRIC_PPL_MEAN]} ± {out[KEY_METRIC_PPL_STD]}')
134 |
135 | return out
136 |
--------------------------------------------------------------------------------
/metrics/noise.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def batch_normalize_last_dim(v, eps=1e-7):
5 | return v / (v ** 2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps)
6 |
7 |
8 | def random_normal(rng, shape):
9 | return torch.from_numpy(rng.randn(*shape)).float()
10 |
11 |
12 | def random_unit(rng, shape):
13 | return batch_normalize_last_dim(torch.from_numpy(rng.rand(*shape)).float())
14 |
15 |
16 | def random_uniform_0_1(rng, shape):
17 | return torch.from_numpy(rng.rand(*shape)).float()
18 |
19 |
20 | def batch_lerp(a, b, t):
21 | return a + (b - a) * t
22 |
23 |
24 | def batch_slerp_any(a, b, t, eps=1e-7):
25 | assert torch.is_tensor(a) and torch.is_tensor(b) and a.dim() >= 2 and a.shape == b.shape
26 | ndims, N = a.dim() - 1, a.shape[-1]
27 | a_1 = batch_normalize_last_dim(a, eps)
28 | b_1 = batch_normalize_last_dim(b, eps)
29 | d = (a_1 * b_1).sum(dim=-1, keepdim=True)
30 | mask_zero = (a_1.norm(dim=-1, keepdim=True) < eps) | (b_1.norm(dim=-1, keepdim=True) < eps)
31 | mask_collinear = (d > 1 - eps) | (d < -1 + eps)
32 | mask_lerp = (mask_zero | mask_collinear).repeat([1 for _ in range(ndims)] + [N])
33 | omega = d.acos()
34 | denom = omega.sin().clamp_min(eps)
35 | coef_a = ((1 - t) * omega).sin() / denom
36 | coef_b = (t * omega).sin() / denom
37 | out = coef_a * a + coef_b * b
38 | out[mask_lerp] = batch_lerp(a, b, t)[mask_lerp]
39 | return out
40 |
41 |
42 | def batch_slerp_unit(a, b, t, eps=1e-7):
43 | out = batch_slerp_any(a, b, t, eps)
44 | out = batch_normalize_last_dim(out, eps)
45 | return out
46 |
--------------------------------------------------------------------------------
/metrics/sample_similarity_base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from metrics.helpers import vassert
4 |
5 |
6 | class SampleSimilarityBase(nn.Module):
7 | def __init__(self, name):
8 | """
9 | Base class for samples similarity measures that can be used in :func:`calculate_metrics`.
10 |
11 | Args:
12 |
13 | name (str): Unique name of the subclassed sample similarity measure, must be the same as used in
14 | :func:`register_sample_similarity`.
15 | """
16 | super(SampleSimilarityBase, self).__init__()
17 | vassert(type(name) is str, 'Sample similarity name must be a string')
18 | self.name = name
19 |
20 | def get_name(self):
21 | return self.name
22 |
23 | def forward(self, *args):
24 | """
25 | Returns the value of sample similarity between the inputs.
26 | """
27 | raise NotImplementedError
28 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | #torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
2 | accelerate==0.16.0
3 | einops
4 | ema-pytorch
5 | pytorch-lightning==1.9.3
6 | scikit-learn
7 | scipy
8 | thop
9 | timm==0.6.12
10 | tensorboard
11 | fvcore
12 | albumentations
13 | omegaconf
14 | numpy==1.23.5
--------------------------------------------------------------------------------
/taming/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/taming/__init__.py
--------------------------------------------------------------------------------
/taming/data/ade20k.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import albumentations
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 | from taming.data.sflckr import SegmentationBase # for examples included in repo
9 |
10 |
11 | class Examples(SegmentationBase):
12 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
13 | super().__init__(data_csv="data/ade20k_examples.txt",
14 | data_root="data/ade20k_images",
15 | segmentation_root="data/ade20k_segmentations",
16 | size=size, random_crop=random_crop,
17 | interpolation=interpolation,
18 | n_labels=151, shift_segmentation=False)
19 |
20 |
21 | # With semantic map and scene label
22 | class ADE20kBase(Dataset):
23 | def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
24 | self.split = self.get_split()
25 | self.n_labels = 151 # unknown + 150
26 | self.data_csv = {"train": "data/ade20k_train.txt",
27 | "validation": "data/ade20k_test.txt"}[self.split]
28 | self.data_root = "data/ade20k_root"
29 | with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
30 | self.scene_categories = f.read().splitlines()
31 | self.scene_categories = dict(line.split() for line in self.scene_categories)
32 | with open(self.data_csv, "r") as f:
33 | self.image_paths = f.read().splitlines()
34 | self._length = len(self.image_paths)
35 | self.labels = {
36 | "relative_file_path_": [l for l in self.image_paths],
37 | "file_path_": [os.path.join(self.data_root, "images", l)
38 | for l in self.image_paths],
39 | "relative_segmentation_path_": [l.replace(".jpg", ".png")
40 | for l in self.image_paths],
41 | "segmentation_path_": [os.path.join(self.data_root, "annotations",
42 | l.replace(".jpg", ".png"))
43 | for l in self.image_paths],
44 | "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
45 | for l in self.image_paths],
46 | }
47 |
48 | size = None if size is not None and size<=0 else size
49 | self.size = size
50 | if crop_size is None:
51 | self.crop_size = size if size is not None else None
52 | else:
53 | self.crop_size = crop_size
54 | if self.size is not None:
55 | self.interpolation = interpolation
56 | self.interpolation = {
57 | "nearest": cv2.INTER_NEAREST,
58 | "bilinear": cv2.INTER_LINEAR,
59 | "bicubic": cv2.INTER_CUBIC,
60 | "area": cv2.INTER_AREA,
61 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
62 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
63 | interpolation=self.interpolation)
64 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
65 | interpolation=cv2.INTER_NEAREST)
66 |
67 | if crop_size is not None:
68 | self.center_crop = not random_crop
69 | if self.center_crop:
70 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
71 | else:
72 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
73 | self.preprocessor = self.cropper
74 |
75 | def __len__(self):
76 | return self._length
77 |
78 | def __getitem__(self, i):
79 | example = dict((k, self.labels[k][i]) for k in self.labels)
80 | image = Image.open(example["file_path_"])
81 | if not image.mode == "RGB":
82 | image = image.convert("RGB")
83 | image = np.array(image).astype(np.uint8)
84 | if self.size is not None:
85 | image = self.image_rescaler(image=image)["image"]
86 | segmentation = Image.open(example["segmentation_path_"])
87 | segmentation = np.array(segmentation).astype(np.uint8)
88 | if self.size is not None:
89 | segmentation = self.segmentation_rescaler(image=segmentation)["image"]
90 | if self.size is not None:
91 | processed = self.preprocessor(image=image, mask=segmentation)
92 | else:
93 | processed = {"image": image, "mask": segmentation}
94 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
95 | segmentation = processed["mask"]
96 | onehot = np.eye(self.n_labels)[segmentation]
97 | example["segmentation"] = onehot
98 | return example
99 |
100 |
101 | class ADE20kTrain(ADE20kBase):
102 | # default to random_crop=True
103 | def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
104 | super().__init__(config=config, size=size, random_crop=random_crop,
105 | interpolation=interpolation, crop_size=crop_size)
106 |
107 | def get_split(self):
108 | return "train"
109 |
110 |
111 | class ADE20kValidation(ADE20kBase):
112 | def get_split(self):
113 | return "validation"
114 |
115 |
116 | if __name__ == "__main__":
117 | dset = ADE20kValidation()
118 | ex = dset[0]
119 | for k in ["image", "scene_category", "segmentation"]:
120 | print(type(ex[k]))
121 | try:
122 | print(ex[k].shape)
123 | except:
124 | print(ex[k])
125 |
--------------------------------------------------------------------------------
/taming/data/annotated_objects_coco.py:
--------------------------------------------------------------------------------
1 | import json
2 | from itertools import chain
3 | from pathlib import Path
4 | from typing import Iterable, Dict, List, Callable, Any
5 | from collections import defaultdict
6 |
7 | from tqdm import tqdm
8 |
9 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
10 | from taming.data.helper_types import Annotation, ImageDescription, Category
11 |
12 | COCO_PATH_STRUCTURE = {
13 | 'train': {
14 | 'top_level': '',
15 | 'instances_annotations': 'annotations/instances_train2017.json',
16 | 'stuff_annotations': 'annotations/stuff_train2017.json',
17 | 'files': 'train2017'
18 | },
19 | 'validation': {
20 | 'top_level': '',
21 | 'instances_annotations': 'annotations/instances_val2017.json',
22 | 'stuff_annotations': 'annotations/stuff_val2017.json',
23 | 'files': 'val2017'
24 | }
25 | }
26 |
27 |
28 | def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
29 | return {
30 | str(img['id']): ImageDescription(
31 | id=img['id'],
32 | license=img.get('license'),
33 | file_name=img['file_name'],
34 | coco_url=img['coco_url'],
35 | original_size=(img['width'], img['height']),
36 | date_captured=img.get('date_captured'),
37 | flickr_url=img.get('flickr_url')
38 | )
39 | for img in description_json
40 | }
41 |
42 |
43 | def load_categories(category_json: Iterable) -> Dict[str, Category]:
44 | return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
45 | for cat in category_json if cat['name'] != 'other'}
46 |
47 |
48 | def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
49 | category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
50 | annotations = defaultdict(list)
51 | total = sum(len(a) for a in annotations_json)
52 | for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
53 | image_id = str(ann['image_id'])
54 | if image_id not in image_descriptions:
55 | raise ValueError(f'image_id [{image_id}] has no image description.')
56 | category_id = ann['category_id']
57 | try:
58 | category_no = category_no_for_id(str(category_id))
59 | except KeyError:
60 | continue
61 |
62 | width, height = image_descriptions[image_id].original_size
63 | bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
64 |
65 | annotations[image_id].append(
66 | Annotation(
67 | id=ann['id'],
68 | area=bbox[2]*bbox[3], # use bbox area
69 | is_group_of=ann['iscrowd'],
70 | image_id=ann['image_id'],
71 | bbox=bbox,
72 | category_id=str(category_id),
73 | category_no=category_no
74 | )
75 | )
76 | return dict(annotations)
77 |
78 |
79 | class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
80 | def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
81 | """
82 | @param data_path: is the path to the following folder structure:
83 | coco/
84 | ├── annotations
85 | │ ├── instances_train2017.json
86 | │ ├── instances_val2017.json
87 | │ ├── stuff_train2017.json
88 | │ └── stuff_val2017.json
89 | ├── train2017
90 | │ ├── 000000000009.jpg
91 | │ ├── 000000000025.jpg
92 | │ └── ...
93 | ├── val2017
94 | │ ├── 000000000139.jpg
95 | │ ├── 000000000285.jpg
96 | │ └── ...
97 | @param: split: one of 'train' or 'validation'
98 | @param: desired image size (give square images)
99 | """
100 | super().__init__(**kwargs)
101 | self.use_things = use_things
102 | self.use_stuff = use_stuff
103 |
104 | with open(self.paths['instances_annotations']) as f:
105 | inst_data_json = json.load(f)
106 | with open(self.paths['stuff_annotations']) as f:
107 | stuff_data_json = json.load(f)
108 |
109 | category_jsons = []
110 | annotation_jsons = []
111 | if self.use_things:
112 | category_jsons.append(inst_data_json['categories'])
113 | annotation_jsons.append(inst_data_json['annotations'])
114 | if self.use_stuff:
115 | category_jsons.append(stuff_data_json['categories'])
116 | annotation_jsons.append(stuff_data_json['annotations'])
117 |
118 | self.categories = load_categories(chain(*category_jsons))
119 | self.filter_categories()
120 | self.setup_category_id_and_number()
121 |
122 | self.image_descriptions = load_image_descriptions(inst_data_json['images'])
123 | annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
124 | self.annotations = self.filter_object_number(annotations, self.min_object_area,
125 | self.min_objects_per_image, self.max_objects_per_image)
126 | self.image_ids = list(self.annotations.keys())
127 | self.clean_up_annotations_and_image_descriptions()
128 |
129 | def get_path_structure(self) -> Dict[str, str]:
130 | if self.split not in COCO_PATH_STRUCTURE:
131 | raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
132 | return COCO_PATH_STRUCTURE[self.split]
133 |
134 | def get_image_path(self, image_id: str) -> Path:
135 | return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
136 |
137 | def get_image_description(self, image_id: str) -> Dict[str, Any]:
138 | # noinspection PyProtectedMember
139 | return self.image_descriptions[image_id]._asdict()
140 |
--------------------------------------------------------------------------------
/taming/data/annotated_objects_open_images.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from csv import DictReader, reader as TupleReader
3 | from pathlib import Path
4 | from typing import Dict, List, Any
5 | import warnings
6 |
7 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
8 | from taming.data.helper_types import Annotation, Category
9 | from tqdm import tqdm
10 |
11 | OPEN_IMAGES_STRUCTURE = {
12 | 'train': {
13 | 'top_level': '',
14 | 'class_descriptions': 'class-descriptions-boxable.csv',
15 | 'annotations': 'oidv6-train-annotations-bbox.csv',
16 | 'file_list': 'train-images-boxable.csv',
17 | 'files': 'train'
18 | },
19 | 'validation': {
20 | 'top_level': '',
21 | 'class_descriptions': 'class-descriptions-boxable.csv',
22 | 'annotations': 'validation-annotations-bbox.csv',
23 | 'file_list': 'validation-images.csv',
24 | 'files': 'validation'
25 | },
26 | 'test': {
27 | 'top_level': '',
28 | 'class_descriptions': 'class-descriptions-boxable.csv',
29 | 'annotations': 'test-annotations-bbox.csv',
30 | 'file_list': 'test-images.csv',
31 | 'files': 'test'
32 | }
33 | }
34 |
35 |
36 | def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
37 | category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
38 | annotations: Dict[str, List[Annotation]] = defaultdict(list)
39 | with open(descriptor_path) as file:
40 | reader = DictReader(file)
41 | for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
42 | width = float(row['XMax']) - float(row['XMin'])
43 | height = float(row['YMax']) - float(row['YMin'])
44 | area = width * height
45 | category_id = row['LabelName']
46 | if category_id in category_mapping:
47 | category_id = category_mapping[category_id]
48 | if area >= min_object_area and category_id in category_no_for_id:
49 | annotations[row['ImageID']].append(
50 | Annotation(
51 | id=i,
52 | image_id=row['ImageID'],
53 | source=row['Source'],
54 | category_id=category_id,
55 | category_no=category_no_for_id[category_id],
56 | confidence=float(row['Confidence']),
57 | bbox=(float(row['XMin']), float(row['YMin']), width, height),
58 | area=area,
59 | is_occluded=bool(int(row['IsOccluded'])),
60 | is_truncated=bool(int(row['IsTruncated'])),
61 | is_group_of=bool(int(row['IsGroupOf'])),
62 | is_depiction=bool(int(row['IsDepiction'])),
63 | is_inside=bool(int(row['IsInside']))
64 | )
65 | )
66 | if 'train' in str(descriptor_path) and i < 14000000:
67 | warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
68 | return dict(annotations)
69 |
70 |
71 | def load_image_ids(csv_path: Path) -> List[str]:
72 | with open(csv_path) as file:
73 | reader = DictReader(file)
74 | return [row['image_name'] for row in reader]
75 |
76 |
77 | def load_categories(csv_path: Path) -> Dict[str, Category]:
78 | with open(csv_path) as file:
79 | reader = TupleReader(file)
80 | return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
81 |
82 |
83 | class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
84 | def __init__(self, use_additional_parameters: bool, **kwargs):
85 | """
86 | @param data_path: is the path to the following folder structure:
87 | open_images/
88 | │ oidv6-train-annotations-bbox.csv
89 | ├── class-descriptions-boxable.csv
90 | ├── oidv6-train-annotations-bbox.csv
91 | ├── test
92 | │ ├── 000026e7ee790996.jpg
93 | │ ├── 000062a39995e348.jpg
94 | │ └── ...
95 | ├── test-annotations-bbox.csv
96 | ├── test-images.csv
97 | ├── train
98 | │ ├── 000002b66c9c498e.jpg
99 | │ ├── 000002b97e5471a0.jpg
100 | │ └── ...
101 | ├── train-images-boxable.csv
102 | ├── validation
103 | │ ├── 0001eeaf4aed83f9.jpg
104 | │ ├── 0004886b7d043cfd.jpg
105 | │ └── ...
106 | ├── validation-annotations-bbox.csv
107 | └── validation-images.csv
108 | @param: split: one of 'train', 'validation' or 'test'
109 | @param: desired image size (returns square images)
110 | """
111 |
112 | super().__init__(**kwargs)
113 | self.use_additional_parameters = use_additional_parameters
114 |
115 | self.categories = load_categories(self.paths['class_descriptions'])
116 | self.filter_categories()
117 | self.setup_category_id_and_number()
118 |
119 | self.image_descriptions = {}
120 | annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
121 | self.category_number)
122 | self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
123 | self.max_objects_per_image)
124 | self.image_ids = list(self.annotations.keys())
125 | self.clean_up_annotations_and_image_descriptions()
126 |
127 | def get_path_structure(self) -> Dict[str, str]:
128 | if self.split not in OPEN_IMAGES_STRUCTURE:
129 | raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
130 | return OPEN_IMAGES_STRUCTURE[self.split]
131 |
132 | def get_image_path(self, image_id: str) -> Path:
133 | return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
134 |
135 | def get_image_description(self, image_id: str) -> Dict[str, Any]:
136 | image_path = self.get_image_path(image_id)
137 | return {'file_path': str(image_path), 'file_name': image_path.name}
138 |
--------------------------------------------------------------------------------
/taming/data/base.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import numpy as np
3 | import albumentations
4 | from PIL import Image
5 | from torch.utils.data import Dataset, ConcatDataset
6 |
7 |
8 | class ConcatDatasetWithIndex(ConcatDataset):
9 | """Modified from original pytorch code to return dataset idx"""
10 | def __getitem__(self, idx):
11 | if idx < 0:
12 | if -idx > len(self):
13 | raise ValueError("absolute value of index should not exceed dataset length")
14 | idx = len(self) + idx
15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
16 | if dataset_idx == 0:
17 | sample_idx = idx
18 | else:
19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
20 | return self.datasets[dataset_idx][sample_idx], dataset_idx
21 |
22 |
23 | class ImagePaths(Dataset):
24 | def __init__(self, paths, size=None, random_crop=False, labels=None):
25 | self.size = size
26 | self.random_crop = random_crop
27 |
28 | self.labels = dict() if labels is None else labels
29 | self.labels["file_path_"] = paths
30 | self._length = len(paths)
31 |
32 | if self.size is not None and self.size > 0:
33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
34 | if not self.random_crop:
35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
36 | else:
37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
39 | else:
40 | self.preprocessor = lambda **kwargs: kwargs
41 |
42 | def __len__(self):
43 | return self._length
44 |
45 | def preprocess_image(self, image_path):
46 | image = Image.open(image_path)
47 | if not image.mode == "RGB":
48 | image = image.convert("RGB")
49 | image = np.array(image).astype(np.uint8)
50 | image = self.preprocessor(image=image)["image"]
51 | image = (image/127.5 - 1.0).astype(np.float32)
52 | return image
53 |
54 | def __getitem__(self, i):
55 | example = dict()
56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i])
57 | for k in self.labels:
58 | example[k] = self.labels[k][i]
59 | return example
60 |
61 |
62 | class NumpyPaths(ImagePaths):
63 | def preprocess_image(self, image_path):
64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
65 | image = np.transpose(image, (1,2,0))
66 | image = Image.fromarray(image, mode="RGB")
67 | image = np.array(image).astype(np.uint8)
68 | image = self.preprocessor(image=image)["image"]
69 | image = (image/127.5 - 1.0).astype(np.float32)
70 | return image
71 |
--------------------------------------------------------------------------------
/taming/data/conditional_builder/objects_bbox.py:
--------------------------------------------------------------------------------
1 | from itertools import cycle
2 | from typing import List, Tuple, Callable, Optional
3 |
4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
5 | from more_itertools.recipes import grouper
6 | from taming.data.image_transforms import convert_pil_to_tensor
7 | from torch import LongTensor, Tensor
8 |
9 | from taming.data.helper_types import BoundingBox, Annotation
10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
12 | pad_list, get_plot_font_size, absolute_bbox
13 |
14 |
15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
16 | @property
17 | def object_descriptor_length(self) -> int:
18 | return 3
19 |
20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
21 | object_triples = [
22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
23 | for ann in annotations
24 | ]
25 | empty_triple = (self.none, self.none, self.none)
26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
27 | return object_triples
28 |
29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
30 | conditional_list = conditional.tolist()
31 | crop_coordinates = None
32 | if self.encode_crop:
33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
34 | conditional_list = conditional_list[:-2]
35 | object_triples = grouper(conditional_list, 3)
36 | assert conditional.shape[0] == self.embedding_dim
37 | return [
38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
39 | for object_triple in object_triples if object_triple[0] != self.none
40 | ], crop_coordinates
41 |
42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
44 | plot = pil_image.new('RGB', figure_size, WHITE)
45 | draw = pil_img_draw.Draw(plot)
46 | font = ImageFont.truetype(
47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
48 | size=get_plot_font_size(font_size, figure_size)
49 | )
50 | width, height = plot.size
51 | description, crop_coordinates = self.inverse_build(conditional)
52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
53 | annotation = self.representation_to_annotation(representation)
54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
55 | bbox = absolute_bbox(bbox, width, height)
56 | draw.rectangle(bbox, outline=color, width=line_width)
57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
58 | if crop_coordinates is not None:
59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
60 | return convert_pil_to_tensor(plot) / 127.5 - 1.
61 |
--------------------------------------------------------------------------------
/taming/data/conditional_builder/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from typing import List, Any, Tuple, Optional
3 |
4 | from taming.data.helper_types import BoundingBox, Annotation
5 |
6 | # source: seaborn, color palette tab10
7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
9 | BLACK = (0, 0, 0)
10 | GRAY_75 = (63, 63, 63)
11 | GRAY_50 = (127, 127, 127)
12 | GRAY_25 = (191, 191, 191)
13 | WHITE = (255, 255, 255)
14 | FULL_CROP = (0., 0., 1., 1.)
15 |
16 |
17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
18 | """
19 | Give intersection area of two rectangles.
20 | @param rectangle1: (x0, y0, w, h) of first rectangle
21 | @param rectangle2: (x0, y0, w, h) of second rectangle
22 | """
23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
27 | return x_overlap * y_overlap
28 |
29 |
30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
32 |
33 |
34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
35 | bbox = relative_bbox
36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
38 |
39 |
40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
42 |
43 |
44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
45 | List[Annotation]:
46 | def clamp(x: float):
47 | return max(min(x, 1.), 0.)
48 |
49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0)
53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0)
54 | if flip:
55 | x0 = 1 - (x0 + w)
56 | return x0, y0, w, h
57 |
58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
59 |
60 |
61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
63 |
64 |
65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
66 | sl = slice(1) if short else slice(None)
67 | string = ''
68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
69 | return string
70 | if annotation.is_group_of:
71 | string += 'group'[sl] + ','
72 | if annotation.is_occluded:
73 | string += 'occluded'[sl] + ','
74 | if annotation.is_depiction:
75 | string += 'depiction'[sl] + ','
76 | if annotation.is_inside:
77 | string += 'inside'[sl]
78 | return '(' + string.strip(",") + ')'
79 |
80 |
81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
82 | if font_size is None:
83 | font_size = 10
84 | if max(figure_size) >= 256:
85 | font_size = 12
86 | if max(figure_size) >= 512:
87 | font_size = 15
88 | return font_size
89 |
90 |
91 | def get_circle_size(figure_size: Tuple[int, int]) -> int:
92 | circle_size = 2
93 | if max(figure_size) >= 256:
94 | circle_size = 3
95 | if max(figure_size) >= 512:
96 | circle_size = 4
97 | return circle_size
98 |
99 |
100 | def load_object_from_string(object_string: str) -> Any:
101 | """
102 | Source: https://stackoverflow.com/a/10773699
103 | """
104 | module_name, class_name = object_string.rsplit(".", 1)
105 | return getattr(importlib.import_module(module_name), class_name)
106 |
--------------------------------------------------------------------------------
/taming/data/custom.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import albumentations
4 | from torch.utils.data import Dataset
5 |
6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7 |
8 |
9 | class CustomBase(Dataset):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__()
12 | self.data = None
13 |
14 | def __len__(self):
15 | return len(self.data)
16 |
17 | def __getitem__(self, i):
18 | example = self.data[i]
19 | return example
20 |
21 |
22 |
23 | class CustomTrain(CustomBase):
24 | def __init__(self, size, training_images_list_file):
25 | super().__init__()
26 | with open(training_images_list_file, "r") as f:
27 | paths = f.read().splitlines()
28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
29 |
30 |
31 | class CustomTest(CustomBase):
32 | def __init__(self, size, test_images_list_file):
33 | super().__init__()
34 | with open(test_images_list_file, "r") as f:
35 | paths = f.read().splitlines()
36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
37 |
38 |
39 |
--------------------------------------------------------------------------------
/taming/data/faceshq.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import albumentations
4 | from torch.utils.data import Dataset
5 |
6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7 |
8 |
9 | class FacesBase(Dataset):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__()
12 | self.data = None
13 | self.keys = None
14 |
15 | def __len__(self):
16 | return len(self.data)
17 |
18 | def __getitem__(self, i):
19 | example = self.data[i]
20 | ex = {}
21 | if self.keys is not None:
22 | for k in self.keys:
23 | ex[k] = example[k]
24 | else:
25 | ex = example
26 | return ex
27 |
28 |
29 | class CelebAHQTrain(FacesBase):
30 | def __init__(self, size, keys=None):
31 | super().__init__()
32 | root = "data/celebahq"
33 | with open("data/celebahqtrain.txt", "r") as f:
34 | relpaths = f.read().splitlines()
35 | paths = [os.path.join(root, relpath) for relpath in relpaths]
36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
37 | self.keys = keys
38 |
39 |
40 | class CelebAHQValidation(FacesBase):
41 | def __init__(self, size, keys=None):
42 | super().__init__()
43 | root = "data/celebahq"
44 | with open("data/celebahqvalidation.txt", "r") as f:
45 | relpaths = f.read().splitlines()
46 | paths = [os.path.join(root, relpath) for relpath in relpaths]
47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
48 | self.keys = keys
49 |
50 |
51 | class FFHQTrain(FacesBase):
52 | def __init__(self, size, keys=None):
53 | super().__init__()
54 | root = "data/ffhq"
55 | with open("data/ffhqtrain.txt", "r") as f:
56 | relpaths = f.read().splitlines()
57 | paths = [os.path.join(root, relpath) for relpath in relpaths]
58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
59 | self.keys = keys
60 |
61 |
62 | class FFHQValidation(FacesBase):
63 | def __init__(self, size, keys=None):
64 | super().__init__()
65 | root = "data/ffhq"
66 | with open("data/ffhqvalidation.txt", "r") as f:
67 | relpaths = f.read().splitlines()
68 | paths = [os.path.join(root, relpath) for relpath in relpaths]
69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
70 | self.keys = keys
71 |
72 |
73 | class FacesHQTrain(Dataset):
74 | # CelebAHQ [0] + FFHQ [1]
75 | def __init__(self, size, keys=None, crop_size=None, coord=False):
76 | d1 = CelebAHQTrain(size=size, keys=keys)
77 | d2 = FFHQTrain(size=size, keys=keys)
78 | self.data = ConcatDatasetWithIndex([d1, d2])
79 | self.coord = coord
80 | if crop_size is not None:
81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
82 | if self.coord:
83 | self.cropper = albumentations.Compose([self.cropper],
84 | additional_targets={"coord": "image"})
85 |
86 | def __len__(self):
87 | return len(self.data)
88 |
89 | def __getitem__(self, i):
90 | ex, y = self.data[i]
91 | if hasattr(self, "cropper"):
92 | if not self.coord:
93 | out = self.cropper(image=ex["image"])
94 | ex["image"] = out["image"]
95 | else:
96 | h,w,_ = ex["image"].shape
97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w)
98 | out = self.cropper(image=ex["image"], coord=coord)
99 | ex["image"] = out["image"]
100 | ex["coord"] = out["coord"]
101 | ex["class"] = y
102 | return ex
103 |
104 |
105 | class FacesHQValidation(Dataset):
106 | # CelebAHQ [0] + FFHQ [1]
107 | def __init__(self, size, keys=None, crop_size=None, coord=False):
108 | d1 = CelebAHQValidation(size=size, keys=keys)
109 | d2 = FFHQValidation(size=size, keys=keys)
110 | self.data = ConcatDatasetWithIndex([d1, d2])
111 | self.coord = coord
112 | if crop_size is not None:
113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
114 | if self.coord:
115 | self.cropper = albumentations.Compose([self.cropper],
116 | additional_targets={"coord": "image"})
117 |
118 | def __len__(self):
119 | return len(self.data)
120 |
121 | def __getitem__(self, i):
122 | ex, y = self.data[i]
123 | if hasattr(self, "cropper"):
124 | if not self.coord:
125 | out = self.cropper(image=ex["image"])
126 | ex["image"] = out["image"]
127 | else:
128 | h,w,_ = ex["image"].shape
129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w)
130 | out = self.cropper(image=ex["image"], coord=coord)
131 | ex["image"] = out["image"]
132 | ex["coord"] = out["coord"]
133 | ex["class"] = y
134 | return ex
135 |
--------------------------------------------------------------------------------
/taming/data/helper_types.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple, Optional, NamedTuple, Union
2 | from PIL.Image import Image as pil_image
3 | from torch import Tensor
4 |
5 | try:
6 | from typing import Literal
7 | except ImportError:
8 | from typing_extensions import Literal
9 |
10 | Image = Union[Tensor, pil_image]
11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d']
13 | SplitType = Literal['train', 'validation', 'test']
14 |
15 |
16 | class ImageDescription(NamedTuple):
17 | id: int
18 | file_name: str
19 | original_size: Tuple[int, int] # w, h
20 | url: Optional[str] = None
21 | license: Optional[int] = None
22 | coco_url: Optional[str] = None
23 | date_captured: Optional[str] = None
24 | flickr_url: Optional[str] = None
25 | flickr_id: Optional[str] = None
26 | coco_id: Optional[str] = None
27 |
28 |
29 | class Category(NamedTuple):
30 | id: str
31 | super_category: Optional[str]
32 | name: str
33 |
34 |
35 | class Annotation(NamedTuple):
36 | area: float
37 | image_id: str
38 | bbox: BoundingBox
39 | category_no: int
40 | category_id: str
41 | id: Optional[int] = None
42 | source: Optional[str] = None
43 | confidence: Optional[float] = None
44 | is_group_of: Optional[bool] = None
45 | is_truncated: Optional[bool] = None
46 | is_occluded: Optional[bool] = None
47 | is_depiction: Optional[bool] = None
48 | is_inside: Optional[bool] = None
49 | segmentation: Optional[Dict] = None
50 |
--------------------------------------------------------------------------------
/taming/data/image_transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import warnings
3 | from typing import Union
4 |
5 | import torch
6 | from torch import Tensor
7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
8 | from torchvision.transforms.functional import _get_image_size as get_image_size
9 |
10 | from taming.data.helper_types import BoundingBox, Image
11 |
12 | pil_to_tensor = PILToTensor()
13 |
14 |
15 | def convert_pil_to_tensor(image: Image) -> Tensor:
16 | with warnings.catch_warnings():
17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
18 | warnings.simplefilter("ignore")
19 | return pil_to_tensor(image)
20 |
21 |
22 | class RandomCrop1dReturnCoordinates(RandomCrop):
23 | def forward(self, img: Image) -> (BoundingBox, Image):
24 | """
25 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
26 | Args:
27 | img (PIL Image or Tensor): Image to be cropped.
28 |
29 | Returns:
30 | Bounding box: x0, y0, w, h
31 | PIL Image or Tensor: Cropped image.
32 |
33 | Based on:
34 | torchvision.transforms.RandomCrop, torchvision 1.7.0
35 | """
36 | if self.padding is not None:
37 | img = F.pad(img, self.padding, self.fill, self.padding_mode)
38 |
39 | width, height = get_image_size(img)
40 | # pad the width if needed
41 | if self.pad_if_needed and width < self.size[1]:
42 | padding = [self.size[1] - width, 0]
43 | img = F.pad(img, padding, self.fill, self.padding_mode)
44 | # pad the height if needed
45 | if self.pad_if_needed and height < self.size[0]:
46 | padding = [0, self.size[0] - height]
47 | img = F.pad(img, padding, self.fill, self.padding_mode)
48 |
49 | i, j, h, w = self.get_params(img, self.size)
50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
51 | return bbox, F.crop(img, i, j, h, w)
52 |
53 |
54 | class Random2dCropReturnCoordinates(torch.nn.Module):
55 | """
56 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
57 | Args:
58 | img (PIL Image or Tensor): Image to be cropped.
59 |
60 | Returns:
61 | Bounding box: x0, y0, w, h
62 | PIL Image or Tensor: Cropped image.
63 |
64 | Based on:
65 | torchvision.transforms.RandomCrop, torchvision 1.7.0
66 | """
67 |
68 | def __init__(self, min_size: int):
69 | super().__init__()
70 | self.min_size = min_size
71 |
72 | def forward(self, img: Image) -> (BoundingBox, Image):
73 | width, height = get_image_size(img)
74 | max_size = min(width, height)
75 | if max_size <= self.min_size:
76 | size = max_size
77 | else:
78 | size = random.randint(self.min_size, max_size)
79 | top = random.randint(0, height - size)
80 | left = random.randint(0, width - size)
81 | bbox = left / width, top / height, size / width, size / height
82 | return bbox, F.crop(img, top, left, size, size)
83 |
84 |
85 | class CenterCropReturnCoordinates(CenterCrop):
86 | @staticmethod
87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
88 | if width > height:
89 | w = height / width
90 | h = 1.0
91 | x0 = 0.5 - w / 2
92 | y0 = 0.
93 | else:
94 | w = 1.0
95 | h = width / height
96 | x0 = 0.
97 | y0 = 0.5 - h / 2
98 | return x0, y0, w, h
99 |
100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
101 | """
102 | Additionally to cropping, returns the relative coordinates of the crop bounding box.
103 | Args:
104 | img (PIL Image or Tensor): Image to be cropped.
105 |
106 | Returns:
107 | Bounding box: x0, y0, w, h
108 | PIL Image or Tensor: Cropped image.
109 | Based on:
110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
111 | """
112 | width, height = get_image_size(img)
113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
114 |
115 |
116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip):
117 | def forward(self, img: Image) -> (bool, Image):
118 | """
119 | Additionally to flipping, returns a boolean whether it was flipped or not.
120 | Args:
121 | img (PIL Image or Tensor): Image to be flipped.
122 |
123 | Returns:
124 | flipped: whether the image was flipped or not
125 | PIL Image or Tensor: Randomly flipped image.
126 |
127 | Based on:
128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
129 | """
130 | if torch.rand(1) < self.p:
131 | return True, F.hflip(img)
132 | return False, img
133 |
--------------------------------------------------------------------------------
/taming/data/sflckr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import albumentations
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class SegmentationBase(Dataset):
10 | def __init__(self,
11 | data_csv, data_root, segmentation_root,
12 | size=None, random_crop=False, interpolation="bicubic",
13 | n_labels=182, shift_segmentation=False,
14 | ):
15 | self.n_labels = n_labels
16 | self.shift_segmentation = shift_segmentation
17 | self.data_csv = data_csv
18 | self.data_root = data_root
19 | self.segmentation_root = segmentation_root
20 | with open(self.data_csv, "r") as f:
21 | self.image_paths = f.read().splitlines()
22 | self._length = len(self.image_paths)
23 | self.labels = {
24 | "relative_file_path_": [l for l in self.image_paths],
25 | "file_path_": [os.path.join(self.data_root, l)
26 | for l in self.image_paths],
27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
28 | for l in self.image_paths]
29 | }
30 |
31 | size = None if size is not None and size<=0 else size
32 | self.size = size
33 | if self.size is not None:
34 | self.interpolation = interpolation
35 | self.interpolation = {
36 | "nearest": cv2.INTER_NEAREST,
37 | "bilinear": cv2.INTER_LINEAR,
38 | "bicubic": cv2.INTER_CUBIC,
39 | "area": cv2.INTER_AREA,
40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
42 | interpolation=self.interpolation)
43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
44 | interpolation=cv2.INTER_NEAREST)
45 | self.center_crop = not random_crop
46 | if self.center_crop:
47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
48 | else:
49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
50 | self.preprocessor = self.cropper
51 |
52 | def __len__(self):
53 | return self._length
54 |
55 | def __getitem__(self, i):
56 | example = dict((k, self.labels[k][i]) for k in self.labels)
57 | image = Image.open(example["file_path_"])
58 | if not image.mode == "RGB":
59 | image = image.convert("RGB")
60 | image = np.array(image).astype(np.uint8)
61 | if self.size is not None:
62 | image = self.image_rescaler(image=image)["image"]
63 | segmentation = Image.open(example["segmentation_path_"])
64 | assert segmentation.mode == "L", segmentation.mode
65 | segmentation = np.array(segmentation).astype(np.uint8)
66 | if self.shift_segmentation:
67 | # used to support segmentations containing unlabeled==255 label
68 | segmentation = segmentation+1
69 | if self.size is not None:
70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"]
71 | if self.size is not None:
72 | processed = self.preprocessor(image=image,
73 | mask=segmentation
74 | )
75 | else:
76 | processed = {"image": image,
77 | "mask": segmentation
78 | }
79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
80 | segmentation = processed["mask"]
81 | onehot = np.eye(self.n_labels)[segmentation]
82 | example["segmentation"] = onehot
83 | return example
84 |
85 |
86 | class Examples(SegmentationBase):
87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
88 | super().__init__(data_csv="data/sflckr_examples.txt",
89 | data_root="data/sflckr_images",
90 | segmentation_root="data/sflckr_segmentations",
91 | size=size, random_crop=random_crop, interpolation=interpolation)
92 |
--------------------------------------------------------------------------------
/taming/data/utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import tarfile
4 | import urllib
5 | import zipfile
6 | from pathlib import Path
7 |
8 | import numpy as np
9 | import torch
10 | from taming.data.helper_types import Annotation
11 | from torch._six import string_classes
12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
13 | from tqdm import tqdm
14 |
15 |
16 | def unpack(path):
17 | if path.endswith("tar.gz"):
18 | with tarfile.open(path, "r:gz") as tar:
19 | tar.extractall(path=os.path.split(path)[0])
20 | elif path.endswith("tar"):
21 | with tarfile.open(path, "r:") as tar:
22 | tar.extractall(path=os.path.split(path)[0])
23 | elif path.endswith("zip"):
24 | with zipfile.ZipFile(path, "r") as f:
25 | f.extractall(path=os.path.split(path)[0])
26 | else:
27 | raise NotImplementedError(
28 | "Unknown file extension: {}".format(os.path.splitext(path)[1])
29 | )
30 |
31 |
32 | def reporthook(bar):
33 | """tqdm progress bar for downloads."""
34 |
35 | def hook(b=1, bsize=1, tsize=None):
36 | if tsize is not None:
37 | bar.total = tsize
38 | bar.update(b * bsize - bar.n)
39 |
40 | return hook
41 |
42 |
43 | def get_root(name):
44 | base = "data/"
45 | root = os.path.join(base, name)
46 | os.makedirs(root, exist_ok=True)
47 | return root
48 |
49 |
50 | def is_prepared(root):
51 | return Path(root).joinpath(".ready").exists()
52 |
53 |
54 | def mark_prepared(root):
55 | Path(root).joinpath(".ready").touch()
56 |
57 |
58 | def prompt_download(file_, source, target_dir, content_dir=None):
59 | targetpath = os.path.join(target_dir, file_)
60 | while not os.path.exists(targetpath):
61 | if content_dir is not None and os.path.exists(
62 | os.path.join(target_dir, content_dir)
63 | ):
64 | break
65 | print(
66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
67 | )
68 | if content_dir is not None:
69 | print(
70 | "Or place its content into '{}'.".format(
71 | os.path.join(target_dir, content_dir)
72 | )
73 | )
74 | input("Press Enter when done...")
75 | return targetpath
76 |
77 |
78 | def download_url(file_, url, target_dir):
79 | targetpath = os.path.join(target_dir, file_)
80 | os.makedirs(target_dir, exist_ok=True)
81 | with tqdm(
82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
83 | ) as bar:
84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
85 | return targetpath
86 |
87 |
88 | def download_urls(urls, target_dir):
89 | paths = dict()
90 | for fname, url in urls.items():
91 | outpath = download_url(fname, url, target_dir)
92 | paths[fname] = outpath
93 | return paths
94 |
95 |
96 | def quadratic_crop(x, bbox, alpha=1.0):
97 | """bbox is xmin, ymin, xmax, ymax"""
98 | im_h, im_w = x.shape[:2]
99 | bbox = np.array(bbox, dtype=np.float32)
100 | bbox = np.clip(bbox, 0, max(im_h, im_w))
101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
102 | w = bbox[2] - bbox[0]
103 | h = bbox[3] - bbox[1]
104 | l = int(alpha * max(w, h))
105 | l = max(l, 2)
106 |
107 | required_padding = -1 * min(
108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
109 | )
110 | required_padding = int(np.ceil(required_padding))
111 | if required_padding > 0:
112 | padding = [
113 | [required_padding, required_padding],
114 | [required_padding, required_padding],
115 | ]
116 | padding += [[0, 0]] * (len(x.shape) - 2)
117 | x = np.pad(x, padding, "reflect")
118 | center = center[0] + required_padding, center[1] + required_padding
119 | xmin = int(center[0] - l / 2)
120 | ymin = int(center[1] - l / 2)
121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
122 |
123 |
124 | def custom_collate(batch):
125 | r"""source: pytorch 1.9.0, only one modification to original code """
126 |
127 | elem = batch[0]
128 | elem_type = type(elem)
129 | if isinstance(elem, torch.Tensor):
130 | out = None
131 | if torch.utils.data.get_worker_info() is not None:
132 | # If we're in a background process, concatenate directly into a
133 | # shared memory tensor to avoid an extra copy
134 | numel = sum([x.numel() for x in batch])
135 | storage = elem.storage()._new_shared(numel)
136 | out = elem.new(storage)
137 | return torch.stack(batch, 0, out=out)
138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
139 | and elem_type.__name__ != 'string_':
140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
141 | # array of string classes and object
142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype))
144 |
145 | return custom_collate([torch.as_tensor(b) for b in batch])
146 | elif elem.shape == (): # scalars
147 | return torch.as_tensor(batch)
148 | elif isinstance(elem, float):
149 | return torch.tensor(batch, dtype=torch.float64)
150 | elif isinstance(elem, int):
151 | return torch.tensor(batch)
152 | elif isinstance(elem, string_classes):
153 | return batch
154 | elif isinstance(elem, collections.abc.Mapping):
155 | return {key: custom_collate([d[key] for d in batch]) for key in elem}
156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
159 | return batch # added
160 | elif isinstance(elem, collections.abc.Sequence):
161 | # check to make sure that the elements in batch have consistent size
162 | it = iter(batch)
163 | elem_size = len(next(it))
164 | if not all(len(elem) == elem_size for elem in it):
165 | raise RuntimeError('each element in list of batch should be of equal size')
166 | transposed = zip(*batch)
167 | return [custom_collate(samples) for samples in transposed]
168 |
169 | raise TypeError(default_collate_err_msg_format.format(elem_type))
170 |
--------------------------------------------------------------------------------
/taming/modules/autoencoder/lpips/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/taming/modules/autoencoder/lpips/vgg.pth
--------------------------------------------------------------------------------
/taming/modules/discriminator/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 |
4 |
5 | from taming.modules.util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22 | """Construct a PatchGAN discriminator
23 | Parameters:
24 | input_nc (int) -- the number of channels in input images
25 | ndf (int) -- the number of filters in the last conv layer
26 | n_layers (int) -- the number of conv layers in the discriminator
27 | norm_layer -- normalization layer
28 | """
29 | super(NLayerDiscriminator, self).__init__()
30 | if not use_actnorm:
31 | norm_layer = nn.BatchNorm2d
32 | else:
33 | norm_layer = ActNorm
34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35 | use_bias = norm_layer.func != nn.BatchNorm2d
36 | else:
37 | use_bias = norm_layer != nn.BatchNorm2d
38 |
39 | kw = 4
40 | padw = 1
41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42 | nf_mult = 1
43 | nf_mult_prev = 1
44 | for n in range(1, n_layers): # gradually increase the number of filters
45 | nf_mult_prev = nf_mult
46 | nf_mult = min(2 ** n, 8)
47 | sequence += [
48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49 | norm_layer(ndf * nf_mult),
50 | nn.LeakyReLU(0.2, True)
51 | ]
52 |
53 | nf_mult_prev = nf_mult
54 | nf_mult = min(2 ** n_layers, 8)
55 | sequence += [
56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57 | norm_layer(ndf * nf_mult),
58 | nn.LeakyReLU(0.2, True)
59 | ]
60 |
61 | sequence += [
62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63 | self.main = nn.Sequential(*sequence)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return self.main(input)
68 |
69 | class NLayerDiscriminator2(nn.Module):
70 | """Defines a PatchGAN discriminator as in Pix2Pix
71 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
72 | """
73 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
74 | """Construct a PatchGAN discriminator
75 | Parameters:
76 | input_nc (int) -- the number of channels in input images
77 | ndf (int) -- the number of filters in the last conv layer
78 | n_layers (int) -- the number of conv layers in the discriminator
79 | norm_layer -- normalization layer
80 | """
81 | super(NLayerDiscriminator2, self).__init__()
82 | if not use_actnorm:
83 | norm_layer = nn.BatchNorm3d
84 | else:
85 | norm_layer = ActNorm
86 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
87 | use_bias = norm_layer.func != nn.BatchNorm3d
88 | else:
89 | use_bias = norm_layer != nn.BatchNorm3d
90 |
91 | kw = 4
92 | padw = 1
93 | sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
94 | nf_mult = 1
95 | nf_mult_prev = 1
96 | for n in range(1, n_layers): # gradually increase the number of filters
97 | nf_mult_prev = nf_mult
98 | nf_mult = min(2 ** n, 8)
99 | sequence += [
100 | nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2,
101 | padding=padw, bias=use_bias, groups=8),
102 | norm_layer(ndf * nf_mult),
103 | nn.LeakyReLU(0.2, True)
104 | ]
105 |
106 | nf_mult_prev = nf_mult
107 | nf_mult = min(2 ** n_layers, 8)
108 | sequence += [
109 | nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1,
110 | padding=padw, bias=use_bias, groups=8),
111 | norm_layer(ndf * nf_mult),
112 | nn.LeakyReLU(0.2, True)
113 | ]
114 |
115 | sequence += [
116 | nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw),
117 | # nn.Sigmoid()
118 | ] # output 1 channel prediction map
119 | self.main = nn.Sequential(*sequence)
120 |
121 | def forward(self, input):
122 | """Standard forward."""
123 | return self.main(input)
124 |
125 | if __name__ == "__main__":
126 | import torch
127 | model = NLayerDiscriminator2(input_nc=3, ndf=64, n_layers=3)
128 | x = torch.rand(1, 3, 64, 64, 64)
129 | with torch.no_grad():
130 | y = model(x)
131 | pause = 0
--------------------------------------------------------------------------------
/taming/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from taming.modules.losses.vqperceptual import DummyLoss
2 |
3 |
--------------------------------------------------------------------------------
/taming/modules/losses/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from collections import namedtuple
7 |
8 | from .util import get_ckpt_path
9 |
10 | class LPIPS(nn.Module):
11 | # Learned perceptual metric
12 | def __init__(self, use_dropout=True):
13 | super().__init__()
14 | self.scaling_layer = ScalingLayer()
15 | self.chns = [64, 128, 256, 512, 512] # vg16 features
16 | self.net = vgg16(pretrained=True, requires_grad=False)
17 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
18 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
19 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
20 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
21 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
22 | self.load_from_pretrained()
23 | for param in self.parameters():
24 | param.requires_grad = False
25 |
26 | def load_from_pretrained(self, name="vgg_lpips"):
27 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
28 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
29 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
30 |
31 | @classmethod
32 | def from_pretrained(cls, name="vgg_lpips"):
33 | if name != "vgg_lpips":
34 | raise NotImplementedError
35 | model = cls()
36 | ckpt = get_ckpt_path(name)
37 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
38 | return model
39 |
40 | def forward(self, input, target):
41 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
42 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
43 | feats0, feats1, diffs = {}, {}, {}
44 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
45 | for kk in range(len(self.chns)):
46 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
47 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
48 |
49 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
50 | val = res[0]
51 | for l in range(1, len(self.chns)):
52 | val += res[l]
53 | return val
54 |
55 |
56 | class ScalingLayer(nn.Module):
57 | def __init__(self):
58 | super(ScalingLayer, self).__init__()
59 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
60 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
61 |
62 | def forward(self, inp):
63 | return (inp - self.shift) / self.scale
64 |
65 |
66 | class NetLinLayer(nn.Module):
67 | """ A single linear layer which does a 1x1 conv """
68 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
69 | super(NetLinLayer, self).__init__()
70 | layers = [nn.Dropout(), ] if (use_dropout) else []
71 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
72 | self.model = nn.Sequential(*layers)
73 |
74 |
75 | class vgg16(torch.nn.Module):
76 | def __init__(self, requires_grad=False, pretrained=True):
77 | super(vgg16, self).__init__()
78 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
79 | self.slice1 = torch.nn.Sequential()
80 | self.slice2 = torch.nn.Sequential()
81 | self.slice3 = torch.nn.Sequential()
82 | self.slice4 = torch.nn.Sequential()
83 | self.slice5 = torch.nn.Sequential()
84 | self.N_slices = 5
85 | for x in range(4):
86 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
87 | for x in range(4, 9):
88 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
89 | for x in range(9, 16):
90 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
91 | for x in range(16, 23):
92 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
93 | for x in range(23, 30):
94 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
95 | if not requires_grad:
96 | for param in self.parameters():
97 | param.requires_grad = False
98 |
99 | def forward(self, X):
100 | h = self.slice1(X)
101 | h_relu1_2 = h
102 | h = self.slice2(h)
103 | h_relu2_2 = h
104 | h = self.slice3(h)
105 | h_relu3_3 = h
106 | h = self.slice4(h)
107 | h_relu4_3 = h
108 | h = self.slice5(h)
109 | h_relu5_3 = h
110 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
111 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
112 | return out
113 |
114 |
115 | def normalize_tensor(x,eps=1e-10):
116 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
117 | return x/(norm_factor+eps)
118 |
119 |
120 | def spatial_average(x, keepdim=True):
121 | return x.mean([2,3],keepdim=keepdim)
122 |
123 |
--------------------------------------------------------------------------------
/taming/modules/losses/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class BCELoss(nn.Module):
6 | def forward(self, prediction, target):
7 | loss = F.binary_cross_entropy_with_logits(prediction,target)
8 | return loss, {}
9 |
10 |
11 | class BCELossWithQuant(nn.Module):
12 | def __init__(self, codebook_weight=1.):
13 | super().__init__()
14 | self.codebook_weight = codebook_weight
15 |
16 | def forward(self, qloss, target, prediction, split):
17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18 | loss = bce_loss + self.codebook_weight*qloss
19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20 | "{}/bce_loss".format(split): bce_loss.detach().mean(),
21 | "{}/quant_loss".format(split): qloss.detach().mean()
22 | }
23 |
--------------------------------------------------------------------------------
/taming/modules/losses/util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7 | }
8 |
9 | CKPT_MAP = {
10 | "vgg_lpips": "vgg.pth"
11 | }
12 |
13 | MD5_MAP = {
14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15 | }
16 |
17 |
18 | def download(url, local_path, chunk_size=1024):
19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20 | with requests.get(url, stream=True) as r:
21 | total_size = int(r.headers.get("content-length", 0))
22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23 | with open(local_path, "wb") as f:
24 | for data in r.iter_content(chunk_size=chunk_size):
25 | if data:
26 | f.write(data)
27 | pbar.update(chunk_size)
28 |
29 |
30 | def md5_hash(path):
31 | with open(path, "rb") as f:
32 | content = f.read()
33 | return hashlib.md5(content).hexdigest()
34 |
35 |
36 | def get_ckpt_path(name, root, check=False):
37 | assert name in URL_MAP
38 | path = os.path.join(root, CKPT_MAP[name])
39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41 | download(URL_MAP[name], path)
42 | md5 = md5_hash(path)
43 | assert md5 == MD5_MAP[name], md5
44 | return path
45 |
46 |
47 | class KeyNotFoundError(Exception):
48 | def __init__(self, cause, keys=None, visited=None):
49 | self.cause = cause
50 | self.keys = keys
51 | self.visited = visited
52 | messages = list()
53 | if keys is not None:
54 | messages.append("Key not found: {}".format(keys))
55 | if visited is not None:
56 | messages.append("Visited: {}".format(visited))
57 | messages.append("Cause:\n{}".format(cause))
58 | message = "\n".join(messages)
59 | super().__init__(message)
60 |
61 |
62 | def retrieve(
63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64 | ):
65 | """Given a nested list or dict return the desired value at key expanding
66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67 | is done in-place.
68 |
69 | Parameters
70 | ----------
71 | list_or_dict : list or dict
72 | Possibly nested list or dictionary.
73 | key : str
74 | key/to/value, path like string describing all keys necessary to
75 | consider to get to the desired value. List indices can also be
76 | passed here.
77 | splitval : str
78 | String that defines the delimiter between keys of the
79 | different depth levels in `key`.
80 | default : obj
81 | Value returned if :attr:`key` is not found.
82 | expand : bool
83 | Whether to expand callable nodes on the path or not.
84 |
85 | Returns
86 | -------
87 | The desired value or if :attr:`default` is not ``None`` and the
88 | :attr:`key` is not found returns ``default``.
89 |
90 | Raises
91 | ------
92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93 | ``None``.
94 | """
95 |
96 | keys = key.split(splitval)
97 |
98 | success = True
99 | try:
100 | visited = []
101 | parent = None
102 | last_key = None
103 | for key in keys:
104 | if callable(list_or_dict):
105 | if not expand:
106 | raise KeyNotFoundError(
107 | ValueError(
108 | "Trying to get past callable node with expand=False."
109 | ),
110 | keys=keys,
111 | visited=visited,
112 | )
113 | list_or_dict = list_or_dict()
114 | parent[last_key] = list_or_dict
115 |
116 | last_key = key
117 | parent = list_or_dict
118 |
119 | try:
120 | if isinstance(list_or_dict, dict):
121 | list_or_dict = list_or_dict[key]
122 | else:
123 | list_or_dict = list_or_dict[int(key)]
124 | except (KeyError, IndexError, ValueError) as e:
125 | raise KeyNotFoundError(e, keys=keys, visited=visited)
126 |
127 | visited += [key]
128 | # final expansion of retrieved value
129 | if expand and callable(list_or_dict):
130 | list_or_dict = list_or_dict()
131 | parent[last_key] = list_or_dict
132 | except KeyNotFoundError as e:
133 | if default is None:
134 | raise e
135 | else:
136 | list_or_dict = default
137 | success = False
138 |
139 | if not pass_success:
140 | return list_or_dict
141 | else:
142 | return list_or_dict, success
143 |
144 |
145 | if __name__ == "__main__":
146 | config = {"keya": "a",
147 | "keyb": "b",
148 | "keyc":
149 | {"cc1": 1,
150 | "cc2": 2,
151 | }
152 | }
153 | from omegaconf import OmegaConf
154 | config = OmegaConf.create(config)
155 | print(config)
156 | retrieve(config, "keya")
157 |
158 |
--------------------------------------------------------------------------------
/taming/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from taming.modules.losses.lpips import LPIPS
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init, NLayerDiscriminator2
7 |
8 |
9 | class DummyLoss(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 |
14 | def adopt_weight(weight, global_step, threshold=0, value=0.):
15 | if global_step < threshold:
16 | weight = value
17 | return weight
18 |
19 |
20 | def hinge_d_loss(logits_real, logits_fake):
21 | loss_real = torch.mean(F.relu(1. - logits_real))
22 | loss_fake = torch.mean(F.relu(1. + logits_fake))
23 | d_loss = 0.5 * (loss_real + loss_fake)
24 | return d_loss
25 |
26 |
27 | def vanilla_d_loss(logits_real, logits_fake):
28 | d_loss = 0.5 * (
29 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
30 | torch.mean(torch.nn.functional.softplus(logits_fake)))
31 | return d_loss
32 |
33 |
34 | class VQLPIPSWithDiscriminator(nn.Module):
35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
38 | disc_ndf=64, disc_loss="hinge"):
39 | super().__init__()
40 | assert disc_loss in ["hinge", "vanilla"]
41 | self.codebook_weight = codebook_weight
42 | self.pixel_weight = pixelloss_weight
43 | self.perceptual_loss = LPIPS().eval()
44 | self.perceptual_weight = perceptual_weight
45 |
46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
47 | n_layers=disc_num_layers,
48 | use_actnorm=use_actnorm,
49 | ndf=disc_ndf
50 | ).apply(weights_init)
51 | self.discriminator_iter_start = disc_start
52 | if disc_loss == "hinge":
53 | self.disc_loss = hinge_d_loss
54 | elif disc_loss == "vanilla":
55 | self.disc_loss = vanilla_d_loss
56 | else:
57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
59 | self.disc_factor = disc_factor
60 | self.discriminator_weight = disc_weight
61 | self.disc_conditional = disc_conditional
62 |
63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
64 | if last_layer is not None:
65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
67 | else:
68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
70 |
71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
73 | d_weight = d_weight * self.discriminator_weight
74 | return d_weight
75 |
76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
77 | global_step, last_layer=None, cond=None, split="train"):
78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
79 | if self.perceptual_weight > 0:
80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
81 | rec_loss = rec_loss + self.perceptual_weight * p_loss
82 | else:
83 | p_loss = torch.tensor([0.0])
84 |
85 | nll_loss = rec_loss
86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
87 | nll_loss = torch.mean(nll_loss)
88 |
89 | # now the GAN part
90 | if optimizer_idx == 0:
91 | # generator update
92 | if cond is None:
93 | assert not self.disc_conditional
94 | logits_fake = self.discriminator(reconstructions.contiguous())
95 | else:
96 | assert self.disc_conditional
97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
98 | g_loss = -torch.mean(logits_fake)
99 |
100 | try:
101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
102 | except RuntimeError:
103 | assert not self.training
104 | d_weight = torch.tensor(0.0)
105 |
106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
108 |
109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
111 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
112 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
113 | "{}/p_loss".format(split): p_loss.detach().mean(),
114 | "{}/d_weight".format(split): d_weight.detach(),
115 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
116 | "{}/g_loss".format(split): g_loss.detach().mean(),
117 | }
118 | return loss, log
119 |
120 | if optimizer_idx == 1:
121 | # second pass for discriminator update
122 | if cond is None:
123 | logits_real = self.discriminator(inputs.contiguous().detach())
124 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
125 | else:
126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
128 |
129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
131 |
132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
133 | "{}/logits_real".format(split): logits_real.detach().mean(),
134 | "{}/logits_fake".format(split): logits_fake.detach().mean()
135 | }
136 | return d_loss, log
137 |
--------------------------------------------------------------------------------
/taming/modules/misc/coord.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class CoordStage(object):
4 | def __init__(self, n_embed, down_factor):
5 | self.n_embed = n_embed
6 | self.down_factor = down_factor
7 |
8 | def eval(self):
9 | return self
10 |
11 | def encode(self, c):
12 | """fake vqmodel interface"""
13 | assert 0.0 <= c.min() and c.max() <= 1.0
14 | b,ch,h,w = c.shape
15 | assert ch == 1
16 |
17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18 | mode="area")
19 | c = c.clamp(0.0, 1.0)
20 | c = self.n_embed*c
21 | c_quant = c.round()
22 | c_ind = c_quant.to(dtype=torch.long)
23 |
24 | info = None, None, c_ind
25 | return c_quant, None, info
26 |
27 | def decode(self, c):
28 | c = c/self.n_embed
29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30 | mode="nearest")
31 | return c
32 |
--------------------------------------------------------------------------------
/taming/modules/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def count_params(model):
6 | total_params = sum(p.numel() for p in model.parameters())
7 | return total_params
8 |
9 |
10 | class ActNorm(nn.Module):
11 | def __init__(self, num_features, logdet=False, affine=True,
12 | allow_reverse_init=False):
13 | assert affine
14 | super().__init__()
15 | self.logdet = logdet
16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18 | self.allow_reverse_init = allow_reverse_init
19 |
20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21 |
22 | def initialize(self, input):
23 | with torch.no_grad():
24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25 | mean = (
26 | flatten.mean(1)
27 | .unsqueeze(1)
28 | .unsqueeze(2)
29 | .unsqueeze(3)
30 | .permute(1, 0, 2, 3)
31 | )
32 | std = (
33 | flatten.std(1)
34 | .unsqueeze(1)
35 | .unsqueeze(2)
36 | .unsqueeze(3)
37 | .permute(1, 0, 2, 3)
38 | )
39 |
40 | self.loc.data.copy_(-mean)
41 | self.scale.data.copy_(1 / (std + 1e-6))
42 |
43 | def forward(self, input, reverse=False):
44 | if reverse:
45 | return self.reverse(input)
46 | if len(input.shape) == 2:
47 | input = input[:,:,None,None]
48 | squeeze = True
49 | else:
50 | squeeze = False
51 |
52 | _, _, height, width = input.shape
53 |
54 | if self.training and self.initialized.item() == 0:
55 | self.initialize(input)
56 | self.initialized.fill_(1)
57 |
58 | h = self.scale * (input + self.loc)
59 |
60 | if squeeze:
61 | h = h.squeeze(-1).squeeze(-1)
62 |
63 | if self.logdet:
64 | log_abs = torch.log(torch.abs(self.scale))
65 | logdet = height*width*torch.sum(log_abs)
66 | logdet = logdet * torch.ones(input.shape[0]).to(input)
67 | return h, logdet
68 |
69 | return h
70 |
71 | def reverse(self, output):
72 | if self.training and self.initialized.item() == 0:
73 | if not self.allow_reverse_init:
74 | raise RuntimeError(
75 | "Initializing ActNorm in reverse direction is "
76 | "disabled by default. Use allow_reverse_init=True to enable."
77 | )
78 | else:
79 | self.initialize(output)
80 | self.initialized.fill_(1)
81 |
82 | if len(output.shape) == 2:
83 | output = output[:,:,None,None]
84 | squeeze = True
85 | else:
86 | squeeze = False
87 |
88 | h = output / self.scale - self.loc
89 |
90 | if squeeze:
91 | h = h.squeeze(-1).squeeze(-1)
92 | return h
93 |
94 |
95 | class AbstractEncoder(nn.Module):
96 | def __init__(self):
97 | super().__init__()
98 |
99 | def encode(self, *args, **kwargs):
100 | raise NotImplementedError
101 |
102 |
103 | class Labelator(AbstractEncoder):
104 | """Net2Net Interface for Class-Conditional Model"""
105 | def __init__(self, n_classes, quantize_interface=True):
106 | super().__init__()
107 | self.n_classes = n_classes
108 | self.quantize_interface = quantize_interface
109 |
110 | def encode(self, c):
111 | c = c[:,None]
112 | if self.quantize_interface:
113 | return c, None, [None, None, c.long()]
114 | return c
115 |
116 |
117 | class SOSProvider(AbstractEncoder):
118 | # for unconditional training
119 | def __init__(self, sos_token, quantize_interface=True):
120 | super().__init__()
121 | self.sos_token = sos_token
122 | self.quantize_interface = quantize_interface
123 |
124 | def encode(self, x):
125 | # get batch size from data and replicate sos_token
126 | c = torch.ones(x.shape[0], 1)*self.sos_token
127 | c = c.long().to(x.device)
128 | if self.quantize_interface:
129 | return c, None, [None, None, c]
130 | return c
131 |
--------------------------------------------------------------------------------
/taming/util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7 | }
8 |
9 | CKPT_MAP = {
10 | "vgg_lpips": "vgg.pth"
11 | }
12 |
13 | MD5_MAP = {
14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15 | }
16 |
17 |
18 | def download(url, local_path, chunk_size=1024):
19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20 | with requests.get(url, stream=True) as r:
21 | total_size = int(r.headers.get("content-length", 0))
22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23 | with open(local_path, "wb") as f:
24 | for data in r.iter_content(chunk_size=chunk_size):
25 | if data:
26 | f.write(data)
27 | pbar.update(chunk_size)
28 |
29 |
30 | def md5_hash(path):
31 | with open(path, "rb") as f:
32 | content = f.read()
33 | return hashlib.md5(content).hexdigest()
34 |
35 |
36 | def get_ckpt_path(name, root, check=False):
37 | assert name in URL_MAP
38 | path = os.path.join(root, CKPT_MAP[name])
39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41 | download(URL_MAP[name], path)
42 | md5 = md5_hash(path)
43 | assert md5 == MD5_MAP[name], md5
44 | return path
45 |
46 |
47 | class KeyNotFoundError(Exception):
48 | def __init__(self, cause, keys=None, visited=None):
49 | self.cause = cause
50 | self.keys = keys
51 | self.visited = visited
52 | messages = list()
53 | if keys is not None:
54 | messages.append("Key not found: {}".format(keys))
55 | if visited is not None:
56 | messages.append("Visited: {}".format(visited))
57 | messages.append("Cause:\n{}".format(cause))
58 | message = "\n".join(messages)
59 | super().__init__(message)
60 |
61 |
62 | def retrieve(
63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64 | ):
65 | """Given a nested list or dict return the desired value at key expanding
66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67 | is done in-place.
68 |
69 | Parameters
70 | ----------
71 | list_or_dict : list or dict
72 | Possibly nested list or dictionary.
73 | key : str
74 | key/to/value, path like string describing all keys necessary to
75 | consider to get to the desired value. List indices can also be
76 | passed here.
77 | splitval : str
78 | String that defines the delimiter between keys of the
79 | different depth levels in `key`.
80 | default : obj
81 | Value returned if :attr:`key` is not found.
82 | expand : bool
83 | Whether to expand callable nodes on the path or not.
84 |
85 | Returns
86 | -------
87 | The desired value or if :attr:`default` is not ``None`` and the
88 | :attr:`key` is not found returns ``default``.
89 |
90 | Raises
91 | ------
92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93 | ``None``.
94 | """
95 |
96 | keys = key.split(splitval)
97 |
98 | success = True
99 | try:
100 | visited = []
101 | parent = None
102 | last_key = None
103 | for key in keys:
104 | if callable(list_or_dict):
105 | if not expand:
106 | raise KeyNotFoundError(
107 | ValueError(
108 | "Trying to get past callable node with expand=False."
109 | ),
110 | keys=keys,
111 | visited=visited,
112 | )
113 | list_or_dict = list_or_dict()
114 | parent[last_key] = list_or_dict
115 |
116 | last_key = key
117 | parent = list_or_dict
118 |
119 | try:
120 | if isinstance(list_or_dict, dict):
121 | list_or_dict = list_or_dict[key]
122 | else:
123 | list_or_dict = list_or_dict[int(key)]
124 | except (KeyError, IndexError, ValueError) as e:
125 | raise KeyNotFoundError(e, keys=keys, visited=visited)
126 |
127 | visited += [key]
128 | # final expansion of retrieved value
129 | if expand and callable(list_or_dict):
130 | list_or_dict = list_or_dict()
131 | parent[last_key] = list_or_dict
132 | except KeyNotFoundError as e:
133 | if default is None:
134 | raise e
135 | else:
136 | list_or_dict = default
137 | success = False
138 |
139 | if not pass_success:
140 | return list_or_dict
141 | else:
142 | return list_or_dict, success
143 |
144 |
145 | if __name__ == "__main__":
146 | config = {"keya": "a",
147 | "keyb": "b",
148 | "keyc":
149 | {"cc1": 1,
150 | "cc2": 2,
151 | }
152 | }
153 | from omegaconf import OmegaConf
154 | config = OmegaConf.create(config)
155 | print(config)
156 | retrieve(config, "keya")
157 |
158 |
--------------------------------------------------------------------------------
/unet_plus/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/unet_plus/ema.py:
--------------------------------------------------------------------------------
1 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py
2 |
3 | from __future__ import division
4 | from __future__ import unicode_literals
5 |
6 | import torch
7 |
8 |
9 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
10 | class ExponentialMovingAverage:
11 | """
12 | Maintains (exponential) moving average of a set of parameters.
13 | """
14 |
15 | def __init__(self, parameters, decay, use_num_updates=True):
16 | """
17 | Args:
18 | parameters: Iterable of `torch.nn.Parameter`; usually the result of
19 | `model.parameters()`.
20 | decay: The exponential decay.
21 | use_num_updates: Whether to use number of updates when computing
22 | averages.
23 | """
24 | if decay < 0.0 or decay > 1.0:
25 | raise ValueError('Decay must be between 0 and 1')
26 | self.decay = decay
27 | self.num_updates = 0 if use_num_updates else None
28 | self.shadow_params = [p.clone().detach()
29 | for p in parameters if p.requires_grad]
30 | self.collected_params = []
31 |
32 | def update(self, parameters):
33 | """
34 | Update currently maintained parameters.
35 |
36 | Call this every time the parameters are updated, such as the result of
37 | the `optimizer.step()` call.
38 |
39 | Args:
40 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of
41 | parameters used to initialize this object.
42 | """
43 | decay = self.decay
44 | if self.num_updates is not None:
45 | self.num_updates += 1
46 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
47 | one_minus_decay = 1.0 - decay
48 | with torch.no_grad():
49 | parameters = [p for p in parameters if p.requires_grad]
50 | for s_param, param in zip(self.shadow_params, parameters):
51 | s_param.sub_(one_minus_decay * (s_param - param))
52 |
53 | def copy_to(self, parameters):
54 | """
55 | Copy current parameters into given collection of parameters.
56 |
57 | Args:
58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
59 | updated with the stored moving averages.
60 | """
61 | parameters = [p for p in parameters if p.requires_grad]
62 | for s_param, param in zip(self.shadow_params, parameters):
63 | if param.requires_grad:
64 | param.data.copy_(s_param.data)
65 |
66 | def store(self, parameters):
67 | """
68 | Save the current parameters for restoring later.
69 |
70 | Args:
71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
72 | temporarily stored.
73 | """
74 | self.collected_params = [param.clone() for param in parameters]
75 |
76 | def restore(self, parameters):
77 | """
78 | Restore the parameters stored with the `store` method.
79 | Useful to validate the model with EMA parameters without affecting the
80 | original optimization process. Store the parameters before the
81 | `copy_to` method. After validation (or model saving), use this to
82 | restore the former parameters.
83 |
84 | Args:
85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
86 | updated with the stored parameters.
87 | """
88 | for c_param, param in zip(self.collected_params, parameters):
89 | param.data.copy_(c_param.data)
90 |
91 | def state_dict(self):
92 | return dict(decay=self.decay, num_updates=self.num_updates,
93 | shadow_params=self.shadow_params)
94 |
95 | def load_state_dict(self, state_dict):
96 | self.decay = state_dict['decay']
97 | self.num_updates = state_dict['num_updates']
98 | self.shadow_params = state_dict['shadow_params']
--------------------------------------------------------------------------------
/unet_plus/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/unet_plus/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output, empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | grad_bias = grad_input.sum(dim).detach()
39 |
40 | return grad_input, grad_bias
41 |
42 | @staticmethod
43 | def backward(ctx, gradgrad_input, gradgrad_bias):
44 | out, = ctx.saved_tensors
45 | gradgrad_out = fused.fused_bias_act(
46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
47 | )
48 |
49 | return gradgrad_out, None, None, None
50 |
51 |
52 | class FusedLeakyReLUFunction(Function):
53 | @staticmethod
54 | def forward(ctx, input, bias, negative_slope, scale):
55 | empty = input.new_empty(0)
56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
57 | ctx.save_for_backward(out)
58 | ctx.negative_slope = negative_slope
59 | ctx.scale = scale
60 |
61 | return out
62 |
63 | @staticmethod
64 | def backward(ctx, grad_output):
65 | out, = ctx.saved_tensors
66 |
67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
68 | grad_output, out, ctx.negative_slope, ctx.scale
69 | )
70 |
71 | return grad_input, grad_bias, None, None
72 |
73 |
74 | class FusedLeakyReLU(nn.Module):
75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
76 | super().__init__()
77 |
78 | self.bias = nn.Parameter(torch.zeros(channel))
79 | self.negative_slope = negative_slope
80 | self.scale = scale
81 |
82 | def forward(self, input):
83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
84 |
85 |
86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
87 | if input.device.type == "cpu":
88 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
89 | return (
90 | F.leaky_relu(
91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
92 | )
93 | * scale
94 | )
95 |
96 | else:
97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
98 |
--------------------------------------------------------------------------------
/unet_plus/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/unet_plus/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/unet_plus/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/unet_plus/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch.autograd import Function
6 | from torch.utils.cpp_extension import load
7 |
8 |
9 | module_path = os.path.dirname(__file__)
10 | upfirdn2d_op = load(
11 | "upfirdn2d",
12 | sources=[
13 | os.path.join(module_path, "upfirdn2d.cpp"),
14 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
15 | ],
16 | )
17 |
18 |
19 | class UpFirDn2dBackward(Function):
20 | @staticmethod
21 | def forward(
22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
23 | ):
24 |
25 | up_x, up_y = up
26 | down_x, down_y = down
27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
28 |
29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
30 |
31 | grad_input = upfirdn2d_op.upfirdn2d(
32 | grad_output,
33 | grad_kernel,
34 | down_x,
35 | down_y,
36 | up_x,
37 | up_y,
38 | g_pad_x0,
39 | g_pad_x1,
40 | g_pad_y0,
41 | g_pad_y1,
42 | )
43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
44 |
45 | ctx.save_for_backward(kernel)
46 |
47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
48 |
49 | ctx.up_x = up_x
50 | ctx.up_y = up_y
51 | ctx.down_x = down_x
52 | ctx.down_y = down_y
53 | ctx.pad_x0 = pad_x0
54 | ctx.pad_x1 = pad_x1
55 | ctx.pad_y0 = pad_y0
56 | ctx.pad_y1 = pad_y1
57 | ctx.in_size = in_size
58 | ctx.out_size = out_size
59 |
60 | return grad_input
61 |
62 | @staticmethod
63 | def backward(ctx, gradgrad_input):
64 | kernel, = ctx.saved_tensors
65 |
66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
67 |
68 | gradgrad_out = upfirdn2d_op.upfirdn2d(
69 | gradgrad_input,
70 | kernel,
71 | ctx.up_x,
72 | ctx.up_y,
73 | ctx.down_x,
74 | ctx.down_y,
75 | ctx.pad_x0,
76 | ctx.pad_x1,
77 | ctx.pad_y0,
78 | ctx.pad_y1,
79 | )
80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
81 | gradgrad_out = gradgrad_out.view(
82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
83 | )
84 |
85 | return gradgrad_out, None, None, None, None, None, None, None, None
86 |
87 |
88 | class UpFirDn2d(Function):
89 | @staticmethod
90 | def forward(ctx, input, kernel, up, down, pad):
91 | up_x, up_y = up
92 | down_x, down_y = down
93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
94 |
95 | kernel_h, kernel_w = kernel.shape
96 | batch, channel, in_h, in_w = input.shape
97 | ctx.in_size = input.shape
98 |
99 | input = input.reshape(-1, in_h, in_w, 1)
100 |
101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
102 |
103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
105 | ctx.out_size = (out_h, out_w)
106 |
107 | ctx.up = (up_x, up_y)
108 | ctx.down = (down_x, down_y)
109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
110 |
111 | g_pad_x0 = kernel_w - pad_x0 - 1
112 | g_pad_y0 = kernel_h - pad_y0 - 1
113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
115 |
116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
117 |
118 | out = upfirdn2d_op.upfirdn2d(
119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
120 | )
121 | # out = out.view(major, out_h, out_w, minor)
122 | out = out.view(-1, channel, out_h, out_w)
123 |
124 | return out
125 |
126 | @staticmethod
127 | def backward(ctx, grad_output):
128 | kernel, grad_kernel = ctx.saved_tensors
129 |
130 | grad_input = UpFirDn2dBackward.apply(
131 | grad_output,
132 | kernel,
133 | grad_kernel,
134 | ctx.up,
135 | ctx.down,
136 | ctx.pad,
137 | ctx.g_pad,
138 | ctx.in_size,
139 | ctx.out_size,
140 | )
141 |
142 | return grad_input, None, None, None, None
143 |
144 |
145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
146 | if input.device.type == "cpu":
147 | out = upfirdn2d_native(
148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
149 | )
150 |
151 | else:
152 | out = UpFirDn2d.apply(
153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
154 | )
155 |
156 | return out
157 |
158 |
159 | def upfirdn2d_native(
160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
161 | ):
162 | _, channel, in_h, in_w = input.shape
163 | input = input.reshape(-1, in_h, in_w, 1)
164 |
165 | _, in_h, in_w, minor = input.shape
166 | kernel_h, kernel_w = kernel.shape
167 |
168 | out = input.view(-1, in_h, 1, in_w, 1, minor)
169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
171 |
172 | out = F.pad(
173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
174 | )
175 | out = out[
176 | :,
177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
179 | :,
180 | ]
181 |
182 | out = out.permute(0, 3, 1, 2)
183 | out = out.reshape(
184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
185 | )
186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
187 | out = F.conv2d(out, w)
188 | out = out.reshape(
189 | -1,
190 | minor,
191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
193 | )
194 | out = out.permute(0, 2, 3, 1)
195 | out = out[:, ::down_y, ::down_x, :]
196 |
197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
199 |
200 | return out.view(-1, channel, out_h, out_w)
201 |
--------------------------------------------------------------------------------
/unet_plus/unet_pp.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # pylint: skip-file
17 | """DDPM model.
18 |
19 | This code is the pytorch equivalent of:
20 | https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py
21 | """
22 | import torch
23 | import torch.nn as nn
24 | import functools
25 |
26 | from unet_plus import utils, layers, normalization
27 |
28 | RefineBlock = layers.RefineBlock
29 | ResidualBlock = layers.ResidualBlock
30 | ResnetBlockDDPM = layers.ResnetBlockDDPM
31 | Upsample = layers.Upsample
32 | Downsample = layers.Downsample
33 | conv3x3 = layers.ddpm_conv3x3
34 | get_act = layers.get_act
35 | get_normalization = normalization.get_normalization
36 | default_initializer = layers.default_init
37 |
38 |
39 | # @utils.register_model(name='ddpm')
40 | class UnetPlus(nn.Module):
41 | def __init__(self, config):
42 | super().__init__()
43 | self.act = act = get_act(config)
44 | self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
45 |
46 | self.nf = nf = config.nf
47 | ch_mult = config.ch_mult
48 | self.num_res_blocks = num_res_blocks = config.num_res_blocks
49 | self.attn_resolutions = attn_resolutions = config.attn_resolutions
50 | dropout = config.dropout
51 | resamp_with_conv = config.resamp_with_conv
52 | self.num_resolutions = num_resolutions = len(ch_mult)
53 | self.all_resolutions = all_resolutions = [config.image_size // (2 ** i) for i in range(num_resolutions)]
54 |
55 | AttnBlock = functools.partial(layers.AttnBlock)
56 | self.conditional = conditional = config.conditional
57 | ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout)
58 | if conditional:
59 | # Condition on noise levels.
60 | modules = [nn.Linear(nf, nf * 4)]
61 | modules[0].weight.data = default_initializer()(modules[0].weight.data.shape)
62 | nn.init.zeros_(modules[0].bias)
63 | modules.append(nn.Linear(nf * 4, nf * 4))
64 | modules[1].weight.data = default_initializer()(modules[1].weight.data.shape)
65 | nn.init.zeros_(modules[1].bias)
66 |
67 | # self.centered = config.data.centered
68 | channels = config.in_channels
69 |
70 | # Downsampling block
71 | modules.append(conv3x3(channels, nf))
72 | hs_c = [nf]
73 | in_ch = nf
74 | for i_level in range(num_resolutions):
75 | # Residual blocks for this resolution
76 | for i_block in range(num_res_blocks):
77 | out_ch = nf * ch_mult[i_level]
78 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
79 | in_ch = out_ch
80 | if all_resolutions[i_level] in attn_resolutions:
81 | modules.append(AttnBlock(channels=in_ch))
82 | hs_c.append(in_ch)
83 | if i_level != num_resolutions - 1:
84 | modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv))
85 | hs_c.append(in_ch)
86 |
87 | in_ch = hs_c[-1]
88 | modules.append(ResnetBlock(in_ch=in_ch))
89 | modules.append(AttnBlock(channels=in_ch))
90 | modules.append(ResnetBlock(in_ch=in_ch))
91 |
92 | # Upsampling block
93 | for i_level in reversed(range(num_resolutions)):
94 | for i_block in range(num_res_blocks + 1):
95 | out_ch = nf * ch_mult[i_level]
96 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
97 | in_ch = out_ch
98 | if all_resolutions[i_level] in attn_resolutions:
99 | modules.append(AttnBlock(channels=in_ch))
100 | if i_level != 0:
101 | modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv))
102 |
103 | assert not hs_c
104 | modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6))
105 | modules.append(conv3x3(in_ch, channels, init_scale=0.))
106 | self.all_modules = nn.ModuleList(modules)
107 |
108 | self.scale_by_sigma = config.scale_by_sigma
109 |
110 | def forward(self, x, times=None):
111 | modules = self.all_modules
112 | m_idx = 0
113 | if times is not None:
114 | # timestep/scale embedding
115 | timesteps = times
116 | temb = layers.get_timestep_embedding(timesteps, self.nf)
117 | temb = modules[m_idx](temb)
118 | m_idx += 1
119 | temb = modules[m_idx](self.act(temb))
120 | m_idx += 1
121 | else:
122 | temb = times
123 |
124 | # if self.centered:
125 | # # Input is in [-1, 1]
126 | # h = x
127 | # else:
128 | # # Input is in [0, 1]
129 | # h = 2 * x - 1.
130 | h = x
131 |
132 | # Downsampling block
133 | hs = [modules[m_idx](h)]
134 | m_idx += 1
135 | for i_level in range(self.num_resolutions):
136 | # Residual blocks for this resolution
137 | for i_block in range(self.num_res_blocks):
138 | h = modules[m_idx](hs[-1], temb)
139 | m_idx += 1
140 | if h.shape[-1] in self.attn_resolutions:
141 | h = modules[m_idx](h)
142 | m_idx += 1
143 | hs.append(h)
144 | if i_level != self.num_resolutions - 1:
145 | hs.append(modules[m_idx](hs[-1]))
146 | m_idx += 1
147 |
148 | h = hs[-1]
149 | h = modules[m_idx](h, temb)
150 | m_idx += 1
151 | h = modules[m_idx](h)
152 | m_idx += 1
153 | h = modules[m_idx](h, temb)
154 | m_idx += 1
155 |
156 | # Upsampling block
157 | for i_level in reversed(range(self.num_resolutions)):
158 | for i_block in range(self.num_res_blocks + 1):
159 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
160 | m_idx += 1
161 | if h.shape[-1] in self.attn_resolutions:
162 | h = modules[m_idx](h)
163 | m_idx += 1
164 | if i_level != 0:
165 | h = modules[m_idx](h)
166 | m_idx += 1
167 |
168 | assert not hs
169 | h = self.act(modules[m_idx](h))
170 | m_idx += 1
171 | h = modules[m_idx](h)
172 | m_idx += 1
173 | assert m_idx == len(modules)
174 |
175 | if self.scale_by_sigma:
176 | # Divide the output by sigmas. Useful for training with the NCSN loss.
177 | # The DDPM loss scales the network output by sigma in the loss function,
178 | # so no need of doing it here.
179 | used_sigmas = self.sigmas[times, None, None, None]
180 | h = h / used_sigmas
181 |
182 | return h
183 |
184 | if __name__ == '__main__':
185 |
186 | pass
--------------------------------------------------------------------------------
/unet_plus/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """All functions and modules related to model definition.
17 | """
18 |
19 | import torch
20 | # import sde_lib
21 | import numpy as np
22 |
23 |
24 | _MODELS = {}
25 |
26 |
27 | def register_model(cls=None, *, name=None):
28 | """A decorator for registering model classes."""
29 |
30 | def _register(cls):
31 | if name is None:
32 | local_name = cls.__name__
33 | else:
34 | local_name = name
35 | if local_name in _MODELS:
36 | raise ValueError(f'Already registered model with name: {local_name}')
37 | _MODELS[local_name] = cls
38 | return cls
39 |
40 | if cls is None:
41 | return _register
42 | else:
43 | return _register(cls)
44 |
45 |
46 | def get_model(name):
47 | return _MODELS[name]
48 |
49 |
50 | def get_sigmas(config):
51 | """Get sigmas --- the set of noise levels for SMLD from config files.
52 | Args:
53 | config: A ConfigDict object parsed from the config file
54 | Returns:
55 | sigmas: a jax numpy arrary of noise levels
56 | """
57 | sigmas = np.exp(
58 | np.linspace(np.log(config.sigma_max), np.log(config.sigma_min), config.num_scales))
59 |
60 | return sigmas
61 |
62 |
63 | def get_ddpm_params(config):
64 | """Get betas and alphas --- parameters used in the original DDPM paper."""
65 | num_diffusion_timesteps = 1000
66 | # parameters need to be adapted if number of time steps differs from 1000
67 | beta_start = config.beta_min / config.num_scales
68 | beta_end = config.beta_max / config.num_scales
69 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
70 |
71 | alphas = 1. - betas
72 | alphas_cumprod = np.cumprod(alphas, axis=0)
73 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
74 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
75 |
76 | return {
77 | 'betas': betas,
78 | 'alphas': alphas,
79 | 'alphas_cumprod': alphas_cumprod,
80 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
81 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
82 | 'beta_min': beta_start * (num_diffusion_timesteps - 1),
83 | 'beta_max': beta_end * (num_diffusion_timesteps - 1),
84 | 'num_diffusion_timesteps': num_diffusion_timesteps
85 | }
86 |
87 |
88 | def create_model(config):
89 | """Create the score model."""
90 | model_name = config.name
91 | score_model = get_model(model_name)(config)
92 | score_model = score_model.to(config.device)
93 | score_model = torch.nn.DataParallel(score_model)
94 | return score_model
95 |
96 |
97 | def get_model_fn(model, train=False):
98 | """Create a function to give the output of the score-based model.
99 |
100 | Args:
101 | model: The score model.
102 | train: `True` for training and `False` for evaluation.
103 |
104 | Returns:
105 | A model function.
106 | """
107 |
108 | def model_fn(x, labels):
109 | """Compute the output of the score-based model.
110 |
111 | Args:
112 | x: A mini-batch of input data.
113 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
114 | for different models.
115 |
116 | Returns:
117 | A tuple of (model output, new mutable states)
118 | """
119 | if not train:
120 | model.eval()
121 | return model(x, labels)
122 | else:
123 | model.train()
124 | return model(x, labels)
125 |
126 | return model_fn
127 |
128 | '''
129 | def get_score_fn(sde, model, train=False, continuous=False):
130 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
131 |
132 | Args:
133 | sde: An `sde_lib.SDE` object that represents the forward SDE.
134 | model: A score model.
135 | train: `True` for training and `False` for evaluation.
136 | continuous: If `True`, the score-based model is expected to directly take continuous time steps.
137 |
138 | Returns:
139 | A score function.
140 | """
141 | model_fn = get_model_fn(model, train=train)
142 |
143 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
144 | def score_fn(x, t):
145 | # Scale neural network output by standard deviation and flip sign
146 | if continuous or isinstance(sde, sde_lib.subVPSDE):
147 | # For VP-trained models, t=0 corresponds to the lowest noise level
148 | # The maximum value of time embedding is assumed to 999 for
149 | # continuously-trained models.
150 | labels = t * 999
151 | score = model_fn(x, labels)
152 | std = sde.marginal_prob(torch.zeros_like(x), t)[1]
153 | else:
154 | # For VP-trained models, t=0 corresponds to the lowest noise level
155 | labels = t * (sde.N - 1)
156 | score = model_fn(x, labels)
157 | std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
158 |
159 | score = -score / std[:, None, None, None]
160 | return score
161 |
162 | elif isinstance(sde, sde_lib.VESDE):
163 | def score_fn(x, t):
164 | if continuous:
165 | labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
166 | else:
167 | # For VE-trained models, t=0 corresponds to the highest noise level
168 | labels = sde.T - t
169 | labels *= sde.N - 1
170 | labels = torch.round(labels).long()
171 |
172 | score = model_fn(x, labels)
173 | return score
174 |
175 | else:
176 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
177 |
178 | return score_fn
179 | '''
180 |
181 | def to_flattened_numpy(x):
182 | """Flatten a torch tensor `x` and convert it to numpy."""
183 | return x.detach().cpu().numpy().reshape((-1,))
184 |
185 |
186 | def from_flattened_numpy(x, shape):
187 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
188 | return torch.from_numpy(x.reshape(shape))
--------------------------------------------------------------------------------