├── README.md ├── config ├── data │ └── test.yaml ├── default.yaml ├── sampler │ └── edm.yaml └── task │ ├── inpainting_rand.yaml │ └── phase_retrieval.yaml ├── dataset └── dataset.txt ├── detail.png ├── example_image └── example.PNG ├── exp.png ├── forward_operator ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── fastmri_utils.cpython-310.pyc │ ├── fastmri_utils.cpython-39.pyc │ ├── resizer.cpython-310.pyc │ └── resizer.cpython-39.pyc ├── bkse │ ├── LICENSE │ ├── README.md │ ├── data │ │ ├── GOPRO_dataset.py │ │ ├── REDS_dataset.py │ │ ├── __init__.py │ │ ├── data_sampler.py │ │ ├── mix_dataset.py │ │ └── util.py │ ├── data_augmentation.py │ ├── domain_specific_deblur.py │ ├── experiments │ │ └── pretrained │ │ │ ├── a.txt │ │ │ └── kernel.pth │ ├── generate_blur.py │ ├── generic_deblur.py │ ├── kernel_encoding │ │ ├── base_model.py │ │ ├── image_base_model.py │ │ └── kernel_wizard.py │ ├── models │ │ ├── __init__.py │ │ ├── arch_util.py │ │ ├── backbones │ │ │ ├── resnet.py │ │ │ ├── skip │ │ │ │ ├── concat.py │ │ │ │ ├── downsampler.py │ │ │ │ ├── non_local_dot_product.py │ │ │ │ ├── skip.py │ │ │ │ └── util.py │ │ │ └── unet_parts.py │ │ ├── deblurring │ │ │ ├── a.txt │ │ │ ├── image_deblur.py │ │ │ └── joint_deblur.py │ │ ├── dips.py │ │ ├── dsd │ │ │ ├── bicubic.py │ │ │ ├── dsd.py │ │ │ ├── dsd_stylegan.py │ │ │ ├── dsd_stylegan2.py │ │ │ ├── op │ │ │ │ ├── __init__.py │ │ │ │ ├── fused_act.py │ │ │ │ ├── fused_bias_act.cpp │ │ │ │ ├── fused_bias_act_kernel.cu │ │ │ │ ├── upfirdn2d.cpp │ │ │ │ ├── upfirdn2d.py │ │ │ │ └── upfirdn2d_kernel.cu │ │ │ ├── spherical_optimizer.py │ │ │ ├── stylegan.py │ │ │ └── stylegan2.py │ │ ├── kernel_encoding │ │ │ ├── base_model.py │ │ │ ├── image_base_model.py │ │ │ └── kernel_wizard.py │ │ ├── losses │ │ │ ├── charbonnier_loss.py │ │ │ ├── dsd_loss.py │ │ │ ├── gan_loss.py │ │ │ ├── hyper_laplacian_penalty.py │ │ │ ├── perceptual_loss.py │ │ │ └── ssim_loss.py │ │ └── lr_scheduler.py │ ├── options │ │ ├── __init__.py │ │ ├── data_augmentation │ │ │ └── default.yml │ │ └── options.py │ ├── requirements.txt │ ├── scripts │ │ ├── a.txt │ │ ├── create_lmdb.py │ │ └── download_dataset.py │ ├── train.py │ ├── train_script.sh │ └── utils │ │ ├── __init__.py │ │ ├── a.txt │ │ └── util.py ├── fastmri_utils.py ├── resizer.py └── util.py ├── main.py ├── ptp_utils.py ├── requirements.txt ├── seq_align.py └── seq_aligner.py /README.md: -------------------------------------------------------------------------------- 1 | # PostEdit: Posterior Sampling for Efficient Zero-Shot Image Editing 2 | 3 | > ⚡️ PostEdit is both inversion- and training-free, necessitating approximately 1.5 seconds and 18 GB of GPU memory to generate high-quality results. 4 | > 5 | > 💥 PostEdit is accepted as a poster in International Conference on Learning Representations (ICLR) 2025! 6 | 7 | 8 | ![exp](exp.png) 9 | ![detail](detail.png) 10 | [Paper](https://arxiv.org/pdf/2410.04844) 11 | 12 | 13 | ## Setup 14 | 15 | This code was tested with Python 3.9, [Pytorch](https://pytorch.org/) 2.4.0 using pre-trained models through [huggingface / diffusers](https://github.com/huggingface/diffusers#readme). 16 | Specifically, we implemented our method over [LCM](https://arxiv.org/pdf/2310.04378). 17 | Additional required packages are listed in the requirements file. 18 | The code was tested on a single NVIDIA A100 GPU. 19 | 20 | ## Preparation 21 | 22 | ### Dataset 23 | Download [PIE-Bench](https://docs.google.com/forms/d/e/1FAIpQLSftGgDwLLMwrad9pX3Odbnd4UXGvcRuXDkRp6BT1nPk8fcH_g/viewform) dataset, and place it in your `PIE_Bench_PATH`. 24 | 25 | ### Installation 26 | 27 | Download the code: 28 | ``` 29 | git clone https://github.com/TFNTF/PostEdit.git 30 | ``` 31 | Download pre-trained models: 32 | 33 | ```python 34 | from diffusers import DiffusionPipeline 35 | import torch 36 | 37 | pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7") 38 | 39 | ``` 40 | 41 | ## Quickstart 42 | ``` python 43 | pip install -r requirements.txt 44 | python main.py 45 | (Optional) Save a specific image to "all_images" file for single image editing. 46 | ``` 47 | 48 | 49 | ## Citation 50 | 51 | ``` bibtex 52 | @article{DBLP:journals/corr/abs-2410-04844, 53 | author = {Feng Tian and 54 | Yixuan Li and 55 | Yichao Yan and 56 | Shanyan Guan and 57 | Yanhao Ge and 58 | Xiaokang Yang}, 59 | title = {PostEdit: Posterior Sampling for Efficient Zero-Shot Image Editing}, 60 | journal = {ICLR}, 61 | year = {2025}, 62 | } 63 | ``` 64 | 65 | ## Acknowledgements 66 | 67 | We thank vivo for granting us access to GPUs. 68 | 69 | ## Contact 70 | 71 | If you have any questions, feel free to contact me through email (tf1021@sjtu.edu.cn). Enjoy! 72 | 73 | -------------------------------------------------------------------------------- /config/data/test.yaml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | gpu: 0 2 | batch_size: 100 3 | num_runs: 1 4 | save_dir: ./results 5 | name: demo 6 | wandb: False 7 | save_samples: True 8 | save_traj: False 9 | save_traj_raw_data: False 10 | show_eval: False 11 | eval_fn_list: ['psnr', 'lpips'] 12 | seed: 42 -------------------------------------------------------------------------------- /config/sampler/edm.yaml: -------------------------------------------------------------------------------- 1 | latent: False 2 | 3 | annealing_scheduler_config: 4 | num_steps: 49 5 | sigma_max: 100 6 | sigma_min: 0.1 7 | sigma_final: 0 8 | schedule: 'linear' 9 | timestep: 'poly-7' 10 | 11 | diffusion_scheduler_config: 12 | num_steps: 5 13 | sigma_min: 0.01 14 | sigma_final: 0 15 | schedule: 'linear' 16 | timestep: 'poly-7' 17 | -------------------------------------------------------------------------------- /config/task/inpainting_rand.yaml: -------------------------------------------------------------------------------- 1 | lgvd_config: 2 | num_steps: 200 3 | lr: 1e-5 4 | tau: 0.01 5 | lr_min_ratio: 0.01 6 | 7 | 8 | operator: 9 | name: inpainting 10 | mask_type: random 11 | mask_prob_range: [0.50, 0.51] # for random 12 | resolution: 64 13 | sigma: 0.05 14 | -------------------------------------------------------------------------------- /config/task/phase_retrieval.yaml: -------------------------------------------------------------------------------- 1 | lgvd_config: 2 | num_steps: 10000 3 | lr: 5e-5 4 | tau: 0.01 5 | lr_min_ratio: 0.01 6 | 7 | 8 | operator: 9 | name: phase_retrieval 10 | oversample: 2 11 | sigma: 0.05 -------------------------------------------------------------------------------- /dataset/dataset.txt: -------------------------------------------------------------------------------- 1 | Download dataset and put it here. (For example, Pie-Bench) 2 | -------------------------------------------------------------------------------- /detail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/detail.png -------------------------------------------------------------------------------- /example_image/example.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/example_image/example.PNG -------------------------------------------------------------------------------- /exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/exp.png -------------------------------------------------------------------------------- /forward_operator/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/forward_operator/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /forward_operator/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/forward_operator/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /forward_operator/__pycache__/fastmri_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/forward_operator/__pycache__/fastmri_utils.cpython-310.pyc -------------------------------------------------------------------------------- /forward_operator/__pycache__/fastmri_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/forward_operator/__pycache__/fastmri_utils.cpython-39.pyc -------------------------------------------------------------------------------- /forward_operator/__pycache__/resizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/forward_operator/__pycache__/resizer.cpython-310.pyc -------------------------------------------------------------------------------- /forward_operator/__pycache__/resizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/forward_operator/__pycache__/resizer.cpython-39.pyc -------------------------------------------------------------------------------- /forward_operator/bkse/README.md: -------------------------------------------------------------------------------- 1 | # Exploring Image Deblurring via Encoded Blur Kernel Space 2 | 3 | ## About the project 4 | 5 | We introduce a method to encode the blur operators of an arbitrary dataset of sharp-blur image pairs into a blur kernel space. Assuming the encoded kernel space is close enough to in-the-wild blur operators, we propose an alternating optimization algorithm for blind image deblurring. It approximates an unseen blur operator by a kernel in the encoded space and searches for the corresponding sharp image. Due to the method's design, the encoded kernel space is fully differentiable, thus can be easily adopted in deep neural network models. 6 | 7 | ![Blur kernel space](imgs/teaser.jpg) 8 | 9 | Detail of the method and experimental results can be found in [our following paper](https://arxiv.org/abs/2104.00317): 10 | ``` 11 | @inproceedings{m_Tran-etal-CVPR21, 12 |   author = {Phong Tran and Anh Tran and Quynh Phung and Minh Hoai}, 13 |   title = {Explore Image Deblurring via Encoded Blur Kernel Space}, 14 |   year = {2021}, 15 |   booktitle = {Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition (CVPR)} 16 | } 17 | ``` 18 | Please CITE our paper whenever this repository is used to help produce published results or incorporated into other software. 19 | 20 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1GDvbr4WQUibaEhQVzYPPObV4STn9NAot?usp=sharing) 21 | 22 | ## Table of Content 23 | 24 | * [About the Project](#about-the-project) 25 | * [Getting Started](#getting-started) 26 | * [Prerequisites](#prerequisites) 27 | * [Installation](#installation) 28 | * [Using the pretrained model](#Using-the-pretrained-model) 29 | * [Training and evaluation](#Training-and-evaluation) 30 | * [Model Zoo](#Model-zoo) 31 | 32 | ## Getting started 33 | 34 | ### Prerequisites 35 | 36 | * Python >= 3.7 37 | * Pytorch >= 1.4.0 38 | * CUDA >= 10.0 39 | 40 | ### Installation 41 | 42 | ``` sh 43 | git clone https://github.com/VinAIResearch/blur-kernel-space-exploring.git 44 | cd blur-kernel-space-exploring 45 | 46 | 47 | conda create -n BlurKernelSpace -y python=3.7 48 | conda activate BlurKernelSpace 49 | conda install --file requirements.txt 50 | ``` 51 | 52 | ## Training and evaluation 53 | ### Preparing datasets 54 | You can download the datasets in the [model zoo section](#model-zoo). 55 | 56 | To use your customized dataset, your dataset must be organized as follow: 57 | ``` 58 | root 59 | ├── blur_imgs 60 | ├── 000 61 | ├──── 00000000.png 62 | ├──── 00000001.png 63 | ├──── ... 64 | ├── 001 65 | ├──── 00000000.png 66 | ├──── 00000001.png 67 | ├──── ... 68 | ├── sharp_imgs 69 | ├── 000 70 | ├──── 00000000.png 71 | ├──── 00000001.png 72 | ├──── ... 73 | ├── 001 74 | ├──── 00000000.png 75 | ├──── 00000001.png 76 | ├──── ... 77 | ``` 78 | where `root`, `blur_imgs`, and `sharp_imgs` folders can have arbitrary names. For example, let `root, blur_imgs, sharp_imgs` be `REDS, train_blur, train_sharp` respectively (That is, you are using the REDS training set), then use the following scripts to create the lmdb dataset: 79 | ```sh 80 | python create_lmdb.py --H 720 --W 1280 --C 3 --img_folder REDS/train_sharp --name train_sharp_wval --save_path ../datasets/REDS/train_sharp_wval.lmdb 81 | python create_lmdb.py --H 720 --W 1280 --C 3 --img_folder REDS/train_blur --name train_blur_wval --save_path ../datasets/REDS/train_blur_wval.lmdb 82 | ``` 83 | where `(H, C, W)` is the shape of the images (note that all images in the dataset must have the same shape), `img_folder` is the folder that contains the images, `name` is the name of the dataset, and `save_path` is the save destination (`save_path` must end with `.lmdb`). 84 | 85 | When the script is finished, two folders `train_sharp_wval.lmdb` and `train_blur_wval.lmdb` will be created in `./REDS`. 86 | 87 | 88 | ### Training 89 | To do image deblurring, data augmentation, and blur generation, you first need to train the blur encoding network (The F function in the paper). This is the only network that you need to train. After creating the dataset, change the value of `dataroot_HQ` and `dataroot_LQ` in `options/kernel_encoding/REDS/woVAE.yml` to the paths of the sharp and blur lmdb datasets that were created before, then use the following script to train the model: 90 | ``` 91 | python train.py -opt options/kernel_encoding/REDS/woVAE.yml 92 | ``` 93 | 94 | where `opt` is the path to yaml file that contains training configurations. You can find some default configurations in the `options` folder. Checkpoints, training states, and logs will be saved in `experiments/modelName`. You can change the configurations (learning rate, hyper-parameters, network structure, etc) in the yaml file. 95 | 96 | ### Testing 97 | #### Data augmentation 98 | To augment a given dataset, first, create an lmdb dataset using `scripts/create_lmdb.py` as before. Then use the following script: 99 | ``` 100 | python data_augmentation.py --target_H=720 --target_W=1280 \ 101 | --source_H=720 --source_W=1280\ 102 | --augmented_H=256 --augmented_W=256\ 103 | --source_LQ_root=datasets/REDS/train_blur_wval.lmdb \ 104 | --source_HQ_root=datasets/REDS/train_sharp_wval.lmdb \ 105 | --target_HQ_root=datasets/REDS/test_sharp_wval.lmdb \ 106 | --save_path=results/GOPRO_augmented \ 107 | --num_images=10 \ 108 | --yml_path=options/data_augmentation/default.yml 109 | ``` 110 | `(target_H, target_W)`, `(source_H, source_W)`, and `(augmented_H, augmented_W)` are the desired shapes of the target images, source images, and augmented images respectively. `source_LQ_root`, `source_HQ_root`, and `target_HQ_root` are the paths of the lmdb datasets for the reference blur-sharp pairs and the input sharp images that were created before. `num_images` is the size of the augmented dataset. `model_path` is the path of the trained model. `yml_path` is the path to the model configuration file. Results will be saved in `save_path`. 111 | 112 | ![Data augmentation examples](imgs/results/augmentation.jpg) 113 | 114 | #### Generate novel blur kernels 115 | To generate a blur image given a sharp image, use the following command: 116 | ```sh 117 | python generate_blur.py --yml_path=options/generate_blur/default.yml \ 118 | --image_path=imgs/sharp_imgs/mushishi.png \ 119 | --num_samples=10 120 | --save_path=./res.png 121 | ``` 122 | where `model_path` is the path of the pre-trained model, `yml_path` is the path of the configuration file. `image_path` is the path of the sharp image. After running the script, a blur image corresponding to the sharp image will be saved in `save_path`. Here is some expected output: 123 | ![kernel generating examples](imgs/results/generate_blur.jpg) 124 | **Note**: This only works with models that were trained with `--VAE` flag. The size of input images must be divisible by 128. 125 | 126 | #### Generic Deblurring 127 | To deblur a blurry image, use the following command: 128 | ```sh 129 | python generic_deblur.py --image_path imgs/blur_imgs/blur1.png --yml_path options/generic_deblur/default.yml --save_path ./res.png 130 | ``` 131 | where `image_path` is the path of the blurry image. `yml_path` is the path of the configuration file. The deblurred image will be saved to `save_path`. 132 | 133 | ![Image deblurring examples](imgs/results/general_deblurring.jpg) 134 | 135 | #### Deblurring using sharp image prior 136 | [mapping]: https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k 137 | [synthesis]: https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8 138 | [pretrained model]: https://drive.google.com/file/d/1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO/view 139 | First, you need to download the pre-trained styleGAN or styleGAN2 networks. If you want to use styleGAN, download the [mapping] and [synthesis] networks, then rename and copy them to `experiments/pretrained/stylegan_mapping.pt` and `experiments/pretrained/stylegan_synthesis.pt` respectively. If you want to use styleGAN2 instead, download the [pretrained model], then rename and copy it to `experiments/pretrained/stylegan2.pt`. 140 | 141 | To deblur a blurry image using styleGAN latent space as the sharp image prior, you can use one of the following commands: 142 | ```sh 143 | python domain_specific_deblur.py --input_dir imgs/blur_faces \ 144 | --output_dir experiments/domain_specific_deblur/results \ 145 | --yml_path options/domain_specific_deblur/stylegan.yml # Use latent space of stylegan 146 | python domain_specific_deblur.py --input_dir imgs/blur_faces \ 147 | --output_dir experiments/domain_specific_deblur/results \ 148 | --yml_path options/domain_specific_deblur/stylegan2.yml # Use latent space of stylegan2 149 | ``` 150 | Results will be saved in `experiments/domain_specific_deblur/results`. 151 | **Note**: Generally, the code still works with images that have the size divisible by 128. However, since our blur kernels are not uniform, the size of the kernel increases as the size of the image increases. 152 | 153 | ![PULSE-like Deblurring examples](imgs/results/domain_specific_deblur.jpg) 154 | 155 | ## Model Zoo 156 | Pretrained models and corresponding datasets are provided in the below table. After downloading the datasets and models, follow the instructions in the [testing section](#testing) to do data augmentation, generating blur images, or image deblurring. 157 | 158 | [REDS]: https://seungjunnah.github.io/Datasets/reds.html 159 | [GOPRO]: https://seungjunnah.github.io/Datasets/gopro 160 | 161 | [REDS woVAE]: https://drive.google.com/file/d/1QSRbxvZPZoPy2bp-KOCbTk8l49RX-zj9/view?usp=sharing 162 | [GOPRO woVAE]: https://drive.google.com/file/d/1xUvRmusWa0PaFej1Kxu11Te33v0JvEeL/view?usp=drive_link 163 | [GOPRO wVAE]: https://drive.google.com/file/d/1vRoDpIsrTRYZKsOMPNbPcMtFDpCT6Foy/view?usp=drive_link 164 | [GOPRO + REDS woVAE]: https://drive.google.com/file/d/169R0hEs3rNeloj-m1rGS4YjW38pu-LFD/view?usp=sharing 165 | 166 | |Model name | dataset(s) | status | 167 | |:-----------------------|:---------------:|-------------------------:| 168 | |[REDS woVAE] | [REDS] | :heavy_check_mark: | 169 | |[GOPRO woVAE] | [GOPRO] | :heavy_check_mark: | 170 | |[GOPRO wVAE] | [GOPRO] | :heavy_check_mark: | 171 | |[GOPRO + REDS woVAE] | [GOPRO], [REDS] | :heavy_check_mark: | 172 | 173 | 174 | ## Notes and references 175 | The training code is borrowed from the EDVR project: https://github.com/xinntao/EDVR 176 | 177 | The backbone code is borrowed from the DeblurGAN project: https://github.com/KupynOrest/DeblurGAN 178 | 179 | The styleGAN code is borrowed from the PULSE project: https://github.com/adamian98/pulse 180 | 181 | The stylegan2 code is borrowed from https://github.com/rosinality/stylegan2-pytorch 182 | -------------------------------------------------------------------------------- /forward_operator/bkse/data/GOPRO_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | GOPRO dataset 3 | support reading images from lmdb, image folder and memcached 4 | """ 5 | import logging 6 | import os.path as osp 7 | import pickle 8 | import random 9 | 10 | import cv2 11 | import data.util as util 12 | import lmdb 13 | import numpy as np 14 | import torch 15 | import torch.utils.data as data 16 | 17 | 18 | try: 19 | import mc # import memcached 20 | except ImportError: 21 | pass 22 | 23 | logger = logging.getLogger("base") 24 | 25 | 26 | class GOPRODataset(data.Dataset): 27 | """ 28 | Reading the training GOPRO dataset 29 | key example: 000_00000000 30 | HQ: Ground-Truth; 31 | LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames 32 | support reading N LQ frames, N = 1, 3, 5, 7 33 | """ 34 | 35 | def __init__(self, opt): 36 | super(GOPRODataset, self).__init__() 37 | self.opt = opt 38 | # temporal augmentation 39 | 40 | self.HQ_root, self.LQ_root = opt["dataroot_HQ"], opt["dataroot_LQ"] 41 | self.N_frames = opt["N_frames"] 42 | self.data_type = self.opt["data_type"] 43 | # directly load image keys 44 | if self.data_type == "lmdb": 45 | self.paths_HQ, _ = util.get_image_paths(self.data_type, opt["dataroot_HQ"]) 46 | logger.info("Using lmdb meta info for cache keys.") 47 | elif opt["cache_keys"]: 48 | logger.info("Using cache keys: {}".format(opt["cache_keys"])) 49 | self.paths_HQ = pickle.load(open(opt["cache_keys"], "rb"))["keys"] 50 | else: 51 | raise ValueError( 52 | "Need to create cache keys (meta_info.pkl) \ 53 | by running [create_lmdb.py]" 54 | ) 55 | 56 | assert self.paths_HQ, "Error: HQ path is empty." 57 | 58 | if self.data_type == "lmdb": 59 | self.HQ_env, self.LQ_env = None, None 60 | elif self.data_type == "mc": # memcached 61 | self.mclient = None 62 | elif self.data_type == "img": 63 | pass 64 | else: 65 | raise ValueError("Wrong data type: {}".format(self.data_type)) 66 | 67 | def _init_lmdb(self): 68 | # https://github.com/chainer/chainermn/issues/129 69 | self.HQ_env = lmdb.open(self.opt["dataroot_HQ"], readonly=True, lock=False, readahead=False, meminit=False) 70 | self.LQ_env = lmdb.open(self.opt["dataroot_LQ"], readonly=True, lock=False, readahead=False, meminit=False) 71 | 72 | def _ensure_memcached(self): 73 | if self.mclient is None: 74 | # specify the config files 75 | server_list_config_file = None 76 | client_config_file = None 77 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file) 78 | 79 | def _read_img_mc(self, path): 80 | """ Return BGR, HWC, [0, 255], uint8""" 81 | value = mc.pyvector() 82 | self.mclient.Get(path, value) 83 | value_buf = mc.ConvertBuffer(value) 84 | img_array = np.frombuffer(value_buf, np.uint8) 85 | img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) 86 | return img 87 | 88 | def _read_img_mc_BGR(self, path, name_a, name_b): 89 | """ 90 | Read BGR channels separately and then combine for 1M limits in cluster 91 | """ 92 | img_B = self._read_img_mc(osp.join(path + "_B", name_a, name_b + ".png")) 93 | img_G = self._read_img_mc(osp.join(path + "_G", name_a, name_b + ".png")) 94 | img_R = self._read_img_mc(osp.join(path + "_R", name_a, name_b + ".png")) 95 | img = cv2.merge((img_B, img_G, img_R)) 96 | return img 97 | 98 | def __getitem__(self, index): 99 | if self.data_type == "mc": 100 | self._ensure_memcached() 101 | elif self.data_type == "lmdb" and (self.HQ_env is None or self.LQ_env is None): 102 | self._init_lmdb() 103 | 104 | HQ_size = self.opt["HQ_size"] 105 | key = self.paths_HQ[index] 106 | 107 | # get the HQ image (as the center frame) 108 | img_HQ = util.read_img(self.HQ_env, key, (3, 720, 1280)) 109 | 110 | # get LQ images 111 | img_LQ = util.read_img(self.LQ_env, key, (3, 720, 1280)) 112 | 113 | if self.opt["phase"] == "train": 114 | _, H, W = 3, 720, 1280 # LQ size 115 | # randomly crop 116 | rnd_h = random.randint(0, max(0, H - HQ_size)) 117 | rnd_w = random.randint(0, max(0, W - HQ_size)) 118 | img_LQ = img_LQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] 119 | img_HQ = img_HQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] 120 | 121 | # augmentation - flip, rotate 122 | imgs = [img_HQ, img_LQ] 123 | rlt = util.augment(imgs, self.opt["use_flip"], self.opt["use_rot"]) 124 | img_HQ = rlt[0] 125 | img_LQ = rlt[1] 126 | 127 | # BGR to RGB, HWC to CHW, numpy to tensor 128 | img_LQ = img_LQ[:, :, [2, 1, 0]] 129 | img_HQ = img_HQ[:, :, [2, 1, 0]] 130 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 131 | img_HQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HQ, (2, 0, 1)))).float() 132 | return {"LQ": img_LQ, "HQ": img_HQ, "key": key} 133 | 134 | def __len__(self): 135 | return len(self.paths_HQ) 136 | -------------------------------------------------------------------------------- /forward_operator/bkse/data/REDS_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | REDS dataset 3 | support reading images from lmdb, image folder and memcached 4 | """ 5 | import logging 6 | import os.path as osp 7 | import pickle 8 | import random 9 | 10 | import cv2 11 | import data.util as util 12 | import lmdb 13 | import numpy as np 14 | import torch 15 | import torch.utils.data as data 16 | 17 | 18 | try: 19 | import mc # import memcached 20 | except ImportError: 21 | pass 22 | 23 | logger = logging.getLogger("base") 24 | 25 | 26 | class REDSDataset(data.Dataset): 27 | """ 28 | Reading the training REDS dataset 29 | key example: 000_00000000 30 | HQ: Ground-Truth; 31 | LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames 32 | support reading N LQ frames, N = 1, 3, 5, 7 33 | """ 34 | 35 | def __init__(self, opt): 36 | super(REDSDataset, self).__init__() 37 | self.opt = opt 38 | # temporal augmentation 39 | 40 | self.HQ_root, self.LQ_root = opt["dataroot_HQ"], opt["dataroot_LQ"] 41 | self.N_frames = opt["N_frames"] 42 | self.data_type = self.opt["data_type"] 43 | # directly load image keys 44 | if self.data_type == "lmdb": 45 | self.paths_HQ, _ = util.get_image_paths(self.data_type, opt["dataroot_HQ"]) 46 | logger.info("Using lmdb meta info for cache keys.") 47 | elif opt["cache_keys"]: 48 | logger.info("Using cache keys: {}".format(opt["cache_keys"])) 49 | self.paths_HQ = pickle.load(open(opt["cache_keys"], "rb"))["keys"] 50 | else: 51 | raise ValueError( 52 | "Need to create cache keys (meta_info.pkl) \ 53 | by running [create_lmdb.py]" 54 | ) 55 | 56 | # remove the REDS4 for testing 57 | self.paths_HQ = [v for v in self.paths_HQ if v.split("_")[0] not in ["000", "011", "015", "020"]] 58 | assert self.paths_HQ, "Error: HQ path is empty." 59 | 60 | if self.data_type == "lmdb": 61 | self.HQ_env, self.LQ_env = None, None 62 | elif self.data_type == "mc": # memcached 63 | self.mclient = None 64 | elif self.data_type == "img": 65 | pass 66 | else: 67 | raise ValueError("Wrong data type: {}".format(self.data_type)) 68 | 69 | def _init_lmdb(self): 70 | # https://github.com/chainer/chainermn/issues/129 71 | self.HQ_env = lmdb.open(self.opt["dataroot_HQ"], readonly=True, lock=False, readahead=False, meminit=False) 72 | self.LQ_env = lmdb.open(self.opt["dataroot_LQ"], readonly=True, lock=False, readahead=False, meminit=False) 73 | 74 | def _ensure_memcached(self): 75 | if self.mclient is None: 76 | # specify the config files 77 | server_list_config_file = None 78 | client_config_file = None 79 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file) 80 | 81 | def _read_img_mc(self, path): 82 | """ Return BGR, HWC, [0, 255], uint8""" 83 | value = mc.pyvector() 84 | self.mclient.Get(path, value) 85 | value_buf = mc.ConvertBuffer(value) 86 | img_array = np.frombuffer(value_buf, np.uint8) 87 | img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) 88 | return img 89 | 90 | def _read_img_mc_BGR(self, path, name_a, name_b): 91 | """ 92 | Read BGR channels separately and then combine for 1M limits in cluster 93 | """ 94 | img_B = self._read_img_mc(osp.join(path + "_B", name_a, name_b + ".png")) 95 | img_G = self._read_img_mc(osp.join(path + "_G", name_a, name_b + ".png")) 96 | img_R = self._read_img_mc(osp.join(path + "_R", name_a, name_b + ".png")) 97 | img = cv2.merge((img_B, img_G, img_R)) 98 | return img 99 | 100 | def __getitem__(self, index): 101 | if self.data_type == "mc": 102 | self._ensure_memcached() 103 | elif self.data_type == "lmdb" and (self.HQ_env is None or self.LQ_env is None): 104 | self._init_lmdb() 105 | 106 | HQ_size = self.opt["HQ_size"] 107 | key = self.paths_HQ[index] 108 | name_a, name_b = key.split("_") 109 | 110 | # get the HQ image 111 | img_HQ = util.read_img(self.HQ_env, key, (3, 720, 1280)) 112 | 113 | # get the LQ image 114 | img_LQ = util.read_img(self.LQ_env, key, (3, 720, 1280)) 115 | 116 | if self.opt["phase"] == "train": 117 | _, H, W = 3, 720, 1280 # LQ size 118 | # randomly crop 119 | rnd_h = random.randint(0, max(0, H - HQ_size)) 120 | rnd_w = random.randint(0, max(0, W - HQ_size)) 121 | img_LQ = img_LQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] 122 | img_HQ = img_HQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] 123 | 124 | # augmentation - flip, rotate 125 | imgs = [img_HQ, img_LQ] 126 | rlt = util.augment(imgs, self.opt["use_flip"], self.opt["use_rot"]) 127 | img_HQ = rlt[0] 128 | img_LQ = rlt[1] 129 | 130 | # BGR to RGB, HWC to CHW, numpy to tensor 131 | img_LQ = img_LQ[:, :, [2, 1, 0]] 132 | img_HQ = img_HQ[:, :, [2, 1, 0]] 133 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 134 | img_HQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HQ, (2, 0, 1)))).float() 135 | 136 | return {"LQ": img_LQ, "HQ": img_HQ} 137 | 138 | def __len__(self): 139 | return len(self.paths_HQ) 140 | -------------------------------------------------------------------------------- /forward_operator/bkse/data/__init__.py: -------------------------------------------------------------------------------- 1 | """create dataset and dataloader""" 2 | import logging 3 | 4 | import torch 5 | import torch.utils.data 6 | 7 | 8 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 9 | phase = dataset_opt["phase"] 10 | if phase == "train": 11 | if opt["dist"]: 12 | world_size = torch.distributed.get_world_size() 13 | num_workers = dataset_opt["n_workers"] 14 | assert dataset_opt["batch_size"] % world_size == 0 15 | batch_size = dataset_opt["batch_size"] // world_size 16 | shuffle = False 17 | else: 18 | num_workers = dataset_opt["n_workers"] * len(opt["gpu_ids"]) 19 | batch_size = dataset_opt["batch_size"] 20 | shuffle = True 21 | return torch.utils.data.DataLoader( 22 | dataset, 23 | batch_size=batch_size, 24 | shuffle=shuffle, 25 | num_workers=num_workers, 26 | sampler=sampler, 27 | drop_last=True, 28 | pin_memory=False, 29 | ) 30 | else: 31 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=False) 32 | 33 | 34 | def create_dataset(dataset_opt): 35 | mode = dataset_opt["mode"] 36 | # datasets for image restoration 37 | if mode == "REDS": 38 | from data.REDS_dataset import REDSDataset as D 39 | elif mode == "GOPRO": 40 | from data.GOPRO_dataset import GOPRODataset as D 41 | elif mode == "fewshot": 42 | from data.fewshot_dataset import FewShotDataset as D 43 | elif mode == "levin": 44 | from data.levin_dataset import LevinDataset as D 45 | elif mode == "mix": 46 | from data.mix_dataset import MixDataset as D 47 | else: 48 | raise NotImplementedError(f"Dataset {mode} is not recognized.") 49 | dataset = D(dataset_opt) 50 | 51 | logger = logging.getLogger("base") 52 | logger.info("Dataset [{:s} - {:s}] is created.".format(dataset.__class__.__name__, dataset_opt["name"])) 53 | return dataset 54 | -------------------------------------------------------------------------------- /forward_operator/bkse/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iteration-oriented* training, 4 | for saving time when restart the dataloader after each epoch 5 | """ 6 | import math 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch.utils.data.sampler import Sampler 11 | 12 | 13 | class DistIterSampler(Sampler): 14 | """Sampler that restricts data loading to a subset of the dataset. 15 | 16 | It is especially useful in conjunction with 17 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 18 | process can pass a DistributedSampler instance as a DataLoader sampler, 19 | and load a subset of the original dataset that is exclusive to it. 20 | 21 | .. note:: 22 | Dataset is assumed to be of constant size. 23 | 24 | Arguments: 25 | dataset: Dataset used for sampling. 26 | num_replicas (optional): Number of processes participating in 27 | distributed training. 28 | rank (optional): Rank of the current process within num_replicas. 29 | """ 30 | 31 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 32 | if num_replicas is None: 33 | if not dist.is_available(): 34 | raise RuntimeError( 35 | "Requires distributed \ 36 | package to be available" 37 | ) 38 | num_replicas = dist.get_world_size() 39 | if rank is None: 40 | if not dist.is_available(): 41 | raise RuntimeError( 42 | "Requires distributed \ 43 | package to be available" 44 | ) 45 | rank = dist.get_rank() 46 | self.dataset = dataset 47 | self.num_replicas = num_replicas 48 | self.rank = rank 49 | self.epoch = 0 50 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 51 | self.total_size = self.num_samples * self.num_replicas 52 | 53 | def __iter__(self): 54 | # deterministically shuffle based on epoch 55 | g = torch.Generator() 56 | g.manual_seed(self.epoch) 57 | indices = torch.randperm(self.total_size, generator=g).tolist() 58 | 59 | dsize = len(self.dataset) 60 | indices = [v % dsize for v in indices] 61 | 62 | # subsample 63 | indices = indices[self.rank : self.total_size : self.num_replicas] 64 | assert len(indices) == self.num_samples 65 | 66 | return iter(indices) 67 | 68 | def __len__(self): 69 | return self.num_samples 70 | 71 | def set_epoch(self, epoch): 72 | self.epoch = epoch 73 | -------------------------------------------------------------------------------- /forward_operator/bkse/data/mix_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mix dataset 3 | support reading images from lmdb 4 | """ 5 | import logging 6 | import random 7 | 8 | import data.util as util 9 | import lmdb 10 | import numpy as np 11 | import torch 12 | import torch.utils.data as data 13 | 14 | 15 | logger = logging.getLogger("base") 16 | 17 | 18 | class MixDataset(data.Dataset): 19 | """ 20 | Reading the training REDS dataset 21 | key example: 000_00000000 22 | HQ: Ground-Truth; 23 | LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames 24 | support reading N LQ frames, N = 1, 3, 5, 7 25 | """ 26 | 27 | def __init__(self, opt): 28 | super(MixDataset, self).__init__() 29 | self.opt = opt 30 | # temporal augmentation 31 | 32 | self.HQ_roots = opt["dataroots_HQ"] 33 | self.LQ_roots = opt["dataroots_LQ"] 34 | self.use_identical = opt["identical_loss"] 35 | dataset_weights = opt["dataset_weights"] 36 | self.data_type = "lmdb" 37 | # directly load image keys 38 | self.HQ_envs, self.LQ_envs = None, None 39 | self.paths_HQ = [] 40 | for idx, (HQ_root, LQ_root) in enumerate(zip(self.HQ_roots, self.LQ_roots)): 41 | paths_HQ, _ = util.get_image_paths(self.data_type, HQ_root) 42 | self.paths_HQ += list(zip([idx] * len(paths_HQ), paths_HQ)) * dataset_weights[idx] 43 | random.shuffle(self.paths_HQ) 44 | logger.info("Using lmdb meta info for cache keys.") 45 | 46 | def _init_lmdb(self): 47 | self.HQ_envs, self.LQ_envs = [], [] 48 | for HQ_root, LQ_root in zip(self.HQ_roots, self.LQ_roots): 49 | self.HQ_envs.append(lmdb.open(HQ_root, readonly=True, lock=False, readahead=False, meminit=False)) 50 | self.LQ_envs.append(lmdb.open(LQ_root, readonly=True, lock=False, readahead=False, meminit=False)) 51 | 52 | def __getitem__(self, index): 53 | if self.HQ_envs is None: 54 | self._init_lmdb() 55 | 56 | HQ_size = self.opt["HQ_size"] 57 | env_idx, key = self.paths_HQ[index] 58 | name_a, name_b = key.split("_") 59 | target_frame_idx = int(name_b) 60 | 61 | # determine the neighbor frames 62 | # ensure not exceeding the borders 63 | neighbor_list = [target_frame_idx] 64 | name_b = "{:08d}".format(neighbor_list[0]) 65 | 66 | # get the HQ image (as the center frame) 67 | img_HQ_l = [] 68 | for v in neighbor_list: 69 | img_HQ = util.read_img(self.HQ_envs[env_idx], "{}_{:08d}".format(name_a, v), (3, 720, 1280)) 70 | img_HQ_l.append(img_HQ) 71 | 72 | # get LQ images 73 | img_LQ = util.read_img(self.LQ_envs[env_idx], "{}_{:08d}".format(name_a, neighbor_list[-1]), (3, 720, 1280)) 74 | if self.opt["phase"] == "train": 75 | _, H, W = 3, 720, 1280 # LQ size 76 | # randomly crop 77 | rnd_h = random.randint(0, max(0, H - HQ_size)) 78 | rnd_w = random.randint(0, max(0, W - HQ_size)) 79 | img_LQ = img_LQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] 80 | img_HQ_l = [v[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] for v in img_HQ_l] 81 | 82 | # augmentation - flip, rotate 83 | img_HQ_l.append(img_LQ) 84 | rlt = util.augment(img_HQ_l, self.opt["use_flip"], self.opt["use_rot"]) 85 | img_HQ_l = rlt[0:-1] 86 | img_LQ = rlt[-1] 87 | 88 | # stack LQ images to NHWC, N is the frame number 89 | img_HQs = np.stack(img_HQ_l, axis=0) 90 | # BGR to RGB, HWC to CHW, numpy to tensor 91 | img_LQ = img_LQ[:, :, [2, 1, 0]] 92 | img_HQs = img_HQs[:, :, :, [2, 1, 0]] 93 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 94 | img_HQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HQs, (0, 3, 1, 2)))).float() 95 | # print(img_LQ.shape, img_HQs.shape) 96 | 97 | if self.use_identical and np.random.randint(0, 10) == 0: 98 | img_LQ = img_HQs[-1, :, :, :] 99 | return {"LQ": img_LQ, "HQs": img_HQs, "identical_w": 10} 100 | 101 | return {"LQ": img_LQ, "HQs": img_HQs, "identical_w": 0} 102 | 103 | def __len__(self): 104 | return len(self.paths_HQ) 105 | -------------------------------------------------------------------------------- /forward_operator/bkse/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import os.path as osp 5 | import random 6 | 7 | import cv2 8 | import data.util as data_util 9 | import lmdb 10 | import numpy as np 11 | import torch 12 | import utils.util as util 13 | import yaml 14 | from models.kernel_encoding.kernel_wizard import KernelWizard 15 | 16 | 17 | def read_image(env, key, x, y, h, w): 18 | img = data_util.read_img(env, key, (3, 720, 1280)) 19 | img = np.transpose(img[x : x + h, y : y + w, [2, 1, 0]], (2, 0, 1)) 20 | return img 21 | 22 | 23 | def main(): 24 | device = torch.device("cuda") 25 | 26 | parser = argparse.ArgumentParser(description="Kernel extractor testing") 27 | 28 | parser.add_argument("--source_H", action="store", help="source image height", type=int, required=True) 29 | parser.add_argument("--source_W", action="store", help="source image width", type=int, required=True) 30 | parser.add_argument("--target_H", action="store", help="target image height", type=int, required=True) 31 | parser.add_argument("--target_W", action="store", help="target image width", type=int, required=True) 32 | parser.add_argument( 33 | "--augmented_H", action="store", help="desired height of the augmented images", type=int, required=True 34 | ) 35 | parser.add_argument( 36 | "--augmented_W", action="store", help="desired width of the augmented images", type=int, required=True 37 | ) 38 | 39 | parser.add_argument( 40 | "--source_LQ_root", action="store", help="source low-quality dataroot", type=str, required=True 41 | ) 42 | parser.add_argument( 43 | "--source_HQ_root", action="store", help="source high-quality dataroot", type=str, required=True 44 | ) 45 | parser.add_argument( 46 | "--target_HQ_root", action="store", help="target high-quality dataroot", type=str, required=True 47 | ) 48 | parser.add_argument("--save_path", action="store", help="save path", type=str, required=True) 49 | parser.add_argument("--yml_path", action="store", help="yml path", type=str, required=True) 50 | parser.add_argument( 51 | "--num_images", action="store", help="number of desire augmented images", type=int, required=True 52 | ) 53 | 54 | args = parser.parse_args() 55 | 56 | source_LQ_root = args.source_LQ_root 57 | source_HQ_root = args.source_HQ_root 58 | target_HQ_root = args.target_HQ_root 59 | 60 | save_path = args.save_path 61 | source_H, source_W = args.source_H, args.source_W 62 | target_H, target_W = args.target_H, args.target_W 63 | augmented_H, augmented_W = args.augmented_H, args.augmented_W 64 | yml_path = args.yml_path 65 | num_images = args.num_images 66 | 67 | # Initializing logger 68 | logger = logging.getLogger("base") 69 | os.makedirs(save_path, exist_ok=True) 70 | util.setup_logger("base", save_path, "test", level=logging.INFO, screen=True, tofile=True) 71 | logger.info("source LQ root: {}".format(source_LQ_root)) 72 | logger.info("source HQ root: {}".format(source_HQ_root)) 73 | logger.info("target HQ root: {}".format(target_HQ_root)) 74 | logger.info("augmented height: {}".format(augmented_H)) 75 | logger.info("augmented width: {}".format(augmented_W)) 76 | logger.info("Number of augmented images: {}".format(num_images)) 77 | 78 | # Initializing mode 79 | logger.info("Loading model...") 80 | with open(yml_path, "r") as f: 81 | print(yml_path) 82 | opt = yaml.load(f)["KernelWizard"] 83 | model_path = opt["pretrained"] 84 | model = KernelWizard(opt) 85 | model.eval() 86 | model.load_state_dict(torch.load(model_path)) 87 | model = model.to(device) 88 | logger.info("Done") 89 | 90 | # processing data 91 | source_HQ_env = lmdb.open(source_HQ_root, readonly=True, lock=False, readahead=False, meminit=False) 92 | source_LQ_env = lmdb.open(source_LQ_root, readonly=True, lock=False, readahead=False, meminit=False) 93 | target_HQ_env = lmdb.open(target_HQ_root, readonly=True, lock=False, readahead=False, meminit=False) 94 | paths_source_HQ, _ = data_util.get_image_paths("lmdb", source_HQ_root) 95 | paths_target_HQ, _ = data_util.get_image_paths("lmdb", target_HQ_root) 96 | 97 | psnr_avg = 0 98 | 99 | for i in range(num_images): 100 | source_key = np.random.choice(paths_source_HQ) 101 | target_key = np.random.choice(paths_target_HQ) 102 | 103 | source_rnd_h = random.randint(0, max(0, source_H - augmented_H)) 104 | source_rnd_w = random.randint(0, max(0, source_W - augmented_W)) 105 | target_rnd_h = random.randint(0, max(0, target_H - augmented_H)) 106 | target_rnd_w = random.randint(0, max(0, target_W - augmented_W)) 107 | 108 | source_LQ = read_image(source_LQ_env, source_key, source_rnd_h, source_rnd_w, augmented_H, augmented_W) 109 | source_HQ = read_image(source_HQ_env, source_key, source_rnd_h, source_rnd_w, augmented_H, augmented_W) 110 | target_HQ = read_image(target_HQ_env, target_key, target_rnd_h, target_rnd_w, augmented_H, augmented_W) 111 | 112 | source_LQ = torch.Tensor(source_LQ).unsqueeze(0).to(device) 113 | source_HQ = torch.Tensor(source_HQ).unsqueeze(0).to(device) 114 | target_HQ = torch.Tensor(target_HQ).unsqueeze(0).to(device) 115 | 116 | with torch.no_grad(): 117 | kernel_mean, kernel_sigma = model(source_HQ, source_LQ) 118 | kernel = kernel_mean + kernel_sigma * torch.randn_like(kernel_mean) 119 | fake_source_LQ = model.adaptKernel(source_HQ, kernel) 120 | target_LQ = model.adaptKernel(target_HQ, kernel) 121 | 122 | LQ_img = util.tensor2img(source_LQ) 123 | fake_LQ_img = util.tensor2img(fake_source_LQ) 124 | target_LQ_img = util.tensor2img(target_LQ) 125 | target_HQ_img = util.tensor2img(target_HQ) 126 | 127 | target_HQ_dst = osp.join(save_path, "sharp/{:03d}/{:08d}.png".format(i // 100, i % 100)) 128 | target_LQ_dst = osp.join(save_path, "blur/{:03d}/{:08d}.png".format(i // 100, i % 100)) 129 | 130 | os.makedirs(osp.dirname(target_HQ_dst), exist_ok=True) 131 | os.makedirs(osp.dirname(target_LQ_dst), exist_ok=True) 132 | 133 | cv2.imwrite(target_HQ_dst, target_HQ_img) 134 | cv2.imwrite(target_LQ_dst, target_LQ_img) 135 | # torch.save(kernel, osp.join(osp.dirname(target_LQ_dst), f'kernel{i:03d}.pth')) 136 | 137 | psnr = util.calculate_psnr(LQ_img, fake_LQ_img) 138 | 139 | logger.info("Reconstruction PSNR of image #{:03d}/{:03d}: {:.2f}db".format(i, num_images, psnr)) 140 | psnr_avg += psnr 141 | 142 | logger.info("Average reconstruction PSNR: {:.2f}db".format(psnr_avg / num_images)) 143 | 144 | 145 | main() 146 | -------------------------------------------------------------------------------- /forward_operator/bkse/domain_specific_deblur.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from math import ceil, log10 3 | from pathlib import Path 4 | 5 | import torchvision 6 | import yaml 7 | from PIL import Image 8 | from torch.nn import DataParallel 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | 12 | class Images(Dataset): 13 | def __init__(self, root_dir, duplicates): 14 | self.root_path = Path(root_dir) 15 | self.image_list = list(self.root_path.glob("*.png")) 16 | self.duplicates = ( 17 | duplicates # Number of times to duplicate the image in the dataset to produce multiple HR images 18 | ) 19 | 20 | def __len__(self): 21 | return self.duplicates * len(self.image_list) 22 | 23 | def __getitem__(self, idx): 24 | img_path = self.image_list[idx // self.duplicates] 25 | image = torchvision.transforms.ToTensor()(Image.open(img_path)) 26 | if self.duplicates == 1: 27 | return image, img_path.stem 28 | else: 29 | return image, img_path.stem + f"_{(idx % self.duplicates)+1}" 30 | 31 | 32 | parser = argparse.ArgumentParser(description="PULSE") 33 | 34 | # I/O arguments 35 | parser.add_argument("--input_dir", type=str, default="imgs/blur_faces", help="input data directory") 36 | parser.add_argument( 37 | "--output_dir", type=str, default="experiments/domain_specific_deblur/results", help="output data directory" 38 | ) 39 | parser.add_argument( 40 | "--cache_dir", 41 | type=str, 42 | default="experiments/domain_specific_deblur/cache", 43 | help="cache directory for model weights", 44 | ) 45 | parser.add_argument( 46 | "--yml_path", type=str, default="options/domain_specific_deblur/stylegan2.yml", help="configuration file" 47 | ) 48 | 49 | kwargs = vars(parser.parse_args()) 50 | 51 | with open(kwargs["yml_path"], "rb") as f: 52 | opt = yaml.safe_load(f) 53 | 54 | dataset = Images(kwargs["input_dir"], duplicates=opt["duplicates"]) 55 | out_path = Path(kwargs["output_dir"]) 56 | out_path.mkdir(parents=True, exist_ok=True) 57 | 58 | dataloader = DataLoader(dataset, batch_size=opt["batch_size"]) 59 | 60 | if opt["stylegan_ver"] == 1: 61 | from models.dsd.dsd_stylegan import DSDStyleGAN 62 | 63 | model = DSDStyleGAN(opt=opt, cache_dir=kwargs["cache_dir"]) 64 | else: 65 | from models.dsd.dsd_stylegan2 import DSDStyleGAN2 66 | 67 | model = DSDStyleGAN2(opt=opt, cache_dir=kwargs["cache_dir"]) 68 | 69 | model = DataParallel(model) 70 | 71 | toPIL = torchvision.transforms.ToPILImage() 72 | 73 | for ref_im, ref_im_name in dataloader: 74 | if opt["save_intermediate"]: 75 | padding = ceil(log10(100)) 76 | for i in range(opt["batch_size"]): 77 | int_path_HR = Path(out_path / ref_im_name[i] / "HR") 78 | int_path_LR = Path(out_path / ref_im_name[i] / "LR") 79 | int_path_HR.mkdir(parents=True, exist_ok=True) 80 | int_path_LR.mkdir(parents=True, exist_ok=True) 81 | for j, (HR, LR) in enumerate(model(ref_im)): 82 | for i in range(opt["batch_size"]): 83 | toPIL(HR[i].cpu().detach().clamp(0, 1)).save(int_path_HR / f"{ref_im_name[i]}_{j:0{padding}}.png") 84 | toPIL(LR[i].cpu().detach().clamp(0, 1)).save(int_path_LR / f"{ref_im_name[i]}_{j:0{padding}}.png") 85 | else: 86 | # out_im = model(ref_im,**kwargs) 87 | for j, (HR, LR) in enumerate(model(ref_im)): 88 | for i in range(opt["batch_size"]): 89 | toPIL(HR[i].cpu().detach().clamp(0, 1)).save(out_path / f"{ref_im_name[i]}.png") 90 | -------------------------------------------------------------------------------- /forward_operator/bkse/experiments/pretrained/a.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /forward_operator/bkse/experiments/pretrained/kernel.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/forward_operator/bkse/experiments/pretrained/kernel.pth -------------------------------------------------------------------------------- /forward_operator/bkse/generate_blur.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import os.path as osp 6 | import torch 7 | import utils.util as util 8 | import yaml 9 | from models.kernel_encoding.kernel_wizard import KernelWizard 10 | 11 | 12 | def main(): 13 | device = torch.device("cuda") 14 | 15 | parser = argparse.ArgumentParser(description="Kernel extractor testing") 16 | 17 | parser.add_argument("--image_path", action="store", help="image path", type=str, required=True) 18 | parser.add_argument("--yml_path", action="store", help="yml path", type=str, required=True) 19 | parser.add_argument("--save_path", action="store", help="save path", type=str, default=".") 20 | parser.add_argument("--num_samples", action="store", help="number of samples", type=int, default=1) 21 | 22 | args = parser.parse_args() 23 | 24 | image_path = args.image_path 25 | yml_path = args.yml_path 26 | num_samples = args.num_samples 27 | 28 | # Initializing mode 29 | with open(yml_path, "r") as f: 30 | opt = yaml.load(f)["KernelWizard"] 31 | model_path = opt["pretrained"] 32 | model = KernelWizard(opt) 33 | model.eval() 34 | model.load_state_dict(torch.load(model_path)) 35 | model = model.to(device) 36 | 37 | HQ = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) / 255.0 38 | HQ = np.transpose(HQ, (2, 0, 1)) 39 | HQ_tensor = torch.Tensor(HQ).unsqueeze(0).to(device).cuda() 40 | 41 | for i in range(num_samples): 42 | print(f"Sample #{i}/{num_samples}") 43 | with torch.no_grad(): 44 | kernel = torch.randn((1, 512, 2, 2)).cuda() * 1.2 45 | LQ_tensor = model.adaptKernel(HQ_tensor, kernel) 46 | 47 | dst = osp.join(args.save_path, f"blur{i:03d}.png") 48 | LQ_img = util.tensor2img(LQ_tensor) 49 | 50 | cv2.imwrite(dst, LQ_img) 51 | 52 | 53 | main() 54 | -------------------------------------------------------------------------------- /forward_operator/bkse/generic_deblur.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import yaml 5 | from models.deblurring.joint_deblur import JointDeblur 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser(description="Kernel extractor testing") 10 | 11 | parser.add_argument("--image_path", action="store", help="image path", type=str, required=True) 12 | parser.add_argument("--save_path", action="store", help="save path", type=str, default="res.png") 13 | parser.add_argument("--yml_path", action="store", help="yml path", type=str, required=True) 14 | 15 | args = parser.parse_args() 16 | 17 | # Initializing mode 18 | with open(args.yml_path, "rb") as f: 19 | opt = yaml.safe_load(f) 20 | model = JointDeblur(opt) 21 | 22 | blur_img = cv2.cvtColor(cv2.imread(args.image_path), cv2.COLOR_BGR2RGB) 23 | sharp_img = model.deblur(blur_img) 24 | 25 | cv2.imwrite(args.save_path, sharp_img) 26 | 27 | 28 | main() 29 | -------------------------------------------------------------------------------- /forward_operator/bkse/kernel_encoding/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parallel import DistributedDataParallel 7 | 8 | 9 | class BaseModel: 10 | def __init__(self, opt): 11 | self.opt = opt 12 | self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") 13 | self.is_train = opt["is_train"] 14 | self.schedulers = [] 15 | self.optimizers = [] 16 | 17 | def feed_data(self, data): 18 | pass 19 | 20 | def optimize_parameters(self): 21 | pass 22 | 23 | def get_current_visuals(self): 24 | pass 25 | 26 | def get_current_losses(self): 27 | pass 28 | 29 | def print_network(self): 30 | pass 31 | 32 | def save(self, label): 33 | pass 34 | 35 | def load(self): 36 | pass 37 | 38 | def _set_lr(self, lr_groups_l): 39 | """Set learning rate for warmup 40 | lr_groups_l: list for lr_groups. each for a optimizer""" 41 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 42 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 43 | param_group["lr"] = lr 44 | 45 | def _get_init_lr(self): 46 | """Get the initial lr, which is set by the scheduler""" 47 | init_lr_groups_l = [] 48 | for optimizer in self.optimizers: 49 | init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) 50 | return init_lr_groups_l 51 | 52 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 53 | for scheduler in self.schedulers: 54 | scheduler.step() 55 | # set up warm-up learning rate 56 | if cur_iter < warmup_iter: 57 | # get initial lr for each group 58 | init_lr_g_l = self._get_init_lr() 59 | # modify warming-up learning rates 60 | warm_up_lr_l = [] 61 | for init_lr_g in init_lr_g_l: 62 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 63 | # set learning rate 64 | self._set_lr(warm_up_lr_l) 65 | 66 | def get_current_learning_rate(self): 67 | return [param_group["lr"] for param_group in self.optimizers[0].param_groups] 68 | 69 | def get_network_description(self, network): 70 | """Get the string and total parameters of the network""" 71 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 72 | network = network.module 73 | return str(network), sum(map(lambda x: x.numel(), network.parameters())) 74 | 75 | def save_network(self, network, network_label, iter_label): 76 | save_filename = "{}_{}.pth".format(iter_label, network_label) 77 | save_path = os.path.join(self.opt["path"]["models"], save_filename) 78 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 79 | network = network.module 80 | state_dict = network.state_dict() 81 | for key, param in state_dict.items(): 82 | state_dict[key] = param.cpu() 83 | torch.save(state_dict, save_path) 84 | 85 | def load_network(self, load_path, network, strict=True, prefix=""): 86 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 87 | network = network.module 88 | load_net = torch.load(load_path) 89 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 90 | for k, v in load_net.items(): 91 | if k.startswith("module."): 92 | load_net_clean[k[7:]] = v 93 | else: 94 | load_net_clean[k] = v 95 | load_net.update(load_net_clean) 96 | 97 | model_dict = network.state_dict() 98 | for k, v in load_net.items(): 99 | k = prefix + k 100 | if (k in model_dict) and (v.shape == model_dict[k].shape): 101 | model_dict[k] = v 102 | else: 103 | print("Load failed:", k) 104 | 105 | network.load_state_dict(model_dict, strict=True) 106 | 107 | def save_training_state(self, epoch, iter_step): 108 | """ 109 | Save training state during training, 110 | which will be used for resuming 111 | """ 112 | 113 | state = {"epoch": epoch, "iter": iter_step, "schedulers": [], "optimizers": []} 114 | for s in self.schedulers: 115 | state["schedulers"].append(s.state_dict()) 116 | for o in self.optimizers: 117 | state["optimizers"].append(o.state_dict()) 118 | save_filename = "{}.state".format(iter_step) 119 | save_path = os.path.join(self.opt["path"]["training_state"], save_filename) 120 | torch.save(state, save_path) 121 | 122 | def resume_training(self, resume_state): 123 | """Resume the optimizers and schedulers for training""" 124 | resume_optimizers = resume_state["optimizers"] 125 | resume_schedulers = resume_state["schedulers"] 126 | assert len(resume_optimizers) == len(self.optimizers), "Wrong lengths of optimizers" 127 | assert len(resume_schedulers) == len(self.schedulers), "Wrong lengths of schedulers" 128 | for i, o in enumerate(resume_optimizers): 129 | self.optimizers[i].load_state_dict(o) 130 | for i, s in enumerate(resume_schedulers): 131 | self.schedulers[i].load_state_dict(s) 132 | -------------------------------------------------------------------------------- /forward_operator/bkse/kernel_encoding/image_base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import models.lr_scheduler as lr_scheduler 5 | import torch 6 | import torch.nn as nn 7 | from models.kernel_encoding.base_model import BaseModel 8 | from models.kernel_encoding.kernel_wizard import KernelWizard 9 | from models.losses.charbonnier_loss import CharbonnierLoss 10 | from torch.nn.parallel import DataParallel, DistributedDataParallel 11 | 12 | 13 | logger = logging.getLogger("base") 14 | 15 | 16 | class ImageBaseModel(BaseModel): 17 | def __init__(self, opt): 18 | super(ImageBaseModel, self).__init__(opt) 19 | 20 | if opt["dist"]: 21 | self.rank = torch.distributed.get_rank() 22 | else: 23 | self.rank = -1 # non dist training 24 | train_opt = opt["train"] 25 | 26 | # define network and load pretrained models 27 | self.netG = KernelWizard(opt["KernelWizard"]).to(self.device) 28 | self.use_vae = opt["KernelWizard"]["use_vae"] 29 | if opt["dist"]: 30 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 31 | else: 32 | self.netG = DataParallel(self.netG) 33 | # print network 34 | self.print_network() 35 | self.load() 36 | 37 | if self.is_train: 38 | self.netG.train() 39 | 40 | # loss 41 | loss_type = train_opt["pixel_criterion"] 42 | if loss_type == "l1": 43 | self.cri_pix = nn.L1Loss(reduction="sum").to(self.device) 44 | elif loss_type == "l2": 45 | self.cri_pix = nn.MSELoss(reduction="sum").to(self.device) 46 | elif loss_type == "cb": 47 | self.cri_pix = CharbonnierLoss().to(self.device) 48 | else: 49 | raise NotImplementedError( 50 | "Loss type [{:s}] is not\ 51 | recognized.".format( 52 | loss_type 53 | ) 54 | ) 55 | self.l_pix_w = train_opt["pixel_weight"] 56 | self.l_kl_w = train_opt["kl_weight"] 57 | 58 | # optimizers 59 | wd_G = train_opt["weight_decay_G"] if train_opt["weight_decay_G"] else 0 60 | params = [] 61 | for k, v in self.netG.named_parameters(): 62 | if v.requires_grad: 63 | params.append(v) 64 | else: 65 | if self.rank <= 0: 66 | logger.warning( 67 | "Params [{:s}] will not\ 68 | optimize.".format( 69 | k 70 | ) 71 | ) 72 | optim_params = [ 73 | {"params": params, "lr": train_opt["lr_G"]}, 74 | ] 75 | 76 | self.optimizer_G = torch.optim.Adam( 77 | optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1"], train_opt["beta2"]) 78 | ) 79 | self.optimizers.append(self.optimizer_G) 80 | 81 | # schedulers 82 | if train_opt["lr_scheme"] == "MultiStepLR": 83 | for optimizer in self.optimizers: 84 | self.schedulers.append( 85 | lr_scheduler.MultiStepLR_Restart( 86 | optimizer, 87 | train_opt["lr_steps"], 88 | restarts=train_opt["restarts"], 89 | weights=train_opt["restart_weights"], 90 | gamma=train_opt["lr_gamma"], 91 | clear_state=train_opt["clear_state"], 92 | ) 93 | ) 94 | elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": 95 | for optimizer in self.optimizers: 96 | self.schedulers.append( 97 | lr_scheduler.CosineAnnealingLR_Restart( 98 | optimizer, 99 | train_opt["T_period"], 100 | eta_min=train_opt["eta_min"], 101 | restarts=train_opt["restarts"], 102 | weights=train_opt["restart_weights"], 103 | ) 104 | ) 105 | else: 106 | raise NotImplementedError() 107 | 108 | self.log_dict = OrderedDict() 109 | 110 | def feed_data(self, data, need_GT=True): 111 | self.LQ = data["LQ"].to(self.device) 112 | self.HQ = data["HQ"].to(self.device) 113 | 114 | def set_params_lr_zero(self, groups): 115 | # fix normal module 116 | for group in groups: 117 | self.optimizers[0].param_groups[group]["lr"] = 0 118 | 119 | def optimize_parameters(self, step): 120 | batchsz, _, _, _ = self.LQ.shape 121 | 122 | self.optimizer_G.zero_grad() 123 | kernel_mean, kernel_sigma = self.netG(self.HQ, self.LQ) 124 | 125 | kernel = kernel_mean + kernel_sigma * torch.randn_like(kernel_mean) 126 | self.fake_LQ = self.netG.module.adaptKernel(self.HQ, kernel) 127 | 128 | l_pix = self.l_pix_w * self.cri_pix(self.fake_LQ, self.LQ) 129 | l_total = l_pix 130 | 131 | if self.use_vae: 132 | KL_divergence = ( 133 | self.l_kl_w 134 | * torch.sum( 135 | torch.pow(kernel_mean, 2) 136 | + torch.pow(kernel_sigma, 2) 137 | - torch.log(1e-8 + torch.pow(kernel_sigma, 2)) 138 | - 1 139 | ).sum() 140 | ) 141 | l_total += KL_divergence 142 | self.log_dict["l_KL"] = KL_divergence.item() / batchsz 143 | 144 | l_total.backward() 145 | self.optimizer_G.step() 146 | 147 | # set log 148 | self.log_dict["l_pix"] = l_pix.item() / batchsz 149 | self.log_dict["l_total"] = l_total.item() / batchsz 150 | 151 | def test(self): 152 | self.netG.eval() 153 | with torch.no_grad(): 154 | self.fake_H = self.netG(self.var_L) 155 | self.netG.train() 156 | 157 | def get_current_log(self): 158 | return self.log_dict 159 | 160 | def get_current_visuals(self, need_GT=True): 161 | out_dict = OrderedDict() 162 | out_dict["LQ"] = self.LQ.detach()[0].float().cpu() 163 | out_dict["rlt"] = self.fake_LQ.detach()[0].float().cpu() 164 | return out_dict 165 | 166 | def print_network(self): 167 | s, n = self.get_network_description(self.netG) 168 | if isinstance(self.netG, nn.DataParallel): 169 | net_struc_str = "{} - {}".format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) 170 | else: 171 | net_struc_str = "{}".format(self.netG.__class__.__name__) 172 | if self.rank <= 0: 173 | logger.info( 174 | "Network G structure: {}, \ 175 | with parameters: {:,d}".format( 176 | net_struc_str, n 177 | ) 178 | ) 179 | logger.info(s) 180 | 181 | def load(self): 182 | if self.opt["path"]["pretrain_model_G"]: 183 | load_path_G = self.opt["path"]["pretrain_model_G"] 184 | if load_path_G is not None: 185 | logger.info( 186 | "Loading model for G [{:s}]\ 187 | ...".format( 188 | load_path_G 189 | ) 190 | ) 191 | self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"]) 192 | 193 | def save(self, iter_label): 194 | self.save_network(self.netG, "G", iter_label) 195 | -------------------------------------------------------------------------------- /forward_operator/bkse/kernel_encoding/kernel_wizard.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from .. import arch_util 4 | import torch 5 | import torch.nn as nn 6 | from ..backbones.resnet import ResidualBlock_noBN, ResnetBlock 7 | from ..backbones.unet_parts import UnetSkipConnectionBlock 8 | 9 | 10 | # The function F in the paper 11 | class KernelExtractor(nn.Module): 12 | def __init__(self, opt): 13 | super(KernelExtractor, self).__init__() 14 | 15 | nf = opt["nf"] 16 | self.kernel_dim = opt["kernel_dim"] 17 | self.use_sharp = opt["KernelExtractor"]["use_sharp"] 18 | self.use_vae = opt["use_vae"] 19 | 20 | # Blur estimator 21 | norm_layer = arch_util.get_norm_layer(opt["KernelExtractor"]["norm"]) 22 | n_blocks = opt["KernelExtractor"]["n_blocks"] 23 | padding_type = opt["KernelExtractor"]["padding_type"] 24 | use_dropout = opt["KernelExtractor"]["use_dropout"] 25 | if type(norm_layer) == functools.partial: 26 | use_bias = norm_layer.func == nn.InstanceNorm2d 27 | else: 28 | use_bias = norm_layer == nn.InstanceNorm2d 29 | 30 | input_nc = nf * 2 if self.use_sharp else nf 31 | output_nc = self.kernel_dim * 2 if self.use_vae else self.kernel_dim 32 | 33 | model = [ 34 | nn.ReflectionPad2d(3), 35 | nn.Conv2d(input_nc, nf, kernel_size=7, padding=0, bias=use_bias), 36 | norm_layer(nf), 37 | nn.ReLU(True), 38 | ] 39 | 40 | n_downsampling = 5 41 | for i in range(n_downsampling): # add downsampling layers 42 | mult = 2 ** i 43 | inc = min(nf * mult, output_nc) 44 | ouc = min(nf * mult * 2, output_nc) 45 | model += [ 46 | nn.Conv2d(inc, ouc, kernel_size=3, stride=2, padding=1, bias=use_bias), 47 | norm_layer(nf * mult * 2), 48 | nn.ReLU(True), 49 | ] 50 | 51 | for i in range(n_blocks): # add ResNet blocks 52 | model += [ 53 | ResnetBlock( 54 | output_nc, 55 | padding_type=padding_type, 56 | norm_layer=norm_layer, 57 | use_dropout=use_dropout, 58 | use_bias=use_bias, 59 | ) 60 | ] 61 | 62 | self.model = nn.Sequential(*model) 63 | 64 | def forward(self, sharp, blur): 65 | output = self.model(torch.cat((sharp, blur), dim=1)) 66 | if self.use_vae: 67 | return output[:, : self.kernel_dim, :, :], output[:, self.kernel_dim :, :, :] 68 | 69 | return output, torch.zeros_like(output).cuda() 70 | 71 | 72 | # The function G in the paper 73 | class KernelAdapter(nn.Module): 74 | def __init__(self, opt): 75 | super(KernelAdapter, self).__init__() 76 | input_nc = opt["nf"] 77 | output_nc = opt["nf"] 78 | ngf = opt["nf"] 79 | norm_layer = arch_util.get_norm_layer(opt["Adapter"]["norm"]) 80 | 81 | # construct unet structure 82 | unet_block = UnetSkipConnectionBlock( 83 | ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True 84 | ) 85 | # gradually reduce the number of filters from ngf * 8 to ngf 86 | unet_block = UnetSkipConnectionBlock( 87 | ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer 88 | ) 89 | unet_block = UnetSkipConnectionBlock( 90 | ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer 91 | ) 92 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 93 | self.model = UnetSkipConnectionBlock( 94 | output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer 95 | ) 96 | 97 | def forward(self, x, k): 98 | """Standard forward""" 99 | return self.model(x, k) 100 | 101 | 102 | class KernelWizard(nn.Module): 103 | def __init__(self, opt): 104 | super(KernelWizard, self).__init__() 105 | lrelu = nn.LeakyReLU(negative_slope=0.1) 106 | front_RBs = opt["front_RBs"] 107 | back_RBs = opt["back_RBs"] 108 | num_image_channels = opt["input_nc"] 109 | nf = opt["nf"] 110 | 111 | # Features extraction 112 | resBlock_noBN_f = functools.partial(ResidualBlock_noBN, nf=nf) 113 | feature_extractor = [] 114 | 115 | feature_extractor.append(nn.Conv2d(num_image_channels, nf, 3, 1, 1, bias=True)) 116 | feature_extractor.append(lrelu) 117 | feature_extractor.append(nn.Conv2d(nf, nf, 3, 2, 1, bias=True)) 118 | feature_extractor.append(lrelu) 119 | feature_extractor.append(nn.Conv2d(nf, nf, 3, 2, 1, bias=True)) 120 | feature_extractor.append(lrelu) 121 | 122 | for i in range(front_RBs): 123 | feature_extractor.append(resBlock_noBN_f()) 124 | 125 | self.feature_extractor = nn.Sequential(*feature_extractor) 126 | 127 | # Kernel extractor 128 | self.kernel_extractor = KernelExtractor(opt) 129 | 130 | # kernel adapter 131 | self.adapter = KernelAdapter(opt) 132 | 133 | # Reconstruction 134 | recon_trunk = [] 135 | for i in range(back_RBs): 136 | recon_trunk.append(resBlock_noBN_f()) 137 | 138 | # upsampling 139 | recon_trunk.append(nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)) 140 | recon_trunk.append(nn.PixelShuffle(2)) 141 | recon_trunk.append(lrelu) 142 | recon_trunk.append(nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)) 143 | recon_trunk.append(nn.PixelShuffle(2)) 144 | recon_trunk.append(lrelu) 145 | recon_trunk.append(nn.Conv2d(64, 64, 3, 1, 1, bias=True)) 146 | recon_trunk.append(lrelu) 147 | recon_trunk.append(nn.Conv2d(64, num_image_channels, 3, 1, 1, bias=True)) 148 | 149 | self.recon_trunk = nn.Sequential(*recon_trunk) 150 | 151 | def adaptKernel(self, x_sharp, kernel): 152 | B, C, H, W = x_sharp.shape 153 | base = x_sharp 154 | 155 | x_sharp = self.feature_extractor(x_sharp) 156 | 157 | out = self.adapter(x_sharp, kernel) 158 | out = self.recon_trunk(out) 159 | out += base 160 | 161 | return out 162 | 163 | def forward(self, x_sharp, x_blur): 164 | x_sharp = self.feature_extractor(x_sharp) 165 | x_blur = self.feature_extractor(x_blur) 166 | 167 | output = self.kernel_extractor(x_sharp, x_blur) 168 | return output 169 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | logger = logging.getLogger("base") 5 | 6 | 7 | def create_model(opt): 8 | model = opt["model"] 9 | if model == "image_base": 10 | from models.kernel_encoding.image_base_model import ImageBaseModel as M 11 | else: 12 | raise NotImplementedError("Model [{:s}] not recognized.".format(model)) 13 | m = M(opt) 14 | logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) 15 | return m 16 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/arch_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | 7 | class Identity(nn.Module): 8 | def forward(self, x): 9 | return x 10 | 11 | 12 | def get_norm_layer(norm_type="instance"): 13 | """Return a normalization layer 14 | Parameters: 15 | norm_type (str) -- the name of the normalization 16 | layer: batch | instance | none 17 | 18 | For BatchNorm, we use learnable affine parameters and 19 | track running statistics (mean/stddev). 20 | 21 | For InstanceNorm, we do not use learnable affine 22 | parameters. We do not track running statistics. 23 | """ 24 | if norm_type == "batch": 25 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 26 | elif norm_type == "instance": 27 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 28 | elif norm_type == "none": 29 | 30 | def norm_layer(x): 31 | return Identity() 32 | 33 | else: 34 | raise NotImplementedError( 35 | f"normalization layer {norm_type}\ 36 | is not found" 37 | ) 38 | return norm_layer 39 | 40 | 41 | def initialize_weights(net_l, scale=1): 42 | if not isinstance(net_l, list): 43 | net_l = [net_l] 44 | for net in net_l: 45 | for m in net.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | init.kaiming_normal_(m.weight, a=0, mode="fan_in") 48 | m.weight.data *= scale # for residual block 49 | if m.bias is not None: 50 | m.bias.data.zero_() 51 | elif isinstance(m, nn.Linear): 52 | init.kaiming_normal_(m.weight, a=0, mode="fan_in") 53 | m.weight.data *= scale 54 | if m.bias is not None: 55 | m.bias.data.zero_() 56 | elif isinstance(m, nn.BatchNorm2d): 57 | init.constant_(m.weight, 1) 58 | init.constant_(m.bias.data, 0.0) 59 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from ..arch_util import initialize_weights 4 | 5 | 6 | class ResnetBlock(nn.Module): 7 | """Define a Resnet block""" 8 | 9 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 10 | """Initialize the Resnet block 11 | A resnet block is a conv block with skip connections 12 | We construct a conv block with build_conv_block function, 13 | and implement skip connections in function. 14 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 15 | """ 16 | super(ResnetBlock, self).__init__() 17 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 18 | 19 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 20 | """Construct a convolutional block. 21 | Parameters: 22 | dim (int) -- the number of channels in the conv layer. 23 | padding_type (str) -- the name of padding 24 | layer: reflect | replicate | zero 25 | norm_layer -- normalization layer 26 | use_dropout (bool) -- if use dropout layers. 27 | use_bias (bool) -- if the conv layer uses bias or not 28 | Returns a conv block (with a conv layer, a normalization layer, 29 | and a non-linearity layer (ReLU)) 30 | """ 31 | conv_block = [] 32 | p = 0 33 | if padding_type == "reflect": 34 | conv_block += [nn.ReflectionPad2d(1)] 35 | elif padding_type == "replicate": 36 | conv_block += [nn.ReplicationPad2d(1)] 37 | elif padding_type == "zero": 38 | p = 1 39 | else: 40 | raise NotImplementedError( 41 | f"padding {padding_type} \ 42 | is not implemented" 43 | ) 44 | 45 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 46 | if use_dropout: 47 | conv_block += [nn.Dropout(0.5)] 48 | 49 | p = 0 50 | if padding_type == "reflect": 51 | conv_block += [nn.ReflectionPad2d(1)] 52 | elif padding_type == "replicate": 53 | conv_block += [nn.ReplicationPad2d(1)] 54 | elif padding_type == "zero": 55 | p = 1 56 | else: 57 | raise NotImplementedError( 58 | f"padding {padding_type} \ 59 | is not implemented" 60 | ) 61 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 62 | 63 | return nn.Sequential(*conv_block) 64 | 65 | def forward(self, x): 66 | """Forward function (with skip connections)""" 67 | out = x + self.conv_block(x) # add skip connections 68 | return out 69 | 70 | 71 | class ResidualBlock_noBN(nn.Module): 72 | """Residual block w/o BN 73 | ---Conv-ReLU-Conv-+- 74 | |________________| 75 | """ 76 | 77 | def __init__(self, nf=64): 78 | super(ResidualBlock_noBN, self).__init__() 79 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 80 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 81 | 82 | # initialization 83 | initialize_weights([self.conv1, self.conv2], 0.1) 84 | 85 | def forward(self, x): 86 | identity = x 87 | out = F.relu(self.conv1(x), inplace=False) 88 | out = self.conv2(out) 89 | return identity + out 90 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/backbones/skip/concat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Concat(nn.Module): 7 | def __init__(self, dim, *args): 8 | super(Concat, self).__init__() 9 | self.dim = dim 10 | 11 | for idx, module in enumerate(args): 12 | self.add_module(str(idx), module) 13 | 14 | def forward(self, input): 15 | inputs = [] 16 | for module in self._modules.values(): 17 | inputs.append(module(input)) 18 | 19 | inputs_shapes2 = [x.shape[2] for x in inputs] 20 | inputs_shapes3 = [x.shape[3] for x in inputs] 21 | 22 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all( 23 | np.array(inputs_shapes3) == min(inputs_shapes3) 24 | ): 25 | inputs_ = inputs 26 | else: 27 | target_shape2 = min(inputs_shapes2) 28 | target_shape3 = min(inputs_shapes3) 29 | 30 | inputs_ = [] 31 | for inp in inputs: 32 | diff2 = (inp.size(2) - target_shape2) // 2 33 | diff3 = (inp.size(3) - target_shape3) // 2 34 | inputs_.append(inp[:, :, diff2 : diff2 + target_shape2, diff3 : diff3 + target_shape3]) 35 | 36 | return torch.cat(inputs_, dim=self.dim) 37 | 38 | def __len__(self): 39 | return len(self._modules) 40 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/backbones/skip/downsampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Downsampler(nn.Module): 7 | """ 8 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf 9 | """ 10 | 11 | def __init__( 12 | self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False 13 | ): 14 | super(Downsampler, self).__init__() 15 | 16 | assert phase in [0, 0.5], "phase should be 0 or 0.5" 17 | 18 | if kernel_type == "lanczos2": 19 | support = 2 20 | kernel_width = 4 * factor + 1 21 | kernel_type_ = "lanczos" 22 | 23 | elif kernel_type == "lanczos3": 24 | support = 3 25 | kernel_width = 6 * factor + 1 26 | kernel_type_ = "lanczos" 27 | 28 | elif kernel_type == "gauss12": 29 | kernel_width = 7 30 | sigma = 1 / 2 31 | kernel_type_ = "gauss" 32 | 33 | elif kernel_type == "gauss1sq2": 34 | kernel_width = 9 35 | sigma = 1.0 / np.sqrt(2) 36 | kernel_type_ = "gauss" 37 | 38 | elif kernel_type in ["lanczos", "gauss", "box"]: 39 | kernel_type_ = kernel_type 40 | 41 | else: 42 | assert False, "wrong name kernel" 43 | 44 | # note that `kernel width` will be different to actual size for phase = 1/2 45 | self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma) 46 | 47 | downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0) 48 | downsampler.weight.data[:] = 0 49 | downsampler.bias.data[:] = 0 50 | 51 | kernel_torch = torch.from_numpy(self.kernel) 52 | for i in range(n_planes): 53 | downsampler.weight.data[i, i] = kernel_torch 54 | 55 | self.downsampler_ = downsampler 56 | 57 | if preserve_size: 58 | 59 | if self.kernel.shape[0] % 2 == 1: 60 | pad = int((self.kernel.shape[0] - 1) / 2.0) 61 | else: 62 | pad = int((self.kernel.shape[0] - factor) / 2.0) 63 | 64 | self.padding = nn.ReplicationPad2d(pad) 65 | 66 | self.preserve_size = preserve_size 67 | 68 | def forward(self, input): 69 | if self.preserve_size: 70 | x = self.padding(input) 71 | else: 72 | x = input 73 | self.x = x 74 | return self.downsampler_(x) 75 | 76 | 77 | class Blurconv(nn.Module): 78 | """ 79 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf 80 | """ 81 | 82 | def __init__(self, n_planes=1, preserve_size=False): 83 | super(Blurconv, self).__init__() 84 | 85 | # self.kernel = kernel 86 | # blurconv = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=1, padding=0) 87 | # blurconvr.weight.data = self.kernel 88 | # blurconv.bias.data[:] = 0 89 | self.n_planes = n_planes 90 | self.preserve_size = preserve_size 91 | 92 | # kernel_torch = torch.from_numpy(self.kernel) 93 | # for i in range(n_planes): 94 | # blurconv.weight.data[i, i] = kernel_torch 95 | 96 | # self.blurconv_ = blurconv 97 | # 98 | # if preserve_size: 99 | # 100 | # if self.kernel.shape[0] % 2 == 1: 101 | # pad = int((self.kernel.shape[0] - 1) / 2.) 102 | # else: 103 | # pad = int((self.kernel.shape[0] - factor) / 2.) 104 | # 105 | # self.padding = nn.ReplicationPad2d(pad) 106 | # 107 | # self.preserve_size = preserve_size 108 | 109 | def forward(self, input, kernel): 110 | if self.preserve_size: 111 | if kernel.shape[0] % 2 == 1: 112 | pad = int((kernel.shape[3] - 1) / 2.0) 113 | else: 114 | pad = int((kernel.shape[3] - 1.0) / 2.0) 115 | padding = nn.ReplicationPad2d(pad) 116 | x = padding(input) 117 | else: 118 | x = input 119 | 120 | blurconv = nn.Conv2d( 121 | self.n_planes, self.n_planes, kernel_size=kernel.size(3), stride=1, padding=0, bias=False 122 | ).cuda() 123 | 124 | blurconv.weight.data[:] = kernel 125 | 126 | return blurconv(x) 127 | 128 | 129 | class Blurconv2(nn.Module): 130 | """ 131 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf 132 | """ 133 | 134 | def __init__(self, n_planes=1, preserve_size=False, k_size=21): 135 | super(Blurconv2, self).__init__() 136 | 137 | self.n_planes = n_planes 138 | self.k_size = k_size 139 | self.preserve_size = preserve_size 140 | self.blurconv = nn.Conv2d(self.n_planes, self.n_planes, kernel_size=k_size, stride=1, padding=0, bias=False) 141 | 142 | # self.blurconv.weight.data[:] /= self.blurconv.weight.data.sum() 143 | def forward(self, input): 144 | if self.preserve_size: 145 | pad = int((self.k_size - 1.0) / 2.0) 146 | padding = nn.ReplicationPad2d(pad) 147 | x = padding(input) 148 | else: 149 | x = input 150 | # self.blurconv.weight.data[:] /= self.blurconv.weight.data.sum() 151 | return self.blurconv(x) 152 | 153 | 154 | def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None): 155 | assert kernel_type in ["lanczos", "gauss", "box"] 156 | 157 | # factor = float(factor) 158 | if phase == 0.5 and kernel_type != "box": 159 | kernel = np.zeros([kernel_width - 1, kernel_width - 1]) 160 | else: 161 | kernel = np.zeros([kernel_width, kernel_width]) 162 | 163 | if kernel_type == "box": 164 | assert phase == 0.5, "Box filter is always half-phased" 165 | kernel[:] = 1.0 / (kernel_width * kernel_width) 166 | 167 | elif kernel_type == "gauss": 168 | assert sigma, "sigma is not specified" 169 | assert phase != 0.5, "phase 1/2 for gauss not implemented" 170 | 171 | center = (kernel_width + 1.0) / 2.0 172 | print(center, kernel_width) 173 | sigma_sq = sigma * sigma 174 | 175 | for i in range(1, kernel.shape[0] + 1): 176 | for j in range(1, kernel.shape[1] + 1): 177 | di = (i - center) / 2.0 178 | dj = (j - center) / 2.0 179 | kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj) / (2 * sigma_sq)) 180 | kernel[i - 1][j - 1] = kernel[i - 1][j - 1] / (2.0 * np.pi * sigma_sq) 181 | elif kernel_type == "lanczos": 182 | assert support, "support is not specified" 183 | center = (kernel_width + 1) / 2.0 184 | 185 | for i in range(1, kernel.shape[0] + 1): 186 | for j in range(1, kernel.shape[1] + 1): 187 | 188 | if phase == 0.5: 189 | di = abs(i + 0.5 - center) / factor 190 | dj = abs(j + 0.5 - center) / factor 191 | else: 192 | di = abs(i - center) / factor 193 | dj = abs(j - center) / factor 194 | 195 | val = 1 196 | if di != 0: 197 | val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support) 198 | val = val / (np.pi * np.pi * di * di) 199 | 200 | if dj != 0: 201 | val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support) 202 | val = val / (np.pi * np.pi * dj * dj) 203 | 204 | kernel[i - 1][j - 1] = val 205 | 206 | else: 207 | assert False, "wrong method name" 208 | 209 | kernel /= kernel.sum() 210 | 211 | return kernel 212 | 213 | 214 | # a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True) 215 | 216 | 217 | ################# 218 | # Learnable downsampler 219 | 220 | # KS = 32 221 | # dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor)) 222 | 223 | # class Apply(nn.Module): 224 | # def __init__(self, what, dim, *args): 225 | # super(Apply, self).__init__() 226 | # self.dim = dim 227 | 228 | # self.what = what 229 | 230 | # def forward(self, input): 231 | # inputs = [] 232 | # for i in range(input.size(self.dim)): 233 | # inputs.append(self.what(input.narrow(self.dim, i, 1))) 234 | 235 | # return torch.cat(inputs, dim=self.dim) 236 | 237 | # def __len__(self): 238 | # return len(self._modules) 239 | 240 | # downs = Apply(dow, 1) 241 | # downs.type(dtype)(net_input.type(dtype)).size() 242 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/backbones/skip/non_local_dot_product.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class _NonLocalBlockND(nn.Module): 6 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 7 | super(_NonLocalBlockND, self).__init__() 8 | 9 | assert dimension in [1, 2, 3] 10 | 11 | self.dimension = dimension 12 | self.sub_sample = sub_sample 13 | 14 | self.in_channels = in_channels 15 | self.inter_channels = inter_channels 16 | 17 | if self.inter_channels is None: 18 | self.inter_channels = in_channels // 2 19 | if self.inter_channels == 0: 20 | self.inter_channels = 1 21 | 22 | if dimension == 3: 23 | conv_nd = nn.Conv3d 24 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 25 | bn = nn.BatchNorm3d 26 | elif dimension == 2: 27 | conv_nd = nn.Conv2d 28 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 29 | bn = nn.BatchNorm2d 30 | else: 31 | conv_nd = nn.Conv1d 32 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 33 | bn = nn.BatchNorm1d 34 | 35 | self.g = conv_nd( 36 | in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 37 | ) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd( 42 | in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 43 | ), 44 | bn(self.in_channels), 45 | ) 46 | nn.init.constant_(self.W[1].weight, 0) 47 | nn.init.constant_(self.W[1].bias, 0) 48 | else: 49 | self.W = conv_nd( 50 | in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 51 | ) 52 | nn.init.constant_(self.W.weight, 0) 53 | nn.init.constant_(self.W.bias, 0) 54 | 55 | self.theta = conv_nd( 56 | in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 57 | ) 58 | 59 | self.phi = conv_nd( 60 | in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 61 | ) 62 | 63 | if sub_sample: 64 | self.g = nn.Sequential(self.g, max_pool_layer) 65 | self.phi = nn.Sequential(self.phi, max_pool_layer) 66 | 67 | def forward(self, x): 68 | """ 69 | :param x: (b, c, t, h, w) 70 | :return: 71 | """ 72 | 73 | batch_size = x.size(0) 74 | 75 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 76 | g_x = g_x.permute(0, 2, 1) 77 | 78 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 79 | theta_x = theta_x.permute(0, 2, 1) 80 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 81 | f = torch.matmul(theta_x, phi_x) 82 | N = f.size(-1) 83 | f_div_C = f / N 84 | 85 | y = torch.matmul(f_div_C, g_x) 86 | y = y.permute(0, 2, 1).contiguous() 87 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 88 | W_y = self.W(y) 89 | z = W_y + x 90 | 91 | return z 92 | 93 | 94 | class NONLocalBlock1D(_NonLocalBlockND): 95 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 96 | super(NONLocalBlock1D, self).__init__( 97 | in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer 98 | ) 99 | 100 | 101 | class NONLocalBlock2D(_NonLocalBlockND): 102 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 103 | super(NONLocalBlock2D, self).__init__( 104 | in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer 105 | ) 106 | 107 | 108 | class NONLocalBlock3D(_NonLocalBlockND): 109 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 110 | super(NONLocalBlock3D, self).__init__( 111 | in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer 112 | ) 113 | 114 | 115 | if __name__ == "__main__": 116 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]: 117 | img = torch.zeros(2, 3, 20) 118 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 119 | out = net(img) 120 | print(out.size()) 121 | 122 | img = torch.zeros(2, 3, 20, 20) 123 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 124 | out = net(img) 125 | print(out.size()) 126 | 127 | img = torch.randn(2, 3, 8, 20, 20) 128 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 129 | out = net(img) 130 | print(out.size()) 131 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/backbones/skip/skip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .concat import Concat 5 | from .non_local_dot_product import NONLocalBlock2D 6 | from .util import get_activation, get_conv 7 | 8 | 9 | def add_module(self, module): 10 | self.add_module(str(len(self) + 1), module) 11 | 12 | 13 | torch.nn.Module.add = add_module 14 | 15 | 16 | def skip( 17 | num_input_channels=2, 18 | num_output_channels=3, 19 | num_channels_down=[16, 32, 64, 128, 128], 20 | num_channels_up=[16, 32, 64, 128, 128], 21 | num_channels_skip=[4, 4, 4, 4, 4], 22 | filter_size_down=3, 23 | filter_size_up=3, 24 | filter_skip_size=1, 25 | need_sigmoid=True, 26 | need_bias=True, 27 | pad="zero", 28 | upsample_mode="nearest", 29 | downsample_mode="stride", 30 | act_fun="LeakyReLU", 31 | need1x1_up=True, 32 | ): 33 | """Assembles encoder-decoder with skip connections. 34 | 35 | Arguments: 36 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 37 | pad (string): zero|reflection (default: 'zero') 38 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 39 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 40 | 41 | """ 42 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 43 | 44 | n_scales = len(num_channels_down) 45 | 46 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)): 47 | upsample_mode = [upsample_mode] * n_scales 48 | 49 | if not (isinstance(downsample_mode, list) or isinstance(downsample_mode, tuple)): 50 | downsample_mode = [downsample_mode] * n_scales 51 | 52 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)): 53 | filter_size_down = [filter_size_down] * n_scales 54 | 55 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)): 56 | filter_size_up = [filter_size_up] * n_scales 57 | 58 | last_scale = n_scales - 1 59 | 60 | model = nn.Sequential() 61 | model_tmp = model 62 | 63 | input_depth = num_input_channels 64 | for i in range(len(num_channels_down)): 65 | 66 | deeper = nn.Sequential() 67 | skip = nn.Sequential() 68 | 69 | if num_channels_skip[i] != 0: 70 | model_tmp.add(Concat(1, skip, deeper)) 71 | else: 72 | model_tmp.add(deeper) 73 | 74 | model_tmp.add( 75 | nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])) 76 | ) 77 | 78 | if num_channels_skip[i] != 0: 79 | skip.add(get_conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 80 | skip.add(nn.BatchNorm2d(num_channels_skip[i])) 81 | skip.add(get_activation(act_fun)) 82 | 83 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 84 | 85 | deeper.add( 86 | get_conv( 87 | input_depth, 88 | num_channels_down[i], 89 | filter_size_down[i], 90 | 2, 91 | bias=need_bias, 92 | pad=pad, 93 | downsample_mode=downsample_mode[i], 94 | ) 95 | ) 96 | deeper.add(nn.BatchNorm2d(num_channels_down[i])) 97 | deeper.add(get_activation(act_fun)) 98 | if i > 1: 99 | deeper.add(NONLocalBlock2D(in_channels=num_channels_down[i])) 100 | deeper.add(get_conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 101 | deeper.add(nn.BatchNorm2d(num_channels_down[i])) 102 | deeper.add(get_activation(act_fun)) 103 | 104 | deeper_main = nn.Sequential() 105 | 106 | if i == len(num_channels_down) - 1: 107 | # The deepest 108 | k = num_channels_down[i] 109 | else: 110 | deeper.add(deeper_main) 111 | k = num_channels_up[i + 1] 112 | 113 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 114 | 115 | model_tmp.add( 116 | get_conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad) 117 | ) 118 | model_tmp.add(nn.BatchNorm2d(num_channels_up[i])) 119 | model_tmp.add(get_activation(act_fun)) 120 | 121 | if need1x1_up: 122 | model_tmp.add(get_conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 123 | model_tmp.add(nn.BatchNorm2d(num_channels_up[i])) 124 | model_tmp.add(get_activation(act_fun)) 125 | 126 | input_depth = num_channels_down[i] 127 | model_tmp = deeper_main 128 | 129 | model.add(get_conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 130 | if need_sigmoid: 131 | model.add(nn.Sigmoid()) 132 | 133 | return model 134 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/backbones/skip/util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .downsampler import Downsampler 4 | 5 | 6 | class Swish(nn.Module): 7 | """ 8 | https://arxiv.org/abs/1710.05941 9 | The hype was so huge that I could not help but try it 10 | """ 11 | 12 | def __init__(self): 13 | super(Swish, self).__init__() 14 | self.s = nn.Sigmoid() 15 | 16 | def forward(self, x): 17 | return x * self.s(x) 18 | 19 | 20 | def get_conv(in_f, out_f, kernel_size, stride=1, bias=True, pad="zero", downsample_mode="stride"): 21 | downsampler = None 22 | if stride != 1 and downsample_mode != "stride": 23 | 24 | if downsample_mode == "avg": 25 | downsampler = nn.AvgPool2d(stride, stride) 26 | elif downsample_mode == "max": 27 | downsampler = nn.MaxPool2d(stride, stride) 28 | elif downsample_mode in ["lanczos2", "lanczos3"]: 29 | downsampler = Downsampler( 30 | n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True 31 | ) 32 | else: 33 | assert False 34 | 35 | stride = 1 36 | 37 | padder = None 38 | to_pad = int((kernel_size - 1) / 2) 39 | if pad == "reflection": 40 | padder = nn.ReflectionPad2d(to_pad) 41 | to_pad = 0 42 | 43 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 44 | 45 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 46 | return nn.Sequential(*layers) 47 | 48 | 49 | def get_activation(act_fun="LeakyReLU"): 50 | """ 51 | Either string defining an activation function or module (e.g. nn.ReLU) 52 | """ 53 | if isinstance(act_fun, str): 54 | if act_fun == "LeakyReLU": 55 | return nn.LeakyReLU(0.2, inplace=True) 56 | elif act_fun == "Swish": 57 | return Swish() 58 | elif act_fun == "ELU": 59 | return nn.ELU() 60 | elif act_fun == "none": 61 | return nn.Sequential() 62 | else: 63 | assert False 64 | else: 65 | return act_fun() 66 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/backbones/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import functools 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class DoubleConv(nn.Module): 10 | """(convolution => [BN] => ReLU) * 2""" 11 | 12 | def __init__(self, in_channels, out_channels, mid_channels=None): 13 | super().__init__() 14 | if not mid_channels: 15 | mid_channels = out_channels 16 | self.double_conv = nn.Sequential( 17 | nn.Conv2d(in_channels, mid_channels, kernel_size=5, padding=2), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=5, padding=2), 20 | nn.ReLU(inplace=True), 21 | ) 22 | 23 | def forward(self, x): 24 | return self.double_conv(x) 25 | 26 | 27 | class UnetSkipConnectionBlock(nn.Module): 28 | """Defines the Unet submodule with skip connection. 29 | X -------------------identity---------------------- 30 | |-- downsampling -- |submodule| -- upsampling --| 31 | """ 32 | 33 | def __init__( 34 | self, 35 | outer_nc, 36 | inner_nc, 37 | input_nc=None, 38 | submodule=None, 39 | outermost=False, 40 | innermost=False, 41 | norm_layer=nn.BatchNorm2d, 42 | use_dropout=False, 43 | ): 44 | """Construct a Unet submodule with skip connections. 45 | Parameters: 46 | outer_nc (int) -- the number of filters in the outer conv layer 47 | inner_nc (int) -- the number of filters in the inner conv layer 48 | input_nc (int) -- the number of channels in input images/features 49 | submodule (UnetSkipConnectionBlock) --previously defined submodules 50 | outermost (bool) -- if this module is the outermost module 51 | innermost (bool) -- if this module is the innermost module 52 | norm_layer -- normalization layer 53 | use_dropout (bool) -- if use dropout layers. 54 | """ 55 | super(UnetSkipConnectionBlock, self).__init__() 56 | self.outermost = outermost 57 | self.innermost = innermost 58 | if type(norm_layer) == functools.partial: 59 | use_bias = norm_layer.func == nn.InstanceNorm2d 60 | else: 61 | use_bias = norm_layer == nn.InstanceNorm2d 62 | if input_nc is None: 63 | input_nc = outer_nc 64 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 65 | downrelu = nn.LeakyReLU(0.2, True) 66 | downnorm = norm_layer(inner_nc) 67 | uprelu = nn.ReLU(True) 68 | upnorm = norm_layer(outer_nc) 69 | 70 | if outermost: 71 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) 72 | # upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 73 | # upconv = DoubleConv(inner_nc * 2, outer_nc) 74 | up = [uprelu, upconv, nn.Tanh()] 75 | down = [downconv] 76 | self.down = nn.Sequential(*down) 77 | self.submodule = submodule 78 | self.up = nn.Sequential(*up) 79 | elif innermost: 80 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 81 | # upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 82 | # upconv = DoubleConv(inner_nc * 2, outer_nc) 83 | down = [downrelu, downconv] 84 | up = [uprelu, upconv, upnorm] 85 | self.down = nn.Sequential(*down) 86 | self.up = nn.Sequential(*up) 87 | else: 88 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 89 | # upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 90 | # upconv = DoubleConv(inner_nc * 2, outer_nc) 91 | down = [downrelu, downconv, downnorm] 92 | up = [uprelu, upconv, upnorm] 93 | if use_dropout: 94 | up += [nn.Dropout(0.5)] 95 | 96 | self.down = nn.Sequential(*down) 97 | self.submodule = submodule 98 | self.up = nn.Sequential(*up) 99 | 100 | def forward(self, x, noise): 101 | 102 | if self.outermost: 103 | return self.up(self.submodule(self.down(x), noise)) 104 | elif self.innermost: # add skip connections 105 | if noise is None: 106 | noise = torch.randn((1, 512, 8, 8)).cuda() * 0.0007 107 | return torch.cat((self.up(torch.cat((self.down(x), noise), dim=1)), x), dim=1) 108 | else: 109 | return torch.cat((self.up(self.submodule(self.down(x), noise)), x), dim=1) 110 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/deblurring/a.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/deblurring/image_deblur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import utils.util as util 4 | from models.dips import ImageDIP, KernelDIP 5 | from models.kernel_encoding.kernel_wizard import KernelWizard 6 | from models.losses.hyper_laplacian_penalty import HyperLaplacianPenalty 7 | from models.losses.perceptual_loss import PerceptualLoss 8 | from models.losses.ssim_loss import SSIM 9 | from torch.optim.lr_scheduler import StepLR 10 | from tqdm import tqdm 11 | 12 | 13 | class ImageDeblur: 14 | def __init__(self, opt): 15 | self.opt = opt 16 | 17 | # losses 18 | self.ssim_loss = SSIM().cuda() 19 | self.mse = nn.MSELoss().cuda() 20 | self.perceptual_loss = PerceptualLoss().cuda() 21 | self.laplace_penalty = HyperLaplacianPenalty(3, 0.66).cuda() 22 | 23 | self.kernel_wizard = KernelWizard(opt["KernelWizard"]).cuda() 24 | self.kernel_wizard.load_state_dict(torch.load(opt["KernelWizard"]["pretrained"])) 25 | 26 | for k, v in self.kernel_wizard.named_parameters(): 27 | v.requires_grad = False 28 | 29 | def reset_optimizers(self): 30 | self.x_optimizer = torch.optim.Adam(self.x_dip.parameters(), lr=self.opt["x_lr"]) 31 | self.k_optimizer = torch.optim.Adam(self.k_dip.parameters(), lr=self.opt["k_lr"]) 32 | 33 | self.x_scheduler = StepLR(self.x_optimizer, step_size=self.opt["num_iters"] // 5, gamma=0.7) 34 | 35 | self.k_scheduler = StepLR(self.k_optimizer, step_size=self.opt["num_iters"] // 5, gamma=0.7) 36 | 37 | def prepare_DIPs(self): 38 | # x is stand for the sharp image, k is stand for the kernel 39 | self.x_dip = ImageDIP(self.opt["ImageDIP"]).cuda() 40 | self.k_dip = KernelDIP(self.opt["KernelDIP"]).cuda() 41 | 42 | # fixed input vectors of DIPs 43 | # zk and zx are the length of the corresponding vectors 44 | self.dip_zk = util.get_noise(64, "noise", (64, 64)).cuda() 45 | self.dip_zx = util.get_noise(8, "noise", self.opt["img_size"]).cuda() 46 | 47 | def warmup(self, warmup_x, warmup_k): 48 | # Input vector of DIPs is sampled from N(z, I) 49 | reg_noise_std = self.opt["reg_noise_std"] 50 | 51 | for step in tqdm(range(self.opt["num_warmup_iters"])): 52 | self.x_optimizer.zero_grad() 53 | dip_zx_rand = self.dip_zx + reg_noise_std * torch.randn_like(self.dip_zx).cuda() 54 | x = self.x_dip(dip_zx_rand) 55 | 56 | loss = self.mse(x, warmup_x) 57 | loss.backward() 58 | self.x_optimizer.step() 59 | 60 | print("Warming up k DIP") 61 | for step in tqdm(range(self.opt["num_warmup_iters"])): 62 | self.k_optimizer.zero_grad() 63 | dip_zk_rand = self.dip_zk + reg_noise_std * torch.randn_like(self.dip_zk).cuda() 64 | k = self.k_dip(dip_zk_rand) 65 | 66 | loss = self.mse(k, warmup_k) 67 | loss.backward() 68 | self.k_optimizer.step() 69 | 70 | def deblur(self, img): 71 | pass 72 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/deblurring/joint_deblur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils.util as util 3 | from models.deblurring.image_deblur import ImageDeblur 4 | from tqdm import tqdm 5 | 6 | 7 | class JointDeblur(ImageDeblur): 8 | def __init__(self, opt): 9 | super(JointDeblur, self).__init__(opt) 10 | 11 | def deblur(self, y): 12 | """Deblur image 13 | Args: 14 | y: Blur image 15 | """ 16 | y = util.img2tensor(y).unsqueeze(0).cuda() 17 | 18 | self.prepare_DIPs() 19 | self.reset_optimizers() 20 | 21 | warmup_k = torch.load(self.opt["warmup_k_path"]).cuda() 22 | self.warmup(y, warmup_k) 23 | 24 | # Input vector of DIPs is sampled from N(z, I) 25 | 26 | print("Deblurring") 27 | reg_noise_std = self.opt["reg_noise_std"] 28 | for step in tqdm(range(self.opt["num_iters"])): 29 | dip_zx_rand = self.dip_zx + reg_noise_std * torch.randn_like(self.dip_zx).cuda() 30 | dip_zk_rand = self.dip_zk + reg_noise_std * torch.randn_like(self.dip_zk).cuda() 31 | 32 | self.x_optimizer.zero_grad() 33 | self.k_optimizer.zero_grad() 34 | 35 | self.x_scheduler.step() 36 | self.k_scheduler.step() 37 | 38 | x = self.x_dip(dip_zx_rand) 39 | k = self.k_dip(dip_zk_rand) 40 | 41 | fake_y = self.kernel_wizard.adaptKernel(x, k) 42 | 43 | if step < self.opt["num_iters"] // 2: 44 | total_loss = 6e-1 * self.perceptual_loss(fake_y, y) 45 | total_loss += 1 - self.ssim_loss(fake_y, y) 46 | total_loss += 5e-5 * torch.norm(k) 47 | total_loss += 2e-2 * self.laplace_penalty(x) 48 | else: 49 | total_loss = self.perceptual_loss(fake_y, y) 50 | total_loss += 5e-2 * self.laplace_penalty(x) 51 | total_loss += 5e-4 * torch.norm(k) 52 | 53 | total_loss.backward() 54 | 55 | self.x_optimizer.step() 56 | self.k_optimizer.step() 57 | 58 | # debugging 59 | # if step % 100 == 0: 60 | # print(torch.norm(k)) 61 | # print(f"{self.k_optimizer.param_groups[0]['lr']:.3e}") 62 | 63 | return util.tensor2img(x.detach()) 64 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/dips.py: -------------------------------------------------------------------------------- 1 | import models.arch_util as arch_util 2 | import torch.nn as nn 3 | from models.backbones.resnet import ResnetBlock 4 | from models.backbones.skip.skip import skip 5 | 6 | 7 | class KernelDIP(nn.Module): 8 | """ 9 | DIP (Deep Image Prior) for blur kernel 10 | """ 11 | 12 | def __init__(self, opt): 13 | super(KernelDIP, self).__init__() 14 | 15 | norm_layer = arch_util.get_norm_layer("none") 16 | n_blocks = opt["n_blocks"] 17 | nf = opt["nf"] 18 | padding_type = opt["padding_type"] 19 | use_dropout = opt["use_dropout"] 20 | kernel_dim = opt["kernel_dim"] 21 | 22 | input_nc = 64 23 | model = [ 24 | nn.ReflectionPad2d(3), 25 | nn.Conv2d(input_nc, nf, kernel_size=7, padding=0, bias=True), 26 | norm_layer(nf), 27 | nn.ReLU(True), 28 | ] 29 | 30 | n_downsampling = 5 31 | for i in range(n_downsampling): # add downsampling layers 32 | mult = 2 ** i 33 | input_nc = min(nf * mult, kernel_dim) 34 | output_nc = min(nf * mult * 2, kernel_dim) 35 | model += [ 36 | nn.Conv2d(input_nc, output_nc, kernel_size=3, stride=2, padding=1, bias=True), 37 | norm_layer(nf * mult * 2), 38 | nn.ReLU(True), 39 | ] 40 | 41 | for i in range(n_blocks): # add ResNet blocks 42 | model += [ 43 | ResnetBlock( 44 | kernel_dim, 45 | padding_type=padding_type, 46 | norm_layer=norm_layer, 47 | use_dropout=use_dropout, 48 | use_bias=True, 49 | ) 50 | ] 51 | 52 | self.model = nn.Sequential(*model) 53 | 54 | def forward(self, noise): 55 | return self.model(noise) 56 | 57 | 58 | class ImageDIP(nn.Module): 59 | """ 60 | DIP (Deep Image Prior) for sharp image 61 | """ 62 | 63 | def __init__(self, opt): 64 | super(ImageDIP, self).__init__() 65 | 66 | input_nc = opt["input_nc"] 67 | output_nc = opt["output_nc"] 68 | 69 | self.model = skip( 70 | input_nc, 71 | output_nc, 72 | num_channels_down=[128, 128, 128, 128, 128], 73 | num_channels_up=[128, 128, 128, 128, 128], 74 | num_channels_skip=[16, 16, 16, 16, 16], 75 | upsample_mode="bilinear", 76 | need_sigmoid=True, 77 | need_bias=True, 78 | pad=opt["padding_type"], 79 | act_fun="LeakyReLU", 80 | ) 81 | 82 | def forward(self, img): 83 | return self.model(img) 84 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/bicubic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class BicubicDownSample(nn.Module): 7 | def bicubic_kernel(self, x, a=-0.50): 8 | """ 9 | This equation is exactly copied from the website below: 10 | https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic 11 | """ 12 | abs_x = torch.abs(x) 13 | if abs_x <= 1.0: 14 | return (a + 2.0) * torch.pow(abs_x, 3.0) - (a + 3.0) * torch.pow(abs_x, 2.0) + 1 15 | elif 1.0 < abs_x < 2.0: 16 | return a * torch.pow(abs_x, 3) - 5.0 * a * torch.pow(abs_x, 2.0) + 8.0 * a * abs_x - 4.0 * a 17 | else: 18 | return 0.0 19 | 20 | def __init__(self, factor=4, cuda=True, padding="reflect"): 21 | super().__init__() 22 | self.factor = factor 23 | size = factor * 4 24 | k = torch.tensor( 25 | [self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor) for i in range(size)], 26 | dtype=torch.float32, 27 | ) 28 | k = k / torch.sum(k) 29 | # k = torch.einsum('i,j->ij', (k, k)) 30 | k1 = torch.reshape(k, shape=(1, 1, size, 1)) 31 | self.k1 = torch.cat([k1, k1, k1], dim=0) 32 | k2 = torch.reshape(k, shape=(1, 1, 1, size)) 33 | self.k2 = torch.cat([k2, k2, k2], dim=0) 34 | self.cuda = ".cuda" if cuda else "" 35 | self.padding = padding 36 | for param in self.parameters(): 37 | param.requires_grad = False 38 | 39 | def forward(self, x, nhwc=False, clip_round=False, byte_output=False): 40 | # x = torch.from_numpy(x).type('torch.FloatTensor') 41 | filter_height = self.factor * 4 42 | filter_width = self.factor * 4 43 | stride = self.factor 44 | 45 | pad_along_height = max(filter_height - stride, 0) 46 | pad_along_width = max(filter_width - stride, 0) 47 | filters1 = self.k1.type("torch{}.FloatTensor".format(self.cuda)) 48 | filters2 = self.k2.type("torch{}.FloatTensor".format(self.cuda)) 49 | 50 | # compute actual padding values for each side 51 | pad_top = pad_along_height // 2 52 | pad_bottom = pad_along_height - pad_top 53 | pad_left = pad_along_width // 2 54 | pad_right = pad_along_width - pad_left 55 | 56 | # apply mirror padding 57 | if nhwc: 58 | x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW 59 | 60 | # downscaling performed by 1-d convolution 61 | x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding) 62 | x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3) 63 | if clip_round: 64 | x = torch.clamp(torch.round(x), 0.0, 255.0) 65 | 66 | x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding) 67 | x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3) 68 | if clip_round: 69 | x = torch.clamp(torch.round(x), 0.0, 255.0) 70 | 71 | if nhwc: 72 | x = torch.transpose(torch.transpose(x, 1, 3), 1, 2) 73 | if byte_output: 74 | return x.type("torch.{}.ByteTensor".format(self.cuda)) 75 | else: 76 | return x 77 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/dsd.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import utils.util as util 8 | from models.dips import KernelDIP 9 | from models.dsd.spherical_optimizer import SphericalOptimizer 10 | from torch.optim.lr_scheduler import StepLR 11 | from tqdm import tqdm 12 | 13 | 14 | class DSD(torch.nn.Module): 15 | def __init__(self, opt, cache_dir): 16 | super(DSD, self).__init__() 17 | 18 | self.opt = opt 19 | 20 | self.verbose = opt["verbose"] 21 | cache_dir = Path(cache_dir) 22 | cache_dir.mkdir(parents=True, exist_ok=True) 23 | 24 | # Initialize synthesis network 25 | if self.verbose: 26 | print("Loading Synthesis Network") 27 | self.load_synthesis_network() 28 | if self.verbose: 29 | print("Synthesis Network loaded!") 30 | 31 | self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2) 32 | 33 | self.initialize_mapping_network() 34 | 35 | def initialize_dip(self): 36 | self.dip_zk = util.get_noise(64, "noise", (64, 64)).cuda().detach() 37 | self.k_dip = KernelDIP(self.opt["KernelDIP"]).cuda() 38 | 39 | def initialize_latent_space(self): 40 | pass 41 | 42 | def initialize_optimizers(self): 43 | # Optimizer for k 44 | self.optimizer_k = torch.optim.Adam(self.k_dip.parameters(), lr=self.opt["k_lr"]) 45 | self.scheduler_k = StepLR( 46 | self.optimizer_k, step_size=self.opt["num_epochs"] * self.opt["num_k_iters"] // 5, gamma=0.7 47 | ) 48 | 49 | # Optimizer for x 50 | optimizer_dict = { 51 | "sgd": torch.optim.SGD, 52 | "adam": torch.optim.Adam, 53 | "sgdm": partial(torch.optim.SGD, momentum=0.9), 54 | "adamax": torch.optim.Adamax, 55 | } 56 | optimizer_func = optimizer_dict[self.opt["optimizer_name"]] 57 | self.optimizer_x = SphericalOptimizer(optimizer_func, self.latent_x_var_list, lr=self.opt["x_lr"]) 58 | 59 | steps = self.opt["num_epochs"] * self.opt["num_x_iters"] 60 | schedule_dict = { 61 | "fixed": lambda x: 1, 62 | "linear1cycle": lambda x: (9 * (1 - np.abs(x / steps - 1 / 2) * 2) + 1) / 10, 63 | "linear1cycledrop": lambda x: (9 * (1 - np.abs(x / (0.9 * steps) - 1 / 2) * 2) + 1) / 10 64 | if x < 0.9 * steps 65 | else 1 / 10 + (x - 0.9 * steps) / (0.1 * steps) * (1 / 1000 - 1 / 10), 66 | } 67 | schedule_func = schedule_dict[self.opt["lr_schedule"]] 68 | self.scheduler_x = torch.optim.lr_scheduler.LambdaLR(self.optimizer_x.opt, schedule_func) 69 | 70 | def warmup_dip(self): 71 | self.reg_noise_std = self.opt["reg_noise_std"] 72 | warmup_k = torch.load("experiments/pretrained/kernel.pth") 73 | 74 | mse = nn.MSELoss().cuda() 75 | 76 | print("Warming up k DIP") 77 | for step in tqdm(range(self.opt["num_warmup_iters"])): 78 | self.optimizer_k.zero_grad() 79 | dip_zk_rand = self.dip_zk + self.reg_noise_std * torch.randn_like(self.dip_zk).cuda() 80 | k = self.k_dip(dip_zk_rand) 81 | 82 | loss = mse(k, warmup_k) 83 | loss.backward() 84 | self.optimizer_k.step() 85 | 86 | def optimize_k_step(self, epoch): 87 | # Optimize k 88 | tq_k = tqdm(range(self.opt["num_k_iters"])) 89 | for j in tq_k: 90 | for p in self.k_dip.parameters(): 91 | p.requires_grad = True 92 | for p in self.latent_x_var_list: 93 | p.requires_grad = False 94 | 95 | self.optimizer_k.zero_grad() 96 | 97 | # Duplicate latent in case tile_latent = True 98 | if self.opt["tile_latent"]: 99 | latent_in = self.latent.expand(-1, 14, -1) 100 | else: 101 | latent_in = self.latent 102 | 103 | dip_zk_rand = self.dip_zk + self.reg_noise_std * torch.randn_like(self.dip_zk).cuda() 104 | # Apply learned linear mapping to match latent distribution to that of the mapping network 105 | latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] + self.gaussian_fit["mean"]) 106 | 107 | # Normalize image to [0,1] instead of [-1,1] 108 | self.gen_im = self.get_gen_im(latent_in) 109 | self.gen_ker = self.k_dip(dip_zk_rand) 110 | 111 | # Calculate Losses 112 | loss, loss_dict = self.loss_builder(latent_in, self.gen_im, self.gen_ker, epoch) 113 | self.cur_loss = loss.cpu().detach().numpy() 114 | 115 | loss.backward() 116 | self.optimizer_k.step() 117 | self.scheduler_k.step() 118 | 119 | msg = " | ".join("{}: {:.4f}".format(k, v) for k, v in loss_dict.items()) 120 | tq_k.set_postfix(loss=msg) 121 | 122 | def optimize_x_step(self, epoch): 123 | tq_x = tqdm(range(self.opt["num_x_iters"])) 124 | for j in tq_x: 125 | for p in self.k_dip.parameters(): 126 | p.requires_grad = False 127 | for p in self.latent_x_var_list: 128 | p.requires_grad = True 129 | 130 | self.optimizer_x.opt.zero_grad() 131 | 132 | # Duplicate latent in case tile_latent = True 133 | if self.opt["tile_latent"]: 134 | latent_in = self.latent.expand(-1, 14, -1) 135 | else: 136 | latent_in = self.latent 137 | 138 | dip_zk_rand = self.dip_zk + self.reg_noise_std * torch.randn_like(self.dip_zk).cuda() 139 | # Apply learned linear mapping to match latent distribution to that of the mapping network 140 | latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] + self.gaussian_fit["mean"]) 141 | 142 | # Normalize image to [0,1] instead of [-1,1] 143 | self.gen_im = self.get_gen_im(latent_in) 144 | self.gen_ker = self.k_dip(dip_zk_rand) 145 | 146 | # Calculate Losses 147 | loss, loss_dict = self.loss_builder(latent_in, self.gen_im, self.gen_ker, epoch) 148 | self.cur_loss = loss.cpu().detach().numpy() 149 | 150 | loss.backward() 151 | self.optimizer_x.step() 152 | self.scheduler_x.step() 153 | 154 | msg = " | ".join("{}: {:.4f}".format(k, v) for k, v in loss_dict.items()) 155 | tq_x.set_postfix(loss=msg) 156 | 157 | def log(self): 158 | if self.cur_loss < self.min_loss: 159 | self.min_loss = self.cur_loss 160 | self.best_im = self.gen_im.clone() 161 | self.best_ker = self.gen_ker.clone() 162 | 163 | def forward(self, ref_im): 164 | if self.opt["seed"]: 165 | seed = self.opt["seed"] 166 | torch.manual_seed(seed) 167 | torch.cuda.manual_seed(seed) 168 | torch.backends.cudnn.deterministic = True 169 | 170 | self.initialize_dip() 171 | self.initialize_latent_space() 172 | self.initialize_optimizers() 173 | self.warmup_dip() 174 | 175 | self.min_loss = np.inf 176 | self.gen_im = None 177 | self.initialize_loss(ref_im) 178 | 179 | if self.verbose: 180 | print("Optimizing") 181 | 182 | for epoch in range(self.opt["num_epochs"]): 183 | print("Step: {}".format(epoch + 1)) 184 | 185 | self.optimize_x_step(epoch) 186 | self.log() 187 | self.optimize_k_step(epoch) 188 | self.log() 189 | 190 | if self.opt["save_intermediate"]: 191 | yield ( 192 | self.best_im.cpu().detach().clamp(0, 1), 193 | self.loss_builder.get_blur_img(self.best_im, self.best_ker), 194 | ) 195 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/dsd_stylegan.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from models.dsd.dsd import DSD 5 | from models.dsd.stylegan import G_mapping, G_synthesis 6 | from models.losses.dsd_loss import LossBuilderStyleGAN 7 | 8 | 9 | class DSDStyleGAN(DSD): 10 | def __init__(self, opt, cache_dir): 11 | super(DSDStyleGAN, self).__init__(opt, cache_dir) 12 | 13 | def load_synthesis_network(self): 14 | self.synthesis = G_synthesis().cuda() 15 | self.synthesis.load_state_dict(torch.load("experiments/pretrained/stylegan_synthesis.pt")) 16 | for v in self.synthesis.parameters(): 17 | v.requires_grad = False 18 | 19 | def initialize_mapping_network(self): 20 | if Path("experiments/pretrained/gaussian_fit_stylegan.pt").exists(): 21 | self.gaussian_fit = torch.load("experiments/pretrained/gaussian_fit_stylegan.pt") 22 | else: 23 | if self.verbose: 24 | print("\tRunning Mapping Network") 25 | 26 | mapping = G_mapping().cuda() 27 | mapping.load_state_dict(torch.load("experiments/pretrained/stylegan_mapping.pt")) 28 | with torch.no_grad(): 29 | torch.manual_seed(0) 30 | latent = torch.randn((1000000, 512), dtype=torch.float32, device="cuda") 31 | latent_out = torch.nn.LeakyReLU(5)(mapping(latent)) 32 | self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)} 33 | torch.save(self.gaussian_fit, "experiments/pretrained/gaussian_fit_stylegan.pt") 34 | if self.verbose: 35 | print('\tSaved "gaussian_fit_stylegan.pt"') 36 | 37 | def initialize_latent_space(self): 38 | batch_size = self.opt["batch_size"] 39 | 40 | # Generate latent tensor 41 | if self.opt["tile_latent"]: 42 | self.latent = torch.randn((batch_size, 1, 512), dtype=torch.float, requires_grad=True, device="cuda") 43 | else: 44 | self.latent = torch.randn((batch_size, 18, 512), dtype=torch.float, requires_grad=True, device="cuda") 45 | 46 | # Generate list of noise tensors 47 | noise = [] # stores all of the noise tensors 48 | noise_vars = [] # stores the noise tensors that we want to optimize on 49 | 50 | noise_type = self.opt["noise_type"] 51 | bad_noise_layers = self.opt["bad_noise_layers"] 52 | for i in range(18): 53 | # dimension of the ith noise tensor 54 | res = (batch_size, 1, 2 ** (i // 2 + 2), 2 ** (i // 2 + 2)) 55 | 56 | if noise_type == "zero" or i in [int(layer) for layer in bad_noise_layers.split(".")]: 57 | new_noise = torch.zeros(res, dtype=torch.float, device="cuda") 58 | new_noise.requires_grad = False 59 | elif noise_type == "fixed": 60 | new_noise = torch.randn(res, dtype=torch.float, device="cuda") 61 | new_noise.requires_grad = False 62 | elif noise_type == "trainable": 63 | new_noise = torch.randn(res, dtype=torch.float, device="cuda") 64 | if i < self.opt["num_trainable_noise_layers"]: 65 | new_noise.requires_grad = True 66 | noise_vars.append(new_noise) 67 | else: 68 | new_noise.requires_grad = False 69 | else: 70 | raise Exception("unknown noise type") 71 | 72 | noise.append(new_noise) 73 | 74 | self.latent_x_var_list = [self.latent] + noise_vars 75 | self.noise = noise 76 | 77 | def initialize_loss(self, ref_im): 78 | self.loss_builder = LossBuilderStyleGAN(ref_im, self.opt).cuda() 79 | 80 | def get_gen_im(self, latent_in): 81 | return (self.synthesis(latent_in, self.noise) + 1) / 2 82 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/dsd_stylegan2.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from models.dsd.dsd import DSD 5 | from models.dsd.stylegan2 import Generator 6 | from models.losses.dsd_loss import LossBuilderStyleGAN2 7 | 8 | 9 | class DSDStyleGAN2(DSD): 10 | def __init__(self, opt, cache_dir): 11 | super(DSDStyleGAN2, self).__init__(opt, cache_dir) 12 | 13 | def load_synthesis_network(self): 14 | self.synthesis = Generator(size=256, style_dim=512, n_mlp=8).cuda() 15 | self.synthesis.load_state_dict(torch.load("experiments/pretrained/stylegan2.pt")["g_ema"], strict=False) 16 | for v in self.synthesis.parameters(): 17 | v.requires_grad = False 18 | 19 | def initialize_mapping_network(self): 20 | if Path("experiments/pretrained/gaussian_fit_stylegan2.pt").exists(): 21 | self.gaussian_fit = torch.load("experiments/pretrained/gaussian_fit_stylegan2.pt") 22 | else: 23 | if self.verbose: 24 | print("\tRunning Mapping Network") 25 | with torch.no_grad(): 26 | torch.manual_seed(0) 27 | latent = torch.randn((1000000, 512), dtype=torch.float32, device="cuda") 28 | latent_out = torch.nn.LeakyReLU(5)(self.synthesis.get_latent(latent)) 29 | self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)} 30 | torch.save(self.gaussian_fit, "experiments/pretrained/gaussian_fit_stylegan2.pt") 31 | if self.verbose: 32 | print('\tSaved "gaussian_fit_stylegan2.pt"') 33 | 34 | def initialize_latent_space(self): 35 | batch_size = self.opt["batch_size"] 36 | 37 | # Generate latent tensor 38 | if self.opt["tile_latent"]: 39 | self.latent = torch.randn((batch_size, 1, 512), dtype=torch.float, requires_grad=True, device="cuda") 40 | else: 41 | self.latent = torch.randn((batch_size, 14, 512), dtype=torch.float, requires_grad=True, device="cuda") 42 | 43 | # Generate list of noise tensors 44 | noise = [] # stores all of the noise tensors 45 | noise_vars = [] # stores the noise tensors that we want to optimize on 46 | 47 | for i in range(14): 48 | res = (i + 5) // 2 49 | res = [1, 1, 2 ** res, 2 ** res] 50 | 51 | noise_type = self.opt["noise_type"] 52 | bad_noise_layers = self.opt["bad_noise_layers"] 53 | if noise_type == "zero" or i in [int(layer) for layer in bad_noise_layers.split(".")]: 54 | new_noise = torch.zeros(res, dtype=torch.float, device="cuda") 55 | new_noise.requires_grad = False 56 | elif noise_type == "fixed": 57 | new_noise = torch.randn(res, dtype=torch.float, device="cuda") 58 | new_noise.requires_grad = False 59 | elif noise_type == "trainable": 60 | new_noise = torch.randn(res, dtype=torch.float, device="cuda") 61 | if i < self.opt["num_trainable_noise_layers"]: 62 | new_noise.requires_grad = True 63 | noise_vars.append(new_noise) 64 | else: 65 | new_noise.requires_grad = False 66 | else: 67 | raise Exception("unknown noise type") 68 | 69 | noise.append(new_noise) 70 | 71 | self.latent_x_var_list = [self.latent] + noise_vars 72 | self.noise = noise 73 | 74 | def initialize_loss(self, ref_im): 75 | self.loss_builder = LossBuilderStyleGAN2(ref_im, self.opt).cuda() 76 | 77 | def get_gen_im(self, latent_in): 78 | return (self.synthesis([latent_in], input_is_latent=True, noise=self.noise)[0] + 1) / 2 79 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/op/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/forward_operator/bkse/models/dsd/op/__init__.py -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.nn import functional as F 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | if bias: 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | else: 40 | grad_bias = None 41 | 42 | return grad_input, grad_bias 43 | 44 | @staticmethod 45 | def backward(ctx, gradgrad_input, gradgrad_bias): 46 | (out,) = ctx.saved_tensors 47 | gradgrad_out = fused.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale) 48 | 49 | return gradgrad_out, None, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | 57 | if bias is None: 58 | bias = empty 59 | 60 | ctx.bias = bias is not None 61 | 62 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 63 | ctx.save_for_backward(out) 64 | ctx.negative_slope = negative_slope 65 | ctx.scale = scale 66 | 67 | return out 68 | 69 | @staticmethod 70 | def backward(ctx, grad_output): 71 | (out,) = ctx.saved_tensors 72 | 73 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 74 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 75 | ) 76 | 77 | return grad_input, grad_bias, None, None 78 | 79 | 80 | class FusedLeakyReLU(nn.Module): 81 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 82 | super().__init__() 83 | 84 | if bias: 85 | self.bias = nn.Parameter(torch.zeros(channel)) 86 | 87 | else: 88 | self.bias = None 89 | 90 | self.negative_slope = negative_slope 91 | self.scale = scale 92 | 93 | def forward(self, input): 94 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 95 | 96 | 97 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 98 | if input.device.type == "cpu": 99 | if bias is not None: 100 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 101 | return F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale 102 | 103 | else: 104 | return F.leaky_relu(input, negative_slope=0.2) * scale 105 | 106 | else: 107 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 108 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.nn import functional as F 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 22 | 23 | up_x, up_y = up 24 | down_x, down_y = down 25 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 26 | 27 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 28 | 29 | grad_input = upfirdn2d_op.upfirdn2d( 30 | grad_output, 31 | grad_kernel, 32 | down_x, 33 | down_y, 34 | up_x, 35 | up_y, 36 | g_pad_x0, 37 | g_pad_x1, 38 | g_pad_y0, 39 | g_pad_y1, 40 | ) 41 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 42 | 43 | ctx.save_for_backward(kernel) 44 | 45 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 46 | 47 | ctx.up_x = up_x 48 | ctx.up_y = up_y 49 | ctx.down_x = down_x 50 | ctx.down_y = down_y 51 | ctx.pad_x0 = pad_x0 52 | ctx.pad_x1 = pad_x1 53 | ctx.pad_y0 = pad_y0 54 | ctx.pad_y1 = pad_y1 55 | ctx.in_size = in_size 56 | ctx.out_size = out_size 57 | 58 | return grad_input 59 | 60 | @staticmethod 61 | def backward(ctx, gradgrad_input): 62 | (kernel,) = ctx.saved_tensors 63 | 64 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 65 | 66 | gradgrad_out = upfirdn2d_op.upfirdn2d( 67 | gradgrad_input, 68 | kernel, 69 | ctx.up_x, 70 | ctx.up_y, 71 | ctx.down_x, 72 | ctx.down_y, 73 | ctx.pad_x0, 74 | ctx.pad_x1, 75 | ctx.pad_y0, 76 | ctx.pad_y1, 77 | ) 78 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 79 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 80 | 81 | return gradgrad_out, None, None, None, None, None, None, None, None 82 | 83 | 84 | class UpFirDn2d(Function): 85 | @staticmethod 86 | def forward(ctx, input, kernel, up, down, pad): 87 | up_x, up_y = up 88 | down_x, down_y = down 89 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 90 | 91 | kernel_h, kernel_w = kernel.shape 92 | batch, channel, in_h, in_w = input.shape 93 | ctx.in_size = input.shape 94 | 95 | input = input.reshape(-1, in_h, in_w, 1) 96 | 97 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 98 | 99 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 100 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 101 | ctx.out_size = (out_h, out_w) 102 | 103 | ctx.up = (up_x, up_y) 104 | ctx.down = (down_x, down_y) 105 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 106 | 107 | g_pad_x0 = kernel_w - pad_x0 - 1 108 | g_pad_y0 = kernel_h - pad_y0 - 1 109 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 110 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 111 | 112 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 113 | 114 | out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 115 | # out = out.view(major, out_h, out_w, minor) 116 | out = out.view(-1, channel, out_h, out_w) 117 | 118 | return out 119 | 120 | @staticmethod 121 | def backward(ctx, grad_output): 122 | kernel, grad_kernel = ctx.saved_tensors 123 | 124 | grad_input = UpFirDn2dBackward.apply( 125 | grad_output, 126 | kernel, 127 | grad_kernel, 128 | ctx.up, 129 | ctx.down, 130 | ctx.pad, 131 | ctx.g_pad, 132 | ctx.in_size, 133 | ctx.out_size, 134 | ) 135 | 136 | return grad_input, None, None, None, None 137 | 138 | 139 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 140 | if input.device.type == "cpu": 141 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 142 | 143 | else: 144 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 145 | 146 | return out 147 | 148 | 149 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 150 | _, channel, in_h, in_w = input.shape 151 | input = input.reshape(-1, in_h, in_w, 1) 152 | 153 | _, in_h, in_w, minor = input.shape 154 | kernel_h, kernel_w = kernel.shape 155 | 156 | out = input.view(-1, in_h, 1, in_w, 1, minor) 157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 159 | 160 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 161 | out = out[ 162 | :, 163 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 164 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 165 | :, 166 | ] 167 | 168 | out = out.permute(0, 3, 1, 2) 169 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 170 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 171 | out = F.conv2d(out, w) 172 | out = out.reshape( 173 | -1, 174 | minor, 175 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 176 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 177 | ) 178 | out = out.permute(0, 2, 3, 1) 179 | out = out[:, ::down_y, ::down_x, :] 180 | 181 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 182 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 183 | 184 | return out.view(-1, channel, out_h, out_w) 185 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/dsd/spherical_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | # Spherical Optimizer Class 6 | # Uses the first two dimensions as batch information 7 | # Optimizes over the surface of a sphere using the initial radius throughout 8 | # 9 | # Example Usage: 10 | # opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01) 11 | 12 | 13 | class SphericalOptimizer(Optimizer): 14 | def __init__(self, optimizer, params, **kwargs): 15 | self.opt = optimizer(params, **kwargs) 16 | self.params = params 17 | with torch.no_grad(): 18 | self.radii = { 19 | param: (param.pow(2).sum(tuple(range(2, param.ndim)), keepdim=True) + 1e-9).sqrt() for param in params 20 | } 21 | 22 | @torch.no_grad() 23 | def step(self, closure=None): 24 | loss = self.opt.step(closure) 25 | for param in self.params: 26 | param.data.div_((param.pow(2).sum(tuple(range(2, param.ndim)), keepdim=True) + 1e-9).sqrt()) 27 | param.mul_(self.radii[param]) 28 | 29 | return loss 30 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/kernel_encoding/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parallel import DistributedDataParallel 7 | 8 | 9 | class BaseModel: 10 | def __init__(self, opt): 11 | self.opt = opt 12 | self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") 13 | self.is_train = opt["is_train"] 14 | self.schedulers = [] 15 | self.optimizers = [] 16 | 17 | def feed_data(self, data): 18 | pass 19 | 20 | def optimize_parameters(self): 21 | pass 22 | 23 | def get_current_visuals(self): 24 | pass 25 | 26 | def get_current_losses(self): 27 | pass 28 | 29 | def print_network(self): 30 | pass 31 | 32 | def save(self, label): 33 | pass 34 | 35 | def load(self): 36 | pass 37 | 38 | def _set_lr(self, lr_groups_l): 39 | """Set learning rate for warmup 40 | lr_groups_l: list for lr_groups. each for a optimizer""" 41 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 42 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 43 | param_group["lr"] = lr 44 | 45 | def _get_init_lr(self): 46 | """Get the initial lr, which is set by the scheduler""" 47 | init_lr_groups_l = [] 48 | for optimizer in self.optimizers: 49 | init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) 50 | return init_lr_groups_l 51 | 52 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 53 | for scheduler in self.schedulers: 54 | scheduler.step() 55 | # set up warm-up learning rate 56 | if cur_iter < warmup_iter: 57 | # get initial lr for each group 58 | init_lr_g_l = self._get_init_lr() 59 | # modify warming-up learning rates 60 | warm_up_lr_l = [] 61 | for init_lr_g in init_lr_g_l: 62 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 63 | # set learning rate 64 | self._set_lr(warm_up_lr_l) 65 | 66 | def get_current_learning_rate(self): 67 | return [param_group["lr"] for param_group in self.optimizers[0].param_groups] 68 | 69 | def get_network_description(self, network): 70 | """Get the string and total parameters of the network""" 71 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 72 | network = network.module 73 | return str(network), sum(map(lambda x: x.numel(), network.parameters())) 74 | 75 | def save_network(self, network, network_label, iter_label): 76 | save_filename = "{}_{}.pth".format(iter_label, network_label) 77 | save_path = os.path.join(self.opt["path"]["models"], save_filename) 78 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 79 | network = network.module 80 | state_dict = network.state_dict() 81 | for key, param in state_dict.items(): 82 | state_dict[key] = param.cpu() 83 | torch.save(state_dict, save_path) 84 | 85 | def load_network(self, load_path, network, strict=True, prefix=""): 86 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 87 | network = network.module 88 | load_net = torch.load(load_path) 89 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 90 | for k, v in load_net.items(): 91 | if k.startswith("module."): 92 | load_net_clean[k[7:]] = v 93 | else: 94 | load_net_clean[k] = v 95 | load_net.update(load_net_clean) 96 | 97 | model_dict = network.state_dict() 98 | for k, v in load_net.items(): 99 | k = prefix + k 100 | if (k in model_dict) and (v.shape == model_dict[k].shape): 101 | model_dict[k] = v 102 | else: 103 | print("Load failed:", k) 104 | 105 | network.load_state_dict(model_dict, strict=True) 106 | 107 | def save_training_state(self, epoch, iter_step): 108 | """ 109 | Save training state during training, 110 | which will be used for resuming 111 | """ 112 | 113 | state = {"epoch": epoch, "iter": iter_step, "schedulers": [], "optimizers": []} 114 | for s in self.schedulers: 115 | state["schedulers"].append(s.state_dict()) 116 | for o in self.optimizers: 117 | state["optimizers"].append(o.state_dict()) 118 | save_filename = "{}.state".format(iter_step) 119 | save_path = os.path.join(self.opt["path"]["training_state"], save_filename) 120 | torch.save(state, save_path) 121 | 122 | def resume_training(self, resume_state): 123 | """Resume the optimizers and schedulers for training""" 124 | resume_optimizers = resume_state["optimizers"] 125 | resume_schedulers = resume_state["schedulers"] 126 | assert len(resume_optimizers) == len(self.optimizers), "Wrong lengths of optimizers" 127 | assert len(resume_schedulers) == len(self.schedulers), "Wrong lengths of schedulers" 128 | for i, o in enumerate(resume_optimizers): 129 | self.optimizers[i].load_state_dict(o) 130 | for i, s in enumerate(resume_schedulers): 131 | self.schedulers[i].load_state_dict(s) 132 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/kernel_encoding/image_base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import models.lr_scheduler as lr_scheduler 5 | import torch 6 | import torch.nn as nn 7 | from models.kernel_encoding.base_model import BaseModel 8 | from models.kernel_encoding.kernel_wizard import KernelWizard 9 | from models.losses.charbonnier_loss import CharbonnierLoss 10 | from torch.nn.parallel import DataParallel, DistributedDataParallel 11 | 12 | 13 | logger = logging.getLogger("base") 14 | 15 | 16 | class ImageBaseModel(BaseModel): 17 | def __init__(self, opt): 18 | super(ImageBaseModel, self).__init__(opt) 19 | 20 | if opt["dist"]: 21 | self.rank = torch.distributed.get_rank() 22 | else: 23 | self.rank = -1 # non dist training 24 | train_opt = opt["train"] 25 | 26 | # define network and load pretrained models 27 | self.netG = KernelWizard(opt["KernelWizard"]).to(self.device) 28 | self.use_vae = opt["KernelWizard"]["use_vae"] 29 | if opt["dist"]: 30 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 31 | else: 32 | self.netG = DataParallel(self.netG) 33 | # print network 34 | self.print_network() 35 | self.load() 36 | 37 | if self.is_train: 38 | self.netG.train() 39 | 40 | # loss 41 | loss_type = train_opt["pixel_criterion"] 42 | if loss_type == "l1": 43 | self.cri_pix = nn.L1Loss(reduction="sum").to(self.device) 44 | elif loss_type == "l2": 45 | self.cri_pix = nn.MSELoss(reduction="sum").to(self.device) 46 | elif loss_type == "cb": 47 | self.cri_pix = CharbonnierLoss().to(self.device) 48 | else: 49 | raise NotImplementedError( 50 | "Loss type [{:s}] is not\ 51 | recognized.".format( 52 | loss_type 53 | ) 54 | ) 55 | self.l_pix_w = train_opt["pixel_weight"] 56 | self.l_kl_w = train_opt["kl_weight"] 57 | 58 | # optimizers 59 | wd_G = train_opt["weight_decay_G"] if train_opt["weight_decay_G"] else 0 60 | params = [] 61 | for k, v in self.netG.named_parameters(): 62 | if v.requires_grad: 63 | params.append(v) 64 | else: 65 | if self.rank <= 0: 66 | logger.warning( 67 | "Params [{:s}] will not\ 68 | optimize.".format( 69 | k 70 | ) 71 | ) 72 | optim_params = [ 73 | {"params": params, "lr": train_opt["lr_G"]}, 74 | ] 75 | 76 | self.optimizer_G = torch.optim.Adam( 77 | optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1"], train_opt["beta2"]) 78 | ) 79 | self.optimizers.append(self.optimizer_G) 80 | 81 | # schedulers 82 | if train_opt["lr_scheme"] == "MultiStepLR": 83 | for optimizer in self.optimizers: 84 | self.schedulers.append( 85 | lr_scheduler.MultiStepLR_Restart( 86 | optimizer, 87 | train_opt["lr_steps"], 88 | restarts=train_opt["restarts"], 89 | weights=train_opt["restart_weights"], 90 | gamma=train_opt["lr_gamma"], 91 | clear_state=train_opt["clear_state"], 92 | ) 93 | ) 94 | elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": 95 | for optimizer in self.optimizers: 96 | self.schedulers.append( 97 | lr_scheduler.CosineAnnealingLR_Restart( 98 | optimizer, 99 | train_opt["T_period"], 100 | eta_min=train_opt["eta_min"], 101 | restarts=train_opt["restarts"], 102 | weights=train_opt["restart_weights"], 103 | ) 104 | ) 105 | else: 106 | raise NotImplementedError() 107 | 108 | self.log_dict = OrderedDict() 109 | 110 | def feed_data(self, data, need_GT=True): 111 | self.LQ = data["LQ"].to(self.device) 112 | self.HQ = data["HQ"].to(self.device) 113 | 114 | def set_params_lr_zero(self, groups): 115 | # fix normal module 116 | for group in groups: 117 | self.optimizers[0].param_groups[group]["lr"] = 0 118 | 119 | def optimize_parameters(self, step): 120 | batchsz, _, _, _ = self.LQ.shape 121 | 122 | self.optimizer_G.zero_grad() 123 | kernel_mean, kernel_sigma = self.netG(self.HQ, self.LQ) 124 | 125 | kernel = kernel_mean + kernel_sigma * torch.randn_like(kernel_mean) 126 | self.fake_LQ = self.netG.module.adaptKernel(self.HQ, kernel) 127 | 128 | l_pix = self.l_pix_w * self.cri_pix(self.fake_LQ, self.LQ) 129 | l_total = l_pix 130 | 131 | if self.use_vae: 132 | KL_divergence = ( 133 | self.l_kl_w 134 | * torch.sum( 135 | torch.pow(kernel_mean, 2) 136 | + torch.pow(kernel_sigma, 2) 137 | - torch.log(1e-8 + torch.pow(kernel_sigma, 2)) 138 | - 1 139 | ).sum() 140 | ) 141 | l_total += KL_divergence 142 | self.log_dict["l_KL"] = KL_divergence.item() / batchsz 143 | 144 | l_total.backward() 145 | self.optimizer_G.step() 146 | 147 | # set log 148 | self.log_dict["l_pix"] = l_pix.item() / batchsz 149 | self.log_dict["l_total"] = l_total.item() / batchsz 150 | 151 | def test(self): 152 | self.netG.eval() 153 | with torch.no_grad(): 154 | self.fake_H = self.netG(self.var_L) 155 | self.netG.train() 156 | 157 | def get_current_log(self): 158 | return self.log_dict 159 | 160 | def get_current_visuals(self, need_GT=True): 161 | out_dict = OrderedDict() 162 | out_dict["LQ"] = self.LQ.detach()[0].float().cpu() 163 | out_dict["rlt"] = self.fake_LQ.detach()[0].float().cpu() 164 | return out_dict 165 | 166 | def print_network(self): 167 | s, n = self.get_network_description(self.netG) 168 | if isinstance(self.netG, nn.DataParallel): 169 | net_struc_str = "{} - {}".format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) 170 | else: 171 | net_struc_str = "{}".format(self.netG.__class__.__name__) 172 | if self.rank <= 0: 173 | logger.info( 174 | "Network G structure: {}, \ 175 | with parameters: {:,d}".format( 176 | net_struc_str, n 177 | ) 178 | ) 179 | logger.info(s) 180 | 181 | def load(self): 182 | if self.opt["path"]["pretrain_model_G"]: 183 | load_path_G = self.opt["path"]["pretrain_model_G"] 184 | if load_path_G is not None: 185 | logger.info( 186 | "Loading model for G [{:s}]\ 187 | ...".format( 188 | load_path_G 189 | ) 190 | ) 191 | self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"]) 192 | 193 | def save(self, iter_label): 194 | self.save_network(self.netG, "G", iter_label) 195 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/kernel_encoding/kernel_wizard.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from .. import arch_util 4 | import torch 5 | import torch.nn as nn 6 | from ..backbones.resnet import ResidualBlock_noBN, ResnetBlock 7 | from ..backbones.unet_parts import UnetSkipConnectionBlock 8 | 9 | 10 | # The function F in the paper 11 | class KernelExtractor(nn.Module): 12 | def __init__(self, opt): 13 | super(KernelExtractor, self).__init__() 14 | 15 | nf = opt["nf"] 16 | self.kernel_dim = opt["kernel_dim"] 17 | self.use_sharp = opt["KernelExtractor"]["use_sharp"] 18 | self.use_vae = opt["use_vae"] 19 | 20 | # Blur estimator 21 | norm_layer = arch_util.get_norm_layer(opt["KernelExtractor"]["norm"]) 22 | n_blocks = opt["KernelExtractor"]["n_blocks"] 23 | padding_type = opt["KernelExtractor"]["padding_type"] 24 | use_dropout = opt["KernelExtractor"]["use_dropout"] 25 | if type(norm_layer) == functools.partial: 26 | use_bias = norm_layer.func == nn.InstanceNorm2d 27 | else: 28 | use_bias = norm_layer == nn.InstanceNorm2d 29 | 30 | input_nc = nf * 2 if self.use_sharp else nf 31 | output_nc = self.kernel_dim * 2 if self.use_vae else self.kernel_dim 32 | 33 | model = [ 34 | nn.ReflectionPad2d(3), 35 | nn.Conv2d(input_nc, nf, kernel_size=7, padding=0, bias=use_bias), 36 | norm_layer(nf), 37 | nn.ReLU(True), 38 | ] 39 | 40 | n_downsampling = 5 41 | for i in range(n_downsampling): # add downsampling layers 42 | mult = 2 ** i 43 | inc = min(nf * mult, output_nc) 44 | ouc = min(nf * mult * 2, output_nc) 45 | model += [ 46 | nn.Conv2d(inc, ouc, kernel_size=3, stride=2, padding=1, bias=use_bias), 47 | norm_layer(nf * mult * 2), 48 | nn.ReLU(True), 49 | ] 50 | 51 | for i in range(n_blocks): # add ResNet blocks 52 | model += [ 53 | ResnetBlock( 54 | output_nc, 55 | padding_type=padding_type, 56 | norm_layer=norm_layer, 57 | use_dropout=use_dropout, 58 | use_bias=use_bias, 59 | ) 60 | ] 61 | 62 | self.model = nn.Sequential(*model) 63 | 64 | def forward(self, sharp, blur): 65 | output = self.model(torch.cat((sharp, blur), dim=1)) 66 | if self.use_vae: 67 | return output[:, : self.kernel_dim, :, :], output[:, self.kernel_dim :, :, :] 68 | 69 | return output, torch.zeros_like(output).cuda() 70 | 71 | 72 | # The function G in the paper 73 | class KernelAdapter(nn.Module): 74 | def __init__(self, opt): 75 | super(KernelAdapter, self).__init__() 76 | input_nc = opt["nf"] 77 | output_nc = opt["nf"] 78 | ngf = opt["nf"] 79 | norm_layer = arch_util.get_norm_layer(opt["Adapter"]["norm"]) 80 | 81 | # construct unet structure 82 | unet_block = UnetSkipConnectionBlock( 83 | ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True 84 | ) 85 | # gradually reduce the number of filters from ngf * 8 to ngf 86 | unet_block = UnetSkipConnectionBlock( 87 | ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer 88 | ) 89 | unet_block = UnetSkipConnectionBlock( 90 | ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer 91 | ) 92 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 93 | self.model = UnetSkipConnectionBlock( 94 | output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer 95 | ) 96 | 97 | def forward(self, x, k): 98 | """Standard forward""" 99 | return self.model(x, k) 100 | 101 | 102 | class KernelWizard(nn.Module): 103 | def __init__(self, opt): 104 | super(KernelWizard, self).__init__() 105 | lrelu = nn.LeakyReLU(negative_slope=0.1) 106 | front_RBs = opt["front_RBs"] 107 | back_RBs = opt["back_RBs"] 108 | num_image_channels = opt["input_nc"] 109 | nf = opt["nf"] 110 | 111 | # Features extraction 112 | resBlock_noBN_f = functools.partial(ResidualBlock_noBN, nf=nf) 113 | feature_extractor = [] 114 | 115 | feature_extractor.append(nn.Conv2d(num_image_channels, nf, 3, 1, 1, bias=True)) 116 | feature_extractor.append(lrelu) 117 | feature_extractor.append(nn.Conv2d(nf, nf, 3, 2, 1, bias=True)) 118 | feature_extractor.append(lrelu) 119 | feature_extractor.append(nn.Conv2d(nf, nf, 3, 2, 1, bias=True)) 120 | feature_extractor.append(lrelu) 121 | 122 | for i in range(front_RBs): 123 | feature_extractor.append(resBlock_noBN_f()) 124 | 125 | self.feature_extractor = nn.Sequential(*feature_extractor) 126 | 127 | # Kernel extractor 128 | self.kernel_extractor = KernelExtractor(opt) 129 | 130 | # kernel adapter 131 | self.adapter = KernelAdapter(opt) 132 | 133 | # Reconstruction 134 | recon_trunk = [] 135 | for i in range(back_RBs): 136 | recon_trunk.append(resBlock_noBN_f()) 137 | 138 | # upsampling 139 | recon_trunk.append(nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)) 140 | recon_trunk.append(nn.PixelShuffle(2)) 141 | recon_trunk.append(lrelu) 142 | recon_trunk.append(nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)) 143 | recon_trunk.append(nn.PixelShuffle(2)) 144 | recon_trunk.append(lrelu) 145 | recon_trunk.append(nn.Conv2d(64, 64, 3, 1, 1, bias=True)) 146 | recon_trunk.append(lrelu) 147 | recon_trunk.append(nn.Conv2d(64, num_image_channels, 3, 1, 1, bias=True)) 148 | 149 | self.recon_trunk = nn.Sequential(*recon_trunk) 150 | 151 | def adaptKernel(self, x_sharp, kernel): 152 | B, C, H, W = x_sharp.shape 153 | base = x_sharp 154 | 155 | x_sharp = self.feature_extractor(x_sharp) 156 | 157 | out = self.adapter(x_sharp, kernel) 158 | out = self.recon_trunk(out) 159 | out += base 160 | 161 | return out 162 | 163 | def forward(self, x_sharp, x_blur): 164 | x_sharp = self.feature_extractor(x_sharp) 165 | x_blur = self.feature_extractor(x_blur) 166 | 167 | output = self.kernel_extractor(x_sharp, x_blur) 168 | return output 169 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/losses/charbonnier_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-6): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | return loss 16 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/losses/dsd_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.dsd.bicubic import BicubicDownSample 3 | from models.kernel_encoding.kernel_wizard import KernelWizard 4 | from models.losses.ssim_loss import SSIM 5 | 6 | 7 | class LossBuilder(torch.nn.Module): 8 | def __init__(self, ref_im, opt): 9 | super(LossBuilder, self).__init__() 10 | assert ref_im.shape[2] == ref_im.shape[3] 11 | self.ref_im = ref_im 12 | loss_str = opt["loss_str"] 13 | self.parsed_loss = [loss_term.split("*") for loss_term in loss_str.split("+")] 14 | self.eps = opt["eps"] 15 | 16 | self.ssim = SSIM().cuda() 17 | 18 | self.D = KernelWizard(opt["KernelWizard"]).cuda() 19 | self.D.load_state_dict(torch.load(opt["KernelWizard"]["pretrained"])) 20 | for v in self.D.parameters(): 21 | v.requires_grad = False 22 | 23 | # Takes a list of tensors, flattens them, and concatenates them into a vector 24 | # Used to calculate euclidian distance between lists of tensors 25 | def flatcat(self, l): 26 | l = l if (isinstance(l, list)) else [l] 27 | return torch.cat([x.flatten() for x in l], dim=0) 28 | 29 | def _loss_l2(self, gen_im_lr, ref_im, **kwargs): 30 | return (gen_im_lr - ref_im).pow(2).mean((1, 2, 3)).clamp(min=self.eps).sum() 31 | 32 | def _loss_l1(self, gen_im_lr, ref_im, **kwargs): 33 | return 10 * ((gen_im_lr - ref_im).abs().mean((1, 2, 3)).clamp(min=self.eps).sum()) 34 | 35 | # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors 36 | def _loss_geocross(self, latent, **kwargs): 37 | pass 38 | 39 | 40 | class LossBuilderStyleGAN(LossBuilder): 41 | def __init__(self, ref_im, opt): 42 | super(LossBuilderStyleGAN, self).__init__(ref_im, opt) 43 | im_size = ref_im.shape[2] 44 | factor = opt["output_size"] // im_size 45 | assert im_size * factor == opt["output_size"] 46 | self.bicub = BicubicDownSample(factor=factor) 47 | 48 | # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors 49 | def _loss_geocross(self, latent, **kwargs): 50 | if latent.shape[1] == 1: 51 | return 0 52 | else: 53 | X = latent.view(-1, 1, 18, 512) 54 | Y = latent.view(-1, 18, 1, 512) 55 | A = ((X - Y).pow(2).sum(-1) + 1e-9).sqrt() 56 | B = ((X + Y).pow(2).sum(-1) + 1e-9).sqrt() 57 | D = 2 * torch.atan2(A, B) 58 | D = ((D.pow(2) * 512).mean((1, 2)) / 8.0).sum() 59 | return D 60 | 61 | def forward(self, latent, gen_im, kernel, step): 62 | var_dict = { 63 | "latent": latent, 64 | "gen_im_lr": self.D.adaptKernel(self.bicub(gen_im), kernel), 65 | "ref_im": self.ref_im, 66 | } 67 | loss = 0 68 | loss_fun_dict = { 69 | "L2": self._loss_l2, 70 | "L1": self._loss_l1, 71 | "GEOCROSS": self._loss_geocross, 72 | } 73 | losses = {} 74 | 75 | for weight, loss_type in self.parsed_loss: 76 | tmp_loss = loss_fun_dict[loss_type](**var_dict) 77 | losses[loss_type] = tmp_loss 78 | loss += float(weight) * tmp_loss 79 | loss += 5e-5 * torch.norm(kernel) 80 | losses["Norm"] = torch.norm(kernel) 81 | 82 | return loss, losses 83 | 84 | def get_blur_img(self, sharp_img, kernel): 85 | return self.D.adaptKernel(self.bicub(sharp_img), kernel).cpu().detach().clamp(0, 1) 86 | 87 | 88 | class LossBuilderStyleGAN2(LossBuilder): 89 | def __init__(self, ref_im, opt): 90 | super(LossBuilderStyleGAN2, self).__init__(ref_im, opt) 91 | 92 | # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors 93 | def _loss_geocross(self, latent, **kwargs): 94 | if latent.shape[1] == 1: 95 | return 0 96 | else: 97 | X = latent.view(-1, 1, 14, 512) 98 | Y = latent.view(-1, 14, 1, 512) 99 | A = ((X - Y).pow(2).sum(-1) + 1e-9).sqrt() 100 | B = ((X + Y).pow(2).sum(-1) + 1e-9).sqrt() 101 | D = 2 * torch.atan2(A, B) 102 | D = ((D.pow(2) * 512).mean((1, 2)) / 6.0).sum() 103 | return D 104 | 105 | def forward(self, latent, gen_im, kernel, step): 106 | var_dict = { 107 | "latent": latent, 108 | "gen_im_lr": self.D.adaptKernel(gen_im, kernel), 109 | "ref_im": self.ref_im, 110 | } 111 | loss = 0 112 | loss_fun_dict = { 113 | "L2": self._loss_l2, 114 | "L1": self._loss_l1, 115 | "GEOCROSS": self._loss_geocross, 116 | } 117 | losses = {} 118 | 119 | for weight, loss_type in self.parsed_loss: 120 | tmp_loss = loss_fun_dict[loss_type](**var_dict) 121 | losses[loss_type] = tmp_loss 122 | loss += float(weight) * tmp_loss 123 | loss += 1e-4 * torch.norm(kernel) 124 | losses["Norm"] = torch.norm(kernel) 125 | 126 | return loss, losses 127 | 128 | def get_blur_img(self, sharp_img, kernel): 129 | return self.D.adaptKernel(sharp_img, kernel).cpu().detach().clamp(0, 1) 130 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/losses/gan_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 6 | class GANLoss(nn.Module): 7 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 8 | super(GANLoss, self).__init__() 9 | self.gan_type = gan_type.lower() 10 | self.real_label_val = real_label_val 11 | self.fake_label_val = fake_label_val 12 | 13 | if self.gan_type == "gan" or self.gan_type == "ragan": 14 | self.loss = nn.BCEWithLogitsLoss() 15 | elif self.gan_type == "lsgan": 16 | self.loss = nn.MSELoss() 17 | elif self.gan_type == "wgan-gp": 18 | 19 | def wgan_loss(input, target): 20 | # target is boolean 21 | return -1 * input.mean() if target else input.mean() 22 | 23 | self.loss = wgan_loss 24 | else: 25 | raise NotImplementedError("GAN type [{:s}] is not found".format(self.gan_type)) 26 | 27 | def get_target_label(self, input, target_is_real): 28 | if self.gan_type == "wgan-gp": 29 | return target_is_real 30 | if target_is_real: 31 | return torch.empty_like(input).fill_(self.real_label_val) 32 | else: 33 | return torch.empty_like(input).fill_(self.fake_label_val) 34 | 35 | def forward(self, input, target_is_real): 36 | target_label = self.get_target_label(input, target_is_real) 37 | loss = self.loss(input, target_label) 38 | return loss 39 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/losses/hyper_laplacian_penalty.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class HyperLaplacianPenalty(nn.Module): 7 | def __init__(self, num_channels, alpha, eps=1e-6): 8 | super(HyperLaplacianPenalty, self).__init__() 9 | 10 | self.alpha = alpha 11 | self.eps = eps 12 | 13 | self.Kx = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).cuda() 14 | self.Kx = self.Kx.expand(1, num_channels, 3, 3) 15 | self.Kx.requires_grad = False 16 | self.Ky = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda() 17 | self.Ky = self.Ky.expand(1, num_channels, 3, 3) 18 | self.Ky.requires_grad = False 19 | 20 | def forward(self, x): 21 | gradX = F.conv2d(x, self.Kx, stride=1, padding=1) 22 | gradY = F.conv2d(x, self.Ky, stride=1, padding=1) 23 | grad = torch.sqrt(gradX ** 2 + gradY ** 2 + self.eps) 24 | 25 | loss = (grad ** self.alpha).mean() 26 | 27 | return loss 28 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/losses/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | class StyleLoss(nn.Module): 7 | r""" 8 | Perceptual loss, VGG-based 9 | https://arxiv.org/abs/1603.08155 10 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 11 | """ 12 | 13 | def __init__(self): 14 | super(StyleLoss, self).__init__() 15 | self.add_module("vgg", VGG19()) 16 | self.criterion = torch.nn.L1Loss() 17 | 18 | def compute_gram(self, x): 19 | b, ch, h, w = x.size() 20 | f = x.view(b, ch, w * h) 21 | f_T = f.transpose(1, 2) 22 | G = f.bmm(f_T) / (h * w * ch) 23 | 24 | return G 25 | 26 | def __call__(self, x, y): 27 | # Compute features 28 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 29 | 30 | # Compute loss 31 | style_loss = 0.0 32 | style_loss += self.criterion(self.compute_gram(x_vgg["relu2_2"]), self.compute_gram(y_vgg["relu2_2"])) 33 | style_loss += self.criterion(self.compute_gram(x_vgg["relu3_4"]), self.compute_gram(y_vgg["relu3_4"])) 34 | style_loss += self.criterion(self.compute_gram(x_vgg["relu4_4"]), self.compute_gram(y_vgg["relu4_4"])) 35 | style_loss += self.criterion(self.compute_gram(x_vgg["relu5_2"]), self.compute_gram(y_vgg["relu5_2"])) 36 | 37 | return style_loss 38 | 39 | 40 | class PerceptualLoss(nn.Module): 41 | r""" 42 | Perceptual loss, VGG-based 43 | https://arxiv.org/abs/1603.08155 44 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 45 | """ 46 | 47 | def __init__(self, weights=[0.2, 0.4, 0.8, 1.0, 3.0]): 48 | super(PerceptualLoss, self).__init__() 49 | self.add_module("vgg", VGG19()) 50 | self.criterion = torch.nn.L1Loss() 51 | self.weights = weights 52 | 53 | def __call__(self, x, y): 54 | # Compute features 55 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 56 | 57 | content_loss = 0.0 58 | content_loss += self.weights[0] * self.criterion(x_vgg["relu1_1"], y_vgg["relu1_1"]) 59 | content_loss += self.weights[1] * self.criterion(x_vgg["relu2_1"], y_vgg["relu2_1"]) 60 | content_loss += self.weights[2] * self.criterion(x_vgg["relu3_1"], y_vgg["relu3_1"]) 61 | content_loss += self.weights[3] * self.criterion(x_vgg["relu4_1"], y_vgg["relu4_1"]) 62 | content_loss += self.weights[4] * self.criterion(x_vgg["relu5_1"], y_vgg["relu5_1"]) 63 | 64 | return content_loss 65 | 66 | 67 | class VGG19(torch.nn.Module): 68 | def __init__(self): 69 | super(VGG19, self).__init__() 70 | features = models.vgg19(pretrained=True).features 71 | self.relu1_1 = torch.nn.Sequential() 72 | self.relu1_2 = torch.nn.Sequential() 73 | 74 | self.relu2_1 = torch.nn.Sequential() 75 | self.relu2_2 = torch.nn.Sequential() 76 | 77 | self.relu3_1 = torch.nn.Sequential() 78 | self.relu3_2 = torch.nn.Sequential() 79 | self.relu3_3 = torch.nn.Sequential() 80 | self.relu3_4 = torch.nn.Sequential() 81 | 82 | self.relu4_1 = torch.nn.Sequential() 83 | self.relu4_2 = torch.nn.Sequential() 84 | self.relu4_3 = torch.nn.Sequential() 85 | self.relu4_4 = torch.nn.Sequential() 86 | 87 | self.relu5_1 = torch.nn.Sequential() 88 | self.relu5_2 = torch.nn.Sequential() 89 | self.relu5_3 = torch.nn.Sequential() 90 | self.relu5_4 = torch.nn.Sequential() 91 | 92 | for x in range(2): 93 | self.relu1_1.add_module(str(x), features[x]) 94 | 95 | for x in range(2, 4): 96 | self.relu1_2.add_module(str(x), features[x]) 97 | 98 | for x in range(4, 7): 99 | self.relu2_1.add_module(str(x), features[x]) 100 | 101 | for x in range(7, 9): 102 | self.relu2_2.add_module(str(x), features[x]) 103 | 104 | for x in range(9, 12): 105 | self.relu3_1.add_module(str(x), features[x]) 106 | 107 | for x in range(12, 14): 108 | self.relu3_2.add_module(str(x), features[x]) 109 | 110 | for x in range(14, 16): 111 | self.relu3_2.add_module(str(x), features[x]) 112 | 113 | for x in range(16, 18): 114 | self.relu3_4.add_module(str(x), features[x]) 115 | 116 | for x in range(18, 21): 117 | self.relu4_1.add_module(str(x), features[x]) 118 | 119 | for x in range(21, 23): 120 | self.relu4_2.add_module(str(x), features[x]) 121 | 122 | for x in range(23, 25): 123 | self.relu4_3.add_module(str(x), features[x]) 124 | 125 | for x in range(25, 27): 126 | self.relu4_4.add_module(str(x), features[x]) 127 | 128 | for x in range(27, 30): 129 | self.relu5_1.add_module(str(x), features[x]) 130 | 131 | for x in range(30, 32): 132 | self.relu5_2.add_module(str(x), features[x]) 133 | 134 | for x in range(32, 34): 135 | self.relu5_3.add_module(str(x), features[x]) 136 | 137 | for x in range(34, 36): 138 | self.relu5_4.add_module(str(x), features[x]) 139 | 140 | # don't need the gradients, just want the features 141 | for param in self.parameters(): 142 | param.requires_grad = False 143 | 144 | def forward(self, x): 145 | relu1_1 = self.relu1_1(x) 146 | relu1_2 = self.relu1_2(relu1_1) 147 | 148 | relu2_1 = self.relu2_1(relu1_2) 149 | relu2_2 = self.relu2_2(relu2_1) 150 | 151 | relu3_1 = self.relu3_1(relu2_2) 152 | relu3_2 = self.relu3_2(relu3_1) 153 | relu3_3 = self.relu3_3(relu3_2) 154 | relu3_4 = self.relu3_4(relu3_3) 155 | 156 | relu4_1 = self.relu4_1(relu3_4) 157 | relu4_2 = self.relu4_2(relu4_1) 158 | relu4_3 = self.relu4_3(relu4_2) 159 | relu4_4 = self.relu4_4(relu4_3) 160 | 161 | relu5_1 = self.relu5_1(relu4_4) 162 | relu5_2 = self.relu5_2(relu5_1) 163 | relu5_3 = self.relu5_3(relu5_2) 164 | relu5_4 = self.relu5_4(relu5_3) 165 | 166 | out = { 167 | "relu1_1": relu1_1, 168 | "relu1_2": relu1_2, 169 | "relu2_1": relu2_1, 170 | "relu2_2": relu2_2, 171 | "relu3_1": relu3_1, 172 | "relu3_2": relu3_2, 173 | "relu3_3": relu3_3, 174 | "relu3_4": relu3_4, 175 | "relu4_1": relu4_1, 176 | "relu4_2": relu4_2, 177 | "relu4_3": relu4_3, 178 | "relu4_4": relu4_4, 179 | "relu5_1": relu5_1, 180 | "relu5_2": relu5_2, 181 | "relu5_3": relu5_3, 182 | "relu5_4": relu5_4, 183 | } 184 | return out 185 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/losses/ssim_loss.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | class SSIM(torch.nn.Module): 9 | @staticmethod 10 | def gaussian(window_size, sigma): 11 | gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)]) 12 | return gauss / gauss.sum() 13 | 14 | @staticmethod 15 | def create_window(window_size, channel): 16 | _1D_window = SSIM.gaussian(window_size, 1.5).unsqueeze(1) 17 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 18 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 19 | return window 20 | 21 | @staticmethod 22 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 23 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 24 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 25 | 26 | mu1_sq = mu1.pow(2) 27 | mu2_sq = mu2.pow(2) 28 | mu1_mu2 = mu1 * mu2 29 | 30 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 31 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 32 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 33 | 34 | C1 = 0.01 ** 2 35 | C2 = 0.03 ** 2 36 | 37 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 38 | 39 | if size_average: 40 | return ssim_map.mean() 41 | else: 42 | return ssim_map.mean(1).mean(1).mean(1) 43 | 44 | def __init__(self, window_size=11, size_average=True): 45 | super(SSIM, self).__init__() 46 | self.window_size = window_size 47 | self.size_average = size_average 48 | self.channel = 1 49 | self.window = self.create_window(window_size, self.channel) 50 | 51 | def forward(self, img1, img2): 52 | (_, channel, _, _) = img1.size() 53 | 54 | if channel == self.channel and self.window.data.type() == img1.data.type(): 55 | window = self.window 56 | else: 57 | window = self.create_window(self.window_size, channel) 58 | 59 | if img1.is_cuda: 60 | window = window.cuda(img1.get_device()) 61 | window = window.type_as(img1) 62 | 63 | self.window = window 64 | self.channel = channel 65 | 66 | return self._ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | -------------------------------------------------------------------------------- /forward_operator/bkse/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter, defaultdict 3 | 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__( 10 | self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1 11 | ): 12 | self.milestones = Counter(milestones) 13 | self.gamma = gamma 14 | self.clear_state = clear_state 15 | self.restarts = restarts if restarts else [0] 16 | self.restarts = [v + 1 for v in self.restarts] 17 | self.restart_weights = weights if weights else [1] 18 | assert len(self.restarts) == len(self.restart_weights), "restarts and their weights do not match." 19 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | if self.last_epoch in self.restarts: 23 | if self.clear_state: 24 | self.optimizer.state = defaultdict(dict) 25 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 26 | return [group["initial_lr"] * weight for group in self.optimizer.param_groups] 27 | if self.last_epoch not in self.milestones: 28 | return [group["lr"] for group in self.optimizer.param_groups] 29 | return [group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 30 | 31 | 32 | class CosineAnnealingLR_Restart(_LRScheduler): 33 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 34 | self.T_period = T_period 35 | self.T_max = self.T_period[0] # current T period 36 | self.eta_min = eta_min 37 | self.restarts = restarts if restarts else [0] 38 | self.restarts = [v + 1 for v in self.restarts] 39 | self.restart_weights = weights if weights else [1] 40 | self.last_restart = 0 41 | assert len(self.restarts) == len(self.restart_weights), "restarts and their weights do not match." 42 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 43 | 44 | def get_lr(self): 45 | if self.last_epoch == 0: 46 | return self.base_lrs 47 | elif self.last_epoch in self.restarts: 48 | self.last_restart = self.last_epoch 49 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 50 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 51 | return [group["initial_lr"] * weight for group in self.optimizer.param_groups] 52 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 53 | return [ 54 | group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 55 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 56 | ] 57 | return [ 58 | (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) 59 | / (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) 60 | * (group["lr"] - self.eta_min) 61 | + self.eta_min 62 | for group in self.optimizer.param_groups 63 | ] 64 | 65 | 66 | if __name__ == "__main__": 67 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, betas=(0.9, 0.99)) 68 | ############################## 69 | # MultiStepLR_Restart 70 | ############################## 71 | # Original 72 | lr_steps = [200000, 400000, 600000, 800000] 73 | restarts = None 74 | restart_weights = None 75 | 76 | # two 77 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 78 | restarts = [500000] 79 | restart_weights = [1] 80 | 81 | # four 82 | lr_steps = [ 83 | 50000, 84 | 100000, 85 | 150000, 86 | 200000, 87 | 240000, 88 | 300000, 89 | 350000, 90 | 400000, 91 | 450000, 92 | 490000, 93 | 550000, 94 | 600000, 95 | 650000, 96 | 700000, 97 | 740000, 98 | 800000, 99 | 850000, 100 | 900000, 101 | 950000, 102 | 990000, 103 | ] 104 | restarts = [250000, 500000, 750000] 105 | restart_weights = [1, 1, 1] 106 | 107 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, clear_state=False) 108 | 109 | ############################## 110 | # Cosine Annealing Restart 111 | ############################## 112 | # two 113 | T_period = [500000, 500000] 114 | restarts = [500000] 115 | restart_weights = [1] 116 | 117 | # four 118 | T_period = [250000, 250000, 250000, 250000] 119 | restarts = [250000, 500000, 750000] 120 | restart_weights = [1, 1, 1] 121 | 122 | scheduler = CosineAnnealingLR_Restart( 123 | optimizer, T_period, eta_min=1e-7, restarts=restarts, weights=restart_weights 124 | ) 125 | 126 | ############################## 127 | # Draw figure 128 | ############################## 129 | N_iter = 1000000 130 | lr_l = list(range(N_iter)) 131 | for i in range(N_iter): 132 | scheduler.step() 133 | current_lr = optimizer.param_groups[0]["lr"] 134 | lr_l[i] = current_lr 135 | 136 | import matplotlib as mpl 137 | import matplotlib.ticker as mtick 138 | from matplotlib import pyplot as plt 139 | 140 | mpl.style.use("default") 141 | import seaborn 142 | 143 | seaborn.set(style="whitegrid") 144 | seaborn.set_context("paper") 145 | 146 | plt.figure(1) 147 | plt.subplot(111) 148 | plt.ticklabel_format(style="sci", axis="x", scilimits=(0, 0)) 149 | plt.title("Title", fontsize=16, color="k") 150 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label="learning rate scheme") 151 | legend = plt.legend(loc="upper right", shadow=False) 152 | ax = plt.gca() 153 | labels = ax.get_xticks().tolist() 154 | for k, v in enumerate(labels): 155 | labels[k] = str(int(v / 1000)) + "K" 156 | ax.set_xticklabels(labels) 157 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter("%.1e")) 158 | 159 | ax.set_ylabel("Learning rate") 160 | ax.set_xlabel("Iteration") 161 | fig = plt.gcf() 162 | plt.show() 163 | -------------------------------------------------------------------------------- /forward_operator/bkse/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/forward_operator/bkse/options/__init__.py -------------------------------------------------------------------------------- /forward_operator/bkse/options/data_augmentation/default.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | gpu_ids: [0] 3 | 4 | #### network structures 5 | KernelWizard: 6 | pretrained: experiments/pretrained/GOPRO_woVAE.pth 7 | input_nc: 3 8 | nf: 64 9 | front_RBs: 10 10 | back_RBs: 20 11 | N_frames: 1 12 | kernel_dim: 512 13 | use_vae: false 14 | KernelExtractor: 15 | norm: none 16 | use_sharp: true 17 | n_blocks: 4 18 | padding_type: reflect 19 | use_dropout: false 20 | Adapter: 21 | norm: none 22 | use_dropout: false 23 | -------------------------------------------------------------------------------- /forward_operator/bkse/options/options.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import os.path as osp 4 | 5 | import yaml 6 | from utils.util import OrderedYaml 7 | 8 | 9 | Loader, Dumper = OrderedYaml() 10 | 11 | 12 | def parse(opt_path, is_train=True): 13 | with open(opt_path, mode="r") as f: 14 | opt = yaml.load(f, Loader=Loader) 15 | # export CUDA_VISIBLE_DEVICES 16 | gpu_list = ",".join(str(x) for x in opt["gpu_ids"]) 17 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list 18 | print("export CUDA_VISIBLE_DEVICES=" + gpu_list) 19 | 20 | opt["is_train"] = is_train 21 | if opt["distortion"] == "sr": 22 | scale = opt["scale"] 23 | 24 | # datasets 25 | for phase, dataset in opt["datasets"].items(): 26 | phase = phase.split("_")[0] 27 | dataset["phase"] = phase 28 | if opt["distortion"] == "sr": 29 | dataset["scale"] = scale 30 | is_lmdb = False 31 | if dataset.get("dataroot_GT", None) is not None: 32 | dataset["dataroot_GT"] = osp.expanduser(dataset["dataroot_GT"]) 33 | if dataset["dataroot_GT"].endswith("lmdb"): 34 | is_lmdb = True 35 | if dataset.get("dataroot_LQ", None) is not None: 36 | dataset["dataroot_LQ"] = osp.expanduser(dataset["dataroot_LQ"]) 37 | if dataset["dataroot_LQ"].endswith("lmdb"): 38 | is_lmdb = True 39 | dataset["data_type"] = "lmdb" if is_lmdb else "img" 40 | if dataset["mode"].endswith("mc"): # for memcached 41 | dataset["data_type"] = "mc" 42 | dataset["mode"] = dataset["mode"].replace("_mc", "") 43 | 44 | # path 45 | for key, path in opt["path"].items(): 46 | if path and key in opt["path"] and key != "strict_load": 47 | opt["path"][key] = osp.expanduser(path) 48 | opt["path"]["root"] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 49 | if is_train: 50 | experiments_root = osp.join(opt["path"]["root"], "experiments", opt["name"]) 51 | opt["path"]["experiments_root"] = experiments_root 52 | opt["path"]["models"] = osp.join(experiments_root, "models") 53 | opt["path"]["training_state"] = osp.join(experiments_root, "training_state") 54 | opt["path"]["log"] = experiments_root 55 | opt["path"]["val_images"] = osp.join(experiments_root, "val_images") 56 | 57 | # change some options for debug mode 58 | if "debug" in opt["name"]: 59 | opt["train"]["val_freq"] = 8 60 | opt["logger"]["print_freq"] = 1 61 | opt["logger"]["save_checkpoint_freq"] = 8 62 | else: # test 63 | results_root = osp.join(opt["path"]["root"], "results", opt["name"]) 64 | opt["path"]["results_root"] = results_root 65 | opt["path"]["log"] = results_root 66 | 67 | # network 68 | if opt["distortion"] == "sr": 69 | opt["network_G"]["scale"] = scale 70 | 71 | return opt 72 | 73 | 74 | def dict2str(opt, indent_l=1): 75 | """dict to string for logger""" 76 | msg = "" 77 | for k, v in opt.items(): 78 | if isinstance(v, dict): 79 | msg += " " * (indent_l * 2) + k + ":[\n" 80 | msg += dict2str(v, indent_l + 1) 81 | msg += " " * (indent_l * 2) + "]\n" 82 | else: 83 | msg += " " * (indent_l * 2) + k + ": " + str(v) + "\n" 84 | return msg 85 | 86 | 87 | class NoneDict(dict): 88 | def __missing__(self, key): 89 | return None 90 | 91 | 92 | # convert to NoneDict, which return None for missing key. 93 | def dict_to_nonedict(opt): 94 | if isinstance(opt, dict): 95 | new_opt = dict() 96 | for key, sub_opt in opt.items(): 97 | new_opt[key] = dict_to_nonedict(sub_opt) 98 | return NoneDict(**new_opt) 99 | elif isinstance(opt, list): 100 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 101 | else: 102 | return opt 103 | 104 | 105 | def check_resume(opt, resume_iter): 106 | """Check resume states and pretrain_model paths""" 107 | logger = logging.getLogger("base") 108 | if opt["path"]["resume_state"]: 109 | if ( 110 | opt["path"].get("pretrain_model_G", None) is not None 111 | or opt["path"].get("pretrain_model_D", None) is not None 112 | ): 113 | logger.warning( 114 | "pretrain_model path will be ignored \ 115 | when resuming training." 116 | ) 117 | 118 | opt["path"]["pretrain_model_G"] = osp.join(opt["path"]["models"], "{}_G.pth".format(resume_iter)) 119 | logger.info("Set [pretrain_model_G] to " + opt["path"]["pretrain_model_G"]) 120 | if "gan" in opt["model"]: 121 | opt["path"]["pretrain_model_D"] = osp.join(opt["path"]["models"], "{}_D.pth".format(resume_iter)) 122 | logger.info("Set [pretrain_model_D] to " + opt["path"]["pretrain_model_D"]) 123 | -------------------------------------------------------------------------------- /forward_operator/bkse/requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.4.0 2 | torchvision >= 0.5.0 3 | pyyaml 4 | opencv-python 5 | numpy 6 | lmdb 7 | tqdm 8 | tensorboard >= 1.15.0 9 | ninja 10 | -------------------------------------------------------------------------------- /forward_operator/bkse/scripts/a.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /forward_operator/bkse/scripts/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import pickle 4 | import sys 5 | from multiprocessing import Pool 6 | 7 | import cv2 8 | import lmdb 9 | 10 | 11 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 12 | import data.util as data_util # noqa: E402 13 | import utils.util as util # noqa: E402 14 | 15 | 16 | def read_image_worker(path, key): 17 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 18 | return (key, img) 19 | 20 | 21 | def create_dataset(name, img_folder, lmdb_save_path, H_dst, W_dst, C_dst): 22 | """Create lmdb for the dataset, each image with a fixed size 23 | key pattern: folder_frameid 24 | """ 25 | # configurations 26 | read_all_imgs = False # whether real all images to memory with multiprocessing 27 | # Set False for use limited memory 28 | BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False 29 | n_thread = 40 30 | ######################################################## 31 | if not lmdb_save_path.endswith(".lmdb"): 32 | raise ValueError("lmdb_save_path must end with 'lmdb'.") 33 | if osp.exists(lmdb_save_path): 34 | print("Folder [{:s}] already exists. Exit...".format(lmdb_save_path)) 35 | sys.exit(1) 36 | 37 | # read all the image paths to a list 38 | print("Reading image path list ...") 39 | all_img_list = data_util._get_paths_from_images(img_folder) 40 | keys = [] 41 | for img_path in all_img_list: 42 | split_rlt = img_path.split("/") 43 | folder = split_rlt[-2] 44 | img_name = split_rlt[-1].split(".png")[0] 45 | keys.append(folder + "_" + img_name) 46 | 47 | if read_all_imgs: 48 | # read all images to memory (multiprocessing) 49 | dataset = {} # store all image data. list cannot keep the order, use dict 50 | print("Read images with multiprocessing, #thread: {} ...".format(n_thread)) 51 | pbar = util.ProgressBar(len(all_img_list)) 52 | 53 | def mycallback(arg): 54 | """get the image data and update pbar""" 55 | key = arg[0] 56 | dataset[key] = arg[1] 57 | pbar.update("Reading {}".format(key)) 58 | 59 | pool = Pool(n_thread) 60 | for path, key in zip(all_img_list, keys): 61 | pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) 62 | pool.close() 63 | pool.join() 64 | print("Finish reading {} images.\nWrite lmdb...".format(len(all_img_list))) 65 | 66 | # create lmdb environment 67 | data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes 68 | print("data size per image is: ", data_size_per_img) 69 | data_size = data_size_per_img * len(all_img_list) 70 | env = lmdb.open(lmdb_save_path, map_size=data_size * 10) 71 | 72 | # write data to lmdb 73 | pbar = util.ProgressBar(len(all_img_list)) 74 | txn = env.begin(write=True) 75 | for idx, (path, key) in enumerate(zip(all_img_list, keys)): 76 | pbar.update("Write {}".format(key)) 77 | key_byte = key.encode("ascii") 78 | data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) 79 | 80 | assert len(data.shape) > 2 or C_dst == 1, "different shape" 81 | 82 | if C_dst == 1: 83 | H, W = data.shape 84 | assert H == H_dst and W == W_dst, "different shape." 85 | else: 86 | H, W, C = data.shape 87 | assert H == H_dst and W == W_dst and C == 3, "different shape." 88 | txn.put(key_byte, data) 89 | if not read_all_imgs and idx % BATCH == 0: 90 | txn.commit() 91 | txn = env.begin(write=True) 92 | txn.commit() 93 | env.close() 94 | print("Finish writing lmdb.") 95 | 96 | # create meta information 97 | meta_info = {} 98 | meta_info["name"] = name 99 | channel = C_dst 100 | meta_info["resolution"] = "{}_{}_{}".format(channel, H_dst, W_dst) 101 | meta_info["keys"] = keys 102 | pickle.dump(meta_info, open(osp.join(lmdb_save_path, "meta_info.pkl"), "wb")) 103 | print("Finish creating lmdb meta info.") 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser(description="Kernel extractor testing") 108 | 109 | parser.add_argument("--H", action="store", help="source image height", type=int, required=True) 110 | parser.add_argument("--W", action="store", help="source image height", type=int, required=True) 111 | parser.add_argument("--C", action="store", help="source image height", type=int, required=True) 112 | parser.add_argument("--img_folder", action="store", help="img folder", type=str, required=True) 113 | parser.add_argument("--save_path", action="store", help="save path", type=str, default=".") 114 | parser.add_argument("--name", action="store", help="dataset name", type=str, required=True) 115 | 116 | args = parser.parse_args() 117 | create_dataset(args.name, args.img_folder, args.save_path, args.H, args.W, args.C) 118 | -------------------------------------------------------------------------------- /forward_operator/bkse/scripts/download_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import requests 6 | 7 | 8 | def download_file_from_google_drive(file_id, destination): 9 | os.makedirs(osp.dirname(destination), exist_ok=True) 10 | URL = "https://docs.google.com/uc?export=download" 11 | 12 | session = requests.Session() 13 | 14 | response = session.get(URL, params={"id": file_id}, stream=True) 15 | token = get_confirm_token(response) 16 | 17 | if token: 18 | params = {"id": file_id, "confirm": token} 19 | response = session.get(URL, params=params, stream=True) 20 | 21 | save_response_content(response, destination) 22 | 23 | 24 | def get_confirm_token(response): 25 | for key, value in response.cookies.items(): 26 | if key.startswith("download_warning"): 27 | return value 28 | 29 | return None 30 | 31 | 32 | def save_response_content(response, destination): 33 | CHUNK_SIZE = 32768 34 | 35 | with open(destination, "wb") as f: 36 | for chunk in response.iter_content(CHUNK_SIZE): 37 | if chunk: # filter out keep-alive new chunks 38 | f.write(chunk) 39 | 40 | 41 | if __name__ == "__main__": 42 | dataset_ids = { 43 | "GOPRO_Large": "1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2", 44 | "train_sharp": "1YLksKtMhd2mWyVSkvhDaDLWSc1qYNCz-", 45 | "train_blur": "1Be2cgzuuXibcqAuJekDgvHq4MLYkCgR8", 46 | "val_sharp": "1MGeObVQ1-Z29f-myDP7-8c3u0_xECKXq", 47 | "val_blur": "1N8z2yD0GDWmh6U4d4EADERtcUgDzGrHx", 48 | "test_blur": "1dr0--ZBKqr4P1M8lek6JKD1Vd6bhhrZT", 49 | } 50 | 51 | parser = argparse.ArgumentParser( 52 | description="Download REDS dataset from google drive to current folder", allow_abbrev=False 53 | ) 54 | 55 | parser.add_argument("--REDS_train_sharp", action="store_true", help="download REDS train_sharp.zip") 56 | parser.add_argument("--REDS_train_blur", action="store_true", help="download REDS train_blur.zip") 57 | parser.add_argument("--REDS_val_sharp", action="store_true", help="download REDS val_sharp.zip") 58 | parser.add_argument("--REDS_val_blur", action="store_true", help="download REDS val_blur.zip") 59 | parser.add_argument("--GOPRO", action="store_true", help="download GOPRO_Large.zip") 60 | 61 | args = parser.parse_args() 62 | 63 | if args.REDS_train_sharp: 64 | download_file_from_google_drive(dataset_ids["train_sharp"], "REDS/train_sharp.zip") 65 | if args.REDS_train_blur: 66 | download_file_from_google_drive(dataset_ids["train_blur"], "REDS/train_blur.zip") 67 | if args.REDS_val_sharp: 68 | download_file_from_google_drive(dataset_ids["val_sharp"], "REDS/val_sharp.zip") 69 | if args.REDS_val_blur: 70 | download_file_from_google_drive(dataset_ids["val_blur"], "REDS/val_blur.zip") 71 | if args.GOPRO: 72 | download_file_from_google_drive(dataset_ids["GOPRO_Large"], "GOPRO/GOPRO.zip") 73 | -------------------------------------------------------------------------------- /forward_operator/bkse/train_script.sh: -------------------------------------------------------------------------------- 1 | python3.7 train.py -opt options/REDS/wsharp_woVAE.yml 2 | -------------------------------------------------------------------------------- /forward_operator/bkse/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TFNTF/PostEdit/3dcf4a680f8e438e9a5e8a28f7a2c6fb2fb2b475/forward_operator/bkse/utils/__init__.py -------------------------------------------------------------------------------- /forward_operator/bkse/utils/a.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /forward_operator/fastmri_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | This source code is licensed under the MIT license found in the 4 | LICENSE file in the root directory of this source tree. 5 | """ 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | from packaging import version 11 | 12 | if version.parse(torch.__version__) >= version.parse("1.7.0"): 13 | import torch.fft # type: ignore 14 | 15 | 16 | def fft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 17 | """ 18 | Apply centered 2 dimensional Fast Fourier Transform. 19 | Args: 20 | data: Complex valued input data containing at least 3 dimensions: 21 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 22 | 2. All other dimensions are assumed to be batch dimensions. 23 | norm: Whether to include normalization. Must be one of ``"backward"`` 24 | or ``"ortho"``. See ``torch.fft.fft`` on PyTorch 1.9.0 for details. 25 | Returns: 26 | The FFT of the input. 27 | """ 28 | if not data.shape[-1] == 2: 29 | raise ValueError("Tensor does not have separate complex dim.") 30 | if norm not in ("ortho", "backward"): 31 | raise ValueError("norm must be 'ortho' or 'backward'.") 32 | normalized = True if norm == "ortho" else False 33 | 34 | data = ifftshift(data, dim=[-3, -2]) 35 | data = torch.fft(data, 2, normalized=normalized) 36 | data = fftshift(data, dim=[-3, -2]) 37 | 38 | return data 39 | 40 | 41 | def ifft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 42 | """ 43 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 44 | Args: 45 | data: Complex valued input data containing at least 3 dimensions: 46 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 47 | 2. All other dimensions are assumed to be batch dimensions. 48 | norm: Whether to include normalization. Must be one of ``"backward"`` 49 | or ``"ortho"``. See ``torch.fft.ifft`` on PyTorch 1.9.0 for 50 | details. 51 | Returns: 52 | The IFFT of the input. 53 | """ 54 | if not data.shape[-1] == 2: 55 | raise ValueError("Tensor does not have separate complex dim.") 56 | if norm not in ("ortho", "backward"): 57 | raise ValueError("norm must be 'ortho' or 'backward'.") 58 | normalized = True if norm == "ortho" else False 59 | 60 | data = ifftshift(data, dim=[-3, -2]) 61 | data = torch.ifft(data, 2, normalized=normalized) 62 | data = fftshift(data, dim=[-3, -2]) 63 | 64 | return data 65 | 66 | 67 | def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 68 | """ 69 | Apply centered 2 dimensional Fast Fourier Transform. 70 | Args: 71 | data: Complex valued input data containing at least 3 dimensions: 72 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 73 | 2. All other dimensions are assumed to be batch dimensions. 74 | norm: Normalization mode. See ``torch.fft.fft``. 75 | Returns: 76 | The FFT of the input. 77 | """ 78 | if not data.shape[-1] == 2: 79 | raise ValueError("Tensor does not have separate complex dim.") 80 | 81 | data = ifftshift(data, dim=[-3, -2]) 82 | data = torch.view_as_real( 83 | torch.fft.fftn( # type: ignore 84 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 85 | ) 86 | ) 87 | data = fftshift(data, dim=[-3, -2]) 88 | 89 | return data 90 | 91 | 92 | def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 93 | """ 94 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 95 | Args: 96 | data: Complex valued input data containing at least 3 dimensions: 97 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 98 | 2. All other dimensions are assumed to be batch dimensions. shape: (B,H,W,2) 99 | norm: Normalization mode. See ``torch.fft.ifft``. 100 | Returns: 101 | The IFFT of the input. 102 | """ 103 | if not data.shape[-1] == 2: 104 | raise ValueError("Tensor does not have separate complex dim.") 105 | 106 | data = ifftshift(data, dim=[-3, -2]) 107 | data = torch.view_as_real( 108 | torch.fft.ifftn( # type: ignore 109 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 110 | ) 111 | ) 112 | data = fftshift(data, dim=[-3, -2]) 113 | 114 | return data 115 | 116 | 117 | # Helper functions 118 | 119 | 120 | def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: 121 | """ 122 | Similar to roll but for only one dim. 123 | Args: 124 | x: A PyTorch tensor. 125 | shift: Amount to roll. 126 | dim: Which dimension to roll. 127 | Returns: 128 | Rolled version of x. 129 | """ 130 | shift = shift % x.size(dim) 131 | if shift == 0: 132 | return x 133 | 134 | left = x.narrow(dim, 0, x.size(dim) - shift) 135 | right = x.narrow(dim, x.size(dim) - shift, shift) 136 | 137 | return torch.cat((right, left), dim=dim) 138 | 139 | 140 | def roll( 141 | x: torch.Tensor, 142 | shift: List[int], 143 | dim: List[int], 144 | ) -> torch.Tensor: 145 | """ 146 | Similar to np.roll but applies to PyTorch Tensors. 147 | Args: 148 | x: A PyTorch tensor. 149 | shift: Amount to roll. 150 | dim: Which dimension to roll. 151 | Returns: 152 | Rolled version of x. 153 | """ 154 | if len(shift) != len(dim): 155 | raise ValueError("len(shift) must match len(dim)") 156 | 157 | for (s, d) in zip(shift, dim): 158 | x = roll_one_dim(x, s, d) 159 | 160 | return x 161 | 162 | 163 | def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: 164 | """ 165 | Similar to np.fft.fftshift but applies to PyTorch Tensors 166 | Args: 167 | x: A PyTorch tensor. 168 | dim: Which dimension to fftshift. 169 | Returns: 170 | fftshifted version of x. 171 | """ 172 | if dim is None: 173 | # this weird code is necessary for toch.jit.script typing 174 | dim = [0] * (x.dim()) 175 | for i in range(1, x.dim()): 176 | dim[i] = i 177 | 178 | # also necessary for torch.jit.script 179 | shift = [0] * len(dim) 180 | for i, dim_num in enumerate(dim): 181 | shift[i] = x.shape[dim_num] // 2 182 | 183 | return roll(x, shift, dim) 184 | 185 | 186 | def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: 187 | """ 188 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 189 | Args: 190 | x: A PyTorch tensor. 191 | dim: Which dimension to ifftshift. 192 | Returns: 193 | ifftshifted version of x. 194 | """ 195 | if dim is None: 196 | # this weird code is necessary for toch.jit.script typing 197 | dim = [0] * (x.dim()) 198 | for i in range(1, x.dim()): 199 | dim[i] = i 200 | 201 | # also necessary for torch.jit.script 202 | shift = [0] * len(dim) 203 | for i, dim_num in enumerate(dim): 204 | shift[i] = (x.shape[dim_num] + 1) // 2 205 | 206 | return roll(x, shift, dim) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.33.0 2 | asttokens==2.4.1 3 | certifi==2024.7.4 4 | charset-normalizer==3.3.2 5 | click==8.1.7 6 | comm==0.2.2 7 | contourpy==1.3.0 8 | cycler==0.12.1 9 | decorator==5.1.1 10 | diffusers @ git+https://github.com/huggingface/diffusers@1ca0a75567da1ca5a97681310c1b57e9f527a84a 11 | docker-pycreds==0.4.0 12 | exceptiongroup==1.2.2 13 | executing==2.0.1 14 | filelock==3.15.4 15 | fonttools==4.55.3 16 | fsspec==2024.6.1 17 | ftfy==6.2.0 18 | gitdb==4.0.11 19 | GitPython==3.1.43 20 | huggingface-hub==0.24.5 21 | idna==3.7 22 | importlib_metadata==8.2.0 23 | importlib_resources==6.4.5 24 | ipdb==0.13.13 25 | ipython==8.18.1 26 | ipywidgets==8.1.3 27 | jedi==0.19.1 28 | Jinja2==3.1.4 29 | joblib==1.4.2 30 | jupyterlab_widgets==3.0.11 31 | kiwisolver==1.4.7 32 | MarkupSafe==2.1.5 33 | matplotlib==3.9.3 34 | matplotlib-inline==0.1.7 35 | mc_bin_client==1.0.1 36 | mpmath==1.3.0 37 | networkx==3.2.1 38 | nltk==3.9.1 39 | numpy==1.26.4 40 | nvidia-cublas-cu12==12.1.3.1 41 | nvidia-cuda-cupti-cu12==12.1.105 42 | nvidia-cuda-nvrtc-cu12==12.1.105 43 | nvidia-cuda-runtime-cu12==12.1.105 44 | nvidia-cudnn-cu12==9.1.0.70 45 | nvidia-cufft-cu12==11.0.2.54 46 | nvidia-curand-cu12==10.3.2.106 47 | nvidia-cusolver-cu12==11.4.5.107 48 | nvidia-cusparse-cu12==12.1.0.106 49 | nvidia-nccl-cu12==2.20.5 50 | nvidia-nvjitlink-cu12==12.6.20 51 | nvidia-nvtx-cu12==12.1.105 52 | opencv-python==4.10.0.84 53 | packaging==24.1 54 | pandas==2.2.3 55 | parso==0.8.4 56 | peft==0.13.2 57 | pexpect==4.9.0 58 | pillow==10.4.0 59 | piq==0.8.0 60 | platformdirs==4.2.2 61 | prettytable==3.11.0 62 | prompt_toolkit==3.0.47 63 | protobuf==5.27.4 64 | psutil==6.0.0 65 | ptyprocess==0.7.0 66 | pure_eval==0.2.3 67 | Pygments==2.18.0 68 | pyparsing==3.2.0 69 | python-dateutil==2.9.0.post0 70 | pytz==2024.2 71 | PyYAML==6.0.1 72 | regex==2024.7.24 73 | requests==2.32.3 74 | safetensors==0.4.3 75 | scipy==1.13.1 76 | sentencepiece==0.2.0 77 | sentry-sdk==2.13.0 78 | setproctitle==1.3.3 79 | six==1.16.0 80 | smmap==5.0.1 81 | stack-data==0.6.3 82 | sympy==1.13.1 83 | tap==0.2 84 | tokenizers==0.19.1 85 | tomli==2.0.1 86 | torch==2.4.0 87 | torchvision==0.19.0 88 | tqdm==4.66.5 89 | traitlets==5.14.3 90 | transformers==4.43.3 91 | triton==3.0.0 92 | typing_extensions==4.12.2 93 | tzdata==2024.2 94 | urllib3==2.2.2 95 | wandb==0.17.7 96 | wcwidth==0.2.13 97 | widgetsnbextension==4.0.11 98 | zipp==3.19.2 99 | -------------------------------------------------------------------------------- /seq_aligner.py: -------------------------------------------------------------------------------- 1 | #Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import numpy as np 16 | 17 | 18 | class ScoreParams: 19 | 20 | def __init__(self, gap, match, mismatch): 21 | self.gap = gap 22 | self.match = match 23 | self.mismatch = mismatch 24 | 25 | def mis_match_char(self, x, y): 26 | if x != y: 27 | return self.mismatch 28 | else: 29 | return self.match 30 | 31 | 32 | def get_matrix(size_x, size_y, gap): 33 | matrix = [] 34 | for i in range(len(size_x) + 1): 35 | sub_matrix = [] 36 | for j in range(len(size_y) + 1): 37 | sub_matrix.append(0) 38 | matrix.append(sub_matrix) 39 | for j in range(1, len(size_y) + 1): 40 | matrix[0][j] = j*gap 41 | for i in range(1, len(size_x) + 1): 42 | matrix[i][0] = i*gap 43 | return matrix 44 | 45 | 46 | def get_matrix(size_x, size_y, gap): 47 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 48 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 49 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 50 | return matrix 51 | 52 | 53 | def get_traceback_matrix(size_x, size_y): 54 | matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32) 55 | matrix[0, 1:] = 1 56 | matrix[1:, 0] = 2 57 | matrix[0, 0] = 4 58 | return matrix 59 | 60 | 61 | def global_align(x, y, score): 62 | matrix = get_matrix(len(x), len(y), score.gap) 63 | trace_back = get_traceback_matrix(len(x), len(y)) 64 | for i in range(1, len(x) + 1): 65 | for j in range(1, len(y) + 1): 66 | left = matrix[i, j - 1] + score.gap 67 | up = matrix[i - 1, j] + score.gap 68 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 69 | matrix[i, j] = max(left, up, diag) 70 | if matrix[i, j] == left: 71 | trace_back[i, j] = 1 72 | elif matrix[i, j] == up: 73 | trace_back[i, j] = 2 74 | else: 75 | trace_back[i, j] = 3 76 | return matrix, trace_back 77 | 78 | 79 | def get_aligned_sequences(x, y, trace_back): 80 | x_seq = [] 81 | y_seq = [] 82 | i = len(x) 83 | j = len(y) 84 | mapper_y_to_x = [] 85 | while i > 0 or j > 0: 86 | if trace_back[i, j] == 3: 87 | x_seq.append(x[i-1]) 88 | y_seq.append(y[j-1]) 89 | i = i-1 90 | j = j-1 91 | mapper_y_to_x.append((j, i)) 92 | elif trace_back[i][j] == 1: 93 | x_seq.append('-') 94 | y_seq.append(y[j-1]) 95 | j = j-1 96 | mapper_y_to_x.append((j, -1)) 97 | elif trace_back[i][j] == 2: 98 | x_seq.append(x[i-1]) 99 | y_seq.append('-') 100 | i = i-1 101 | elif trace_back[i][j] == 4: 102 | break 103 | mapper_y_to_x.reverse() 104 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) 105 | 106 | 107 | def get_mapper(x: str, y: str, tokenizer, max_len=77): 108 | x_seq = tokenizer.encode(x) 109 | y_seq = tokenizer.encode(y) 110 | score = ScoreParams(0, 1, -1) 111 | matrix, trace_back = global_align(x_seq, y_seq, score) 112 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 113 | alphas = torch.ones(max_len) 114 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() 115 | mapper = torch.zeros(max_len, dtype=torch.int64) 116 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1] 117 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) 118 | return mapper, alphas 119 | 120 | 121 | def get_refinement_mapper(prompts, tokenizer, max_len=77): 122 | x_seq = prompts[0] 123 | mappers, alphas = [], [] 124 | for i in range(1, len(prompts)): 125 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) 126 | mappers.append(mapper) 127 | alphas.append(alpha) 128 | return torch.stack(mappers), torch.stack(alphas) 129 | 130 | 131 | def get_word_inds(text: str, word_place: int, tokenizer): 132 | split_text = text.split(" ") 133 | if type(word_place) is str: 134 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 135 | elif type(word_place) is int: 136 | word_place = [word_place] 137 | out = [] 138 | if len(word_place) > 0: 139 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 140 | cur_len, ptr = 0, 0 141 | 142 | for i in range(len(words_encode)): 143 | cur_len += len(words_encode[i]) 144 | if ptr in word_place: 145 | out.append(i + 1) 146 | if cur_len >= len(split_text[ptr]): 147 | ptr += 1 148 | cur_len = 0 149 | return np.array(out) 150 | 151 | 152 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): 153 | words_x = x.split(' ') 154 | words_y = y.split(' ') 155 | if len(words_x) != len(words_y): 156 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" 157 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") 158 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] 159 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] 160 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] 161 | mapper = np.zeros((max_len, max_len)) 162 | i = j = 0 163 | cur_inds = 0 164 | while i < max_len and j < max_len: 165 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: 166 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] 167 | if len(inds_source_) == len(inds_target_): 168 | mapper[inds_source_, inds_target_] = 1 169 | else: 170 | ratio = 1 / len(inds_target_) 171 | for i_t in inds_target_: 172 | mapper[inds_source_, i_t] = ratio 173 | cur_inds += 1 174 | i += len(inds_source_) 175 | j += len(inds_target_) 176 | elif cur_inds < len(inds_source): 177 | mapper[i, j] = 1 178 | i += 1 179 | j += 1 180 | else: 181 | mapper[j, j] = 1 182 | i += 1 183 | j += 1 184 | 185 | return torch.from_numpy(mapper).float() 186 | 187 | 188 | 189 | def get_replacement_mapper(prompts, tokenizer, max_len=77): 190 | x_seq = prompts[0] 191 | mappers = [] 192 | for i in range(1, len(prompts)): 193 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) 194 | mappers.append(mapper) 195 | return torch.stack(mappers) 196 | 197 | --------------------------------------------------------------------------------