├── .idea ├── .name ├── PiSA-SR.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── README.md ├── figs ├── AIGC1.png ├── AIGC2.png ├── AIGC3.png ├── AIGC4.png ├── AIGC5.png ├── comparison.png ├── fig1_github.png ├── framework.png ├── realworld1.png ├── realworld2.png ├── realworld3.png └── realworld4.png ├── pisasr.py ├── ram ├── __init__.py ├── configs │ ├── condition_config.json │ ├── med_config.json │ ├── q2l_config.json │ └── swin │ │ ├── config_swinB_384.json │ │ ├── config_swinL_384.json │ │ └── config_swinL_444.json ├── data │ ├── ram_tag_list.txt │ ├── ram_tag_list_chinese.txt │ ├── ram_tag_list_threshold.txt │ └── tag_list.txt ├── inference.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── bert.cpython-310.pyc │ │ ├── bert_lora.cpython-310.pyc │ │ ├── condition_network.cpython-310.pyc │ │ ├── ram.cpython-310.pyc │ │ ├── ram_condition.cpython-310.pyc │ │ ├── ram_lora.cpython-310.pyc │ │ ├── ram_swin_bert_lora.cpython-310.pyc │ │ ├── ram_swin_lora.cpython-310.pyc │ │ ├── swin_transformer.cpython-310.pyc │ │ ├── swin_transformer_lora.cpython-310.pyc │ │ ├── tag2text.cpython-310.pyc │ │ ├── tag2text_lora.cpython-310.pyc │ │ ├── utils.cpython-310.pyc │ │ └── vit.cpython-310.pyc │ ├── bert.py │ ├── bert_lora.py │ ├── ram.py │ ├── ram_lora.py │ ├── swin_transformer.py │ ├── swin_transformer_lora.py │ ├── tag2text.py │ ├── tag2text_lora.py │ ├── utils.py │ └── vit.py ├── transform.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── metrics.cpython-310.pyc │ └── openset_utils.cpython-310.pyc │ ├── metrics.py │ └── openset_utils.py ├── requirements.txt ├── scripts ├── get_path.py ├── test │ ├── test_adjustable.sh │ └── test_default.sh └── train │ └── train_pisasr.sh ├── src ├── datasets │ ├── dataset.py │ ├── params.yml │ └── realesrgan.py ├── models │ ├── autoencoder_kl.py │ └── unet_2d_condition.py └── my_utils │ ├── __pycache__ │ ├── devices.cpython-310.pyc │ ├── training_lr_utils.cpython-310.pyc │ ├── training_utils.cpython-310.pyc │ ├── training_utils_aigc.cpython-310.pyc │ ├── training_utils_project.cpython-310.pyc │ ├── training_utils_realsr.cpython-310.pyc │ ├── training_utils_realsr_sdxl_vsd_nostage.cpython-310.pyc │ ├── training_utils_realsr_vsd.cpython-310.pyc │ ├── training_utils_realsr_vsd_controlnet_nostage.cpython-310.pyc │ ├── training_utils_realsr_vsd_nostage.cpython-310.pyc │ ├── training_utils_realsr_vsd_nostage_0513.cpython-310.pyc │ ├── training_utils_realsr_vsd_stage1.cpython-310.pyc │ ├── training_utils_realsr_vsd_stage2.cpython-310.pyc │ ├── training_utils_realsr_vsd_turbo_nostage.cpython-310.pyc │ ├── training_utils_res.cpython-310.pyc │ ├── training_utils_seesr_vsd_nostage.cpython-310.pyc │ ├── training_utils_wanghui.cpython-310.pyc │ ├── vaehook.cpython-310.pyc │ └── wavelet_color_fix.cpython-310.pyc │ ├── devices.py │ ├── training_utils.py │ ├── vaehook.py │ └── wavelet_color_fix.py ├── test_pisasr.py └── train_pisasr.py /.idea/.name: -------------------------------------------------------------------------------- 1 | test_pisasr.py -------------------------------------------------------------------------------- /.idea/PiSA-SR.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 14 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 13 | 14 | 16 | 17 | 18 | 21 | 27 | 28 | 29 | 30 | 31 | 1733479124231 32 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

Pixel-level and Semantic-level Adjustable Super-resolution: A Dual-LoRA Approach

4 | 5 | 🚩 Accepted by CVPR2025 6 | 7 | 8 | 9 | 10 | [Lingchen Sun](https://scholar.google.com/citations?hl=zh-CN&tzom=-480&user=ZCDjTn8AAAAJ)1,2 11 | | [Rongyuan Wu](https://scholar.google.com/citations?user=A-U8zE8AAAAJ&hl=zh-CN)1,2 | 12 | [Zhiyuan Ma](https://scholar.google.com/citations?user=F15mLDYAAAAJ&hl=en)1 | 13 | [Shuaizheng Liu](https://scholar.google.com/citations?user=wzdCc-QAAAAJ&hl=en)1,2 | 14 | [Qiaosi Yi](https://dblp.org/pid/249/8335.html)1,2 | 15 | [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)1,2 16 | 17 | 1The Hong Kong Polytechnic University, 2OPPO Research Institute 18 |
19 | 20 | 21 | ## ⏰ Update 22 | - **2025.3.25**: Training code is released. 23 | - **2025.1.2**: Code and models are released. 24 | - **2024.12.4**: The paper and this repo are released. 25 | 26 | :star: If PiSA-SR is helpful to your images or projects, please help star this repo. Thanks! :hugs: 27 | 28 | ## 🌟 Overview Framework 29 | 30 | ![PiSA-SR](figs/framework.png) 31 | 32 | 33 | (a) Training procedure of PiSA-SR. During the training process, two LoRA modules are respectively optimized for pixel-level and semantic-level enhancement. 34 | 35 | (b) Inference procedure of PiSA-SR. During the inference stage, users can use the default setting to reconstruct the high-quality image in one-step diffusion or adjust λpix and λsem to control the strengths of pixel-level and semantic-level enhancement. 36 | ## 😍 Visual Results 37 | ### Demo on Real-world SR 38 | [](https://imgsli.com/MzM0NDE3) [](https://imgsli.com/MzM0NDIz) [](https://imgsli.com/MzM0NDIx) [](https://imgsli.com/MzM0NDI2) 39 | 40 | ### Demo on AIGC Enhancement 41 | [](https://imgsli.com/MzM0NDI4) [](https://imgsli.com/MzM0NDMx) [](https://imgsli.com/MzM0NDM1) [](https://imgsli.com/MzM0NDM0) [](https://imgsli.com/MzM0NDM2) 42 | 43 | ### Adjustable SR Results 44 | 45 | PiSA-SR 46 | 47 | 48 | By increasing the guidance scale λpix on the pixel-level LoRA module, the image degradations such as noise and compression artifacts can be gradually removed; however, a too-strong λpix will make the SR image over-smoothed. By increasing the guidance scale λsem on the semantic-level LoRA module, the SR images will have more semantic details; nonetheless, a too-high λsem will generate visual artifacts. 49 | 50 | ### Comparisons with Other DM-Based SR Methods 51 | ![PiSA-SR](figs/comparison.png) 52 | 53 | ## ⚙ Dependencies and Installation 54 | ```shell 55 | ## git clone this repository 56 | git clone https://github.com/csslc/PiSA-SR 57 | cd PiSA-SR 58 | 59 | 60 | # create an environment 61 | conda create -n PiSA-SR python=3.10 62 | conda activate PiSA-SR 63 | pip install --upgrade pip 64 | pip install -r requirements.txt 65 | ``` 66 | 67 | ## 🍭 Quick Inference 68 | #### Step 1: Download the pretrained models 69 | - Download the pretrained SD-2.1-base models from [HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). 70 | - Download the RAM model from [HuggingFace](https://huggingface.co/spaces/xinyu1205/recognize-anything/blob/main/ram_swin_large_14m.pth) and save the model to the [folder](src/ram_pretrain_model). 71 | - Download the PiSA-SR model from [`GoogleDrive`](https://drive.google.com/drive/folders/1oLetijWNd59xwJE5oU-eXylQBifxWdss?usp=drive_link) or [`BaiduNetdisk(pwd: pisa)`](https://pan.baidu.com/s/1wcMVp9vmsDrLnK0yTAH2Ig) and put the models in the `preset/models`: 72 | 73 | #### Step 2: Prepare testing data 74 | You can put the testing images in the `preset/test_datasets`. 75 | 76 | #### Step 3: Running testing command 77 | For default setting: 78 | ``` 79 | python test_pisasr.py \ 80 | --pretrained_model_path preset/models/stable-diffusion-2-1-base \ 81 | --pretrained_path preset/models/pisa_sr.pkl \ 82 | --process_size 512 \ 83 | --upscale 4 \ 84 | --input_image preset/test_datasets \ 85 | --output_dir experiments/test \ 86 | --default 87 | ``` 88 | 89 | For adjustable setting: 90 | ``` 91 | python test_pisasr.py \ 92 | --pretrained_model_path preset/models/stable-diffusion-2-1-base \ 93 | --pretrained_path preset/models/pisa_sr.pkl \ 94 | --process_size 512 \ 95 | --upscale 4 \ 96 | --input_image preset/test_datasets \ 97 | --output_dir experiments/test \ 98 | --lambda_pix 1.0 \ 99 | --lambda_sem 1.0 100 | ``` 101 | 🛠️You can adjust `lambda_pix` and `lambda_sem` to **control the strengths of pixel-wise fidelity and semantic-level details**. 102 | 103 | We integrate [tile_diffusion](https://github.com/albarji/mixture-of-diffusers) and [tile_vae](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/tree/main) to the [test_pisasr.py](test_pisasr.py) to save the GPU memory for inference. 104 | You can change the tile size and stride according to the VRAM of your device. 105 | 106 | ``` 107 | python test_pisasr.py \ 108 | --pretrained_model_path preset/models/stable-diffusion-2-1-base \ 109 | --pretrained_path preset/models/pisa_sr.pkl \ 110 | --process_size 512 \ 111 | --upscale 4 \ 112 | --input_image preset/test_datasets \ 113 | --output_dir experiments/test \ 114 | --latent_tiled_size 96 \ 115 | --latent_tiled_overlap 32 \ 116 | --vae_encoder_tiled_size 1024 \ 117 | --vae_decoder_tiled_size 224 \ 118 | --default 119 | ``` 120 | 121 | ## 🚋 Train 122 | #### Step1: Prepare training data 123 | Generate txt file for the training set. 124 | Fill in the required information in [get_path](scripts/get_path.py) and run, then you can obtain the txt file recording the paths of ground-truth images. 125 | You can save the txt file into `preset/gt_path.txt`. 126 | The high-quality ground-truth images can be selected from your training dataset, and the txt file can be saved in `preset/gt_selected_path`. 127 | 128 | #### Step2: Train Model 129 | 1. Download pretrained [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) to provide generative capabilities. 130 | 131 | ```shell 132 | wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt --no-check-certificate 133 | ``` 134 | 135 | 2. Download [RAM](https://huggingface.co/spaces/xinyu1205/recognize-anything/blob/main/ram_swin_large_14m.pth) model for extracting text prompt, and put the model into `src/ram_pretrain_model`. 136 | 137 | 3. Start training. 138 | ```shell 139 | CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_pisasr.py \ 140 | --pretrained_model_path="preset/models/stable-diffusion-2-1-base" \ 141 | --pretrained_model_path_csd="preset/models/stable-diffusion-2-1-base" \ 142 | --dataset_txt_paths="preset/gt_path.txt" \ 143 | --highquality_dataset_txt_paths="preset/gt_selected_path.txt" \ 144 | --dataset_test_folder="preset/testfolder" \ 145 | --learning_rate=5e-5 \ 146 | --train_batch_size=4 \ 147 | --prob=0.1 \ 148 | --gradient_accumulation_steps=1 \ 149 | --enable_xformers_memory_efficient_attention --checkpointing_steps 500 \ 150 | --seed 123 \ 151 | --output_dir="experiments/train-pisasr" \ 152 | --cfg_csd 7.5 \ 153 | --timesteps1 1 \ 154 | --lambda_lpips=2.0 \ 155 | --lambda_l2=1.0 \ 156 | --lambda_csd=1.0 \ 157 | --pix_steps=4000 \ 158 | --lora_rank_unet_pix=4 \ 159 | --lora_rank_unet_sem=4 \ 160 | --min_dm_step_ratio=0.02 \ 161 | --max_dm_step_ratio=0.5 \ 162 | --null_text_ratio=0.5 \ 163 | --align_method="adain" \ 164 | --deg_file_path="params.yml" \ 165 | --tracker_project_name "PiSASR" \ 166 | --is_module True 167 | ``` 168 | 169 | ### Citations 170 | If our code helps your research or work, please consider citing our paper. 171 | The following are BibTeX references: 172 | 173 | ``` 174 | @article{sun2024pisasr, 175 | title={Pixel-level and Semantic-level Adjustable Super-resolution: A Dual-LoRA Approach}, 176 | author={Sun, Lingchen and Wu, Rongyuan and Ma, Zhiyuan and Liu, Shuaizheng and Yi, Qiaosi and Zhang, Lei}, 177 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 178 | year={2025} 179 | } 180 | ``` 181 | 182 | ### License 183 | This project is released under the [Apache 2.0 license](LICENSE). 184 | 185 | ### Acknowledgement 186 | This project is based on [OSEDiff](https://github.com/cswry/OSEDiff). Thanks for the awesome work. 187 | 188 | ### Contact 189 | If you have any questions, please contact: ling-chen.sun@connect.polyu.hk 190 | 191 | 192 |
193 | statistics 194 | 195 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=csslc/PiSA-SR) 196 | 197 |
198 | -------------------------------------------------------------------------------- /figs/AIGC1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/AIGC1.png -------------------------------------------------------------------------------- /figs/AIGC2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/AIGC2.png -------------------------------------------------------------------------------- /figs/AIGC3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/AIGC3.png -------------------------------------------------------------------------------- /figs/AIGC4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/AIGC4.png -------------------------------------------------------------------------------- /figs/AIGC5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/AIGC5.png -------------------------------------------------------------------------------- /figs/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/comparison.png -------------------------------------------------------------------------------- /figs/fig1_github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/fig1_github.png -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/framework.png -------------------------------------------------------------------------------- /figs/realworld1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/realworld1.png -------------------------------------------------------------------------------- /figs/realworld2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/realworld2.png -------------------------------------------------------------------------------- /figs/realworld3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/realworld3.png -------------------------------------------------------------------------------- /figs/realworld4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/figs/realworld4.png -------------------------------------------------------------------------------- /ram/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import inference_tag2text, inference_ram, inference_ram_openset 2 | from .transform import get_transform 3 | -------------------------------------------------------------------------------- /ram/configs/condition_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "nf": 64 3 | } -------------------------------------------------------------------------------- /ram/configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } -------------------------------------------------------------------------------- /ram/configs/q2l_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 4, 15 | "num_hidden_layers": 2, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true, 21 | "add_tag_cross_attention": false 22 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinB_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", 3 | "vision_width": 1024, 4 | "image_res": 384, 5 | "window_size": 12, 6 | "embed_dim": 128, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 4, 8, 16, 32 ] 9 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinL_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth", 3 | "vision_width": 1536, 4 | "image_res": 384, 5 | "window_size": 12, 6 | "embed_dim": 192, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 6, 12, 24, 48 ] 9 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinL_444.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth", 3 | "vision_width": 1536, 4 | "image_res": 444, 5 | "window_size": 12, 6 | "embed_dim": 192, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 6, 12, 24, 48 ] 9 | } -------------------------------------------------------------------------------- /ram/inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Inference of RAM and Tag2Text Models 3 | * Written by Xinyu Huang 4 | ''' 5 | import torch 6 | 7 | 8 | def inference_tag2text(image, model, input_tag="None"): 9 | 10 | with torch.no_grad(): 11 | caption, tag_predict = model.generate(image, 12 | tag_input=None, 13 | max_length=50, 14 | return_tag_predict=True) 15 | 16 | if input_tag == '' or input_tag == 'none' or input_tag == 'None': 17 | return tag_predict[0], None, caption[0] 18 | 19 | # If user input specified tags: 20 | else: 21 | input_tag_list = [] 22 | input_tag_list.append(input_tag.replace(',', ' | ')) 23 | 24 | with torch.no_grad(): 25 | caption, input_tag = model.generate(image, 26 | tag_input=input_tag_list, 27 | max_length=50, 28 | return_tag_predict=True) 29 | 30 | return tag_predict[0], input_tag[0], caption[0] 31 | 32 | 33 | def inference_ram(image, model): 34 | 35 | with torch.no_grad(): 36 | tags, tags_chinese = model.generate_tag(image) 37 | 38 | return tags 39 | 40 | 41 | def inference_ram_openset(image, model): 42 | 43 | with torch.no_grad(): 44 | tags = model.generate_tag_openset(image) 45 | 46 | return tags[0] 47 | -------------------------------------------------------------------------------- /ram/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ram import ram 2 | from .tag2text import tag2text 3 | -------------------------------------------------------------------------------- /ram/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/bert.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/bert.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/bert_lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/bert_lora.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/condition_network.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/condition_network.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/ram.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/ram.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/ram_condition.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/ram_condition.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/ram_lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/ram_lora.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/ram_swin_bert_lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/ram_swin_bert_lora.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/ram_swin_lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/ram_swin_lora.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/swin_transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/swin_transformer.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/swin_transformer_lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/swin_transformer_lora.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/tag2text.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/tag2text.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/tag2text_lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/tag2text_lora.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/__pycache__/vit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/models/__pycache__/vit.cpython-310.pyc -------------------------------------------------------------------------------- /ram/models/ram.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Recognize Anything Model (RAM) 3 | * Written by Xinyu Huang 4 | ''' 5 | import json 6 | import warnings 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | from .bert import BertConfig, BertLMHeadModel, BertModel 13 | from .swin_transformer import SwinTransformer 14 | from .utils import * 15 | 16 | warnings.filterwarnings("ignore") 17 | 18 | 19 | 20 | class RAM(nn.Module): 21 | def __init__(self, 22 | med_config=f'{CONFIG_PATH}/configs/med_config.json', 23 | image_size=384, 24 | vit='base', 25 | vit_grad_ckpt=False, 26 | vit_ckpt_layer=0, 27 | prompt='a picture of ', 28 | threshold=0.68, 29 | delete_tag_index=[], 30 | tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt', 31 | tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt'): 32 | r""" The Recognize Anything Model (RAM) inference module. 33 | RAM is a strong image tagging model, which can recognize any common category with high accuracy. 34 | Described in the paper " Recognize Anything: A Strong Image Tagging Model" https://recognize-anything.github.io/ 35 | 36 | Args: 37 | med_config (str): path for the mixture of encoder-decoder model's configuration file 38 | image_size (int): input image size 39 | vit (str): model size of vision transformer 40 | threshold (int): tagging threshold 41 | delete_tag_index (list): delete some tags that may disturb captioning 42 | """ 43 | super().__init__() 44 | 45 | # create image encoder 46 | if vit == 'swin_b': 47 | if image_size == 224: 48 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 49 | elif image_size == 384: 50 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 51 | vision_config = read_json(vision_config_path) 52 | assert image_size == vision_config['image_res'] 53 | # assert config['patch_size'] == 32 54 | vision_width = vision_config['vision_width'] 55 | 56 | self.visual_encoder = SwinTransformer( 57 | img_size=vision_config['image_res'], 58 | patch_size=4, 59 | in_chans=3, 60 | embed_dim=vision_config['embed_dim'], 61 | depths=vision_config['depths'], 62 | num_heads=vision_config['num_heads'], 63 | window_size=vision_config['window_size'], 64 | mlp_ratio=4., 65 | qkv_bias=True, 66 | drop_rate=0.0, 67 | drop_path_rate=0.1, 68 | ape=False, 69 | patch_norm=True, 70 | use_checkpoint=False) 71 | 72 | elif vit == 'swin_l': 73 | if image_size == 224: 74 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' 75 | elif image_size == 384: 76 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' 77 | elif image_size == 444: 78 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_444.json' 79 | vision_config = read_json(vision_config_path) 80 | assert image_size == vision_config['image_res'] 81 | # assert config['patch_size'] == 32 82 | vision_width = vision_config['vision_width'] 83 | 84 | self.visual_encoder = SwinTransformer( 85 | img_size=vision_config['image_res'], 86 | patch_size=4, 87 | in_chans=3, 88 | embed_dim=vision_config['embed_dim'], 89 | depths=vision_config['depths'], 90 | num_heads=vision_config['num_heads'], 91 | window_size=vision_config['window_size'], 92 | mlp_ratio=4., 93 | qkv_bias=True, 94 | drop_rate=0.0, 95 | drop_path_rate=0.1, 96 | ape=False, 97 | patch_norm=True, 98 | use_checkpoint=False) 99 | 100 | else: 101 | self.visual_encoder, vision_width = create_vit( 102 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer) 103 | 104 | # create tokenzier 105 | self.tokenizer = init_tokenizer() 106 | 107 | # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder 108 | # create image-tag interaction encoder 109 | encoder_config = BertConfig.from_json_file(med_config) 110 | encoder_config.encoder_width = 512 111 | self.tag_encoder = BertModel(config=encoder_config, 112 | add_pooling_layer=False) 113 | 114 | # create image-tag-text decoder 115 | decoder_config = BertConfig.from_json_file(med_config) 116 | self.text_decoder = BertLMHeadModel(config=decoder_config) 117 | 118 | self.delete_tag_index = delete_tag_index 119 | self.prompt = prompt 120 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 121 | 122 | # load tag list 123 | self.tag_list = self.load_tag_list(tag_list) 124 | self.tag_list_chinese = self.load_tag_list(tag_list_chinese) 125 | 126 | # create image-tag recognition decoder 127 | self.threshold = threshold 128 | self.num_class = len(self.tag_list) 129 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') 130 | q2l_config.encoder_width = 512 131 | self.tagging_head = BertModel(config=q2l_config, 132 | add_pooling_layer=False) 133 | self.tagging_head.resize_token_embeddings(len(self.tokenizer)) 134 | # self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size) 135 | self.label_embed = nn.Parameter(torch.zeros(self.num_class, q2l_config.encoder_width)) 136 | 137 | if q2l_config.hidden_size != 512: 138 | self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) 139 | else: 140 | self.wordvec_proj = nn.Identity() 141 | 142 | self.fc = nn.Linear(q2l_config.hidden_size, 1) 143 | 144 | self.del_selfattention() 145 | 146 | # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder" 147 | tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', 148 | ' ') 149 | self.image_proj = nn.Linear(vision_width, 512) 150 | # self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/textual_label_embedding.pth',map_location='cpu').float()) 151 | 152 | # adjust thresholds for some tags 153 | self.class_threshold = torch.ones(self.num_class) * self.threshold 154 | ram_class_threshold_path = f'{CONFIG_PATH}/data/ram_tag_list_threshold.txt' 155 | with open(ram_class_threshold_path, 'r', encoding='utf-8') as f: 156 | ram_class_threshold = [float(s.strip()) for s in f] 157 | for key,value in enumerate(ram_class_threshold): 158 | self.class_threshold[key] = value 159 | 160 | def load_tag_list(self, tag_list_file): 161 | with open(tag_list_file, 'r', encoding="utf-8") as f: 162 | tag_list = f.read().splitlines() 163 | tag_list = np.array(tag_list) 164 | return tag_list 165 | 166 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label 167 | def del_selfattention(self): 168 | del self.tagging_head.embeddings 169 | for layer in self.tagging_head.encoder.layer: 170 | del layer.attention 171 | 172 | def condition_forward(self, 173 | image, 174 | threshold=0.68, 175 | condition_flag=None, 176 | tag_input=None, 177 | only_feature=True, 178 | ): 179 | 180 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) 181 | 182 | image_embeds = self.image_proj(self.visual_encoder(image)) 183 | if only_feature: 184 | return image_embeds 185 | else: 186 | image_atts = torch.ones(image_embeds.size()[:-1], 187 | dtype=torch.long).to(image.device) 188 | 189 | # recognized image tags using image-tag recogntiion decoder 190 | image_cls_embeds = image_embeds[:, 0, :] 191 | image_spatial_embeds = image_embeds[:, 1:, :] 192 | 193 | bs = image_spatial_embeds.shape[0] 194 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) 195 | tagging_embed = self.tagging_head( 196 | encoder_embeds=label_embed, 197 | encoder_hidden_states=image_embeds, 198 | encoder_attention_mask=image_atts, 199 | return_dict=False, 200 | mode='tagging', 201 | ) 202 | 203 | logits = self.fc(tagging_embed[0]).squeeze(-1) 204 | 205 | targets = torch.where( 206 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 207 | torch.tensor(1.0).to(image.device), 208 | torch.zeros(self.num_class).to(image.device)) 209 | 210 | return image_embeds, logits, targets 211 | 212 | def generate_tag(self, 213 | image, 214 | threshold=0.68, 215 | tag_input=None, 216 | ): 217 | 218 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) 219 | 220 | image_embeds = self.image_proj(self.visual_encoder(image)) 221 | image_atts = torch.ones(image_embeds.size()[:-1], 222 | dtype=torch.long).to(image.device) 223 | 224 | # recognized image tags using image-tag recogntiion decoder 225 | image_cls_embeds = image_embeds[:, 0, :] 226 | image_spatial_embeds = image_embeds[:, 1:, :] 227 | 228 | bs = image_spatial_embeds.shape[0] 229 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) 230 | tagging_embed = self.tagging_head( 231 | encoder_embeds=label_embed, 232 | encoder_hidden_states=image_embeds, 233 | encoder_attention_mask=image_atts, 234 | return_dict=False, 235 | mode='tagging', 236 | ) 237 | 238 | logits = self.fc(tagging_embed[0]).squeeze(-1) 239 | 240 | targets = torch.where( 241 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 242 | torch.tensor(1.0).to(image.device), 243 | torch.zeros(self.num_class).to(image.device)) 244 | 245 | tag = targets.cpu().numpy() 246 | tag[:,self.delete_tag_index] = 0 247 | tag_output = [] 248 | tag_output_chinese = [] 249 | for b in range(bs): 250 | index = np.argwhere(tag[b] == 1) 251 | token = self.tag_list[index].squeeze(axis=1) 252 | # tag_output.append(' | '.join(token)) 253 | tag_output.append(', '.join(token)) 254 | token_chinese = self.tag_list_chinese[index].squeeze(axis=1) 255 | # tag_output_chinese.append(' | '.join(token_chinese)) 256 | tag_output_chinese.append(', '.join(token_chinese)) 257 | 258 | 259 | return tag_output, tag_output_chinese 260 | 261 | def generate_tag_openset(self, 262 | image, 263 | threshold=0.68, 264 | tag_input=None, 265 | ): 266 | 267 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) 268 | 269 | image_embeds = self.image_proj(self.visual_encoder(image)) 270 | image_atts = torch.ones(image_embeds.size()[:-1], 271 | dtype=torch.long).to(image.device) 272 | 273 | # recognized image tags using image-tag recogntiion decoder 274 | image_cls_embeds = image_embeds[:, 0, :] 275 | image_spatial_embeds = image_embeds[:, 1:, :] 276 | 277 | bs = image_spatial_embeds.shape[0] 278 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) 279 | tagging_embed = self.tagging_head( 280 | encoder_embeds=label_embed, 281 | encoder_hidden_states=image_embeds, 282 | encoder_attention_mask=image_atts, 283 | return_dict=False, 284 | mode='tagging', 285 | ) 286 | 287 | logits = self.fc(tagging_embed[0]).squeeze(-1) 288 | 289 | targets = torch.where( 290 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 291 | torch.tensor(1.0).to(image.device), 292 | torch.zeros(self.num_class).to(image.device)) 293 | 294 | tag = targets.cpu().numpy() 295 | tag[:,self.delete_tag_index] = 0 296 | tag_output = [] 297 | for b in range(bs): 298 | index = np.argwhere(tag[b] == 1) 299 | token = self.tag_list[index].squeeze(axis=1) 300 | tag_output.append(' | '.join(token)) 301 | 302 | return tag_output 303 | 304 | 305 | # load RAM pretrained model parameters 306 | def ram(pretrained='', **kwargs): 307 | model = RAM(**kwargs) 308 | if pretrained: 309 | if kwargs['vit'] == 'swin_b': 310 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) 311 | elif kwargs['vit'] == 'swin_l': 312 | model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs) 313 | else: 314 | model, msg = load_checkpoint(model, pretrained) 315 | print('vit:', kwargs['vit']) 316 | # print('msg', msg) 317 | return model 318 | -------------------------------------------------------------------------------- /ram/models/ram_lora.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Recognize Anything Model (RAM) 3 | * Written by Xinyu Huang 4 | ''' 5 | import json 6 | import warnings 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | 13 | from .bert_lora import BertConfig, BertLMHeadModel, BertModel 14 | from .swin_transformer_lora import SwinTransformer 15 | from .utils import * 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | 21 | class RAMLora(nn.Module): 22 | def __init__(self, 23 | condition_config=f'{CONFIG_PATH}/configs/condition_config.json', 24 | med_config=f'{CONFIG_PATH}/configs/med_config.json', 25 | image_size=384, 26 | vit='base', 27 | vit_grad_ckpt=False, 28 | vit_ckpt_layer=0, 29 | prompt='a picture of ', 30 | threshold=0.68, 31 | max_threthold=0.9, 32 | add_threthold=0, 33 | delete_tag_index=[], 34 | tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt', 35 | tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt'): 36 | r""" The Recognize Anything Model (RAM) inference module. 37 | RAM is a strong image tagging model, which can recognize any common category with high accuracy. 38 | Described in the paper " Recognize Anything: A Strong Image Tagging Model" https://recognize-anything.github.io/ 39 | 40 | Args: 41 | med_config (str): path for the mixture of encoder-decoder model's configuration file 42 | image_size (int): input image size 43 | vit (str): model size of vision transformer 44 | threshold (int): tagging threshold 45 | delete_tag_index (list): delete some tags that may disturb captioning 46 | """ 47 | super().__init__() 48 | 49 | # create image encoder 50 | if vit == 'swin_b': 51 | if image_size == 224: 52 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 53 | elif image_size == 384: 54 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 55 | vision_config = read_json(vision_config_path) 56 | assert image_size == vision_config['image_res'] 57 | # assert config['patch_size'] == 32 58 | vision_width = vision_config['vision_width'] 59 | 60 | self.visual_encoder = SwinTransformer( 61 | img_size=vision_config['image_res'], 62 | patch_size=4, 63 | in_chans=3, 64 | embed_dim=vision_config['embed_dim'], 65 | depths=vision_config['depths'], 66 | num_heads=vision_config['num_heads'], 67 | window_size=vision_config['window_size'], 68 | mlp_ratio=4., 69 | qkv_bias=True, 70 | drop_rate=0.0, 71 | drop_path_rate=0.1, 72 | ape=False, 73 | patch_norm=True, 74 | use_checkpoint=False) 75 | 76 | elif vit == 'swin_l': 77 | if image_size == 224: 78 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' 79 | elif image_size == 384: 80 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' 81 | elif image_size == 444: 82 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_444.json' 83 | vision_config = read_json(vision_config_path) 84 | assert image_size == vision_config['image_res'] 85 | # assert config['patch_size'] == 32 86 | vision_width = vision_config['vision_width'] 87 | 88 | self.visual_encoder = SwinTransformer( 89 | img_size=vision_config['image_res'], 90 | patch_size=4, 91 | in_chans=3, 92 | embed_dim=vision_config['embed_dim'], 93 | depths=vision_config['depths'], 94 | num_heads=vision_config['num_heads'], 95 | window_size=vision_config['window_size'], 96 | mlp_ratio=4., 97 | qkv_bias=True, 98 | drop_rate=0.0, 99 | drop_path_rate=0.1, 100 | ape=False, 101 | patch_norm=True, 102 | use_checkpoint=False) 103 | 104 | else: 105 | self.visual_encoder, vision_width = create_vit( 106 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer) 107 | 108 | # create tokenzier 109 | self.tokenizer = init_tokenizer() 110 | 111 | # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder 112 | # create image-tag interaction encoder 113 | encoder_config = BertConfig.from_json_file(med_config) 114 | encoder_config.encoder_width = 512 115 | self.tag_encoder = BertModel(config=encoder_config, 116 | add_pooling_layer=False) 117 | 118 | # create image-tag-text decoder 119 | decoder_config = BertConfig.from_json_file(med_config) 120 | self.text_decoder = BertLMHeadModel(config=decoder_config) 121 | 122 | self.delete_tag_index = delete_tag_index 123 | self.prompt = prompt 124 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 125 | 126 | # load tag list 127 | self.tag_list = self.load_tag_list(tag_list) 128 | self.tag_list_chinese = self.load_tag_list(tag_list_chinese) 129 | 130 | # create image-tag recognition decoder 131 | self.threshold = threshold 132 | self.num_class = len(self.tag_list) 133 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') 134 | q2l_config.encoder_width = 512 135 | self.tagging_head = BertModel(config=q2l_config, 136 | add_pooling_layer=False) 137 | self.tagging_head.resize_token_embeddings(len(self.tokenizer)) 138 | # self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size) 139 | self.label_embed = nn.Parameter(torch.zeros(self.num_class, q2l_config.encoder_width)) 140 | 141 | if q2l_config.hidden_size != 512: 142 | self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) 143 | else: 144 | self.wordvec_proj = nn.Identity() 145 | 146 | self.fc = nn.Linear(q2l_config.hidden_size, 1) 147 | 148 | self.del_selfattention() 149 | 150 | # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder" 151 | tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', 152 | ' ') 153 | self.image_proj = nn.Linear(vision_width, 512) 154 | # self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/textual_label_embedding.pth',map_location='cpu').float()) 155 | 156 | # adjust thresholds for some tags 157 | self.class_threshold = torch.ones(self.num_class) * self.threshold 158 | 159 | print(f'Loading default thretholds from .txt....') 160 | ram_class_threshold_path = f'{CONFIG_PATH}/data/ram_tag_list_threshold.txt' 161 | with open(ram_class_threshold_path, 'r', encoding='utf-8') as f: 162 | ram_class_threshold = [float(s.strip()) for s in f] 163 | for key,value in enumerate(ram_class_threshold): 164 | if value > max_threthold: 165 | self.class_threshold[key] = value 166 | else: 167 | self.class_threshold[key] = min(value + add_threthold, max_threthold) 168 | 169 | 170 | 171 | def load_tag_list(self, tag_list_file): 172 | with open(tag_list_file, 'r', encoding="utf-8") as f: 173 | tag_list = f.read().splitlines() 174 | tag_list = np.array(tag_list) 175 | return tag_list 176 | 177 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label 178 | def del_selfattention(self): 179 | del self.tagging_head.embeddings 180 | for layer in self.tagging_head.encoder.layer: 181 | del layer.attention 182 | 183 | def generate_image_embeds(self, 184 | image, 185 | condition=False 186 | ): 187 | 188 | image_embeds = self.image_proj(self.visual_encoder(image)) 189 | 190 | return image_embeds 191 | 192 | def generate_tag(self, 193 | image, 194 | threshold=0.68, 195 | tag_input=None, 196 | ): 197 | 198 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) 199 | 200 | image_embeds = self.image_proj(self.visual_encoder(image)) 201 | 202 | image_atts = torch.ones(image_embeds.size()[:-1], 203 | dtype=torch.long).to(image.device) 204 | 205 | # recognized image tags using image-tag recogntiion decoder 206 | image_cls_embeds = image_embeds[:, 0, :] 207 | image_spatial_embeds = image_embeds[:, 1:, :] 208 | 209 | bs = image_spatial_embeds.shape[0] 210 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) 211 | tagging_embed = self.tagging_head( 212 | encoder_embeds=label_embed, 213 | encoder_hidden_states=image_embeds, 214 | encoder_attention_mask=image_atts, 215 | return_dict=False, 216 | mode='tagging', 217 | ) 218 | 219 | logits = self.fc(tagging_embed[0]).squeeze(-1) 220 | targets = torch.where( 221 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 222 | torch.tensor(1.0).to(image.device), 223 | torch.zeros(self.num_class).to(image.device)) 224 | 225 | tag = targets.cpu().numpy() 226 | tag[:,self.delete_tag_index] = 0 227 | tag_output = [] 228 | tag_output_chinese = [] 229 | for b in range(bs): 230 | index = np.argwhere(tag[b] == 1) 231 | token = self.tag_list[index].squeeze(axis=1) 232 | # tag_output.append(' | '.join(token)) 233 | tag_output.append(', '.join(token)) 234 | token_chinese = self.tag_list_chinese[index].squeeze(axis=1) 235 | # tag_output_chinese.append(' | '.join(token_chinese)) 236 | tag_output_chinese.append(', '.join(token_chinese)) 237 | 238 | 239 | return tag_output, tag_output_chinese 240 | 241 | 242 | 243 | def condition_forward(self, 244 | image, 245 | threshold=0.68, 246 | condition_flag=None, 247 | tag_input=None, 248 | only_feature=True 249 | ): 250 | 251 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) 252 | image_embeds = self.image_proj(self.visual_encoder(image)) 253 | 254 | if only_feature: 255 | return image_embeds 256 | else: 257 | image_atts = torch.ones(image_embeds.size()[:-1], 258 | dtype=torch.long).to(image.device) 259 | 260 | # recognized image tags using image-tag recogntiion decoder 261 | image_cls_embeds = image_embeds[:, 0, :] 262 | image_spatial_embeds = image_embeds[:, 1:, :] 263 | 264 | bs = image_spatial_embeds.shape[0] 265 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) 266 | tagging_embed = self.tagging_head( 267 | encoder_embeds=label_embed, 268 | encoder_hidden_states=image_embeds, 269 | encoder_attention_mask=image_atts, 270 | return_dict=False, 271 | mode='tagging', 272 | ) 273 | 274 | logits = self.fc(tagging_embed[0]).squeeze(-1) 275 | 276 | targets = torch.where( 277 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 278 | torch.tensor(1.0).to(image.device), 279 | torch.zeros(self.num_class).to(image.device)) 280 | 281 | return image_embeds, logits, targets 282 | 283 | def generate_tag_openset(self, 284 | image, 285 | threshold=0.68, 286 | tag_input=None, 287 | ): 288 | 289 | label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) 290 | 291 | image_embeds = self.image_proj(self.visual_encoder(image)) 292 | image_atts = torch.ones(image_embeds.size()[:-1], 293 | dtype=torch.long).to(image.device) 294 | 295 | # recognized image tags using image-tag recogntiion decoder 296 | image_cls_embeds = image_embeds[:, 0, :] 297 | image_spatial_embeds = image_embeds[:, 1:, :] 298 | 299 | bs = image_spatial_embeds.shape[0] 300 | label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) 301 | tagging_embed = self.tagging_head( 302 | encoder_embeds=label_embed, 303 | encoder_hidden_states=image_embeds, 304 | encoder_attention_mask=image_atts, 305 | return_dict=False, 306 | mode='tagging', 307 | ) 308 | 309 | logits = self.fc(tagging_embed[0]).squeeze(-1) 310 | 311 | targets = torch.where( 312 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 313 | torch.tensor(1.0).to(image.device), 314 | torch.zeros(self.num_class).to(image.device)) 315 | 316 | tag = targets.cpu().numpy() 317 | tag[:,self.delete_tag_index] = 0 318 | tag_output = [] 319 | for b in range(bs): 320 | index = np.argwhere(tag[b] == 1) 321 | token = self.tag_list[index].squeeze(axis=1) 322 | tag_output.append(' | '.join(token)) 323 | 324 | return tag_output 325 | 326 | 327 | # load RAM pretrained model parameters 328 | def ram(pretrained='', pretrained_condition='', **kwargs): 329 | model = RAMLora(**kwargs) 330 | 331 | if pretrained: 332 | if kwargs['vit'] == 'swin_b': 333 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) 334 | elif kwargs['vit'] == 'swin_l': 335 | model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs) 336 | else: 337 | model, msg = load_checkpoint(model, pretrained) 338 | print('vit:', kwargs['vit']) 339 | 340 | if pretrained_condition: 341 | model.load_state_dict(torch.load(pretrained_condition), strict=False) 342 | print(f'load lora from {pretrained_condition}') 343 | 344 | return model 345 | -------------------------------------------------------------------------------- /ram/models/tag2text.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Tag2Text Model 3 | * Written by Xinyu Huang 4 | ''' 5 | import numpy as np 6 | import json 7 | import torch 8 | import warnings 9 | 10 | from torch import nn 11 | from .bert import BertConfig, BertModel, BertLMHeadModel 12 | from .swin_transformer import SwinTransformer 13 | 14 | from .utils import * 15 | 16 | warnings.filterwarnings("ignore") 17 | 18 | 19 | class Tag2Text(nn.Module): 20 | 21 | def __init__(self, 22 | med_config=f'{CONFIG_PATH}/configs/med_config.json', 23 | image_size=384, 24 | vit='base', 25 | vit_grad_ckpt=False, 26 | vit_ckpt_layer=0, 27 | prompt='a picture of ', 28 | threshold=0.68, 29 | delete_tag_index=[127,2961, 3351, 3265, 3338, 3355, 3359], 30 | tag_list=f'{CONFIG_PATH}/data/tag_list.txt'): 31 | r""" Tag2Text inference module, both captioning and tagging are included. 32 | Tag2Text is an efficient and controllable vision-language pre-training framework. 33 | Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657 34 | 35 | Args: 36 | med_config (str): path for the mixture of encoder-decoder model's configuration file 37 | image_size (int): input image size 38 | vit (str): model size of vision transformer 39 | threshold (int): tagging threshold 40 | delete_tag_index (list): delete some tags that may disturb captioning 41 | """ 42 | super().__init__() 43 | 44 | # create image encoder 45 | if vit == 'swin_b': 46 | if image_size == 224: 47 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 48 | elif image_size == 384: 49 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 50 | vision_config = read_json(vision_config_path) 51 | assert image_size == vision_config['image_res'] 52 | # assert config['patch_size'] == 32 53 | vision_width = vision_config['vision_width'] 54 | 55 | self.visual_encoder = SwinTransformer( 56 | img_size=vision_config['image_res'], 57 | patch_size=4, 58 | in_chans=3, 59 | embed_dim=vision_config['embed_dim'], 60 | depths=vision_config['depths'], 61 | num_heads=vision_config['num_heads'], 62 | window_size=vision_config['window_size'], 63 | mlp_ratio=4., 64 | qkv_bias=True, 65 | drop_rate=0.0, 66 | drop_path_rate=0.1, 67 | ape=False, 68 | patch_norm=True, 69 | use_checkpoint=False) 70 | 71 | else: 72 | self.visual_encoder, vision_width = create_vit( 73 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer) 74 | 75 | # create tokenzier 76 | self.tokenizer = init_tokenizer() 77 | 78 | # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder 79 | # create image-tag interaction encoder 80 | encoder_config = BertConfig.from_json_file(med_config) 81 | encoder_config.encoder_width = vision_width 82 | self.tag_encoder = BertModel(config=encoder_config, 83 | add_pooling_layer=False) 84 | 85 | # create image-tag-text decoder 86 | decoder_config = BertConfig.from_json_file(med_config) 87 | self.text_decoder = BertLMHeadModel(config=decoder_config) 88 | 89 | # delete some tags that may disturb captioning 90 | # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one" 91 | self.delete_tag_index = delete_tag_index 92 | self.prompt = prompt 93 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 94 | 95 | # load tag list 96 | self.tag_list = self.load_tag_list(tag_list) 97 | 98 | # create image-tag recognition decoder 99 | self.threshold = threshold 100 | self.num_class = len(self.tag_list) 101 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') 102 | q2l_config.encoder_width = vision_width 103 | self.tagging_head = BertModel(config=q2l_config, 104 | add_pooling_layer=False) 105 | self.tagging_head.resize_token_embeddings(len(self.tokenizer)) 106 | self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size) 107 | self.fc = GroupWiseLinear(self.num_class, 108 | q2l_config.hidden_size, 109 | bias=True) 110 | self.del_selfattention() 111 | 112 | self.tagging_loss_function = AsymmetricLoss(gamma_neg=7, 113 | gamma_pos=0, 114 | clip=0.05) 115 | 116 | # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder" 117 | tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', 118 | ' ') 119 | 120 | # adjust thresholds for some tags 121 | # default threshold: 0.68 122 | # 2701: "person"; 2828: "man"; 1167: "woman"; 123 | tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7} 124 | self.class_threshold = torch.ones(self.num_class) * self.threshold 125 | for key,value in tag_thrshold.items(): 126 | self.class_threshold[key] = value 127 | 128 | def load_tag_list(self, tag_list_file): 129 | with open(tag_list_file, 'r') as f: 130 | tag_list = f.read().splitlines() 131 | tag_list = np.array(tag_list) 132 | return tag_list 133 | 134 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label 135 | def del_selfattention(self): 136 | del self.tagging_head.embeddings 137 | for layer in self.tagging_head.encoder.layer: 138 | del layer.attention 139 | 140 | 141 | def forward(self, image, caption, tag): 142 | """ 143 | call function as forward 144 | 145 | Args: 146 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384 147 | caption: type: list[string] len: batch_size 148 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0 149 | 150 | Returns: 151 | loss: type: torch.Tensor 152 | """ 153 | 154 | image_embeds = self.visual_encoder(image) 155 | image_atts = torch.ones(image_embeds.size()[:-1], 156 | dtype=torch.long).to(image.device) 157 | 158 | ##================= Image Tagging ================## 159 | bs = image_embeds.shape[0] 160 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) 161 | 162 | tagging_embed = self.tagging_head( 163 | encoder_embeds=label_embed, 164 | encoder_hidden_states=image_embeds, 165 | encoder_attention_mask=image_atts, 166 | return_dict=False, 167 | mode='tagging', 168 | ) 169 | 170 | logits = self.fc(tagging_embed[0]) 171 | 172 | loss_tag = self.tagging_loss_function(logits, tag) 173 | 174 | ##================= Image-Tag-Text Generation ================## 175 | tag = tag.cpu().numpy() 176 | tag_input = [] 177 | for b in range(bs): 178 | index = np.argwhere(tag[b] == 1) 179 | token = self.tag_list[index].squeeze(axis=1) 180 | tag_input.append(' | '.join(token)) 181 | 182 | # tokenizer input tags 183 | tag_input_tokenzier = self.tokenizer(tag_input, 184 | padding='max_length', 185 | truncation=True, 186 | max_length=40, 187 | return_tensors="pt").to( 188 | image.device) 189 | encoder_input_ids = tag_input_tokenzier.input_ids 190 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id 191 | 192 | # put input tag into image-tag interaction encoder to interact with image embeddings 193 | output_tagembedding = self.tag_encoder( 194 | encoder_input_ids, 195 | attention_mask=tag_input_tokenzier.attention_mask, 196 | encoder_hidden_states=image_embeds, 197 | encoder_attention_mask=image_atts, 198 | return_dict=True, 199 | ) 200 | 201 | text = self.tokenizer(caption, 202 | padding='longest', 203 | truncation=True, 204 | max_length=40, 205 | return_tensors="pt").to( 206 | image.device) 207 | 208 | decoder_input_ids = text.input_ids 209 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id 210 | 211 | decoder_targets = decoder_input_ids.masked_fill( 212 | decoder_input_ids == self.tokenizer.pad_token_id, -100) 213 | decoder_targets[:,:self.prompt_length] = -100 214 | 215 | decoder_output = self.text_decoder(decoder_input_ids, 216 | attention_mask = text.attention_mask, 217 | encoder_hidden_states = output_tagembedding.last_hidden_state, 218 | encoder_attention_mask = None, 219 | labels = decoder_targets, 220 | return_dict = True, 221 | ) 222 | 223 | loss_t2t = decoder_output.loss 224 | 225 | # balance loss scale 226 | loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach() 227 | 228 | return loss 229 | 230 | def generate_image_embeds(self, 231 | image, 232 | condition=False 233 | ): 234 | 235 | image_embeds = self.visual_encoder(image) 236 | 237 | return image_embeds 238 | 239 | def condition_forward(self, 240 | image, 241 | sample=False, 242 | num_beams=3, 243 | max_length=30, 244 | min_length=10, 245 | top_p=0.9, 246 | repetition_penalty=1.0, 247 | tag_input=None, 248 | return_tag_predict=False): 249 | 250 | image_embeds = self.visual_encoder(image) 251 | image_atts = torch.ones(image_embeds.size()[:-1], 252 | dtype=torch.long).to(image.device) 253 | 254 | # if not user specified tags, recognized image tags using image-tag recogntiion decoder 255 | 256 | 257 | bs = image_embeds.shape[0] 258 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) 259 | tagging_embed = self.tagging_head( 260 | encoder_embeds=label_embed, 261 | encoder_hidden_states=image_embeds, 262 | encoder_attention_mask=image_atts, 263 | return_dict=False, 264 | mode='tagging', 265 | ) 266 | 267 | logits = self.fc(tagging_embed[0]) 268 | 269 | targets = torch.where( 270 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 271 | torch.tensor(1.0).to(image.device), 272 | torch.zeros(self.num_class).to(image.device)) 273 | 274 | # delete some tags that may disturb captioning 275 | targets[:, self.delete_tag_index] = 0 276 | 277 | return image_embeds, logits, targets 278 | 279 | 280 | def generate(self, 281 | image, 282 | sample=False, 283 | num_beams=3, 284 | max_length=30, 285 | min_length=10, 286 | top_p=0.9, 287 | repetition_penalty=1.0, 288 | tag_input=None, 289 | return_tag_predict=False): 290 | 291 | image_embeds = self.visual_encoder(image) 292 | image_atts = torch.ones(image_embeds.size()[:-1], 293 | dtype=torch.long).to(image.device) 294 | 295 | # if not user specified tags, recognized image tags using image-tag recogntiion decoder 296 | if tag_input == None: 297 | 298 | bs = image_embeds.shape[0] 299 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) 300 | tagging_embed = self.tagging_head( 301 | encoder_embeds=label_embed, 302 | encoder_hidden_states=image_embeds, 303 | encoder_attention_mask=image_atts, 304 | return_dict=False, 305 | mode='tagging', 306 | ) 307 | 308 | logits = self.fc(tagging_embed[0]) 309 | 310 | targets = torch.where( 311 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 312 | torch.tensor(1.0).to(image.device), 313 | torch.zeros(self.num_class).to(image.device)) 314 | 315 | tag = targets.cpu().numpy() 316 | 317 | # delete some tags that may disturb captioning 318 | tag[:, self.delete_tag_index] = 0 319 | 320 | tag_input = [] 321 | for b in range(bs): 322 | index = np.argwhere(tag[b] == 1) 323 | token = self.tag_list[index].squeeze(axis=1) 324 | tag_input.append(', '.join(token)) 325 | 326 | tag_output = tag_input 327 | 328 | # beam search for text generation(default) 329 | if not sample: 330 | image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) 331 | tag_input_temp = [] 332 | for tag in tag_input: 333 | for i in range(num_beams): 334 | tag_input_temp.append(tag) 335 | tag_input = tag_input_temp 336 | 337 | image_atts = torch.ones(image_embeds.size()[:-1], 338 | dtype=torch.long).to(image.device) 339 | 340 | # tokenizer input tags 341 | tag_input_tokenzier = self.tokenizer(tag_input, 342 | padding='max_length', 343 | truncation=True, 344 | max_length=40, 345 | return_tensors="pt").to( 346 | image.device) 347 | encoder_input_ids = tag_input_tokenzier.input_ids 348 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id 349 | 350 | # put input tag into image-tag interaction encoder to interact with image embeddings 351 | output_tagembedding = self.tag_encoder( 352 | encoder_input_ids, 353 | attention_mask=tag_input_tokenzier.attention_mask, 354 | encoder_hidden_states=image_embeds, 355 | encoder_attention_mask=image_atts, 356 | return_dict=True, 357 | ) 358 | 359 | # prompt trick for better captioning, followed BLIP 360 | prompt = [self.prompt] * image.size(0) 361 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( 362 | image.device) 363 | input_ids[:, 0] = self.tokenizer.bos_token_id 364 | input_ids = input_ids[:, :-1] 365 | 366 | if sample: 367 | # nucleus sampling 368 | model_kwargs = { 369 | "encoder_hidden_states": output_tagembedding.last_hidden_state, 370 | "encoder_attention_mask": None 371 | } 372 | outputs = self.text_decoder.generate( 373 | input_ids=input_ids, 374 | max_length=max_length, 375 | min_length=min_length, 376 | do_sample=True, 377 | top_p=top_p, 378 | num_return_sequences=1, 379 | eos_token_id=self.tokenizer.sep_token_id, 380 | pad_token_id=self.tokenizer.pad_token_id, 381 | repetition_penalty=1.1, 382 | **model_kwargs) 383 | else: 384 | # beam search (default) 385 | model_kwargs = { 386 | "encoder_hidden_states": output_tagembedding.last_hidden_state, 387 | "encoder_attention_mask": None 388 | } 389 | outputs = self.text_decoder.generate( 390 | input_ids=input_ids, 391 | max_length=max_length, 392 | min_length=min_length, 393 | num_beams=num_beams, 394 | eos_token_id=self.tokenizer.sep_token_id, 395 | pad_token_id=self.tokenizer.pad_token_id, 396 | repetition_penalty=repetition_penalty, 397 | **model_kwargs) 398 | 399 | captions = [] 400 | for output in outputs: 401 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 402 | captions.append(caption[len(self.prompt):]) 403 | if return_tag_predict == True: 404 | return captions, tag_output 405 | return captions 406 | 407 | 408 | # load Tag2Text pretrained model parameters 409 | def tag2text(pretrained='', **kwargs): 410 | model = Tag2Text(**kwargs) 411 | if pretrained: 412 | if kwargs['vit'] == 'swin_b': 413 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) 414 | else: 415 | model, msg = load_checkpoint(model, pretrained) 416 | print('vit:', kwargs['vit']) 417 | # print('msg', msg) 418 | return model 419 | 420 | -------------------------------------------------------------------------------- /ram/models/tag2text_lora.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Tag2Text Model 3 | * Written by Xinyu Huang 4 | ''' 5 | import numpy as np 6 | import json 7 | import torch 8 | import warnings 9 | 10 | from torch import nn 11 | from .bert_lora import BertConfig, BertModel, BertLMHeadModel 12 | from .swin_transformer_lora import SwinTransformer 13 | 14 | from .utils import * 15 | 16 | warnings.filterwarnings("ignore") 17 | 18 | 19 | class Tag2Text(nn.Module): 20 | 21 | def __init__(self, 22 | med_config=f'{CONFIG_PATH}/configs/med_config.json', 23 | image_size=384, 24 | vit='base', 25 | vit_grad_ckpt=False, 26 | vit_ckpt_layer=0, 27 | prompt='a picture of ', 28 | threshold=0.68, 29 | delete_tag_index=[127,2961, 3351, 3265, 3338, 3355, 3359], 30 | tag_list=f'{CONFIG_PATH}/data/tag_list.txt'): 31 | r""" Tag2Text inference module, both captioning and tagging are included. 32 | Tag2Text is an efficient and controllable vision-language pre-training framework. 33 | Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657 34 | 35 | Args: 36 | med_config (str): path for the mixture of encoder-decoder model's configuration file 37 | image_size (int): input image size 38 | vit (str): model size of vision transformer 39 | threshold (int): tagging threshold 40 | delete_tag_index (list): delete some tags that may disturb captioning 41 | """ 42 | super().__init__() 43 | 44 | # create image encoder 45 | if vit == 'swin_b': 46 | if image_size == 224: 47 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 48 | elif image_size == 384: 49 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 50 | vision_config = read_json(vision_config_path) 51 | assert image_size == vision_config['image_res'] 52 | # assert config['patch_size'] == 32 53 | vision_width = vision_config['vision_width'] 54 | 55 | self.visual_encoder = SwinTransformer( 56 | img_size=vision_config['image_res'], 57 | patch_size=4, 58 | in_chans=3, 59 | embed_dim=vision_config['embed_dim'], 60 | depths=vision_config['depths'], 61 | num_heads=vision_config['num_heads'], 62 | window_size=vision_config['window_size'], 63 | mlp_ratio=4., 64 | qkv_bias=True, 65 | drop_rate=0.0, 66 | drop_path_rate=0.1, 67 | ape=False, 68 | patch_norm=True, 69 | use_checkpoint=False) 70 | 71 | else: 72 | self.visual_encoder, vision_width = create_vit( 73 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer) 74 | 75 | # create tokenzier 76 | self.tokenizer = init_tokenizer() 77 | 78 | # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder 79 | # create image-tag interaction encoder 80 | encoder_config = BertConfig.from_json_file(med_config) 81 | encoder_config.encoder_width = vision_width 82 | self.tag_encoder = BertModel(config=encoder_config, 83 | add_pooling_layer=False) 84 | 85 | # create image-tag-text decoder 86 | decoder_config = BertConfig.from_json_file(med_config) 87 | self.text_decoder = BertLMHeadModel(config=decoder_config) 88 | 89 | # delete some tags that may disturb captioning 90 | # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one" 91 | self.delete_tag_index = delete_tag_index 92 | self.prompt = prompt 93 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 94 | 95 | # load tag list 96 | self.tag_list = self.load_tag_list(tag_list) 97 | 98 | # create image-tag recognition decoder 99 | self.threshold = threshold 100 | self.num_class = len(self.tag_list) 101 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') 102 | q2l_config.encoder_width = vision_width 103 | self.tagging_head = BertModel(config=q2l_config, 104 | add_pooling_layer=False) 105 | self.tagging_head.resize_token_embeddings(len(self.tokenizer)) 106 | self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size) 107 | self.fc = GroupWiseLinear(self.num_class, 108 | q2l_config.hidden_size, 109 | bias=True) 110 | self.del_selfattention() 111 | 112 | self.tagging_loss_function = AsymmetricLoss(gamma_neg=7, 113 | gamma_pos=0, 114 | clip=0.05) 115 | 116 | # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder" 117 | tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', 118 | ' ') 119 | 120 | # adjust thresholds for some tags 121 | # default threshold: 0.68 122 | # 2701: "person"; 2828: "man"; 1167: "woman"; 123 | tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7} 124 | self.class_threshold = torch.ones(self.num_class) * self.threshold 125 | for key,value in tag_thrshold.items(): 126 | self.class_threshold[key] = value 127 | 128 | def load_tag_list(self, tag_list_file): 129 | with open(tag_list_file, 'r') as f: 130 | tag_list = f.read().splitlines() 131 | tag_list = np.array(tag_list) 132 | return tag_list 133 | 134 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label 135 | def del_selfattention(self): 136 | del self.tagging_head.embeddings 137 | for layer in self.tagging_head.encoder.layer: 138 | del layer.attention 139 | 140 | 141 | def forward(self, image, caption, tag): 142 | """ 143 | call function as forward 144 | 145 | Args: 146 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384 147 | caption: type: list[string] len: batch_size 148 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0 149 | 150 | Returns: 151 | loss: type: torch.Tensor 152 | """ 153 | 154 | image_embeds = self.visual_encoder(image) 155 | image_atts = torch.ones(image_embeds.size()[:-1], 156 | dtype=torch.long).to(image.device) 157 | 158 | ##================= Image Tagging ================## 159 | bs = image_embeds.shape[0] 160 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) 161 | 162 | tagging_embed = self.tagging_head( 163 | encoder_embeds=label_embed, 164 | encoder_hidden_states=image_embeds, 165 | encoder_attention_mask=image_atts, 166 | return_dict=False, 167 | mode='tagging', 168 | ) 169 | 170 | logits = self.fc(tagging_embed[0]) 171 | 172 | loss_tag = self.tagging_loss_function(logits, tag) 173 | 174 | ##================= Image-Tag-Text Generation ================## 175 | tag = tag.cpu().numpy() 176 | tag_input = [] 177 | for b in range(bs): 178 | index = np.argwhere(tag[b] == 1) 179 | token = self.tag_list[index].squeeze(axis=1) 180 | tag_input.append(' | '.join(token)) 181 | 182 | # tokenizer input tags 183 | tag_input_tokenzier = self.tokenizer(tag_input, 184 | padding='max_length', 185 | truncation=True, 186 | max_length=40, 187 | return_tensors="pt").to( 188 | image.device) 189 | encoder_input_ids = tag_input_tokenzier.input_ids 190 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id 191 | 192 | # put input tag into image-tag interaction encoder to interact with image embeddings 193 | output_tagembedding = self.tag_encoder( 194 | encoder_input_ids, 195 | attention_mask=tag_input_tokenzier.attention_mask, 196 | encoder_hidden_states=image_embeds, 197 | encoder_attention_mask=image_atts, 198 | return_dict=True, 199 | ) 200 | 201 | text = self.tokenizer(caption, 202 | padding='longest', 203 | truncation=True, 204 | max_length=40, 205 | return_tensors="pt").to( 206 | image.device) 207 | 208 | decoder_input_ids = text.input_ids 209 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id 210 | 211 | decoder_targets = decoder_input_ids.masked_fill( 212 | decoder_input_ids == self.tokenizer.pad_token_id, -100) 213 | decoder_targets[:,:self.prompt_length] = -100 214 | 215 | decoder_output = self.text_decoder(decoder_input_ids, 216 | attention_mask = text.attention_mask, 217 | encoder_hidden_states = output_tagembedding.last_hidden_state, 218 | encoder_attention_mask = None, 219 | labels = decoder_targets, 220 | return_dict = True, 221 | ) 222 | 223 | loss_t2t = decoder_output.loss 224 | 225 | # balance loss scale 226 | loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach() 227 | 228 | return loss 229 | 230 | def generate_image_embeds(self, 231 | image, 232 | condition=False 233 | ): 234 | 235 | image_embeds = self.visual_encoder(image) 236 | 237 | return image_embeds 238 | 239 | def condition_forward(self, 240 | image, 241 | sample=False, 242 | num_beams=3, 243 | max_length=30, 244 | min_length=10, 245 | top_p=0.9, 246 | repetition_penalty=1.0, 247 | tag_input=None, 248 | return_tag_predict=False): 249 | 250 | image_embeds = self.visual_encoder(image) 251 | image_atts = torch.ones(image_embeds.size()[:-1], 252 | dtype=torch.long).to(image.device) 253 | 254 | # if not user specified tags, recognized image tags using image-tag recogntiion decoder 255 | 256 | 257 | bs = image_embeds.shape[0] 258 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) 259 | tagging_embed = self.tagging_head( 260 | encoder_embeds=label_embed, 261 | encoder_hidden_states=image_embeds, 262 | encoder_attention_mask=image_atts, 263 | return_dict=False, 264 | mode='tagging', 265 | ) 266 | 267 | logits = self.fc(tagging_embed[0]) 268 | 269 | targets = torch.where( 270 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 271 | torch.tensor(1.0).to(image.device), 272 | torch.zeros(self.num_class).to(image.device)) 273 | 274 | # delete some tags that may disturb captioning 275 | targets[:, self.delete_tag_index] = 0 276 | 277 | return image_embeds, logits, targets 278 | 279 | 280 | def generate(self, 281 | image, 282 | sample=False, 283 | num_beams=3, 284 | max_length=30, 285 | min_length=10, 286 | top_p=0.9, 287 | repetition_penalty=1.0, 288 | tag_input=None, 289 | return_tag_predict=False): 290 | 291 | image_embeds = self.visual_encoder(image) 292 | image_atts = torch.ones(image_embeds.size()[:-1], 293 | dtype=torch.long).to(image.device) 294 | 295 | # if not user specified tags, recognized image tags using image-tag recogntiion decoder 296 | if tag_input == None: 297 | 298 | bs = image_embeds.shape[0] 299 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) 300 | tagging_embed = self.tagging_head( 301 | encoder_embeds=label_embed, 302 | encoder_hidden_states=image_embeds, 303 | encoder_attention_mask=image_atts, 304 | return_dict=False, 305 | mode='tagging', 306 | ) 307 | 308 | logits = self.fc(tagging_embed[0]) 309 | 310 | targets = torch.where( 311 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 312 | torch.tensor(1.0).to(image.device), 313 | torch.zeros(self.num_class).to(image.device)) 314 | 315 | tag = targets.cpu().numpy() 316 | 317 | # delete some tags that may disturb captioning 318 | tag[:, self.delete_tag_index] = 0 319 | 320 | tag_input = [] 321 | for b in range(bs): 322 | index = np.argwhere(tag[b] == 1) 323 | token = self.tag_list[index].squeeze(axis=1) 324 | tag_input.append(', '.join(token)) 325 | 326 | tag_output = tag_input 327 | 328 | # beam search for text generation(default) 329 | if not sample: 330 | image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) 331 | tag_input_temp = [] 332 | for tag in tag_input: 333 | for i in range(num_beams): 334 | tag_input_temp.append(tag) 335 | tag_input = tag_input_temp 336 | 337 | image_atts = torch.ones(image_embeds.size()[:-1], 338 | dtype=torch.long).to(image.device) 339 | 340 | # tokenizer input tags 341 | tag_input_tokenzier = self.tokenizer(tag_input, 342 | padding='max_length', 343 | truncation=True, 344 | max_length=40, 345 | return_tensors="pt").to( 346 | image.device) 347 | encoder_input_ids = tag_input_tokenzier.input_ids 348 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id 349 | 350 | # put input tag into image-tag interaction encoder to interact with image embeddings 351 | output_tagembedding = self.tag_encoder( 352 | encoder_input_ids, 353 | attention_mask=tag_input_tokenzier.attention_mask, 354 | encoder_hidden_states=image_embeds, 355 | encoder_attention_mask=image_atts, 356 | return_dict=True, 357 | ) 358 | 359 | # prompt trick for better captioning, followed BLIP 360 | prompt = [self.prompt] * image.size(0) 361 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( 362 | image.device) 363 | input_ids[:, 0] = self.tokenizer.bos_token_id 364 | input_ids = input_ids[:, :-1] 365 | 366 | if sample: 367 | # nucleus sampling 368 | model_kwargs = { 369 | "encoder_hidden_states": output_tagembedding.last_hidden_state, 370 | "encoder_attention_mask": None 371 | } 372 | outputs = self.text_decoder.generate( 373 | input_ids=input_ids, 374 | max_length=max_length, 375 | min_length=min_length, 376 | do_sample=True, 377 | top_p=top_p, 378 | num_return_sequences=1, 379 | eos_token_id=self.tokenizer.sep_token_id, 380 | pad_token_id=self.tokenizer.pad_token_id, 381 | repetition_penalty=1.1, 382 | **model_kwargs) 383 | else: 384 | # beam search (default) 385 | model_kwargs = { 386 | "encoder_hidden_states": output_tagembedding.last_hidden_state, 387 | "encoder_attention_mask": None 388 | } 389 | outputs = self.text_decoder.generate( 390 | input_ids=input_ids, 391 | max_length=max_length, 392 | min_length=min_length, 393 | num_beams=num_beams, 394 | eos_token_id=self.tokenizer.sep_token_id, 395 | pad_token_id=self.tokenizer.pad_token_id, 396 | repetition_penalty=repetition_penalty, 397 | **model_kwargs) 398 | 399 | captions = [] 400 | for output in outputs: 401 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 402 | captions.append(caption[len(self.prompt):]) 403 | if return_tag_predict == True: 404 | return captions, tag_output 405 | return captions 406 | 407 | 408 | # load Tag2Text pretrained model parameters 409 | def tag2text(pretrained='', **kwargs): 410 | model = Tag2Text(**kwargs) 411 | if pretrained: 412 | if kwargs['vit'] == 'swin_b': 413 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) 414 | else: 415 | model, msg = load_checkpoint(model, pretrained) 416 | print('vit:', kwargs['vit']) 417 | # print('msg', msg) 418 | return model 419 | 420 | -------------------------------------------------------------------------------- /ram/models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import math 5 | 6 | from torch import nn 7 | from typing import List 8 | from transformers import BertTokenizer 9 | from urllib.parse import urlparse 10 | from timm.models.hub import download_cached_file 11 | from .vit import interpolate_pos_embed 12 | from .swin_transformer import interpolate_relative_pos_embed 13 | from pathlib import Path 14 | CONFIG_PATH=(Path(__file__).resolve().parents[1]) 15 | 16 | def read_json(rpath): 17 | with open(rpath, 'r') as f: 18 | return json.load(f) 19 | 20 | 21 | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, 22 | base_model_prefix: str, skip_key: str): 23 | uninitialized_encoder_weights: List[str] = [] 24 | if decoder.__class__ != encoder.__class__: 25 | logger.info( 26 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." 27 | ) 28 | 29 | def tie_encoder_to_decoder_recursively( 30 | decoder_pointer: nn.Module, 31 | encoder_pointer: nn.Module, 32 | module_name: str, 33 | uninitialized_encoder_weights: List[str], 34 | skip_key: str, 35 | depth=0, 36 | ): 37 | assert isinstance(decoder_pointer, nn.Module) and isinstance( 38 | encoder_pointer, nn.Module 39 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" 40 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name: 41 | assert hasattr(encoder_pointer, "weight") 42 | encoder_pointer.weight = decoder_pointer.weight 43 | if hasattr(decoder_pointer, "bias"): 44 | assert hasattr(encoder_pointer, "bias") 45 | encoder_pointer.bias = decoder_pointer.bias 46 | print(module_name + ' is tied') 47 | return 48 | 49 | encoder_modules = encoder_pointer._modules 50 | decoder_modules = decoder_pointer._modules 51 | if len(decoder_modules) > 0: 52 | assert ( 53 | len(encoder_modules) > 0 54 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" 55 | 56 | all_encoder_weights = set([ 57 | module_name + "/" + sub_name 58 | for sub_name in encoder_modules.keys() 59 | ]) 60 | encoder_layer_pos = 0 61 | for name, module in decoder_modules.items(): 62 | if name.isdigit(): 63 | encoder_name = str(int(name) + encoder_layer_pos) 64 | decoder_name = name 65 | if not isinstance( 66 | decoder_modules[decoder_name], 67 | type(encoder_modules[encoder_name])) and len( 68 | encoder_modules) != len(decoder_modules): 69 | # this can happen if the name corresponds to the position in a list module list of layers 70 | # in this case the decoder has added a cross-attention that the encoder does not have 71 | # thus skip this step and subtract one layer pos from encoder 72 | encoder_layer_pos -= 1 73 | continue 74 | elif name not in encoder_modules: 75 | continue 76 | elif depth > 500: 77 | raise ValueError( 78 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." 79 | ) 80 | else: 81 | decoder_name = encoder_name = name 82 | tie_encoder_to_decoder_recursively( 83 | decoder_modules[decoder_name], 84 | encoder_modules[encoder_name], 85 | module_name + "/" + name, 86 | uninitialized_encoder_weights, 87 | skip_key, 88 | depth=depth + 1, 89 | ) 90 | all_encoder_weights.remove(module_name + "/" + encoder_name) 91 | 92 | uninitialized_encoder_weights += list(all_encoder_weights) 93 | 94 | # tie weights recursively 95 | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, 96 | uninitialized_encoder_weights, skip_key) 97 | 98 | 99 | class GroupWiseLinear(nn.Module): 100 | # could be changed to: 101 | # output = torch.einsum('ijk,zjk->ij', x, self.W) 102 | # or output = torch.einsum('ijk,jk->ij', x, self.W[0]) 103 | def __init__(self, num_class, hidden_dim, bias=True): 104 | super().__init__() 105 | self.num_class = num_class 106 | self.hidden_dim = hidden_dim 107 | self.bias = bias 108 | 109 | self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim)) 110 | if bias: 111 | self.b = nn.Parameter(torch.Tensor(1, num_class)) 112 | self.reset_parameters() 113 | 114 | def reset_parameters(self): 115 | stdv = 1. / math.sqrt(self.W.size(2)) 116 | for i in range(self.num_class): 117 | self.W[0][i].data.uniform_(-stdv, stdv) 118 | if self.bias: 119 | for i in range(self.num_class): 120 | self.b[0][i].data.uniform_(-stdv, stdv) 121 | 122 | def forward(self, x): 123 | # x: B,K,d 124 | x = (self.W * x).sum(-1) 125 | if self.bias: 126 | x = x + self.b 127 | return x 128 | 129 | 130 | def init_tokenizer(): 131 | # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 132 | tokenizer = BertTokenizer.from_pretrained('/home/notebook/data/group/LowLevelLLM/LLM/bert-base-uncased', local_files_only=True) 133 | tokenizer.add_special_tokens({'bos_token': '[DEC]'}) 134 | tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) 135 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 136 | return tokenizer 137 | 138 | 139 | def create_vit(vit, 140 | image_size, 141 | use_grad_checkpointing=False, 142 | ckpt_layer=0, 143 | drop_path_rate=0): 144 | 145 | assert vit in ['base', 'large'], "vit parameter must be base or large" 146 | if vit == 'base': 147 | vision_width = 768 148 | visual_encoder = VisionTransformer( 149 | img_size=image_size, 150 | patch_size=16, 151 | embed_dim=vision_width, 152 | depth=12, 153 | num_heads=12, 154 | use_grad_checkpointing=use_grad_checkpointing, 155 | ckpt_layer=ckpt_layer, 156 | drop_path_rate=0 or drop_path_rate) 157 | elif vit == 'large': 158 | vision_width = 1024 159 | visual_encoder = VisionTransformer( 160 | img_size=image_size, 161 | patch_size=16, 162 | embed_dim=vision_width, 163 | depth=24, 164 | num_heads=16, 165 | use_grad_checkpointing=use_grad_checkpointing, 166 | ckpt_layer=ckpt_layer, 167 | drop_path_rate=0.1 or drop_path_rate) 168 | return visual_encoder, vision_width 169 | 170 | 171 | def is_url(url_or_filename): 172 | parsed = urlparse(url_or_filename) 173 | return parsed.scheme in ("http", "https") 174 | 175 | 176 | def load_checkpoint(model, url_or_filename): 177 | if is_url(url_or_filename): 178 | cached_file = download_cached_file(url_or_filename, 179 | check_hash=False, 180 | progress=True) 181 | checkpoint = torch.load(cached_file, map_location='cpu') 182 | elif os.path.isfile(url_or_filename): 183 | checkpoint = torch.load(url_or_filename, map_location='cpu') 184 | else: 185 | raise RuntimeError('checkpoint url or path is invalid') 186 | 187 | state_dict = checkpoint['model'] 188 | 189 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed( 190 | state_dict['visual_encoder.pos_embed'], model.visual_encoder) 191 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 192 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed( 193 | state_dict['visual_encoder_m.pos_embed'], model.visual_encoder_m) 194 | for key in model.state_dict().keys(): 195 | if key in state_dict.keys(): 196 | if state_dict[key].shape != model.state_dict()[key].shape: 197 | del state_dict[key] 198 | 199 | msg = model.load_state_dict(state_dict, strict=False) 200 | print('load checkpoint from %s' % url_or_filename) 201 | return model, msg 202 | 203 | # def load_checkpoint_condition(model, url_or_filename): 204 | def load_checkpoint_swinlarge_condition(model, url_or_filename, kwargs): 205 | if kwargs['image_size'] == 224: 206 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' 207 | elif kwargs['image_size'] == 384: 208 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' 209 | window_size = read_json(vision_config_path)['window_size'] 210 | print('--------------') 211 | print(url_or_filename) 212 | print('--------------') 213 | if is_url(url_or_filename): 214 | cached_file = download_cached_file(url_or_filename, 215 | check_hash=False, 216 | progress=True) 217 | checkpoint = torch.load(cached_file, map_location='cpu') 218 | elif os.path.isfile(url_or_filename): 219 | checkpoint = torch.load(url_or_filename, map_location='cpu') 220 | else: 221 | raise RuntimeError('checkpoint url or path is invalid') 222 | 223 | state_dict = checkpoint['params'] 224 | 225 | for k in list(state_dict.keys()): 226 | if 'relative_position_bias_table' in k: 227 | dst_num_pos = (2 * window_size - 1)**2 228 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], 229 | dst_num_pos, 230 | param_name=k) 231 | elif ('relative_position_index' in k) or ('attn_mask' in k): 232 | del state_dict[k] 233 | elif "vision_multi" in k: 234 | state_dict[k.replace("vision_multi", 235 | "tagging_head")] = state_dict.pop(k) 236 | 237 | msg = model.load_state_dict(state_dict, strict=False) 238 | print('load checkpoint from %s' % url_or_filename) 239 | return model, msg 240 | 241 | 242 | def load_checkpoint_swinbase(model, url_or_filename, kwargs): 243 | if kwargs['image_size'] == 224: 244 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 245 | elif kwargs['image_size'] == 384: 246 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 247 | window_size = read_json(vision_config_path)['window_size'] 248 | print('--------------') 249 | print(url_or_filename) 250 | print('--------------') 251 | if is_url(url_or_filename): 252 | cached_file = download_cached_file(url_or_filename, 253 | check_hash=False, 254 | progress=True) 255 | checkpoint = torch.load(cached_file, map_location='cpu') 256 | elif os.path.isfile(url_or_filename): 257 | checkpoint = torch.load(url_or_filename, map_location='cpu') 258 | else: 259 | raise RuntimeError('checkpoint url or path is invalid') 260 | 261 | state_dict = checkpoint['model'] 262 | 263 | for k in list(state_dict.keys()): 264 | if 'relative_position_bias_table' in k: 265 | dst_num_pos = (2 * window_size - 1)**2 266 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], 267 | dst_num_pos, 268 | param_name=k) 269 | elif ('relative_position_index' in k) or ('attn_mask' in k): 270 | del state_dict[k] 271 | elif "vision_multi" in k: 272 | state_dict[k.replace("vision_multi", 273 | "tagging_head")] = state_dict.pop(k) 274 | 275 | msg = model.load_state_dict(state_dict, strict=False) 276 | print('load checkpoint from %s' % url_or_filename) 277 | return model, msg 278 | 279 | 280 | def load_checkpoint_swinlarge(model, url_or_filename, kwargs): 281 | if kwargs['image_size'] == 224: 282 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' 283 | elif kwargs['image_size'] == 384: 284 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' 285 | window_size = read_json(vision_config_path)['window_size'] 286 | print('--------------') 287 | print(url_or_filename) 288 | print('--------------') 289 | if is_url(url_or_filename): 290 | cached_file = download_cached_file(url_or_filename, 291 | check_hash=False, 292 | progress=True) 293 | checkpoint = torch.load(cached_file, map_location='cpu') 294 | elif os.path.isfile(url_or_filename): 295 | checkpoint = torch.load(url_or_filename, map_location='cpu') 296 | else: 297 | raise RuntimeError('checkpoint url or path is invalid') 298 | 299 | state_dict = checkpoint['model'] 300 | 301 | for k in list(state_dict.keys()): 302 | if 'relative_position_bias_table' in k: 303 | dst_num_pos = (2 * window_size - 1)**2 304 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], 305 | dst_num_pos, 306 | param_name=k) 307 | elif ('relative_position_index' in k) or ('attn_mask' in k): 308 | del state_dict[k] 309 | elif "vision_multi" in k: 310 | state_dict[k.replace("vision_multi", 311 | "tagging_head")] = state_dict.pop(k) 312 | 313 | msg = model.load_state_dict(state_dict, strict=False) 314 | print('load checkpoint from %s' % url_or_filename) 315 | return model, msg 316 | 317 | 318 | # Tagging loss function 319 | # copy from https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py 320 | class AsymmetricLoss(nn.Module): 321 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True): 322 | super(AsymmetricLoss, self).__init__() 323 | 324 | self.gamma_neg = gamma_neg 325 | self.gamma_pos = gamma_pos 326 | self.clip = clip 327 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 328 | self.eps = eps 329 | 330 | def forward(self, x, y): 331 | """" 332 | Parameters 333 | ---------- 334 | x: input logits 335 | y: targets (multi-label binarized vector) 336 | """ 337 | 338 | # Calculating Probabilities 339 | x_sigmoid = torch.sigmoid(x) 340 | xs_pos = x_sigmoid 341 | xs_neg = 1 - x_sigmoid 342 | 343 | # Asymmetric Clipping 344 | if self.clip is not None and self.clip > 0: 345 | xs_neg = (xs_neg + self.clip).clamp(max=1) 346 | 347 | # Basic CE calculation 348 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 349 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 350 | loss = los_pos + los_neg 351 | 352 | # Asymmetric Focusing 353 | if self.gamma_neg > 0 or self.gamma_pos > 0: 354 | if self.disable_torch_grad_focal_loss: 355 | torch.set_grad_enabled(False) 356 | pt0 = xs_pos * y 357 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 358 | pt = pt0 + pt1 359 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 360 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 361 | if self.disable_torch_grad_focal_loss: 362 | torch.set_grad_enabled(True) 363 | loss *= one_sided_w 364 | 365 | return -loss.sum() -------------------------------------------------------------------------------- /ram/models/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from functools import partial 15 | 16 | from timm.models.vision_transformer import _cfg, PatchEmbed 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_, DropPath 19 | from timm.models.helpers import named_apply, adapt_input_conv 20 | 21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 22 | 23 | class Mlp(nn.Module): 24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 25 | """ 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 50 | self.scale = qk_scale or head_dim ** -0.5 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | self.attn_gradients = None 56 | self.attention_map = None 57 | 58 | def save_attn_gradients(self, attn_gradients): 59 | self.attn_gradients = attn_gradients 60 | 61 | def get_attn_gradients(self): 62 | return self.attn_gradients 63 | 64 | def save_attention_map(self, attention_map): 65 | self.attention_map = attention_map 66 | 67 | def get_attention_map(self): 68 | return self.attention_map 69 | 70 | def forward(self, x, register_hook=False): 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 74 | 75 | attn = (q @ k.transpose(-2, -1)) * self.scale 76 | attn = attn.softmax(dim=-1) 77 | attn = self.attn_drop(attn) 78 | 79 | if register_hook: 80 | self.save_attention_map(attn) 81 | attn.register_hook(self.save_attn_gradients) 82 | 83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 84 | x = self.proj(x) 85 | x = self.proj_drop(x) 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm2 = norm_layer(dim) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | if use_grad_checkpointing: 104 | self.attn = checkpoint_wrapper(self.attn) 105 | self.mlp = checkpoint_wrapper(self.mlp) 106 | 107 | def forward(self, x, register_hook=False): 108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 109 | x = x + self.drop_path(self.mlp(self.norm2(x))) 110 | return x 111 | 112 | 113 | class VisionTransformer(nn.Module): 114 | """ Vision Transformer 115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 116 | https://arxiv.org/abs/2010.11929 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 121 | use_grad_checkpointing=False, ckpt_layer=0): 122 | """ 123 | Args: 124 | img_size (int, tuple): input image size 125 | patch_size (int, tuple): patch size 126 | in_chans (int): number of input channels 127 | num_classes (int): number of classes for classification head 128 | embed_dim (int): embedding dimension 129 | depth (int): depth of transformer 130 | num_heads (int): number of attention heads 131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 132 | qkv_bias (bool): enable bias for qkv if True 133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 135 | drop_rate (float): dropout rate 136 | attn_drop_rate (float): attention dropout rate 137 | drop_path_rate (float): stochastic depth rate 138 | norm_layer: (nn.Module): normalization layer 139 | """ 140 | super().__init__() 141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 143 | 144 | self.patch_embed = PatchEmbed( 145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 146 | 147 | num_patches = self.patch_embed.num_patches 148 | 149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 151 | self.pos_drop = nn.Dropout(p=drop_rate) 152 | 153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 154 | self.blocks = nn.ModuleList([ 155 | Block( 156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) 159 | ) 160 | for i in range(depth)]) 161 | self.norm = norm_layer(embed_dim) 162 | 163 | trunc_normal_(self.pos_embed, std=.02) 164 | trunc_normal_(self.cls_token, std=.02) 165 | self.apply(self._init_weights) 166 | 167 | def _init_weights(self, m): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_(m.weight, std=.02) 170 | if isinstance(m, nn.Linear) and m.bias is not None: 171 | nn.init.constant_(m.bias, 0) 172 | elif isinstance(m, nn.LayerNorm): 173 | nn.init.constant_(m.bias, 0) 174 | nn.init.constant_(m.weight, 1.0) 175 | 176 | @torch.jit.ignore 177 | def no_weight_decay(self): 178 | return {'pos_embed', 'cls_token'} 179 | 180 | def forward(self, x, register_blk=-1): 181 | B = x.shape[0] 182 | x = self.patch_embed(x) 183 | 184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 185 | x = torch.cat((cls_tokens, x), dim=1) 186 | 187 | x = x + self.pos_embed[:,:x.size(1),:] 188 | x = self.pos_drop(x) 189 | 190 | for i,blk in enumerate(self.blocks): 191 | x = blk(x, register_blk==i) 192 | x = self.norm(x) 193 | 194 | return x 195 | 196 | @torch.jit.ignore() 197 | def load_pretrained(self, checkpoint_path, prefix=''): 198 | _load_weights(self, checkpoint_path, prefix) 199 | 200 | 201 | @torch.no_grad() 202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 204 | """ 205 | import numpy as np 206 | 207 | def _n2p(w, t=True): 208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 209 | w = w.flatten() 210 | if t: 211 | if w.ndim == 4: 212 | w = w.transpose([3, 2, 0, 1]) 213 | elif w.ndim == 3: 214 | w = w.transpose([2, 0, 1]) 215 | elif w.ndim == 2: 216 | w = w.transpose([1, 0]) 217 | return torch.from_numpy(w) 218 | 219 | w = np.load(checkpoint_path) 220 | if not prefix and 'opt/target/embedding/kernel' in w: 221 | prefix = 'opt/target/' 222 | 223 | if hasattr(model.patch_embed, 'backbone'): 224 | # hybrid 225 | backbone = model.patch_embed.backbone 226 | stem_only = not hasattr(backbone, 'stem') 227 | stem = backbone if stem_only else backbone.stem 228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 231 | if not stem_only: 232 | for i, stage in enumerate(backbone.stages): 233 | for j, block in enumerate(stage.blocks): 234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 235 | for r in range(3): 236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 239 | if block.downsample is not None: 240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 244 | else: 245 | embed_conv_w = adapt_input_conv( 246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 247 | model.patch_embed.proj.weight.copy_(embed_conv_w) 248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 251 | if pos_embed_w.shape != model.pos_embed.shape: 252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 254 | model.pos_embed.copy_(pos_embed_w) 255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 263 | for i, block in enumerate(model.blocks.children()): 264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 268 | block.attn.qkv.weight.copy_(torch.cat([ 269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 270 | block.attn.qkv.bias.copy_(torch.cat([ 271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 274 | for r in range(2): 275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 279 | 280 | 281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 282 | # interpolate position embedding 283 | embedding_size = pos_embed_checkpoint.shape[-1] 284 | num_patches = visual_encoder.patch_embed.num_patches 285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 286 | # height (== width) for the checkpoint position embedding 287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 288 | # height (== width) for the new position embedding 289 | new_size = int(num_patches ** 0.5) 290 | 291 | if orig_size!=new_size: 292 | # class_token and dist_token are kept unchanged 293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 294 | # only the position tokens are interpolated 295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 297 | pos_tokens = torch.nn.functional.interpolate( 298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 302 | 303 | return new_pos_embed 304 | else: 305 | return pos_embed_checkpoint -------------------------------------------------------------------------------- /ram/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Normalize, Compose, Resize, ToTensor 2 | 3 | 4 | def convert_to_rgb(image): 5 | return image.convert("RGB") 6 | 7 | def get_transform(image_size=384): 8 | return Compose([ 9 | convert_to_rgb, 10 | Resize((image_size, image_size)), 11 | ToTensor(), 12 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 13 | ]) 14 | -------------------------------------------------------------------------------- /ram/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import get_mAP, get_PR 2 | from .openset_utils import build_openset_label_embedding 3 | -------------------------------------------------------------------------------- /ram/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /ram/utils/__pycache__/metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/utils/__pycache__/metrics.cpython-310.pyc -------------------------------------------------------------------------------- /ram/utils/__pycache__/openset_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/ram/utils/__pycache__/openset_utils.cpython-310.pyc -------------------------------------------------------------------------------- /ram/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | from numpy import ndarray 5 | 6 | 7 | def get_mAP( 8 | preds: ndarray, 9 | gt_file: str, 10 | taglist: List[str] 11 | ) -> Tuple[float, ndarray]: 12 | assert preds.shape[1] == len(taglist) 13 | 14 | # When mapping categories from test datasets to our system, there might be 15 | # multiple vs one situation due to different semantic definitions of tags. 16 | # So there can be duplicate tags in `taglist`. This special case is taken 17 | # into account. 18 | tag2idxs = {} 19 | for idx, tag in enumerate(taglist): 20 | if tag not in tag2idxs: 21 | tag2idxs[tag] = [] 22 | tag2idxs[tag].append(idx) 23 | 24 | # build targets 25 | targets = np.zeros_like(preds) 26 | with open(gt_file, "r") as f: 27 | lines = [line.strip("\n").split(",") for line in f.readlines()] 28 | assert len(lines) == targets.shape[0] 29 | for i, line in enumerate(lines): 30 | for tag in line[1:]: 31 | targets[i, tag2idxs[tag]] = 1.0 32 | 33 | # compute average precision for each class 34 | APs = np.zeros(preds.shape[1]) 35 | for k in range(preds.shape[1]): 36 | APs[k] = _average_precision(preds[:, k], targets[:, k]) 37 | 38 | return APs.mean(), APs 39 | 40 | 41 | def _average_precision(output: ndarray, target: ndarray) -> float: 42 | epsilon = 1e-8 43 | 44 | # sort examples 45 | indices = output.argsort()[::-1] 46 | # Computes prec@i 47 | total_count_ = np.cumsum(np.ones((len(output), 1))) 48 | 49 | target_ = target[indices] 50 | ind = target_ == 1 51 | pos_count_ = np.cumsum(ind) 52 | total = pos_count_[-1] 53 | pos_count_[np.logical_not(ind)] = 0 54 | pp = pos_count_ / total_count_ 55 | precision_at_i_ = np.sum(pp) 56 | precision_at_i = precision_at_i_ / (total + epsilon) 57 | 58 | return precision_at_i 59 | 60 | 61 | def get_PR( 62 | pred_file: str, 63 | gt_file: str, 64 | taglist: List[str] 65 | ) -> Tuple[float, float, ndarray, ndarray]: 66 | # When mapping categories from test datasets to our system, there might be 67 | # multiple vs one situation due to different semantic definitions of tags. 68 | # So there can be duplicate tags in `taglist`. This special case is taken 69 | # into account. 70 | tag2idxs = {} 71 | for idx, tag in enumerate(taglist): 72 | if tag not in tag2idxs: 73 | tag2idxs[tag] = [] 74 | tag2idxs[tag].append(idx) 75 | 76 | # build preds 77 | with open(pred_file, "r", encoding="utf-8") as f: 78 | lines = [line.strip().split(",") for line in f.readlines()] 79 | preds = np.zeros((len(lines), len(tag2idxs)), dtype=bool) 80 | for i, line in enumerate(lines): 81 | for tag in line[1:]: 82 | preds[i, tag2idxs[tag]] = True 83 | 84 | # build targets 85 | with open(gt_file, "r", encoding="utf-8") as f: 86 | lines = [line.strip().split(",") for line in f.readlines()] 87 | targets = np.zeros((len(lines), len(tag2idxs)), dtype=bool) 88 | for i, line in enumerate(lines): 89 | for tag in line[1:]: 90 | targets[i, tag2idxs[tag]] = True 91 | 92 | assert preds.shape == targets.shape 93 | 94 | # calculate P and R 95 | TPs = ( preds & targets).sum(axis=0) # noqa: E201, E222 96 | FPs = ( preds & ~targets).sum(axis=0) # noqa: E201, E222 97 | FNs = (~preds & targets).sum(axis=0) # noqa: E201, E222 98 | eps = 1.e-9 99 | Ps = TPs / (TPs + FPs + eps) 100 | Rs = TPs / (TPs + FNs + eps) 101 | 102 | return Ps.mean(), Rs.mean(), Ps, Rs 103 | -------------------------------------------------------------------------------- /ram/utils/openset_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | from clip import clip 7 | 8 | 9 | def article(name): 10 | return "an" if name[0] in "aeiou" else "a" 11 | 12 | 13 | def processed_name(name, rm_dot=False): 14 | # _ for lvis 15 | # / for obj365 16 | res = name.replace("_", " ").replace("/", " or ").lower() 17 | if rm_dot: 18 | res = res.rstrip(".") 19 | return res 20 | 21 | 22 | single_template = ["a photo of a {}."] 23 | 24 | multiple_templates = [ 25 | "There is {article} {} in the scene.", 26 | "There is the {} in the scene.", 27 | "a photo of {article} {} in the scene.", 28 | "a photo of the {} in the scene.", 29 | "a photo of one {} in the scene.", 30 | "itap of {article} {}.", 31 | "itap of my {}.", # itap: I took a picture of 32 | "itap of the {}.", 33 | "a photo of {article} {}.", 34 | "a photo of my {}.", 35 | "a photo of the {}.", 36 | "a photo of one {}.", 37 | "a photo of many {}.", 38 | "a good photo of {article} {}.", 39 | "a good photo of the {}.", 40 | "a bad photo of {article} {}.", 41 | "a bad photo of the {}.", 42 | "a photo of a nice {}.", 43 | "a photo of the nice {}.", 44 | "a photo of a cool {}.", 45 | "a photo of the cool {}.", 46 | "a photo of a weird {}.", 47 | "a photo of the weird {}.", 48 | "a photo of a small {}.", 49 | "a photo of the small {}.", 50 | "a photo of a large {}.", 51 | "a photo of the large {}.", 52 | "a photo of a clean {}.", 53 | "a photo of the clean {}.", 54 | "a photo of a dirty {}.", 55 | "a photo of the dirty {}.", 56 | "a bright photo of {article} {}.", 57 | "a bright photo of the {}.", 58 | "a dark photo of {article} {}.", 59 | "a dark photo of the {}.", 60 | "a photo of a hard to see {}.", 61 | "a photo of the hard to see {}.", 62 | "a low resolution photo of {article} {}.", 63 | "a low resolution photo of the {}.", 64 | "a cropped photo of {article} {}.", 65 | "a cropped photo of the {}.", 66 | "a close-up photo of {article} {}.", 67 | "a close-up photo of the {}.", 68 | "a jpeg corrupted photo of {article} {}.", 69 | "a jpeg corrupted photo of the {}.", 70 | "a blurry photo of {article} {}.", 71 | "a blurry photo of the {}.", 72 | "a pixelated photo of {article} {}.", 73 | "a pixelated photo of the {}.", 74 | "a black and white photo of the {}.", 75 | "a black and white photo of {article} {}.", 76 | "a plastic {}.", 77 | "the plastic {}.", 78 | "a toy {}.", 79 | "the toy {}.", 80 | "a plushie {}.", 81 | "the plushie {}.", 82 | "a cartoon {}.", 83 | "the cartoon {}.", 84 | "an embroidered {}.", 85 | "the embroidered {}.", 86 | "a painting of the {}.", 87 | "a painting of a {}.", 88 | ] 89 | 90 | 91 | openimages_rare_unseen = ['Aerial photography', 92 | 'Aircraft engine', 93 | 'Ale', 94 | 'Aloe', 95 | 'Amphibian', 96 | 'Angling', 97 | 'Anole', 98 | 'Antique car', 99 | 'Arcade game', 100 | 'Arthropod', 101 | 'Assault rifle', 102 | 'Athletic shoe', 103 | 'Auto racing', 104 | 'Backlighting', 105 | 'Bagpipes', 106 | 'Ball game', 107 | 'Barbecue chicken', 108 | 'Barechested', 109 | 'Barquentine', 110 | 'Beef tenderloin', 111 | 'Billiard room', 112 | 'Billiards', 113 | 'Bird of prey', 114 | 'Black swan', 115 | 'Black-and-white', 116 | 'Blond', 117 | 'Boating', 118 | 'Bonbon', 119 | 'Bottled water', 120 | 'Bouldering', 121 | 'Bovine', 122 | 'Bratwurst', 123 | 'Breadboard', 124 | 'Briefs', 125 | 'Brisket', 126 | 'Brochette', 127 | 'Calabaza', 128 | 'Camera operator', 129 | 'Canola', 130 | 'Childbirth', 131 | 'Chordophone', 132 | 'Church bell', 133 | 'Classical sculpture', 134 | 'Close-up', 135 | 'Cobblestone', 136 | 'Coca-cola', 137 | 'Combat sport', 138 | 'Comics', 139 | 'Compact car', 140 | 'Computer speaker', 141 | 'Cookies and crackers', 142 | 'Coral reef fish', 143 | 'Corn on the cob', 144 | 'Cosmetics', 145 | 'Crocodilia', 146 | 'Digital camera', 147 | 'Dishware', 148 | 'Divemaster', 149 | 'Dobermann', 150 | 'Dog walking', 151 | 'Domestic rabbit', 152 | 'Domestic short-haired cat', 153 | 'Double-decker bus', 154 | 'Drums', 155 | 'Electric guitar', 156 | 'Electric piano', 157 | 'Electronic instrument', 158 | 'Equestrianism', 159 | 'Equitation', 160 | 'Erinaceidae', 161 | 'Extreme sport', 162 | 'Falafel', 163 | 'Figure skating', 164 | 'Filling station', 165 | 'Fire apparatus', 166 | 'Firearm', 167 | 'Flatbread', 168 | 'Floristry', 169 | 'Forklift truck', 170 | 'Freight transport', 171 | 'Fried food', 172 | 'Fried noodles', 173 | 'Frigate', 174 | 'Frozen yogurt', 175 | 'Frying', 176 | 'Full moon', 177 | 'Galleon', 178 | 'Glacial landform', 179 | 'Gliding', 180 | 'Go-kart', 181 | 'Goats', 182 | 'Grappling', 183 | 'Great white shark', 184 | 'Gumbo', 185 | 'Gun turret', 186 | 'Hair coloring', 187 | 'Halter', 188 | 'Headphones', 189 | 'Heavy cruiser', 190 | 'Herding', 191 | 'High-speed rail', 192 | 'Holding hands', 193 | 'Horse and buggy', 194 | 'Horse racing', 195 | 'Hound', 196 | 'Hunting knife', 197 | 'Hurdling', 198 | 'Inflatable', 199 | 'Jackfruit', 200 | 'Jeans', 201 | 'Jiaozi', 202 | 'Junk food', 203 | 'Khinkali', 204 | 'Kitesurfing', 205 | 'Lawn game', 206 | 'Leaf vegetable', 207 | 'Lechon', 208 | 'Lifebuoy', 209 | 'Locust', 210 | 'Lumpia', 211 | 'Luxury vehicle', 212 | 'Machine tool', 213 | 'Medical imaging', 214 | 'Melee weapon', 215 | 'Microcontroller', 216 | 'Middle ages', 217 | 'Military person', 218 | 'Military vehicle', 219 | 'Milky way', 220 | 'Miniature Poodle', 221 | 'Modern dance', 222 | 'Molluscs', 223 | 'Monoplane', 224 | 'Motorcycling', 225 | 'Musical theatre', 226 | 'Narcissus', 227 | 'Nest box', 228 | 'Newsagent\'s shop', 229 | 'Nile crocodile', 230 | 'Nordic skiing', 231 | 'Nuclear power plant', 232 | 'Orator', 233 | 'Outdoor shoe', 234 | 'Parachuting', 235 | 'Pasta salad', 236 | 'Peafowl', 237 | 'Pelmeni', 238 | 'Perching bird', 239 | 'Performance car', 240 | 'Personal water craft', 241 | 'Pit bull', 242 | 'Plant stem', 243 | 'Pork chop', 244 | 'Portrait photography', 245 | 'Primate', 246 | 'Procyonidae', 247 | 'Prosciutto', 248 | 'Public speaking', 249 | 'Racewalking', 250 | 'Ramen', 251 | 'Rear-view mirror', 252 | 'Residential area', 253 | 'Ribs', 254 | 'Rice ball', 255 | 'Road cycling', 256 | 'Roller skating', 257 | 'Roman temple', 258 | 'Rowing', 259 | 'Rural area', 260 | 'Sailboat racing', 261 | 'Scaled reptile', 262 | 'Scuba diving', 263 | 'Senior citizen', 264 | 'Shallot', 265 | 'Shinto shrine', 266 | 'Shooting range', 267 | 'Siberian husky', 268 | 'Sledding', 269 | 'Soba', 270 | 'Solar energy', 271 | 'Sport climbing', 272 | 'Sport utility vehicle', 273 | 'Steamed rice', 274 | 'Stemware', 275 | 'Sumo', 276 | 'Surfing Equipment', 277 | 'Team sport', 278 | 'Touring car', 279 | 'Toy block', 280 | 'Trampolining', 281 | 'Underwater diving', 282 | 'Vegetarian food', 283 | 'Wallaby', 284 | 'Water polo', 285 | 'Watercolor paint', 286 | 'Whiskers', 287 | 'Wind wave', 288 | 'Woodwind instrument', 289 | 'Yakitori', 290 | 'Zeppelin'] 291 | 292 | 293 | def build_openset_label_embedding(categories=None): 294 | if categories is None: 295 | categories = openimages_rare_unseen 296 | # model, _ = clip.load("ViT-B/16") 297 | model, _ = clip.load("ViT-B-16.pt") 298 | templates = multiple_templates 299 | 300 | run_on_gpu = torch.cuda.is_available() 301 | 302 | with torch.no_grad(): 303 | openset_label_embedding = [] 304 | for category in categories: 305 | texts = [ 306 | template.format( 307 | processed_name(category, rm_dot=True), article=article(category) 308 | ) 309 | for template in templates 310 | ] 311 | texts = [ 312 | "This is " + text if text.startswith("a") or text.startswith("the") else text 313 | for text in texts 314 | ] 315 | texts = clip.tokenize(texts) # tokenize 316 | if run_on_gpu: 317 | texts = texts.cuda() 318 | model = model.cuda() 319 | text_embeddings = model.encode_text(texts) 320 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 321 | text_embedding = text_embeddings.mean(dim=0) 322 | text_embedding /= text_embedding.norm() 323 | openset_label_embedding.append(text_embedding) 324 | openset_label_embedding = torch.stack(openset_label_embedding, dim=1) 325 | if run_on_gpu: 326 | openset_label_embedding = openset_label_embedding.cuda() 327 | 328 | openset_label_embedding = openset_label_embedding.t() 329 | return openset_label_embedding, categories 330 | 331 | 332 | 333 | 334 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.25.0 2 | torch==2.0.1 3 | transformers==4.28.1 4 | xformers==0.0.20 5 | einops==0.7.0 6 | open-clip-torch==2.20.0 7 | peft==0.9.0 8 | Pillow==9.5.0 9 | PyYAML==6.0 10 | huggingface_hub==0.25.2 11 | numpy==1.23.5 12 | loralib 13 | basicsr 14 | fairscale 15 | -------------------------------------------------------------------------------- /scripts/get_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def write_png_paths(folder_path, txt_path): 4 | with open(txt_path, 'w') as f: 5 | for root, dirs, files in os.walk(folder_path): 6 | for file in files: 7 | if file.endswith('.png'): 8 | f.write(os.path.join(root, file) + '\n') 9 | 10 | # Example usage: 11 | folder_path = '' 12 | txt_path = '/gt_path.txt' 13 | write_png_paths(folder_path, txt_path) -------------------------------------------------------------------------------- /scripts/test/test_adjustable.sh: -------------------------------------------------------------------------------- 1 | 2 | python test_pisasr.py \ 3 | --pretrained_model_path preset/models/stable-diffusion-2-1-base \ 4 | --pretrained_path preset/models/pisa_sr.pkl \ 5 | --process_size 512 \ 6 | --upscale 4 \ 7 | --input_image preset/test_datasets \ 8 | --output_dir experiments/test \ 9 | --lambda_pix 1.0 \ 10 | --lambda_sem 1.0 11 | -------------------------------------------------------------------------------- /scripts/test/test_default.sh: -------------------------------------------------------------------------------- 1 | 2 | python test_pisasr.py \ 3 | --pretrained_model_path preset/models/stable-diffusion-2-1-base \ 4 | --pretrained_path preset/models/pisa_sr.pkl \ 5 | --process_size 512 \ 6 | --upscale 4 \ 7 | --input_image preset/test_datasets \ 8 | --output_dir experiments/test \ 9 | --default 10 | 11 | -------------------------------------------------------------------------------- /scripts/train/train_pisasr.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_pisasr.py \ 3 | --pretrained_model_path="preset/models/stable-diffusion-2-1-base" \ 4 | --pretrained_model_path_csd="preset/models/stable-diffusion-2-1-base" \ 5 | --dataset_txt_paths="preset/gt_path.txt" \ 6 | --highquality_dataset_txt_paths="preset/gt_selected_path.txt" \ 7 | --dataset_test_folder="preset/testfolder" \ 8 | --learning_rate=5e-5 \ 9 | --train_batch_size=4 \ 10 | --prob=0.0 \ 11 | --gradient_accumulation_steps=1 \ 12 | --enable_xformers_memory_efficient_attention --checkpointing_steps 500 \ 13 | --seed 123 \ 14 | --output_dir="experiments/train-pisasr" \ 15 | --cfg_csd 7.5 \ 16 | --timesteps1 1 \ 17 | --lambda_lpips=2.0 \ 18 | --lambda_l2=1.0 \ 19 | --lambda_csd=1.0 \ 20 | --pix_steps=4000 \ 21 | --lora_rank_unet_pix=4 \ 22 | --lora_rank_unet_sem=4 \ 23 | --min_dm_step_ratio=0.02 \ 24 | --max_dm_step_ratio=0.5 \ 25 | --null_text_ratio=0.5 \ 26 | --align_method="adain" \ 27 | --deg_file_path="params.yml" \ 28 | --tracker_project_name "PiSASR" \ 29 | --is_module True 30 | -------------------------------------------------------------------------------- /src/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | from PIL import Image 5 | from torchvision import transforms 6 | import torchvision.transforms.functional as F 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | from src.datasets.realesrgan import RealESRGAN_degradation 11 | 12 | 13 | 14 | class PairedSROnlineTxtDataset(torch.utils.data.Dataset): 15 | def __init__(self, split=None, args=None): 16 | super().__init__() 17 | 18 | self.args = args 19 | self.split = split 20 | if split == 'train': 21 | self.degradation = RealESRGAN_degradation(args.deg_file_path, device='cpu') 22 | self.crop_preproc = transforms.Compose([ 23 | transforms.RandomCrop((args.resolution_ori, args.resolution_ori)), 24 | transforms.Resize((args.resolution_tgt, args.resolution_tgt)), 25 | transforms.RandomHorizontalFlip(), 26 | ]) 27 | with open(args.dataset_txt_paths, 'r') as f: 28 | self.gt_list = [line.strip() for line in f.readlines()] 29 | if args.highquality_dataset_txt_paths is not None: 30 | with open(args.highquality_dataset_txt_paths, 'r') as f: 31 | self.hq_gt_list = [line.strip() for line in f.readlines()] 32 | 33 | elif split == 'test': 34 | self.input_folder = os.path.join(args.dataset_test_folder, "test_SR_bicubic") 35 | self.output_folder = os.path.join(args.dataset_test_folder, "test_HR") 36 | self.lr_list = [] 37 | self.gt_list = [] 38 | lr_names = os.listdir(os.path.join(self.input_folder)) 39 | gt_names = os.listdir(os.path.join(self.output_folder)) 40 | assert len(lr_names) == len(gt_names) 41 | for i in range(len(lr_names)): 42 | self.lr_list.append(os.path.join(self.input_folder, lr_names[i])) 43 | self.gt_list.append(os.path.join(self.output_folder,gt_names[i])) 44 | self.crop_preproc = transforms.Compose([ 45 | transforms.RandomCrop((args.resolution_ori, args.resolution_ori)), 46 | transforms.Resize((args.resolution_tgt, args.resolution_tgt)), 47 | ]) 48 | assert len(self.lr_list) == len(self.gt_list) 49 | 50 | def __len__(self): 51 | return len(self.gt_list) 52 | 53 | def __getitem__(self, idx): 54 | 55 | if self.split == 'train': 56 | if self.args.highquality_dataset_txt_paths is not None: 57 | if np.random.uniform() < self.args.prob: 58 | gt_img = Image.open(self.gt_list[idx]).convert('RGB') 59 | else: 60 | idx = random.sample(range(0, len(self.hq_gt_list)), 1) 61 | gt_img = Image.open(self.hq_gt_list[idx[0]]).convert('RGB') 62 | else: 63 | gt_img = Image.open(self.gt_list[idx]).convert('RGB') 64 | gt_img = self.crop_preproc(gt_img) 65 | 66 | output_t, img_t = self.degradation.degrade_process(np.asarray(gt_img)/255., resize_bak=True) 67 | output_t, img_t = output_t.squeeze(0), img_t.squeeze(0) 68 | 69 | # input images scaled to -1,1 70 | img_t = F.normalize(img_t, mean=[0.5], std=[0.5]) 71 | # output images scaled to -1,1 72 | output_t = F.normalize(output_t, mean=[0.5], std=[0.5]) 73 | 74 | example = {} 75 | # example["prompt"] = caption 76 | example["neg_prompt"] = self.args.neg_prompt_csd 77 | example["null_prompt"] = "" 78 | example["output_pixel_values"] = output_t 79 | example["conditioning_pixel_values"] = img_t 80 | 81 | return example 82 | 83 | elif self.split == 'test': 84 | input_img = Image.open(self.lr_list[idx]).convert('RGB') 85 | output_img = Image.open(self.gt_list[idx]).convert('RGB') 86 | img_t = self.crop_preproc(input_img) 87 | output_t = self.crop_preproc(output_img) 88 | # input images scaled to -1, 1 89 | img_t = F.to_tensor(img_t) 90 | img_t = F.normalize(img_t, mean=[0.5], std=[0.5]) 91 | # output images scaled to -1,1 92 | output_t = F.to_tensor(output_t) 93 | output_t = F.normalize(output_t, mean=[0.5], std=[0.5]) 94 | 95 | example = {} 96 | example["neg_prompt"] = self.args.neg_prompt_csd 97 | example["null_prompt"] = "" 98 | example["output_pixel_values"] = output_t 99 | example["conditioning_pixel_values"] = img_t 100 | example["base_name"] = os.path.basename(self.lr_list[idx]) 101 | 102 | return example 103 | -------------------------------------------------------------------------------- /src/datasets/params.yml: -------------------------------------------------------------------------------- 1 | scale: 4 2 | color_jitter_prob: 0.0 3 | gray_prob: 0.0 4 | 5 | # the first degradation process 6 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep 7 | resize_range: [0.3, 1.5] 8 | gaussian_noise_prob: 0.5 9 | noise_range: [1, 15] 10 | poisson_scale_range: [0.05, 2.0] 11 | gray_noise_prob: 0.4 12 | jpeg_range: [60, 95] 13 | 14 | # the second degradation process 15 | second_blur_prob: 0.5 16 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep 17 | resize_range2: [0.6, 1.2] 18 | gaussian_noise_prob2: 0.5 19 | noise_range2: [1, 12] 20 | poisson_scale_range2: [0.05, 1.0] 21 | gray_noise_prob2: 0.4 22 | jpeg_range2: [60, 100] 23 | 24 | kernel_info: 25 | blur_kernel_size: 21 26 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 27 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 28 | sinc_prob: 0.1 29 | blur_sigma: [0.2, 3] 30 | betag_range: [0.5, 4] 31 | betap_range: [1, 2] 32 | 33 | blur_kernel_size2: 21 34 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 35 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 36 | sinc_prob2: 0.1 37 | blur_sigma2: [0.2, 1.5] 38 | betag_range2: [0.5, 4] 39 | betap_range2: [1, 2] 40 | 41 | final_sinc_prob: 0.8 -------------------------------------------------------------------------------- /src/datasets/realesrgan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import glob 5 | import math 6 | import yaml 7 | import random 8 | from collections import OrderedDict 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from basicsr.data.transforms import augment 13 | from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels 14 | from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img 15 | from basicsr.utils.img_process_util import filter2D 16 | from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt 17 | from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, 18 | normalize, rgb_to_grayscale) 19 | 20 | cur_path = os.path.dirname(os.path.abspath(__file__)) 21 | 22 | def ordered_yaml(): 23 | """Support OrderedDict for yaml. 24 | 25 | Returns: 26 | yaml Loader and Dumper. 27 | """ 28 | try: 29 | from yaml import CDumper as Dumper 30 | from yaml import CLoader as Loader 31 | except ImportError: 32 | from yaml import Dumper, Loader 33 | 34 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 35 | 36 | def dict_representer(dumper, data): 37 | return dumper.represent_dict(data.items()) 38 | 39 | def dict_constructor(loader, node): 40 | return OrderedDict(loader.construct_pairs(node)) 41 | 42 | Dumper.add_representer(OrderedDict, dict_representer) 43 | Loader.add_constructor(_mapping_tag, dict_constructor) 44 | return Loader, Dumper 45 | 46 | def opt_parse(opt_path): 47 | with open(opt_path, mode='r') as f: 48 | Loader, _ = ordered_yaml() 49 | opt = yaml.load(f, Loader=Loader) # ignore_security_alert_wait_for_fix RCE 50 | 51 | return opt 52 | 53 | class RealESRGAN_degradation(object): 54 | def __init__(self, opt_name='params_realesrgan.yml', device='cpu'): 55 | opt_path = f'{cur_path}/{opt_name}' 56 | self.opt = opt_parse(opt_path) 57 | self.device = device #torch.device('cpu') 58 | optk = self.opt['kernel_info'] 59 | 60 | # blur settings for the first degradation 61 | self.blur_kernel_size = optk['blur_kernel_size'] 62 | self.kernel_list = optk['kernel_list'] 63 | self.kernel_prob = optk['kernel_prob'] 64 | self.blur_sigma = optk['blur_sigma'] 65 | self.betag_range = optk['betag_range'] 66 | self.betap_range = optk['betap_range'] 67 | self.sinc_prob = optk['sinc_prob'] 68 | 69 | # blur settings for the second degradation 70 | self.blur_kernel_size2 = optk['blur_kernel_size2'] 71 | self.kernel_list2 = optk['kernel_list2'] 72 | self.kernel_prob2 = optk['kernel_prob2'] 73 | self.blur_sigma2 = optk['blur_sigma2'] 74 | self.betag_range2 = optk['betag_range2'] 75 | self.betap_range2 = optk['betap_range2'] 76 | self.sinc_prob2 = optk['sinc_prob2'] 77 | 78 | # a final sinc filter 79 | self.final_sinc_prob = optk['final_sinc_prob'] 80 | 81 | self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 82 | self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect 83 | self.pulse_tensor[10, 10] = 1 84 | 85 | self.jpeger = DiffJPEG(differentiable=False).to(self.device) 86 | self.usm_shaper = USMSharp().to(self.device) 87 | 88 | def color_jitter_pt(self, img, brightness, contrast, saturation, hue): 89 | fn_idx = torch.randperm(4) 90 | for fn_id in fn_idx: 91 | if fn_id == 0 and brightness is not None: 92 | brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() 93 | img = adjust_brightness(img, brightness_factor) 94 | 95 | if fn_id == 1 and contrast is not None: 96 | contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() 97 | img = adjust_contrast(img, contrast_factor) 98 | 99 | if fn_id == 2 and saturation is not None: 100 | saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() 101 | img = adjust_saturation(img, saturation_factor) 102 | 103 | if fn_id == 3 and hue is not None: 104 | hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() 105 | img = adjust_hue(img, hue_factor) 106 | return img 107 | 108 | def random_augment(self, img_gt): 109 | # random horizontal flip 110 | img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True) 111 | """ 112 | # random color jitter 113 | if np.random.uniform() < self.opt['color_jitter_prob']: 114 | jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) 115 | img_gt = img_gt + jitter_val 116 | img_gt = np.clip(img_gt, 0, 1) 117 | 118 | # random grayscale 119 | if np.random.uniform() < self.opt['gray_prob']: 120 | #img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) 121 | img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY) 122 | img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) 123 | """ 124 | # BGR to RGB, HWC to CHW, numpy to tensor 125 | img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0) 126 | 127 | return img_gt 128 | 129 | def random_kernels(self): 130 | # ------------------------ Generate kernels (used in the first degradation) ------------------------ # 131 | kernel_size = random.choice(self.kernel_range) 132 | if np.random.uniform() < self.sinc_prob: 133 | # this sinc filter setting is for kernels ranging from [7, 21] 134 | if kernel_size < 13: 135 | omega_c = np.random.uniform(np.pi / 3, np.pi) 136 | else: 137 | omega_c = np.random.uniform(np.pi / 5, np.pi) 138 | kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) 139 | else: 140 | kernel = random_mixed_kernels( 141 | self.kernel_list, 142 | self.kernel_prob, 143 | kernel_size, 144 | self.blur_sigma, 145 | self.blur_sigma, [-math.pi, math.pi], 146 | self.betag_range, 147 | self.betap_range, 148 | noise_range=None) 149 | # pad kernel 150 | pad_size = (21 - kernel_size) // 2 151 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) 152 | 153 | # ------------------------ Generate kernels (used in the second degradation) ------------------------ # 154 | kernel_size = random.choice(self.kernel_range) 155 | if np.random.uniform() < self.sinc_prob2: 156 | if kernel_size < 13: 157 | omega_c = np.random.uniform(np.pi / 3, np.pi) 158 | else: 159 | omega_c = np.random.uniform(np.pi / 5, np.pi) 160 | kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) 161 | else: 162 | kernel2 = random_mixed_kernels( 163 | self.kernel_list2, 164 | self.kernel_prob2, 165 | kernel_size, 166 | self.blur_sigma2, 167 | self.blur_sigma2, [-math.pi, math.pi], 168 | self.betag_range2, 169 | self.betap_range2, 170 | noise_range=None) 171 | 172 | # pad kernel 173 | pad_size = (21 - kernel_size) // 2 174 | kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) 175 | 176 | # ------------------------------------- sinc kernel ------------------------------------- # 177 | if np.random.uniform() < self.final_sinc_prob: 178 | kernel_size = random.choice(self.kernel_range) 179 | omega_c = np.random.uniform(np.pi / 3, np.pi) 180 | sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) 181 | sinc_kernel = torch.FloatTensor(sinc_kernel) 182 | else: 183 | sinc_kernel = self.pulse_tensor 184 | 185 | kernel = torch.FloatTensor(kernel) 186 | kernel2 = torch.FloatTensor(kernel2) 187 | 188 | return kernel, kernel2, sinc_kernel 189 | 190 | @torch.no_grad() 191 | def degrade_process(self, img_gt, resize_bak=False): 192 | img_gt = self.random_augment(img_gt) 193 | kernel1, kernel2, sinc_kernel = self.random_kernels() 194 | img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device) 195 | #img_gt = self.usm_shaper(img_gt) # shaper gt 196 | ori_h, ori_w = img_gt.size()[2:4] 197 | 198 | #scale_final = random.randint(4, 16) 199 | scale_final = 4 200 | 201 | # ----------------------- The first degradation process ----------------------- # 202 | # blur 203 | out = filter2D(img_gt, kernel1) 204 | # random resize 205 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] 206 | if updown_type == 'up': 207 | scale = np.random.uniform(1, self.opt['resize_range'][1]) 208 | elif updown_type == 'down': 209 | scale = np.random.uniform(self.opt['resize_range'][0], 1) 210 | else: 211 | scale = 1 212 | mode = random.choice(['area', 'bilinear', 'bicubic']) 213 | out = F.interpolate(out, scale_factor=scale, mode=mode) 214 | # noise 215 | gray_noise_prob = self.opt['gray_noise_prob'] 216 | if np.random.uniform() < self.opt['gaussian_noise_prob']: 217 | out = random_add_gaussian_noise_pt( 218 | out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) 219 | else: 220 | out = random_add_poisson_noise_pt( 221 | out, 222 | scale_range=self.opt['poisson_scale_range'], 223 | gray_prob=gray_noise_prob, 224 | clip=True, 225 | rounds=False) 226 | # JPEG compression 227 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) 228 | out = torch.clamp(out, 0, 1) 229 | out = self.jpeger(out, quality=jpeg_p) 230 | 231 | # ----------------------- The second degradation process ----------------------- # 232 | # blur 233 | if np.random.uniform() < self.opt['second_blur_prob']: 234 | out = filter2D(out, kernel2) 235 | # random resize 236 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] 237 | if updown_type == 'up': 238 | scale = np.random.uniform(1, self.opt['resize_range2'][1]) 239 | elif updown_type == 'down': 240 | scale = np.random.uniform(self.opt['resize_range2'][0], 1) 241 | else: 242 | scale = 1 243 | mode = random.choice(['area', 'bilinear', 'bicubic']) 244 | out = F.interpolate( 245 | out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode) 246 | # noise 247 | gray_noise_prob = self.opt['gray_noise_prob2'] 248 | if np.random.uniform() < self.opt['gaussian_noise_prob2']: 249 | out = random_add_gaussian_noise_pt( 250 | out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) 251 | else: 252 | out = random_add_poisson_noise_pt( 253 | out, 254 | scale_range=self.opt['poisson_scale_range2'], 255 | gray_prob=gray_noise_prob, 256 | clip=True, 257 | rounds=False) 258 | 259 | # JPEG compression + the final sinc filter 260 | # We also need to resize images to desired sizes. We group [resize back + sinc filter] together 261 | # as one operation. 262 | # We consider two orders: 263 | # 1. [resize back + sinc filter] + JPEG compression 264 | # 2. JPEG compression + [resize back + sinc filter] 265 | # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. 266 | if np.random.uniform() < 0.5: 267 | # resize back + the final sinc filter 268 | mode = random.choice(['area', 'bilinear', 'bicubic']) 269 | out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode) 270 | out = filter2D(out, sinc_kernel) 271 | # JPEG compression 272 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 273 | out = torch.clamp(out, 0, 1) 274 | out = self.jpeger(out, quality=jpeg_p) 275 | else: 276 | # JPEG compression 277 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 278 | out = torch.clamp(out, 0, 1) 279 | out = self.jpeger(out, quality=jpeg_p) 280 | # resize back + the final sinc filter 281 | mode = random.choice(['area', 'bilinear', 'bicubic']) 282 | out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode) 283 | out = filter2D(out, sinc_kernel) 284 | 285 | if np.random.uniform() < self.opt['gray_prob']: 286 | out = rgb_to_grayscale(out, num_output_channels=1) 287 | 288 | if np.random.uniform() < self.opt['color_jitter_prob']: 289 | brightness = self.opt.get('brightness', (0.5, 1.5)) 290 | contrast = self.opt.get('contrast', (0.5, 1.5)) 291 | saturation = self.opt.get('saturation', (0, 1.5)) 292 | hue = self.opt.get('hue', (-0.1, 0.1)) 293 | out = self.color_jitter_pt(out, brightness, contrast, saturation, hue) 294 | 295 | if resize_bak: 296 | mode = random.choice(['area', 'bilinear', 'bicubic']) 297 | out = F.interpolate(out, size=(ori_h, ori_w), mode=mode) 298 | # clamp and round 299 | img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. 300 | 301 | return img_gt, img_lq -------------------------------------------------------------------------------- /src/models/autoencoder_kl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Dict, Optional, Tuple, Union 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | from diffusers.configuration_utils import ConfigMixin, register_to_config 20 | from diffusers.loaders import FromOriginalVAEMixin 21 | from diffusers.utils.accelerate_utils import apply_forward_hook 22 | from diffusers.models.attention_processor import ( 23 | ADDED_KV_ATTENTION_PROCESSORS, 24 | CROSS_ATTENTION_PROCESSORS, 25 | Attention, 26 | AttentionProcessor, 27 | AttnAddedKVProcessor, 28 | AttnProcessor, 29 | ) 30 | from diffusers.models.modeling_outputs import AutoencoderKLOutput 31 | from diffusers.models.modeling_utils import ModelMixin 32 | from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder 33 | 34 | class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): 35 | r""" 36 | A VAE model with KL loss for encoding images into latents and decoding latent representations into images. 37 | 38 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 39 | for all models (such as downloading or saving). 40 | 41 | Parameters: 42 | in_channels (int, *optional*, defaults to 3): Number of channels in the input image. 43 | out_channels (int, *optional*, defaults to 3): Number of channels in the output. 44 | down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): 45 | Tuple of downsample block types. 46 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): 47 | Tuple of upsample block types. 48 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): 49 | Tuple of block output channels. 50 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 51 | latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. 52 | sample_size (`int`, *optional*, defaults to `32`): Sample input size. 53 | scaling_factor (`float`, *optional*, defaults to 0.18215): 54 | The component-wise standard deviation of the trained latent space computed using the first batch of the 55 | training set. This is used to scale the latent space to have unit variance when training the diffusion 56 | model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the 57 | diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 58 | / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image 59 | Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. 60 | force_upcast (`bool`, *optional*, default to `True`): 61 | If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE 62 | can be fine-tuned / trained to a lower range without loosing too much precision in which case 63 | `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix 64 | """ 65 | 66 | _supports_gradient_checkpointing = True 67 | 68 | @register_to_config 69 | def __init__( 70 | self, 71 | in_channels: int = 3, 72 | out_channels: int = 3, 73 | down_block_types: Tuple[str] = ("DownEncoderBlock2D",), 74 | up_block_types: Tuple[str] = ("UpDecoderBlock2D",), 75 | block_out_channels: Tuple[int] = (64,), 76 | layers_per_block: int = 1, 77 | act_fn: str = "silu", 78 | latent_channels: int = 4, 79 | norm_num_groups: int = 32, 80 | sample_size: int = 32, 81 | scaling_factor: float = 0.18215, 82 | force_upcast: float = True, 83 | ): 84 | super().__init__() 85 | 86 | # pass init params to Encoder 87 | self.encoder = Encoder( 88 | in_channels=in_channels, 89 | out_channels=latent_channels, 90 | down_block_types=down_block_types, 91 | block_out_channels=block_out_channels, 92 | layers_per_block=layers_per_block, 93 | act_fn=act_fn, 94 | norm_num_groups=norm_num_groups, 95 | double_z=True, 96 | ) 97 | 98 | # pass init params to Decoder 99 | self.decoder = Decoder( 100 | in_channels=latent_channels, 101 | out_channels=out_channels, 102 | up_block_types=up_block_types, 103 | block_out_channels=block_out_channels, 104 | layers_per_block=layers_per_block, 105 | norm_num_groups=norm_num_groups, 106 | act_fn=act_fn, 107 | ) 108 | 109 | self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) 110 | self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) 111 | 112 | self.use_slicing = False 113 | self.use_tiling = False 114 | 115 | # only relevant if vae tiling is enabled 116 | self.tile_sample_min_size = self.config.sample_size 117 | sample_size = ( 118 | self.config.sample_size[0] 119 | if isinstance(self.config.sample_size, (list, tuple)) 120 | else self.config.sample_size 121 | ) 122 | self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) 123 | self.tile_overlap_factor = 0.25 124 | 125 | def _set_gradient_checkpointing(self, module, value=False): 126 | if isinstance(module, (Encoder, Decoder)): 127 | module.gradient_checkpointing = value 128 | 129 | def enable_tiling(self, use_tiling: bool = True): 130 | r""" 131 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 132 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 133 | processing larger images. 134 | """ 135 | self.use_tiling = use_tiling 136 | 137 | def disable_tiling(self): 138 | r""" 139 | Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing 140 | decoding in one step. 141 | """ 142 | self.enable_tiling(False) 143 | 144 | def enable_slicing(self): 145 | r""" 146 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 147 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 148 | """ 149 | self.use_slicing = True 150 | 151 | def disable_slicing(self): 152 | r""" 153 | Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing 154 | decoding in one step. 155 | """ 156 | self.use_slicing = False 157 | 158 | @property 159 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors 160 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 161 | r""" 162 | Returns: 163 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 164 | indexed by its weight name. 165 | """ 166 | # set recursively 167 | processors = {} 168 | 169 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 170 | if hasattr(module, "get_processor"): 171 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 172 | 173 | for sub_name, child in module.named_children(): 174 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 175 | 176 | return processors 177 | 178 | for name, module in self.named_children(): 179 | fn_recursive_add_processors(name, module, processors) 180 | 181 | return processors 182 | 183 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor 184 | def set_attn_processor( 185 | self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False 186 | ): 187 | r""" 188 | Sets the attention processor to use to compute attention. 189 | 190 | Parameters: 191 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 192 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 193 | for **all** `Attention` layers. 194 | 195 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 196 | processor. This is strongly recommended when setting trainable attention processors. 197 | 198 | """ 199 | count = len(self.attn_processors.keys()) 200 | 201 | if isinstance(processor, dict) and len(processor) != count: 202 | raise ValueError( 203 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 204 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 205 | ) 206 | 207 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 208 | if hasattr(module, "set_processor"): 209 | if not isinstance(processor, dict): 210 | module.set_processor(processor, _remove_lora=_remove_lora) 211 | else: 212 | module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) 213 | 214 | for sub_name, child in module.named_children(): 215 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 216 | 217 | for name, module in self.named_children(): 218 | fn_recursive_attn_processor(name, module, processor) 219 | 220 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 221 | def set_default_attn_processor(self): 222 | """ 223 | Disables custom attention processors and sets the default attention implementation. 224 | """ 225 | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 226 | processor = AttnAddedKVProcessor() 227 | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 228 | processor = AttnProcessor() 229 | else: 230 | raise ValueError( 231 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 232 | ) 233 | 234 | self.set_attn_processor(processor, _remove_lora=True) 235 | 236 | @apply_forward_hook 237 | def encode( 238 | self, x: torch.FloatTensor, return_dict: bool = True 239 | ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: 240 | """ 241 | Encode a batch of images into latents. 242 | 243 | Args: 244 | x (`torch.FloatTensor`): Input batch of images. 245 | return_dict (`bool`, *optional*, defaults to `True`): 246 | Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. 247 | 248 | Returns: 249 | The latent representations of the encoded images. If `return_dict` is True, a 250 | [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. 251 | """ 252 | if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): 253 | return self.tiled_encode(x, return_dict=return_dict) 254 | 255 | if self.use_slicing and x.shape[0] > 1: 256 | encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] 257 | h = torch.cat(encoded_slices) 258 | else: 259 | h = self.encoder(x) 260 | 261 | moments = self.quant_conv(h.to(dtype=self.quant_conv.weight.dtype)) 262 | posterior = DiagonalGaussianDistribution(moments) 263 | 264 | if not return_dict: 265 | return (posterior,) 266 | 267 | return AutoencoderKLOutput(latent_dist=posterior) 268 | 269 | def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: 270 | if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): 271 | return self.tiled_decode(z, return_dict=return_dict) 272 | 273 | z = self.post_quant_conv(z.to(dtype=self.post_quant_conv.weight.dtype)) 274 | dec = self.decoder(z) 275 | 276 | if not return_dict: 277 | return (dec,) 278 | 279 | return DecoderOutput(sample=dec) 280 | 281 | @apply_forward_hook 282 | def decode( 283 | self, z: torch.FloatTensor, return_dict: bool = True, generator=None 284 | ) -> Union[DecoderOutput, torch.FloatTensor]: 285 | """ 286 | Decode a batch of images. 287 | 288 | Args: 289 | z (`torch.FloatTensor`): Input batch of latent vectors. 290 | return_dict (`bool`, *optional*, defaults to `True`): 291 | Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. 292 | 293 | Returns: 294 | [`~models.vae.DecoderOutput`] or `tuple`: 295 | If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is 296 | returned. 297 | 298 | """ 299 | if self.use_slicing and z.shape[0] > 1: 300 | decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] 301 | decoded = torch.cat(decoded_slices) 302 | else: 303 | decoded = self._decode(z).sample 304 | 305 | if not return_dict: 306 | return (decoded,) 307 | 308 | return DecoderOutput(sample=decoded) 309 | 310 | def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: 311 | blend_extent = min(a.shape[2], b.shape[2], blend_extent) 312 | for y in range(blend_extent): 313 | b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) 314 | return b 315 | 316 | def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: 317 | blend_extent = min(a.shape[3], b.shape[3], blend_extent) 318 | for x in range(blend_extent): 319 | b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) 320 | return b 321 | 322 | def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: 323 | r"""Encode a batch of images using a tiled encoder. 324 | 325 | When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several 326 | steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is 327 | different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the 328 | tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the 329 | output, but they should be much less noticeable. 330 | 331 | Args: 332 | x (`torch.FloatTensor`): Input batch of images. 333 | return_dict (`bool`, *optional*, defaults to `True`): 334 | Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. 335 | 336 | Returns: 337 | [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: 338 | If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain 339 | `tuple` is returned. 340 | """ 341 | overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) 342 | blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) 343 | row_limit = self.tile_latent_min_size - blend_extent 344 | 345 | # Split the image into 512x512 tiles and encode them separately. 346 | rows = [] 347 | for i in range(0, x.shape[2], overlap_size): 348 | row = [] 349 | for j in range(0, x.shape[3], overlap_size): 350 | tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] 351 | tile = self.encoder(tile) 352 | tile = self.quant_conv(tile) 353 | row.append(tile) 354 | rows.append(row) 355 | result_rows = [] 356 | for i, row in enumerate(rows): 357 | result_row = [] 358 | for j, tile in enumerate(row): 359 | # blend the above tile and the left tile 360 | # to the current tile and add the current tile to the result row 361 | if i > 0: 362 | tile = self.blend_v(rows[i - 1][j], tile, blend_extent) 363 | if j > 0: 364 | tile = self.blend_h(row[j - 1], tile, blend_extent) 365 | result_row.append(tile[:, :, :row_limit, :row_limit]) 366 | result_rows.append(torch.cat(result_row, dim=3)) 367 | 368 | moments = torch.cat(result_rows, dim=2) 369 | posterior = DiagonalGaussianDistribution(moments) 370 | 371 | if not return_dict: 372 | return (posterior,) 373 | 374 | return AutoencoderKLOutput(latent_dist=posterior) 375 | 376 | def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: 377 | r""" 378 | Decode a batch of images using a tiled decoder. 379 | 380 | Args: 381 | z (`torch.FloatTensor`): Input batch of latent vectors. 382 | return_dict (`bool`, *optional*, defaults to `True`): 383 | Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. 384 | 385 | Returns: 386 | [`~models.vae.DecoderOutput`] or `tuple`: 387 | If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is 388 | returned. 389 | """ 390 | overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) 391 | blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) 392 | row_limit = self.tile_sample_min_size - blend_extent 393 | 394 | # Split z into overlapping 64x64 tiles and decode them separately. 395 | # The tiles have an overlap to avoid seams between tiles. 396 | rows = [] 397 | for i in range(0, z.shape[2], overlap_size): 398 | row = [] 399 | for j in range(0, z.shape[3], overlap_size): 400 | tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] 401 | tile = self.post_quant_conv(tile) 402 | decoded = self.decoder(tile) 403 | row.append(decoded) 404 | rows.append(row) 405 | result_rows = [] 406 | for i, row in enumerate(rows): 407 | result_row = [] 408 | for j, tile in enumerate(row): 409 | # blend the above tile and the left tile 410 | # to the current tile and add the current tile to the result row 411 | if i > 0: 412 | tile = self.blend_v(rows[i - 1][j], tile, blend_extent) 413 | if j > 0: 414 | tile = self.blend_h(row[j - 1], tile, blend_extent) 415 | result_row.append(tile[:, :, :row_limit, :row_limit]) 416 | result_rows.append(torch.cat(result_row, dim=3)) 417 | 418 | dec = torch.cat(result_rows, dim=2) 419 | if not return_dict: 420 | return (dec,) 421 | 422 | return DecoderOutput(sample=dec) 423 | 424 | def forward( 425 | self, 426 | sample: torch.FloatTensor, 427 | sample_posterior: bool = False, 428 | return_dict: bool = True, 429 | generator: Optional[torch.Generator] = None, 430 | ) -> Union[DecoderOutput, torch.FloatTensor]: 431 | r""" 432 | Args: 433 | sample (`torch.FloatTensor`): Input sample. 434 | sample_posterior (`bool`, *optional*, defaults to `False`): 435 | Whether to sample from the posterior. 436 | return_dict (`bool`, *optional*, defaults to `True`): 437 | Whether or not to return a [`DecoderOutput`] instead of a plain tuple. 438 | """ 439 | x = sample 440 | posterior = self.encode(x).latent_dist 441 | if sample_posterior: 442 | z = posterior.sample(generator=generator) 443 | else: 444 | z = posterior.mode() 445 | dec = self.decode(z).sample 446 | 447 | if not return_dict: 448 | return (dec,) 449 | 450 | return DecoderOutput(sample=dec) 451 | 452 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections 453 | def fuse_qkv_projections(self): 454 | """ 455 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, 456 | key, value) are fused. For cross-attention modules, key and value projection matrices are fused. 457 | 458 | 459 | 460 | This API is 🧪 experimental. 461 | 462 | 463 | """ 464 | self.original_attn_processors = None 465 | 466 | for _, attn_processor in self.attn_processors.items(): 467 | if "Added" in str(attn_processor.__class__.__name__): 468 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 469 | 470 | self.original_attn_processors = self.attn_processors 471 | 472 | for module in self.modules(): 473 | if isinstance(module, Attention): 474 | module.fuse_projections(fuse=True) 475 | 476 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 477 | def unfuse_qkv_projections(self): 478 | """Disables the fused QKV projection if enabled. 479 | 480 | 481 | 482 | This API is 🧪 experimental. 483 | 484 | 485 | 486 | """ 487 | if self.original_attn_processors is not None: 488 | self.set_attn_processor(self.original_attn_processors) 489 | 490 | 491 | 492 | def merge_and_unload( 493 | self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None 494 | ) -> torch.nn.Module: 495 | 496 | return self._unload_and_optionally_merge( 497 | progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names 498 | ) 499 | 500 | def _unload_and_optionally_merge( 501 | self, 502 | merge=True, 503 | progressbar: bool = False, 504 | safe_merge: bool = False, 505 | adapter_names: Optional[list[str]] = None, 506 | ): 507 | from tqdm import tqdm 508 | from peft.tuners.tuners_utils import onload_layer 509 | from peft.utils import _get_submodules, ModulesToSaveWrapper 510 | 511 | key_list = [key for key, _ in self.named_modules() if "lora_" not in key] 512 | desc = "Unloading " + ("and merging " if merge else "") + "model" 513 | for key in tqdm(key_list, disable=not progressbar, desc=desc): 514 | try: 515 | parent, target, target_name = _get_submodules(self, key) 516 | except AttributeError: 517 | continue 518 | with onload_layer(target): 519 | if hasattr(target, "base_layer"): 520 | if merge: 521 | target.merge(safe_merge=safe_merge, adapter_names=adapter_names) 522 | self._replace_module(parent, target_name, target.get_base_layer(), target) 523 | elif isinstance(target, ModulesToSaveWrapper): 524 | # save any additional trainable modules part of `modules_to_save` 525 | new_module = target.modules_to_save[target.active_adapter] 526 | if hasattr(new_module, "base_layer"): 527 | # check if the module is itself a tuner layer 528 | if merge: 529 | new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) 530 | new_module = new_module.get_base_layer() 531 | setattr(parent, target_name, new_module) 532 | 533 | return self 534 | 535 | def _replace_module(self, parent, child_name, new_module, child): 536 | setattr(parent, child_name, new_module) 537 | # It's not necessary to set requires_grad here, as that is handled by 538 | # _mark_only_adapters_as_trainable 539 | 540 | # child layer wraps the original module, unpack it 541 | if hasattr(child, "base_layer"): 542 | child = child.base_layer 543 | 544 | if not hasattr(new_module, "base_layer"): 545 | new_module.weight = child.weight 546 | if hasattr(child, "bias"): 547 | new_module.bias = child.bias 548 | 549 | if getattr(child, "state", None) is not None: 550 | if hasattr(new_module, "base_layer"): 551 | new_module.base_layer.state = child.state 552 | else: 553 | new_module.state = child.state 554 | new_module.to(child.weight.device) 555 | 556 | # dispatch to correct device 557 | for name, module in new_module.named_modules(): 558 | if ("lora_" in name) or ("ranknum" in name): 559 | weight = child.qweight if hasattr(child, "qweight") else child.weight 560 | module.to(weight.device) -------------------------------------------------------------------------------- /src/my_utils/__pycache__/devices.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/devices.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_lr_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_lr_utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_aigc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_aigc.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_project.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_project.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_realsr.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_realsr.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_realsr_sdxl_vsd_nostage.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_realsr_sdxl_vsd_nostage.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_realsr_vsd.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_realsr_vsd.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_realsr_vsd_controlnet_nostage.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_realsr_vsd_controlnet_nostage.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_realsr_vsd_nostage.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_realsr_vsd_nostage.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_realsr_vsd_nostage_0513.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_realsr_vsd_nostage_0513.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_realsr_vsd_stage1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_realsr_vsd_stage1.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_realsr_vsd_stage2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_realsr_vsd_stage2.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_realsr_vsd_turbo_nostage.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_realsr_vsd_turbo_nostage.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_res.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_res.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_seesr_vsd_nostage.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_seesr_vsd_nostage.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/training_utils_wanghui.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/training_utils_wanghui.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/vaehook.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/vaehook.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/__pycache__/wavelet_color_fix.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csslc/PiSA-SR/96fa0e7ad42972968ae3be3683da39c5d6ecc067/src/my_utils/__pycache__/wavelet_color_fix.cpython-310.pyc -------------------------------------------------------------------------------- /src/my_utils/devices.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import contextlib 3 | from functools import lru_cache 4 | 5 | import torch 6 | #from modules import errors 7 | 8 | if sys.platform == "darwin": 9 | from modules import mac_specific 10 | 11 | 12 | def has_mps() -> bool: 13 | if sys.platform != "darwin": 14 | return False 15 | else: 16 | return mac_specific.has_mps 17 | 18 | 19 | def get_cuda_device_string(): 20 | return "cuda" 21 | 22 | 23 | def get_optimal_device_name(): 24 | if torch.cuda.is_available(): 25 | return get_cuda_device_string() 26 | 27 | if has_mps(): 28 | return "mps" 29 | 30 | return "cpu" 31 | 32 | 33 | def get_optimal_device(): 34 | return torch.device(get_optimal_device_name()) 35 | 36 | 37 | def get_device_for(task): 38 | return get_optimal_device() 39 | 40 | 41 | def torch_gc(): 42 | 43 | if torch.cuda.is_available(): 44 | with torch.cuda.device(get_cuda_device_string()): 45 | torch.cuda.empty_cache() 46 | torch.cuda.ipc_collect() 47 | 48 | if has_mps(): 49 | mac_specific.torch_mps_gc() 50 | 51 | 52 | def enable_tf32(): 53 | if torch.cuda.is_available(): 54 | 55 | # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't 56 | # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 57 | if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): 58 | torch.backends.cudnn.benchmark = True 59 | 60 | torch.backends.cuda.matmul.allow_tf32 = True 61 | torch.backends.cudnn.allow_tf32 = True 62 | 63 | 64 | enable_tf32() 65 | #errors.run(enable_tf32, "Enabling TF32") 66 | 67 | cpu = torch.device("cpu") 68 | device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda") 69 | dtype = torch.float16 70 | dtype_vae = torch.float16 71 | dtype_unet = torch.float16 72 | unet_needs_upcast = False 73 | 74 | 75 | def cond_cast_unet(input): 76 | return input.to(dtype_unet) if unet_needs_upcast else input 77 | 78 | 79 | def cond_cast_float(input): 80 | return input.float() if unet_needs_upcast else input 81 | 82 | 83 | def randn(seed, shape): 84 | torch.manual_seed(seed) 85 | return torch.randn(shape, device=device) 86 | 87 | 88 | def randn_without_seed(shape): 89 | return torch.randn(shape, device=device) 90 | 91 | 92 | def autocast(disable=False): 93 | if disable: 94 | return contextlib.nullcontext() 95 | 96 | return torch.autocast("cuda") 97 | 98 | 99 | def without_autocast(disable=False): 100 | return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() 101 | 102 | 103 | class NansException(Exception): 104 | pass 105 | 106 | 107 | def test_for_nans(x, where): 108 | if not torch.all(torch.isnan(x)).item(): 109 | return 110 | 111 | if where == "unet": 112 | message = "A tensor with all NaNs was produced in Unet." 113 | 114 | elif where == "vae": 115 | message = "A tensor with all NaNs was produced in VAE." 116 | 117 | else: 118 | message = "A tensor with all NaNs was produced." 119 | 120 | message += " Use --disable-nan-check commandline argument to disable this check." 121 | 122 | raise NansException(message) 123 | 124 | 125 | @lru_cache 126 | def first_time_calculation(): 127 | """ 128 | just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and 129 | spends about 2.7 seconds doing that, at least wih NVidia. 130 | """ 131 | 132 | x = torch.zeros((1, 1)).to(device, dtype) 133 | linear = torch.nn.Linear(1, 1).to(device, dtype) 134 | linear(x) 135 | 136 | x = torch.zeros((1, 1, 3, 3)).to(device, dtype) 137 | conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) 138 | conv2d(x) 139 | -------------------------------------------------------------------------------- /src/my_utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torchvision import transforms 5 | import torchvision.transforms.functional as F 6 | from pathlib import Path 7 | 8 | def parse_args(input_args=None): 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument("--is_module", default=False) 13 | parser.add_argument("--tracker_project_name", type=str, default="pisasr2025") 14 | 15 | # args for the loss function 16 | parser.add_argument("--lambda_lpips", default=2.0, type=float) 17 | parser.add_argument("--lambda_l2", default=1.0, type=float) 18 | parser.add_argument("--lambda_csd", default=1.0, type=float) 19 | 20 | # args for the csd training 21 | parser.add_argument("--pretrained_model_path_csd", default='preset/models/stable-diffusion-2-1-base', type=str) 22 | parser.add_argument("--min_dm_step_ratio", default=0.02, type=float) 23 | parser.add_argument("--max_dm_step_ratio", default=0.98, type=float) 24 | parser.add_argument("--neg_prompt_csd", default="painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth", type=str) 25 | parser.add_argument("--pos_prompt_csd", default="", type=str) 26 | parser.add_argument("--cfg_csd", default=1.0, type=float) 27 | 28 | # args for the `t` test 29 | parser.add_argument("--timesteps1", default=1, type=float) 30 | # details about the model architecture 31 | parser.add_argument("--pretrained_model_path", default='preset/models/stable-diffusion-2-1-base') 32 | # # unet lora setting 33 | parser.add_argument("--lora_rank_unet_pix", default=4, type=int) 34 | parser.add_argument("--lora_rank_unet_sem", default=4, type=int) 35 | 36 | # dataset options 37 | parser.add_argument("--dataset_txt_paths", default='/gt_path.txt', type=str) 38 | parser.add_argument("--highquality_dataset_txt_paths", default='/gt_selected_path.txt', type=str) 39 | parser.add_argument("--dataset_test_folder", 40 | default="/testfolder") 41 | parser.add_argument("--null_text_ratio", default=0., type=float) 42 | parser.add_argument("--prob", default=0.5, type=float) 43 | parser.add_argument("--resolution_ori", type=int, default=512,) 44 | parser.add_argument("--resolution_tgt", type=int, default=512,) 45 | 46 | # resume 47 | parser.add_argument("--resume_ckpt", default=None, type=str) 48 | 49 | # training details 50 | parser.add_argument("--output_dir", default='experiments/oup') 51 | parser.add_argument("--seed", type=int, default=123, help="A seed for reproducible training.") 52 | parser.add_argument("--train_batch_size", type=int, default=2, help="Batch size (per device) for the training dataloader.") 53 | parser.add_argument("--num_training_epochs", type=int, default=10000) 54 | parser.add_argument("--max_train_steps", type=int, default=100000,) 55 | parser.add_argument("--pix_steps", type=int, default=10,) 56 | parser.add_argument("--checkpointing_steps", type=int, default=500,) 57 | parser.add_argument("--eval_freq", type=int, default=500, ) 58 | parser.add_argument("--gradient_accumulation_steps", type=int, default=2, help="Number of updates steps to accumulate before performing a backward/update pass.",) 59 | parser.add_argument("--gradient_checkpointing", action="store_true",) 60 | parser.add_argument("--learning_rate", type=float, default=5e-5) 61 | parser.add_argument("--lr_scheduler", type=str, default="constant", 62 | help=( 63 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 64 | ' "constant", "constant_with_warmup"]' 65 | ), 66 | ) 67 | parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.") 68 | parser.add_argument("--lr_num_cycles", type=int, default=1, 69 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 70 | ) 71 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 72 | 73 | parser.add_argument("--dataloader_num_workers", type=int, default=0,) 74 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 75 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 76 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 77 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 78 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 79 | parser.add_argument("--allow_tf32", action="store_true", 80 | help=( 81 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 82 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 83 | ), 84 | ) 85 | parser.add_argument("--report_to", type=str, default="tensorboard", 86 | help=( 87 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 88 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 89 | ), 90 | ) 91 | parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"],) 92 | parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") 93 | parser.add_argument("--set_grads_to_none", action="store_true",) 94 | 95 | parser.add_argument("--logging_dir", type=str, default="logs") 96 | parser.add_argument("--deg_file_path", default="params.yml", type=str) 97 | parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain') 98 | 99 | 100 | if input_args is not None: 101 | args = parser.parse_args(input_args) 102 | else: 103 | args = parser.parse_args() 104 | 105 | return args -------------------------------------------------------------------------------- /src/my_utils/wavelet_color_fix.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # -------------------------------------------------------------------------------- 3 | # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) 4 | # -------------------------------------------------------------------------------- 5 | ''' 6 | 7 | import torch 8 | from PIL import Image 9 | from torch import Tensor 10 | from torch.nn import functional as F 11 | 12 | from torchvision.transforms import ToTensor, ToPILImage 13 | 14 | def adain_color_fix(target: Image, source: Image): 15 | # Convert images to tensors 16 | to_tensor = ToTensor() 17 | target_tensor = to_tensor(target).unsqueeze(0) 18 | source_tensor = to_tensor(source).unsqueeze(0) 19 | 20 | # Apply adaptive instance normalization 21 | result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) 22 | 23 | # Convert tensor back to image 24 | to_image = ToPILImage() 25 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 26 | 27 | return result_image 28 | 29 | def wavelet_color_fix(target: Image, source: Image): 30 | # Convert images to tensors 31 | to_tensor = ToTensor() 32 | target_tensor = to_tensor(target).unsqueeze(0) 33 | source_tensor = to_tensor(source).unsqueeze(0) 34 | 35 | # Apply wavelet reconstruction 36 | result_tensor = wavelet_reconstruction(target_tensor, source_tensor) 37 | 38 | # Convert tensor back to image 39 | to_image = ToPILImage() 40 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 41 | 42 | return result_image 43 | 44 | def calc_mean_std(feat: Tensor, eps=1e-5): 45 | """Calculate mean and std for adaptive_instance_normalization. 46 | Args: 47 | feat (Tensor): 4D tensor. 48 | eps (float): A small value added to the variance to avoid 49 | divide-by-zero. Default: 1e-5. 50 | """ 51 | size = feat.size() 52 | assert len(size) == 4, 'The input feature should be 4D tensor.' 53 | b, c = size[:2] 54 | feat_var = feat.reshape(b, c, -1).var(dim=2) + eps 55 | feat_std = feat_var.sqrt().reshape(b, c, 1, 1) 56 | feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) 57 | return feat_mean, feat_std 58 | 59 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): 60 | """Adaptive instance normalization. 61 | Adjust the reference features to have the similar color and illuminations 62 | as those in the degradate features. 63 | Args: 64 | content_feat (Tensor): The reference feature. 65 | style_feat (Tensor): The degradate features. 66 | """ 67 | size = content_feat.size() 68 | style_mean, style_std = calc_mean_std(style_feat) 69 | content_mean, content_std = calc_mean_std(content_feat) 70 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 71 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 72 | 73 | def wavelet_blur(image: Tensor, radius: int): 74 | """ 75 | Apply wavelet blur to the input tensor. 76 | """ 77 | # input shape: (1, 3, H, W) 78 | # convolution kernel 79 | kernel_vals = [ 80 | [0.0625, 0.125, 0.0625], 81 | [0.125, 0.25, 0.125], 82 | [0.0625, 0.125, 0.0625], 83 | ] 84 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) 85 | # add channel dimensions to the kernel to make it a 4D tensor 86 | kernel = kernel[None, None] 87 | # repeat the kernel across all input channels 88 | kernel = kernel.repeat(3, 1, 1, 1) 89 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate') 90 | # apply convolution 91 | output = F.conv2d(image, kernel, groups=3, dilation=radius) 92 | return output 93 | 94 | def wavelet_decomposition(image: Tensor, levels=5): 95 | """ 96 | Apply wavelet decomposition to the input tensor. 97 | This function only returns the low frequency & the high frequency. 98 | """ 99 | high_freq = torch.zeros_like(image) 100 | for i in range(levels): 101 | radius = 2 ** i 102 | low_freq = wavelet_blur(image, radius) 103 | high_freq += (image - low_freq) 104 | image = low_freq 105 | 106 | return high_freq, low_freq 107 | 108 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): 109 | """ 110 | Apply wavelet decomposition, so that the content will have the same color as the style. 111 | """ 112 | # calculate the wavelet decomposition of the content feature 113 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat) 114 | del content_low_freq 115 | # calculate the wavelet decomposition of the style feature 116 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat) 117 | del style_high_freq 118 | # reconstruct the content feature with the style's high frequency 119 | return content_high_freq + style_low_freq 120 | -------------------------------------------------------------------------------- /test_pisasr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | from torchvision import transforms 7 | import torchvision.transforms.functional as F 8 | 9 | from pisasr import PiSASR_eval 10 | from src.my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix 11 | 12 | import glob 13 | 14 | 15 | def pisa_sr(args): 16 | # Initialize the model 17 | model = PiSASR_eval(args) 18 | model.set_eval() 19 | 20 | # Get all input images 21 | if os.path.isdir(args.input_image): 22 | image_names = sorted(glob.glob(f'{args.input_image}/*.png')) 23 | else: 24 | image_names = [args.input_image] 25 | 26 | # Make the output directory 27 | os.makedirs(args.output_dir, exist_ok=True) 28 | print(f'There are {len(image_names)} images.') 29 | 30 | time_records = [] 31 | for image_name in image_names: 32 | # Ensure the input image is a multiple of 8 33 | input_image = Image.open(image_name).convert('RGB') 34 | ori_width, ori_height = input_image.size 35 | rscale = args.upscale 36 | resize_flag = False 37 | 38 | if ori_width < args.process_size // rscale or ori_height < args.process_size // rscale: 39 | scale = (args.process_size // rscale) / min(ori_width, ori_height) 40 | input_image = input_image.resize((int(scale * ori_width), int(scale * ori_height))) 41 | resize_flag = True 42 | 43 | input_image = input_image.resize((input_image.size[0] * rscale, input_image.size[1] * rscale)) 44 | new_width = input_image.width - input_image.width % 8 45 | new_height = input_image.height - input_image.height % 8 46 | input_image = input_image.resize((new_width, new_height), Image.LANCZOS) 47 | bname = os.path.basename(image_name) 48 | 49 | # Get caption (you can add the text prompt here) 50 | validation_prompt = '' 51 | 52 | # Translate the image 53 | with torch.no_grad(): 54 | c_t = F.to_tensor(input_image).unsqueeze(0).cuda() * 2 - 1 55 | inference_time, output_image = model(args.default, c_t, prompt=validation_prompt) 56 | 57 | print(f"Inference time: {inference_time:.4f} seconds") 58 | time_records.append(inference_time) 59 | 60 | output_image = output_image * 0.5 + 0.5 61 | output_image = torch.clip(output_image, 0, 1) 62 | output_pil = transforms.ToPILImage()(output_image[0].cpu()) 63 | 64 | if args.align_method == 'adain': 65 | output_pil = adain_color_fix(target=output_pil, source=input_image) 66 | elif args.align_method == 'wavelet': 67 | output_pil = wavelet_color_fix(target=output_pil, source=input_image) 68 | 69 | if resize_flag: 70 | output_pil = output_pil.resize((int(args.upscale * ori_width), int(args.upscale * ori_height))) 71 | output_pil.save(os.path.join(args.output_dir, bname)) 72 | 73 | # Calculate the average inference time, excluding the first few for stabilization 74 | if len(time_records) > 3: 75 | average_time = np.mean(time_records[3:]) 76 | else: 77 | average_time = np.mean(time_records) 78 | print(f"Average inference time: {average_time:.4f} seconds") 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--input_image', '-i', type=str, default='preset/test_datasets', help="path to the input image") 84 | parser.add_argument('--output_dir', '-o', type=str, default='experiments/test', help="the directory to save the output") 85 | parser.add_argument("--pretrained_model_path", type=str, default='preset/models/stable-diffusion-2-1-base') 86 | parser.add_argument('--pretrained_path', type=str, default='preset/models/pisa_sr.pkl', help="path to a model state dict to be used") 87 | parser.add_argument('--seed', type=int, default=42, help="Random seed to be used") 88 | parser.add_argument("--process_size", type=int, default=512) 89 | parser.add_argument("--upscale", type=int, default=4) 90 | parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default="adain") 91 | parser.add_argument("--lambda_pix", default=1.0, type=float, help="the scale for pixel-level enhancement") 92 | parser.add_argument("--lambda_sem", default=1.0, type=float, help="the scale for sementic-level enhancements") 93 | parser.add_argument("--vae_decoder_tiled_size", type=int, default=224) 94 | parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024) 95 | parser.add_argument("--latent_tiled_size", type=int, default=96) 96 | parser.add_argument("--latent_tiled_overlap", type=int, default=32) 97 | parser.add_argument("--mixed_precision", type=str, default="fp16") 98 | parser.add_argument("--default", action="store_true", help="use default or adjustale setting?") 99 | 100 | args = parser.parse_args() 101 | 102 | # Call the processing function 103 | pisa_sr(args) 104 | -------------------------------------------------------------------------------- /train_pisasr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import lpips 4 | import clip 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.utils.checkpoint 9 | import transformers 10 | from accelerate import Accelerator 11 | from accelerate.utils import set_seed 12 | from PIL import Image 13 | from torchvision import transforms 14 | from tqdm.auto import tqdm 15 | 16 | import diffusers 17 | from diffusers.utils.import_utils import is_xformers_available 18 | from diffusers.optimization import get_scheduler 19 | 20 | from pisasr import CSDLoss, PiSASR 21 | from src.my_utils.training_utils import parse_args 22 | from src.datasets.dataset import PairedSROnlineTxtDataset 23 | 24 | from pathlib import Path 25 | from accelerate.utils import set_seed, ProjectConfiguration 26 | from accelerate import DistributedDataParallelKwargs 27 | 28 | from src.my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix 29 | import random 30 | 31 | def main(args): 32 | logging_dir = Path(args.output_dir, args.logging_dir) 33 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 34 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 35 | 36 | accelerator = Accelerator( 37 | gradient_accumulation_steps=args.gradient_accumulation_steps, 38 | mixed_precision=args.mixed_precision, 39 | log_with=args.report_to, 40 | project_config=accelerator_project_config, 41 | kwargs_handlers=[ddp_kwargs], 42 | ) 43 | 44 | if accelerator.is_local_main_process: 45 | transformers.utils.logging.set_verbosity_warning() 46 | diffusers.utils.logging.set_verbosity_info() 47 | else: 48 | transformers.utils.logging.set_verbosity_error() 49 | diffusers.utils.logging.set_verbosity_error() 50 | 51 | if args.seed is not None: 52 | set_seed(args.seed) 53 | 54 | if accelerator.is_main_process: 55 | os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True) 56 | os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True) 57 | 58 | net_pisasr = PiSASR(args) 59 | 60 | if args.enable_xformers_memory_efficient_attention: 61 | if is_xformers_available(): 62 | net_pisasr.unet.enable_xformers_memory_efficient_attention() 63 | else: 64 | raise ValueError("xformers is not available, please install it by running `pip install xformers`") 65 | 66 | if args.gradient_checkpointing: 67 | net_pisasr.unet.enable_gradient_checkpointing() 68 | 69 | if args.allow_tf32: 70 | torch.backends.cuda.matmul.allow_tf32 = True 71 | 72 | # init CSDLoss model 73 | net_csd = CSDLoss(args=args, accelerator=accelerator) 74 | net_csd.requires_grad_(False) 75 | 76 | net_lpips = lpips.LPIPS(net='vgg').cuda() 77 | net_lpips.requires_grad_(False) 78 | 79 | # # set gen adapter 80 | net_pisasr.unet.set_adapter(['default_encoder_pix', 'default_decoder_pix', 'default_others_pix']) 81 | net_pisasr.set_train_pix() # first to remove degradation 82 | 83 | # make the optimizer 84 | layers_to_opt = [] 85 | for n, _p in net_pisasr.unet.named_parameters(): 86 | if "lora" in n: 87 | layers_to_opt.append(_p) 88 | 89 | optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate, 90 | betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, 91 | eps=args.adam_epsilon,) 92 | lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer, 93 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 94 | num_training_steps=args.max_train_steps * accelerator.num_processes, 95 | num_cycles=args.lr_num_cycles, power=args.lr_power,) 96 | 97 | # initialize the dataset 98 | dataset_train = PairedSROnlineTxtDataset(split="train", args=args) 99 | dataset_val = PairedSROnlineTxtDataset(split="test", args=args) 100 | dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers) 101 | dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0) 102 | 103 | 104 | # init RAM for text prompt extractor 105 | from ram.models.ram_lora import ram 106 | from ram import inference_ram as inference 107 | ram_transforms = transforms.Compose([ 108 | transforms.Resize((384, 384)), 109 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 110 | ]) 111 | RAM = ram(pretrained='src/ram_pretrain_model/ram_swin_large_14m.pth', 112 | pretrained_condition=None, 113 | image_size=384, 114 | vit='swin_l') 115 | RAM.eval() 116 | RAM.to("cuda", dtype=torch.float16) 117 | 118 | # Prepare everything with our `accelerator`. 119 | net_pisasr, optimizer, dl_train, lr_scheduler = accelerator.prepare( 120 | net_pisasr, optimizer, dl_train, lr_scheduler 121 | ) 122 | net_lpips = accelerator.prepare(net_lpips) 123 | 124 | weight_dtype = torch.float32 125 | if accelerator.mixed_precision == "fp16": 126 | weight_dtype = torch.float16 127 | elif accelerator.mixed_precision == "bf16": 128 | weight_dtype = torch.bfloat16 129 | 130 | # We need to initialize the trackers we use, and also store our configuration. 131 | # The trackers initializes automatically on the main process. 132 | if accelerator.is_main_process: 133 | tracker_config = dict(vars(args)) 134 | accelerator.init_trackers(args.tracker_project_name, config=tracker_config) 135 | 136 | progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps", 137 | disable=not accelerator.is_local_main_process,) 138 | 139 | # start the training loop 140 | global_step = 0 141 | lambda_l2 = args.lambda_l2 142 | lambda_lpips = 0 143 | lambda_csd = 0 144 | if args.resume_ckpt is not None: 145 | args.pix_steps = 1 146 | for epoch in range(0, args.num_training_epochs): 147 | for step, batch in enumerate(dl_train): 148 | with accelerator.accumulate(net_pisasr): 149 | x_src = batch["conditioning_pixel_values"] 150 | x_tgt = batch["output_pixel_values"] 151 | 152 | # get text prompts from GT 153 | x_tgt_ram = ram_transforms(x_tgt*0.5+0.5) 154 | caption = inference(x_tgt_ram.to(dtype=torch.float16), RAM) 155 | batch["prompt"] = [f'{each_caption}, {args.pos_prompt_csd}' for each_caption in caption] 156 | 157 | if global_step == args.pix_steps: 158 | # begin the semantic optimization 159 | if args.is_module: 160 | net_pisasr.module.unet.set_adapter(['default_encoder_pix', 'default_decoder_pix', 'default_others_pix','default_encoder_sem', 'default_decoder_sem', 'default_others_sem']) 161 | net_pisasr.module.set_train_sem() 162 | else: 163 | net_pisasr.unet.set_adapter(['default_encoder_pix', 'default_decoder_pix', 'default_others_pix','default_encoder_sem', 'default_decoder_sem', 'default_others_sem']) 164 | net_pisasr.set_train_sem() 165 | 166 | lambda_l2 = args.lambda_l2 167 | lambda_lpips = args.lambda_lpips 168 | lambda_csd = args.lambda_csd 169 | 170 | x_tgt_pred, latents_pred, prompt_embeds, neg_prompt_embeds = net_pisasr(x_src, x_tgt, batch=batch, args=args) 171 | loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * lambda_l2 172 | loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * lambda_lpips 173 | loss = loss_l2 + loss_lpips 174 | # reg loss 175 | loss_csd = net_csd.cal_csd(latents_pred, prompt_embeds, neg_prompt_embeds, args, ) * lambda_csd 176 | loss = loss + loss_csd 177 | accelerator.backward(loss) 178 | if accelerator.sync_gradients: 179 | accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm) 180 | optimizer.step() 181 | lr_scheduler.step() 182 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 183 | 184 | if accelerator.sync_gradients: 185 | progress_bar.update(1) 186 | global_step += 1 187 | 188 | if accelerator.is_main_process: 189 | logs = {} 190 | # log all the losses 191 | logs["loss_csd"] = loss_csd.detach().item() 192 | logs["loss_l2"] = loss_l2.detach().item() 193 | logs["loss_lpips"] = loss_lpips.detach().item() 194 | progress_bar.set_postfix(**logs) 195 | 196 | # checkpoint the model 197 | if global_step % args.checkpointing_steps == 1: 198 | outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl") 199 | accelerator.unwrap_model(net_pisasr).save_model(outf) 200 | 201 | # test 202 | if global_step % args.eval_freq == 1: 203 | os.makedirs(os.path.join(args.output_dir, "eval", f"fid_{global_step}"), exist_ok=True) 204 | for step, batch_val in enumerate(dl_val): 205 | x_src = batch_val["conditioning_pixel_values"].cuda() 206 | x_tgt = batch_val["output_pixel_values"].cuda() 207 | x_basename = batch_val["base_name"][0] 208 | B, C, H, W = x_src.shape 209 | assert B == 1, "Use batch size 1 for eval." 210 | with torch.no_grad(): 211 | # get text prompts from LR 212 | x_src_ram = ram_transforms(x_src * 0.5 + 0.5) 213 | caption = inference(x_src_ram.to(dtype=torch.float16), RAM) 214 | batch_val["prompt"] = caption 215 | # forward pass 216 | x_tgt_pred, latents_pred, _, _ = accelerator.unwrap_model(net_pisasr)(x_src, x_tgt, 217 | batch=batch_val, 218 | args=args) 219 | # save the output 220 | output_pil = transforms.ToPILImage()(x_tgt_pred[0].cpu() * 0.5 + 0.5) 221 | input_image = transforms.ToPILImage()(x_src[0].cpu() * 0.5 + 0.5) 222 | if args.align_method == 'adain': 223 | output_pil = adain_color_fix(target=output_pil, source=input_image) 224 | elif args.align_method == 'wavelet': 225 | output_pil = wavelet_color_fix(target=output_pil, source=input_image) 226 | else: 227 | pass 228 | outf = os.path.join(args.output_dir, "eval", f"fid_{global_step}", f"{x_basename}") 229 | output_pil.save(outf) 230 | gc.collect() 231 | torch.cuda.empty_cache() 232 | accelerator.log(logs, step=global_step) 233 | 234 | accelerator.log(logs, step=global_step) 235 | 236 | if __name__ == "__main__": 237 | args = parse_args() 238 | main(args) 239 | --------------------------------------------------------------------------------