├── .idea ├── .gitignore ├── Mask-Conditioned-Latent-Space-Diffusion.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── LICENSE ├── ReadMe.md ├── assets ├── denoising_process │ ├── 2018 │ │ ├── 0.png │ │ ├── 1.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ ├── 17.png │ │ ├── 18.png │ │ ├── 19.png │ │ ├── 2.png │ │ ├── 20.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ ├── 9.png │ │ └── test.gif │ ├── 3063 │ │ ├── 0.png │ │ ├── 1.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ ├── 17.png │ │ ├── 18.png │ │ ├── 19.png │ │ ├── 2.png │ │ ├── 20.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ ├── 9.png │ │ └── test.gif │ └── 5096 │ │ ├── 0.png │ │ ├── 1.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ ├── 17.png │ │ ├── 18.png │ │ ├── 19.png │ │ ├── 2.png │ │ ├── 20.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ ├── 9.png │ │ └── test.gif ├── framework.png └── teaser.png ├── configs ├── BIPED_sample.yaml ├── BIPED_train.yaml ├── BSDS_sample.yaml ├── BSDS_train.yaml ├── NYUD_sample.yaml ├── NYUD_train.yaml ├── default.yaml └── first_stage_d4.yaml ├── demo.py ├── demo_trt.py ├── denoising_diffusion_pytorch ├── __init__.py ├── data.py ├── ddm_const_sde.py ├── efficientnet.py ├── ema.py ├── encoder_decoder.py ├── imagenet.py ├── loss.py ├── mask_cond_unet.py ├── quantization.py ├── resnet.py ├── swin_transformer.py ├── uncond_unet.py ├── utils.py ├── vgg.py ├── wavelet.py └── wcc.py ├── metrics ├── datasets.py ├── defaults.py ├── feature_extractor_base.py ├── feature_extractor_inceptionv3.py ├── generative_model_base.py ├── helpers.py ├── interpolate_compat_tensorflow.py ├── metric.py ├── metric_fid.py ├── metric_isc.py ├── metric_kid.py ├── metric_ppl.py ├── noise.py ├── registry.py ├── sample_similarity_base.py ├── sample_similarity_lpips.py └── utils.py ├── requirement.txt ├── sample_cond_ldm.py ├── taming ├── __init__.py ├── data │ ├── ade20k.py │ ├── annotated_objects_coco.py │ ├── annotated_objects_dataset.py │ ├── annotated_objects_open_images.py │ ├── base.py │ ├── coco.py │ ├── conditional_builder │ │ ├── objects_bbox.py │ │ ├── objects_center_points.py │ │ └── utils.py │ ├── custom.py │ ├── faceshq.py │ ├── helper_types.py │ ├── image_transforms.py │ ├── imagenet.py │ ├── open_images_helper.py │ ├── sflckr.py │ └── utils.py ├── modules │ ├── autoencoder │ │ └── lpips │ │ │ └── vgg.pth │ ├── diffusionmodules │ │ └── model.py │ ├── discriminator │ │ └── model.py │ ├── losses │ │ ├── __init__.py │ │ ├── lpips.py │ │ ├── segmentation.py │ │ ├── util.py │ │ └── vqperceptual.py │ ├── misc │ │ └── coord.py │ ├── util.py │ └── vqvae │ │ └── quantize.py └── util.py ├── train_cond_ldm.py ├── train_vae.py └── unet_plus ├── __init__.py ├── ema.py ├── layers.py ├── layerspp.py ├── ncsnpp.py ├── ncsnpp2.py ├── ncsnpp3.py ├── ncsnpp4.py ├── ncsnpp5.py ├── ncsnpp6.py ├── ncsnpp7.py ├── ncsnpp8.py ├── ncsnpp9.py ├── ncsnv2.py ├── normalization.py ├── op ├── __init__.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── unet_pp.py ├── up_or_down_sampling.py └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/.idea/.gitignore -------------------------------------------------------------------------------- /.idea/Mask-Conditioned-Latent-Space-Diffusion.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 16 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | ## DiffusionEdge: Diffusion Probabilistic Model for Crisp Edge Detection ([arxiv](https://arxiv.org/abs/2401.02032)) 2 | [Yunfan Ye](https://yunfan1202.github.io), [Yuhang Huang](https://github.com/GuHuangAI), [Renjiao Yi](https://renjiaoyi.github.io/), [Zhiping Cai](), [Kai Xu](http://kevinkaixu.net/index.html). 3 | 4 | ![Teaser](assets/teaser.png) 5 | ![](assets/denoising_process/3063/test.gif) 6 | ![](assets/denoising_process/5096/test.gif) 7 | 8 | # News 9 | - We release a real-time model trained on BSDS, please see **[Real-time DiffusionEdge](#vi-real-time-diffusionedge)**. 10 | - We create a [WeChat Group](https://github.com/GuHuangAI/DiffusionEdge/issues/17) for flexible discussion. 11 | Please use WeChat APP to scan the QR code. 12 | - 2023-12-09: The paper is accepted by **AAAI-2024**. 13 | - Upload the pretrained **first stage checkpoint** [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1.1/first_stage_total_320.pt). 14 | - Upload **pretrained weights** and **pre-computed results**. 15 | - We now update a simple demo, please see **[Quickly Demo](#iii-quickly-demo-)** 16 | - First Committed. 17 | 18 | ## I. Before Starting. 19 | 1. install torch 20 | ~~~ 21 | conda create -n diffedge python=3.9 22 | conda activate diffedge 23 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 24 | ~~~ 25 | 2. install other packages. 26 | ~~~ 27 | pip install -r requirement.txt 28 | ~~~ 29 | 3. prepare accelerate config. 30 | ~~~ 31 | accelerate config 32 | ~~~ 33 | 34 | ## II. Prepare Data. 35 | The training data structure should look like: 36 | ```commandline 37 | |-- $data_root 38 | | |-- image 39 | | |-- |-- raw 40 | | |-- |-- |-- XXXXX.jpg 41 | | |-- |-- |-- XXXXX.jpg 42 | | |-- edge 43 | | |-- |-- raw 44 | | |-- |-- |-- XXXXX.png 45 | | |-- |-- |-- XXXXX.png 46 | ``` 47 | The testing data structure should look like: 48 | ```commandline 49 | |-- $data_root 50 | | |-- XXXXX.jpg 51 | | |-- XXXXX.jpg 52 | ``` 53 | 54 | ## III. Quickly Demo ! 55 | 1. download the pretrained weights: 56 | 57 | | Dataset | ODS (SEval/CEval) | OIS (SEval/CEval) | AC | Weight | Pre-computed results | 58 | |---------|--------------------------------------------------------------------|--------------------------------------------------------------------|-------|----------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------| 59 | | BSDS | 0.834 / 0.749 | 0.848 / 0.754 | 0.476 | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1.1/bsds.pt) | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1/results_bsds_stride240_step5.zip) | 60 | | NYUD | 0.761 / 0.732 | 0.766 / 0.738 | 0.846 | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1.1/nyud.pt) | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1/results_nyud_stride240_step5.zip) | 61 | | BIPED | 0.899 | 0.901 | 0.849 | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1.1/biped.pt) | [download](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1/results_biped_stride240_step5.zip) | 62 | 63 | 2. put your images in a directory and run: 64 | ~~~ 65 | python demo.py --input_dir $your input dir$ --pre_weight $the downloaded weight path$ --out_dir $the path saves your results$ --bs 8 66 | ~~~ 67 | The larger `--bs` is, the faster the inference speed is and the larger the CUDA memory is. 68 | 69 | ## IV. Training. 70 | 1. train the first stage model (AutoEncoder): 71 | ~~~[inference_numpy_for_slide.py](..%2F..%2F..%2F..%2Fmedia%2Fhuang%2F2da18d46-7cba-4259-9abd-0df819bb104c%2Finference_numpy_for_slide.py) 72 | accelerate launch train_vae.py --cfg ./configs/first_stage_d4.yaml 73 | ~~~ 74 | 2. you should add the final model weight of the first stage to the config file `./configs/BSDS_train.yaml` (**line 42**), then train latent diffusion-edge model: 75 | ~~~ 76 | accelerate launch train_cond_ldm.py --cfg ./configs/BSDS_train.yaml 77 | ~~~ 78 | 79 | ## V. Inference. 80 | make sure your model weight path is added in the config file `./configs/BSDS_sample.yaml` (**line 73**), and run: 81 | ~~~ 82 | python sample_cond_ldm.py --cfg ./configs/BSDS_sample.yaml 83 | ~~~ 84 | Note that you can modify the `sampling_timesteps` (**line 11**) to control the inference speed. 85 | 86 | ## VI. Real-time DiffusionEdge. 87 | 1. We now only test in the following environment, and more details will be released soon. 88 | 89 | | Environment | Version | 90 | |-------------|---------| 91 | | TensorRT | 8.6.1 | 92 | | cuda | 11.6 | 93 | | cudnn | 8.7.0 | 94 | | pycuda | 2024.1 | 95 | 96 | Please follow this [link](https://github.com/NVIDIA/TensorRT) to install TensorRT. 97 | 98 | 2. Download the pretrained [weight](https://github.com/GuHuangAI/DiffusionEdge/releases/download/v1.1/model_crop_size_256_fps_150_ods_0813_ois_0825.trt). 99 | Real-time, qi~dong! 100 | ~~~ 101 | python demo_trt.py --input_dir $your input dir$ --pre_weight $the downloaded weight path$ --out_dir $the path saves your results$ 102 | ~~~ 103 | 104 | ## Contact 105 | If you have some questions, please contact with huangai@nudt.edu.cn. 106 | ## Thanks 107 | Thanks to the base code [DDM-Public](https://github.com/GuHuangAI/DDM-Public). 108 | ## Citation 109 | ~~~ 110 | @inproceedings{ye2024diffusionedge, 111 | title={DiffusionEdge: Diffusion Probabilistic Model for Crisp Edge Detection}, 112 | author={Yunfan Ye and Kai Xu and Yuhang Huang and Renjiao Yi and Zhiping Cai}, 113 | year={2024}, 114 | booktitle={AAAI} 115 | } 116 | ~~~ 117 | -------------------------------------------------------------------------------- /assets/denoising_process/2018/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/0.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/1.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/10.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/11.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/12.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/13.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/14.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/15.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/16.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/17.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/18.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/19.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/2.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/20.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/3.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/4.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/5.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/6.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/7.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/8.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/9.png -------------------------------------------------------------------------------- /assets/denoising_process/2018/test.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/2018/test.gif -------------------------------------------------------------------------------- /assets/denoising_process/3063/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/0.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/1.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/10.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/11.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/12.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/13.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/14.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/15.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/16.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/17.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/18.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/19.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/2.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/20.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/3.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/4.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/5.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/6.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/7.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/8.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/9.png -------------------------------------------------------------------------------- /assets/denoising_process/3063/test.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/3063/test.gif -------------------------------------------------------------------------------- /assets/denoising_process/5096/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/0.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/1.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/10.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/11.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/12.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/13.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/14.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/15.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/16.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/17.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/18.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/19.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/2.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/20.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/3.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/4.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/5.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/6.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/7.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/8.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/9.png -------------------------------------------------------------------------------- /assets/denoising_process/5096/test.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/denoising_process/5096/test.gif -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/framework.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/assets/teaser.png -------------------------------------------------------------------------------- /configs/BIPED_sample.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: const_sde 3 | model_name: cond_unet 4 | image_size: [320, 320] 5 | input_keys: ['image', 'cond'] 6 | ckpt_path: 7 | ignore_keys: [ ] 8 | only_model: False 9 | timesteps: 1000 10 | train_sample: -1 11 | sampling_timesteps: 5 12 | loss_type: l2 13 | objective: pred_KC 14 | start_dist: normal 15 | perceptual_weight: 0 16 | scale_factor: 0.3 17 | scale_by_std: True 18 | default_scale: True 19 | scale_by_softsign: False 20 | eps: !!float 1e-4 21 | weighting_loss: True 22 | first_stage: 23 | embed_dim: 3 24 | lossconfig: 25 | disc_start: 50001 26 | kl_weight: 0.000001 27 | disc_weight: 0.5 28 | disc_in_channels: 1 29 | ddconfig: 30 | double_z: True 31 | z_channels: 3 32 | resolution: [ 320, 320 ] 33 | in_channels: 1 34 | out_ch: 1 35 | ch: 128 36 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 37 | num_res_blocks: 2 38 | attn_resolutions: [ ] 39 | dropout: 0.0 40 | ckpt_path: 'checkpoints/first_stage_total_320.pt' 41 | unet: 42 | dim: 128 43 | cond_net: swin 44 | channels: 3 45 | out_mul: 1 46 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults) 47 | cond_in_dim: 3 48 | cond_dim: 128 49 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults) 50 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 51 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 52 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 53 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 54 | fourier_scale: 16 55 | cond_pe: False 56 | num_pos_feats: 128 57 | cond_feature_size: [ 80, 80 ] 58 | 59 | data: 60 | name: edge 61 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BIPED_test' 62 | augment_horizontal_flip: True 63 | batch_size: 8 64 | num_workers: 4 65 | 66 | sampler: 67 | sample_type: "slide" 68 | stride: [240, 240] 69 | batch_size: 1 70 | sample_num: 300 71 | use_ema: False 72 | save_folder: "./results" 73 | ckpt_path: "/data/huang/diffusion_edge/checkpoints/BIPED_size320_swin_unet12_no_resize_disloss/model-9.pt" -------------------------------------------------------------------------------- /configs/BIPED_train.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: const_sde 3 | model_name: cond_unet 4 | image_size: [320, 320] 5 | input_keys: ['image', 'cond'] 6 | ckpt_path: 7 | ignore_keys: [ ] 8 | only_model: False 9 | timesteps: 1000 10 | train_sample: -1 11 | sampling_timesteps: 10 12 | loss_type: l2 13 | objective: pred_KC 14 | start_dist: normal 15 | perceptual_weight: 0 16 | scale_factor: 0.3 17 | scale_by_std: True 18 | default_scale: True 19 | scale_by_softsign: False 20 | eps: !!float 1e-4 21 | weighting_loss: True 22 | use_disloss: True 23 | 24 | first_stage: 25 | embed_dim: 3 26 | lossconfig: 27 | disc_start: 50001 28 | kl_weight: 0.000001 29 | disc_weight: 0.5 30 | disc_in_channels: 1 31 | ddconfig: 32 | double_z: True 33 | z_channels: 3 34 | resolution: [ 320, 320 ] 35 | in_channels: 1 36 | out_ch: 1 37 | ch: 128 38 | ch_mult: [1, 2, 4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 2 40 | attn_resolutions: [ ] 41 | dropout: 0.0 42 | ckpt_path: './checkpoints/first_stage_total_320.pt' 43 | 44 | unet: 45 | dim: 128 46 | cond_net: swin 47 | fix_bb: False 48 | channels: 3 49 | out_mul: 1 50 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults) 51 | cond_in_dim: 3 52 | cond_dim: 128 53 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults) 54 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 55 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 56 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 57 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 58 | fourier_scale: 16 59 | cond_pe: False 60 | num_pos_feats: 128 61 | cond_feature_size: [ 80, 80 ] 62 | input_size: [80, 80] 63 | 64 | data: 65 | name: edge 66 | crop_type: rand_crop 67 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BIPED_100%' 68 | augment_horizontal_flip: True 69 | batch_size: 8 70 | num_workers: 8 71 | 72 | trainer: 73 | gradient_accumulate_every: 1 74 | lr: !!float 5e-5 75 | min_lr: !!float 5e-6 76 | train_num_steps: 100000 77 | save_and_sample_every: 5000 78 | enable_resume: False 79 | log_freq: 1000 80 | results_folder: "./training/BIPED_size320_swin_unet12_no_resize_disloss" 81 | amp: False 82 | fp16: False 83 | resume_milestone: 0 84 | test_before: True 85 | ema_update_after_step: 10000 86 | ema_update_every: 10 87 | -------------------------------------------------------------------------------- /configs/BSDS_sample.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: const_sde 3 | model_name: cond_unet 4 | image_size: [320, 320] 5 | input_keys: ['image', 'cond'] 6 | ckpt_path: 7 | ignore_keys: [ ] 8 | only_model: False 9 | timesteps: 1000 10 | train_sample: -1 11 | sampling_timesteps: 5 12 | loss_type: l2 13 | objective: pred_noise 14 | start_dist: normal 15 | perceptual_weight: 0 16 | scale_factor: 0.3 17 | scale_by_std: True 18 | default_scale: True 19 | scale_by_softsign: False 20 | eps: !!float 1e-4 21 | weighting_loss: False 22 | first_stage: 23 | embed_dim: 3 24 | lossconfig: 25 | disc_start: 50001 26 | kl_weight: 0.000001 27 | disc_weight: 0.5 28 | disc_in_channels: 1 29 | ddconfig: 30 | double_z: True 31 | z_channels: 3 32 | resolution: [ 320, 320 ] 33 | in_channels: 1 34 | out_ch: 1 35 | ch: 128 36 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 37 | num_res_blocks: 2 38 | attn_resolutions: [ ] 39 | dropout: 0.0 40 | ckpt_path: 'checkpoints/first_stage_total_320.pt' 41 | unet: 42 | dim: 128 43 | cond_net: swin 44 | channels: 3 45 | out_mul: 1 46 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults) 47 | cond_in_dim: 3 48 | cond_dim: 128 49 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults) 50 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 51 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 52 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 53 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 54 | fourier_scale: 16 55 | cond_pe: False 56 | num_pos_feats: 128 57 | cond_feature_size: [ 80, 80 ] 58 | 59 | data: 60 | name: edge 61 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BSDS_test' 62 | augment_horizontal_flip: True 63 | batch_size: 8 64 | num_workers: 4 65 | 66 | sampler: 67 | sample_type: "slide" 68 | stride: [240, 240] 69 | batch_size: 1 70 | sample_num: 300 71 | use_ema: False 72 | save_folder: './results' 73 | ckpt_path: "/data/huang/diffusion_edge/checkpoints/BSDS_swin_unet12_disloss_bs2x8/model-4.pt" -------------------------------------------------------------------------------- /configs/BSDS_train.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: const_sde 3 | model_name: cond_unet 4 | image_size: [320, 320] 5 | input_keys: ['image', 'cond'] 6 | ckpt_path: 7 | ignore_keys: [ ] 8 | only_model: False 9 | timesteps: 1000 10 | train_sample: -1 11 | sampling_timesteps: 1 12 | loss_type: l2 13 | objective: pred_KC 14 | start_dist: normal 15 | perceptual_weight: 0 16 | scale_factor: 0.3 17 | scale_by_std: True 18 | default_scale: True 19 | scale_by_softsign: False 20 | eps: !!float 1e-4 21 | weighting_loss: True 22 | use_disloss: True 23 | 24 | first_stage: 25 | embed_dim: 3 26 | lossconfig: 27 | disc_start: 50001 28 | kl_weight: 0.000001 29 | disc_weight: 0.5 30 | disc_in_channels: 1 31 | ddconfig: 32 | double_z: True 33 | z_channels: 3 34 | resolution: [ 320, 320 ] 35 | in_channels: 1 36 | out_ch: 1 37 | ch: 128 38 | ch_mult: [1, 2, 4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 2 40 | attn_resolutions: [ ] 41 | dropout: 0.0 42 | ckpt_path: './checkpoints/first_stage_total_320.pt' # the weight is obtained by training the first stage model 43 | 44 | unet: 45 | dim: 128 46 | cond_net: swin 47 | fix_bb: False 48 | channels: 3 49 | out_mul: 1 50 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults) 51 | cond_in_dim: 3 52 | cond_dim: 128 53 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults) 54 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 55 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 56 | fourier_scale: 16 57 | cond_pe: False 58 | num_pos_feats: 128 59 | cond_feature_size: [ 80, 80 ] 60 | input_size: [80, 80] 61 | 62 | data: 63 | name: edge 64 | crop_type: rand_resize_crop 65 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BSDS_100%' 66 | augment_horizontal_flip: True 67 | batch_size: 2 68 | num_workers: 8 69 | 70 | trainer: 71 | gradient_accumulate_every: 8 72 | lr: !!float 5e-5 73 | min_lr: !!float 5e-6 74 | train_num_steps: 100000 75 | save_and_sample_every: 5000 76 | enable_resume: False 77 | log_freq: 1000 78 | results_folder: "./training/BSDS_swin_unet12_disloss_bs2x8" 79 | amp: False 80 | fp16: False 81 | resume_milestone: 0 82 | test_before: True 83 | ema_update_after_step: 10000 84 | ema_update_every: 10 85 | -------------------------------------------------------------------------------- /configs/NYUD_sample.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: const_sde 3 | model_name: cond_unet 4 | image_size: [320, 320] 5 | input_keys: ['image', 'cond'] 6 | ckpt_path: 7 | ignore_keys: [ ] 8 | only_model: False 9 | timesteps: 1000 10 | train_sample: -1 11 | sampling_timesteps: 5 12 | loss_type: l2 13 | objective: pred_KC 14 | start_dist: normal 15 | perceptual_weight: 0 16 | scale_factor: 0.3 17 | scale_by_std: True 18 | default_scale: True 19 | scale_by_softsign: False 20 | eps: !!float 1e-4 21 | weighting_loss: True 22 | first_stage: 23 | embed_dim: 3 24 | lossconfig: 25 | disc_start: 50001 26 | kl_weight: 0.000001 27 | disc_weight: 0.5 28 | disc_in_channels: 1 29 | ddconfig: 30 | double_z: True 31 | z_channels: 3 32 | resolution: [ 320, 320 ] 33 | in_channels: 1 34 | out_ch: 1 35 | ch: 128 36 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 37 | num_res_blocks: 2 38 | attn_resolutions: [ ] 39 | dropout: 0.0 40 | ckpt_path: 'checkpoints/first_stage_total_320.pt' 41 | unet: 42 | dim: 128 43 | cond_net: swin 44 | channels: 3 45 | out_mul: 1 46 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults) 47 | cond_in_dim: 3 48 | cond_dim: 128 49 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults) 50 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 51 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 52 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 53 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 54 | fourier_scale: 16 55 | cond_pe: False 56 | num_pos_feats: 128 57 | cond_feature_size: [ 80, 80 ] 58 | 59 | data: 60 | name: edge 61 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/NYUD_100%/test/image' 62 | augment_horizontal_flip: True 63 | batch_size: 8 64 | num_workers: 4 65 | 66 | sampler: 67 | sample_type: "slide" 68 | stride: [240, 240] 69 | batch_size: 1 70 | sample_num: 300 71 | use_ema: False 72 | save_folder: "./results" 73 | ckpt_path: "/data/huang/diffusion_edge/checkpoints/NYUD_swin_unet12_no_resize_disloss//model-17.pt" -------------------------------------------------------------------------------- /configs/NYUD_train.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: const_sde 3 | model_name: cond_unet 4 | image_size: [320, 320] 5 | input_keys: ['image', 'cond'] 6 | ckpt_path: 7 | ignore_keys: [ ] 8 | only_model: False 9 | timesteps: 1000 10 | train_sample: -1 11 | sampling_timesteps: 10 12 | loss_type: l2 13 | objective: pred_KC 14 | start_dist: normal 15 | perceptual_weight: 0 16 | scale_factor: 0.3 17 | scale_by_std: True 18 | default_scale: True 19 | scale_by_softsign: False 20 | eps: !!float 1e-4 21 | weighting_loss: True 22 | use_disloss: True 23 | 24 | first_stage: 25 | embed_dim: 3 26 | lossconfig: 27 | disc_start: 50001 28 | kl_weight: 0.000001 29 | disc_weight: 0.5 30 | disc_in_channels: 1 31 | ddconfig: 32 | double_z: True 33 | z_channels: 3 34 | resolution: [ 320, 320 ] 35 | in_channels: 1 36 | out_ch: 1 37 | ch: 128 38 | ch_mult: [1, 2, 4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 2 40 | attn_resolutions: [ ] 41 | dropout: 0.0 42 | ckpt_path: './checkpoints/first_stage_total_320.pt' 43 | 44 | unet: 45 | dim: 128 46 | cond_net: swin 47 | fix_bb: False 48 | channels: 3 49 | out_mul: 1 50 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults) 51 | cond_in_dim: 3 52 | cond_dim: 128 53 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults) 54 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 55 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 56 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 57 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 58 | fourier_scale: 16 59 | cond_pe: False 60 | num_pos_feats: 128 61 | cond_feature_size: [ 80, 80 ] 62 | input_size: [80, 80] 63 | 64 | data: 65 | name: edge 66 | crop_type: rand_crop 67 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/NYUD_100%/train' 68 | augment_horizontal_flip: True 69 | batch_size: 8 70 | num_workers: 8 71 | 72 | trainer: 73 | gradient_accumulate_every: 2 74 | lr: !!float 5e-5 75 | min_lr: !!float 5e-6 76 | train_num_steps: 100000 77 | save_and_sample_every: 5000 78 | enable_resume: False 79 | log_freq: 1000 80 | results_folder: "./training/NYUD_swin_unet12_no_resize_disloss" 81 | amp: False 82 | fp16: False 83 | resume_milestone: 0 84 | test_before: True 85 | ema_update_after_step: 10000 86 | ema_update_every: 10 87 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: const_sde 3 | model_name: cond_unet 4 | image_size: [320, 320] 5 | input_keys: ['image', 'cond'] 6 | ckpt_path: 7 | ignore_keys: [ ] 8 | only_model: False 9 | timesteps: 1000 10 | train_sample: -1 11 | sampling_timesteps: 1 12 | loss_type: l2 13 | objective: pred_noise 14 | start_dist: normal 15 | perceptual_weight: 0 16 | scale_factor: 0.3 17 | scale_by_std: True 18 | default_scale: True 19 | scale_by_softsign: False 20 | eps: !!float 1e-4 21 | weighting_loss: False 22 | first_stage: 23 | embed_dim: 3 24 | lossconfig: 25 | disc_start: 50001 26 | kl_weight: 0.000001 27 | disc_weight: 0.5 28 | disc_in_channels: 1 29 | ddconfig: 30 | double_z: True 31 | z_channels: 3 32 | resolution: [ 320, 320 ] 33 | in_channels: 1 34 | out_ch: 1 35 | ch: 128 36 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 37 | num_res_blocks: 2 38 | attn_resolutions: [ ] 39 | dropout: 0.0 40 | ckpt_path: 41 | unet: 42 | dim: 128 43 | cond_net: swin 44 | without_pretrain: False 45 | channels: 3 46 | out_mul: 1 47 | dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults) 48 | cond_in_dim: 3 49 | cond_dim: 128 50 | cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults) 51 | # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 52 | # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ] 53 | window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 54 | window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] 55 | fourier_scale: 16 56 | cond_pe: False 57 | num_pos_feats: 128 58 | cond_feature_size: [ 80, 80 ] 59 | 60 | data: 61 | name: edge 62 | img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BSDS_test' 63 | augment_horizontal_flip: True 64 | batch_size: 8 65 | num_workers: 4 66 | 67 | sampler: 68 | sample_type: "slide" 69 | stride: [240, 240] 70 | batch_size: 1 71 | sample_num: 300 72 | use_ema: True 73 | save_folder: 74 | ckpt_path: -------------------------------------------------------------------------------- /configs/first_stage_d4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | embed_dim: 3 3 | lossconfig: 4 | disc_start: 50001 5 | kl_weight: 0.000001 6 | disc_weight: 0.5 7 | disc_in_channels: 1 8 | ddconfig: 9 | double_z: True 10 | z_channels: 3 11 | resolution: [320, 320] 12 | in_channels: 1 13 | out_ch: 1 14 | ch: 128 15 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 16 | num_res_blocks: 2 17 | attn_resolutions: [ ] 18 | dropout: 0.0 19 | ckpt_path: #'/home/zhuchenyang/project/hyx/data/pretrain_weights/model-kl-d4.ckpt' 20 | 21 | data: 22 | name: edge 23 | img_folder: '/home/zhuchenyang/project/hyx/data/total_edges' 24 | augment_horizontal_flip: True 25 | batch_size: 8 26 | 27 | trainer: 28 | gradient_accumulate_every: 2 29 | lr: !!float 5e-6 30 | min_lr: !!float 5e-7 31 | train_num_steps: 150000 32 | save_and_sample_every: 10000 33 | log_freq: 100 34 | results_folder: '/home/zhuchenyang/project/hyx/data/total_edges/results_ae_kl_320x320_d4' 35 | amp: False 36 | fp16: False -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | # from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer 2 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | # .path.append() 5 | from taming.modules.losses.vqperceptual import * 6 | 7 | 8 | class LPIPSWithDiscriminator(nn.Module): 9 | def __init__(self, *, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 10 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 11 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 12 | disc_loss="hinge"): 13 | 14 | super().__init__() 15 | assert disc_loss in ["hinge", "vanilla"] 16 | self.kl_weight = kl_weight 17 | self.pixel_weight = pixelloss_weight 18 | self.perceptual_loss = LPIPS().eval() 19 | self.perceptual_weight = perceptual_weight 20 | # output log variance 21 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 22 | 23 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 24 | n_layers=disc_num_layers, 25 | use_actnorm=use_actnorm 26 | ).apply(weights_init) 27 | self.discriminator_iter_start = disc_start 28 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 29 | self.disc_factor = disc_factor 30 | self.discriminator_weight = disc_weight 31 | self.disc_conditional = disc_conditional 32 | 33 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 34 | if last_layer is not None: 35 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 36 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 37 | else: 38 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 39 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 40 | 41 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 42 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 43 | d_weight = d_weight * self.discriminator_weight 44 | return d_weight 45 | 46 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 47 | global_step, last_layer=None, cond=None, split="train", 48 | weights=None): 49 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + \ 50 | F.mse_loss(inputs, reconstructions, reduction="none") 51 | if self.perceptual_weight > 0: 52 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 53 | rec_loss = rec_loss + self.perceptual_weight * p_loss 54 | 55 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 56 | weighted_nll_loss = nll_loss 57 | if weights is not None: 58 | weighted_nll_loss = weights*nll_loss 59 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 60 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 61 | kl_loss = posteriors.kl() 62 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 63 | 64 | # now the GAN part 65 | if optimizer_idx == 0: 66 | # generator update 67 | if cond is None: 68 | assert not self.disc_conditional 69 | logits_fake = self.discriminator(reconstructions.contiguous()) 70 | else: 71 | assert self.disc_conditional 72 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 73 | g_loss = -torch.mean(logits_fake) 74 | 75 | if self.disc_factor > 0.0: 76 | try: 77 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 78 | except RuntimeError: 79 | assert not self.training 80 | d_weight = torch.tensor(0.0) 81 | else: 82 | d_weight = torch.tensor(0.0) 83 | 84 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 85 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 86 | 87 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 88 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 89 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 90 | "{}/d_weight".format(split): d_weight.detach(), 91 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 92 | "{}/g_loss".format(split): g_loss.detach().mean(), 93 | } 94 | return loss, log 95 | 96 | if optimizer_idx == 1: 97 | # second pass for discriminator update 98 | if cond is None: 99 | logits_real = self.discriminator(inputs.contiguous().detach()) 100 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 101 | else: 102 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 103 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 104 | 105 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 106 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 107 | 108 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 109 | "{}/logits_real".format(split): logits_real.detach().mean(), 110 | "{}/logits_fake".format(split): logits_fake.detach().mean() 111 | } 112 | return d_loss, log 113 | 114 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/quantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import Parameter 4 | 5 | 6 | def weight_quantization(b): 7 | def uniform_quant(x, b): 8 | xdiv = x.mul((2 ** b - 1)) 9 | xhard = xdiv.round().div(2 ** b - 1) 10 | return xhard 11 | 12 | class _pq(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, input, alpha): 15 | input.div_(alpha) # weights are first divided by alpha 16 | input_c = input.clamp(min=-1, max=1) # then clipped to [-1,1] 17 | sign = input_c.sign() 18 | input_abs = input_c.abs() 19 | input_q = uniform_quant(input_abs, b).mul(sign) 20 | ctx.save_for_backward(input, input_q) 21 | input_q = input_q.mul(alpha) # rescale to the original range 22 | return input_q 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | grad_input = grad_output.clone() # grad for weights will not be clipped 27 | input, input_q = ctx.saved_tensors 28 | i = (input.abs() > 1.).float() 29 | sign = input.sign() 30 | grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum() 31 | return grad_input, grad_alpha 32 | 33 | return _pq().apply 34 | 35 | 36 | class weight_quantize_fn(nn.Module): 37 | def __init__(self, bit_w): 38 | super(weight_quantize_fn, self).__init__() 39 | assert bit_w > 0 40 | 41 | self.bit_w = bit_w - 1 42 | self.weight_q = weight_quantization(b=self.bit_w) 43 | self.register_parameter('w_alpha', Parameter(torch.tensor(3.0), requires_grad=True)) 44 | 45 | def forward(self, weight): 46 | mean = weight.data.mean() 47 | std = weight.data.std() 48 | weight = weight.add(-mean).div(std) # weights normalization 49 | weight_q = self.weight_q(weight, self.w_alpha) 50 | return weight_q 51 | 52 | def change_bit(self, bit_w): 53 | self.bit_w = bit_w - 1 54 | self.weight_q = weight_quantization(b=self.bit_w) 55 | 56 | def act_quantization(b, signed=False): 57 | def uniform_quant(x, b=3): 58 | xdiv = x.mul(2 ** b - 1) 59 | xhard = xdiv.round().div(2 ** b - 1) 60 | return xhard 61 | 62 | class _uq(torch.autograd.Function): 63 | @staticmethod 64 | def forward(ctx, input, alpha): 65 | input = input.div(alpha) 66 | input_c = input.clamp(min=-1, max=1) if signed else input.clamp(max=1) 67 | input_q = uniform_quant(input_c, b) 68 | ctx.save_for_backward(input, input_q) 69 | input_q = input_q.mul(alpha) 70 | return input_q 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | grad_input = grad_output.clone() 75 | input, input_q = ctx.saved_tensors 76 | i = (input.abs() > 1.).float() 77 | sign = input.sign() 78 | grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum() 79 | grad_input = grad_input * (1 - i) 80 | return grad_input, grad_alpha 81 | 82 | return _uq().apply 83 | 84 | class act_quantize_fn(nn.Module): 85 | def __init__(self, bit_a, signed=False): 86 | super(act_quantize_fn, self).__init__() 87 | self.bit_a = bit_a 88 | self.signed = signed 89 | if signed: 90 | self.bit_a -= 1 91 | assert bit_a > 0 92 | 93 | self.act_q = act_quantization(b=self.bit_a, signed=signed) 94 | self.register_parameter('a_alpha', Parameter(torch.tensor(8.0), requires_grad=True)) 95 | 96 | def forward(self, x): 97 | return self.act_q(x, self.a_alpha) 98 | 99 | def change_bit(self, bit_a): 100 | self.bit_a = bit_a 101 | if self.signed: 102 | self.bit_a -= 1 103 | self.act_q = act_quantization(b=self.bit_a, signed=self.signed) -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import time 4 | import logging 5 | import math 6 | 7 | def create_logger(root_dir, des=''): 8 | root_output_dir = Path(root_dir) 9 | # set up logger 10 | if not root_output_dir.exists(): 11 | print('=> creating {}'.format(root_output_dir)) 12 | root_output_dir.mkdir(exist_ok=True, parents=True) 13 | time_str = time.strftime('%Y-%m-%d-%H-%M') 14 | log_file = '{}_{}.log'.format(time_str, des) 15 | final_log_file = root_output_dir / log_file 16 | head = '%(asctime)-15s %(message)s' 17 | logging.basicConfig(filename=str(final_log_file), format=head) 18 | logger = logging.getLogger() 19 | logger.setLevel(logging.INFO) 20 | console = logging.StreamHandler() 21 | logging.getLogger('').addHandler(console) 22 | return logger 23 | 24 | def exists(x): 25 | return x is not None 26 | 27 | def default(val, d): 28 | if exists(val): 29 | return val 30 | return d() if callable(d) else d 31 | 32 | def identity(t, *args, **kwargs): 33 | return t 34 | 35 | def cycle(dl): 36 | while True: 37 | for data in dl: 38 | yield data 39 | 40 | def has_int_squareroot(num): 41 | return (math.sqrt(num) ** 2) == num 42 | 43 | def num_to_groups(num, divisor): 44 | groups = num // divisor 45 | remainder = num % divisor 46 | arr = [divisor] * groups 47 | if remainder > 0: 48 | arr.append(remainder) 49 | return arr 50 | 51 | def convert_image_to_fn(img_type, image): 52 | if image.mode != img_type: 53 | return image.convert(img_type) 54 | return image 55 | 56 | # normalization functions 57 | 58 | def normalize_to_neg_one_to_one(img): 59 | return img * 2 - 1 60 | 61 | def unnormalize_to_zero_to_one(t): 62 | return (t + 1) * 0.5 63 | 64 | def dict2str(dict): 65 | s = '' 66 | for k, v in dict.items(): 67 | s += "{}: {:.5f}, ".format(k, v) 68 | return s -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/wavelet.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | import pywt.data 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | import torch.nn.functional as F 7 | 8 | 9 | def create_wavelet_filter(wave, in_size, out_size, type=torch.float): 10 | w = pywt.Wavelet(wave) 11 | dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type) 12 | dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type) 13 | dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1), 14 | dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1), 15 | dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1), 16 | dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0) 17 | 18 | dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1) 19 | 20 | rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0]) 21 | rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0]) 22 | rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1), 23 | rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1), 24 | rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1), 25 | rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0) 26 | 27 | rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1) 28 | 29 | return dec_filters, rec_filters 30 | 31 | 32 | def wt(x, filters, in_size, level): 33 | _, _, h, w = x.shape 34 | pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1) 35 | res = F.conv2d(x, filters, stride=2, groups=in_size, padding=pad) 36 | if level > 1: 37 | res[:, ::4] = wt(res[:, ::4], filters, in_size, level - 1) 38 | res = res.reshape(-1, 2, h // 2, w // 2).transpose(1, 2).reshape(-1, in_size, h, w) 39 | return res 40 | 41 | 42 | def iwt(x, inv_filters, in_size, level): 43 | _, _, h, w = x.shape 44 | pad = (inv_filters.shape[2] // 2 - 1, inv_filters.shape[3] // 2 - 1) 45 | res = x.reshape(-1, h // 2, 2, w // 2).transpose(1, 2).reshape(-1, 4 * in_size, h // 2, w // 2) 46 | if level > 1: 47 | res[:, ::4] = iwt(res[:, ::4], inv_filters, in_size, level - 1) 48 | res = F.conv_transpose2d(res, inv_filters, stride=2, groups=in_size, padding=pad) 49 | return res 50 | 51 | 52 | def get_inverse_transform(weights, in_size, level): 53 | class InverseWaveletTransform(Function): 54 | 55 | @staticmethod 56 | def forward(ctx, input): 57 | with torch.no_grad(): 58 | x = iwt(input, weights, in_size, level) 59 | return x 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | grad = wt(grad_output, weights, in_size, level) 64 | return grad, None 65 | 66 | return InverseWaveletTransform().apply 67 | 68 | 69 | def get_transform(weights, in_size, level): 70 | class WaveletTransform(Function): 71 | 72 | @staticmethod 73 | def forward(ctx, input): 74 | with torch.no_grad(): 75 | x = wt(input, weights, in_size, level) 76 | return x 77 | 78 | @staticmethod 79 | def backward(ctx, grad_output): 80 | grad = iwt(grad_output, weights, in_size, level) 81 | return grad, None 82 | 83 | return WaveletTransform().apply -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/wcc.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | 3 | import torch 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | from denoising_diffusion_pytorch.quantization import weight_quantize_fn, act_quantize_fn 8 | from denoising_diffusion_pytorch import wavelet 9 | 10 | 11 | class WCC(nn.Conv1d): 12 | def __init__(self, in_channels: int, 13 | out_channels: int, 14 | stride: Union[int, Tuple] = 1, 15 | padding: Union[int, Tuple] = 0, 16 | dilation: Union[int, Tuple] = 1, 17 | groups: int = 1, 18 | bias: bool = False, 19 | levels: int = 3, 20 | compress_rate: float = 0.25, 21 | bit_w: int = 8, 22 | bit_a: int = 8, 23 | wt_type: str = "db1"): 24 | super(WCC, self).__init__(in_channels, out_channels, 1, stride, padding, dilation, groups, bias) 25 | self.layer_type = 'WCC' 26 | self.bit_w = bit_w 27 | self.bit_a = bit_a 28 | 29 | self.weight_quant = weight_quantize_fn(self.bit_w) 30 | self.act_quant = act_quantize_fn(self.bit_a, signed=True) 31 | 32 | self.levels = levels 33 | self.wt_type = wt_type 34 | self.compress_rate = compress_rate 35 | 36 | dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type, 37 | in_size=in_channels, 38 | out_size=out_channels) 39 | self.wt_filters = nn.Parameter(dec_filters, requires_grad=False) 40 | self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False) 41 | self.wt = wavelet.get_transform(self.wt_filters, in_channels, levels) 42 | self.iwt = wavelet.get_inverse_transform(self.iwt_filters, out_channels, levels) 43 | 44 | self.get_pad = lambda n: ((2 ** levels) - n) % (2 ** levels) 45 | 46 | def forward(self, x): 47 | in_shape = x.shape 48 | pads = (0, self.get_pad(in_shape[2]), 0, self.get_pad(in_shape[3])) 49 | x = F.pad(x, pads) # pad to match 2^(levels) 50 | 51 | weight_q = self.weight_quant(self.weight) # quantize weights 52 | x = self.wt(x) # H 53 | topk, ids = self.compress(x) # T 54 | topk_q = self.act_quant(topk) # quantize activations 55 | topk_q = F.conv1d(topk_q, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups) # K_1x1 56 | x = self.decompress(topk_q, ids, x.shape) # T^T 57 | x = self.iwt(x) # H^T 58 | 59 | x = x[:, :, :in_shape[2], :in_shape[3]] # remove pads 60 | return x 61 | 62 | def compress(self, x): 63 | b, c, h, w = x.shape 64 | acc = x.norm(dim=1).pow(2) 65 | acc = acc.view(b, h * w) 66 | k = int(h * w * self.compress_rate) 67 | ids = acc.topk(k, dim=1, sorted=False)[1] 68 | ids.unsqueeze_(dim=1) 69 | topk = x.reshape((b, c, h * w)).gather(dim=2, index=ids.repeat(1, c, 1)) 70 | return topk, ids 71 | 72 | def decompress(self, topk, ids, shape): 73 | b, _, h, w = shape 74 | ids = ids.repeat(1, self.out_channels, 1) 75 | x = torch.zeros(size=(b, self.out_channels, h * w), requires_grad=True, device=topk.device) 76 | x = x.scatter(dim=2, index=ids, src=topk) 77 | x = x.reshape((b, self.out_channels, h, w)) 78 | return x 79 | 80 | def change_wt_params(self, compress_rate, levels, wt_type="db1"): 81 | self.compress_rate = compress_rate 82 | self.levels = levels 83 | dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type, 84 | in_size=self.in_channels, 85 | out_size=self.out_channels) 86 | self.wt_filters = nn.Parameter(dec_filters, requires_grad=False) 87 | self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False) 88 | self.wt = wavelet.get_transform(self.wt_filters, self.in_channels, levels) 89 | self.iwt = wavelet.get_inverse_transform(self.iwt_filters, self.out_channels, levels) 90 | 91 | def change_bit(self, bit_w, bit_a): 92 | self.bit_w = bit_w 93 | self.bit_a = bit_a 94 | self.weight_quant.change_bit(bit_w) 95 | self.act_quant.change_bit(bit_a) 96 | 97 | if __name__ == '__main__': 98 | wcc = WCC(80, 80) 99 | x = torch.rand(1, 80, 80, 80) 100 | y = wcc(x) 101 | pause = 0 -------------------------------------------------------------------------------- /metrics/datasets.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from contextlib import redirect_stdout 3 | 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | from torchvision.datasets import CIFAR10, STL10 8 | 9 | from metrics.helpers import vassert 10 | 11 | 12 | class TransformPILtoRGBTensor: 13 | def __call__(self, img): 14 | vassert(type(img) is Image.Image, 'Input is not a PIL.Image') 15 | width, height = img.size 16 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3) 17 | img = img.permute(2, 0, 1) 18 | return img 19 | 20 | 21 | class ImagesPathDataset(Dataset): 22 | def __init__(self, files, transforms=None): 23 | self.files = files 24 | self.transforms = TransformPILtoRGBTensor() if transforms is None else transforms 25 | 26 | def __len__(self): 27 | return len(self.files) 28 | 29 | def __getitem__(self, i): 30 | path = self.files[i] 31 | img = Image.open(path).convert('RGB') 32 | img = self.transforms(img) 33 | return img 34 | 35 | 36 | class Cifar10_RGB(CIFAR10): 37 | def __init__(self, *args, **kwargs): 38 | with redirect_stdout(sys.stderr): 39 | super().__init__(*args, **kwargs) 40 | 41 | def __getitem__(self, index): 42 | img, target = super().__getitem__(index) 43 | return img 44 | 45 | 46 | class STL10_RGB(STL10): 47 | def __init__(self, *args, **kwargs): 48 | with redirect_stdout(sys.stderr): 49 | super().__init__(*args, **kwargs) 50 | 51 | def __getitem__(self, index): 52 | img, target = super().__getitem__(index) 53 | return img 54 | 55 | 56 | class RandomlyGeneratedDataset(Dataset): 57 | def __init__(self, num_samples, *dimensions, dtype=torch.uint8, seed=2021): 58 | vassert(dtype == torch.uint8, 'Unsupported dtype') 59 | rng_stash = torch.get_rng_state() 60 | try: 61 | torch.manual_seed(seed) 62 | self.imgs = torch.randint(0, 255, (num_samples, *dimensions), dtype=dtype) 63 | finally: 64 | torch.set_rng_state(rng_stash) 65 | 66 | def __len__(self): 67 | return self.imgs.shape[0] 68 | 69 | def __getitem__(self, i): 70 | return self.imgs[i] 71 | -------------------------------------------------------------------------------- /metrics/defaults.py: -------------------------------------------------------------------------------- 1 | DEFAULTS = { 2 | 'input1': None, 3 | 'input2': None, 4 | 'cuda': True, 5 | 'batch_size': 64, 6 | 'isc': False, 7 | 'fid': False, 8 | 'kid': False, 9 | 'ppl': False, 10 | 'feature_extractor': 'inception-v3-compat', 11 | 'feature_layer_isc': 'logits_unbiased', 12 | 'feature_layer_fid': '2048', 13 | 'feature_layer_kid': '2048', 14 | 'feature_extractor_weights_path': None, 15 | 'isc_splits': 10, 16 | 'kid_subsets': 100, 17 | 'kid_subset_size': 1000, 18 | 'kid_degree': 3, 19 | 'kid_gamma': None, 20 | 'kid_coef0': 1, 21 | 'ppl_epsilon': 1e-4, 22 | 'ppl_reduction': 'mean', 23 | 'ppl_sample_similarity': 'lpips-vgg16', 24 | 'ppl_sample_similarity_resize': 64, 25 | 'ppl_sample_similarity_dtype': 'uint8', 26 | 'ppl_discard_percentile_lower': 1, 27 | 'ppl_discard_percentile_higher': 99, 28 | 'ppl_z_interp_mode': 'lerp', 29 | 'samples_shuffle': True, 30 | 'samples_find_deep': False, 31 | 'samples_find_ext': 'png,jpg,jpeg', 32 | 'samples_ext_lossy': 'jpg,jpeg', 33 | 'datasets_root': None, 34 | 'datasets_download': True, 35 | 'cache_root': None, 36 | 'cache': True, 37 | 'input1_cache_name': None, 38 | 'input1_model_z_type': 'normal', 39 | 'input1_model_z_size': None, 40 | 'input1_model_num_classes': 0, 41 | 'input1_model_num_samples': None, 42 | 'input2_cache_name': None, 43 | 'input2_model_z_type': 'normal', 44 | 'input2_model_z_size': None, 45 | 'input2_model_num_classes': 0, 46 | 'input2_model_num_samples': None, 47 | 'rng_seed': 2020, 48 | 'save_cpu_ram': False, 49 | 'verbose': True, 50 | } 51 | -------------------------------------------------------------------------------- /metrics/feature_extractor_base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from metrics.helpers import vassert 4 | 5 | 6 | class FeatureExtractorBase(nn.Module): 7 | def __init__(self, name, features_list): 8 | """ 9 | Base class for feature extractors that can be used in :func:`calculate_metrics`. 10 | 11 | Args: 12 | 13 | name (str): Unique name of the subclassed feature extractor, must be the same as used in 14 | :func:`register_feature_extractor`. 15 | 16 | features_list (list): List of feature names, provided by the subclassed feature extractor. 17 | """ 18 | super(FeatureExtractorBase, self).__init__() 19 | vassert(type(name) is str, 'Feature extractor name must be a string') 20 | vassert(type(features_list) in (list, tuple), 'Wrong features list type') 21 | vassert( 22 | all((a in self.get_provided_features_list() for a in features_list)), 23 | f'Requested features {tuple(features_list)} are not on the list provided by the selected feature extractor ' 24 | f'{self.get_provided_features_list()}' 25 | ) 26 | vassert(len(features_list) == len(set(features_list)), 'Duplicate features requested') 27 | self.name = name 28 | self.features_list = features_list 29 | 30 | def get_name(self): 31 | return self.name 32 | 33 | @staticmethod 34 | def get_provided_features_list(): 35 | """ 36 | Returns a tuple of feature names, extracted by the subclassed feature extractor. 37 | """ 38 | raise NotImplementedError 39 | 40 | def get_requested_features_list(self): 41 | return self.features_list 42 | 43 | def convert_features_tuple_to_dict(self, features): 44 | # The only compound return type of the forward function amenable to JIT tracing is tuple. 45 | # This function simply helps to recover the mapping. 46 | vassert( 47 | type(features) is tuple and len(features) == len(self.features_list), 48 | 'Features must be the output of forward function' 49 | ) 50 | return dict(((name, feature) for name, feature in zip(self.features_list, features))) 51 | 52 | def forward(self, input): 53 | """ 54 | Returns a tuple of tensors extracted from the `input`, in the same order as they are provided by 55 | `get_provided_features_list()`. 56 | """ 57 | raise NotImplementedError 58 | -------------------------------------------------------------------------------- /metrics/generative_model_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class GenerativeModelBase(ABC, torch.nn.Module): 7 | """ 8 | Base class for generative models that can be used as inputs in :func:`calculate_metrics`. 9 | """ 10 | 11 | @property 12 | @abstractmethod 13 | def z_size(self): 14 | """ 15 | Size of the noise dimension of the generative model (positive integer). 16 | """ 17 | pass 18 | 19 | @property 20 | @abstractmethod 21 | def z_type(self): 22 | """ 23 | Type of the noise used by the generative model (see :ref:`registry ` for a list of preregistered noise 24 | types, see :func:`register_noise_source` for registering a new noise type). 25 | """ 26 | pass 27 | 28 | @property 29 | @abstractmethod 30 | def num_classes(self): 31 | """ 32 | Number of classes used by a conditional generative model. Must return zero for unconditional models. 33 | """ 34 | pass 35 | -------------------------------------------------------------------------------- /metrics/helpers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | from metrics.defaults import DEFAULTS 5 | 6 | 7 | def vassert(truecond, message): 8 | if not truecond: 9 | raise ValueError(message) 10 | 11 | 12 | def vprint(verbose, message): 13 | if verbose: 14 | print(message, file=sys.stderr) 15 | 16 | 17 | def get_kwarg(name, kwargs): 18 | return kwargs.get(name, DEFAULTS[name]) 19 | 20 | 21 | def json_decode_string(s): 22 | try: 23 | out = json.loads(s) 24 | except json.JSONDecodeError as e: 25 | print(f'Failed to decode JSON string: {s}', file=sys.stderr) 26 | raise 27 | return out 28 | -------------------------------------------------------------------------------- /metrics/interpolate_compat_tensorflow.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn.modules.utils import _ntuple 6 | 7 | 8 | def interpolate_bilinear_2d_like_tensorflow1x(input, size=None, scale_factor=None, align_corners=None, method='slow'): 9 | r"""Down/up samples the input to either the given :attr:`size` or the given :attr:`scale_factor` 10 | 11 | Epsilon-exact bilinear interpolation as it is implemented in TensorFlow 1.x: 12 | https://github.com/tensorflow/tensorflow/blob/f66daa493e7383052b2b44def2933f61faf196e0/tensorflow/core/kernels/image_resizer_state.h#L41 13 | https://github.com/tensorflow/tensorflow/blob/6795a8c3a3678fb805b6a8ba806af77ddfe61628/tensorflow/core/kernels/resize_bilinear_op.cc#L85 14 | as per proposal: 15 | https://github.com/pytorch/pytorch/issues/10604#issuecomment-465783319 16 | 17 | Related materials: 18 | https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35 19 | https://jricheimer.github.io/tensorflow/2019/02/11/resize-confusion/ 20 | https://machinethink.net/blog/coreml-upsampling/ 21 | 22 | Currently only 2D spatial sampling is supported, i.e. expected inputs are 4-D in shape. 23 | 24 | The input dimensions are interpreted in the form: 25 | `mini-batch x channels x height x width`. 26 | 27 | Args: 28 | input (Tensor): the input tensor 29 | size (Tuple[int, int]): output spatial size. 30 | scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. 31 | align_corners (bool, optional): Same meaning as in TensorFlow 1.x. 32 | method (str, optional): 33 | 'slow' (1e-4 L_inf error on GPU, bit-exact on CPU, with checkerboard 32x32->299x299), or 34 | 'fast' (1e-3 L_inf error on GPU and CPU, with checkerboard 32x32->299x299) 35 | """ 36 | if method not in ('slow', 'fast'): 37 | raise ValueError('how_exact can only be one of "slow", "fast"') 38 | 39 | if input.dim() != 4: 40 | raise ValueError('input must be a 4-D tensor') 41 | 42 | if not torch.is_floating_point(input): 43 | raise ValueError('input must be of floating point dtype') 44 | 45 | if size is not None and (type(size) not in (tuple, list) or len(size) != 2): 46 | raise ValueError('size must be a list or a tuple of two elements') 47 | 48 | if align_corners is None: 49 | raise ValueError('align_corners is not specified (use this function for a complete determinism)') 50 | 51 | def _check_size_scale_factor(dim): 52 | if size is None and scale_factor is None: 53 | raise ValueError('either size or scale_factor should be defined') 54 | if size is not None and scale_factor is not None: 55 | raise ValueError('only one of size or scale_factor should be defined') 56 | if scale_factor is not None and isinstance(scale_factor, tuple) and len(scale_factor) != dim: 57 | raise ValueError('scale_factor shape must match input shape. ' 58 | 'Input is {}D, scale_factor size is {}'.format(dim, len(scale_factor))) 59 | 60 | is_tracing = torch._C._get_tracing_state() 61 | 62 | def _output_size(dim): 63 | _check_size_scale_factor(dim) 64 | if size is not None: 65 | if is_tracing: 66 | return [torch.tensor(i) for i in size] 67 | else: 68 | return size 69 | scale_factors = _ntuple(dim)(scale_factor) 70 | # math.floor might return float in py2.7 71 | 72 | # make scale_factor a tensor in tracing so constant doesn't get baked in 73 | if is_tracing: 74 | return [ 75 | (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float())) 76 | for i in range(dim) 77 | ] 78 | else: 79 | return [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] 80 | 81 | def tf_calculate_resize_scale(in_size, out_size): 82 | if align_corners: 83 | if is_tracing: 84 | return (in_size - 1) / (out_size.float() - 1).clamp(min=1) 85 | else: 86 | return (in_size - 1) / max(1, out_size - 1) 87 | else: 88 | if is_tracing: 89 | return in_size / out_size.float() 90 | else: 91 | return in_size / out_size 92 | 93 | out_size = _output_size(2) 94 | scale_x = tf_calculate_resize_scale(input.shape[3], out_size[1]) 95 | scale_y = tf_calculate_resize_scale(input.shape[2], out_size[0]) 96 | 97 | def resample_using_grid_sample(): 98 | grid_x = torch.arange(0, out_size[1], 1, dtype=input.dtype, device=input.device) 99 | grid_x = grid_x * (2 * scale_x / (input.shape[3] - 1)) - 1 100 | 101 | grid_y = torch.arange(0, out_size[0], 1, dtype=input.dtype, device=input.device) 102 | grid_y = grid_y * (2 * scale_y / (input.shape[2] - 1)) - 1 103 | 104 | grid_x = grid_x.view(1, out_size[1]).repeat(out_size[0], 1) 105 | grid_y = grid_y.view(out_size[0], 1).repeat(1, out_size[1]) 106 | 107 | grid_xy = torch.cat((grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)), dim=2).unsqueeze(0) 108 | grid_xy = grid_xy.repeat(input.shape[0], 1, 1, 1) 109 | 110 | out = F.grid_sample(input, grid_xy, mode='bilinear', padding_mode='border', align_corners=True) 111 | return out 112 | 113 | def resample_manually(): 114 | grid_x = torch.arange(0, out_size[1], 1, dtype=input.dtype, device=input.device) 115 | grid_x = grid_x * torch.tensor(scale_x, dtype=torch.float32) 116 | grid_x_lo = grid_x.long() 117 | grid_x_hi = (grid_x_lo + 1).clamp_max(input.shape[3] - 1) 118 | grid_dx = grid_x - grid_x_lo.float() 119 | 120 | grid_y = torch.arange(0, out_size[0], 1, dtype=input.dtype, device=input.device) 121 | grid_y = grid_y * torch.tensor(scale_y, dtype=torch.float32) 122 | grid_y_lo = grid_y.long() 123 | grid_y_hi = (grid_y_lo + 1).clamp_max(input.shape[2] - 1) 124 | grid_dy = grid_y - grid_y_lo.float() 125 | 126 | # could be improved with index_select 127 | in_00 = input[:, :, grid_y_lo, :][:, :, :, grid_x_lo] 128 | in_01 = input[:, :, grid_y_lo, :][:, :, :, grid_x_hi] 129 | in_10 = input[:, :, grid_y_hi, :][:, :, :, grid_x_lo] 130 | in_11 = input[:, :, grid_y_hi, :][:, :, :, grid_x_hi] 131 | 132 | in_0 = in_00 + (in_01 - in_00) * grid_dx.view(1, 1, 1, out_size[1]) 133 | in_1 = in_10 + (in_11 - in_10) * grid_dx.view(1, 1, 1, out_size[1]) 134 | out = in_0 + (in_1 - in_0) * grid_dy.view(1, 1, out_size[0], 1) 135 | 136 | return out 137 | 138 | if method == 'slow': 139 | out = resample_manually() 140 | else: 141 | out = resample_using_grid_sample() 142 | 143 | return out 144 | -------------------------------------------------------------------------------- /metrics/metric_fid.py: -------------------------------------------------------------------------------- 1 | # Functions fid_features_to_statistics and fid_statistics_to_metric are adapted from 2 | # https://github.com/bioinf-jku/TTUR/blob/master/fid.py commit id d4baae8 3 | # Distributed under Apache License 2.0: https://github.com/bioinf-jku/TTUR/blob/master/LICENSE 4 | 5 | import numpy as np 6 | import scipy.linalg 7 | import torch 8 | 9 | from metrics.helpers import get_kwarg, vprint 10 | from metrics.utils import get_cacheable_input_name, cache_lookup_one_recompute_on_miss, \ 11 | extract_featuresdict_from_input_id_cached, create_feature_extractor 12 | 13 | KEY_METRIC_FID = 'frechet_inception_distance' 14 | 15 | 16 | def fid_features_to_statistics(features): 17 | assert torch.is_tensor(features) and features.dim() == 2 18 | features = features.numpy() 19 | mu = np.mean(features, axis=0) 20 | sigma = np.cov(features, rowvar=False) 21 | return { 22 | 'mu': mu, 23 | 'sigma': sigma, 24 | } 25 | 26 | 27 | def fid_statistics_to_metric(stat_1, stat_2, verbose): 28 | eps = 1e-6 29 | 30 | mu1, sigma1 = stat_1['mu'], stat_1['sigma'] 31 | mu2, sigma2 = stat_2['mu'], stat_2['sigma'] 32 | assert mu1.shape == mu2.shape and mu1.dtype == mu2.dtype 33 | assert sigma1.shape == sigma2.shape and sigma1.dtype == sigma2.dtype 34 | 35 | mu1 = np.atleast_1d(mu1) 36 | mu2 = np.atleast_1d(mu2) 37 | 38 | sigma1 = np.atleast_2d(sigma1) 39 | sigma2 = np.atleast_2d(sigma2) 40 | 41 | assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths' 42 | assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions' 43 | 44 | diff = mu1 - mu2 45 | 46 | # Product might be almost singular 47 | covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) 48 | if not np.isfinite(covmean).all(): 49 | vprint(verbose, 50 | f'WARNING: fid calculation produces singular product; ' 51 | f'adding {eps} to diagonal of cov estimates' 52 | ) 53 | offset = np.eye(sigma1.shape[0]) * eps 54 | covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset), disp=verbose) 55 | 56 | # Numerical error might give slight imaginary component 57 | if np.iscomplexobj(covmean): 58 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 59 | m = np.max(np.abs(covmean.imag)) 60 | assert False, 'Imaginary component {}'.format(m) 61 | covmean = covmean.real 62 | 63 | tr_covmean = np.trace(covmean) 64 | 65 | out = { 66 | KEY_METRIC_FID: float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) 67 | } 68 | 69 | vprint(verbose, f'Frechet Inception Distance: {out[KEY_METRIC_FID]}') 70 | 71 | return out 72 | 73 | 74 | def fid_featuresdict_to_statistics(featuresdict, feat_layer_name): 75 | features = featuresdict[feat_layer_name] 76 | statistics = fid_features_to_statistics(features) 77 | return statistics 78 | 79 | 80 | def fid_featuresdict_to_statistics_cached( 81 | featuresdict, cacheable_input_name, feat_extractor, feat_layer_name, **kwargs 82 | ): 83 | 84 | def fn_recompute(): 85 | return fid_featuresdict_to_statistics(featuresdict, feat_layer_name) 86 | 87 | if cacheable_input_name is not None: 88 | feat_extractor_name = feat_extractor.get_name() 89 | cached_name = f'{cacheable_input_name}-{feat_extractor_name}-stat-fid-{feat_layer_name}' 90 | stat = cache_lookup_one_recompute_on_miss(cached_name, fn_recompute, **kwargs) 91 | else: 92 | stat = fn_recompute() 93 | return stat 94 | 95 | 96 | def fid_input_id_to_statistics(input_id, feat_extractor, feat_layer_name, **kwargs): 97 | featuresdict = extract_featuresdict_from_input_id_cached(input_id, feat_extractor, **kwargs) 98 | return fid_featuresdict_to_statistics(featuresdict, feat_layer_name) 99 | 100 | 101 | def fid_input_id_to_statistics_cached(input_id, feat_extractor, feat_layer_name, **kwargs): 102 | 103 | def fn_recompute(): 104 | return fid_input_id_to_statistics(input_id, feat_extractor, feat_layer_name, **kwargs) 105 | 106 | cacheable_input_name = get_cacheable_input_name(input_id, **kwargs) 107 | 108 | if cacheable_input_name is not None: 109 | feat_extractor_name = feat_extractor.get_name() 110 | cached_name = f'{cacheable_input_name}-{feat_extractor_name}-stat-fid-{feat_layer_name}' 111 | stat = cache_lookup_one_recompute_on_miss(cached_name, fn_recompute, **kwargs) 112 | else: 113 | stat = fn_recompute() 114 | return stat 115 | 116 | 117 | def fid_inputs_to_metric(feat_extractor, **kwargs): 118 | feat_layer_name = get_kwarg('feature_layer_fid', kwargs) 119 | verbose = get_kwarg('verbose', kwargs) 120 | 121 | vprint(verbose, f'Extracting statistics from input 1') 122 | stats_1 = fid_input_id_to_statistics_cached(1, feat_extractor, feat_layer_name, **kwargs) 123 | 124 | vprint(verbose, f'Extracting statistics from input 2') 125 | stats_2 = fid_input_id_to_statistics_cached(2, feat_extractor, feat_layer_name, **kwargs) 126 | 127 | metric = fid_statistics_to_metric(stats_1, stats_2, get_kwarg('verbose', kwargs)) 128 | return metric 129 | 130 | 131 | def calculate_fid(**kwargs): 132 | feature_extractor = get_kwarg('feature_extractor', kwargs) 133 | feat_layer_name = get_kwarg('feature_layer_fid', kwargs) 134 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs) 135 | metric = fid_inputs_to_metric(feat_extractor, **kwargs) 136 | return metric 137 | -------------------------------------------------------------------------------- /metrics/metric_isc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from metrics.helpers import get_kwarg, vprint 5 | from metrics.utils import extract_featuresdict_from_input_id_cached, create_feature_extractor 6 | 7 | KEY_METRIC_ISC_MEAN = 'inception_score_mean' 8 | KEY_METRIC_ISC_STD = 'inception_score_std' 9 | 10 | 11 | def isc_features_to_metric(feature, splits=10, shuffle=True, rng_seed=2020): 12 | assert torch.is_tensor(feature) and feature.dim() == 2 13 | N, C = feature.shape 14 | if shuffle: 15 | rng = np.random.RandomState(rng_seed) 16 | feature = feature[rng.permutation(N), :] 17 | feature = feature.double() 18 | 19 | p = feature.softmax(dim=1) 20 | log_p = feature.log_softmax(dim=1) 21 | 22 | scores = [] 23 | for i in range(splits): 24 | p_chunk = p[(i * N // splits): ((i + 1) * N // splits), :] 25 | log_p_chunk = log_p[(i * N // splits): ((i + 1) * N // splits), :] 26 | q_chunk = p_chunk.mean(dim=0, keepdim=True) 27 | kl = p_chunk * (log_p_chunk - q_chunk.log()) 28 | kl = kl.sum(dim=1).mean().exp().item() 29 | scores.append(kl) 30 | 31 | return { 32 | KEY_METRIC_ISC_MEAN: float(np.mean(scores)), 33 | KEY_METRIC_ISC_STD: float(np.std(scores)), 34 | } 35 | 36 | 37 | def isc_featuresdict_to_metric(featuresdict, feat_layer_name, **kwargs): 38 | features = featuresdict[feat_layer_name] 39 | 40 | out = isc_features_to_metric( 41 | features, 42 | get_kwarg('isc_splits', kwargs), 43 | get_kwarg('samples_shuffle', kwargs), 44 | get_kwarg('rng_seed', kwargs), 45 | ) 46 | 47 | vprint(get_kwarg('verbose', kwargs), f'Inception Score: {out[KEY_METRIC_ISC_MEAN]} ± {out[KEY_METRIC_ISC_STD]}') 48 | 49 | return out 50 | 51 | 52 | def isc_input_id_to_metric(input_id, feat_extractor, feat_layer_name, **kwargs): 53 | featuresdict = extract_featuresdict_from_input_id_cached(input_id, feat_extractor, **kwargs) 54 | return isc_featuresdict_to_metric(featuresdict, feat_layer_name, **kwargs) 55 | 56 | 57 | def calculate_isc(input_id, **kwargs): 58 | feature_extractor = get_kwarg('feature_extractor', kwargs) 59 | feat_layer_name = get_kwarg('feature_layer_isc', kwargs) 60 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs) 61 | metric = isc_input_id_to_metric(input_id, feat_extractor, feat_layer_name, **kwargs) 62 | return metric 63 | -------------------------------------------------------------------------------- /metrics/metric_kid.py: -------------------------------------------------------------------------------- 1 | # Functions mmd2 and polynomial_kernel are adapted from 2 | # https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py 3 | # Distributed under BSD 3-Clause: https://github.com/mbinkowski/MMD-GAN/blob/master/LICENSE 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from metrics.helpers import get_kwarg, vassert, vprint 10 | from metrics.utils import create_feature_extractor, extract_featuresdict_from_input_id_cached 11 | 12 | KEY_METRIC_KID_MEAN = 'kernel_inception_distance_mean' 13 | KEY_METRIC_KID_STD = 'kernel_inception_distance_std' 14 | 15 | 16 | def mmd2(K_XX, K_XY, K_YY, unit_diagonal=False, mmd_est='unbiased'): 17 | vassert(mmd_est in ('biased', 'unbiased', 'u-statistic'), 'Invalid value of mmd_est') 18 | 19 | m = K_XX.shape[0] 20 | assert K_XX.shape == (m, m) 21 | assert K_XY.shape == (m, m) 22 | assert K_YY.shape == (m, m) 23 | 24 | # Get the various sums of kernels that we'll use 25 | # Kts drop the diagonal, but we don't need to compute them explicitly 26 | if unit_diagonal: 27 | diag_X = diag_Y = 1 28 | sum_diag_X = sum_diag_Y = m 29 | else: 30 | diag_X = np.diagonal(K_XX) 31 | diag_Y = np.diagonal(K_YY) 32 | 33 | sum_diag_X = diag_X.sum() 34 | sum_diag_Y = diag_Y.sum() 35 | 36 | Kt_XX_sums = K_XX.sum(axis=1) - diag_X 37 | Kt_YY_sums = K_YY.sum(axis=1) - diag_Y 38 | K_XY_sums_0 = K_XY.sum(axis=0) 39 | 40 | Kt_XX_sum = Kt_XX_sums.sum() 41 | Kt_YY_sum = Kt_YY_sums.sum() 42 | K_XY_sum = K_XY_sums_0.sum() 43 | 44 | if mmd_est == 'biased': 45 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) 46 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 47 | - 2 * K_XY_sum / (m * m)) 48 | else: 49 | mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1)) 50 | if mmd_est == 'unbiased': 51 | mmd2 -= 2 * K_XY_sum / (m * m) 52 | else: 53 | mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1)) 54 | 55 | return mmd2 56 | 57 | 58 | def polynomial_kernel(X, Y, degree=3, gamma=None, coef0=1): 59 | if gamma is None: 60 | gamma = 1.0 / X.shape[1] 61 | K = (np.matmul(X, Y.T) * gamma + coef0) ** degree 62 | return K 63 | 64 | 65 | def polynomial_mmd(features_1, features_2, degree, gamma, coef0): 66 | k_11 = polynomial_kernel(features_1, features_1, degree=degree, gamma=gamma, coef0=coef0) 67 | k_22 = polynomial_kernel(features_2, features_2, degree=degree, gamma=gamma, coef0=coef0) 68 | k_12 = polynomial_kernel(features_1, features_2, degree=degree, gamma=gamma, coef0=coef0) 69 | return mmd2(k_11, k_12, k_22) 70 | 71 | 72 | def kid_features_to_metric(features_1, features_2, **kwargs): 73 | assert torch.is_tensor(features_1) and features_1.dim() == 2 74 | assert torch.is_tensor(features_2) and features_2.dim() == 2 75 | assert features_1.shape[1] == features_2.shape[1] 76 | 77 | kid_subsets = get_kwarg('kid_subsets', kwargs) 78 | kid_subset_size = get_kwarg('kid_subset_size', kwargs) 79 | verbose = get_kwarg('verbose', kwargs) 80 | 81 | n_samples_1, n_samples_2 = len(features_1), len(features_2) 82 | vassert( 83 | n_samples_1 >= kid_subset_size and n_samples_2 >= kid_subset_size, 84 | f'KID subset size {kid_subset_size} cannot be smaller than the number of samples (input_1: {n_samples_1}, ' 85 | f'input_2: {n_samples_2}). Consider using "kid_subset_size" kwarg or "--kid-subset-size" command line key to ' 86 | f'proceed.' 87 | ) 88 | 89 | features_1 = features_1.cpu().numpy() 90 | features_2 = features_2.cpu().numpy() 91 | 92 | mmds = np.zeros(kid_subsets) 93 | rng = np.random.RandomState(get_kwarg('rng_seed', kwargs)) 94 | 95 | for i in tqdm( 96 | range(kid_subsets), disable=not verbose, leave=False, unit='subsets', 97 | desc='Kernel Inception Distance' 98 | ): 99 | f1 = features_1[rng.choice(n_samples_1, kid_subset_size, replace=False)] 100 | f2 = features_2[rng.choice(n_samples_2, kid_subset_size, replace=False)] 101 | o = polynomial_mmd( 102 | f1, 103 | f2, 104 | get_kwarg('kid_degree', kwargs), 105 | get_kwarg('kid_gamma', kwargs), 106 | get_kwarg('kid_coef0', kwargs), 107 | ) 108 | mmds[i] = o 109 | 110 | out = { 111 | KEY_METRIC_KID_MEAN: float(np.mean(mmds)), 112 | KEY_METRIC_KID_STD: float(np.std(mmds)), 113 | } 114 | 115 | vprint(verbose, f'Kernel Inception Distance: {out[KEY_METRIC_KID_MEAN]} ± {out[KEY_METRIC_KID_STD]}') 116 | 117 | return out 118 | 119 | 120 | def kid_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs): 121 | features_1 = featuresdict_1[feat_layer_name] 122 | features_2 = featuresdict_2[feat_layer_name] 123 | metric = kid_features_to_metric(features_1, features_2, **kwargs) 124 | return metric 125 | 126 | 127 | def calculate_kid(**kwargs): 128 | feature_extractor = get_kwarg('feature_extractor', kwargs) 129 | feat_layer_name = get_kwarg('feature_layer_kid', kwargs) 130 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs) 131 | featuresdict_1 = extract_featuresdict_from_input_id_cached(1, feat_extractor, **kwargs) 132 | featuresdict_2 = extract_featuresdict_from_input_id_cached(2, feat_extractor, **kwargs) 133 | metric = kid_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs) 134 | return metric 135 | -------------------------------------------------------------------------------- /metrics/metric_ppl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | 5 | from metrics.generative_model_base import GenerativeModelBase 6 | from metrics.helpers import get_kwarg, vassert, vprint 7 | from metrics.utils import sample_random, batch_interp, create_sample_similarity, \ 8 | prepare_input_descriptor_from_input_id, prepare_input_from_descriptor 9 | 10 | KEY_METRIC_PPL_RAW = 'perceptual_path_length_raw' 11 | KEY_METRIC_PPL_MEAN = 'perceptual_path_length_mean' 12 | KEY_METRIC_PPL_STD = 'perceptual_path_length_std' 13 | 14 | 15 | def calculate_ppl(input_id, **kwargs): 16 | """ 17 | Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py 18 | """ 19 | batch_size = get_kwarg('batch_size', kwargs) 20 | is_cuda = get_kwarg('cuda', kwargs) 21 | verbose = get_kwarg('verbose', kwargs) 22 | epsilon = get_kwarg('ppl_epsilon', kwargs) 23 | interp = get_kwarg('ppl_z_interp_mode', kwargs) 24 | reduction = get_kwarg('ppl_reduction', kwargs) 25 | similarity_name = get_kwarg('ppl_sample_similarity', kwargs) 26 | sample_similarity_resize = get_kwarg('ppl_sample_similarity_resize', kwargs) 27 | sample_similarity_dtype = get_kwarg('ppl_sample_similarity_dtype', kwargs) 28 | discard_percentile_lower = get_kwarg('ppl_discard_percentile_lower', kwargs) 29 | discard_percentile_higher = get_kwarg('ppl_discard_percentile_higher', kwargs) 30 | 31 | input_desc = prepare_input_descriptor_from_input_id(input_id, **kwargs) 32 | model = prepare_input_from_descriptor(input_desc, **kwargs) 33 | vassert( 34 | isinstance(model, GenerativeModelBase), 35 | 'Input needs to be an instance of GenerativeModelBase, which can be either passed programmatically by wrapping ' 36 | 'a model with GenerativeModelModuleWrapper, or via command line by specifying a path to a ONNX or PTH (JIT) ' 37 | 'model and a set of input1_model_* arguments' 38 | ) 39 | 40 | if is_cuda: 41 | model.cuda() 42 | 43 | input_model_num_samples = input_desc['input_model_num_samples'] 44 | input_model_num_classes = model.num_classes 45 | input_model_z_size = model.z_size 46 | input_model_z_type = model.z_type 47 | 48 | vassert(input_model_num_classes >= 0, 'Model can be unconditional (0 classes) or conditional (positive)') 49 | vassert(type(input_model_z_size) is int and input_model_z_size > 0, 50 | 'Dimensionality of generator noise not specified ("input1_model_z_size" argument)') 51 | vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number') 52 | vassert(type(input_model_num_samples) is int and input_model_num_samples > 0, 'Number of samples must be positive') 53 | vassert(reduction in ('none', 'mean'), 'Reduction must be one of [none, mean]') 54 | vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile') 55 | vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile') 56 | if discard_percentile_lower is not None and discard_percentile_higher is not None: 57 | vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles') 58 | 59 | sample_similarity = create_sample_similarity( 60 | similarity_name, 61 | sample_similarity_resize=sample_similarity_resize, 62 | sample_similarity_dtype=sample_similarity_dtype, 63 | **kwargs 64 | ) 65 | 66 | is_cond = input_desc['input_model_num_classes'] > 0 67 | 68 | rng = np.random.RandomState(get_kwarg('rng_seed', kwargs)) 69 | 70 | lat_e0 = sample_random(rng, (input_model_num_samples, input_model_z_size), input_model_z_type) 71 | lat_e1 = sample_random(rng, (input_model_num_samples, input_model_z_size), input_model_z_type) 72 | lat_e1 = batch_interp(lat_e0, lat_e1, epsilon, interp) 73 | 74 | labels = None 75 | if is_cond: 76 | labels = torch.from_numpy(rng.randint(0, input_model_num_classes, (input_model_num_samples,))) 77 | 78 | distances = [] 79 | 80 | with tqdm(disable=not verbose, leave=False, unit='samples', total=input_model_num_samples, 81 | desc='Perceptual Path Length') as t, torch.no_grad(): 82 | for begin_id in range(0, input_model_num_samples, batch_size): 83 | end_id = min(begin_id + batch_size, input_model_num_samples) 84 | batch_sz = end_id - begin_id 85 | 86 | batch_lat_e0 = lat_e0[begin_id:end_id] 87 | batch_lat_e1 = lat_e1[begin_id:end_id] 88 | if is_cond: 89 | batch_labels = labels[begin_id:end_id] 90 | 91 | if is_cuda: 92 | batch_lat_e0 = batch_lat_e0.cuda(non_blocking=True) 93 | batch_lat_e1 = batch_lat_e1.cuda(non_blocking=True) 94 | if is_cond: 95 | batch_labels = batch_labels.cuda(non_blocking=True) 96 | 97 | if is_cond: 98 | rgb_e01 = model.forward( 99 | torch.cat((batch_lat_e0, batch_lat_e1), dim=0), 100 | torch.cat((batch_labels, batch_labels), dim=0) 101 | ) 102 | else: 103 | rgb_e01 = model.forward( 104 | torch.cat((batch_lat_e0, batch_lat_e1), dim=0) 105 | ) 106 | rgb_e0, rgb_e1 = rgb_e01.chunk(2) 107 | 108 | sim = sample_similarity(rgb_e0, rgb_e1) 109 | dist_lat_e01 = sim / (epsilon ** 2) 110 | distances.append(dist_lat_e01.cpu().numpy()) 111 | 112 | t.update(batch_sz) 113 | 114 | distances = np.concatenate(distances, axis=0) 115 | 116 | cond, lo, hi = None, None, None 117 | if discard_percentile_lower is not None: 118 | lo = np.percentile(distances, discard_percentile_lower, interpolation='lower') 119 | cond = lo <= distances 120 | if discard_percentile_higher is not None: 121 | hi = np.percentile(distances, discard_percentile_higher, interpolation='higher') 122 | cond = np.logical_and(cond, distances <= hi) 123 | if cond is not None: 124 | distances = np.extract(cond, distances) 125 | 126 | out = { 127 | KEY_METRIC_PPL_MEAN: float(np.mean(distances)), 128 | KEY_METRIC_PPL_STD: float(np.std(distances)) 129 | } 130 | if reduction == 'none': 131 | out[KEY_METRIC_PPL_RAW] = distances 132 | 133 | vprint(verbose, f'Perceptual Path Length: {out[KEY_METRIC_PPL_MEAN]} ± {out[KEY_METRIC_PPL_STD]}') 134 | 135 | return out 136 | -------------------------------------------------------------------------------- /metrics/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def batch_normalize_last_dim(v, eps=1e-7): 5 | return v / (v ** 2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps) 6 | 7 | 8 | def random_normal(rng, shape): 9 | return torch.from_numpy(rng.randn(*shape)).float() 10 | 11 | 12 | def random_unit(rng, shape): 13 | return batch_normalize_last_dim(torch.from_numpy(rng.rand(*shape)).float()) 14 | 15 | 16 | def random_uniform_0_1(rng, shape): 17 | return torch.from_numpy(rng.rand(*shape)).float() 18 | 19 | 20 | def batch_lerp(a, b, t): 21 | return a + (b - a) * t 22 | 23 | 24 | def batch_slerp_any(a, b, t, eps=1e-7): 25 | assert torch.is_tensor(a) and torch.is_tensor(b) and a.dim() >= 2 and a.shape == b.shape 26 | ndims, N = a.dim() - 1, a.shape[-1] 27 | a_1 = batch_normalize_last_dim(a, eps) 28 | b_1 = batch_normalize_last_dim(b, eps) 29 | d = (a_1 * b_1).sum(dim=-1, keepdim=True) 30 | mask_zero = (a_1.norm(dim=-1, keepdim=True) < eps) | (b_1.norm(dim=-1, keepdim=True) < eps) 31 | mask_collinear = (d > 1 - eps) | (d < -1 + eps) 32 | mask_lerp = (mask_zero | mask_collinear).repeat([1 for _ in range(ndims)] + [N]) 33 | omega = d.acos() 34 | denom = omega.sin().clamp_min(eps) 35 | coef_a = ((1 - t) * omega).sin() / denom 36 | coef_b = (t * omega).sin() / denom 37 | out = coef_a * a + coef_b * b 38 | out[mask_lerp] = batch_lerp(a, b, t)[mask_lerp] 39 | return out 40 | 41 | 42 | def batch_slerp_unit(a, b, t, eps=1e-7): 43 | out = batch_slerp_any(a, b, t, eps) 44 | out = batch_normalize_last_dim(out, eps) 45 | return out 46 | -------------------------------------------------------------------------------- /metrics/sample_similarity_base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from metrics.helpers import vassert 4 | 5 | 6 | class SampleSimilarityBase(nn.Module): 7 | def __init__(self, name): 8 | """ 9 | Base class for samples similarity measures that can be used in :func:`calculate_metrics`. 10 | 11 | Args: 12 | 13 | name (str): Unique name of the subclassed sample similarity measure, must be the same as used in 14 | :func:`register_sample_similarity`. 15 | """ 16 | super(SampleSimilarityBase, self).__init__() 17 | vassert(type(name) is str, 'Sample similarity name must be a string') 18 | self.name = name 19 | 20 | def get_name(self): 21 | return self.name 22 | 23 | def forward(self, *args): 24 | """ 25 | Returns the value of sample similarity between the inputs. 26 | """ 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | #torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 2 | accelerate==0.16.0 3 | einops 4 | ema-pytorch 5 | pytorch-lightning==1.9.3 6 | scikit-learn 7 | scipy 8 | thop 9 | timm==0.6.12 10 | tensorboard 11 | fvcore 12 | albumentations 13 | omegaconf 14 | numpy==1.23.5 -------------------------------------------------------------------------------- /taming/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/taming/__init__.py -------------------------------------------------------------------------------- /taming/data/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | from taming.data.sflckr import SegmentationBase # for examples included in repo 9 | 10 | 11 | class Examples(SegmentationBase): 12 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"): 13 | super().__init__(data_csv="data/ade20k_examples.txt", 14 | data_root="data/ade20k_images", 15 | segmentation_root="data/ade20k_segmentations", 16 | size=size, random_crop=random_crop, 17 | interpolation=interpolation, 18 | n_labels=151, shift_segmentation=False) 19 | 20 | 21 | # With semantic map and scene label 22 | class ADE20kBase(Dataset): 23 | def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): 24 | self.split = self.get_split() 25 | self.n_labels = 151 # unknown + 150 26 | self.data_csv = {"train": "data/ade20k_train.txt", 27 | "validation": "data/ade20k_test.txt"}[self.split] 28 | self.data_root = "data/ade20k_root" 29 | with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: 30 | self.scene_categories = f.read().splitlines() 31 | self.scene_categories = dict(line.split() for line in self.scene_categories) 32 | with open(self.data_csv, "r") as f: 33 | self.image_paths = f.read().splitlines() 34 | self._length = len(self.image_paths) 35 | self.labels = { 36 | "relative_file_path_": [l for l in self.image_paths], 37 | "file_path_": [os.path.join(self.data_root, "images", l) 38 | for l in self.image_paths], 39 | "relative_segmentation_path_": [l.replace(".jpg", ".png") 40 | for l in self.image_paths], 41 | "segmentation_path_": [os.path.join(self.data_root, "annotations", 42 | l.replace(".jpg", ".png")) 43 | for l in self.image_paths], 44 | "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] 45 | for l in self.image_paths], 46 | } 47 | 48 | size = None if size is not None and size<=0 else size 49 | self.size = size 50 | if crop_size is None: 51 | self.crop_size = size if size is not None else None 52 | else: 53 | self.crop_size = crop_size 54 | if self.size is not None: 55 | self.interpolation = interpolation 56 | self.interpolation = { 57 | "nearest": cv2.INTER_NEAREST, 58 | "bilinear": cv2.INTER_LINEAR, 59 | "bicubic": cv2.INTER_CUBIC, 60 | "area": cv2.INTER_AREA, 61 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 62 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 63 | interpolation=self.interpolation) 64 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 65 | interpolation=cv2.INTER_NEAREST) 66 | 67 | if crop_size is not None: 68 | self.center_crop = not random_crop 69 | if self.center_crop: 70 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 71 | else: 72 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 73 | self.preprocessor = self.cropper 74 | 75 | def __len__(self): 76 | return self._length 77 | 78 | def __getitem__(self, i): 79 | example = dict((k, self.labels[k][i]) for k in self.labels) 80 | image = Image.open(example["file_path_"]) 81 | if not image.mode == "RGB": 82 | image = image.convert("RGB") 83 | image = np.array(image).astype(np.uint8) 84 | if self.size is not None: 85 | image = self.image_rescaler(image=image)["image"] 86 | segmentation = Image.open(example["segmentation_path_"]) 87 | segmentation = np.array(segmentation).astype(np.uint8) 88 | if self.size is not None: 89 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 90 | if self.size is not None: 91 | processed = self.preprocessor(image=image, mask=segmentation) 92 | else: 93 | processed = {"image": image, "mask": segmentation} 94 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 95 | segmentation = processed["mask"] 96 | onehot = np.eye(self.n_labels)[segmentation] 97 | example["segmentation"] = onehot 98 | return example 99 | 100 | 101 | class ADE20kTrain(ADE20kBase): 102 | # default to random_crop=True 103 | def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): 104 | super().__init__(config=config, size=size, random_crop=random_crop, 105 | interpolation=interpolation, crop_size=crop_size) 106 | 107 | def get_split(self): 108 | return "train" 109 | 110 | 111 | class ADE20kValidation(ADE20kBase): 112 | def get_split(self): 113 | return "validation" 114 | 115 | 116 | if __name__ == "__main__": 117 | dset = ADE20kValidation() 118 | ex = dset[0] 119 | for k in ["image", "scene_category", "segmentation"]: 120 | print(type(ex[k])) 121 | try: 122 | print(ex[k].shape) 123 | except: 124 | print(ex[k]) 125 | -------------------------------------------------------------------------------- /taming/data/annotated_objects_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import chain 3 | from pathlib import Path 4 | from typing import Iterable, Dict, List, Callable, Any 5 | from collections import defaultdict 6 | 7 | from tqdm import tqdm 8 | 9 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 10 | from taming.data.helper_types import Annotation, ImageDescription, Category 11 | 12 | COCO_PATH_STRUCTURE = { 13 | 'train': { 14 | 'top_level': '', 15 | 'instances_annotations': 'annotations/instances_train2017.json', 16 | 'stuff_annotations': 'annotations/stuff_train2017.json', 17 | 'files': 'train2017' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'instances_annotations': 'annotations/instances_val2017.json', 22 | 'stuff_annotations': 'annotations/stuff_val2017.json', 23 | 'files': 'val2017' 24 | } 25 | } 26 | 27 | 28 | def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: 29 | return { 30 | str(img['id']): ImageDescription( 31 | id=img['id'], 32 | license=img.get('license'), 33 | file_name=img['file_name'], 34 | coco_url=img['coco_url'], 35 | original_size=(img['width'], img['height']), 36 | date_captured=img.get('date_captured'), 37 | flickr_url=img.get('flickr_url') 38 | ) 39 | for img in description_json 40 | } 41 | 42 | 43 | def load_categories(category_json: Iterable) -> Dict[str, Category]: 44 | return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) 45 | for cat in category_json if cat['name'] != 'other'} 46 | 47 | 48 | def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], 49 | category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: 50 | annotations = defaultdict(list) 51 | total = sum(len(a) for a in annotations_json) 52 | for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): 53 | image_id = str(ann['image_id']) 54 | if image_id not in image_descriptions: 55 | raise ValueError(f'image_id [{image_id}] has no image description.') 56 | category_id = ann['category_id'] 57 | try: 58 | category_no = category_no_for_id(str(category_id)) 59 | except KeyError: 60 | continue 61 | 62 | width, height = image_descriptions[image_id].original_size 63 | bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) 64 | 65 | annotations[image_id].append( 66 | Annotation( 67 | id=ann['id'], 68 | area=bbox[2]*bbox[3], # use bbox area 69 | is_group_of=ann['iscrowd'], 70 | image_id=ann['image_id'], 71 | bbox=bbox, 72 | category_id=str(category_id), 73 | category_no=category_no 74 | ) 75 | ) 76 | return dict(annotations) 77 | 78 | 79 | class AnnotatedObjectsCoco(AnnotatedObjectsDataset): 80 | def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): 81 | """ 82 | @param data_path: is the path to the following folder structure: 83 | coco/ 84 | ├── annotations 85 | │ ├── instances_train2017.json 86 | │ ├── instances_val2017.json 87 | │ ├── stuff_train2017.json 88 | │ └── stuff_val2017.json 89 | ├── train2017 90 | │ ├── 000000000009.jpg 91 | │ ├── 000000000025.jpg 92 | │ └── ... 93 | ├── val2017 94 | │ ├── 000000000139.jpg 95 | │ ├── 000000000285.jpg 96 | │ └── ... 97 | @param: split: one of 'train' or 'validation' 98 | @param: desired image size (give square images) 99 | """ 100 | super().__init__(**kwargs) 101 | self.use_things = use_things 102 | self.use_stuff = use_stuff 103 | 104 | with open(self.paths['instances_annotations']) as f: 105 | inst_data_json = json.load(f) 106 | with open(self.paths['stuff_annotations']) as f: 107 | stuff_data_json = json.load(f) 108 | 109 | category_jsons = [] 110 | annotation_jsons = [] 111 | if self.use_things: 112 | category_jsons.append(inst_data_json['categories']) 113 | annotation_jsons.append(inst_data_json['annotations']) 114 | if self.use_stuff: 115 | category_jsons.append(stuff_data_json['categories']) 116 | annotation_jsons.append(stuff_data_json['annotations']) 117 | 118 | self.categories = load_categories(chain(*category_jsons)) 119 | self.filter_categories() 120 | self.setup_category_id_and_number() 121 | 122 | self.image_descriptions = load_image_descriptions(inst_data_json['images']) 123 | annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) 124 | self.annotations = self.filter_object_number(annotations, self.min_object_area, 125 | self.min_objects_per_image, self.max_objects_per_image) 126 | self.image_ids = list(self.annotations.keys()) 127 | self.clean_up_annotations_and_image_descriptions() 128 | 129 | def get_path_structure(self) -> Dict[str, str]: 130 | if self.split not in COCO_PATH_STRUCTURE: 131 | raise ValueError(f'Split [{self.split} does not exist for COCO data.]') 132 | return COCO_PATH_STRUCTURE[self.split] 133 | 134 | def get_image_path(self, image_id: str) -> Path: 135 | return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) 136 | 137 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 138 | # noinspection PyProtectedMember 139 | return self.image_descriptions[image_id]._asdict() 140 | -------------------------------------------------------------------------------- /taming/data/annotated_objects_open_images.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from csv import DictReader, reader as TupleReader 3 | from pathlib import Path 4 | from typing import Dict, List, Any 5 | import warnings 6 | 7 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 8 | from taming.data.helper_types import Annotation, Category 9 | from tqdm import tqdm 10 | 11 | OPEN_IMAGES_STRUCTURE = { 12 | 'train': { 13 | 'top_level': '', 14 | 'class_descriptions': 'class-descriptions-boxable.csv', 15 | 'annotations': 'oidv6-train-annotations-bbox.csv', 16 | 'file_list': 'train-images-boxable.csv', 17 | 'files': 'train' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'class_descriptions': 'class-descriptions-boxable.csv', 22 | 'annotations': 'validation-annotations-bbox.csv', 23 | 'file_list': 'validation-images.csv', 24 | 'files': 'validation' 25 | }, 26 | 'test': { 27 | 'top_level': '', 28 | 'class_descriptions': 'class-descriptions-boxable.csv', 29 | 'annotations': 'test-annotations-bbox.csv', 30 | 'file_list': 'test-images.csv', 31 | 'files': 'test' 32 | } 33 | } 34 | 35 | 36 | def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], 37 | category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]: 38 | annotations: Dict[str, List[Annotation]] = defaultdict(list) 39 | with open(descriptor_path) as file: 40 | reader = DictReader(file) 41 | for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'): 42 | width = float(row['XMax']) - float(row['XMin']) 43 | height = float(row['YMax']) - float(row['YMin']) 44 | area = width * height 45 | category_id = row['LabelName'] 46 | if category_id in category_mapping: 47 | category_id = category_mapping[category_id] 48 | if area >= min_object_area and category_id in category_no_for_id: 49 | annotations[row['ImageID']].append( 50 | Annotation( 51 | id=i, 52 | image_id=row['ImageID'], 53 | source=row['Source'], 54 | category_id=category_id, 55 | category_no=category_no_for_id[category_id], 56 | confidence=float(row['Confidence']), 57 | bbox=(float(row['XMin']), float(row['YMin']), width, height), 58 | area=area, 59 | is_occluded=bool(int(row['IsOccluded'])), 60 | is_truncated=bool(int(row['IsTruncated'])), 61 | is_group_of=bool(int(row['IsGroupOf'])), 62 | is_depiction=bool(int(row['IsDepiction'])), 63 | is_inside=bool(int(row['IsInside'])) 64 | ) 65 | ) 66 | if 'train' in str(descriptor_path) and i < 14000000: 67 | warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].') 68 | return dict(annotations) 69 | 70 | 71 | def load_image_ids(csv_path: Path) -> List[str]: 72 | with open(csv_path) as file: 73 | reader = DictReader(file) 74 | return [row['image_name'] for row in reader] 75 | 76 | 77 | def load_categories(csv_path: Path) -> Dict[str, Category]: 78 | with open(csv_path) as file: 79 | reader = TupleReader(file) 80 | return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader} 81 | 82 | 83 | class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset): 84 | def __init__(self, use_additional_parameters: bool, **kwargs): 85 | """ 86 | @param data_path: is the path to the following folder structure: 87 | open_images/ 88 | │ oidv6-train-annotations-bbox.csv 89 | ├── class-descriptions-boxable.csv 90 | ├── oidv6-train-annotations-bbox.csv 91 | ├── test 92 | │ ├── 000026e7ee790996.jpg 93 | │ ├── 000062a39995e348.jpg 94 | │ └── ... 95 | ├── test-annotations-bbox.csv 96 | ├── test-images.csv 97 | ├── train 98 | │ ├── 000002b66c9c498e.jpg 99 | │ ├── 000002b97e5471a0.jpg 100 | │ └── ... 101 | ├── train-images-boxable.csv 102 | ├── validation 103 | │ ├── 0001eeaf4aed83f9.jpg 104 | │ ├── 0004886b7d043cfd.jpg 105 | │ └── ... 106 | ├── validation-annotations-bbox.csv 107 | └── validation-images.csv 108 | @param: split: one of 'train', 'validation' or 'test' 109 | @param: desired image size (returns square images) 110 | """ 111 | 112 | super().__init__(**kwargs) 113 | self.use_additional_parameters = use_additional_parameters 114 | 115 | self.categories = load_categories(self.paths['class_descriptions']) 116 | self.filter_categories() 117 | self.setup_category_id_and_number() 118 | 119 | self.image_descriptions = {} 120 | annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping, 121 | self.category_number) 122 | self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image, 123 | self.max_objects_per_image) 124 | self.image_ids = list(self.annotations.keys()) 125 | self.clean_up_annotations_and_image_descriptions() 126 | 127 | def get_path_structure(self) -> Dict[str, str]: 128 | if self.split not in OPEN_IMAGES_STRUCTURE: 129 | raise ValueError(f'Split [{self.split} does not exist for Open Images data.]') 130 | return OPEN_IMAGES_STRUCTURE[self.split] 131 | 132 | def get_image_path(self, image_id: str) -> Path: 133 | return self.paths['files'].joinpath(f'{image_id:0>16}.jpg') 134 | 135 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 136 | image_path = self.get_image_path(image_id) 137 | return {'file_path': str(image_path), 'file_name': image_path.name} 138 | -------------------------------------------------------------------------------- /taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /taming/data/conditional_builder/objects_bbox.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | from typing import List, Tuple, Callable, Optional 3 | 4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 5 | from more_itertools.recipes import grouper 6 | from taming.data.image_transforms import convert_pil_to_tensor 7 | from torch import LongTensor, Tensor 8 | 9 | from taming.data.helper_types import BoundingBox, Annotation 10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ 12 | pad_list, get_plot_font_size, absolute_bbox 13 | 14 | 15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 16 | @property 17 | def object_descriptor_length(self) -> int: 18 | return 3 19 | 20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 21 | object_triples = [ 22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) 23 | for ann in annotations 24 | ] 25 | empty_triple = (self.none, self.none, self.none) 26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 27 | return object_triples 28 | 29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 30 | conditional_list = conditional.tolist() 31 | crop_coordinates = None 32 | if self.encode_crop: 33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 34 | conditional_list = conditional_list[:-2] 35 | object_triples = grouper(conditional_list, 3) 36 | assert conditional.shape[0] == self.embedding_dim 37 | return [ 38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) 39 | for object_triple in object_triples if object_triple[0] != self.none 40 | ], crop_coordinates 41 | 42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 44 | plot = pil_image.new('RGB', figure_size, WHITE) 45 | draw = pil_img_draw.Draw(plot) 46 | font = ImageFont.truetype( 47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 48 | size=get_plot_font_size(font_size, figure_size) 49 | ) 50 | width, height = plot.size 51 | description, crop_coordinates = self.inverse_build(conditional) 52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 53 | annotation = self.representation_to_annotation(representation) 54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 55 | bbox = absolute_bbox(bbox, width, height) 56 | draw.rectangle(bbox, outline=color, width=line_width) 57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 58 | if crop_coordinates is not None: 59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 60 | return convert_pil_to_tensor(plot) / 127.5 - 1. 61 | -------------------------------------------------------------------------------- /taming/data/conditional_builder/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import List, Any, Tuple, Optional 3 | 4 | from taming.data.helper_types import BoundingBox, Annotation 5 | 6 | # source: seaborn, color palette tab10 7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), 8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] 9 | BLACK = (0, 0, 0) 10 | GRAY_75 = (63, 63, 63) 11 | GRAY_50 = (127, 127, 127) 12 | GRAY_25 = (191, 191, 191) 13 | WHITE = (255, 255, 255) 14 | FULL_CROP = (0., 0., 1., 1.) 15 | 16 | 17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: 18 | """ 19 | Give intersection area of two rectangles. 20 | @param rectangle1: (x0, y0, w, h) of first rectangle 21 | @param rectangle2: (x0, y0, w, h) of second rectangle 22 | """ 23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] 24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] 25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) 26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) 27 | return x_overlap * y_overlap 28 | 29 | 30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: 31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] 32 | 33 | 34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: 35 | bbox = relative_bbox 36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height 37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) 38 | 39 | 40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: 41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))] 42 | 43 | 44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ 45 | List[Annotation]: 46 | def clamp(x: float): 47 | return max(min(x, 1.), 0.) 48 | 49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox: 50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) 51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) 52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0) 53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0) 54 | if flip: 55 | x0 = 1 - (x0 + w) 56 | return x0, y0, w, h 57 | 58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] 59 | 60 | 61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: 62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] 63 | 64 | 65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: 66 | sl = slice(1) if short else slice(None) 67 | string = '' 68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): 69 | return string 70 | if annotation.is_group_of: 71 | string += 'group'[sl] + ',' 72 | if annotation.is_occluded: 73 | string += 'occluded'[sl] + ',' 74 | if annotation.is_depiction: 75 | string += 'depiction'[sl] + ',' 76 | if annotation.is_inside: 77 | string += 'inside'[sl] 78 | return '(' + string.strip(",") + ')' 79 | 80 | 81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: 82 | if font_size is None: 83 | font_size = 10 84 | if max(figure_size) >= 256: 85 | font_size = 12 86 | if max(figure_size) >= 512: 87 | font_size = 15 88 | return font_size 89 | 90 | 91 | def get_circle_size(figure_size: Tuple[int, int]) -> int: 92 | circle_size = 2 93 | if max(figure_size) >= 256: 94 | circle_size = 3 95 | if max(figure_size) >= 512: 96 | circle_size = 4 97 | return circle_size 98 | 99 | 100 | def load_object_from_string(object_string: str) -> Any: 101 | """ 102 | Source: https://stackoverflow.com/a/10773699 103 | """ 104 | module_name, class_name = object_string.rsplit(".", 1) 105 | return getattr(importlib.import_module(module_name), class_name) 106 | -------------------------------------------------------------------------------- /taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class CustomBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, i): 18 | example = self.data[i] 19 | return example 20 | 21 | 22 | 23 | class CustomTrain(CustomBase): 24 | def __init__(self, size, training_images_list_file): 25 | super().__init__() 26 | with open(training_images_list_file, "r") as f: 27 | paths = f.read().splitlines() 28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 29 | 30 | 31 | class CustomTest(CustomBase): 32 | def __init__(self, size, test_images_list_file): 33 | super().__init__() 34 | with open(test_images_list_file, "r") as f: 35 | paths = f.read().splitlines() 36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 37 | 38 | 39 | -------------------------------------------------------------------------------- /taming/data/faceshq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class FacesBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | self.keys = None 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, i): 19 | example = self.data[i] 20 | ex = {} 21 | if self.keys is not None: 22 | for k in self.keys: 23 | ex[k] = example[k] 24 | else: 25 | ex = example 26 | return ex 27 | 28 | 29 | class CelebAHQTrain(FacesBase): 30 | def __init__(self, size, keys=None): 31 | super().__init__() 32 | root = "data/celebahq" 33 | with open("data/celebahqtrain.txt", "r") as f: 34 | relpaths = f.read().splitlines() 35 | paths = [os.path.join(root, relpath) for relpath in relpaths] 36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 37 | self.keys = keys 38 | 39 | 40 | class CelebAHQValidation(FacesBase): 41 | def __init__(self, size, keys=None): 42 | super().__init__() 43 | root = "data/celebahq" 44 | with open("data/celebahqvalidation.txt", "r") as f: 45 | relpaths = f.read().splitlines() 46 | paths = [os.path.join(root, relpath) for relpath in relpaths] 47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 48 | self.keys = keys 49 | 50 | 51 | class FFHQTrain(FacesBase): 52 | def __init__(self, size, keys=None): 53 | super().__init__() 54 | root = "data/ffhq" 55 | with open("data/ffhqtrain.txt", "r") as f: 56 | relpaths = f.read().splitlines() 57 | paths = [os.path.join(root, relpath) for relpath in relpaths] 58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 59 | self.keys = keys 60 | 61 | 62 | class FFHQValidation(FacesBase): 63 | def __init__(self, size, keys=None): 64 | super().__init__() 65 | root = "data/ffhq" 66 | with open("data/ffhqvalidation.txt", "r") as f: 67 | relpaths = f.read().splitlines() 68 | paths = [os.path.join(root, relpath) for relpath in relpaths] 69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 70 | self.keys = keys 71 | 72 | 73 | class FacesHQTrain(Dataset): 74 | # CelebAHQ [0] + FFHQ [1] 75 | def __init__(self, size, keys=None, crop_size=None, coord=False): 76 | d1 = CelebAHQTrain(size=size, keys=keys) 77 | d2 = FFHQTrain(size=size, keys=keys) 78 | self.data = ConcatDatasetWithIndex([d1, d2]) 79 | self.coord = coord 80 | if crop_size is not None: 81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) 82 | if self.coord: 83 | self.cropper = albumentations.Compose([self.cropper], 84 | additional_targets={"coord": "image"}) 85 | 86 | def __len__(self): 87 | return len(self.data) 88 | 89 | def __getitem__(self, i): 90 | ex, y = self.data[i] 91 | if hasattr(self, "cropper"): 92 | if not self.coord: 93 | out = self.cropper(image=ex["image"]) 94 | ex["image"] = out["image"] 95 | else: 96 | h,w,_ = ex["image"].shape 97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 98 | out = self.cropper(image=ex["image"], coord=coord) 99 | ex["image"] = out["image"] 100 | ex["coord"] = out["coord"] 101 | ex["class"] = y 102 | return ex 103 | 104 | 105 | class FacesHQValidation(Dataset): 106 | # CelebAHQ [0] + FFHQ [1] 107 | def __init__(self, size, keys=None, crop_size=None, coord=False): 108 | d1 = CelebAHQValidation(size=size, keys=keys) 109 | d2 = FFHQValidation(size=size, keys=keys) 110 | self.data = ConcatDatasetWithIndex([d1, d2]) 111 | self.coord = coord 112 | if crop_size is not None: 113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) 114 | if self.coord: 115 | self.cropper = albumentations.Compose([self.cropper], 116 | additional_targets={"coord": "image"}) 117 | 118 | def __len__(self): 119 | return len(self.data) 120 | 121 | def __getitem__(self, i): 122 | ex, y = self.data[i] 123 | if hasattr(self, "cropper"): 124 | if not self.coord: 125 | out = self.cropper(image=ex["image"]) 126 | ex["image"] = out["image"] 127 | else: 128 | h,w,_ = ex["image"].shape 129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 130 | out = self.cropper(image=ex["image"], coord=coord) 131 | ex["image"] = out["image"] 132 | ex["coord"] = out["coord"] 133 | ex["class"] = y 134 | return ex 135 | -------------------------------------------------------------------------------- /taming/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /taming/data/image_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor 8 | from torchvision.transforms.functional import _get_image_size as get_image_size 9 | 10 | from taming.data.helper_types import BoundingBox, Image 11 | 12 | pil_to_tensor = PILToTensor() 13 | 14 | 15 | def convert_pil_to_tensor(image: Image) -> Tensor: 16 | with warnings.catch_warnings(): 17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 18 | warnings.simplefilter("ignore") 19 | return pil_to_tensor(image) 20 | 21 | 22 | class RandomCrop1dReturnCoordinates(RandomCrop): 23 | def forward(self, img: Image) -> (BoundingBox, Image): 24 | """ 25 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 26 | Args: 27 | img (PIL Image or Tensor): Image to be cropped. 28 | 29 | Returns: 30 | Bounding box: x0, y0, w, h 31 | PIL Image or Tensor: Cropped image. 32 | 33 | Based on: 34 | torchvision.transforms.RandomCrop, torchvision 1.7.0 35 | """ 36 | if self.padding is not None: 37 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 38 | 39 | width, height = get_image_size(img) 40 | # pad the width if needed 41 | if self.pad_if_needed and width < self.size[1]: 42 | padding = [self.size[1] - width, 0] 43 | img = F.pad(img, padding, self.fill, self.padding_mode) 44 | # pad the height if needed 45 | if self.pad_if_needed and height < self.size[0]: 46 | padding = [0, self.size[0] - height] 47 | img = F.pad(img, padding, self.fill, self.padding_mode) 48 | 49 | i, j, h, w = self.get_params(img, self.size) 50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h 51 | return bbox, F.crop(img, i, j, h, w) 52 | 53 | 54 | class Random2dCropReturnCoordinates(torch.nn.Module): 55 | """ 56 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 57 | Args: 58 | img (PIL Image or Tensor): Image to be cropped. 59 | 60 | Returns: 61 | Bounding box: x0, y0, w, h 62 | PIL Image or Tensor: Cropped image. 63 | 64 | Based on: 65 | torchvision.transforms.RandomCrop, torchvision 1.7.0 66 | """ 67 | 68 | def __init__(self, min_size: int): 69 | super().__init__() 70 | self.min_size = min_size 71 | 72 | def forward(self, img: Image) -> (BoundingBox, Image): 73 | width, height = get_image_size(img) 74 | max_size = min(width, height) 75 | if max_size <= self.min_size: 76 | size = max_size 77 | else: 78 | size = random.randint(self.min_size, max_size) 79 | top = random.randint(0, height - size) 80 | left = random.randint(0, width - size) 81 | bbox = left / width, top / height, size / width, size / height 82 | return bbox, F.crop(img, top, left, size, size) 83 | 84 | 85 | class CenterCropReturnCoordinates(CenterCrop): 86 | @staticmethod 87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: 88 | if width > height: 89 | w = height / width 90 | h = 1.0 91 | x0 = 0.5 - w / 2 92 | y0 = 0. 93 | else: 94 | w = 1.0 95 | h = width / height 96 | x0 = 0. 97 | y0 = 0.5 - h / 2 98 | return x0, y0, w, h 99 | 100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): 101 | """ 102 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 103 | Args: 104 | img (PIL Image or Tensor): Image to be cropped. 105 | 106 | Returns: 107 | Bounding box: x0, y0, w, h 108 | PIL Image or Tensor: Cropped image. 109 | Based on: 110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 111 | """ 112 | width, height = get_image_size(img) 113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) 114 | 115 | 116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip): 117 | def forward(self, img: Image) -> (bool, Image): 118 | """ 119 | Additionally to flipping, returns a boolean whether it was flipped or not. 120 | Args: 121 | img (PIL Image or Tensor): Image to be flipped. 122 | 123 | Returns: 124 | flipped: whether the image was flipped or not 125 | PIL Image or Tensor: Randomly flipped image. 126 | 127 | Based on: 128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 129 | """ 130 | if torch.rand(1) < self.p: 131 | return True, F.hflip(img) 132 | return False, img 133 | -------------------------------------------------------------------------------- /taming/data/sflckr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SegmentationBase(Dataset): 10 | def __init__(self, 11 | data_csv, data_root, segmentation_root, 12 | size=None, random_crop=False, interpolation="bicubic", 13 | n_labels=182, shift_segmentation=False, 14 | ): 15 | self.n_labels = n_labels 16 | self.shift_segmentation = shift_segmentation 17 | self.data_csv = data_csv 18 | self.data_root = data_root 19 | self.segmentation_root = segmentation_root 20 | with open(self.data_csv, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | self.labels = { 24 | "relative_file_path_": [l for l in self.image_paths], 25 | "file_path_": [os.path.join(self.data_root, l) 26 | for l in self.image_paths], 27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) 28 | for l in self.image_paths] 29 | } 30 | 31 | size = None if size is not None and size<=0 else size 32 | self.size = size 33 | if self.size is not None: 34 | self.interpolation = interpolation 35 | self.interpolation = { 36 | "nearest": cv2.INTER_NEAREST, 37 | "bilinear": cv2.INTER_LINEAR, 38 | "bicubic": cv2.INTER_CUBIC, 39 | "area": cv2.INTER_AREA, 40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 42 | interpolation=self.interpolation) 43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 44 | interpolation=cv2.INTER_NEAREST) 45 | self.center_crop = not random_crop 46 | if self.center_crop: 47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 48 | else: 49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) 50 | self.preprocessor = self.cropper 51 | 52 | def __len__(self): 53 | return self._length 54 | 55 | def __getitem__(self, i): 56 | example = dict((k, self.labels[k][i]) for k in self.labels) 57 | image = Image.open(example["file_path_"]) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | if self.size is not None: 62 | image = self.image_rescaler(image=image)["image"] 63 | segmentation = Image.open(example["segmentation_path_"]) 64 | assert segmentation.mode == "L", segmentation.mode 65 | segmentation = np.array(segmentation).astype(np.uint8) 66 | if self.shift_segmentation: 67 | # used to support segmentations containing unlabeled==255 label 68 | segmentation = segmentation+1 69 | if self.size is not None: 70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 71 | if self.size is not None: 72 | processed = self.preprocessor(image=image, 73 | mask=segmentation 74 | ) 75 | else: 76 | processed = {"image": image, 77 | "mask": segmentation 78 | } 79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 80 | segmentation = processed["mask"] 81 | onehot = np.eye(self.n_labels)[segmentation] 82 | example["segmentation"] = onehot 83 | return example 84 | 85 | 86 | class Examples(SegmentationBase): 87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"): 88 | super().__init__(data_csv="data/sflckr_examples.txt", 89 | data_root="data/sflckr_images", 90 | segmentation_root="data/sflckr_segmentations", 91 | size=size, random_crop=random_crop, interpolation=interpolation) 92 | -------------------------------------------------------------------------------- /taming/data/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import tarfile 4 | import urllib 5 | import zipfile 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | from taming.data.helper_types import Annotation 11 | from torch._six import string_classes 12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format 13 | from tqdm import tqdm 14 | 15 | 16 | def unpack(path): 17 | if path.endswith("tar.gz"): 18 | with tarfile.open(path, "r:gz") as tar: 19 | tar.extractall(path=os.path.split(path)[0]) 20 | elif path.endswith("tar"): 21 | with tarfile.open(path, "r:") as tar: 22 | tar.extractall(path=os.path.split(path)[0]) 23 | elif path.endswith("zip"): 24 | with zipfile.ZipFile(path, "r") as f: 25 | f.extractall(path=os.path.split(path)[0]) 26 | else: 27 | raise NotImplementedError( 28 | "Unknown file extension: {}".format(os.path.splitext(path)[1]) 29 | ) 30 | 31 | 32 | def reporthook(bar): 33 | """tqdm progress bar for downloads.""" 34 | 35 | def hook(b=1, bsize=1, tsize=None): 36 | if tsize is not None: 37 | bar.total = tsize 38 | bar.update(b * bsize - bar.n) 39 | 40 | return hook 41 | 42 | 43 | def get_root(name): 44 | base = "data/" 45 | root = os.path.join(base, name) 46 | os.makedirs(root, exist_ok=True) 47 | return root 48 | 49 | 50 | def is_prepared(root): 51 | return Path(root).joinpath(".ready").exists() 52 | 53 | 54 | def mark_prepared(root): 55 | Path(root).joinpath(".ready").touch() 56 | 57 | 58 | def prompt_download(file_, source, target_dir, content_dir=None): 59 | targetpath = os.path.join(target_dir, file_) 60 | while not os.path.exists(targetpath): 61 | if content_dir is not None and os.path.exists( 62 | os.path.join(target_dir, content_dir) 63 | ): 64 | break 65 | print( 66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) 67 | ) 68 | if content_dir is not None: 69 | print( 70 | "Or place its content into '{}'.".format( 71 | os.path.join(target_dir, content_dir) 72 | ) 73 | ) 74 | input("Press Enter when done...") 75 | return targetpath 76 | 77 | 78 | def download_url(file_, url, target_dir): 79 | targetpath = os.path.join(target_dir, file_) 80 | os.makedirs(target_dir, exist_ok=True) 81 | with tqdm( 82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ 83 | ) as bar: 84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) 85 | return targetpath 86 | 87 | 88 | def download_urls(urls, target_dir): 89 | paths = dict() 90 | for fname, url in urls.items(): 91 | outpath = download_url(fname, url, target_dir) 92 | paths[fname] = outpath 93 | return paths 94 | 95 | 96 | def quadratic_crop(x, bbox, alpha=1.0): 97 | """bbox is xmin, ymin, xmax, ymax""" 98 | im_h, im_w = x.shape[:2] 99 | bbox = np.array(bbox, dtype=np.float32) 100 | bbox = np.clip(bbox, 0, max(im_h, im_w)) 101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) 102 | w = bbox[2] - bbox[0] 103 | h = bbox[3] - bbox[1] 104 | l = int(alpha * max(w, h)) 105 | l = max(l, 2) 106 | 107 | required_padding = -1 * min( 108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) 109 | ) 110 | required_padding = int(np.ceil(required_padding)) 111 | if required_padding > 0: 112 | padding = [ 113 | [required_padding, required_padding], 114 | [required_padding, required_padding], 115 | ] 116 | padding += [[0, 0]] * (len(x.shape) - 2) 117 | x = np.pad(x, padding, "reflect") 118 | center = center[0] + required_padding, center[1] + required_padding 119 | xmin = int(center[0] - l / 2) 120 | ymin = int(center[1] - l / 2) 121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) 122 | 123 | 124 | def custom_collate(batch): 125 | r"""source: pytorch 1.9.0, only one modification to original code """ 126 | 127 | elem = batch[0] 128 | elem_type = type(elem) 129 | if isinstance(elem, torch.Tensor): 130 | out = None 131 | if torch.utils.data.get_worker_info() is not None: 132 | # If we're in a background process, concatenate directly into a 133 | # shared memory tensor to avoid an extra copy 134 | numel = sum([x.numel() for x in batch]) 135 | storage = elem.storage()._new_shared(numel) 136 | out = elem.new(storage) 137 | return torch.stack(batch, 0, out=out) 138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 139 | and elem_type.__name__ != 'string_': 140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 141 | # array of string classes and object 142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 144 | 145 | return custom_collate([torch.as_tensor(b) for b in batch]) 146 | elif elem.shape == (): # scalars 147 | return torch.as_tensor(batch) 148 | elif isinstance(elem, float): 149 | return torch.tensor(batch, dtype=torch.float64) 150 | elif isinstance(elem, int): 151 | return torch.tensor(batch) 152 | elif isinstance(elem, string_classes): 153 | return batch 154 | elif isinstance(elem, collections.abc.Mapping): 155 | return {key: custom_collate([d[key] for d in batch]) for key in elem} 156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) 158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added 159 | return batch # added 160 | elif isinstance(elem, collections.abc.Sequence): 161 | # check to make sure that the elements in batch have consistent size 162 | it = iter(batch) 163 | elem_size = len(next(it)) 164 | if not all(len(elem) == elem_size for elem in it): 165 | raise RuntimeError('each element in list of batch should be of equal size') 166 | transposed = zip(*batch) 167 | return [custom_collate(samples) for samples in transposed] 168 | 169 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 170 | -------------------------------------------------------------------------------- /taming/modules/autoencoder/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuHuangAI/DiffusionEdge/b03617e84b79336fd1f4d650105580e026ab26a7/taming/modules/autoencoder/lpips/vgg.pth -------------------------------------------------------------------------------- /taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | 69 | class NLayerDiscriminator2(nn.Module): 70 | """Defines a PatchGAN discriminator as in Pix2Pix 71 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 72 | """ 73 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 74 | """Construct a PatchGAN discriminator 75 | Parameters: 76 | input_nc (int) -- the number of channels in input images 77 | ndf (int) -- the number of filters in the last conv layer 78 | n_layers (int) -- the number of conv layers in the discriminator 79 | norm_layer -- normalization layer 80 | """ 81 | super(NLayerDiscriminator2, self).__init__() 82 | if not use_actnorm: 83 | norm_layer = nn.BatchNorm3d 84 | else: 85 | norm_layer = ActNorm 86 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 87 | use_bias = norm_layer.func != nn.BatchNorm3d 88 | else: 89 | use_bias = norm_layer != nn.BatchNorm3d 90 | 91 | kw = 4 92 | padw = 1 93 | sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 94 | nf_mult = 1 95 | nf_mult_prev = 1 96 | for n in range(1, n_layers): # gradually increase the number of filters 97 | nf_mult_prev = nf_mult 98 | nf_mult = min(2 ** n, 8) 99 | sequence += [ 100 | nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, 101 | padding=padw, bias=use_bias, groups=8), 102 | norm_layer(ndf * nf_mult), 103 | nn.LeakyReLU(0.2, True) 104 | ] 105 | 106 | nf_mult_prev = nf_mult 107 | nf_mult = min(2 ** n_layers, 8) 108 | sequence += [ 109 | nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, 110 | padding=padw, bias=use_bias, groups=8), 111 | norm_layer(ndf * nf_mult), 112 | nn.LeakyReLU(0.2, True) 113 | ] 114 | 115 | sequence += [ 116 | nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw), 117 | # nn.Sigmoid() 118 | ] # output 1 channel prediction map 119 | self.main = nn.Sequential(*sequence) 120 | 121 | def forward(self, input): 122 | """Standard forward.""" 123 | return self.main(input) 124 | 125 | if __name__ == "__main__": 126 | import torch 127 | model = NLayerDiscriminator2(input_nc=3, ndf=64, n_layers=3) 128 | x = torch.rand(1, 3, 64, 64, 64) 129 | with torch.no_grad(): 130 | y = model(x) 131 | pause = 0 -------------------------------------------------------------------------------- /taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from .util import get_ckpt_path 9 | 10 | class LPIPS(nn.Module): 11 | # Learned perceptual metric 12 | def __init__(self, use_dropout=True): 13 | super().__init__() 14 | self.scaling_layer = ScalingLayer() 15 | self.chns = [64, 128, 256, 512, 512] # vg16 features 16 | self.net = vgg16(pretrained=True, requires_grad=False) 17 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 18 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 19 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 20 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 21 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 22 | self.load_from_pretrained() 23 | for param in self.parameters(): 24 | param.requires_grad = False 25 | 26 | def load_from_pretrained(self, name="vgg_lpips"): 27 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 28 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 29 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 30 | 31 | @classmethod 32 | def from_pretrained(cls, name="vgg_lpips"): 33 | if name != "vgg_lpips": 34 | raise NotImplementedError 35 | model = cls() 36 | ckpt = get_ckpt_path(name) 37 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 38 | return model 39 | 40 | def forward(self, input, target): 41 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 42 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 43 | feats0, feats1, diffs = {}, {}, {} 44 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 45 | for kk in range(len(self.chns)): 46 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 47 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 48 | 49 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 50 | val = res[0] 51 | for l in range(1, len(self.chns)): 52 | val += res[l] 53 | return val 54 | 55 | 56 | class ScalingLayer(nn.Module): 57 | def __init__(self): 58 | super(ScalingLayer, self).__init__() 59 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 60 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 61 | 62 | def forward(self, inp): 63 | return (inp - self.shift) / self.scale 64 | 65 | 66 | class NetLinLayer(nn.Module): 67 | """ A single linear layer which does a 1x1 conv """ 68 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 69 | super(NetLinLayer, self).__init__() 70 | layers = [nn.Dropout(), ] if (use_dropout) else [] 71 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 72 | self.model = nn.Sequential(*layers) 73 | 74 | 75 | class vgg16(torch.nn.Module): 76 | def __init__(self, requires_grad=False, pretrained=True): 77 | super(vgg16, self).__init__() 78 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 79 | self.slice1 = torch.nn.Sequential() 80 | self.slice2 = torch.nn.Sequential() 81 | self.slice3 = torch.nn.Sequential() 82 | self.slice4 = torch.nn.Sequential() 83 | self.slice5 = torch.nn.Sequential() 84 | self.N_slices = 5 85 | for x in range(4): 86 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 87 | for x in range(4, 9): 88 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 89 | for x in range(9, 16): 90 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 91 | for x in range(16, 23): 92 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 93 | for x in range(23, 30): 94 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 95 | if not requires_grad: 96 | for param in self.parameters(): 97 | param.requires_grad = False 98 | 99 | def forward(self, X): 100 | h = self.slice1(X) 101 | h_relu1_2 = h 102 | h = self.slice2(h) 103 | h_relu2_2 = h 104 | h = self.slice3(h) 105 | h_relu3_3 = h 106 | h = self.slice4(h) 107 | h_relu4_3 = h 108 | h = self.slice5(h) 109 | h_relu5_3 = h 110 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 111 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 112 | return out 113 | 114 | 115 | def normalize_tensor(x,eps=1e-10): 116 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 117 | return x/(norm_factor+eps) 118 | 119 | 120 | def spatial_average(x, keepdim=True): 121 | return x.mean([2,3],keepdim=keepdim) 122 | 123 | -------------------------------------------------------------------------------- /taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /taming/modules/losses/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /taming/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from taming.modules.losses.lpips import LPIPS 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init, NLayerDiscriminator2 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge"): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 64 | if last_layer is not None: 65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 67 | else: 68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 70 | 71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 73 | d_weight = d_weight * self.discriminator_weight 74 | return d_weight 75 | 76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 77 | global_step, last_layer=None, cond=None, split="train"): 78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 79 | if self.perceptual_weight > 0: 80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss 86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | nll_loss = torch.mean(nll_loss) 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | try: 101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 102 | except RuntimeError: 103 | assert not self.training 104 | d_weight = torch.tensor(0.0) 105 | 106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 108 | 109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 111 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 112 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 113 | "{}/p_loss".format(split): p_loss.detach().mean(), 114 | "{}/d_weight".format(split): d_weight.detach(), 115 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 116 | "{}/g_loss".format(split): g_loss.detach().mean(), 117 | } 118 | return loss, log 119 | 120 | if optimizer_idx == 1: 121 | # second pass for discriminator update 122 | if cond is None: 123 | logits_real = self.discriminator(inputs.contiguous().detach()) 124 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 125 | else: 126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 131 | 132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 133 | "{}/logits_real".format(split): logits_real.detach().mean(), 134 | "{}/logits_fake".format(split): logits_fake.detach().mean() 135 | } 136 | return d_loss, log 137 | -------------------------------------------------------------------------------- /taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /unet_plus/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /unet_plus/ema.py: -------------------------------------------------------------------------------- 1 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py 2 | 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | 8 | 9 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 10 | class ExponentialMovingAverage: 11 | """ 12 | Maintains (exponential) moving average of a set of parameters. 13 | """ 14 | 15 | def __init__(self, parameters, decay, use_num_updates=True): 16 | """ 17 | Args: 18 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 19 | `model.parameters()`. 20 | decay: The exponential decay. 21 | use_num_updates: Whether to use number of updates when computing 22 | averages. 23 | """ 24 | if decay < 0.0 or decay > 1.0: 25 | raise ValueError('Decay must be between 0 and 1') 26 | self.decay = decay 27 | self.num_updates = 0 if use_num_updates else None 28 | self.shadow_params = [p.clone().detach() 29 | for p in parameters if p.requires_grad] 30 | self.collected_params = [] 31 | 32 | def update(self, parameters): 33 | """ 34 | Update currently maintained parameters. 35 | 36 | Call this every time the parameters are updated, such as the result of 37 | the `optimizer.step()` call. 38 | 39 | Args: 40 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 41 | parameters used to initialize this object. 42 | """ 43 | decay = self.decay 44 | if self.num_updates is not None: 45 | self.num_updates += 1 46 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 47 | one_minus_decay = 1.0 - decay 48 | with torch.no_grad(): 49 | parameters = [p for p in parameters if p.requires_grad] 50 | for s_param, param in zip(self.shadow_params, parameters): 51 | s_param.sub_(one_minus_decay * (s_param - param)) 52 | 53 | def copy_to(self, parameters): 54 | """ 55 | Copy current parameters into given collection of parameters. 56 | 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 59 | updated with the stored moving averages. 60 | """ 61 | parameters = [p for p in parameters if p.requires_grad] 62 | for s_param, param in zip(self.shadow_params, parameters): 63 | if param.requires_grad: 64 | param.data.copy_(s_param.data) 65 | 66 | def store(self, parameters): 67 | """ 68 | Save the current parameters for restoring later. 69 | 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | temporarily stored. 73 | """ 74 | self.collected_params = [param.clone() for param in parameters] 75 | 76 | def restore(self, parameters): 77 | """ 78 | Restore the parameters stored with the `store` method. 79 | Useful to validate the model with EMA parameters without affecting the 80 | original optimization process. Store the parameters before the 81 | `copy_to` method. After validation (or model saving), use this to 82 | restore the former parameters. 83 | 84 | Args: 85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 86 | updated with the stored parameters. 87 | """ 88 | for c_param, param in zip(self.collected_params, parameters): 89 | param.data.copy_(c_param.data) 90 | 91 | def state_dict(self): 92 | return dict(decay=self.decay, num_updates=self.num_updates, 93 | shadow_params=self.shadow_params) 94 | 95 | def load_state_dict(self, state_dict): 96 | self.decay = state_dict['decay'] 97 | self.num_updates = state_dict['num_updates'] 98 | self.shadow_params = state_dict['shadow_params'] -------------------------------------------------------------------------------- /unet_plus/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /unet_plus/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /unet_plus/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /unet_plus/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /unet_plus/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /unet_plus/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /unet_plus/unet_pp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """DDPM model. 18 | 19 | This code is the pytorch equivalent of: 20 | https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py 21 | """ 22 | import torch 23 | import torch.nn as nn 24 | import functools 25 | 26 | from unet_plus import utils, layers, normalization 27 | 28 | RefineBlock = layers.RefineBlock 29 | ResidualBlock = layers.ResidualBlock 30 | ResnetBlockDDPM = layers.ResnetBlockDDPM 31 | Upsample = layers.Upsample 32 | Downsample = layers.Downsample 33 | conv3x3 = layers.ddpm_conv3x3 34 | get_act = layers.get_act 35 | get_normalization = normalization.get_normalization 36 | default_initializer = layers.default_init 37 | 38 | 39 | # @utils.register_model(name='ddpm') 40 | class UnetPlus(nn.Module): 41 | def __init__(self, config): 42 | super().__init__() 43 | self.act = act = get_act(config) 44 | self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) 45 | 46 | self.nf = nf = config.nf 47 | ch_mult = config.ch_mult 48 | self.num_res_blocks = num_res_blocks = config.num_res_blocks 49 | self.attn_resolutions = attn_resolutions = config.attn_resolutions 50 | dropout = config.dropout 51 | resamp_with_conv = config.resamp_with_conv 52 | self.num_resolutions = num_resolutions = len(ch_mult) 53 | self.all_resolutions = all_resolutions = [config.image_size // (2 ** i) for i in range(num_resolutions)] 54 | 55 | AttnBlock = functools.partial(layers.AttnBlock) 56 | self.conditional = conditional = config.conditional 57 | ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout) 58 | if conditional: 59 | # Condition on noise levels. 60 | modules = [nn.Linear(nf, nf * 4)] 61 | modules[0].weight.data = default_initializer()(modules[0].weight.data.shape) 62 | nn.init.zeros_(modules[0].bias) 63 | modules.append(nn.Linear(nf * 4, nf * 4)) 64 | modules[1].weight.data = default_initializer()(modules[1].weight.data.shape) 65 | nn.init.zeros_(modules[1].bias) 66 | 67 | # self.centered = config.data.centered 68 | channels = config.in_channels 69 | 70 | # Downsampling block 71 | modules.append(conv3x3(channels, nf)) 72 | hs_c = [nf] 73 | in_ch = nf 74 | for i_level in range(num_resolutions): 75 | # Residual blocks for this resolution 76 | for i_block in range(num_res_blocks): 77 | out_ch = nf * ch_mult[i_level] 78 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 79 | in_ch = out_ch 80 | if all_resolutions[i_level] in attn_resolutions: 81 | modules.append(AttnBlock(channels=in_ch)) 82 | hs_c.append(in_ch) 83 | if i_level != num_resolutions - 1: 84 | modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) 85 | hs_c.append(in_ch) 86 | 87 | in_ch = hs_c[-1] 88 | modules.append(ResnetBlock(in_ch=in_ch)) 89 | modules.append(AttnBlock(channels=in_ch)) 90 | modules.append(ResnetBlock(in_ch=in_ch)) 91 | 92 | # Upsampling block 93 | for i_level in reversed(range(num_resolutions)): 94 | for i_block in range(num_res_blocks + 1): 95 | out_ch = nf * ch_mult[i_level] 96 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) 97 | in_ch = out_ch 98 | if all_resolutions[i_level] in attn_resolutions: 99 | modules.append(AttnBlock(channels=in_ch)) 100 | if i_level != 0: 101 | modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) 102 | 103 | assert not hs_c 104 | modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6)) 105 | modules.append(conv3x3(in_ch, channels, init_scale=0.)) 106 | self.all_modules = nn.ModuleList(modules) 107 | 108 | self.scale_by_sigma = config.scale_by_sigma 109 | 110 | def forward(self, x, times=None): 111 | modules = self.all_modules 112 | m_idx = 0 113 | if times is not None: 114 | # timestep/scale embedding 115 | timesteps = times 116 | temb = layers.get_timestep_embedding(timesteps, self.nf) 117 | temb = modules[m_idx](temb) 118 | m_idx += 1 119 | temb = modules[m_idx](self.act(temb)) 120 | m_idx += 1 121 | else: 122 | temb = times 123 | 124 | # if self.centered: 125 | # # Input is in [-1, 1] 126 | # h = x 127 | # else: 128 | # # Input is in [0, 1] 129 | # h = 2 * x - 1. 130 | h = x 131 | 132 | # Downsampling block 133 | hs = [modules[m_idx](h)] 134 | m_idx += 1 135 | for i_level in range(self.num_resolutions): 136 | # Residual blocks for this resolution 137 | for i_block in range(self.num_res_blocks): 138 | h = modules[m_idx](hs[-1], temb) 139 | m_idx += 1 140 | if h.shape[-1] in self.attn_resolutions: 141 | h = modules[m_idx](h) 142 | m_idx += 1 143 | hs.append(h) 144 | if i_level != self.num_resolutions - 1: 145 | hs.append(modules[m_idx](hs[-1])) 146 | m_idx += 1 147 | 148 | h = hs[-1] 149 | h = modules[m_idx](h, temb) 150 | m_idx += 1 151 | h = modules[m_idx](h) 152 | m_idx += 1 153 | h = modules[m_idx](h, temb) 154 | m_idx += 1 155 | 156 | # Upsampling block 157 | for i_level in reversed(range(self.num_resolutions)): 158 | for i_block in range(self.num_res_blocks + 1): 159 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) 160 | m_idx += 1 161 | if h.shape[-1] in self.attn_resolutions: 162 | h = modules[m_idx](h) 163 | m_idx += 1 164 | if i_level != 0: 165 | h = modules[m_idx](h) 166 | m_idx += 1 167 | 168 | assert not hs 169 | h = self.act(modules[m_idx](h)) 170 | m_idx += 1 171 | h = modules[m_idx](h) 172 | m_idx += 1 173 | assert m_idx == len(modules) 174 | 175 | if self.scale_by_sigma: 176 | # Divide the output by sigmas. Useful for training with the NCSN loss. 177 | # The DDPM loss scales the network output by sigma in the loss function, 178 | # so no need of doing it here. 179 | used_sigmas = self.sigmas[times, None, None, None] 180 | h = h / used_sigmas 181 | 182 | return h 183 | 184 | if __name__ == '__main__': 185 | 186 | pass -------------------------------------------------------------------------------- /unet_plus/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | 19 | import torch 20 | # import sde_lib 21 | import numpy as np 22 | 23 | 24 | _MODELS = {} 25 | 26 | 27 | def register_model(cls=None, *, name=None): 28 | """A decorator for registering model classes.""" 29 | 30 | def _register(cls): 31 | if name is None: 32 | local_name = cls.__name__ 33 | else: 34 | local_name = name 35 | if local_name in _MODELS: 36 | raise ValueError(f'Already registered model with name: {local_name}') 37 | _MODELS[local_name] = cls 38 | return cls 39 | 40 | if cls is None: 41 | return _register 42 | else: 43 | return _register(cls) 44 | 45 | 46 | def get_model(name): 47 | return _MODELS[name] 48 | 49 | 50 | def get_sigmas(config): 51 | """Get sigmas --- the set of noise levels for SMLD from config files. 52 | Args: 53 | config: A ConfigDict object parsed from the config file 54 | Returns: 55 | sigmas: a jax numpy arrary of noise levels 56 | """ 57 | sigmas = np.exp( 58 | np.linspace(np.log(config.sigma_max), np.log(config.sigma_min), config.num_scales)) 59 | 60 | return sigmas 61 | 62 | 63 | def get_ddpm_params(config): 64 | """Get betas and alphas --- parameters used in the original DDPM paper.""" 65 | num_diffusion_timesteps = 1000 66 | # parameters need to be adapted if number of time steps differs from 1000 67 | beta_start = config.beta_min / config.num_scales 68 | beta_end = config.beta_max / config.num_scales 69 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 70 | 71 | alphas = 1. - betas 72 | alphas_cumprod = np.cumprod(alphas, axis=0) 73 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 74 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 75 | 76 | return { 77 | 'betas': betas, 78 | 'alphas': alphas, 79 | 'alphas_cumprod': alphas_cumprod, 80 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 81 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 82 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 83 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 84 | 'num_diffusion_timesteps': num_diffusion_timesteps 85 | } 86 | 87 | 88 | def create_model(config): 89 | """Create the score model.""" 90 | model_name = config.name 91 | score_model = get_model(model_name)(config) 92 | score_model = score_model.to(config.device) 93 | score_model = torch.nn.DataParallel(score_model) 94 | return score_model 95 | 96 | 97 | def get_model_fn(model, train=False): 98 | """Create a function to give the output of the score-based model. 99 | 100 | Args: 101 | model: The score model. 102 | train: `True` for training and `False` for evaluation. 103 | 104 | Returns: 105 | A model function. 106 | """ 107 | 108 | def model_fn(x, labels): 109 | """Compute the output of the score-based model. 110 | 111 | Args: 112 | x: A mini-batch of input data. 113 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 114 | for different models. 115 | 116 | Returns: 117 | A tuple of (model output, new mutable states) 118 | """ 119 | if not train: 120 | model.eval() 121 | return model(x, labels) 122 | else: 123 | model.train() 124 | return model(x, labels) 125 | 126 | return model_fn 127 | 128 | ''' 129 | def get_score_fn(sde, model, train=False, continuous=False): 130 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 131 | 132 | Args: 133 | sde: An `sde_lib.SDE` object that represents the forward SDE. 134 | model: A score model. 135 | train: `True` for training and `False` for evaluation. 136 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 137 | 138 | Returns: 139 | A score function. 140 | """ 141 | model_fn = get_model_fn(model, train=train) 142 | 143 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 144 | def score_fn(x, t): 145 | # Scale neural network output by standard deviation and flip sign 146 | if continuous or isinstance(sde, sde_lib.subVPSDE): 147 | # For VP-trained models, t=0 corresponds to the lowest noise level 148 | # The maximum value of time embedding is assumed to 999 for 149 | # continuously-trained models. 150 | labels = t * 999 151 | score = model_fn(x, labels) 152 | std = sde.marginal_prob(torch.zeros_like(x), t)[1] 153 | else: 154 | # For VP-trained models, t=0 corresponds to the lowest noise level 155 | labels = t * (sde.N - 1) 156 | score = model_fn(x, labels) 157 | std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] 158 | 159 | score = -score / std[:, None, None, None] 160 | return score 161 | 162 | elif isinstance(sde, sde_lib.VESDE): 163 | def score_fn(x, t): 164 | if continuous: 165 | labels = sde.marginal_prob(torch.zeros_like(x), t)[1] 166 | else: 167 | # For VE-trained models, t=0 corresponds to the highest noise level 168 | labels = sde.T - t 169 | labels *= sde.N - 1 170 | labels = torch.round(labels).long() 171 | 172 | score = model_fn(x, labels) 173 | return score 174 | 175 | else: 176 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 177 | 178 | return score_fn 179 | ''' 180 | 181 | def to_flattened_numpy(x): 182 | """Flatten a torch tensor `x` and convert it to numpy.""" 183 | return x.detach().cpu().numpy().reshape((-1,)) 184 | 185 | 186 | def from_flattened_numpy(x, shape): 187 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 188 | return torch.from_numpy(x.reshape(shape)) --------------------------------------------------------------------------------