├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── configs ├── inference │ ├── controlnet_c_3b_canny.yaml │ ├── controlnet_c_3b_identity.yaml │ ├── controlnet_c_3b_inpainting.yaml │ ├── controlnet_c_3b_sr.yaml │ ├── lora_c_3b.yaml │ ├── stage_b_1b.yaml │ ├── stage_b_3b.yaml │ ├── stage_c_1b.yaml │ └── stage_c_3b.yaml └── training │ ├── cfg_control_lr.yaml │ ├── lora_personalization.yaml │ └── t2i.yaml ├── core ├── __init__.py ├── data │ ├── __init__.py │ ├── bucketeer.py │ ├── bucketeer_deg.py │ └── deg_kair_utils │ │ ├── test.bmp │ │ ├── test.png │ │ ├── utils_alignfaces.py │ │ ├── utils_blindsr.py │ │ ├── utils_bnorm.py │ │ ├── utils_deblur.py │ │ ├── utils_dist.py │ │ ├── utils_googledownload.py │ │ ├── utils_image.py │ │ ├── utils_lmdb.py │ │ ├── utils_logger.py │ │ ├── utils_mat.py │ │ ├── utils_matconvnet.py │ │ ├── utils_model.py │ │ ├── utils_modelsummary.py │ │ ├── utils_option.py │ │ ├── utils_params.py │ │ ├── utils_receptivefield.py │ │ ├── utils_regularizers.py │ │ ├── utils_sisr.py │ │ ├── utils_video.py │ │ └── utils_videoio.py ├── scripts │ ├── __init__.py │ └── cli.py ├── templates │ ├── __init__.py │ └── diffusion.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── base_dto.cpython-310.pyc │ ├── base_dto.cpython-39.pyc │ ├── save_and_load.cpython-310.pyc │ └── save_and_load.cpython-39.pyc │ ├── base_dto.py │ └── save_and_load.py ├── figures ├── California_000490.jpg ├── example_dataset │ ├── 000008.jpg │ ├── 000008.json │ ├── 000012.jpg │ └── 000012.json ├── example_lora_cat │ ├── 1_B0004902.jpg │ ├── 2_B0005089.jpg │ ├── 3_B0005163.jpg │ ├── 4_cat.jpg │ ├── 5_cat.jpg │ ├── 6_cat.jpg │ ├── 7_cat.jpg │ ├── 8_20240611230541.jpg │ ├── 9_20240611230549.jpg │ └── README.md └── teaser.jpg ├── gdf ├── __init__.py ├── loss_weights.py ├── noise_conditions.py ├── readme.md ├── samplers.py ├── scalers.py ├── schedulers.py └── targets.py ├── inference ├── __init__.py ├── test_controlnet.py ├── test_personalized.py ├── test_t2i.py └── utils.py ├── models └── models_checklist.txt ├── modules ├── __init__.py ├── cnet_modules │ ├── face_id │ │ ├── __pycache__ │ │ │ └── arcface.cpython-310.pyc │ │ └── arcface.py │ ├── inpainting │ │ ├── __pycache__ │ │ │ ├── saliency_model.cpython-310.pyc │ │ │ └── saliency_model.cpython-39.pyc │ │ ├── saliency_model.pt │ │ └── saliency_model.py │ └── pidinet │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── model.cpython-310.pyc │ │ ├── model.cpython-39.pyc │ │ ├── util.cpython-310.pyc │ │ └── util.cpython-39.pyc │ │ ├── ckpts │ │ └── table5_pidinet.pth │ │ ├── model.py │ │ └── util.py ├── common.py ├── common_ckpt.py ├── controlnet.py ├── effnet.py ├── inr_fea_res_lite.py ├── lora.py ├── model_4stage_lite.py ├── previewer.py ├── resnet.py ├── speed_util.py ├── stage_a.py ├── stage_b.py └── stage_c.py ├── prompt_list.txt ├── requirements.txt └── train ├── __init__.py ├── base.py ├── dist_core.py ├── train_b.py ├── train_c.py ├── train_c_lora.py ├── train_personalized.py ├── train_t2i.py └── train_ultrapixel_control.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.yml 2 | *.out 3 | dist_file_* 4 | __pycache__/* 5 | */__pycache__/* 6 | */**/__pycache__/* 7 | *_latest_output.jpg 8 | *_sample.jpg 9 | jobs/*.sh 10 | .ipynb_checkpoints 11 | *.safetensors 12 | *_test.yaml 13 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UltraPixel: Advancing Ultra-High-Resolution Image Synthesis to New Peaks (NeurIPS 2024) 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-paper-red)](https://arxiv.org/abs/2407.02158) 4 | [![Full Paper](https://img.shields.io/badge/Full_Paper-PDF-blue)](https://drive.google.com/file/d/1X18HH9kj7ltAnZorrkD84RJEdsJu4gDF/view?usp=sharing) 5 | [![Project Homepage](https://img.shields.io/badge/Project-Homepage-brightgreen)](https://jingjingrenabc.github.io/ultrapixel/) 6 | [![Hugging Face Demo](https://img.shields.io/badge/Hugging_Face-Demo-yellow)](https://huggingface.co/spaces/roubaofeipi/UltraPixel-demo) 7 | 8 | UltraPixel is designed to create exceptionally high-quality, detail-rich images at various resolutions, pushing the boundaries of ultra-high-resolution image synthesis. For more details and to see more stunning images, please visit the [Project Page](https://jingjingrenabc.github.io/ultrapixel/). The [arXiv version](https://arxiv.org/abs/2407.02158) of the paper contains compressed images, while the [full paper](https://drive.google.com/file/d/1X18HH9kj7ltAnZorrkD84RJEdsJu4gDF/view?usp=sharing) features uncompressed, high-quality images. 9 | 10 | ## 🔥 **Updates:** 11 | - **`2024/09/26`**: 🎉 UltraPixel has been accepted to NeurIPS 2024! 12 | - **`2024/09/19`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/roubaofeipi/UltraPixel-demo), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)! Gradio interface for text-to-image inference is also provided, and please see Inference section! 13 | - **`2024/09/19`**: We have updated the versions of PyTorch and Torchvision in our environment. On an RTX 4090 GPU, generating a 2560×5120 image (without stage_a_tiled) now takes approximately 60 seconds, compared to about three minutes in the previous setup. 14 | 15 | ![teaser](figures/teaser.jpg) 16 | 17 | ## Getting Started 18 | **1.** Install dependency by running: 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | **2.** Download pre-trained models from [StableCascade model downloading instructions](https://github.com/Stability-AI/StableCascade/tree/master/models). Small-big models (the small model for stage b and the big model for stage with bfloat16 format are used.) The big-big setting is also supported, while small-big favors more efficiency. 23 | 24 | **3.** Download newly added parameters of UltraPixel from [here](https://huggingface.co/roubaofeipi/UltraPixel). 25 | 26 | **Note**: All model downloading urls are provided [here](./models/models_checklist.txt). They should be put in the directory [models](./models). 27 | 28 | ## Inference 29 | ### Text-guided Image Generation 30 | We provide Gradio interface for inference. Run by : 31 | ``` 32 | CUDA_VISIBLE_DEVICES=0 python app.py 33 | ``` 34 | Or generate an image by running: 35 | ``` 36 | CUDA_VISIBLE_DEVICES=0 python inference/test_t2i.py 37 | ``` 38 | **Tips**: To generate aesthetic images, use detailed prompts with specific descriptions. It's recommended to include elements such as the subject, background, colors, lighting, and mood, and enhance your prompts with high-quality modifiers like "high quality", "rich detail", "8k", "photo-realistic", "cinematic", and "perfection". For example, use "A breathtaking sunset over a serene mountain range, with vibrant orange and purple hues in the sky, high quality, rich detail, 8k, photo-realistic, cinematic lighting, perfection". Be concise but detailed, specific and clear, and experiment with different word combinations for the best results. 39 | 40 | Several example prompts are provided [here](./prompt_list.txt). 41 | 42 | It is recommended to add "--stage_a_tiled" for decoding in stage a to save memory. 43 | 44 | The table below show memory requirements and running times on different GPUs. For the A100 with 80GB memory, tiled decoding is not necessary. 45 | 46 | **On 80G A100:** 47 | | Resolution | Stage C | Stage B | Stage A | 48 | |---------------------|----------|---------|--------| 49 | |2048*2048 |15.9G / 12s | 14.5G / 4s |**w/o tiled**: 11.2G / 1s | 50 | |4096*4096 |18.7G / 52s | 19.7G / 26s |**w/o tiled**: 45.3G / 2s, **tiled**: 9.3G / 128s| 51 | 52 | **On 32G V100** (only works using float32 on Stages C and B): 53 | | Resolution | Stage C | Stage B | Stage A | 54 | |---------------------|----------|---------|-------------------------------| 55 | |2048*2048 |16.7G / 83s | 11.7G / 22s |**w/o tiled**: 10.1G / 2s | 56 | |4096*4096 |18.0G / 287s | 22.7G / 172s |**w/o tiled**: OOM, **tiled**: 9.0G / 305s| 57 | 58 | **On 24G RTX4090:** 59 | | Resolution | Stage C | Stage B | Stage A | 60 | |---------------------|----------|---------|-------------------------------| 61 | |2048*2048 |15.5G / 83s | 13.2G / 22s |**w/o tiled**: 11.3G / 1s | 62 | |4096*4096 |19.9G / 153s | 23.4G / 44s |**w/o tiled**: OOM, **tiled**: 11.3G / 114s | 63 | 64 | ### Personalized Image Generation 65 | The repo provides a personalized model of a cat. Download the personalized model [here](https://huggingface.co/roubaofeipi/UltraPixel/blob/main/lora_cat.safetensors) and run the following command to generate personalized results. Note that in the text command you need to use identifier "cat [roubaobao]" to indicate the cat. 66 | ``` 67 | CUDA_VISIBLE_DEVICES=0 python inference/test_personalized.py 68 | ``` 69 | ### Controlnet Image Generation 70 | Download Canny [ControlNet](https://huggingface.co/stabilityai/stable-cascade/resolve/main/controlnet/canny.safetensors) provided by StableCascade and run the command: 71 | ``` 72 | CUDA_VISIBLE_DEVICES=0 python inference/test_controlnet.py 73 | ``` 74 | Note that ControlNet is used without further fine-tuning, so the supported highest resolution is 4K, e.g., 3840 * 2160, 2048 * 2048. 75 | 76 | 77 | ## T2I Training 78 | Put all your images and captions into a folder. Here's an example training dataset [here](./figures/example_dataset) for reference. 79 | Start training by running: 80 | ``` 81 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train/train_t2i.py configs/training/t2i.yaml 82 | ``` 83 | 84 | 85 | ## Personalized Training 86 | Put all your images into a folder. Here's an expample training dataset [here](./figures/example_dataset). The training prompt can be described as: a photo of a cat [roubaobao]. 87 | 88 | Start training by running: 89 | ``` 90 | CUDA_VISIBLE_DEVICES=0,1 python train/train_personalized.py \ 91 | configs/training/lora_personalization.yaml 92 | ``` 93 | 94 | ## Citation 95 | ```bibtex 96 | @article{ren2024ultrapixel, 97 | title={UltraPixel: Advancing Ultra-High-Resolution Image Synthesis to New Peaks}, 98 | author={Ren, Jingjing and Li, Wenbo and Chen, Haoyu and Pei, Renjing and Shao, Bin and Guo, Yong and Peng, Long and Song, Fenglong and Zhu, Lei}, 99 | journal={arXiv preprint arXiv:2407.02158}, 100 | year={2024} 101 | } 102 | ``` 103 | ## Contact Information 104 | To reach out to the paper’s authors, please refer to the contact information provided on the [project page](https://jingjingrenabc.github.io/ultrapixel/). 105 | 106 | ## Acknowledgements 107 | This project is build upon [StableCascade](https://github.com/Stability-AI/StableCascade) and [Trans-inr](https://github.com/yinboc/trans-inr). Thanks for their code sharing :) 108 | -------------------------------------------------------------------------------- /configs/inference/controlnet_c_3b_canny.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | # ControlNet specific 6 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 7 | controlnet_filter: CannyFilter 8 | controlnet_filter_params: 9 | resize: 224 10 | 11 | effnet_checkpoint_path: models/effnet_encoder.safetensors 12 | previewer_checkpoint_path: models/previewer.safetensors 13 | generator_checkpoint_path: models/stage_c_bf16.safetensors 14 | controlnet_checkpoint_path: models/canny.safetensors 15 | -------------------------------------------------------------------------------- /configs/inference/controlnet_c_3b_identity.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | # ControlNet specific 6 | controlnet_bottleneck_mode: 'simple' 7 | controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] 8 | controlnet_filter: IdentityFilter 9 | controlnet_filter_params: 10 | max_faces: 4 11 | p_drop: 0.00 12 | p_full: 0.0 13 | 14 | effnet_checkpoint_path: models/effnet_encoder.safetensors 15 | previewer_checkpoint_path: models/previewer.safetensors 16 | generator_checkpoint_path: models/stage_c_bf16.safetensors 17 | controlnet_checkpoint_path: 18 | -------------------------------------------------------------------------------- /configs/inference/controlnet_c_3b_inpainting.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | # ControlNet specific 6 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 7 | controlnet_filter: InpaintFilter 8 | controlnet_filter_params: 9 | thresold: [0.04, 0.4] 10 | p_outpaint: 0.4 11 | 12 | effnet_checkpoint_path: models/effnet_encoder.safetensors 13 | previewer_checkpoint_path: models/previewer.safetensors 14 | generator_checkpoint_path: models/stage_c_bf16.safetensors 15 | controlnet_checkpoint_path: models/inpainting.safetensors 16 | -------------------------------------------------------------------------------- /configs/inference/controlnet_c_3b_sr.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | # ControlNet specific 6 | controlnet_bottleneck_mode: 'large' 7 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 8 | controlnet_filter: SREffnetFilter 9 | controlnet_filter_params: 10 | scale_factor: 0.5 11 | 12 | effnet_checkpoint_path: models/effnet_encoder.safetensors 13 | previewer_checkpoint_path: models/previewer.safetensors 14 | generator_checkpoint_path: models/stage_c_bf16.safetensors 15 | controlnet_checkpoint_path: models/super_resolution.safetensors 16 | -------------------------------------------------------------------------------- /configs/inference/lora_c_3b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | # LoRA specific 6 | module_filters: ['.attn'] 7 | rank: 4 8 | train_tokens: 9 | # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized 10 | - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails 11 | 12 | effnet_checkpoint_path: models/effnet_encoder.safetensors 13 | previewer_checkpoint_path: models/previewer.safetensors 14 | generator_checkpoint_path: models/stage_c_bf16.safetensors 15 | lora_checkpoint_path: models/lora_fernando_10k.safetensors 16 | -------------------------------------------------------------------------------- /configs/inference/stage_b_1b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 700M 3 | dtype: bfloat16 4 | 5 | # For demonstration purposes in reconstruct_images.ipynb 6 | webdataset_path: path to your dataset 7 | batch_size: 1 8 | image_size: 2048 9 | grad_accum_steps: 1 10 | 11 | effnet_checkpoint_path: models/effnet_encoder.safetensors 12 | stage_a_checkpoint_path: models/stage_a.safetensors 13 | generator_checkpoint_path: models/stage_b_lite_bf16.safetensors -------------------------------------------------------------------------------- /configs/inference/stage_b_3b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3B 3 | dtype: bfloat16 4 | 5 | # For demonstration purposes in reconstruct_images.ipynb 6 | webdataset_path: path to your dataset 7 | batch_size: 4 8 | image_size: 1024 9 | grad_accum_steps: 1 10 | 11 | effnet_checkpoint_path: models/effnet_encoder.safetensors 12 | stage_a_checkpoint_path: models/stage_a.safetensors 13 | generator_checkpoint_path: models/stage_b_lite_bf16.safetensors -------------------------------------------------------------------------------- /configs/inference/stage_c_1b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 1B 3 | dtype: bfloat16 4 | 5 | effnet_checkpoint_path: models/effnet_encoder.safetensors 6 | previewer_checkpoint_path: models/previewer.safetensors 7 | generator_checkpoint_path: models/stage_c_lite_bf16.safetensors -------------------------------------------------------------------------------- /configs/inference/stage_c_3b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | model_version: 3.6B 3 | dtype: bfloat16 4 | 5 | effnet_checkpoint_path: models/effnet_encoder.safetensors 6 | previewer_checkpoint_path: models/previewer.safetensors 7 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /configs/training/cfg_control_lr.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: Ultrapixel_controlnet 3 | 4 | checkpoint_path: checkpoint output path 5 | output_path: visual results output path 6 | model_version: 3.6B 7 | dtype: float32 8 | # # WandB 9 | # wandb_project: StableCascade 10 | # wandb_entity: wandb_username 11 | #module_filters: ['.depthwise', '.mapper', '.attn', '.channelwise' ] 12 | #rank: 32 13 | # TRAINING PARAMS 14 | lr: 1.0e-4 15 | batch_size: 12 16 | #image_size: [1536, 2048, 2560, 3072, 4096] 17 | image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] 18 | #image_size: [ 1024, 1536, 2048, 2560, 3072, 3584, 3840, 4096, 4608] 19 | #image_size: [ 1024, 1280] 20 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 21 | grad_accum_steps: 2 22 | updates: 40000 23 | backup_every: 5000 24 | save_every: 256 25 | warmup_updates: 1 26 | use_fsdp: True 27 | 28 | # ControlNet specific 29 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 30 | controlnet_filter: CannyFilter 31 | controlnet_filter_params: 32 | resize: 224 33 | # offset_noise: 0.1 34 | 35 | # GDF 36 | adaptive_loss_weight: True 37 | 38 | ema_start_iters: 10 39 | ema_iters: 50 40 | ema_beta: 0.9 41 | 42 | webdataset_path: path to your training dataset 43 | effnet_checkpoint_path: models/effnet_encoder.safetensors 44 | previewer_checkpoint_path: models/previewer.safetensors 45 | generator_checkpoint_path: models/stage_c_bf16.safetensors 46 | controlnet_checkpoint_path: pretrained controlnet path 47 | 48 | -------------------------------------------------------------------------------- /configs/training/lora_personalization.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: roubao_cat_personalized 3 | 4 | checkpoint_path: checkpoint output path 5 | output_path: visual results output path 6 | model_version: 3.6B 7 | dtype: float32 8 | 9 | module_filters: [ '.attn'] 10 | rank: 4 11 | train_tokens: 12 | # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized 13 | - ['[roubaobao]', '^cat'] # custom token [snail], initialize as avg of snail & snails 14 | # TRAINING PARAMS 15 | lr: 1.0e-4 16 | batch_size: 4 17 | 18 | image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] 19 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 20 | grad_accum_steps: 2 21 | updates: 40000 22 | backup_every: 5000 23 | save_every: 512 24 | warmup_updates: 1 25 | use_ddp: True 26 | 27 | # GDF 28 | adaptive_loss_weight: True 29 | 30 | 31 | tmp_prompt: a photo of a cat [roubaobao] 32 | webdataset_path: path to your personalized training dataset 33 | effnet_checkpoint_path: models/effnet_encoder.safetensors 34 | previewer_checkpoint_path: models/previewer.safetensors 35 | generator_checkpoint_path: models/stage_c_bf16.safetensors 36 | ultrapixel_path: models/ultrapixel_t2i.safetensors 37 | 38 | -------------------------------------------------------------------------------- /configs/training/t2i.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: ultrapixel_t2i 3 | #strc_fixlrt_norm3_lite_1024_hrft_newdata 4 | checkpoint_path: checkpoint output path #output model directory 5 | output_path: visual results output path #experiment output directory 6 | model_version: 3.6B # finetune large stage c model of stablecascade 7 | dtype: float32 8 | 9 | 10 | # TRAINING PARAMS 11 | lr: 1.0e-4 12 | batch_size: 4 # gpu_number * num_per_gpu * grad_accum_steps 13 | image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] # possible image resolution 14 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 15 | grad_accum_steps: 2 16 | updates: 40000 17 | backup_every: 5000 18 | save_every: 256 19 | warmup_updates: 1 20 | use_ddp: True 21 | 22 | # GDF 23 | adaptive_loss_weight: True 24 | 25 | 26 | webdataset_path: path to your personalized training dataset 27 | effnet_checkpoint_path: models/effnet_encoder.safetensors 28 | previewer_checkpoint_path: models/previewer.safetensors 29 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /core/data/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import yaml 4 | import os 5 | from .bucketeer import Bucketeer 6 | 7 | class MultiFilter(): 8 | def __init__(self, rules, default=False): 9 | self.rules = rules 10 | self.default = default 11 | 12 | def __call__(self, x): 13 | try: 14 | x_json = x['json'] 15 | if isinstance(x_json, bytes): 16 | x_json = json.loads(x_json) 17 | validations = [] 18 | for k, r in self.rules.items(): 19 | if isinstance(k, tuple): 20 | v = r(*[x_json[kv] for kv in k]) 21 | else: 22 | v = r(x_json[k]) 23 | validations.append(v) 24 | return all(validations) 25 | except Exception: 26 | return False 27 | 28 | class MultiGetter(): 29 | def __init__(self, rules): 30 | self.rules = rules 31 | 32 | def __call__(self, x_json): 33 | if isinstance(x_json, bytes): 34 | x_json = json.loads(x_json) 35 | outputs = [] 36 | for k, r in self.rules.items(): 37 | if isinstance(k, tuple): 38 | v = r(*[x_json[kv] for kv in k]) 39 | else: 40 | v = r(x_json[k]) 41 | outputs.append(v) 42 | if len(outputs) == 1: 43 | outputs = outputs[0] 44 | return outputs 45 | 46 | def setup_webdataset_path(paths, cache_path=None): 47 | if cache_path is None or not os.path.exists(cache_path): 48 | tar_paths = [] 49 | if isinstance(paths, str): 50 | paths = [paths] 51 | for path in paths: 52 | if path.strip().endswith(".tar"): 53 | # Avoid looking up s3 if we already have a tar file 54 | tar_paths.append(path) 55 | continue 56 | bucket = "/".join(path.split("/")[:3]) 57 | result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True) 58 | files = result.stdout.decode('utf-8').split() 59 | files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")] 60 | tar_paths += files 61 | 62 | with open(cache_path, 'w', encoding='utf-8') as outfile: 63 | yaml.dump(tar_paths, outfile, default_flow_style=False) 64 | else: 65 | with open(cache_path, 'r', encoding='utf-8') as file: 66 | tar_paths = yaml.safe_load(file) 67 | 68 | tar_paths_str = ",".join([f"{p}" for p in tar_paths]) 69 | return f"pipe:aws s3 cp {{ {tar_paths_str} }} -" 70 | -------------------------------------------------------------------------------- /core/data/bucketeer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | from torchtools.transforms import SmartCrop 5 | import math 6 | 7 | class Bucketeer(): 8 | def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): 9 | assert crop_mode in ['center', 'random', 'smart'] 10 | self.crop_mode = crop_mode 11 | self.ratios = ratios 12 | if reverse_list: 13 | for r in list(ratios): 14 | if 1/r not in self.ratios: 15 | self.ratios.append(1/r) 16 | self.sizes = {} 17 | for dd in density: 18 | self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios] 19 | 20 | self.batch_size = dataloader.batch_size 21 | self.iterator = iter(dataloader) 22 | all_sizes = [] 23 | for k, vs in self.sizes.items(): 24 | all_sizes += vs 25 | self.buckets = {s: [] for s in all_sizes} 26 | self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None 27 | self.p_random_ratio = p_random_ratio 28 | self.interpolate_nearest = interpolate_nearest 29 | 30 | def get_available_batch(self): 31 | for b in self.buckets: 32 | if len(self.buckets[b]) >= self.batch_size: 33 | batch = self.buckets[b][:self.batch_size] 34 | self.buckets[b] = self.buckets[b][self.batch_size:] 35 | return batch 36 | return None 37 | 38 | def get_closest_size(self, x): 39 | w, h = x.size(-1), x.size(-2) 40 | 41 | 42 | best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) 43 | find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()} 44 | min_ = find_dict[list(find_dict.keys())[0]] 45 | find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx] 46 | for dd, val in find_dict.items(): 47 | if val < min_: 48 | min_ = val 49 | find_size = self.sizes[dd][best_size_idx] 50 | 51 | return find_size 52 | 53 | def get_resize_size(self, orig_size, tgt_size): 54 | if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: 55 | alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) 56 | resize_size = max(alt_min, min(tgt_size)) 57 | else: 58 | alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) 59 | resize_size = max(alt_max, max(tgt_size)) 60 | 61 | return resize_size 62 | 63 | def __next__(self): 64 | batch = self.get_available_batch() 65 | while batch is None: 66 | elements = next(self.iterator) 67 | for dct in elements: 68 | img = dct['images'] 69 | size = self.get_closest_size(img) 70 | resize_size = self.get_resize_size(img.shape[-2:], size) 71 | 72 | if self.interpolate_nearest: 73 | img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) 74 | else: 75 | img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) 76 | if self.crop_mode == 'center': 77 | img = torchvision.transforms.functional.center_crop(img, size) 78 | elif self.crop_mode == 'random': 79 | img = torchvision.transforms.RandomCrop(size)(img) 80 | elif self.crop_mode == 'smart': 81 | self.smartcrop.output_size = size 82 | img = self.smartcrop(img) 83 | 84 | self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) 85 | batch = self.get_available_batch() 86 | 87 | out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} 88 | return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} 89 | -------------------------------------------------------------------------------- /core/data/bucketeer_deg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | from torchtools.transforms import SmartCrop 5 | import math 6 | 7 | class Bucketeer(): 8 | def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): 9 | assert crop_mode in ['center', 'random', 'smart'] 10 | self.crop_mode = crop_mode 11 | self.ratios = ratios 12 | if reverse_list: 13 | for r in list(ratios): 14 | if 1/r not in self.ratios: 15 | self.ratios.append(1/r) 16 | self.sizes = {} 17 | for dd in density: 18 | self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios] 19 | print('in line 17 buckteer', self.sizes) 20 | self.batch_size = dataloader.batch_size 21 | self.iterator = iter(dataloader) 22 | all_sizes = [] 23 | for k, vs in self.sizes.items(): 24 | all_sizes += vs 25 | self.buckets = {s: [] for s in all_sizes} 26 | self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None 27 | self.p_random_ratio = p_random_ratio 28 | self.interpolate_nearest = interpolate_nearest 29 | 30 | def get_available_batch(self): 31 | for b in self.buckets: 32 | if len(self.buckets[b]) >= self.batch_size: 33 | batch = self.buckets[b][:self.batch_size] 34 | self.buckets[b] = self.buckets[b][self.batch_size:] 35 | return batch 36 | return None 37 | 38 | def get_closest_size(self, x): 39 | w, h = x.size(-1), x.size(-2) 40 | #if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio: 41 | # best_size_idx = np.random.randint(len(self.ratios)) 42 | #print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio) 43 | #else: 44 | 45 | best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) 46 | find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()} 47 | min_ = find_dict[list(find_dict.keys())[0]] 48 | find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx] 49 | for dd, val in find_dict.items(): 50 | if val < min_: 51 | min_ = val 52 | find_size = self.sizes[dd][best_size_idx] 53 | 54 | return find_size 55 | 56 | def get_resize_size(self, orig_size, tgt_size): 57 | if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: 58 | alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) 59 | resize_size = max(alt_min, min(tgt_size)) 60 | else: 61 | alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) 62 | resize_size = max(alt_max, max(tgt_size)) 63 | #print('in line 50', orig_size, tgt_size, resize_size) 64 | return resize_size 65 | 66 | def __next__(self): 67 | batch = self.get_available_batch() 68 | while batch is None: 69 | elements = next(self.iterator) 70 | for dct in elements: 71 | img = dct['images'] 72 | size = self.get_closest_size(img) 73 | resize_size = self.get_resize_size(img.shape[-2:], size) 74 | #print('in line 74', img.size(), resize_size) 75 | if self.interpolate_nearest: 76 | img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) 77 | else: 78 | img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) 79 | if self.crop_mode == 'center': 80 | img = torchvision.transforms.functional.center_crop(img, size) 81 | elif self.crop_mode == 'random': 82 | img = torchvision.transforms.RandomCrop(size)(img) 83 | elif self.crop_mode == 'smart': 84 | self.smartcrop.output_size = size 85 | img = self.smartcrop(img) 86 | print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img)) 87 | self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) 88 | batch = self.get_available_batch() 89 | 90 | out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} 91 | return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} 92 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/test.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/core/data/deg_kair_utils/test.bmp -------------------------------------------------------------------------------- /core/data/deg_kair_utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/core/data/deg_kair_utils/test.png -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_alignfaces.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 24 15:43:29 2017 4 | @author: zhaoy 5 | """ 6 | import cv2 7 | import numpy as np 8 | from skimage import transform as trans 9 | 10 | # reference facial points, a list of coordinates (x,y) 11 | REFERENCE_FACIAL_POINTS = [ 12 | [30.29459953, 51.69630051], 13 | [65.53179932, 51.50139999], 14 | [48.02519989, 71.73660278], 15 | [33.54930115, 92.3655014], 16 | [62.72990036, 92.20410156] 17 | ] 18 | 19 | DEFAULT_CROP_SIZE = (96, 112) 20 | 21 | 22 | def _umeyama(src, dst, estimate_scale=True, scale=1.0): 23 | """Estimate N-D similarity transformation with or without scaling. 24 | Parameters 25 | ---------- 26 | src : (M, N) array 27 | Source coordinates. 28 | dst : (M, N) array 29 | Destination coordinates. 30 | estimate_scale : bool 31 | Whether to estimate scaling factor. 32 | Returns 33 | ------- 34 | T : (N + 1, N + 1) 35 | The homogeneous similarity transformation matrix. The matrix contains 36 | NaN values only if the problem is not well-conditioned. 37 | References 38 | ---------- 39 | .. [1] "Least-squares estimation of transformation parameters between two 40 | point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` 41 | """ 42 | 43 | num = src.shape[0] 44 | dim = src.shape[1] 45 | 46 | # Compute mean of src and dst. 47 | src_mean = src.mean(axis=0) 48 | dst_mean = dst.mean(axis=0) 49 | 50 | # Subtract mean from src and dst. 51 | src_demean = src - src_mean 52 | dst_demean = dst - dst_mean 53 | 54 | # Eq. (38). 55 | A = dst_demean.T @ src_demean / num 56 | 57 | # Eq. (39). 58 | d = np.ones((dim,), dtype=np.double) 59 | if np.linalg.det(A) < 0: 60 | d[dim - 1] = -1 61 | 62 | T = np.eye(dim + 1, dtype=np.double) 63 | 64 | U, S, V = np.linalg.svd(A) 65 | 66 | # Eq. (40) and (43). 67 | rank = np.linalg.matrix_rank(A) 68 | if rank == 0: 69 | return np.nan * T 70 | elif rank == dim - 1: 71 | if np.linalg.det(U) * np.linalg.det(V) > 0: 72 | T[:dim, :dim] = U @ V 73 | else: 74 | s = d[dim - 1] 75 | d[dim - 1] = -1 76 | T[:dim, :dim] = U @ np.diag(d) @ V 77 | d[dim - 1] = s 78 | else: 79 | T[:dim, :dim] = U @ np.diag(d) @ V 80 | 81 | if estimate_scale: 82 | # Eq. (41) and (42). 83 | scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) 84 | else: 85 | scale = scale 86 | 87 | T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) 88 | T[:dim, :dim] *= scale 89 | 90 | return T, scale 91 | 92 | 93 | class FaceWarpException(Exception): 94 | def __str__(self): 95 | return 'In File {}:{}'.format( 96 | __file__, super.__str__(self)) 97 | 98 | 99 | def get_reference_facial_points(output_size=None, 100 | inner_padding_factor=0.0, 101 | outer_padding=(0, 0), 102 | default_square=False): 103 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) 104 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE) 105 | 106 | # 0) make the inner region a square 107 | if default_square: 108 | size_diff = max(tmp_crop_size) - tmp_crop_size 109 | tmp_5pts += size_diff / 2 110 | tmp_crop_size += size_diff 111 | 112 | if (output_size and 113 | output_size[0] == tmp_crop_size[0] and 114 | output_size[1] == tmp_crop_size[1]): 115 | print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size)) 116 | return tmp_5pts 117 | 118 | if (inner_padding_factor == 0 and 119 | outer_padding == (0, 0)): 120 | if output_size is None: 121 | print('No paddings to do: return default reference points') 122 | return tmp_5pts 123 | else: 124 | raise FaceWarpException( 125 | 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) 126 | 127 | # check output size 128 | if not (0 <= inner_padding_factor <= 1.0): 129 | raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') 130 | 131 | if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) 132 | and output_size is None): 133 | output_size = tmp_crop_size * \ 134 | (1 + inner_padding_factor * 2).astype(np.int32) 135 | output_size += np.array(outer_padding) 136 | print(' deduced from paddings, output_size = ', output_size) 137 | 138 | if not (outer_padding[0] < output_size[0] 139 | and outer_padding[1] < output_size[1]): 140 | raise FaceWarpException('Not (outer_padding[0] < output_size[0]' 141 | 'and outer_padding[1] < output_size[1])') 142 | 143 | # 1) pad the inner region according inner_padding_factor 144 | # print('---> STEP1: pad the inner region according inner_padding_factor') 145 | if inner_padding_factor > 0: 146 | size_diff = tmp_crop_size * inner_padding_factor * 2 147 | tmp_5pts += size_diff / 2 148 | tmp_crop_size += np.round(size_diff).astype(np.int32) 149 | 150 | # print(' crop_size = ', tmp_crop_size) 151 | # print(' reference_5pts = ', tmp_5pts) 152 | 153 | # 2) resize the padded inner region 154 | # print('---> STEP2: resize the padded inner region') 155 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 156 | # print(' crop_size = ', tmp_crop_size) 157 | # print(' size_bf_outer_pad = ', size_bf_outer_pad) 158 | 159 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: 160 | raise FaceWarpException('Must have (output_size - outer_padding)' 161 | '= some_scale * (crop_size * (1.0 + inner_padding_factor)') 162 | 163 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] 164 | # print(' resize scale_factor = ', scale_factor) 165 | tmp_5pts = tmp_5pts * scale_factor 166 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) 167 | # tmp_5pts = tmp_5pts + size_diff / 2 168 | tmp_crop_size = size_bf_outer_pad 169 | # print(' crop_size = ', tmp_crop_size) 170 | # print(' reference_5pts = ', tmp_5pts) 171 | 172 | # 3) add outer_padding to make output_size 173 | reference_5point = tmp_5pts + np.array(outer_padding) 174 | tmp_crop_size = output_size 175 | # print('---> STEP3: add outer_padding to make output_size') 176 | # print(' crop_size = ', tmp_crop_size) 177 | # print(' reference_5pts = ', tmp_5pts) 178 | # 179 | # print('===> end get_reference_facial_points\n') 180 | 181 | return reference_5point 182 | 183 | 184 | def get_affine_transform_matrix(src_pts, dst_pts): 185 | tfm = np.float32([[1, 0, 0], [0, 1, 0]]) 186 | n_pts = src_pts.shape[0] 187 | ones = np.ones((n_pts, 1), src_pts.dtype) 188 | src_pts_ = np.hstack([src_pts, ones]) 189 | dst_pts_ = np.hstack([dst_pts, ones]) 190 | 191 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) 192 | 193 | if rank == 3: 194 | tfm = np.float32([ 195 | [A[0, 0], A[1, 0], A[2, 0]], 196 | [A[0, 1], A[1, 1], A[2, 1]] 197 | ]) 198 | elif rank == 2: 199 | tfm = np.float32([ 200 | [A[0, 0], A[1, 0], 0], 201 | [A[0, 1], A[1, 1], 0] 202 | ]) 203 | 204 | return tfm 205 | 206 | 207 | def warp_and_crop_face(src_img, 208 | facial_pts, 209 | reference_pts=None, 210 | crop_size=(96, 112), 211 | align_type='smilarity'): #smilarity cv2_affine affine 212 | if reference_pts is None: 213 | if crop_size[0] == 96 and crop_size[1] == 112: 214 | reference_pts = REFERENCE_FACIAL_POINTS 215 | else: 216 | default_square = False 217 | inner_padding_factor = 0 218 | outer_padding = (0, 0) 219 | output_size = crop_size 220 | 221 | reference_pts = get_reference_facial_points(output_size, 222 | inner_padding_factor, 223 | outer_padding, 224 | default_square) 225 | 226 | ref_pts = np.float32(reference_pts) 227 | ref_pts_shp = ref_pts.shape 228 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: 229 | raise FaceWarpException( 230 | 'reference_pts.shape must be (K,2) or (2,K) and K>2') 231 | 232 | if ref_pts_shp[0] == 2: 233 | ref_pts = ref_pts.T 234 | 235 | src_pts = np.float32(facial_pts) 236 | src_pts_shp = src_pts.shape 237 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: 238 | raise FaceWarpException( 239 | 'facial_pts.shape must be (K,2) or (2,K) and K>2') 240 | 241 | if src_pts_shp[0] == 2: 242 | src_pts = src_pts.T 243 | 244 | if src_pts.shape != ref_pts.shape: 245 | raise FaceWarpException( 246 | 'facial_pts and reference_pts must have the same shape') 247 | 248 | if align_type is 'cv2_affine': 249 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) 250 | tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) 251 | elif align_type is 'affine': 252 | tfm = get_affine_transform_matrix(src_pts, ref_pts) 253 | tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) 254 | else: 255 | params, scale = _umeyama(src_pts, ref_pts) 256 | tfm = params[:2, :] 257 | 258 | params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale) 259 | tfm_inv = params[:2, :] 260 | 261 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3) 262 | 263 | return face_img, tfm_inv 264 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_bnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | """ 6 | # -------------------------------------------- 7 | # Batch Normalization 8 | # -------------------------------------------- 9 | 10 | # Kai Zhang (cskaizhang@gmail.com) 11 | # https://github.com/cszn 12 | # 01/Jan/2019 13 | # -------------------------------------------- 14 | """ 15 | 16 | 17 | # -------------------------------------------- 18 | # remove/delete specified layer 19 | # -------------------------------------------- 20 | def deleteLayer(model, layer_type=nn.BatchNorm2d): 21 | ''' Kai Zhang, 11/Jan/2019. 22 | ''' 23 | for k, m in list(model.named_children()): 24 | if isinstance(m, layer_type): 25 | del model._modules[k] 26 | deleteLayer(m, layer_type) 27 | 28 | 29 | # -------------------------------------------- 30 | # merge bn, "conv+bn" --> "conv" 31 | # -------------------------------------------- 32 | def merge_bn(model): 33 | ''' Kai Zhang, 11/Jan/2019. 34 | merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv') 35 | based on https://github.com/pytorch/pytorch/pull/901 36 | ''' 37 | prev_m = None 38 | for k, m in list(model.named_children()): 39 | if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)): 40 | 41 | w = prev_m.weight.data 42 | 43 | if prev_m.bias is None: 44 | zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type()) 45 | prev_m.bias = nn.Parameter(zeros) 46 | b = prev_m.bias.data 47 | 48 | invstd = m.running_var.clone().add_(m.eps).pow_(-0.5) 49 | if isinstance(prev_m, nn.ConvTranspose2d): 50 | w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w)) 51 | else: 52 | w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) 53 | b.add_(-m.running_mean).mul_(invstd) 54 | if m.affine: 55 | if isinstance(prev_m, nn.ConvTranspose2d): 56 | w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w)) 57 | else: 58 | w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) 59 | b.mul_(m.weight.data).add_(m.bias.data) 60 | 61 | del model._modules[k] 62 | prev_m = m 63 | merge_bn(m) 64 | 65 | 66 | # -------------------------------------------- 67 | # add bn, "conv" --> "conv+bn" 68 | # -------------------------------------------- 69 | def add_bn(model): 70 | ''' Kai Zhang, 11/Jan/2019. 71 | ''' 72 | for k, m in list(model.named_children()): 73 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)): 74 | b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True) 75 | b.weight.data.fill_(1) 76 | new_m = nn.Sequential(model._modules[k], b) 77 | model._modules[k] = new_m 78 | add_bn(m) 79 | 80 | 81 | # -------------------------------------------- 82 | # tidy model after removing bn 83 | # -------------------------------------------- 84 | def tidy_sequential(model): 85 | ''' Kai Zhang, 11/Jan/2019. 86 | ''' 87 | for k, m in list(model.named_children()): 88 | if isinstance(m, nn.Sequential): 89 | if m.__len__() == 1: 90 | model._modules[k] = m.__getitem__(0) 91 | tidy_sequential(m) 92 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_dist.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | # ---------------------------------- 11 | # init 12 | # ---------------------------------- 13 | def init_dist(launcher, backend='nccl', **kwargs): 14 | if mp.get_start_method(allow_none=True) is None: 15 | mp.set_start_method('spawn') 16 | if launcher == 'pytorch': 17 | _init_dist_pytorch(backend, **kwargs) 18 | elif launcher == 'slurm': 19 | _init_dist_slurm(backend, **kwargs) 20 | else: 21 | raise ValueError(f'Invalid launcher type: {launcher}') 22 | 23 | 24 | def _init_dist_pytorch(backend, **kwargs): 25 | rank = int(os.environ['RANK']) 26 | num_gpus = torch.cuda.device_count() 27 | torch.cuda.set_device(rank % num_gpus) 28 | dist.init_process_group(backend=backend, **kwargs) 29 | 30 | 31 | def _init_dist_slurm(backend, port=None): 32 | """Initialize slurm distributed training environment. 33 | If argument ``port`` is not specified, then the master port will be system 34 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 35 | environment variable, then a default port ``29500`` will be used. 36 | Args: 37 | backend (str): Backend of torch.distributed. 38 | port (int, optional): Master port. Defaults to None. 39 | """ 40 | proc_id = int(os.environ['SLURM_PROCID']) 41 | ntasks = int(os.environ['SLURM_NTASKS']) 42 | node_list = os.environ['SLURM_NODELIST'] 43 | num_gpus = torch.cuda.device_count() 44 | torch.cuda.set_device(proc_id % num_gpus) 45 | addr = subprocess.getoutput( 46 | f'scontrol show hostname {node_list} | head -n1') 47 | # specify master port 48 | if port is not None: 49 | os.environ['MASTER_PORT'] = str(port) 50 | elif 'MASTER_PORT' in os.environ: 51 | pass # use MASTER_PORT in the environment variable 52 | else: 53 | # 29500 is torch.distributed default port 54 | os.environ['MASTER_PORT'] = '29500' 55 | os.environ['MASTER_ADDR'] = addr 56 | os.environ['WORLD_SIZE'] = str(ntasks) 57 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 58 | os.environ['RANK'] = str(proc_id) 59 | dist.init_process_group(backend=backend) 60 | 61 | 62 | 63 | # ---------------------------------- 64 | # get rank and world_size 65 | # ---------------------------------- 66 | def get_dist_info(): 67 | if dist.is_available(): 68 | initialized = dist.is_initialized() 69 | else: 70 | initialized = False 71 | if initialized: 72 | rank = dist.get_rank() 73 | world_size = dist.get_world_size() 74 | else: 75 | rank = 0 76 | world_size = 1 77 | return rank, world_size 78 | 79 | 80 | def get_rank(): 81 | if not dist.is_available(): 82 | return 0 83 | 84 | if not dist.is_initialized(): 85 | return 0 86 | 87 | return dist.get_rank() 88 | 89 | 90 | def get_world_size(): 91 | if not dist.is_available(): 92 | return 1 93 | 94 | if not dist.is_initialized(): 95 | return 1 96 | 97 | return dist.get_world_size() 98 | 99 | 100 | def master_only(func): 101 | 102 | @functools.wraps(func) 103 | def wrapper(*args, **kwargs): 104 | rank, _ = get_dist_info() 105 | if rank == 0: 106 | return func(*args, **kwargs) 107 | 108 | return wrapper 109 | 110 | 111 | 112 | 113 | 114 | 115 | # ---------------------------------- 116 | # operation across ranks 117 | # ---------------------------------- 118 | def reduce_sum(tensor): 119 | if not dist.is_available(): 120 | return tensor 121 | 122 | if not dist.is_initialized(): 123 | return tensor 124 | 125 | tensor = tensor.clone() 126 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 127 | 128 | return tensor 129 | 130 | 131 | def gather_grad(params): 132 | world_size = get_world_size() 133 | 134 | if world_size == 1: 135 | return 136 | 137 | for param in params: 138 | if param.grad is not None: 139 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 140 | param.grad.data.div_(world_size) 141 | 142 | 143 | def all_gather(data): 144 | world_size = get_world_size() 145 | 146 | if world_size == 1: 147 | return [data] 148 | 149 | buffer = pickle.dumps(data) 150 | storage = torch.ByteStorage.from_buffer(buffer) 151 | tensor = torch.ByteTensor(storage).to('cuda') 152 | 153 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 154 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 155 | dist.all_gather(size_list, local_size) 156 | size_list = [int(size.item()) for size in size_list] 157 | max_size = max(size_list) 158 | 159 | tensor_list = [] 160 | for _ in size_list: 161 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 162 | 163 | if local_size != max_size: 164 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 165 | tensor = torch.cat((tensor, padding), 0) 166 | 167 | dist.all_gather(tensor_list, tensor) 168 | 169 | data_list = [] 170 | 171 | for size, tensor in zip(size_list, tensor_list): 172 | buffer = tensor.cpu().numpy().tobytes()[:size] 173 | data_list.append(pickle.loads(buffer)) 174 | 175 | return data_list 176 | 177 | 178 | def reduce_loss_dict(loss_dict): 179 | world_size = get_world_size() 180 | 181 | if world_size < 2: 182 | return loss_dict 183 | 184 | with torch.no_grad(): 185 | keys = [] 186 | losses = [] 187 | 188 | for k in sorted(loss_dict.keys()): 189 | keys.append(k) 190 | losses.append(loss_dict[k]) 191 | 192 | losses = torch.stack(losses, 0) 193 | dist.reduce(losses, dst=0) 194 | 195 | if dist.get_rank() == 0: 196 | losses /= world_size 197 | 198 | reduced_losses = {k: v for k, v in zip(keys, losses)} 199 | 200 | return reduced_losses 201 | 202 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_googledownload.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | 6 | ''' 7 | borrowed from 8 | https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py 9 | ''' 10 | 11 | 12 | def sizeof_fmt(size, suffix='B'): 13 | """Get human readable file size. 14 | Args: 15 | size (int): File size. 16 | suffix (str): Suffix. Default: 'B'. 17 | Return: 18 | str: Formated file siz. 19 | """ 20 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 21 | if abs(size) < 1024.0: 22 | return f'{size:3.1f} {unit}{suffix}' 23 | size /= 1024.0 24 | return f'{size:3.1f} Y{suffix}' 25 | 26 | 27 | def download_file_from_google_drive(file_id, save_path): 28 | """Download files from google drive. 29 | Ref: 30 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 31 | Args: 32 | file_id (str): File id. 33 | save_path (str): Save path. 34 | """ 35 | 36 | session = requests.Session() 37 | URL = 'https://docs.google.com/uc?export=download' 38 | params = {'id': file_id} 39 | 40 | response = session.get(URL, params=params, stream=True) 41 | token = get_confirm_token(response) 42 | if token: 43 | params['confirm'] = token 44 | response = session.get(URL, params=params, stream=True) 45 | 46 | # get file size 47 | response_file_size = session.get( 48 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 49 | if 'Content-Range' in response_file_size.headers: 50 | file_size = int( 51 | response_file_size.headers['Content-Range'].split('/')[1]) 52 | else: 53 | file_size = None 54 | 55 | save_response_content(response, save_path, file_size) 56 | 57 | 58 | def get_confirm_token(response): 59 | for key, value in response.cookies.items(): 60 | if key.startswith('download_warning'): 61 | return value 62 | return None 63 | 64 | 65 | def save_response_content(response, 66 | destination, 67 | file_size=None, 68 | chunk_size=32768): 69 | if file_size is not None: 70 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 71 | 72 | readable_file_size = sizeof_fmt(file_size) 73 | else: 74 | pbar = None 75 | 76 | with open(destination, 'wb') as f: 77 | downloaded_size = 0 78 | for chunk in response.iter_content(chunk_size): 79 | downloaded_size += chunk_size 80 | if pbar is not None: 81 | pbar.update(1) 82 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 83 | f'/ {readable_file_size}') 84 | if chunk: # filter out keep-alive new chunks 85 | f.write(chunk) 86 | if pbar is not None: 87 | pbar.close() 88 | 89 | 90 | if __name__ == "__main__": 91 | file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv' 92 | save_path = 'BSRGAN.pth' 93 | download_file_from_google_drive(file_id, save_path) 94 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_lmdb.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | example.lmdb 22 | ├── data.mdb 23 | ├── lock.mdb 24 | ├── meta_info.txt 25 | 26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 27 | https://lmdb.readthedocs.io/en/release/ for more details. 28 | 29 | The meta_info.txt is a specified txt file to record the meta information 30 | of our datasets. It will be automatically created when preparing 31 | datasets by our provided dataset tools. 32 | Each line in the txt file records 1)image name (with extension), 33 | 2)image shape, and 3)compression level, separated by a white space. 34 | 35 | For example, the meta information could be: 36 | `000_00000000.png (720,1280,3) 1`, which means: 37 | 1) image name (with extension): 000_00000000.png; 38 | 2) image shape: (720,1280,3); 39 | 3) compression level: 1 40 | 41 | We use the image name without extension as the lmdb key. 42 | 43 | If `multiprocessing_read` is True, it will read all the images to memory 44 | using multiprocessing. Thus, your server needs to have enough memory. 45 | 46 | Args: 47 | data_path (str): Data path for reading images. 48 | lmdb_path (str): Lmdb save path. 49 | img_path_list (str): Image path list. 50 | keys (str): Used for lmdb keys. 51 | batch (int): After processing batch images, lmdb commits. 52 | Default: 5000. 53 | compress_level (int): Compress level when encoding images. Default: 1. 54 | multiprocessing_read (bool): Whether use multiprocessing to read all 55 | the images to memory. Default: False. 56 | n_thread (int): For multiprocessing. 57 | map_size (int | None): Map size for lmdb env. If None, use the 58 | estimated size from images. Default: None 59 | """ 60 | 61 | assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' 62 | f'but got {len(img_path_list)} and {len(keys)}') 63 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 64 | print(f'Totoal images: {len(img_path_list)}') 65 | if not lmdb_path.endswith('.lmdb'): 66 | raise ValueError("lmdb_path must end with '.lmdb'.") 67 | if osp.exists(lmdb_path): 68 | print(f'Folder {lmdb_path} already exists. Exit.') 69 | sys.exit(1) 70 | 71 | if multiprocessing_read: 72 | # read all the images to memory (multiprocessing) 73 | dataset = {} # use dict to keep the order for multiprocessing 74 | shapes = {} 75 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 76 | pbar = tqdm(total=len(img_path_list), unit='image') 77 | 78 | def callback(arg): 79 | """get the image data and update pbar.""" 80 | key, dataset[key], shapes[key] = arg 81 | pbar.update(1) 82 | pbar.set_description(f'Read {key}') 83 | 84 | pool = Pool(n_thread) 85 | for path, key in zip(img_path_list, keys): 86 | pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) 87 | pool.close() 88 | pool.join() 89 | pbar.close() 90 | print(f'Finish reading {len(img_path_list)} images.') 91 | 92 | # create lmdb environment 93 | if map_size is None: 94 | # obtain data size for one image 95 | img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 96 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 97 | data_size_per_img = img_byte.nbytes 98 | print('Data size per image is: ', data_size_per_img) 99 | data_size = data_size_per_img * len(img_path_list) 100 | map_size = data_size * 10 101 | 102 | env = lmdb.open(lmdb_path, map_size=map_size) 103 | 104 | # write data to lmdb 105 | pbar = tqdm(total=len(img_path_list), unit='chunk') 106 | txn = env.begin(write=True) 107 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 108 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 109 | pbar.update(1) 110 | pbar.set_description(f'Write {key}') 111 | key_byte = key.encode('ascii') 112 | if multiprocessing_read: 113 | img_byte = dataset[key] 114 | h, w, c = shapes[key] 115 | else: 116 | _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) 117 | h, w, c = img_shape 118 | 119 | txn.put(key_byte, img_byte) 120 | # write meta information 121 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 122 | if idx % batch == 0: 123 | txn.commit() 124 | txn = env.begin(write=True) 125 | pbar.close() 126 | txn.commit() 127 | env.close() 128 | txt_file.close() 129 | print('\nFinish writing lmdb.') 130 | 131 | 132 | def read_img_worker(path, key, compress_level): 133 | """Read image worker. 134 | 135 | Args: 136 | path (str): Image path. 137 | key (str): Image key. 138 | compress_level (int): Compress level when encoding images. 139 | 140 | Returns: 141 | str: Image key. 142 | byte: Image byte. 143 | tuple[int]: Image shape. 144 | """ 145 | 146 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 147 | # deal with `libpng error: Read Error` 148 | if img is None: 149 | print(f'To deal with `libpng error: Read Error`, use PIL to load {path}') 150 | from PIL import Image 151 | import numpy as np 152 | img = Image.open(path) 153 | img = np.asanyarray(img) 154 | img = img[:, :, [2, 1, 0]] 155 | 156 | if img.ndim == 2: 157 | h, w = img.shape 158 | c = 1 159 | else: 160 | h, w, c = img.shape 161 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 162 | return (key, img_byte, (h, w, c)) 163 | 164 | 165 | class LmdbMaker(): 166 | """LMDB Maker. 167 | 168 | Args: 169 | lmdb_path (str): Lmdb save path. 170 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 171 | batch (int): After processing batch images, lmdb commits. 172 | Default: 5000. 173 | compress_level (int): Compress level when encoding images. Default: 1. 174 | """ 175 | 176 | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): 177 | if not lmdb_path.endswith('.lmdb'): 178 | raise ValueError("lmdb_path must end with '.lmdb'.") 179 | if osp.exists(lmdb_path): 180 | print(f'Folder {lmdb_path} already exists. Exit.') 181 | sys.exit(1) 182 | 183 | self.lmdb_path = lmdb_path 184 | self.batch = batch 185 | self.compress_level = compress_level 186 | self.env = lmdb.open(lmdb_path, map_size=map_size) 187 | self.txn = self.env.begin(write=True) 188 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 189 | self.counter = 0 190 | 191 | def put(self, img_byte, key, img_shape): 192 | self.counter += 1 193 | key_byte = key.encode('ascii') 194 | self.txn.put(key_byte, img_byte) 195 | # write meta information 196 | h, w, c = img_shape 197 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 198 | if self.counter % self.batch == 0: 199 | self.txn.commit() 200 | self.txn = self.env.begin(write=True) 201 | 202 | def close(self): 203 | self.txn.commit() 204 | self.env.close() 205 | self.txt_file.close() 206 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import datetime 3 | import logging 4 | 5 | 6 | ''' 7 | # -------------------------------------------- 8 | # Kai Zhang (github: https://github.com/cszn) 9 | # 03/Mar/2019 10 | # -------------------------------------------- 11 | # https://github.com/xinntao/BasicSR 12 | # -------------------------------------------- 13 | ''' 14 | 15 | 16 | def log(*args, **kwargs): 17 | print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) 18 | 19 | 20 | ''' 21 | # -------------------------------------------- 22 | # logger 23 | # -------------------------------------------- 24 | ''' 25 | 26 | 27 | def logger_info(logger_name, log_path='default_logger.log'): 28 | ''' set up logger 29 | modified by Kai Zhang (github: https://github.com/cszn) 30 | ''' 31 | log = logging.getLogger(logger_name) 32 | if log.hasHandlers(): 33 | print('LogHandlers exist!') 34 | else: 35 | print('LogHandlers setup!') 36 | level = logging.INFO 37 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') 38 | fh = logging.FileHandler(log_path, mode='a') 39 | fh.setFormatter(formatter) 40 | log.setLevel(level) 41 | log.addHandler(fh) 42 | # print(len(log.handlers)) 43 | 44 | sh = logging.StreamHandler() 45 | sh.setFormatter(formatter) 46 | log.addHandler(sh) 47 | 48 | 49 | ''' 50 | # -------------------------------------------- 51 | # print to file and std_out simultaneously 52 | # -------------------------------------------- 53 | ''' 54 | 55 | 56 | class logger_print(object): 57 | def __init__(self, log_path="default.log"): 58 | self.terminal = sys.stdout 59 | self.log = open(log_path, 'a') 60 | 61 | def write(self, message): 62 | self.terminal.write(message) 63 | self.log.write(message) # write the message 64 | 65 | def flush(self): 66 | pass 67 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_mat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import scipy.io as spio 4 | import pandas as pd 5 | 6 | 7 | def loadmat(filename): 8 | ''' 9 | this function should be called instead of direct spio.loadmat 10 | as it cures the problem of not properly recovering python dictionaries 11 | from mat files. It calls the function check keys to cure all entries 12 | which are still mat-objects 13 | ''' 14 | data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True) 15 | return dict_to_nonedict(_check_keys(data)) 16 | 17 | def _check_keys(dict): 18 | ''' 19 | checks if entries in dictionary are mat-objects. If yes 20 | todict is called to change them to nested dictionaries 21 | ''' 22 | for key in dict: 23 | if isinstance(dict[key], spio.matlab.mio5_params.mat_struct): 24 | dict[key] = _todict(dict[key]) 25 | return dict 26 | 27 | def _todict(matobj): 28 | ''' 29 | A recursive function which constructs from matobjects nested dictionaries 30 | ''' 31 | dict = {} 32 | for strg in matobj._fieldnames: 33 | elem = matobj.__dict__[strg] 34 | if isinstance(elem, spio.matlab.mio5_params.mat_struct): 35 | dict[strg] = _todict(elem) 36 | else: 37 | dict[strg] = elem 38 | return dict 39 | 40 | 41 | def dict_to_nonedict(opt): 42 | if isinstance(opt, dict): 43 | new_opt = dict() 44 | for key, sub_opt in opt.items(): 45 | new_opt[key] = dict_to_nonedict(sub_opt) 46 | return NoneDict(**new_opt) 47 | elif isinstance(opt, list): 48 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 49 | else: 50 | return opt 51 | 52 | 53 | class NoneDict(dict): 54 | def __missing__(self, key): 55 | return None 56 | 57 | 58 | def mat2json(mat_path=None, filepath = None): 59 | """ 60 | Converts .mat file to .json and writes new file 61 | Parameters 62 | ---------- 63 | mat_path: Str 64 | path/filename .mat存放路径 65 | filepath: Str 66 | 如果需要保存成json, 添加这一路径. 否则不保存 67 | Returns 68 | 返回转化的字典 69 | ------- 70 | None 71 | Examples 72 | -------- 73 | >>> mat2json(blah blah) 74 | """ 75 | 76 | matlabFile = loadmat(mat_path) 77 | #pop all those dumb fields that don't let you jsonize file 78 | matlabFile.pop('__header__') 79 | matlabFile.pop('__version__') 80 | matlabFile.pop('__globals__') 81 | #jsonize the file - orientation is 'index' 82 | matlabFile = pd.Series(matlabFile).to_json() 83 | 84 | if filepath: 85 | json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json' 86 | with open(json_path, 'w') as f: 87 | f.write(matlabFile) 88 | return matlabFile -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_matconvnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch 4 | from collections import OrderedDict 5 | 6 | # import scipy.io as io 7 | import hdf5storage 8 | 9 | """ 10 | # -------------------------------------------- 11 | # Convert matconvnet SimpleNN model into pytorch model 12 | # -------------------------------------------- 13 | # Kai Zhang (cskaizhang@gmail.com) 14 | # https://github.com/cszn 15 | # 28/Nov/2019 16 | # -------------------------------------------- 17 | """ 18 | 19 | 20 | def weights2tensor(x, squeeze=False, in_features=None, out_features=None): 21 | """Modified version of https://github.com/albanie/pytorch-mcn 22 | Adjust memory layout and load weights as torch tensor 23 | Args: 24 | x (ndaray): a numpy array, corresponding to a set of network weights 25 | stored in column major order 26 | squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove 27 | singletons from the trailing dimensions. So after converting to 28 | pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1) 29 | it will be reshaped to a matrix with shape (A,B). 30 | in_features (int :: None): used to reshape weights for a linear block. 31 | out_features (int :: None): used to reshape weights for a linear block. 32 | Returns: 33 | torch.tensor: a permuted sets of weights, matching the pytorch layout 34 | convention 35 | """ 36 | if x.ndim == 4: 37 | x = x.transpose((3, 2, 0, 1)) 38 | # for FFDNet, pixel-shuffle layer 39 | # if x.shape[1]==13: 40 | # x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:] 41 | # if x.shape[0]==12: 42 | # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] 43 | # if x.shape[1]==5: 44 | # x=x[:,[0,2,1,3, 4],:,:] 45 | # if x.shape[0]==4: 46 | # x=x[[0,2,1,3],:,:,:] 47 | ## for SRMD, pixel-shuffle layer 48 | # if x.shape[0]==12: 49 | # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] 50 | # if x.shape[0]==27: 51 | # x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:] 52 | # if x.shape[0]==48: 53 | # x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:] 54 | 55 | elif x.ndim == 3: # add by Kai 56 | x = x[:,:,:,None] 57 | x = x.transpose((3, 2, 0, 1)) 58 | elif x.ndim == 2: 59 | if x.shape[1] == 1: 60 | x = x.flatten() 61 | if squeeze: 62 | if in_features and out_features: 63 | x = x.reshape((out_features, in_features)) 64 | x = np.squeeze(x) 65 | return torch.from_numpy(np.ascontiguousarray(x)) 66 | 67 | 68 | def save_model(network, save_path): 69 | state_dict = network.state_dict() 70 | for key, param in state_dict.items(): 71 | state_dict[key] = param.cpu() 72 | torch.save(state_dict, save_path) 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | 78 | # from utils import utils_logger 79 | # import logging 80 | # utils_logger.logger_info('a', 'a.log') 81 | # logger = logging.getLogger('a') 82 | # 83 | # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat') 84 | mcn = hdf5storage.loadmat('models/modelcolor.mat') 85 | 86 | 87 | #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0]) 88 | 89 | mat_net = OrderedDict() 90 | for idx in range(25): 91 | mat_net[str(idx)] = OrderedDict() 92 | count = -1 93 | 94 | print(idx) 95 | for i in range(13): 96 | 97 | if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv': 98 | 99 | count += 1 100 | w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0] 101 | # print(w.shape) 102 | w = weights2tensor(w) 103 | # print(w.shape) 104 | 105 | b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1] 106 | b = weights2tensor(b) 107 | print(b.shape) 108 | 109 | mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w 110 | mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b 111 | 112 | torch.save(mat_net, 'model_zoo/modelcolor.pth') 113 | 114 | 115 | 116 | # from models.network_dncnn import IRCNN as net 117 | # network = net(in_nc=3, out_nc=3, nc=64) 118 | # state_dict = network.state_dict() 119 | # 120 | # #show_kv(state_dict) 121 | # 122 | # for i in range(len(mcn['net'][0][0][0])): 123 | # print(mcn['net'][0][0][0][i][0][0][0][0]) 124 | # 125 | # count = -1 126 | # mat_net = OrderedDict() 127 | # for i in range(len(mcn['net'][0][0][0])): 128 | # if mcn['net'][0][0][0][i][0][0][0][0] == 'conv': 129 | # 130 | # count += 1 131 | # w = mcn['net'][0][0][0][i][0][1][0][0] 132 | # print(w.shape) 133 | # w = weights2tensor(w) 134 | # print(w.shape) 135 | # 136 | # b = mcn['net'][0][0][0][i][0][1][0][1] 137 | # b = weights2tensor(b) 138 | # print(b.shape) 139 | # 140 | # mat_net['model.{:d}.weight'.format(count*2)] = w 141 | # mat_net['model.{:d}.bias'.format(count*2)] = b 142 | # 143 | # torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth') 144 | # 145 | # 146 | # 147 | # crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth') 148 | # def show_kv(net): 149 | # for k, v in net.items(): 150 | # print(k) 151 | # 152 | # show_kv(crt_net) 153 | 154 | 155 | # from models.network_dncnn import DnCNN as net 156 | # network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R') 157 | 158 | # from models.network_srmd import SRMD as net 159 | # #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R') 160 | # network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle') 161 | # 162 | # from models.network_rrdb import RRDB as net 163 | # network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv') 164 | # 165 | # state_dict = network.state_dict() 166 | # for key, param in state_dict.items(): 167 | # print(key) 168 | # from models.network_imdn import IMDN as net 169 | # network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle') 170 | # state_dict = network.state_dict() 171 | # mat_net = OrderedDict() 172 | # for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()): 173 | # mat_net[key] = param2 174 | # torch.save(mat_net, 'model_zoo/imdn_x4_1.pth') 175 | # 176 | 177 | # net_old = torch.load('net_old.pth') 178 | # def show_kv(net): 179 | # for k, v in net.items(): 180 | # print(k) 181 | # 182 | # show_kv(net_old) 183 | # from models.network_dpsr import MSRResNet_prior as net 184 | # model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle') 185 | # state_dict = network.state_dict() 186 | # net_new = OrderedDict() 187 | # for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()): 188 | # net_new[key] = param_old 189 | # torch.save(net_new, 'net_new.pth') 190 | 191 | 192 | # print(key) 193 | # print(param.size()) 194 | 195 | 196 | 197 | # run utils/utils_matconvnet.py 198 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_option.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from datetime import datetime 4 | import json 5 | import re 6 | import glob 7 | 8 | 9 | ''' 10 | # -------------------------------------------- 11 | # Kai Zhang (github: https://github.com/cszn) 12 | # 03/Mar/2019 13 | # -------------------------------------------- 14 | # https://github.com/xinntao/BasicSR 15 | # -------------------------------------------- 16 | ''' 17 | 18 | 19 | def get_timestamp(): 20 | return datetime.now().strftime('_%y%m%d_%H%M%S') 21 | 22 | 23 | def parse(opt_path, is_train=True): 24 | 25 | # ---------------------------------------- 26 | # remove comments starting with '//' 27 | # ---------------------------------------- 28 | json_str = '' 29 | with open(opt_path, 'r') as f: 30 | for line in f: 31 | line = line.split('//')[0] + '\n' 32 | json_str += line 33 | 34 | # ---------------------------------------- 35 | # initialize opt 36 | # ---------------------------------------- 37 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 38 | 39 | opt['opt_path'] = opt_path 40 | opt['is_train'] = is_train 41 | 42 | # ---------------------------------------- 43 | # set default 44 | # ---------------------------------------- 45 | if 'merge_bn' not in opt: 46 | opt['merge_bn'] = False 47 | opt['merge_bn_startpoint'] = -1 48 | 49 | if 'scale' not in opt: 50 | opt['scale'] = 1 51 | 52 | # ---------------------------------------- 53 | # datasets 54 | # ---------------------------------------- 55 | for phase, dataset in opt['datasets'].items(): 56 | phase = phase.split('_')[0] 57 | dataset['phase'] = phase 58 | dataset['scale'] = opt['scale'] # broadcast 59 | dataset['n_channels'] = opt['n_channels'] # broadcast 60 | if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None: 61 | dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H']) 62 | if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None: 63 | dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L']) 64 | 65 | # ---------------------------------------- 66 | # path 67 | # ---------------------------------------- 68 | for key, path in opt['path'].items(): 69 | if path and key in opt['path']: 70 | opt['path'][key] = os.path.expanduser(path) 71 | 72 | path_task = os.path.join(opt['path']['root'], opt['task']) 73 | opt['path']['task'] = path_task 74 | opt['path']['log'] = path_task 75 | opt['path']['options'] = os.path.join(path_task, 'options') 76 | 77 | if is_train: 78 | opt['path']['models'] = os.path.join(path_task, 'models') 79 | opt['path']['images'] = os.path.join(path_task, 'images') 80 | else: # test 81 | opt['path']['images'] = os.path.join(path_task, 'test_images') 82 | 83 | # ---------------------------------------- 84 | # network 85 | # ---------------------------------------- 86 | opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1 87 | 88 | # ---------------------------------------- 89 | # GPU devices 90 | # ---------------------------------------- 91 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 92 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 93 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 94 | 95 | # ---------------------------------------- 96 | # default setting for distributeddataparallel 97 | # ---------------------------------------- 98 | if 'find_unused_parameters' not in opt: 99 | opt['find_unused_parameters'] = True 100 | if 'use_static_graph' not in opt: 101 | opt['use_static_graph'] = False 102 | if 'dist' not in opt: 103 | opt['dist'] = False 104 | opt['num_gpu'] = len(opt['gpu_ids']) 105 | print('number of GPUs is: ' + str(opt['num_gpu'])) 106 | 107 | # ---------------------------------------- 108 | # default setting for perceptual loss 109 | # ---------------------------------------- 110 | if 'F_feature_layer' not in opt['train']: 111 | opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34] 112 | if 'F_weights' not in opt['train']: 113 | opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0] 114 | if 'F_lossfn_type' not in opt['train']: 115 | opt['train']['F_lossfn_type'] = 'l1' 116 | if 'F_use_input_norm' not in opt['train']: 117 | opt['train']['F_use_input_norm'] = True 118 | if 'F_use_range_norm' not in opt['train']: 119 | opt['train']['F_use_range_norm'] = False 120 | 121 | # ---------------------------------------- 122 | # default setting for optimizer 123 | # ---------------------------------------- 124 | if 'G_optimizer_type' not in opt['train']: 125 | opt['train']['G_optimizer_type'] = "adam" 126 | if 'G_optimizer_betas' not in opt['train']: 127 | opt['train']['G_optimizer_betas'] = [0.9,0.999] 128 | if 'G_scheduler_restart_weights' not in opt['train']: 129 | opt['train']['G_scheduler_restart_weights'] = 1 130 | if 'G_optimizer_wd' not in opt['train']: 131 | opt['train']['G_optimizer_wd'] = 0 132 | if 'G_optimizer_reuse' not in opt['train']: 133 | opt['train']['G_optimizer_reuse'] = False 134 | if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']: 135 | opt['train']['D_optimizer_reuse'] = False 136 | 137 | # ---------------------------------------- 138 | # default setting of strict for model loading 139 | # ---------------------------------------- 140 | if 'G_param_strict' not in opt['train']: 141 | opt['train']['G_param_strict'] = True 142 | if 'netD' in opt and 'D_param_strict' not in opt['path']: 143 | opt['train']['D_param_strict'] = True 144 | if 'E_param_strict' not in opt['path']: 145 | opt['train']['E_param_strict'] = True 146 | 147 | # ---------------------------------------- 148 | # Exponential Moving Average 149 | # ---------------------------------------- 150 | if 'E_decay' not in opt['train']: 151 | opt['train']['E_decay'] = 0 152 | 153 | # ---------------------------------------- 154 | # default setting for discriminator 155 | # ---------------------------------------- 156 | if 'netD' in opt: 157 | if 'net_type' not in opt['netD']: 158 | opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet 159 | if 'in_nc' not in opt['netD']: 160 | opt['netD']['in_nc'] = 3 161 | if 'base_nc' not in opt['netD']: 162 | opt['netD']['base_nc'] = 64 163 | if 'n_layers' not in opt['netD']: 164 | opt['netD']['n_layers'] = 3 165 | if 'norm_type' not in opt['netD']: 166 | opt['netD']['norm_type'] = 'spectral' 167 | 168 | 169 | return opt 170 | 171 | 172 | def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): 173 | """ 174 | Args: 175 | save_dir: model folder 176 | net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' 177 | pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path 178 | 179 | Return: 180 | init_iter: iteration number 181 | init_path: model path 182 | """ 183 | file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) 184 | if file_list: 185 | iter_exist = [] 186 | for file_ in file_list: 187 | iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) 188 | iter_exist.append(int(iter_current[0])) 189 | init_iter = max(iter_exist) 190 | init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) 191 | else: 192 | init_iter = 0 193 | init_path = pretrained_path 194 | return init_iter, init_path 195 | 196 | 197 | ''' 198 | # -------------------------------------------- 199 | # convert the opt into json file 200 | # -------------------------------------------- 201 | ''' 202 | 203 | 204 | def save(opt): 205 | opt_path = opt['opt_path'] 206 | opt_path_copy = opt['path']['options'] 207 | dirname, filename_ext = os.path.split(opt_path) 208 | filename, ext = os.path.splitext(filename_ext) 209 | dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext) 210 | with open(dump_path, 'w') as dump_file: 211 | json.dump(opt, dump_file, indent=2) 212 | 213 | 214 | ''' 215 | # -------------------------------------------- 216 | # dict to string for logger 217 | # -------------------------------------------- 218 | ''' 219 | 220 | 221 | def dict2str(opt, indent_l=1): 222 | msg = '' 223 | for k, v in opt.items(): 224 | if isinstance(v, dict): 225 | msg += ' ' * (indent_l * 2) + k + ':[\n' 226 | msg += dict2str(v, indent_l + 1) 227 | msg += ' ' * (indent_l * 2) + ']\n' 228 | else: 229 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 230 | return msg 231 | 232 | 233 | ''' 234 | # -------------------------------------------- 235 | # convert OrderedDict to NoneDict, 236 | # return None for missing key 237 | # -------------------------------------------- 238 | ''' 239 | 240 | 241 | def dict_to_nonedict(opt): 242 | if isinstance(opt, dict): 243 | new_opt = dict() 244 | for key, sub_opt in opt.items(): 245 | new_opt[key] = dict_to_nonedict(sub_opt) 246 | return NoneDict(**new_opt) 247 | elif isinstance(opt, list): 248 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 249 | else: 250 | return opt 251 | 252 | 253 | class NoneDict(dict): 254 | def __missing__(self, key): 255 | return None 256 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | 5 | from models import basicblock as B 6 | 7 | def show_kv(net): 8 | for k, v in net.items(): 9 | print(k) 10 | 11 | # should run train debug mode first to get an initial model 12 | #crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth') 13 | # 14 | #for k, v in crt_net.items(): 15 | # print(k) 16 | #for k, v in crt_net.items(): 17 | # if k in pretrained_net: 18 | # crt_net[k] = pretrained_net[k] 19 | # print('replace ... ', k) 20 | 21 | # x2 -> x4 22 | #crt_net['model.5.weight'] = pretrained_net['model.2.weight'] 23 | #crt_net['model.5.bias'] = pretrained_net['model.2.bias'] 24 | #crt_net['model.8.weight'] = pretrained_net['model.5.weight'] 25 | #crt_net['model.8.bias'] = pretrained_net['model.5.bias'] 26 | #crt_net['model.10.weight'] = pretrained_net['model.7.weight'] 27 | #crt_net['model.10.bias'] = pretrained_net['model.7.bias'] 28 | #torch.save(crt_net, '../pretrained_tmp.pth') 29 | 30 | # x2 -> x3 31 | ''' 32 | in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3 33 | new_filter = torch.Tensor(576, 64, 3, 3) 34 | new_filter[0:256, :, :, :] = in_filter 35 | new_filter[256:512, :, :, :] = in_filter 36 | new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :] 37 | crt_net['model.2.weight'] = new_filter 38 | 39 | in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3 40 | new_bias = torch.Tensor(576) 41 | new_bias[0:256] = in_bias 42 | new_bias[256:512] = in_bias 43 | new_bias[512:] = in_bias[0:576 - 512] 44 | crt_net['model.2.bias'] = new_bias 45 | 46 | torch.save(crt_net, '../pretrained_tmp.pth') 47 | ''' 48 | 49 | # x2 -> x8 50 | ''' 51 | crt_net['model.5.weight'] = pretrained_net['model.2.weight'] 52 | crt_net['model.5.bias'] = pretrained_net['model.2.bias'] 53 | crt_net['model.8.weight'] = pretrained_net['model.2.weight'] 54 | crt_net['model.8.bias'] = pretrained_net['model.2.bias'] 55 | crt_net['model.11.weight'] = pretrained_net['model.5.weight'] 56 | crt_net['model.11.bias'] = pretrained_net['model.5.bias'] 57 | crt_net['model.13.weight'] = pretrained_net['model.7.weight'] 58 | crt_net['model.13.bias'] = pretrained_net['model.7.bias'] 59 | torch.save(crt_net, '../pretrained_tmp.pth') 60 | ''' 61 | 62 | # x3/4/8 RGB -> Y 63 | 64 | def rgb2gray_net(net, only_input=True): 65 | 66 | if only_input: 67 | in_filter = net['0.weight'] 68 | in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114 69 | in_new_filter.unsqueeze_(1) 70 | net['0.weight'] = in_new_filter 71 | 72 | # out_filter = pretrained_net['model.13.weight'] 73 | # out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \ 74 | # out_filter[2, :, :, :] * 0.114 75 | # out_new_filter.unsqueeze_(0) 76 | # crt_net['model.13.weight'] = out_new_filter 77 | # out_bias = pretrained_net['model.13.bias'] 78 | # out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114 79 | # out_new_bias = torch.Tensor(1).fill_(out_new_bias) 80 | # crt_net['model.13.bias'] = out_new_bias 81 | 82 | # torch.save(crt_net, '../pretrained_tmp.pth') 83 | 84 | return net 85 | 86 | 87 | 88 | if __name__ == '__main__': 89 | 90 | net = torchvision.models.vgg19(pretrained=True) 91 | for k,v in net.features.named_parameters(): 92 | if k=='0.weight': 93 | in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114 94 | in_new_filter.unsqueeze_(1) 95 | v = in_new_filter 96 | print(v.shape) 97 | print(v[0,0,0,0]) 98 | if k=='0.bias': 99 | in_new_bias = v 100 | print(v[0]) 101 | 102 | print(net.features[0]) 103 | 104 | net.features[0] = B.conv(1, 64, mode='C') 105 | 106 | print(net.features[0]) 107 | net.features[0].weight.data=in_new_filter 108 | net.features[0].bias.data=in_new_bias 109 | 110 | for k,v in net.features.named_parameters(): 111 | if k=='0.weight': 112 | print(v[0,0,0,0]) 113 | if k=='0.bias': 114 | print(v[0]) 115 | 116 | # transfer parameters of old model to new one 117 | model_old = torch.load(model_path) 118 | state_dict = model.state_dict() 119 | for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()): 120 | state_dict[key2] = param 121 | print([key, key2]) 122 | # print([param.size(), param2.size()]) 123 | torch.save(state_dict, 'model_new.pth') 124 | 125 | 126 | # rgb2gray_net(net) 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_receptivefield.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # online calculation: https://fomoro.com/research/article/receptive-field-calculator# 4 | 5 | # [filter size, stride, padding] 6 | #Assume the two dimensions are the same 7 | #Each kernel requires the following parameters: 8 | # - k_i: kernel size 9 | # - s_i: stride 10 | # - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow) 11 | # 12 | #Each layer i requires the following parameters to be fully represented: 13 | # - n_i: number of feature (data layer has n_1 = imagesize ) 14 | # - j_i: distance (projected to image pixel distance) between center of two adjacent features 15 | # - r_i: receptive field of a feature in layer i 16 | # - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding) 17 | 18 | import math 19 | 20 | def outFromIn(conv, layerIn): 21 | n_in = layerIn[0] 22 | j_in = layerIn[1] 23 | r_in = layerIn[2] 24 | start_in = layerIn[3] 25 | k = conv[0] 26 | s = conv[1] 27 | p = conv[2] 28 | 29 | n_out = math.floor((n_in - k + 2*p)/s) + 1 30 | actualP = (n_out-1)*s - n_in + k 31 | pR = math.ceil(actualP/2) 32 | pL = math.floor(actualP/2) 33 | 34 | j_out = j_in * s 35 | r_out = r_in + (k - 1)*j_in 36 | start_out = start_in + ((k-1)/2 - pL)*j_in 37 | return n_out, j_out, r_out, start_out 38 | 39 | def printLayer(layer, layer_name): 40 | print(layer_name + ":") 41 | print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3])) 42 | 43 | 44 | 45 | layerInfos = [] 46 | if __name__ == '__main__': 47 | 48 | convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]] 49 | layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12'] 50 | imsize = 128 51 | 52 | print ("-------Net summary------") 53 | currentLayer = [imsize, 1, 1, 0.5] 54 | printLayer(currentLayer, "input image") 55 | for i in range(len(convnet)): 56 | currentLayer = outFromIn(convnet[i], currentLayer) 57 | layerInfos.append(currentLayer) 58 | printLayer(currentLayer, layer_names[i]) 59 | 60 | 61 | # run utils/utils_receptivefield.py 62 | -------------------------------------------------------------------------------- /core/data/deg_kair_utils/utils_regularizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | ''' 6 | # -------------------------------------------- 7 | # Kai Zhang (github: https://github.com/cszn) 8 | # 03/Mar/2019 9 | # -------------------------------------------- 10 | ''' 11 | 12 | 13 | # -------------------------------------------- 14 | # SVD Orthogonal Regularization 15 | # -------------------------------------------- 16 | def regularizer_orth(m): 17 | """ 18 | # ---------------------------------------- 19 | # SVD Orthogonal Regularization 20 | # ---------------------------------------- 21 | # Applies regularization to the training by performing the 22 | # orthogonalization technique described in the paper 23 | # This function is to be called by the torch.nn.Module.apply() method, 24 | # which applies svd_orthogonalization() to every layer of the model. 25 | # usage: net.apply(regularizer_orth) 26 | # ---------------------------------------- 27 | """ 28 | classname = m.__class__.__name__ 29 | if classname.find('Conv') != -1: 30 | w = m.weight.data.clone() 31 | c_out, c_in, f1, f2 = w.size() 32 | # dtype = m.weight.data.type() 33 | w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) 34 | # self.netG.apply(svd_orthogonalization) 35 | u, s, v = torch.svd(w) 36 | s[s > 1.5] = s[s > 1.5] - 1e-4 37 | s[s < 0.5] = s[s < 0.5] + 1e-4 38 | w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) 39 | m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) 40 | else: 41 | pass 42 | 43 | 44 | # -------------------------------------------- 45 | # SVD Orthogonal Regularization 46 | # -------------------------------------------- 47 | def regularizer_orth2(m): 48 | """ 49 | # ---------------------------------------- 50 | # Applies regularization to the training by performing the 51 | # orthogonalization technique described in the paper 52 | # This function is to be called by the torch.nn.Module.apply() method, 53 | # which applies svd_orthogonalization() to every layer of the model. 54 | # usage: net.apply(regularizer_orth2) 55 | # ---------------------------------------- 56 | """ 57 | classname = m.__class__.__name__ 58 | if classname.find('Conv') != -1: 59 | w = m.weight.data.clone() 60 | c_out, c_in, f1, f2 = w.size() 61 | # dtype = m.weight.data.type() 62 | w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) 63 | u, s, v = torch.svd(w) 64 | s_mean = s.mean() 65 | s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4 66 | s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4 67 | w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) 68 | m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) 69 | else: 70 | pass 71 | 72 | 73 | 74 | def regularizer_clip(m): 75 | """ 76 | # ---------------------------------------- 77 | # usage: net.apply(regularizer_clip) 78 | # ---------------------------------------- 79 | """ 80 | eps = 1e-4 81 | c_min = -1.5 82 | c_max = 1.5 83 | 84 | classname = m.__class__.__name__ 85 | if classname.find('Conv') != -1 or classname.find('Linear') != -1: 86 | w = m.weight.data.clone() 87 | w[w > c_max] -= eps 88 | w[w < c_min] += eps 89 | m.weight.data = w 90 | 91 | if m.bias is not None: 92 | b = m.bias.data.clone() 93 | b[b > c_max] -= eps 94 | b[b < c_min] += eps 95 | m.bias.data = b 96 | 97 | # elif classname.find('BatchNorm2d') != -1: 98 | # 99 | # rv = m.running_var.data.clone() 100 | # rm = m.running_mean.data.clone() 101 | # 102 | # if m.affine: 103 | # m.weight.data 104 | # m.bias.data 105 | -------------------------------------------------------------------------------- /core/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/core/scripts/__init__.py -------------------------------------------------------------------------------- /core/scripts/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from .. import WarpCore 4 | from .. import templates 5 | 6 | 7 | def template_init(args): 8 | return '''' 9 | 10 | 11 | '''.strip() 12 | 13 | 14 | def init_template(args): 15 | parser = argparse.ArgumentParser(description='WarpCore template init tool') 16 | parser.add_argument('-t', '--template', type=str, default='WarpCore') 17 | args = parser.parse_args(args) 18 | 19 | if args.template == 'WarpCore': 20 | template_cls = WarpCore 21 | else: 22 | try: 23 | template_cls = __import__(args.template) 24 | except ModuleNotFoundError: 25 | template_cls = getattr(templates, args.template) 26 | print(template_cls) 27 | 28 | 29 | def main(): 30 | if len(sys.argv) < 2: 31 | print('Usage: core ') 32 | sys.exit(1) 33 | if sys.argv[1] == 'init': 34 | init_template(sys.argv[2:]) 35 | else: 36 | print('Unknown command') 37 | sys.exit(1) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /core/templates/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion import DiffusionCore -------------------------------------------------------------------------------- /core/templates/diffusion.py: -------------------------------------------------------------------------------- 1 | from .. import WarpCore 2 | from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary 3 | from abc import abstractmethod 4 | from dataclasses import dataclass 5 | import torch 6 | from torch import nn 7 | from torch.utils.data import DataLoader 8 | from gdf import GDF 9 | import numpy as np 10 | from tqdm import tqdm 11 | import wandb 12 | 13 | import webdataset as wds 14 | from webdataset.handlers import warn_and_continue 15 | from torch.distributed import barrier 16 | from enum import Enum 17 | 18 | class TargetReparametrization(Enum): 19 | EPSILON = 'epsilon' 20 | X0 = 'x0' 21 | 22 | class DiffusionCore(WarpCore): 23 | @dataclass(frozen=True) 24 | class Config(WarpCore.Config): 25 | # TRAINING PARAMS 26 | lr: float = EXPECTED_TRAIN 27 | grad_accum_steps: int = EXPECTED_TRAIN 28 | batch_size: int = EXPECTED_TRAIN 29 | updates: int = EXPECTED_TRAIN 30 | warmup_updates: int = EXPECTED_TRAIN 31 | save_every: int = 500 32 | backup_every: int = 20000 33 | use_fsdp: bool = True 34 | 35 | # EMA UPDATE 36 | ema_start_iters: int = None 37 | ema_iters: int = None 38 | ema_beta: float = None 39 | 40 | # GDF setting 41 | gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0 42 | 43 | @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED 44 | class Info(WarpCore.Info): 45 | ema_loss: float = None 46 | 47 | @dataclass(frozen=True) 48 | class Models(WarpCore.Models): 49 | generator : nn.Module = EXPECTED 50 | generator_ema : nn.Module = None # optional 51 | 52 | @dataclass(frozen=True) 53 | class Optimizers(WarpCore.Optimizers): 54 | generator : any = EXPECTED 55 | 56 | @dataclass(frozen=True) 57 | class Schedulers(WarpCore.Schedulers): 58 | generator: any = None 59 | 60 | @dataclass(frozen=True) 61 | class Extras(WarpCore.Extras): 62 | gdf: GDF = EXPECTED 63 | sampling_configs: dict = EXPECTED 64 | 65 | # -------------------------------------------- 66 | info: Info 67 | config: Config 68 | 69 | @abstractmethod 70 | def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: 71 | raise NotImplementedError("This method needs to be overriden") 72 | 73 | @abstractmethod 74 | def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: 75 | raise NotImplementedError("This method needs to be overriden") 76 | 77 | @abstractmethod 78 | def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False): 79 | raise NotImplementedError("This method needs to be overriden") 80 | 81 | @abstractmethod 82 | def webdataset_path(self, extras: Extras): 83 | raise NotImplementedError("This method needs to be overriden") 84 | 85 | @abstractmethod 86 | def webdataset_filters(self, extras: Extras): 87 | raise NotImplementedError("This method needs to be overriden") 88 | 89 | @abstractmethod 90 | def webdataset_preprocessors(self, extras: Extras): 91 | raise NotImplementedError("This method needs to be overriden") 92 | 93 | @abstractmethod 94 | def sample(self, models: Models, data: WarpCore.Data, extras: Extras): 95 | raise NotImplementedError("This method needs to be overriden") 96 | # ------------- 97 | 98 | def setup_data(self, extras: Extras) -> WarpCore.Data: 99 | # SETUP DATASET 100 | dataset_path = self.webdataset_path(extras) 101 | preprocessors = self.webdataset_preprocessors(extras) 102 | filters = self.webdataset_filters(extras) 103 | 104 | handler = warn_and_continue # None 105 | # handler = None 106 | dataset = wds.WebDataset( 107 | dataset_path, resampled=True, handler=handler 108 | ).select(filters).shuffle(690, handler=handler).decode( 109 | "pilrgb", handler=handler 110 | ).to_tuple( 111 | *[p[0] for p in preprocessors], handler=handler 112 | ).map_tuple( 113 | *[p[1] for p in preprocessors], handler=handler 114 | ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)}) 115 | 116 | # SETUP DATALOADER 117 | real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps) 118 | dataloader = DataLoader( 119 | dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True 120 | ) 121 | 122 | return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader)) 123 | 124 | def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): 125 | batch = next(data.iterator) 126 | 127 | with torch.no_grad(): 128 | conditions = self.get_conditions(batch, models, extras) 129 | latents = self.encode_latents(batch, models, extras) 130 | noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) 131 | 132 | # FORWARD PASS 133 | with torch.cuda.amp.autocast(dtype=torch.bfloat16): 134 | pred = models.generator(noised, noise_cond, **conditions) 135 | if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON: 136 | pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss 137 | target = noise 138 | elif self.config.gdf_target_reparametrization == TargetReparametrization.X0: 139 | pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss 140 | target = latents 141 | loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) 142 | loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps 143 | 144 | return loss, loss_adjusted 145 | 146 | def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): 147 | start_iter = self.info.iter+1 148 | max_iters = self.config.updates * self.config.grad_accum_steps 149 | if self.is_main_node: 150 | print(f"STARTING AT STEP: {start_iter}/{max_iters}") 151 | 152 | pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP 153 | models.generator.train() 154 | for i in pbar: 155 | # FORWARD PASS 156 | loss, loss_adjusted = self.forward_pass(data, extras, models) 157 | 158 | # BACKWARD PASS 159 | if i % self.config.grad_accum_steps == 0 or i == max_iters: 160 | loss_adjusted.backward() 161 | grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) 162 | optimizers_dict = optimizers.to_dict() 163 | for k in optimizers_dict: 164 | optimizers_dict[k].step() 165 | schedulers_dict = schedulers.to_dict() 166 | for k in schedulers_dict: 167 | schedulers_dict[k].step() 168 | models.generator.zero_grad(set_to_none=True) 169 | self.info.total_steps += 1 170 | else: 171 | with models.generator.no_sync(): 172 | loss_adjusted.backward() 173 | self.info.iter = i 174 | 175 | # UPDATE EMA 176 | if models.generator_ema is not None and i % self.config.ema_iters == 0: 177 | update_weights_ema( 178 | models.generator_ema, models.generator, 179 | beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) 180 | ) 181 | 182 | # UPDATE LOSS METRICS 183 | self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 184 | 185 | if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): 186 | wandb.alert( 187 | title=f"NaN value encountered in training run {self.info.wandb_run_id}", 188 | text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", 189 | wait_duration=60*30 190 | ) 191 | 192 | if self.is_main_node: 193 | logs = { 194 | 'loss': self.info.ema_loss, 195 | 'raw_loss': loss.mean().item(), 196 | 'grad_norm': grad_norm.item(), 197 | 'lr': optimizers.generator.param_groups[0]['lr'], 198 | 'total_steps': self.info.total_steps, 199 | } 200 | 201 | pbar.set_postfix(logs) 202 | if self.config.wandb_project is not None: 203 | wandb.log(logs) 204 | 205 | if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters: 206 | # SAVE AND CHECKPOINT STUFF 207 | if np.isnan(loss.mean().item()): 208 | if self.is_main_node and self.config.wandb_project is not None: 209 | tqdm.write("Skipping sampling & checkpoint because the loss is NaN") 210 | wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") 211 | else: 212 | self.save_checkpoints(models, optimizers) 213 | if self.is_main_node: 214 | create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') 215 | self.sample(models, data, extras) 216 | 217 | def models_to_save(self): 218 | return ['generator', 'generator_ema'] 219 | 220 | def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): 221 | barrier() 222 | suffix = '' if suffix is None else suffix 223 | self.save_info(self.info, suffix=suffix) 224 | models_dict = models.to_dict() 225 | optimizers_dict = optimizers.to_dict() 226 | for key in self.models_to_save(): 227 | model = models_dict[key] 228 | if model is not None: 229 | self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) 230 | for key in optimizers_dict: 231 | optimizer = optimizers_dict[key] 232 | if optimizer is not None: 233 | self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None) 234 | if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: 235 | self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k") 236 | torch.cuda.empty_cache() 237 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN 2 | from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail 3 | 4 | # MOVE IT SOMERWHERE ELSE 5 | def update_weights_ema(tgt_model, src_model, beta=0.999): 6 | for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): 7 | self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta) 8 | for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()): 9 | self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta) -------------------------------------------------------------------------------- /core/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/core/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /core/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/core/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /core/utils/__pycache__/base_dto.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/core/utils/__pycache__/base_dto.cpython-310.pyc -------------------------------------------------------------------------------- /core/utils/__pycache__/base_dto.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/core/utils/__pycache__/base_dto.cpython-39.pyc -------------------------------------------------------------------------------- /core/utils/__pycache__/save_and_load.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/core/utils/__pycache__/save_and_load.cpython-310.pyc -------------------------------------------------------------------------------- /core/utils/__pycache__/save_and_load.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/core/utils/__pycache__/save_and_load.cpython-39.pyc -------------------------------------------------------------------------------- /core/utils/base_dto.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from dataclasses import dataclass, _MISSING_TYPE 3 | from munch import Munch 4 | 5 | EXPECTED = "___REQUIRED___" 6 | EXPECTED_TRAIN = "___REQUIRED_TRAIN___" 7 | 8 | # pylint: disable=invalid-field-call 9 | def nested_dto(x, raw=False): 10 | return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x)) 11 | 12 | @dataclass(frozen=True) 13 | class Base: 14 | training: bool = None 15 | def __new__(cls, **kwargs): 16 | training = kwargs.get('training', True) 17 | setteable_fields = cls.setteable_fields(**kwargs) 18 | mandatory_fields = cls.mandatory_fields(**kwargs) 19 | invalid_kwargs = [ 20 | {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False) 21 | ] 22 | print(mandatory_fields) 23 | assert ( 24 | len(invalid_kwargs) == 0 25 | ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable." 26 | missing_kwargs = [f for f in mandatory_fields if f not in kwargs] 27 | assert ( 28 | len(missing_kwargs) == 0 29 | ), f"Required fields missing initializing this DTO: {missing_kwargs}." 30 | return object.__new__(cls) 31 | 32 | 33 | @classmethod 34 | def setteable_fields(cls, **kwargs): 35 | return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN] 36 | 37 | @classmethod 38 | def mandatory_fields(cls, **kwargs): 39 | training = kwargs.get('training', True) 40 | return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)] 41 | 42 | @classmethod 43 | def from_dict(cls, kwargs): 44 | for k in kwargs: 45 | if isinstance(kwargs[k], (dict, list, tuple)): 46 | kwargs[k] = Munch.fromDict(kwargs[k]) 47 | return cls(**kwargs) 48 | 49 | def to_dict(self): 50 | # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes 51 | selfdict = {} 52 | for k in dataclasses.fields(self): 53 | selfdict[k.name] = getattr(self, k.name) 54 | if isinstance(selfdict[k.name], Munch): 55 | selfdict[k.name] = selfdict[k.name].toDict() 56 | return selfdict 57 | -------------------------------------------------------------------------------- /core/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from pathlib import Path 5 | import safetensors 6 | import wandb 7 | 8 | 9 | def create_folder_if_necessary(path): 10 | path = "/".join(path.split("/")[:-1]) 11 | Path(path).mkdir(parents=True, exist_ok=True) 12 | 13 | 14 | def safe_save(ckpt, path): 15 | try: 16 | os.remove(f"{path}.bak") 17 | except OSError: 18 | pass 19 | try: 20 | os.rename(path, f"{path}.bak") 21 | except OSError: 22 | pass 23 | if path.endswith(".pt") or path.endswith(".ckpt"): 24 | torch.save(ckpt, path) 25 | elif path.endswith(".json"): 26 | with open(path, "w", encoding="utf-8") as f: 27 | json.dump(ckpt, f, indent=4) 28 | elif path.endswith(".safetensors"): 29 | safetensors.torch.save_file(ckpt, path) 30 | else: 31 | raise ValueError(f"File extension not supported: {path}") 32 | 33 | 34 | def load_or_fail(path, wandb_run_id=None): 35 | accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"] 36 | try: 37 | assert any( 38 | [path.endswith(ext) for ext in accepted_extensions] 39 | ), f"Automatic loading not supported for this extension: {path}" 40 | if not os.path.exists(path): 41 | checkpoint = None 42 | elif path.endswith(".pt") or path.endswith(".ckpt"): 43 | checkpoint = torch.load(path, map_location="cpu") 44 | elif path.endswith(".json"): 45 | with open(path, "r", encoding="utf-8") as f: 46 | checkpoint = json.load(f) 47 | elif path.endswith(".safetensors"): 48 | checkpoint = {} 49 | with safetensors.safe_open(path, framework="pt", device="cpu") as f: 50 | for key in f.keys(): 51 | checkpoint[key] = f.get_tensor(key) 52 | return checkpoint 53 | except Exception as e: 54 | if wandb_run_id is not None: 55 | wandb.alert( 56 | title=f"Corrupt checkpoint for run {wandb_run_id}", 57 | text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed", 58 | ) 59 | raise e 60 | -------------------------------------------------------------------------------- /figures/California_000490.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/California_000490.jpg -------------------------------------------------------------------------------- /figures/example_dataset/000008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_dataset/000008.jpg -------------------------------------------------------------------------------- /figures/example_dataset/000008.json: -------------------------------------------------------------------------------- 1 | { "caption": "The image captures the iconic Shard, a modern skyscraper that stands as the tallest building in the United Kingdom. The Shard, with its glass and steel structure, pierces the sky, its pointed top reaching towards the heavens. The photograph is taken from a low angle, which emphasizes the height and grandeur of the building. The sky forms a beautiful backdrop, painted in hues of pinkish-orange, suggesting that the photo was taken at sunset. The Shard is nestled between two other buildings, their presence subtly hinted at in the foreground. The image does not contain any discernible text or countable objects, and there are no visible actions taking place. The relative positions of the objects confirm that the Shard is the central focus of the image, with the other buildings and the sky providing context to its location. The image is devoid of any aesthetic descriptions, focusing solely on the factual representation of the scene." 2 | } -------------------------------------------------------------------------------- /figures/example_dataset/000012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_dataset/000012.jpg -------------------------------------------------------------------------------- /figures/example_dataset/000012.json: -------------------------------------------------------------------------------- 1 | {"caption": "cars in a road during daytime"} -------------------------------------------------------------------------------- /figures/example_lora_cat/1_B0004902.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_lora_cat/1_B0004902.jpg -------------------------------------------------------------------------------- /figures/example_lora_cat/2_B0005089.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_lora_cat/2_B0005089.jpg -------------------------------------------------------------------------------- /figures/example_lora_cat/3_B0005163.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_lora_cat/3_B0005163.jpg -------------------------------------------------------------------------------- /figures/example_lora_cat/4_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_lora_cat/4_cat.jpg -------------------------------------------------------------------------------- /figures/example_lora_cat/5_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_lora_cat/5_cat.jpg -------------------------------------------------------------------------------- /figures/example_lora_cat/6_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_lora_cat/6_cat.jpg -------------------------------------------------------------------------------- /figures/example_lora_cat/7_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_lora_cat/7_cat.jpg -------------------------------------------------------------------------------- /figures/example_lora_cat/8_20240611230541.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_lora_cat/8_20240611230541.jpg -------------------------------------------------------------------------------- /figures/example_lora_cat/9_20240611230549.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/example_lora_cat/9_20240611230549.jpg -------------------------------------------------------------------------------- /figures/example_lora_cat/README.md: -------------------------------------------------------------------------------- 1 | example dataset for lora personalization 2 | -------------------------------------------------------------------------------- /figures/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/figures/teaser.jpg -------------------------------------------------------------------------------- /gdf/loss_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | # --- Loss Weighting 5 | class BaseLossWeight(): 6 | def weight(self, logSNR): 7 | raise NotImplementedError("this method needs to be overridden") 8 | 9 | def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs): 10 | clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range 11 | if shift != 1: 12 | logSNR = logSNR.clone() + 2 * np.log(shift) 13 | return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range) 14 | 15 | class ComposedLossWeight(BaseLossWeight): 16 | def __init__(self, div, mul): 17 | self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul 18 | self.div = [div] if isinstance(div, BaseLossWeight) else div 19 | 20 | def weight(self, logSNR): 21 | prod, div = 1, 1 22 | for m in self.mul: 23 | prod *= m.weight(logSNR) 24 | for d in self.div: 25 | div *= d.weight(logSNR) 26 | return prod/div 27 | 28 | class ConstantLossWeight(BaseLossWeight): 29 | def __init__(self, v=1): 30 | self.v = v 31 | 32 | def weight(self, logSNR): 33 | return torch.ones_like(logSNR) * self.v 34 | 35 | class SNRLossWeight(BaseLossWeight): 36 | def weight(self, logSNR): 37 | return logSNR.exp() 38 | 39 | class P2LossWeight(BaseLossWeight): 40 | def __init__(self, k=1.0, gamma=1.0, s=1.0): 41 | self.k, self.gamma, self.s = k, gamma, s 42 | 43 | def weight(self, logSNR): 44 | return (self.k + (logSNR * self.s).exp()) ** -self.gamma 45 | 46 | class SNRPlusOneLossWeight(BaseLossWeight): 47 | def weight(self, logSNR): 48 | return logSNR.exp() + 1 49 | 50 | class MinSNRLossWeight(BaseLossWeight): 51 | def __init__(self, max_snr=5): 52 | self.max_snr = max_snr 53 | 54 | def weight(self, logSNR): 55 | return logSNR.exp().clamp(max=self.max_snr) 56 | 57 | class MinSNRPlusOneLossWeight(BaseLossWeight): 58 | def __init__(self, max_snr=5): 59 | self.max_snr = max_snr 60 | 61 | def weight(self, logSNR): 62 | return (logSNR.exp() + 1).clamp(max=self.max_snr) 63 | 64 | class TruncatedSNRLossWeight(BaseLossWeight): 65 | def __init__(self, min_snr=1): 66 | self.min_snr = min_snr 67 | 68 | def weight(self, logSNR): 69 | return logSNR.exp().clamp(min=self.min_snr) 70 | 71 | class SechLossWeight(BaseLossWeight): 72 | def __init__(self, div=2): 73 | self.div = div 74 | 75 | def weight(self, logSNR): 76 | return 1/(logSNR/self.div).cosh() 77 | 78 | class DebiasedLossWeight(BaseLossWeight): 79 | def weight(self, logSNR): 80 | return 1/logSNR.exp().sqrt() 81 | 82 | class SigmoidLossWeight(BaseLossWeight): 83 | def __init__(self, s=1): 84 | self.s = s 85 | 86 | def weight(self, logSNR): 87 | return (logSNR * self.s).sigmoid() 88 | 89 | class AdaptiveLossWeight(BaseLossWeight): 90 | def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]): 91 | self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1) 92 | self.bucket_losses = torch.ones(buckets) 93 | self.weight_range = weight_range 94 | 95 | def weight(self, logSNR): 96 | indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR) 97 | return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range) 98 | 99 | def update_buckets(self, logSNR, loss, beta=0.99): 100 | indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu() 101 | self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta) 102 | -------------------------------------------------------------------------------- /gdf/noise_conditions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class BaseNoiseCond(): 5 | def __init__(self, *args, shift=1, clamp_range=None, **kwargs): 6 | clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range 7 | self.shift = shift 8 | self.clamp_range = clamp_range 9 | self.setup(*args, **kwargs) 10 | 11 | def setup(self, *args, **kwargs): 12 | pass # this method is optional, override it if required 13 | 14 | def cond(self, logSNR): 15 | raise NotImplementedError("this method needs to be overriden") 16 | 17 | def __call__(self, logSNR): 18 | if self.shift != 1: 19 | logSNR = logSNR.clone() + 2 * np.log(self.shift) 20 | return self.cond(logSNR).clamp(*self.clamp_range) 21 | 22 | class CosineTNoiseCond(BaseNoiseCond): 23 | def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999] 24 | self.s = torch.tensor([s]) 25 | self.clamp_range = clamp_range 26 | self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 27 | 28 | def cond(self, logSNR): 29 | var = logSNR.sigmoid() 30 | var = var.clamp(*self.clamp_range) 31 | s, min_var = self.s.to(var.device), self.min_var.to(var.device) 32 | t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s 33 | return t 34 | 35 | class EDMNoiseCond(BaseNoiseCond): 36 | def cond(self, logSNR): 37 | return -logSNR/8 38 | 39 | class SigmoidNoiseCond(BaseNoiseCond): 40 | def cond(self, logSNR): 41 | return (-logSNR).sigmoid() 42 | 43 | class LogSNRNoiseCond(BaseNoiseCond): 44 | def cond(self, logSNR): 45 | return logSNR 46 | 47 | class EDMSigmaNoiseCond(BaseNoiseCond): 48 | def setup(self, sigma_data=1): 49 | self.sigma_data = sigma_data 50 | 51 | def cond(self, logSNR): 52 | return torch.exp(-logSNR / 2) * self.sigma_data 53 | 54 | class RectifiedFlowsNoiseCond(BaseNoiseCond): 55 | def cond(self, logSNR): 56 | _a = logSNR.exp() - 1 57 | _a[_a == 0] = 1e-3 # Avoid division by zero 58 | a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) 59 | return a 60 | 61 | # Any NoiseCond that cannot be described easily as a continuous function of t 62 | # It needs to define self.x and self.y in the setup() method 63 | class PiecewiseLinearNoiseCond(BaseNoiseCond): 64 | def setup(self): 65 | self.x = None 66 | self.y = None 67 | 68 | def piecewise_linear(self, y, xs, ys): 69 | indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y) 70 | x_min, x_max = xs[indices], xs[indices+1] 71 | y_min, y_max = ys[indices], ys[indices+1] 72 | x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min) 73 | return x 74 | 75 | def cond(self, logSNR): 76 | var = logSNR.sigmoid() 77 | t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0) 78 | return t 79 | 80 | class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond): 81 | def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): 82 | self.total_steps = total_steps 83 | linear_range_sqrt = [r**0.5 for r in linear_range] 84 | self.x = torch.linspace(0, 1, total_steps+1) 85 | 86 | alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 87 | self.y = alphas.cumprod(dim=-1) 88 | 89 | def cond(self, logSNR): 90 | return super().cond(logSNR).clamp(0, 1) 91 | 92 | class DiscreteNoiseCond(BaseNoiseCond): 93 | def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]): 94 | self.noise_cond = noise_cond 95 | self.steps = steps 96 | self.continuous_range = continuous_range 97 | 98 | def cond(self, logSNR): 99 | cond = self.noise_cond(logSNR) 100 | cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0]) 101 | return cond.mul(self.steps).long() 102 | -------------------------------------------------------------------------------- /gdf/readme.md: -------------------------------------------------------------------------------- 1 | # Generic Diffusion Framework (GDF) 2 | 3 | # Basic usage 4 | GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM 5 | , EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different 6 | frameworks 7 | 8 | Using GDF is very straighforward, first of all just define an instance of the GDF class: 9 | 10 | ```python 11 | from gdf import GDF 12 | from gdf import CosineSchedule 13 | from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight 14 | 15 | gdf = GDF( 16 | schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), 17 | input_scaler=VPScaler(), target=EpsilonTarget(), 18 | noise_cond=CosineTNoiseCond(), 19 | loss_weight=P2LossWeight(), 20 | ) 21 | ``` 22 | 23 | You need to define the following components: 24 | * **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution. 25 | * **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule. 26 | * **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows) 27 | * **Target**: What the target is during training, usually: epsilon, x0 or v 28 | * **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8` 29 | * **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use 30 | 31 | All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just: 32 | ```python 33 | class VPScaler(): 34 | def __call__(self, logSNR): 35 | a_squared = logSNR.sigmoid() 36 | a = a_squared.sqrt() 37 | b = (1-a_squared).sqrt() 38 | return a, b 39 | 40 | ``` 41 | 42 | So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc... 43 | 44 | ### Training 45 | 46 | When you define your training loop you can get all you need by just doing: 47 | ```python 48 | shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution 49 | for inputs, extra_conditions in dataloader_iterator: 50 | noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift) 51 | pred = diffusion_model(noised, noise_cond, extra_conditions) 52 | 53 | loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) 54 | loss_adjusted = (loss * loss_weight).mean() 55 | 56 | loss_adjusted.backward() 57 | optimizer.step() 58 | optimizer.zero_grad(set_to_none=True) 59 | ``` 60 | 61 | And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the 62 | training from the GDF class. 63 | 64 | ### Sampling 65 | 66 | The other important part is sampling, when you want to use this framework to sample you can just do the following: 67 | 68 | ```python 69 | from gdf import DDPMSampler 70 | 71 | shift = 1 72 | sampling_configs = { 73 | "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift, 74 | "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999]) 75 | } 76 | 77 | *_, (sampled, _, _) = gdf.sample( 78 | diffusion_model, {"cond": extra_conditions}, latents.shape, 79 | unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)}, 80 | device=device, **sampling_configs 81 | ) 82 | ``` 83 | 84 | # Available modules 85 | 86 | TODO 87 | -------------------------------------------------------------------------------- /gdf/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SimpleSampler(): 4 | def __init__(self, gdf): 5 | self.gdf = gdf 6 | self.current_step = -1 7 | 8 | def __call__(self, *args, **kwargs): 9 | self.current_step += 1 10 | return self.step(*args, **kwargs) 11 | 12 | def init_x(self, shape): 13 | return torch.randn(*shape) 14 | 15 | def step(self, x, x0, epsilon, logSNR, logSNR_prev): 16 | raise NotImplementedError("You should override the 'apply' function.") 17 | 18 | class DDIMSampler(SimpleSampler): 19 | def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0): 20 | a, b = self.gdf.input_scaler(logSNR) 21 | if len(a.shape) == 1: 22 | a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) 23 | 24 | a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) 25 | if len(a_prev.shape) == 1: 26 | a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) 27 | 28 | sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0 29 | # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) 30 | x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) 31 | return x 32 | 33 | class DDPMSampler(DDIMSampler): 34 | def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1): 35 | return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta) 36 | 37 | class LCMSampler(SimpleSampler): 38 | def step(self, x, x0, epsilon, logSNR, logSNR_prev): 39 | a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) 40 | if len(a_prev.shape) == 1: 41 | a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) 42 | return x0 * a_prev + torch.randn_like(epsilon) * b_prev 43 | -------------------------------------------------------------------------------- /gdf/scalers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class BaseScaler(): 4 | def __init__(self): 5 | self.stretched_limits = None 6 | 7 | def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): 8 | min_logSNR = schedule(torch.ones(1), shift=shift) 9 | max_logSNR = schedule(torch.zeros(1), shift=shift) 10 | 11 | min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] 12 | max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] 13 | self.stretched_limits = [min_a, max_a, min_b, max_b] 14 | return self.stretched_limits 15 | 16 | def stretch_limits(self, a, b): 17 | min_a, max_a, min_b, max_b = self.stretched_limits 18 | return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) 19 | 20 | def scalers(self, logSNR): 21 | raise NotImplementedError("this method needs to be overridden") 22 | 23 | def __call__(self, logSNR): 24 | a, b = self.scalers(logSNR) 25 | if self.stretched_limits is not None: 26 | a, b = self.stretch_limits(a, b) 27 | return a, b 28 | 29 | class VPScaler(BaseScaler): 30 | def scalers(self, logSNR): 31 | a_squared = logSNR.sigmoid() 32 | a = a_squared.sqrt() 33 | b = (1-a_squared).sqrt() 34 | return a, b 35 | 36 | class LERPScaler(BaseScaler): 37 | def scalers(self, logSNR): 38 | _a = logSNR.exp() - 1 39 | _a[_a == 0] = 1e-3 # Avoid division by zero 40 | a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) 41 | b = 1-a 42 | return a, b 43 | -------------------------------------------------------------------------------- /gdf/schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class BaseSchedule(): 5 | def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs): 6 | self.setup(*args, **kwargs) 7 | self.limits = None 8 | self.discrete_steps = discrete_steps 9 | self.shift = shift 10 | if force_limits: 11 | self.reset_limits() 12 | 13 | def reset_limits(self, shift=1, disable=False): 14 | try: 15 | self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max 16 | return self.limits 17 | except Exception: 18 | print("WARNING: this schedule doesn't support t and will be unbounded") 19 | return None 20 | 21 | def setup(self, *args, **kwargs): 22 | raise NotImplementedError("this method needs to be overriden") 23 | 24 | def schedule(self, *args, **kwargs): 25 | raise NotImplementedError("this method needs to be overriden") 26 | 27 | def __call__(self, t, *args, shift=1, **kwargs): 28 | if isinstance(t, torch.Tensor): 29 | batch_size = None 30 | if self.discrete_steps is not None: 31 | if t.dtype != torch.long: 32 | t = (t * (self.discrete_steps-1)).round().long() 33 | t = t / (self.discrete_steps-1) 34 | t = t.clamp(0, 1) 35 | else: 36 | batch_size = t 37 | t = None 38 | logSNR = self.schedule(t, batch_size, *args, **kwargs) 39 | if shift*self.shift != 1: 40 | logSNR += 2 * np.log(1/(shift*self.shift)) 41 | if self.limits is not None: 42 | logSNR = logSNR.clamp(*self.limits) 43 | return logSNR 44 | 45 | class CosineSchedule(BaseSchedule): 46 | def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False): 47 | self.s = torch.tensor([s]) 48 | self.clamp_range = clamp_range 49 | self.norm_instead = norm_instead 50 | self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 51 | 52 | def schedule(self, t, batch_size): 53 | if t is None: 54 | t = (1-torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0) 55 | s, min_var = self.s.to(t.device), self.min_var.to(t.device) 56 | var = torch.cos((s + t)/(1+s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var 57 | if self.norm_instead: 58 | var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] 59 | else: 60 | var = var.clamp(*self.clamp_range) 61 | logSNR = (var/(1-var)).log() 62 | return logSNR 63 | 64 | class CosineSchedule2(BaseSchedule): 65 | def setup(self, logsnr_range=[-15, 15]): 66 | self.t_min = np.arctan(np.exp(-0.5 * logsnr_range[1])) 67 | self.t_max = np.arctan(np.exp(-0.5 * logsnr_range[0])) 68 | 69 | def schedule(self, t, batch_size): 70 | if t is None: 71 | t = 1-torch.rand(batch_size) 72 | return -2 * (self.t_min + t*(self.t_max-self.t_min)).tan().log() 73 | 74 | class SqrtSchedule(BaseSchedule): 75 | def setup(self, s=1e-4, clamp_range=[0.0001, 0.9999], norm_instead=False): 76 | self.s = s 77 | self.clamp_range = clamp_range 78 | self.norm_instead = norm_instead 79 | 80 | def schedule(self, t, batch_size): 81 | if t is None: 82 | t = 1-torch.rand(batch_size) 83 | var = 1 - (t + self.s)**0.5 84 | if self.norm_instead: 85 | var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] 86 | else: 87 | var = var.clamp(*self.clamp_range) 88 | logSNR = (var/(1-var)).log() 89 | return logSNR 90 | 91 | class RectifiedFlowsSchedule(BaseSchedule): 92 | def setup(self, logsnr_range=[-15, 15]): 93 | self.logsnr_range = logsnr_range 94 | 95 | def schedule(self, t, batch_size): 96 | if t is None: 97 | t = 1-torch.rand(batch_size) 98 | logSNR = (((1-t)**2)/(t**2)).log() 99 | logSNR = logSNR.clamp(*self.logsnr_range) 100 | return logSNR 101 | 102 | class EDMSampleSchedule(BaseSchedule): 103 | def setup(self, sigma_range=[0.002, 80], p=7): 104 | self.sigma_range = sigma_range 105 | self.p = p 106 | 107 | def schedule(self, t, batch_size): 108 | if t is None: 109 | t = 1-torch.rand(batch_size) 110 | smin, smax, p = *self.sigma_range, self.p 111 | sigma = (smax ** (1/p) + (1-t) * (smin ** (1/p) - smax ** (1/p))) ** p 112 | logSNR = (1/sigma**2).log() 113 | return logSNR 114 | 115 | class EDMTrainSchedule(BaseSchedule): 116 | def setup(self, mu=-1.2, std=1.2): 117 | self.mu = mu 118 | self.std = std 119 | 120 | def schedule(self, t, batch_size): 121 | if t is not None: 122 | raise Exception("EDMTrainSchedule doesn't support passing timesteps: t") 123 | logSNR = -2*(torch.randn(batch_size) * self.std - self.mu) 124 | return logSNR 125 | 126 | class LinearSchedule(BaseSchedule): 127 | def setup(self, logsnr_range=[-10, 10]): 128 | self.logsnr_range = logsnr_range 129 | 130 | def schedule(self, t, batch_size): 131 | if t is None: 132 | t = 1-torch.rand(batch_size) 133 | logSNR = t * (self.logsnr_range[0]-self.logsnr_range[1]) + self.logsnr_range[1] 134 | return logSNR 135 | 136 | # Any schedule that cannot be described easily as a continuous function of t 137 | # It needs to define self.x and self.y in the setup() method 138 | class PiecewiseLinearSchedule(BaseSchedule): 139 | def setup(self): 140 | self.x = None 141 | self.y = None 142 | 143 | def piecewise_linear(self, x, xs, ys): 144 | indices = torch.searchsorted(xs[:-1], x) - 1 145 | x_min, x_max = xs[indices], xs[indices+1] 146 | y_min, y_max = ys[indices], ys[indices+1] 147 | var = y_min + (y_max - y_min) * (x - x_min) / (x_max - x_min) 148 | return var 149 | 150 | def schedule(self, t, batch_size): 151 | if t is None: 152 | t = 1-torch.rand(batch_size) 153 | var = self.piecewise_linear(t, self.x.to(t.device), self.y.to(t.device)) 154 | logSNR = (var/(1-var)).log() 155 | return logSNR 156 | 157 | class StableDiffusionSchedule(PiecewiseLinearSchedule): 158 | def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): 159 | linear_range_sqrt = [r**0.5 for r in linear_range] 160 | self.x = torch.linspace(0, 1, total_steps+1) 161 | 162 | alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 163 | self.y = alphas.cumprod(dim=-1) 164 | 165 | class AdaptiveTrainSchedule(BaseSchedule): 166 | def setup(self, logsnr_range=[-10, 10], buckets=100, min_probs=0.0): 167 | th = torch.linspace(logsnr_range[0], logsnr_range[1], buckets+1) 168 | self.bucket_ranges = torch.tensor([(th[i], th[i+1]) for i in range(buckets)]) 169 | self.bucket_probs = torch.ones(buckets) 170 | self.min_probs = min_probs 171 | 172 | def schedule(self, t, batch_size): 173 | if t is not None: 174 | raise Exception("AdaptiveTrainSchedule doesn't support passing timesteps: t") 175 | norm_probs = ((self.bucket_probs+self.min_probs) / (self.bucket_probs+self.min_probs).sum()) 176 | buckets = torch.multinomial(norm_probs, batch_size, replacement=True) 177 | ranges = self.bucket_ranges[buckets] 178 | logSNR = torch.rand(batch_size) * (ranges[:, 1]-ranges[:, 0]) + ranges[:, 0] 179 | return logSNR 180 | 181 | def update_buckets(self, logSNR, loss, beta=0.99): 182 | range_mtx = self.bucket_ranges.unsqueeze(0).expand(logSNR.size(0), -1, -1).to(logSNR.device) 183 | range_mask = (range_mtx[:, :, 0] <= logSNR[:, None]) * (range_mtx[:, :, 1] > logSNR[:, None]).float() 184 | range_idx = range_mask.argmax(-1).cpu() 185 | self.bucket_probs[range_idx] = self.bucket_probs[range_idx] * beta + loss.detach().cpu() * (1-beta) 186 | 187 | class InterpolatedSchedule(BaseSchedule): 188 | def setup(self, scheduler1, scheduler2, shifts=[1.0, 1.0]): 189 | self.scheduler1 = scheduler1 190 | self.scheduler2 = scheduler2 191 | self.shifts = shifts 192 | 193 | def schedule(self, t, batch_size): 194 | if t is None: 195 | t = 1-torch.rand(batch_size) 196 | t = t.clamp(1e-7, 1-1e-7) # avoid infinities multiplied by 0 which cause nan 197 | low_logSNR = self.scheduler1(t, shift=self.shifts[0]) 198 | high_logSNR = self.scheduler2(t, shift=self.shifts[1]) 199 | return low_logSNR * t + high_logSNR * (1-t) 200 | 201 | -------------------------------------------------------------------------------- /gdf/targets.py: -------------------------------------------------------------------------------- 1 | class EpsilonTarget(): 2 | def __call__(self, x0, epsilon, logSNR, a, b): 3 | return epsilon 4 | 5 | def x0(self, noised, pred, logSNR, a, b): 6 | return (noised - pred * b) / a 7 | 8 | def epsilon(self, noised, pred, logSNR, a, b): 9 | return pred 10 | def noise_givenx0_noised(self, x0, noised , logSNR, a, b): 11 | return (noised - a * x0) / b 12 | def xt(self, x0, noise, logSNR, a, b): 13 | 14 | return x0 * a + noise*b 15 | class X0Target(): 16 | def __call__(self, x0, epsilon, logSNR, a, b): 17 | return x0 18 | 19 | def x0(self, noised, pred, logSNR, a, b): 20 | return pred 21 | 22 | def epsilon(self, noised, pred, logSNR, a, b): 23 | return (noised - pred * a) / b 24 | 25 | class VTarget(): 26 | def __call__(self, x0, epsilon, logSNR, a, b): 27 | return a * epsilon - b * x0 28 | 29 | def x0(self, noised, pred, logSNR, a, b): 30 | squared_sum = a**2 + b**2 31 | return a/squared_sum * noised - b/squared_sum * pred 32 | 33 | def epsilon(self, noised, pred, logSNR, a, b): 34 | squared_sum = a**2 + b**2 35 | return b/squared_sum * noised + a/squared_sum * pred 36 | 37 | class RectifiedFlowsTarget(): 38 | def __call__(self, x0, epsilon, logSNR, a, b): 39 | return epsilon - x0 40 | 41 | def x0(self, noised, pred, logSNR, a, b): 42 | return noised - pred * b 43 | 44 | def epsilon(self, noised, pred, logSNR, a, b): 45 | return noised + pred * a 46 | -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/inference/__init__.py -------------------------------------------------------------------------------- /inference/test_controlnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import torchvision 5 | from tqdm import tqdm 6 | import sys 7 | sys.path.append(os.path.abspath('./')) 8 | 9 | from inference.utils import * 10 | from core.utils import load_or_fail 11 | from train import WurstCore_control_lrguide, WurstCoreB 12 | from PIL import Image 13 | from core.utils import load_or_fail 14 | import math 15 | import argparse 16 | import time 17 | import random 18 | import numpy as np 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( '--height', type=int, default=3840, help='image height') 22 | parser.add_argument('--width', type=int, default=2160, help='image width') 23 | parser.add_argument('--control_weight', type=float, default=0.70, help='[ 0.3, 0.8]') 24 | parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') 25 | parser.add_argument('--seed', type=int, default=123, help='random seed') 26 | parser.add_argument('--config_c', type=str, 27 | default='configs/training/cfg_control_lr.yaml' ,help='config file for stage c, latent generation') 28 | parser.add_argument('--config_b', type=str, 29 | default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') 30 | parser.add_argument( '--prompt', type=str, 31 | default='A peaceful lake surrounded by mountain, white cloud in the sky, high quality,', help='text prompt') 32 | parser.add_argument( '--num_image', type=int, default=4, help='how many images generated') 33 | parser.add_argument( '--output_dir', type=str, default='figures/controlnet_results/', help='output directory for generated image') 34 | parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') 35 | parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') 36 | parser.add_argument( '--canny_source_url', type=str, default="figures/California_000490.jpg", help='image used to extract canny edge map') 37 | 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | if __name__ == "__main__": 43 | 44 | args = parse_args() 45 | width = args.width 46 | height = args.height 47 | torch.manual_seed(args.seed) 48 | random.seed(args.seed) 49 | np.random.seed(args.seed) 50 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 51 | dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float 52 | 53 | 54 | # SETUP STAGE C 55 | with open(args.config_c, "r", encoding="utf-8") as file: 56 | loaded_config = yaml.safe_load(file) 57 | core = WurstCore_control_lrguide(config_dict=loaded_config, device=device, training=False) 58 | 59 | # SETUP STAGE B 60 | with open(args.config_b, "r", encoding="utf-8") as file: 61 | config_file_b = yaml.safe_load(file) 62 | 63 | core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) 64 | 65 | extras = core.setup_extras_pre() 66 | models = core.setup_models(extras) 67 | models.generator.eval().requires_grad_(False) 68 | print("CONTROLNET READY") 69 | 70 | extras_b = core_b.setup_extras_pre() 71 | models_b = core_b.setup_models(extras_b, skip_clip=True) 72 | models_b = WurstCoreB.Models( 73 | **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} 74 | ) 75 | models_b.generator.eval().requires_grad_(False) 76 | print("STAGE B READY") 77 | 78 | batch_size = 1 79 | save_dir = args.output_dir 80 | url = args.canny_source_url 81 | images = resize_image(Image.open(url).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1) 82 | batch = {'images': images} 83 | 84 | 85 | 86 | 87 | 88 | 89 | cnet_multiplier = args.control_weight # 0.8 0.6 0.3 control strength 90 | caption_list = [args.prompt] * args.num_image 91 | height_lr, width_lr = get_target_lr_size(height / width, std_size=32) 92 | stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) 93 | stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) 94 | 95 | 96 | 97 | 98 | if not os.path.exists(save_dir): 99 | os.makedirs(save_dir) 100 | 101 | 102 | sdd = torch.load(args.pretrained_path, map_location='cpu') 103 | collect_sd = {} 104 | for k, v in sdd.items(): 105 | collect_sd[k[7:]] = v 106 | models.train_norm.load_state_dict(collect_sd, strict=True) 107 | 108 | 109 | 110 | 111 | models.controlnet.load_state_dict(load_or_fail(core.config.controlnet_checkpoint_path), strict=True) 112 | # Stage C Parameters 113 | extras.sampling_configs['cfg'] = 1 114 | extras.sampling_configs['shift'] = 2 115 | extras.sampling_configs['timesteps'] = 20 116 | extras.sampling_configs['t_start'] = 1.0 117 | 118 | # Stage B Parameters 119 | extras_b.sampling_configs['cfg'] = 1.1 120 | extras_b.sampling_configs['shift'] = 1 121 | extras_b.sampling_configs['timesteps'] = 10 122 | extras_b.sampling_configs['t_start'] = 1.0 123 | 124 | # PREPARE CONDITIONS 125 | 126 | 127 | 128 | 129 | for out_cnt, caption in enumerate(caption_list): 130 | with torch.no_grad(): 131 | 132 | batch['captions'] = [caption + ' high quality'] * batch_size 133 | conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) 134 | unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) 135 | 136 | cnet, cnet_input = core.get_cnet(batch, models, extras) 137 | cnet_uncond = cnet 138 | conditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet] 139 | unconditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet_uncond] 140 | edge_images = show_images(cnet_input) 141 | models.generator.cuda() 142 | for idx, img in enumerate(edge_images): 143 | img.save(os.path.join(save_dir, f"edge_{url.split('/')[-1]}")) 144 | 145 | 146 | print('STAGE C GENERATION***************************') 147 | with torch.cuda.amp.autocast(dtype=dtype): 148 | sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions, unconditions) 149 | models.generator.cpu() 150 | torch.cuda.empty_cache() 151 | 152 | conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) 153 | unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) 154 | 155 | conditions_b['effnet'] = sampled_c 156 | unconditions_b['effnet'] = torch.zeros_like(sampled_c) 157 | print('STAGE B + A DECODING***************************') 158 | with torch.cuda.amp.autocast(dtype=dtype): 159 | sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) 160 | 161 | torch.cuda.empty_cache() 162 | imgs = show_images(sampled) 163 | 164 | for idx, img in enumerate(imgs): 165 | img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(out_cnt).zfill(5) + '.jpg')) 166 | print('finished! Results at ', save_dir ) 167 | -------------------------------------------------------------------------------- /inference/test_personalized.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import yaml 4 | import torch 5 | from tqdm import tqdm 6 | import sys 7 | sys.path.append(os.path.abspath('./')) 8 | from inference.utils import * 9 | from train import WurstCoreB 10 | from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight 11 | from train import WurstCore_personalized as WurstCoreC 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import random 15 | import math 16 | import argparse 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( '--height', type=int, default=3072, help='image height') 22 | parser.add_argument('--width', type=int, default=4096, help='image width') 23 | parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') 24 | parser.add_argument('--seed', type=int, default=23, help='random seed') 25 | parser.add_argument('--config_c', type=str, 26 | default="configs/training/lora_personalization.yaml" ,help='config file for stage c, latent generation') 27 | parser.add_argument('--config_b', type=str, 28 | default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') 29 | parser.add_argument( '--prompt', type=str, 30 | default='A photo of cat [roubaobao] with sunglasses, Time Square in the background, high quality, detail rich, 8k', help='text prompt') 31 | parser.add_argument( '--num_image', type=int, default=4, help='how many images generated') 32 | parser.add_argument( '--output_dir', type=str, default='figures/personalized/', help='output directory for generated image') 33 | parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') 34 | parser.add_argument( '--pretrained_path_lora', type=str, default='models/lora_cat.safetensors',help='pretrained path of personalized lora parameter') 35 | parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') 36 | args = parser.parse_args() 37 | return args 38 | 39 | if __name__ == "__main__": 40 | args = parse_args() 41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | torch.manual_seed(args.seed) 43 | random.seed(args.seed) 44 | np.random.seed(args.seed) 45 | dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float 46 | 47 | 48 | # SETUP STAGE C 49 | with open(args.config_c, "r", encoding="utf-8") as file: 50 | loaded_config = yaml.safe_load(file) 51 | core = WurstCoreC(config_dict=loaded_config, device=device, training=False) 52 | 53 | # SETUP STAGE B 54 | with open(args.config_b, "r", encoding="utf-8") as file: 55 | config_file_b = yaml.safe_load(file) 56 | core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) 57 | 58 | extras = core.setup_extras_pre() 59 | models = core.setup_models(extras) 60 | models.generator.eval().requires_grad_(False) 61 | print("STAGE C READY") 62 | 63 | extras_b = core_b.setup_extras_pre() 64 | models_b = core_b.setup_models(extras_b, skip_clip=True) 65 | models_b = WurstCoreB.Models( 66 | **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} 67 | ) 68 | models_b.generator.bfloat16().eval().requires_grad_(False) 69 | print("STAGE B READY") 70 | 71 | 72 | batch_size = 1 73 | captions = [args.prompt] * args.num_image 74 | height, width = args.height, args.width 75 | save_dir = args.output_dir 76 | 77 | if not os.path.exists(save_dir): 78 | os.makedirs(save_dir) 79 | 80 | 81 | pretrained_pth = args.pretrained_path 82 | sdd = torch.load(pretrained_pth, map_location='cpu') 83 | collect_sd = {} 84 | for k, v in sdd.items(): 85 | collect_sd[k[7:]] = v 86 | 87 | models.train_norm.load_state_dict(collect_sd) 88 | 89 | 90 | pretrained_pth_lora = args.pretrained_path_lora 91 | sdd = torch.load(pretrained_pth_lora, map_location='cpu') 92 | collect_sd = {} 93 | for k, v in sdd.items(): 94 | collect_sd[k[7:]] = v 95 | 96 | models.train_lora.load_state_dict(collect_sd) 97 | 98 | 99 | models.generator.eval() 100 | models.train_norm.eval() 101 | 102 | 103 | height_lr, width_lr = get_target_lr_size(height / width, std_size=32) 104 | stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) 105 | stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) 106 | 107 | # Stage C Parameters 108 | 109 | extras.sampling_configs['cfg'] = 4 110 | extras.sampling_configs['shift'] = 1 111 | extras.sampling_configs['timesteps'] = 20 112 | extras.sampling_configs['t_start'] = 1.0 113 | extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) 114 | 115 | 116 | 117 | # Stage B Parameters 118 | 119 | extras_b.sampling_configs['cfg'] = 1.1 120 | extras_b.sampling_configs['shift'] = 1 121 | extras_b.sampling_configs['timesteps'] = 10 122 | extras_b.sampling_configs['t_start'] = 1.0 123 | 124 | 125 | for cnt, caption in enumerate(captions): 126 | 127 | batch = {'captions': [caption] * batch_size} 128 | conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) 129 | unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) 130 | 131 | conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) 132 | unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) 133 | 134 | 135 | 136 | 137 | for cnt, caption in enumerate(captions): 138 | 139 | 140 | batch = {'captions': [caption] * batch_size} 141 | conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) 142 | unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) 143 | 144 | conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) 145 | unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) 146 | 147 | 148 | with torch.no_grad(): 149 | 150 | 151 | models.generator.cuda() 152 | print('STAGE C GENERATION***************************') 153 | with torch.cuda.amp.autocast(dtype=dtype): 154 | sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) 155 | 156 | 157 | 158 | models.generator.cpu() 159 | torch.cuda.empty_cache() 160 | 161 | conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) 162 | unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) 163 | conditions_b['effnet'] = sampled_c 164 | unconditions_b['effnet'] = torch.zeros_like(sampled_c) 165 | print('STAGE B + A DECODING***************************') 166 | 167 | with torch.cuda.amp.autocast(dtype=dtype): 168 | sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) 169 | 170 | torch.cuda.empty_cache() 171 | imgs = show_images(sampled) 172 | for idx, img in enumerate(imgs): 173 | print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx) 174 | img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg')) 175 | 176 | 177 | print('finished! Results at ', save_dir ) 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /inference/test_t2i.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import yaml 4 | import torch 5 | from tqdm import tqdm 6 | import sys 7 | sys.path.append(os.path.abspath('./')) 8 | from inference.utils import * 9 | from core.utils import load_or_fail 10 | from train import WurstCoreB 11 | from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight 12 | from train import WurstCore_t2i as WurstCoreC 13 | import torch.nn.functional as F 14 | from core.utils import load_or_fail 15 | import numpy as np 16 | import random 17 | import math 18 | import argparse 19 | from einops import rearrange 20 | import math 21 | #inrfft_3b_strc_WurstCore 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( '--height', type=int, default=2560, help='image height') 25 | parser.add_argument('--width', type=int, default=5120, help='image width') 26 | parser.add_argument('--seed', type=int, default=123, help='random seed') 27 | parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') 28 | parser.add_argument('--config_c', type=str, 29 | default='configs/training/t2i.yaml' ,help='config file for stage c, latent generation') 30 | parser.add_argument('--config_b', type=str, 31 | default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') 32 | parser.add_argument( '--prompt', type=str, 33 | default='A photo-realistic image of a west highland white terrier in the garden, high quality, detail rich, 8K', help='text prompt') 34 | parser.add_argument( '--num_image', type=int, default=10, help='how many images generated') 35 | parser.add_argument( '--output_dir', type=str, default='figures/output_results/', help='output directory for generated image') 36 | parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') 37 | parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | args = parse_args() 46 | print(args) 47 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 48 | print(device) 49 | torch.manual_seed(args.seed) 50 | random.seed(args.seed) 51 | np.random.seed(args.seed) 52 | dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float 53 | #gdf = gdf_refine( 54 | # schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), 55 | # input_scaler=VPScaler(), target=EpsilonTarget(), 56 | # noise_cond=CosineTNoiseCond(), 57 | # loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), 58 | # ) 59 | # SETUP STAGE C 60 | config_file = args.config_c 61 | with open(config_file, "r", encoding="utf-8") as file: 62 | loaded_config = yaml.safe_load(file) 63 | 64 | core = WurstCoreC(config_dict=loaded_config, device=device, training=False) 65 | 66 | # SETUP STAGE B 67 | config_file_b = args.config_b 68 | with open(config_file_b, "r", encoding="utf-8") as file: 69 | config_file_b = yaml.safe_load(file) 70 | 71 | core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) 72 | 73 | extras = core.setup_extras_pre() 74 | models = core.setup_models(extras) 75 | models.generator.eval().requires_grad_(False) 76 | print("STAGE C READY") 77 | 78 | extras_b = core_b.setup_extras_pre() 79 | models_b = core_b.setup_models(extras_b, skip_clip=True) 80 | models_b = WurstCoreB.Models( 81 | **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} 82 | ) 83 | models_b.generator.bfloat16().eval().requires_grad_(False) 84 | print("STAGE B READY") 85 | 86 | captions = [args.prompt] * args.num_image 87 | 88 | 89 | height, width = args.height, args.width 90 | save_dir = args.output_dir 91 | 92 | if not os.path.exists(save_dir): 93 | os.makedirs(save_dir) 94 | 95 | pretrained_path = args.pretrained_path 96 | sdd = torch.load(pretrained_path, map_location='cpu') 97 | collect_sd = {} 98 | for k, v in sdd.items(): 99 | collect_sd[k[7:]] = v 100 | 101 | models.train_norm.load_state_dict(collect_sd) 102 | 103 | 104 | models.generator.eval() 105 | models.train_norm.eval() 106 | 107 | batch_size=1 108 | height_lr, width_lr = get_target_lr_size(height / width, std_size=32) 109 | stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) 110 | stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) 111 | 112 | # Stage C Parameters 113 | extras.sampling_configs['cfg'] = 4 114 | extras.sampling_configs['shift'] = 1 115 | extras.sampling_configs['timesteps'] = 20 116 | extras.sampling_configs['t_start'] = 1.0 117 | extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) 118 | 119 | 120 | 121 | # Stage B Parameters 122 | extras_b.sampling_configs['cfg'] = 1.1 123 | extras_b.sampling_configs['shift'] = 1 124 | extras_b.sampling_configs['timesteps'] = 10 125 | extras_b.sampling_configs['t_start'] = 1.0 126 | 127 | 128 | 129 | 130 | for cnt, caption in enumerate(captions): 131 | 132 | 133 | batch = {'captions': [caption] * batch_size} 134 | conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) 135 | unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) 136 | 137 | conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) 138 | unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) 139 | 140 | 141 | with torch.no_grad(): 142 | 143 | 144 | models.generator.cuda() 145 | print('STAGE C GENERATION***************************') 146 | with torch.cuda.amp.autocast(dtype=dtype): 147 | sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) 148 | 149 | 150 | 151 | models.generator.cpu() 152 | torch.cuda.empty_cache() 153 | 154 | conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) 155 | unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) 156 | conditions_b['effnet'] = sampled_c 157 | unconditions_b['effnet'] = torch.zeros_like(sampled_c) 158 | print('STAGE B + A DECODING***************************') 159 | 160 | with torch.cuda.amp.autocast(dtype=dtype): 161 | sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) 162 | 163 | torch.cuda.empty_cache() 164 | imgs = show_images(sampled) 165 | for idx, img in enumerate(imgs): 166 | print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx) 167 | img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg')) 168 | 169 | 170 | print('finished! Results at ', save_dir ) 171 | -------------------------------------------------------------------------------- /inference/utils.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import torch 3 | import requests 4 | import torchvision 5 | from math import ceil 6 | from io import BytesIO 7 | import matplotlib.pyplot as plt 8 | import torchvision.transforms.functional as F 9 | import math 10 | from tqdm import tqdm 11 | def download_image(url): 12 | return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB") 13 | 14 | 15 | def resize_image(image, size=768): 16 | tensor_image = F.to_tensor(image) 17 | resized_image = F.resize(tensor_image, size, antialias=True) 18 | return resized_image 19 | 20 | 21 | def downscale_images(images, factor=3/4): 22 | scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32) 23 | scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) 24 | return scaled_image 25 | 26 | 27 | 28 | def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0): 29 | resolution_multiple = 42.67 30 | latent_height = ceil(height / compression_factor_b) 31 | latent_width = ceil(width / compression_factor_b) 32 | stage_c_latent_shape = (batch_size, 16, latent_height, latent_width) 33 | 34 | latent_height = ceil(height / compression_factor_a) 35 | latent_width = ceil(width / compression_factor_a) 36 | stage_b_latent_shape = (batch_size, 4, latent_height, latent_width) 37 | 38 | return stage_c_latent_shape, stage_b_latent_shape 39 | 40 | 41 | def get_views(H, W, window_size=64, stride=16): 42 | ''' 43 | - H, W: height and width of the latent 44 | ''' 45 | num_blocks_height = (H - window_size) // stride + 1 46 | num_blocks_width = (W - window_size) // stride + 1 47 | total_num_blocks = int(num_blocks_height * num_blocks_width) 48 | views = [] 49 | for i in range(total_num_blocks): 50 | h_start = int((i // num_blocks_width) * stride) 51 | h_end = h_start + window_size 52 | w_start = int((i % num_blocks_width) * stride) 53 | w_end = w_start + window_size 54 | views.append((h_start, h_end, w_start, w_end)) 55 | return views 56 | 57 | 58 | 59 | def show_images(images, rows=None, cols=None, **kwargs): 60 | if images.size(1) == 1: 61 | images = images.repeat(1, 3, 1, 1) 62 | elif images.size(1) > 3: 63 | images = images[:, :3] 64 | 65 | if rows is None: 66 | rows = 1 67 | if cols is None: 68 | cols = images.size(0) // rows 69 | 70 | _, _, h, w = images.shape 71 | 72 | imgs = [] 73 | for i, img in enumerate(images): 74 | imgs.append( torchvision.transforms.functional.to_pil_image(img.clamp(0, 1))) 75 | 76 | return imgs 77 | 78 | 79 | 80 | def decode_b(conditions_b, unconditions_b, models_b, bshape, extras_b, device, \ 81 | stage_a_tiled=False, num_instance=4, patch_size=256, stride=24): 82 | 83 | 84 | sampling_b = extras_b.gdf.sample( 85 | models_b.generator.half(), conditions_b, bshape, 86 | unconditions_b, device=device, 87 | **extras_b.sampling_configs, 88 | ) 89 | models_b.generator.cuda() 90 | for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']): 91 | sampled_b = sampled_b 92 | models_b.generator.cpu() 93 | torch.cuda.empty_cache() 94 | if stage_a_tiled: 95 | with torch.cuda.amp.autocast(dtype=torch.float16): 96 | padding = (stride*2, stride*2, stride*2, stride*2) 97 | sampled_b = torch.nn.functional.pad(sampled_b, padding, mode='reflect') 98 | count = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device) 99 | sampled = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device) 100 | views = get_views(sampled_b.shape[-2], sampled_b.shape[-1], window_size=patch_size, stride=stride) 101 | 102 | for view_idx, (h_start, h_end, w_start, w_end) in enumerate(tqdm(views, total=len(views))): 103 | 104 | sampled[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += models_b.stage_a.decode(sampled_b[:, :, h_start:h_end, w_start:w_end]).float() 105 | count[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += 1 106 | sampled /= count 107 | sampled = sampled[:, :, stride*4*2:-stride*4*2, stride*4*2:-stride*4*2] 108 | else: 109 | 110 | sampled = models_b.stage_a.decode(sampled_b, tiled_decoding=stage_a_tiled) 111 | 112 | return sampled.float() 113 | 114 | 115 | def generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions=None, unconditions=None): 116 | if conditions is None: 117 | conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) 118 | if unconditions is None: 119 | unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) 120 | sampling_c = extras.gdf.sample( 121 | models.generator, conditions, stage_c_latent_shape, stage_c_latent_shape_lr, 122 | unconditions, device=device, **extras.sampling_configs, 123 | ) 124 | for idx, (sampled_c, sampled_c_curr, _, _) in enumerate(tqdm(sampling_c, total=extras.sampling_configs['timesteps'])): 125 | sampled_c = sampled_c 126 | return sampled_c 127 | 128 | def get_target_lr_size(ratio, std_size=24): 129 | w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) 130 | return (h * 32 , w *32 ) 131 | 132 | -------------------------------------------------------------------------------- /models/models_checklist.txt: -------------------------------------------------------------------------------- 1 | https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors 2 | https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors 3 | https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors 4 | https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors 5 | https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors 6 | https://huggingface.co/roubaofeipi/UltraPixel/blob/main/ultrapixel_t2i.safetensors 7 | https://huggingface.co/roubaofeipi/UltraPixel/blob/main/lora_cat.safetensors (only required for personalization) -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .effnet import EfficientNetEncoder 2 | from .stage_c import StageC 3 | from .stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock 4 | from .previewer import Previewer 5 | from .controlnet import ControlNet, ControlNetDeliverer 6 | from . import controlnet as controlnet_filters 7 | -------------------------------------------------------------------------------- /modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc -------------------------------------------------------------------------------- /modules/cnet_modules/face_id/arcface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx, onnx2torch, cv2 3 | import torch 4 | from insightface.utils import face_align 5 | 6 | 7 | class ArcFaceRecognizer: 8 | def __init__(self, model_file=None, device='cpu', dtype=torch.float32): 9 | assert model_file is not None 10 | self.model_file = model_file 11 | 12 | self.device = device 13 | self.dtype = dtype 14 | self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) 15 | for param in self.model.parameters(): 16 | param.requires_grad = False 17 | self.model.eval() 18 | 19 | self.input_mean = 127.5 20 | self.input_std = 127.5 21 | self.input_size = (112, 112) 22 | self.input_shape = ['None', 3, 112, 112] 23 | 24 | def get(self, img, face): 25 | aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0]) 26 | face.embedding = self.get_feat(aimg).flatten() 27 | return face.embedding 28 | 29 | def compute_sim(self, feat1, feat2): 30 | from numpy.linalg import norm 31 | feat1 = feat1.ravel() 32 | feat2 = feat2.ravel() 33 | sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) 34 | return sim 35 | 36 | def get_feat(self, imgs): 37 | if not isinstance(imgs, list): 38 | imgs = [imgs] 39 | input_size = self.input_size 40 | 41 | blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, 42 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 43 | 44 | blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) 45 | net_out = self.model(blob_torch) 46 | return net_out[0].float().cpu() 47 | 48 | 49 | def distance2bbox(points, distance, max_shape=None): 50 | """Decode distance prediction to bounding box. 51 | 52 | Args: 53 | points (Tensor): Shape (n, 2), [x, y]. 54 | distance (Tensor): Distance from the given point to 4 55 | boundaries (left, top, right, bottom). 56 | max_shape (tuple): Shape of the image. 57 | 58 | Returns: 59 | Tensor: Decoded bboxes. 60 | """ 61 | x1 = points[:, 0] - distance[:, 0] 62 | y1 = points[:, 1] - distance[:, 1] 63 | x2 = points[:, 0] + distance[:, 2] 64 | y2 = points[:, 1] + distance[:, 3] 65 | if max_shape is not None: 66 | x1 = x1.clamp(min=0, max=max_shape[1]) 67 | y1 = y1.clamp(min=0, max=max_shape[0]) 68 | x2 = x2.clamp(min=0, max=max_shape[1]) 69 | y2 = y2.clamp(min=0, max=max_shape[0]) 70 | return np.stack([x1, y1, x2, y2], axis=-1) 71 | 72 | 73 | def distance2kps(points, distance, max_shape=None): 74 | """Decode distance prediction to bounding box. 75 | 76 | Args: 77 | points (Tensor): Shape (n, 2), [x, y]. 78 | distance (Tensor): Distance from the given point to 4 79 | boundaries (left, top, right, bottom). 80 | max_shape (tuple): Shape of the image. 81 | 82 | Returns: 83 | Tensor: Decoded bboxes. 84 | """ 85 | preds = [] 86 | for i in range(0, distance.shape[1], 2): 87 | px = points[:, i % 2] + distance[:, i] 88 | py = points[:, i % 2 + 1] + distance[:, i + 1] 89 | if max_shape is not None: 90 | px = px.clamp(min=0, max=max_shape[1]) 91 | py = py.clamp(min=0, max=max_shape[0]) 92 | preds.append(px) 93 | preds.append(py) 94 | return np.stack(preds, axis=-1) 95 | 96 | 97 | class FaceDetector: 98 | def __init__(self, model_file=None, dtype=torch.float32, device='cuda'): 99 | self.model_file = model_file 100 | self.taskname = 'detection' 101 | self.center_cache = {} 102 | self.nms_thresh = 0.4 103 | self.det_thresh = 0.5 104 | 105 | self.device = device 106 | self.dtype = dtype 107 | self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) 108 | for param in self.model.parameters(): 109 | param.requires_grad = False 110 | self.model.eval() 111 | 112 | input_shape = (320, 320) 113 | self.input_size = input_shape 114 | self.input_shape = input_shape 115 | 116 | self.input_mean = 127.5 117 | self.input_std = 128.0 118 | self._anchor_ratio = 1.0 119 | self._num_anchors = 1 120 | self.fmc = 3 121 | self._feat_stride_fpn = [8, 16, 32] 122 | self._num_anchors = 2 123 | self.use_kps = True 124 | 125 | self.det_thresh = 0.5 126 | self.nms_thresh = 0.4 127 | 128 | def forward(self, img, threshold): 129 | scores_list = [] 130 | bboxes_list = [] 131 | kpss_list = [] 132 | input_size = tuple(img.shape[0:2][::-1]) 133 | blob = cv2.dnn.blobFromImage(img, 1.0 / self.input_std, input_size, 134 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 135 | blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) 136 | net_outs_torch = self.model(blob_torch) 137 | # print(list(map(lambda x: x.shape, net_outs_torch))) 138 | net_outs = list(map(lambda x: x.float().cpu().numpy(), net_outs_torch)) 139 | 140 | input_height = blob.shape[2] 141 | input_width = blob.shape[3] 142 | fmc = self.fmc 143 | for idx, stride in enumerate(self._feat_stride_fpn): 144 | scores = net_outs[idx] 145 | bbox_preds = net_outs[idx + fmc] 146 | bbox_preds = bbox_preds * stride 147 | if self.use_kps: 148 | kps_preds = net_outs[idx + fmc * 2] * stride 149 | height = input_height // stride 150 | width = input_width // stride 151 | K = height * width 152 | key = (height, width, stride) 153 | if key in self.center_cache: 154 | anchor_centers = self.center_cache[key] 155 | else: 156 | # solution-1, c style: 157 | # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) 158 | # for i in range(height): 159 | # anchor_centers[i, :, 1] = i 160 | # for i in range(width): 161 | # anchor_centers[:, i, 0] = i 162 | 163 | # solution-2: 164 | # ax = np.arange(width, dtype=np.float32) 165 | # ay = np.arange(height, dtype=np.float32) 166 | # xv, yv = np.meshgrid(np.arange(width), np.arange(height)) 167 | # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) 168 | 169 | # solution-3: 170 | anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) 171 | # print(anchor_centers.shape) 172 | 173 | anchor_centers = (anchor_centers * stride).reshape((-1, 2)) 174 | if self._num_anchors > 1: 175 | anchor_centers = np.stack([anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2)) 176 | if len(self.center_cache) < 100: 177 | self.center_cache[key] = anchor_centers 178 | 179 | pos_inds = np.where(scores >= threshold)[0] 180 | bboxes = distance2bbox(anchor_centers, bbox_preds) 181 | pos_scores = scores[pos_inds] 182 | pos_bboxes = bboxes[pos_inds] 183 | scores_list.append(pos_scores) 184 | bboxes_list.append(pos_bboxes) 185 | if self.use_kps: 186 | kpss = distance2kps(anchor_centers, kps_preds) 187 | # kpss = kps_preds 188 | kpss = kpss.reshape((kpss.shape[0], -1, 2)) 189 | pos_kpss = kpss[pos_inds] 190 | kpss_list.append(pos_kpss) 191 | return scores_list, bboxes_list, kpss_list 192 | 193 | def detect(self, img, input_size=None, max_num=0, metric='default'): 194 | assert input_size is not None or self.input_size is not None 195 | input_size = self.input_size if input_size is None else input_size 196 | 197 | im_ratio = float(img.shape[0]) / img.shape[1] 198 | model_ratio = float(input_size[1]) / input_size[0] 199 | if im_ratio > model_ratio: 200 | new_height = input_size[1] 201 | new_width = int(new_height / im_ratio) 202 | else: 203 | new_width = input_size[0] 204 | new_height = int(new_width * im_ratio) 205 | det_scale = float(new_height) / img.shape[0] 206 | resized_img = cv2.resize(img, (new_width, new_height)) 207 | det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8) 208 | det_img[:new_height, :new_width, :] = resized_img 209 | 210 | scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) 211 | 212 | scores = np.vstack(scores_list) 213 | scores_ravel = scores.ravel() 214 | order = scores_ravel.argsort()[::-1] 215 | bboxes = np.vstack(bboxes_list) / det_scale 216 | if self.use_kps: 217 | kpss = np.vstack(kpss_list) / det_scale 218 | pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) 219 | pre_det = pre_det[order, :] 220 | keep = self.nms(pre_det) 221 | det = pre_det[keep, :] 222 | if self.use_kps: 223 | kpss = kpss[order, :, :] 224 | kpss = kpss[keep, :, :] 225 | else: 226 | kpss = None 227 | if max_num > 0 and det.shape[0] > max_num: 228 | area = (det[:, 2] - det[:, 0]) * (det[:, 3] - 229 | det[:, 1]) 230 | img_center = img.shape[0] // 2, img.shape[1] // 2 231 | offsets = np.vstack([ 232 | (det[:, 0] + det[:, 2]) / 2 - img_center[1], 233 | (det[:, 1] + det[:, 3]) / 2 - img_center[0] 234 | ]) 235 | offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) 236 | if metric == 'max': 237 | values = area 238 | else: 239 | values = area - offset_dist_squared * 2.0 # some extra weight on the centering 240 | bindex = np.argsort( 241 | values)[::-1] # some extra weight on the centering 242 | bindex = bindex[0:max_num] 243 | det = det[bindex, :] 244 | if kpss is not None: 245 | kpss = kpss[bindex, :] 246 | return det, kpss 247 | 248 | def nms(self, dets): 249 | thresh = self.nms_thresh 250 | x1 = dets[:, 0] 251 | y1 = dets[:, 1] 252 | x2 = dets[:, 2] 253 | y2 = dets[:, 3] 254 | scores = dets[:, 4] 255 | 256 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 257 | order = scores.argsort()[::-1] 258 | 259 | keep = [] 260 | while order.size > 0: 261 | i = order[0] 262 | keep.append(i) 263 | xx1 = np.maximum(x1[i], x1[order[1:]]) 264 | yy1 = np.maximum(y1[i], y1[order[1:]]) 265 | xx2 = np.minimum(x2[i], x2[order[1:]]) 266 | yy2 = np.minimum(y2[i], y2[order[1:]]) 267 | 268 | w = np.maximum(0.0, xx2 - xx1 + 1) 269 | h = np.maximum(0.0, yy2 - yy1 + 1) 270 | inter = w * h 271 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 272 | 273 | inds = np.where(ovr <= thresh)[0] 274 | order = order[inds + 1] 275 | 276 | return keep 277 | -------------------------------------------------------------------------------- /modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc -------------------------------------------------------------------------------- /modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc -------------------------------------------------------------------------------- /modules/cnet_modules/inpainting/saliency_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/inpainting/saliency_model.pt -------------------------------------------------------------------------------- /modules/cnet_modules/inpainting/saliency_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | 8 | 9 | # MICRO RESNET 10 | class ResBlock(nn.Module): 11 | def __init__(self, channels): 12 | super(ResBlock, self).__init__() 13 | 14 | self.resblock = nn.Sequential( 15 | nn.ReflectionPad2d(1), 16 | nn.Conv2d(channels, channels, kernel_size=3), 17 | nn.InstanceNorm2d(channels, affine=True), 18 | nn.ReLU(), 19 | nn.ReflectionPad2d(1), 20 | nn.Conv2d(channels, channels, kernel_size=3), 21 | nn.InstanceNorm2d(channels, affine=True), 22 | ) 23 | 24 | def forward(self, x): 25 | out = self.resblock(x) 26 | return out + x 27 | 28 | 29 | class Upsample2d(nn.Module): 30 | def __init__(self, scale_factor): 31 | super(Upsample2d, self).__init__() 32 | 33 | self.interp = nn.functional.interpolate 34 | self.scale_factor = scale_factor 35 | 36 | def forward(self, x): 37 | x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') 38 | return x 39 | 40 | 41 | class MicroResNet(nn.Module): 42 | def __init__(self): 43 | super(MicroResNet, self).__init__() 44 | 45 | self.downsampler = nn.Sequential( 46 | nn.ReflectionPad2d(4), 47 | nn.Conv2d(3, 8, kernel_size=9, stride=4), 48 | nn.InstanceNorm2d(8, affine=True), 49 | nn.ReLU(), 50 | nn.ReflectionPad2d(1), 51 | nn.Conv2d(8, 16, kernel_size=3, stride=2), 52 | nn.InstanceNorm2d(16, affine=True), 53 | nn.ReLU(), 54 | nn.ReflectionPad2d(1), 55 | nn.Conv2d(16, 32, kernel_size=3, stride=2), 56 | nn.InstanceNorm2d(32, affine=True), 57 | nn.ReLU(), 58 | ) 59 | 60 | self.residual = nn.Sequential( 61 | ResBlock(32), 62 | nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), 63 | ResBlock(64), 64 | ) 65 | 66 | self.segmentator = nn.Sequential( 67 | nn.ReflectionPad2d(1), 68 | nn.Conv2d(64, 16, kernel_size=3), 69 | nn.InstanceNorm2d(16, affine=True), 70 | nn.ReLU(), 71 | Upsample2d(scale_factor=2), 72 | nn.ReflectionPad2d(4), 73 | nn.Conv2d(16, 1, kernel_size=9), 74 | nn.Sigmoid() 75 | ) 76 | 77 | def forward(self, x): 78 | out = self.downsampler(x) 79 | out = self.residual(out) 80 | out = self.segmentator(out) 81 | return out 82 | -------------------------------------------------------------------------------- /modules/cnet_modules/pidinet/__init__.py: -------------------------------------------------------------------------------- 1 | # Pidinet 2 | # https://github.com/hellozhuo/pidinet 3 | 4 | import os 5 | import torch 6 | import numpy as np 7 | from einops import rearrange 8 | from .model import pidinet 9 | from .util import annotator_ckpts_path, safe_step 10 | 11 | 12 | class PidiNetDetector: 13 | def __init__(self, device): 14 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth" 15 | modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth") 16 | if not os.path.exists(modelpath): 17 | from basicsr.utils.download_util import load_file_from_url 18 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 19 | self.netNetwork = pidinet() 20 | self.netNetwork.load_state_dict( 21 | {k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()}) 22 | self.netNetwork.to(device).eval().requires_grad_(False) 23 | 24 | def __call__(self, input_image): # , safe=False): 25 | return self.netNetwork(input_image)[-1] 26 | # assert input_image.ndim == 3 27 | # input_image = input_image[:, :, ::-1].copy() 28 | # with torch.no_grad(): 29 | # image_pidi = torch.from_numpy(input_image).float().cuda() 30 | # image_pidi = image_pidi / 255.0 31 | # image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') 32 | # edge = self.netNetwork(image_pidi)[-1] 33 | 34 | # if safe: 35 | # edge = safe_step(edge) 36 | # edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 37 | # return edge[0][0] 38 | -------------------------------------------------------------------------------- /modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth -------------------------------------------------------------------------------- /modules/cnet_modules/pidinet/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import cv2 5 | import os 6 | 7 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 8 | 9 | 10 | def HWC3(x): 11 | assert x.dtype == np.uint8 12 | if x.ndim == 2: 13 | x = x[:, :, None] 14 | assert x.ndim == 3 15 | H, W, C = x.shape 16 | assert C == 1 or C == 3 or C == 4 17 | if C == 3: 18 | return x 19 | if C == 1: 20 | return np.concatenate([x, x, x], axis=2) 21 | if C == 4: 22 | color = x[:, :, 0:3].astype(np.float32) 23 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 24 | y = color * alpha + 255.0 * (1.0 - alpha) 25 | y = y.clip(0, 255).astype(np.uint8) 26 | return y 27 | 28 | 29 | def resize_image(input_image, resolution): 30 | H, W, C = input_image.shape 31 | H = float(H) 32 | W = float(W) 33 | k = float(resolution) / min(H, W) 34 | H *= k 35 | W *= k 36 | H = int(np.round(H / 64.0)) * 64 37 | W = int(np.round(W / 64.0)) * 64 38 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 39 | return img 40 | 41 | 42 | def nms(x, t, s): 43 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) 44 | 45 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) 46 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) 47 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) 48 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) 49 | 50 | y = np.zeros_like(x) 51 | 52 | for f in [f1, f2, f3, f4]: 53 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x) 54 | 55 | z = np.zeros_like(y, dtype=np.uint8) 56 | z[y > t] = 255 57 | return z 58 | 59 | 60 | def make_noise_disk(H, W, C, F): 61 | noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) 62 | noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) 63 | noise = noise[F: F + H, F: F + W] 64 | noise -= np.min(noise) 65 | noise /= np.max(noise) 66 | if C == 1: 67 | noise = noise[:, :, None] 68 | return noise 69 | 70 | 71 | def min_max_norm(x): 72 | x -= np.min(x) 73 | x /= np.maximum(np.max(x), 1e-5) 74 | return x 75 | 76 | 77 | def safe_step(x, step=2): 78 | y = x.astype(np.float32) * float(step + 1) 79 | y = y.astype(np.int32).astype(np.float32) / float(step) 80 | return y 81 | 82 | 83 | def img2mask(img, H, W, low=10, high=90): 84 | assert img.ndim == 3 or img.ndim == 2 85 | assert img.dtype == np.uint8 86 | 87 | if img.ndim == 3: 88 | y = img[:, :, random.randrange(0, img.shape[2])] 89 | else: 90 | y = img 91 | 92 | y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) 93 | 94 | if random.uniform(0, 1) < 0.5: 95 | y = 255 - y 96 | 97 | return y < np.percentile(y, random.randrange(low, high)) 98 | -------------------------------------------------------------------------------- /modules/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from einops import rearrange 6 | import torch.fft as fft 7 | class Linear(torch.nn.Linear): 8 | def reset_parameters(self): 9 | return None 10 | 11 | class Conv2d(torch.nn.Conv2d): 12 | def reset_parameters(self): 13 | return None 14 | 15 | 16 | 17 | class Attention2D(nn.Module): 18 | def __init__(self, c, nhead, dropout=0.0): 19 | super().__init__() 20 | self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) 21 | 22 | def forward(self, x, kv, self_attn=False): 23 | orig_shape = x.shape 24 | x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 25 | if self_attn: 26 | #print('in line 23 algong self att ', kv.shape, x.shape) 27 | kv = torch.cat([x, kv], dim=1) 28 | #if x.shape[1] >= 72 * 72: 29 | # x = x * math.sqrt(math.log(64*64, 24*24)) 30 | 31 | x = self.attn(x, kv, kv, need_weights=False)[0] 32 | x = x.permute(0, 2, 1).view(*orig_shape) 33 | return x 34 | 35 | 36 | class LayerNorm2d(nn.LayerNorm): 37 | def __init__(self, *args, **kwargs): 38 | super().__init__(*args, **kwargs) 39 | 40 | def forward(self, x): 41 | return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 42 | 43 | class GlobalResponseNorm(nn.Module): 44 | "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" 45 | def __init__(self, dim): 46 | super().__init__() 47 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 48 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 49 | 50 | def forward(self, x): 51 | Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) 52 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 53 | return self.gamma * (x * Nx) + self.beta + x 54 | 55 | 56 | class ResBlock(nn.Module): 57 | def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2): 58 | super().__init__() 59 | self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) 60 | # self.depthwise = SAMBlock(c, num_heads, expansion) 61 | self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) 62 | self.channelwise = nn.Sequential( 63 | Linear(c + c_skip, c * 4), 64 | nn.GELU(), 65 | GlobalResponseNorm(c * 4), 66 | nn.Dropout(dropout), 67 | Linear(c * 4, c) 68 | ) 69 | 70 | def forward(self, x, x_skip=None): 71 | x_res = x 72 | x = self.norm(self.depthwise(x)) 73 | if x_skip is not None: 74 | x = torch.cat([x, x_skip], dim=1) 75 | x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 76 | return x + x_res 77 | 78 | 79 | class AttnBlock(nn.Module): 80 | def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): 81 | super().__init__() 82 | self.self_attn = self_attn 83 | self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) 84 | self.attention = Attention2D(c, nhead, dropout) 85 | self.kv_mapper = nn.Sequential( 86 | nn.SiLU(), 87 | Linear(c_cond, c) 88 | ) 89 | 90 | def forward(self, x, kv): 91 | kv = self.kv_mapper(kv) 92 | res = self.attention(self.norm(x), kv, self_attn=self.self_attn) 93 | 94 | #print(torch.unique(res), torch.unique(x), self.self_attn) 95 | #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24)) 96 | x = x + res 97 | 98 | return x 99 | 100 | class FeedForwardBlock(nn.Module): 101 | def __init__(self, c, dropout=0.0): 102 | super().__init__() 103 | self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) 104 | self.channelwise = nn.Sequential( 105 | Linear(c, c * 4), 106 | nn.GELU(), 107 | GlobalResponseNorm(c * 4), 108 | nn.Dropout(dropout), 109 | Linear(c * 4, c) 110 | ) 111 | 112 | def forward(self, x): 113 | x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 114 | return x 115 | 116 | 117 | class TimestepBlock(nn.Module): 118 | def __init__(self, c, c_timestep, conds=['sca']): 119 | super().__init__() 120 | self.mapper = Linear(c_timestep, c * 2) 121 | self.conds = conds 122 | for cname in conds: 123 | setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) 124 | 125 | def forward(self, x, t): 126 | t = t.chunk(len(self.conds) + 1, dim=1) 127 | a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) 128 | for i, c in enumerate(self.conds): 129 | ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) 130 | a, b = a + ac, b + bc 131 | return x * (1 + a) + b 132 | -------------------------------------------------------------------------------- /modules/effnet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import nn 3 | 4 | 5 | # EfficientNet 6 | class EfficientNetEncoder(nn.Module): 7 | def __init__(self, c_latent=16): 8 | super().__init__() 9 | self.backbone = torchvision.models.efficientnet_v2_s().features.eval() 10 | self.mapper = nn.Sequential( 11 | nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), 12 | nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 13 | ) 14 | 15 | def forward(self, x): 16 | return self.mapper(self.backbone(x)) 17 | 18 | -------------------------------------------------------------------------------- /modules/lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LoRA(nn.Module): 6 | def __init__(self, layer, name='weight', rank=16, alpha=1): 7 | super().__init__() 8 | weight = getattr(layer, name) 9 | self.lora_down = nn.Parameter(torch.zeros((rank, weight.size(1)))) 10 | self.lora_up = nn.Parameter(torch.zeros((weight.size(0), rank))) 11 | nn.init.normal_(self.lora_up, mean=0, std=1) 12 | 13 | self.scale = alpha / rank 14 | self.enabled = True 15 | 16 | def forward(self, original_weights): 17 | if self.enabled: 18 | lora_shape = list(original_weights.shape[:2]) + [1] * (len(original_weights.shape) - 2) 19 | lora_weights = torch.matmul(self.lora_up.clone(), self.lora_down.clone()).view(*lora_shape) * self.scale 20 | return original_weights + lora_weights 21 | else: 22 | return original_weights 23 | 24 | 25 | def apply_lora(model, filters=None, rank=16): 26 | def check_parameter(module, name): 27 | return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( 28 | getattr(module, name), nn.Parameter) 29 | 30 | for name, module in model.named_modules(): 31 | if filters is None or any([f in name for f in filters]): 32 | if check_parameter(module, "weight"): 33 | device, dtype = module.weight.device, module.weight.dtype 34 | torch.nn.utils.parametrize.register_parametrization(module, 'weight', LoRA(module, "weight", rank=rank).to(dtype).to(device)) 35 | elif check_parameter(module, "in_proj_weight"): 36 | device, dtype = module.in_proj_weight.device, module.in_proj_weight.dtype 37 | torch.nn.utils.parametrize.register_parametrization(module, 'in_proj_weight', LoRA(module, "in_proj_weight", rank=rank).to(dtype).to(device)) 38 | 39 | 40 | class ReToken(nn.Module): 41 | def __init__(self, indices=None): 42 | super().__init__() 43 | assert indices is not None 44 | self.embeddings = nn.Parameter(torch.zeros(len(indices), 1280)) 45 | self.register_buffer('indices', torch.tensor(indices)) 46 | self.enabled = True 47 | 48 | def forward(self, embeddings): 49 | if self.enabled: 50 | embeddings = embeddings.clone() 51 | for i, idx in enumerate(self.indices): 52 | embeddings[idx] += self.embeddings[i] 53 | return embeddings 54 | 55 | 56 | def apply_retoken(module, indices=None): 57 | def check_parameter(module, name): 58 | return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( 59 | getattr(module, name), nn.Parameter) 60 | 61 | if check_parameter(module, "weight"): 62 | device, dtype = module.weight.device, module.weight.dtype 63 | torch.nn.utils.parametrize.register_parametrization(module, 'weight', ReToken(indices=indices).to(dtype).to(device)) 64 | 65 | 66 | def remove_lora(model, leave_parametrized=True): 67 | for module in model.modules(): 68 | if torch.nn.utils.parametrize.is_parametrized(module, "weight"): 69 | nn.utils.parametrize.remove_parametrizations(module, "weight", leave_parametrized=leave_parametrized) 70 | elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): 71 | nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight", leave_parametrized=leave_parametrized) 72 | -------------------------------------------------------------------------------- /modules/previewer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 5 | class Previewer(nn.Module): 6 | def __init__(self, c_in=16, c_hidden=512, c_out=3): 7 | super().__init__() 8 | self.blocks = nn.Sequential( 9 | nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels 10 | nn.GELU(), 11 | nn.BatchNorm2d(c_hidden), 12 | 13 | nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), 14 | nn.GELU(), 15 | nn.BatchNorm2d(c_hidden), 16 | 17 | nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 18 | nn.GELU(), 19 | nn.BatchNorm2d(c_hidden // 2), 20 | 21 | nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), 22 | nn.GELU(), 23 | nn.BatchNorm2d(c_hidden // 2), 24 | 25 | nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 26 | nn.GELU(), 27 | nn.BatchNorm2d(c_hidden // 4), 28 | 29 | nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), 30 | nn.GELU(), 31 | nn.BatchNorm2d(c_hidden // 4), 32 | 33 | nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 34 | nn.GELU(), 35 | nn.BatchNorm2d(c_hidden // 4), 36 | 37 | nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), 38 | nn.GELU(), 39 | nn.BatchNorm2d(c_hidden // 4), 40 | 41 | nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), 42 | ) 43 | 44 | def forward(self, x): 45 | return self.blocks(x) 46 | -------------------------------------------------------------------------------- /modules/resnet.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catcathh/UltraPixel/e85d3bc9b0eaedbe1b5aea8a5c2624caab9df1f3/modules/resnet.py -------------------------------------------------------------------------------- /modules/speed_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from einops import repeat 7 | class CheckpointFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, run_function, length, *args): 10 | ctx.run_function = run_function 11 | ctx.input_tensors = list(args[:length]) 12 | ctx.input_params = list(args[length:]) 13 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 14 | "dtype": torch.get_autocast_gpu_dtype(), 15 | "cache_enabled": torch.is_autocast_cache_enabled()} 16 | with torch.no_grad(): 17 | output_tensors = ctx.run_function(*ctx.input_tensors) 18 | return output_tensors 19 | 20 | @staticmethod 21 | def backward(ctx, *output_grads): 22 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 23 | with torch.enable_grad(), \ 24 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 25 | # Fixes a bug where the first op in run_function modifies the 26 | # Tensor storage in place, which is not allowed for detach()'d 27 | # Tensors. 28 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 29 | output_tensors = ctx.run_function(*shallow_copies) 30 | input_grads = torch.autograd.grad( 31 | output_tensors, 32 | ctx.input_tensors + ctx.input_params, 33 | output_grads, 34 | allow_unused=True, 35 | ) 36 | del ctx.input_tensors 37 | del ctx.input_params 38 | del output_tensors 39 | return (None, None) + input_grads 40 | 41 | def checkpoint(func, inputs, params, flag): 42 | """ 43 | Evaluate a function without caching intermediate activations, allowing for 44 | reduced memory at the expense of extra compute in the backward pass. 45 | :param func: the function to evaluate. 46 | :param inputs: the argument sequence to pass to `func`. 47 | :param params: a sequence of parameters `func` depends on but does not 48 | explicitly take as arguments. 49 | :param flag: if False, disable gradient checkpointing. 50 | """ 51 | if flag: 52 | args = tuple(inputs) + tuple(params) 53 | return CheckpointFunction.apply(func, len(inputs), *args) 54 | else: 55 | return func(*inputs) -------------------------------------------------------------------------------- /modules/stage_a.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchtools.nn import VectorQuantize 4 | from einops import rearrange 5 | import torch.nn.functional as F 6 | import math 7 | class ResBlock(nn.Module): 8 | def __init__(self, c, c_hidden): 9 | super().__init__() 10 | # depthwise/attention 11 | self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) 12 | self.depthwise = nn.Sequential( 13 | nn.ReplicationPad2d(1), 14 | nn.Conv2d(c, c, kernel_size=3, groups=c) 15 | ) 16 | 17 | # channelwise 18 | self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) 19 | self.channelwise = nn.Sequential( 20 | nn.Linear(c, c_hidden), 21 | nn.GELU(), 22 | nn.Linear(c_hidden, c), 23 | ) 24 | 25 | self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) 26 | 27 | # Init weights 28 | def _basic_init(module): 29 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 30 | torch.nn.init.xavier_uniform_(module.weight) 31 | if module.bias is not None: 32 | nn.init.constant_(module.bias, 0) 33 | 34 | self.apply(_basic_init) 35 | 36 | def _norm(self, x, norm): 37 | return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 38 | 39 | def forward(self, x): 40 | 41 | mods = self.gammas 42 | 43 | x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] 44 | 45 | #x = x.to(torch.float64) 46 | x = x + self.depthwise(x_temp) * mods[2] 47 | 48 | x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] 49 | x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] 50 | 51 | return x 52 | 53 | 54 | def extract_patches(tensor, patch_size, stride): 55 | b, c, H, W = tensor.shape 56 | pad_h = (patch_size - (H - patch_size) % stride) % stride 57 | pad_w = (patch_size - (W - patch_size) % stride) % stride 58 | tensor = F.pad(tensor, (0, pad_w, 0, pad_h), mode='reflect') 59 | 60 | 61 | patches = tensor.unfold(2, patch_size, stride).unfold(3, patch_size, stride) 62 | patches = patches.contiguous().view(b, c, -1, patch_size, patch_size) 63 | patches = patches.permute(0, 2, 1, 3, 4) 64 | return patches, (H, W) 65 | 66 | def fuse_patches(patches, patch_size, stride, H, W): 67 | 68 | b, num_patches, c, _, _ = patches.shape 69 | patches = patches.permute(0, 2, 1, 3, 4) 70 | 71 | 72 | 73 | pad_h = (patch_size - (H - patch_size) % stride) % stride 74 | pad_w = (patch_size - (W - patch_size) % stride) % stride 75 | out_h = H + pad_h 76 | out_w = W + pad_w 77 | patches = patches.contiguous().view(b, c , -1, patch_size*patch_size ).permute(0, 1, 3, 2) 78 | patches = patches.contiguous().view(b, c*patch_size*patch_size, -1) 79 | 80 | tensor = F.fold(patches, output_size=(out_h, out_w), kernel_size=patch_size, stride=stride) 81 | overlap_cnt = F.fold(torch.ones_like(patches), output_size=(out_h, out_w), kernel_size=patch_size, stride=stride) 82 | tensor = tensor / overlap_cnt 83 | print('end fuse patch', tensor.shape, (tensor.dtype)) 84 | return tensor[:, :, :H, :W] 85 | 86 | 87 | 88 | class StageA(nn.Module): 89 | def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, 90 | scale_factor=0.43): # 0.3764 91 | super().__init__() 92 | self.c_latent = c_latent 93 | self.scale_factor = scale_factor 94 | c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] 95 | 96 | # Encoder blocks 97 | self.in_block = nn.Sequential( 98 | nn.PixelUnshuffle(2), 99 | nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) 100 | ) 101 | down_blocks = [] 102 | for i in range(levels): 103 | if i > 0: 104 | down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) 105 | block = ResBlock(c_levels[i], c_levels[i] * 4) 106 | down_blocks.append(block) 107 | down_blocks.append(nn.Sequential( 108 | nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), 109 | nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 110 | )) 111 | self.down_blocks = nn.Sequential(*down_blocks) 112 | self.down_blocks[0] 113 | 114 | self.codebook_size = codebook_size 115 | self.vquantizer = VectorQuantize(c_latent, k=codebook_size) 116 | 117 | # Decoder blocks 118 | up_blocks = [nn.Sequential( 119 | nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) 120 | )] 121 | for i in range(levels): 122 | for j in range(bottleneck_blocks if i == 0 else 1): 123 | block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) 124 | up_blocks.append(block) 125 | if i < levels - 1: 126 | up_blocks.append( 127 | nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, 128 | padding=1)) 129 | self.up_blocks = nn.Sequential(*up_blocks) 130 | self.out_block = nn.Sequential( 131 | nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), 132 | nn.PixelShuffle(2), 133 | ) 134 | 135 | def encode(self, x, quantize=False): 136 | x = self.in_block(x) 137 | x = self.down_blocks(x) 138 | if quantize: 139 | qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) 140 | return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 141 | else: 142 | return x / self.scale_factor, None, None, None 143 | 144 | 145 | 146 | def decode(self, x, tiled_decoding=False): 147 | x = x * self.scale_factor 148 | x = self.up_blocks(x) 149 | x = self.out_block(x) 150 | return x 151 | 152 | def forward(self, x, quantize=False): 153 | qe, x, _, vq_loss = self.encode(x, quantize) 154 | x = self.decode(qe) 155 | return x, vq_loss 156 | 157 | 158 | class Discriminator(nn.Module): 159 | def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): 160 | super().__init__() 161 | d = max(depth - 3, 3) 162 | layers = [ 163 | nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), 164 | nn.LeakyReLU(0.2), 165 | ] 166 | for i in range(depth - 1): 167 | c_in = c_hidden // (2 ** max((d - i), 0)) 168 | c_out = c_hidden // (2 ** max((d - 1 - i), 0)) 169 | layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) 170 | layers.append(nn.InstanceNorm2d(c_out)) 171 | layers.append(nn.LeakyReLU(0.2)) 172 | self.encoder = nn.Sequential(*layers) 173 | self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) 174 | self.logits = nn.Sigmoid() 175 | 176 | def forward(self, x, cond=None): 177 | x = self.encoder(x) 178 | if cond is not None: 179 | cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) 180 | x = torch.cat([x, cond], dim=1) 181 | x = self.shuffle(x) 182 | x = self.logits(x) 183 | return x 184 | -------------------------------------------------------------------------------- /prompt_list.txt: -------------------------------------------------------------------------------- 1 | A close-up of a blooming peony, with layers of soft, pink petals, a delicate fragrance, and dewdrops glistening 2 | in the early morning light. 3 | 4 | A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a 5 | clear blue sky. 6 | 7 | A close-up portrait of a young woman with flawless skin, vibrant red lipstick, and wavy brown hair, wearing 8 | a vintage floral dress and standing in front of a blooming garden. 9 | 10 | The image features a snow-covered mountain range with a large, snow-covered mountain in the background. 11 | The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the 12 | winter season, with snow covering the ground and the trees. 13 | 14 | Crocodile in a sweater. 15 | 16 | A vibrant anime scene of a young girl with long, flowing pink hair, big sparkling blue eyes, and a school 17 | uniform, standing under a cherry blossom tree with petals falling around her. The background shows a 18 | traditional Japanese school with cherry blossoms in full bloom. 19 | 20 | A playful Labrador retriever puppy with a shiny, golden coat, chasing a red ball in a spacious backyard, with 21 | green grass and a wooden fence. 22 | 23 | A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm 24 | lights glowing from the windows, and a path of footprints leading to the front door. 25 | 26 | A highly detailed, high-quality image of the Banff National Park in Canada. The turquoise waters of Lake 27 | Louise are surrounded by snow-capped mountains and dense pine forests. A wooden canoe is docked at the 28 | edge of the lake. The sky is a clear, bright blue, and the air is crisp and fresh. 29 | 30 | A highly detailed, high-quality image of a Shih Tzu receiving a bath in a home bathroom. The dog is standing 31 | in a tub, covered in suds, with a slightly wet and adorable look. The background includes bathroom fixtures, 32 | towels, and a clean, tiled floor. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | accelerate>=0.25.0 3 | torch==2.0.1+cu118 4 | torchvision==0.15.2+cu118 5 | transformers>=4.30.0 6 | numpy>=1.23.5 7 | kornia>=0.7.0 8 | insightface>=0.7.3 9 | opencv-python>=4.8.1.78 10 | tqdm>=4.66.1 11 | matplotlib>=3.7.4 12 | webdataset>=0.2.79 13 | wandb>=0.16.2 14 | munch>=4.0.0 15 | onnxruntime>=1.16.3 16 | einops>=0.7.0 17 | onnx2torch>=1.5.13 18 | warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 19 | torchtools @ git+https://github.com/pabloppp/pytorch-tools 20 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_b import WurstCore as WurstCoreB 2 | from .train_c import WurstCore as WurstCoreC 3 | from .train_t2i import WurstCore as WurstCore_t2i 4 | from .train_ultrapixel_control import WurstCore as WurstCore_control_lrguide 5 | from .train_personalized import WurstCore as WurstCore_personalized -------------------------------------------------------------------------------- /train/dist_core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def get_world_size(): 6 | """Find OMPI world size without calling mpi functions 7 | :rtype: int 8 | """ 9 | if os.environ.get('PMI_SIZE') is not None: 10 | return int(os.environ.get('PMI_SIZE') or 1) 11 | elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: 12 | return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) 13 | else: 14 | return torch.cuda.device_count() 15 | 16 | 17 | def get_global_rank(): 18 | """Find OMPI world rank without calling mpi functions 19 | :rtype: int 20 | """ 21 | if os.environ.get('PMI_RANK') is not None: 22 | return int(os.environ.get('PMI_RANK') or 0) 23 | elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: 24 | return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) 25 | else: 26 | return 0 27 | 28 | 29 | def get_local_rank(): 30 | """Find OMPI local rank without calling mpi functions 31 | :rtype: int 32 | """ 33 | if os.environ.get('MPI_LOCALRANKID') is not None: 34 | return int(os.environ.get('MPI_LOCALRANKID') or 0) 35 | elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: 36 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) 37 | else: 38 | return 0 39 | 40 | 41 | def get_master_ip(): 42 | if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: 43 | return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] 44 | elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: 45 | return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') 46 | else: 47 | return "127.0.0.1" 48 | --------------------------------------------------------------------------------